wiredef

6 commits
Updated 2026-04-29 20:05:07
src
src/scalar.rs
#[derive(Debug, Clone, Copy, PartialEq, Eq, PartialOrd, Ord)]
pub enum ScalarError {
    BufferTooSmall,
}

pub trait WriteScalar: Sized {
    fn write_scalar(&self, buf: &mut &mut [u8]) -> Result<(), ScalarError>;
}

pub trait ReadScalar<'a>: Sized {
    fn read_scalar(buf: &mut &'a [u8]) -> Result<Self, ScalarError>;
}

fn write_and_advance(buf: &mut &mut [u8], payload: &[u8]) -> Result<(), ScalarError> {
    if payload.len() > buf.len() {
        return Err(ScalarError::BufferTooSmall);
    }

    let (to_edit, remain) = core::mem::take(buf).split_at_mut(payload.len());

    to_edit.copy_from_slice(payload);
    *buf = remain;

    Ok(())
}

fn read_and_advance<'a>(buf: &mut &'a [u8], len: usize) -> Result<&'a [u8], ScalarError> {
    if len > buf.len() {
        return Err(ScalarError::BufferTooSmall);
    }

    let (ret, remain) = buf.split_at(len);
    *buf = remain;

    Ok(ret)
}

macro_rules! int {
    ($($int:ty),+) => {
        $(
            impl WriteScalar for $int {
                fn write_scalar(&self, buf: &mut &mut [u8]) -> Result<(), ScalarError> {
                    write_and_advance(buf, &self.to_be_bytes())
                }
            }

            impl<'a> ReadScalar<'a> for $int {
                fn read_scalar(buf: &mut &[u8]) -> Result<Self, ScalarError> {
                    Ok(<$int>::from_be_bytes(
                        read_and_advance(buf, core::mem::size_of::<$int>())?
                            .try_into()
                            .unwrap(),
                    ))
                }
            }
        )+
    };
}

int!(u8, u16, u32, u64, i8, i16, i32, i64, usize, f32, f64);

impl WriteScalar for bool {
    fn write_scalar(&self, buf: &mut &mut [u8]) -> Result<(), ScalarError> {
        (*self as u8).write_scalar(buf)
    }
}

impl<'a> ReadScalar<'a> for bool {
    fn read_scalar(buf: &mut &'a [u8]) -> Result<Self, ScalarError> {
        Ok(u8::read_scalar(buf)? != 0)
    }
}

impl<const N: usize> WriteScalar for [u8; N] {
    fn write_scalar(&self, buf: &mut &mut [u8]) -> Result<(), ScalarError> {
        write_and_advance(buf, self)
    }
}

impl<'a, const N: usize> ReadScalar<'a> for [u8; N] {
    fn read_scalar(buf: &mut &[u8]) -> Result<Self, ScalarError> {
        Ok(read_and_advance(buf, N)?.try_into().unwrap())
    }
}

#[cfg(test)]
mod tests {
    use super::*;

    macro_rules! test_scalar_roundtrip {
        ($($name:ident: $ty:ty = $value:expr),+ $(,)?) => {
            $(
                #[test]
                fn $name() {
                    let mut buffer = [0u8; 32];
                    let mut buf = &mut buffer[..];

                    let value: $ty = $value;
                    value.write_scalar(&mut buf).unwrap();

                    let mut read_buf = &buffer[..];
                    let result = <$ty>::read_scalar(&mut read_buf).unwrap();
                    assert_eq!(result, value);
                }
            )+
        };
    }

    test_scalar_roundtrip! {
        test_u8_write_read: u8 = 42,
        test_u16_write_read: u16 = 0x1234,
        test_u32_write_read: u32 = 0x12345678,
        test_u64_write_read: u64 = 0x123456789ABCDEF0,
        test_i8_write_read: i8 = -42,
        test_i16_write_read: i16 = -1234,
        test_i32_write_read: i32 = -12345678,
        test_i64_write_read: i64 = -123456789012345,
        test_f32_write_read: f32 = 3.14159,
        test_f64_write_read: f64 = 3.141592653589793,
        test_usize_write_read: usize = 12345,

        test_u8_max: u8 = u8::MAX,
        test_u16_max: u16 = u16::MAX,
        test_u32_max: u32 = u32::MAX,
        test_u64_max: u64 = u64::MAX,
        test_i8_min: i8 = i8::MIN,
        test_i8_max: i8 = i8::MAX,
        test_i16_min: i16 = i16::MIN,
        test_i16_max: i16 = i16::MAX,
        test_i32_min: i32 = i32::MIN,
        test_i32_max: i32 = i32::MAX,
        test_i64_min: i64 = i64::MIN,
        test_i64_max: i64 = i64::MAX,
        test_u32_zero: u32 = 0,
        test_i32_zero: i32 = 0,
        test_f32_zero: f32 = 0.0,
        test_f64_zero: f64 = 0.0,

        test_bool_true: bool = true,
        test_bool_false: bool = false,

        test_array: [u8; 5] = [1, 2, 3, 4, 5],
    }

    #[test]
    fn test_buffer_too_small_write() {
        let mut buffer = [0u8; 2];
        let mut buf = &mut buffer[..];

        let value: u32 = 0x12345678;
        let result = value.write_scalar(&mut buf);
        assert!(matches!(result, Err(ScalarError::BufferTooSmall)));
    }

    #[test]
    fn test_buffer_too_small_read() {
        let buffer = [0u8; 2];
        let mut buf = &buffer[..];

        let result = u32::read_scalar(&mut buf);
        assert!(matches!(result, Err(ScalarError::BufferTooSmall)));
    }

    #[test]
    fn test_multiple_values_sequential() {
        let mut buffer = [0u8; 50];
        let mut buf = &mut buffer[..];

        let v1: u32 = 0x12345678;
        let v2: u16 = 0xABCD;
        let v3 = true;
        let v4 = [1, 2, 3];

        v1.write_scalar(&mut buf).unwrap();
        v2.write_scalar(&mut buf).unwrap();
        v3.write_scalar(&mut buf).unwrap();
        v4.write_scalar(&mut buf).unwrap();

        let mut read_buf = &buffer[..];
        assert_eq!(u32::read_scalar(&mut read_buf).unwrap(), 0x12345678);
        assert_eq!(u16::read_scalar(&mut read_buf).unwrap(), 0xABCD);
        assert_eq!(bool::read_scalar(&mut read_buf).unwrap(), true);
        assert_eq!(<[u8; 3]>::read_scalar(&mut read_buf).unwrap(), [1, 2, 3]);
    }

    #[test]
    fn test_big_endian_encoding() {
        let mut buffer = [0u8; 10];
        let mut buf = &mut buffer[..];

        let value: u32 = 0x12345678;
        value.write_scalar(&mut buf).unwrap();

        assert_eq!(buffer[0], 0x12);
        assert_eq!(buffer[1], 0x34);
        assert_eq!(buffer[2], 0x56);
        assert_eq!(buffer[3], 0x78);
    }

    #[test]
    fn test_buffer_advancement() {
        let mut buffer = [0u8; 20];
        let mut buf = &mut buffer[..];

        let initial_len = buf.len();
        42u32.write_scalar(&mut buf).unwrap();

        assert_eq!(buf.len(), initial_len - 4);
    }

    #[test]
    fn test_read_buffer_advancement() {
        let mut buffer = [0u8; 20];
        let mut write_buf = &mut buffer[..];
        42u32.write_scalar(&mut write_buf).unwrap();

        let mut read_buf = &buffer[..];
        let initial_len = read_buf.len();
        u32::read_scalar(&mut read_buf).unwrap();

        assert_eq!(read_buf.len(), initial_len - 4);
    }
}