nipc

6 commits
Updated 2026-04-29 20:06:23
src/proto
src/proto/scalar.rs
use crate::proto::key::{KeyError, key};

#[derive(Debug, Clone, Copy, PartialEq, Eq, PartialOrd, Ord)]
pub enum ScalarError {
    BufferTooSmall,
    LengthTooBig,
    Utf8,
    Key(KeyError),
}

pub trait WriteScalar: Sized {
    fn write_scalar(&self, buf: &mut &mut [u8]) -> Result<(), ScalarError>;
}

pub trait ReadScalar<'a>: Sized {
    fn read_scalar(buf: &mut &'a [u8]) -> Result<Self, ScalarError>;
}

fn write_and_advance(buf: &mut &mut [u8], payload: &[u8]) -> Result<(), ScalarError> {
    if payload.len() > buf.len() {
        return Err(ScalarError::BufferTooSmall);
    }

    let (to_edit, remain) = core::mem::take(buf).split_at_mut(payload.len());

    to_edit.copy_from_slice(payload);
    *buf = remain;

    Ok(())
}

fn read_and_advance<'a>(buf: &mut &'a [u8], len: usize) -> Result<&'a [u8], ScalarError> {
    if len > buf.len() {
        return Err(ScalarError::BufferTooSmall);
    }

    let (ret, remain) = buf.split_at(len);
    *buf = remain;

    Ok(ret)
}

macro_rules! int {
    ($($int:ty),+) => {
        $(
            impl WriteScalar for $int {
                fn write_scalar(&self, buf: &mut &mut [u8]) -> Result<(), ScalarError> {
                    write_and_advance(buf, &self.to_be_bytes())
                }
            }

            impl<'a> ReadScalar<'a> for $int {
                fn read_scalar(buf: &mut &[u8]) -> Result<Self, ScalarError> {
                    Ok(<$int>::from_be_bytes(
                        read_and_advance(buf, core::mem::size_of::<$int>())?
                            .try_into()
                            .unwrap(),
                    ))
                }
            }
        )+
    };
}

int!(u8, u16, u32, u64, i8, i16, i32, i64, usize, f32, f64);

impl WriteScalar for bool {
    fn write_scalar(&self, buf: &mut &mut [u8]) -> Result<(), ScalarError> {
        (*self as u8).write_scalar(buf)
    }
}

impl<'a> ReadScalar<'a> for bool {
    fn read_scalar(buf: &mut &'a [u8]) -> Result<Self, ScalarError> {
        Ok(u8::read_scalar(buf)? != 0)
    }
}

impl<const N: usize> WriteScalar for [u8; N] {
    fn write_scalar(&self, buf: &mut &mut [u8]) -> Result<(), ScalarError> {
        write_and_advance(buf, self)
    }
}

impl<'a, const N: usize> ReadScalar<'a> for [u8; N] {
    fn read_scalar(buf: &mut &[u8]) -> Result<Self, ScalarError> {
        Ok(read_and_advance(buf, N)?.try_into().unwrap())
    }
}

impl WriteScalar for &[u8] {
    fn write_scalar(&self, buf: &mut &mut [u8]) -> Result<(), ScalarError> {
        u32::try_from(self.len())
            .map_err(|_| ScalarError::LengthTooBig)?
            .write_scalar(buf)?;

        write_and_advance(buf, self)
    }
}

impl<'a> ReadScalar<'a> for &'a [u8] {
    fn read_scalar(buf: &mut &'a [u8]) -> Result<Self, ScalarError> {
        let len = u32::read_scalar(buf)?;

        read_and_advance(buf, len as usize)
    }
}

impl WriteScalar for &str {
    fn write_scalar(&self, buf: &mut &mut [u8]) -> Result<(), ScalarError> {
        self.as_bytes().write_scalar(buf)
    }
}

impl<'a> ReadScalar<'a> for &'a str {
    fn read_scalar(buf: &mut &'a [u8]) -> Result<Self, ScalarError> {
        core::str::from_utf8(<&[u8]>::read_scalar(buf)?).map_err(|_| ScalarError::Utf8)
    }
}

impl WriteScalar for &key {
    fn write_scalar(&self, buf: &mut &mut [u8]) -> Result<(), ScalarError> {
        self.as_str().write_scalar(buf)
    }
}

impl<'a> ReadScalar<'a> for &'a key {
    fn read_scalar(buf: &mut &'a [u8]) -> Result<Self, ScalarError> {
        key::new(<&str>::read_scalar(buf)?).map_err(ScalarError::Key)
    }
}

#[cfg(test)]
mod tests {
    use super::*;

    macro_rules! test_scalar_roundtrip {
        ($($name:ident: $ty:ty = $value:expr),+ $(,)?) => {
            $(
                #[test]
                fn $name() {
                    let mut buffer = [0u8; 32];
                    let mut buf = &mut buffer[..];

                    let value: $ty = $value;
                    value.write_scalar(&mut buf).unwrap();

                    let mut read_buf = &buffer[..];
                    let result = <$ty>::read_scalar(&mut read_buf).unwrap();
                    assert_eq!(result, value);
                }
            )+
        };
    }

    test_scalar_roundtrip! {
        test_u8_write_read: u8 = 42,
        test_u16_write_read: u16 = 0x1234,
        test_u32_write_read: u32 = 0x12345678,
        test_u64_write_read: u64 = 0x123456789ABCDEF0,
        test_i8_write_read: i8 = -42,
        test_i16_write_read: i16 = -1234,
        test_i32_write_read: i32 = -12345678,
        test_i64_write_read: i64 = -123456789012345,
        test_f32_write_read: f32 = 3.14159,
        test_f64_write_read: f64 = 3.141592653589793,
        test_usize_write_read: usize = 12345,

        test_u8_max: u8 = u8::MAX,
        test_u16_max: u16 = u16::MAX,
        test_u32_max: u32 = u32::MAX,
        test_u64_max: u64 = u64::MAX,
        test_i8_min: i8 = i8::MIN,
        test_i8_max: i8 = i8::MAX,
        test_i16_min: i16 = i16::MIN,
        test_i16_max: i16 = i16::MAX,
        test_i32_min: i32 = i32::MIN,
        test_i32_max: i32 = i32::MAX,
        test_i64_min: i64 = i64::MIN,
        test_i64_max: i64 = i64::MAX,
        test_u32_zero: u32 = 0,
        test_i32_zero: i32 = 0,
        test_f32_zero: f32 = 0.0,
        test_f64_zero: f64 = 0.0,

        test_bool_true: bool = true,
        test_bool_false: bool = false,

        test_array: [u8; 5] = [1, 2, 3, 4, 5],
        test_slice: &[u8] =  &[10, 20, 30, 40, 50],
        test_slice_empty: &[u8] =  &[],

        test_str: &str = "Hello, World!",
        test_str_empty: &str = "",
        test_str_unicode: &str = "Hëllö 世界 🦀",
    }

    #[test]
    fn test_buffer_too_small_write() {
        let mut buffer = [0u8; 2];
        let mut buf = &mut buffer[..];

        let value: u32 = 0x12345678;
        let result = value.write_scalar(&mut buf);
        assert!(matches!(result, Err(ScalarError::BufferTooSmall)));
    }

    #[test]
    fn test_buffer_too_small_read() {
        let buffer = [0u8; 2];
        let mut buf = &buffer[..];

        let result = u32::read_scalar(&mut buf);
        assert!(matches!(result, Err(ScalarError::BufferTooSmall)));
    }

    #[test]
    fn test_str_invalid_utf8() {
        let mut buffer = [0u8; 20];
        buffer[0..4].copy_from_slice(&3u32.to_be_bytes());
        buffer[4] = 0xFF;
        buffer[5] = 0xFF;
        buffer[6] = 0xFF;

        let mut read_buf = &buffer[..];
        let result = <&str>::read_scalar(&mut read_buf);
        assert!(matches!(result, Err(ScalarError::Utf8)));
    }

    #[test]
    fn test_multiple_values_sequential() {
        let mut buffer = [0u8; 50];
        let mut buf = &mut buffer[..];

        let v1: u32 = 0x12345678;
        let v2: u16 = 0xABCD;
        let v3 = true;
        let v4 = "test";

        v1.write_scalar(&mut buf).unwrap();
        v2.write_scalar(&mut buf).unwrap();
        v3.write_scalar(&mut buf).unwrap();
        v4.write_scalar(&mut buf).unwrap();

        let mut read_buf = &buffer[..];
        assert_eq!(u32::read_scalar(&mut read_buf).unwrap(), 0x12345678);
        assert_eq!(u16::read_scalar(&mut read_buf).unwrap(), 0xABCD);
        assert_eq!(bool::read_scalar(&mut read_buf).unwrap(), true);
        assert_eq!(<&str>::read_scalar(&mut read_buf).unwrap(), "test");
    }

    #[test]
    fn test_big_endian_encoding() {
        let mut buffer = [0u8; 10];
        let mut buf = &mut buffer[..];

        let value: u32 = 0x12345678;
        value.write_scalar(&mut buf).unwrap();

        assert_eq!(buffer[0], 0x12);
        assert_eq!(buffer[1], 0x34);
        assert_eq!(buffer[2], 0x56);
        assert_eq!(buffer[3], 0x78);
    }

    #[test]
    fn test_buffer_advancement() {
        let mut buffer = [0u8; 20];
        let mut buf = &mut buffer[..];

        let initial_len = buf.len();
        42u32.write_scalar(&mut buf).unwrap();

        assert_eq!(buf.len(), initial_len - 4);
    }

    #[test]
    fn test_read_buffer_advancement() {
        let mut buffer = [0u8; 20];
        let mut write_buf = &mut buffer[..];
        42u32.write_scalar(&mut write_buf).unwrap();

        let mut read_buf = &buffer[..];
        let initial_len = read_buf.len();
        u32::read_scalar(&mut read_buf).unwrap();

        assert_eq!(read_buf.len(), initial_len - 4);
    }
}