src/proto/msg.rs
use crate::proto::{
key::key,
scalar::{ReadScalar, ScalarError, WriteScalar},
};
#[derive(Debug, Clone, Copy, PartialEq, Eq, PartialOrd, Ord)]
pub enum MsgError {
HeaderError(ScalarError),
BodyError(ScalarError),
IncorrectHeader,
}
pub trait WriteMsg: Sized {
fn write_msg(&self, buf: &mut &mut [u8]) -> Result<(), MsgError> {
self.write_msg_header(buf)?;
self.write_msg_body(buf)
}
fn write_msg_header(&self, buf: &mut &mut [u8]) -> Result<(), MsgError>;
fn write_msg_body(&self, buf: &mut &mut [u8]) -> Result<(), MsgError>;
}
pub trait ReadMsg<'a>: Sized {
fn read_msg(buf: &mut &'a [u8]) -> Result<Self, MsgError> {
Self::read_msg_header(buf)?;
Self::read_msg_body(buf)
}
fn read_msg_header(buf: &mut &'a [u8]) -> Result<(), MsgError>;
fn read_msg_body(buf: &mut &'a [u8]) -> Result<Self, MsgError>;
}
macro_rules! pack {
(
$(
$(#[$outer:meta])*
$vis:vis struct $name:ident
$(<$lt:lifetime>)?: $header:literal
{
$(
$(#[$field_meta:meta])*
$field_vis:vis $field_name:ident : $field_ty:ty
),* $(,)?
}
)+
) => {
$(
$(#[$outer])*
$vis struct $name
$(<$lt>)?
{
$(
$(#[$field_meta])*
$field_vis $field_name : $field_ty
),*
}
impl $(<$lt>)? $name $(<$lt>)? {
pub const HEADER: u8 = $header as u8;
}
impl $(<$lt>)? WriteMsg for $name $(<$lt>)? {
fn write_msg_header(
&self,
buf: &mut &mut [u8]
) -> Result<(), MsgError> {
($header as u8).write_scalar(buf).map_err(MsgError::HeaderError)
}
#[allow(unused)]
fn write_msg_body(
&self,
buf: &mut &mut [u8]
) -> Result<(), MsgError> {
$(
self.$field_name.write_scalar(buf).map_err(MsgError::BodyError)?;
)*
Ok(())
}
}
impl <'__lt $(, $lt)?> ReadMsg<'__lt> for $name $(<$lt>)? $(where $lt: '__lt, '__lt: $lt)? {
fn read_msg_header(buf: &mut &'__lt [u8]) -> Result<(), MsgError> {
if <u8>::read_scalar(buf).map_err(MsgError::HeaderError)? != $header {
return Err(MsgError::IncorrectHeader);
}
Ok(())
}
#[allow(unused)]
fn read_msg_body(buf: &mut &'__lt [u8]) -> Result<Self, MsgError> {
Ok(Self {
$(
$field_name: <$field_ty>::read_scalar(buf).map_err(MsgError::BodyError)?,
)*
})
}
}
)+
}
}
pack! {
#[derive(Debug, PartialEq, Eq)]
pub struct DeclareSubscriber<'a>: 0x1 {
pub key: &'a key,
pub id: u32,
}
#[derive(Debug, PartialEq, Eq)]
pub struct DeclareHandler<'a>: 0x2 {
pub key: &'a key,
pub id: u32,
}
#[derive(Debug, PartialEq, Eq)]
pub struct UndeclareSubscriber: 0x3 {
pub id: u32,
}
#[derive(Debug, PartialEq, Eq)]
pub struct UndeclareHandler: 0x4 {
pub id: u32,
}
#[derive(Debug, PartialEq, Eq)]
pub struct Pub<'a>: 0x5 {
pub key: &'a key,
pub payload: &'a [u8],
}
#[derive(Debug, PartialEq, Eq)]
pub struct Query: 0x6 {}
#[derive(Debug, PartialEq, Eq)]
pub struct Reply: 0x7 {}
#[derive(Debug, PartialEq, Eq)]
pub struct ReplyFinal: 0x8 {}
}
#[cfg(test)]
mod tests {
use super::*;
macro_rules! test_msg_roundtrip {
($($name:ident: $msg_type:ty = $constructor:expr),+ $(,)?) => {
$(
#[test]
fn $name() {
let msg: $msg_type = $constructor;
let mut buffer = [0u8; 256];
let mut buf = &mut buffer[..];
msg.write_msg(&mut buf).unwrap();
let written_len = 256 - buf.len();
let mut read_buf = &buffer[..written_len];
let decoded = <$msg_type>::read_msg(&mut read_buf).unwrap();
assert_eq!(msg, decoded, "Decoded msg {decoded:?} not equals to original {msg:?}");
assert_eq!(read_buf.len(), 0, "Buffer not fully consumed");
}
)+
#[test]
fn test_sequential() {
let mut buffer = [0u8; u16::MAX as usize];
let mut buf = &mut buffer[..];
$(
let msg: $msg_type = $constructor;
msg.write_msg(&mut buf).unwrap();
)+
let written_len = u16::MAX as usize - buf.len();
let mut read_buf = &buffer[..written_len];
$(
let msg: $msg_type = $constructor;
let decoded = <$msg_type>::read_msg(&mut read_buf).unwrap();
assert_eq!(msg, decoded, "Decoded msg {decoded:?} not equals to original {msg:?}");
)+
assert_eq!(read_buf.len(), 0, "Buffer not fully consumed");
}
};
}
test_msg_roundtrip! {
test_declare_subscriber: DeclareSubscriber = DeclareSubscriber { key: key::unchecked("test/key"), id: 42 },
test_declare_handler: DeclareHandler = DeclareHandler { key: key::unchecked("handler/key"), id: 123 },
test_undeclare_subscriber: UndeclareSubscriber = UndeclareSubscriber { id: 999 },
test_undeclare_handler: UndeclareHandler = UndeclareHandler { id: 888 },
test_pub: Pub = Pub { key: key::unchecked("pub/topic"), payload: b"hello world" },
test_pub_empty: Pub = Pub { key: key::unchecked("pub/topic"), payload: b"" },
test_query: Query = Query {},
test_repl: Reply = Reply {},
test_reply_final: ReplyFinal = ReplyFinal {},
}
#[test]
fn test_all_headers_unique() {
let headers = [
DeclareSubscriber::HEADER,
DeclareHandler::HEADER,
UndeclareSubscriber::HEADER,
UndeclareHandler::HEADER,
Pub::HEADER,
Query::HEADER,
Reply::HEADER,
ReplyFinal::HEADER,
];
let mut seen = std::collections::HashSet::new();
for header in headers {
assert!(seen.insert(header), "Duplicate header: {}", header);
}
}
}