nipc

6 commits
Updated 2026-04-29 20:06:23
src/proto
src/proto/msg.rs
use crate::proto::{
    key::key,
    scalar::{ReadScalar, ScalarError, WriteScalar},
};

#[derive(Debug, Clone, Copy, PartialEq, Eq, PartialOrd, Ord)]
pub enum MsgError {
    HeaderError(ScalarError),
    BodyError(ScalarError),

    IncorrectHeader,
}

pub trait WriteMsg: Sized {
    fn write_msg(&self, buf: &mut &mut [u8]) -> Result<(), MsgError> {
        self.write_msg_header(buf)?;
        self.write_msg_body(buf)
    }

    fn write_msg_header(&self, buf: &mut &mut [u8]) -> Result<(), MsgError>;
    fn write_msg_body(&self, buf: &mut &mut [u8]) -> Result<(), MsgError>;
}

pub trait ReadMsg<'a>: Sized {
    fn read_msg(buf: &mut &'a [u8]) -> Result<Self, MsgError> {
        Self::read_msg_header(buf)?;
        Self::read_msg_body(buf)
    }

    fn read_msg_header(buf: &mut &'a [u8]) -> Result<(), MsgError>;
    fn read_msg_body(buf: &mut &'a [u8]) -> Result<Self, MsgError>;
}

macro_rules! pack {
    (
        $(
            $(#[$outer:meta])*
            $vis:vis struct $name:ident
            $(<$lt:lifetime>)?: $header:literal
            {
                $(
                    $(#[$field_meta:meta])*
                    $field_vis:vis $field_name:ident : $field_ty:ty
                ),* $(,)?
            }
        )+
    ) => {
        $(
            $(#[$outer])*
            $vis struct $name
            $(<$lt>)?
            {
                $(
                    $(#[$field_meta])*
                    $field_vis $field_name : $field_ty
                ),*
            }

            impl $(<$lt>)? $name $(<$lt>)? {
                pub const HEADER: u8 = $header as u8;
            }

            impl $(<$lt>)? WriteMsg for $name $(<$lt>)? {
                fn write_msg_header(
                    &self,
                    buf: &mut &mut [u8]
                ) -> Result<(), MsgError> {
                    ($header as u8).write_scalar(buf).map_err(MsgError::HeaderError)
                }

                #[allow(unused)]
                fn write_msg_body(
                    &self,
                    buf: &mut &mut [u8]
                ) -> Result<(), MsgError> {
                    $(
                        self.$field_name.write_scalar(buf).map_err(MsgError::BodyError)?;
                    )*

                    Ok(())
                }
            }

            impl <'__lt $(, $lt)?> ReadMsg<'__lt> for $name $(<$lt>)? $(where $lt: '__lt, '__lt: $lt)? {
                fn read_msg_header(buf: &mut &'__lt [u8]) -> Result<(), MsgError> {
                    if <u8>::read_scalar(buf).map_err(MsgError::HeaderError)? != $header {
                        return Err(MsgError::IncorrectHeader);
                    }

                    Ok(())
                }

                #[allow(unused)]
                fn read_msg_body(buf: &mut &'__lt [u8]) -> Result<Self, MsgError> {
                    Ok(Self {
                        $(
                            $field_name: <$field_ty>::read_scalar(buf).map_err(MsgError::BodyError)?,
                        )*
                    })
                }
            }
        )+
    }
}

pack! {
    #[derive(Debug, PartialEq, Eq)]
    pub struct DeclareSubscriber<'a>: 0x1 {
        pub key: &'a key,
        pub id: u32,
    }

    #[derive(Debug, PartialEq, Eq)]
    pub struct DeclareHandler<'a>: 0x2 {
        pub key: &'a key,
        pub id: u32,
    }

    #[derive(Debug, PartialEq, Eq)]
    pub struct UndeclareSubscriber: 0x3 {
        pub id: u32,
    }
    #[derive(Debug, PartialEq, Eq)]
    pub struct UndeclareHandler: 0x4 {
        pub id: u32,
    }

    #[derive(Debug, PartialEq, Eq)]
    pub struct Pub<'a>: 0x5 {
        pub key: &'a key,
        pub payload: &'a [u8],
    }
    #[derive(Debug, PartialEq, Eq)]
    pub struct Query: 0x6 {}
    #[derive(Debug, PartialEq, Eq)]
    pub struct Reply: 0x7 {}
    #[derive(Debug, PartialEq, Eq)]
    pub struct ReplyFinal: 0x8 {}
}

#[cfg(test)]
mod tests {
    use super::*;

    macro_rules! test_msg_roundtrip {
        ($($name:ident: $msg_type:ty = $constructor:expr),+ $(,)?) => {
            $(
                #[test]
                fn $name() {
                    let msg: $msg_type = $constructor;
                    let mut buffer = [0u8; 256];
                    let mut buf = &mut buffer[..];

                    msg.write_msg(&mut buf).unwrap();

                    let written_len = 256 - buf.len();
                    let mut read_buf = &buffer[..written_len];

                    let decoded = <$msg_type>::read_msg(&mut read_buf).unwrap();

                    assert_eq!(msg, decoded, "Decoded msg {decoded:?} not equals to original {msg:?}");
                    assert_eq!(read_buf.len(), 0, "Buffer not fully consumed");
                }
            )+

            #[test]
            fn test_sequential() {
                let mut buffer = [0u8; u16::MAX as usize];
                let mut buf = &mut buffer[..];

                $(
                    let msg: $msg_type = $constructor;
                    msg.write_msg(&mut buf).unwrap();
                )+

                let written_len = u16::MAX as usize - buf.len();
                let mut read_buf = &buffer[..written_len];

                $(
                    let msg: $msg_type = $constructor;
                    let decoded = <$msg_type>::read_msg(&mut read_buf).unwrap();

                    assert_eq!(msg, decoded, "Decoded msg {decoded:?} not equals to original {msg:?}");
                )+

                assert_eq!(read_buf.len(), 0, "Buffer not fully consumed");
            }
        };
    }

    test_msg_roundtrip! {
        test_declare_subscriber: DeclareSubscriber = DeclareSubscriber { key: key::unchecked("test/key"), id: 42 },
        test_declare_handler: DeclareHandler = DeclareHandler { key: key::unchecked("handler/key"), id: 123 },
        test_undeclare_subscriber: UndeclareSubscriber = UndeclareSubscriber { id: 999 },
        test_undeclare_handler: UndeclareHandler = UndeclareHandler { id: 888 },
        test_pub: Pub = Pub { key: key::unchecked("pub/topic"), payload: b"hello world" },
        test_pub_empty: Pub = Pub { key: key::unchecked("pub/topic"), payload: b"" },
        test_query: Query = Query {},
        test_repl: Reply = Reply {},
        test_reply_final: ReplyFinal = ReplyFinal {},
    }

    #[test]
    fn test_all_headers_unique() {
        let headers = [
            DeclareSubscriber::HEADER,
            DeclareHandler::HEADER,
            UndeclareSubscriber::HEADER,
            UndeclareHandler::HEADER,
            Pub::HEADER,
            Query::HEADER,
            Reply::HEADER,
            ReplyFinal::HEADER,
        ];

        let mut seen = std::collections::HashSet::new();
        for header in headers {
            assert!(seen.insert(header), "Duplicate header: {}", header);
        }
    }
}