waylite

2 commits
Updated 2026-04-23 15:23:13
src
src/scalar.rs
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum ScalarError {
    CouldNotWriteLen {
        required_len: usize,
        available_len: usize,
        loc: &'static str,
    },
    CouldNotWriteValue {
        required_len: usize,
        available_len: usize,
        loc: &'static str,
    },

    CouldNotReadLen {
        required_len: usize,
        available_len: usize,
        loc: &'static str,
    },
    CouldNotReadValue {
        required_len: usize,
        available_len: usize,
        loc: &'static str,
    },
    InvalidUTF8,
}

macro_rules! ewlen {
    ($r:expr, $a:expr) => {
        ScalarError::CouldNotWriteLen {
            required_len: $r,
            available_len: $a,
            loc: concat!(file!(), ":", line!(), ":", column!()),
        }
    };
}

macro_rules! ewvalue {
    ($r:expr, $a:expr) => {
        ScalarError::CouldNotWriteValue {
            required_len: $r,
            available_len: $a,
            loc: concat!(file!(), ":", line!(), ":", column!()),
        }
    };
}

macro_rules! erlen {
    ($r:expr, $a:expr) => {
        ScalarError::CouldNotReadLen {
            required_len: $r,
            available_len: $a,
            loc: concat!(file!(), ":", line!(), ":", column!()),
        }
    };
}

macro_rules! ervalue {
    ($r:expr, $a:expr) => {
        ScalarError::CouldNotReadValue {
            required_len: $r,
            available_len: $a,
            loc: concat!(file!(), ":", line!(), ":", column!()),
        }
    };
}

pub fn write_i32(buf: &mut &mut [u8], v: i32) -> Result<(), ScalarError> {
    if buf.len() < 4 {
        return Err(ewvalue!(4, buf.len()));
    }

    let (head, tail) = core::mem::take(buf).split_at_mut(4);

    head.copy_from_slice(&v.to_ne_bytes());
    *buf = tail;

    Ok(())
}

pub fn write_u32(buf: &mut &mut [u8], v: u32) -> Result<(), ScalarError> {
    if buf.len() < 4 {
        return Err(ewvalue!(4, buf.len()));
    }

    let (head, tail) = core::mem::take(buf).split_at_mut(4);

    head.copy_from_slice(&v.to_ne_bytes());
    *buf = tail;

    Ok(())
}

pub fn write_string(buf: &mut &mut [u8], s: &str) -> Result<(), ScalarError> {
    if s.is_empty() {
        return write_u32(buf, 0);
    }

    let bytes = s.as_bytes();
    let len = bytes.len() + 1;
    write_u32(buf, len as u32).map_err(|_| ewlen!(4, buf.len()))?;

    if len > buf.len() {
        return Err(ewvalue!(len, buf.len()));
    }

    let (head, tail) = core::mem::take(buf).split_at_mut(len);
    head[..bytes.len()].copy_from_slice(bytes);
    head[bytes.len()] = 0;

    *buf = tail;

    let padding = (4 - (len & 3)) & 3;

    if padding != 0 {
        let (_, tail) = core::mem::take(buf).split_at_mut(padding);
        *buf = tail;
    }

    Ok(())
}

pub fn write_new_id(buf: &mut &mut [u8], id: u32) -> Result<(), ScalarError> {
    write_u32(buf, id)
}

pub fn write_new_id_untyped(
    buf: &mut &mut [u8],
    interface: &str,
    version: u32,
    id: u32,
) -> Result<(), ScalarError> {
    write_string(buf, interface)?;
    write_u32(buf, version)?;
    write_u32(buf, id)
}

pub fn write_array(buf: &mut &mut [u8], bytes: &[u8]) -> Result<(), ScalarError> {
    if bytes.is_empty() {
        return write_u32(buf, 0);
    }

    let len = bytes.len();
    write_u32(buf, len as u32).map_err(|_| ewlen!(4, buf.len()))?;

    if len > buf.len() {
        return Err(ewvalue!(len, buf.len()));
    }

    let (head, tail) = core::mem::take(buf).split_at_mut(len);
    head[..bytes.len()].copy_from_slice(bytes);

    *buf = tail;

    let padding = (4 - (len & 3)) & 3;

    if padding != 0 {
        let (_, tail) = core::mem::take(buf).split_at_mut(padding);
        *buf = tail;
    }

    Ok(())
}

pub fn read_i32(buf: &mut &[u8]) -> Result<i32, ScalarError> {
    if buf.len() < 4 {
        return Err(ervalue!(4, buf.len()));
    }

    let (head, tail) = buf.split_at(4);
    let v = i32::from_ne_bytes(head.try_into().unwrap());
    *buf = tail;

    Ok(v)
}

pub fn read_u32(buf: &mut &[u8]) -> Result<u32, ScalarError> {
    if buf.len() < 4 {
        return Err(ervalue!(4, buf.len()));
    }

    let (head, tail) = buf.split_at(4);
    let v = u32::from_ne_bytes(head.try_into().unwrap());
    *buf = tail;

    Ok(v)
}

pub fn read_string<'a>(buf: &mut &'a [u8]) -> Result<&'a str, ScalarError> {
    let len = read_u32(buf).map_err(|_| erlen!(4, buf.len()))? as usize;

    if len == 0 {
        return Ok("");
    }

    if buf.len() < len {
        return Err(ervalue!(len, buf.len()));
    }

    let (head, tail) = buf.split_at(len);
    *buf = tail;

    let s = core::str::from_utf8(&head[..len - 1]).map_err(|_| ScalarError::InvalidUTF8)?;

    let padding = (4 - (len & 3)) & 3;
    if padding != 0 {
        let (_, tail) = buf.split_at(padding);
        *buf = tail;
    }

    Ok(s)
}

pub fn read_new_id(buf: &mut &[u8]) -> Result<u32, ScalarError> {
    read_u32(buf)
}

pub fn read_new_id_untyped<'a>(buf: &mut &'a [u8]) -> Result<(&'a str, u32, u32), ScalarError> {
    let interface = read_string(buf)?;
    let version = read_u32(buf)?;
    let id = read_u32(buf)?;
    Ok((interface, version, id))
}

pub fn read_array<'a>(buf: &mut &'a [u8]) -> Result<&'a [u8], ScalarError> {
    let len = read_u32(buf).map_err(|_| erlen!(4, buf.len()))? as usize;

    if len == 0 {
        return Ok(&[]);
    }

    if buf.len() < len {
        return Err(ervalue!(len, buf.len()));
    }

    let (data, tail) = buf.split_at(len);
    *buf = tail;

    let padding = (4 - (len & 3)) & 3;
    if padding != 0 {
        let (_, tail) = buf.split_at(padding);
        *buf = tail;
    }

    Ok(data)
}

#[cfg(test)]
mod tests {
    use super::*;

    fn roundtrip_buf(f: impl FnOnce(&mut &mut [u8])) -> Vec<u8> {
        let mut storage = vec![0u8; 512];
        let mut w: &mut [u8] = &mut storage;
        f(&mut w);
        let written = 512 - w.len();
        storage[..written].to_vec()
    }

    #[test]
    fn i32_roundtrip_positive() {
        let buf = roundtrip_buf(|w| write_i32(w, 42).unwrap());
        assert_eq!(read_i32(&mut buf.as_slice()).unwrap(), 42);
    }

    #[test]
    fn i32_roundtrip_negative() {
        let buf = roundtrip_buf(|w| write_i32(w, -1).unwrap());
        assert_eq!(read_i32(&mut buf.as_slice()).unwrap(), -1);
    }

    #[test]
    fn i32_roundtrip_min_max() {
        for v in [i32::MIN, i32::MAX, 0] {
            let buf = roundtrip_buf(|w| write_i32(w, v).unwrap());
            assert_eq!(read_i32(&mut buf.as_slice()).unwrap(), v);
        }
    }

    #[test]
    fn u32_roundtrip() {
        for v in [0u32, 1, u32::MAX, 0xDEAD_BEEF] {
            let buf = roundtrip_buf(|w| write_u32(w, v).unwrap());
            assert_eq!(read_u32(&mut buf.as_slice()).unwrap(), v);
        }
    }

    #[test]
    fn string_empty() {
        let buf = roundtrip_buf(|w| write_string(w, "").unwrap());
        assert_eq!(read_string(&mut buf.as_slice()).unwrap(), "");
    }

    #[test]
    fn string_1_char() {
        // len=2 (b + \0), padding=2  →  total = 4 + 4 = 8 bytes
        let buf = roundtrip_buf(|w| write_string(w, "a").unwrap());
        assert_eq!(read_string(&mut buf.as_slice()).unwrap(), "a");
    }

    #[test]
    fn string_3_chars_no_padding() {
        // len=4 (abc\0), padding=0  →  total = 4 + 4 = 8 bytes
        let buf = roundtrip_buf(|w| write_string(w, "abc").unwrap());
        assert_eq!(read_string(&mut buf.as_slice()).unwrap(), "abc");
    }

    #[test]
    fn string_4_chars_with_padding() {
        // len=5 (abcd\0), padding=3  →  total = 4 + 8 = 12 bytes
        let buf = roundtrip_buf(|w| write_string(w, "abcd").unwrap());
        assert_eq!(read_string(&mut buf.as_slice()).unwrap(), "abcd");
    }

    #[test]
    fn string_unicode() {
        let s = "héllo";
        let buf = roundtrip_buf(|w| write_string(w, s).unwrap());
        assert_eq!(read_string(&mut buf.as_slice()).unwrap(), s);
    }

    #[test]
    fn string_advances_cursor() {
        // Write two strings back-to-back and read both
        let mut storage = vec![0u8; 256];
        let mut w: &mut [u8] = &mut storage;
        write_string(&mut w, "foo").unwrap();
        write_string(&mut w, "bar").unwrap();

        let written = 256 - w.len();
        let mut r: &[u8] = &storage[..written];
        assert_eq!(read_string(&mut r).unwrap(), "foo");
        assert_eq!(read_string(&mut r).unwrap(), "bar");
        assert!(r.is_empty());
    }

    #[test]
    fn mixed_roundtrip() {
        let mut storage = vec![0u8; 256];
        let mut w: &mut [u8] = &mut storage;
        write_i32(&mut w, -7).unwrap();
        write_u32(&mut w, 42).unwrap();
        write_string(&mut w, "rust").unwrap();
        write_i32(&mut w, 0).unwrap();

        let written = 256 - w.len();
        let mut r: &[u8] = &storage[..written];
        assert_eq!(read_i32(&mut r).unwrap(), -7);
        assert_eq!(read_u32(&mut r).unwrap(), 42);
        assert_eq!(read_string(&mut r).unwrap(), "rust");
        assert_eq!(read_i32(&mut r).unwrap(), 0);
        assert!(r.is_empty());
    }

    #[test]
    fn new_id_roundtrip() {
        for id in [0u32, 1, 42, u32::MAX] {
            let buf = roundtrip_buf(|w| write_new_id(w, id).unwrap());
            assert_eq!(read_new_id(&mut buf.as_slice()).unwrap(), id);
        }
    }

    #[test]
    fn new_id_untyped_roundtrip() {
        let cases = [("wl_compositor", 1u32, 5u32), ("wl_seat", 7, 1), ("", 0, 0)];
        for (iface, version, id) in cases {
            let buf = roundtrip_buf(|w| write_new_id_untyped(w, iface, version, id).unwrap());
            let (r_iface, r_version, r_id) = read_new_id_untyped(&mut buf.as_slice()).unwrap();
            assert_eq!(r_iface, iface);
            assert_eq!(r_version, version);
            assert_eq!(r_id, id);
        }
    }

    #[test]
    fn new_id_untyped_cursor_exhausted() {
        // After decode the buffer should be fully consumed
        let buf = roundtrip_buf(|w| write_new_id_untyped(w, "wl_output", 3, 99).unwrap());
        let mut r: &[u8] = &buf;
        read_new_id_untyped(&mut r).unwrap();
        assert!(r.is_empty());
    }

    #[test]
    fn array_empty() {
        let buf = roundtrip_buf(|w| write_array(w, &[]).unwrap());
        // Empty array: just a 4-byte length of 0, no padding needed
        assert_eq!(buf.len(), 4);
        let data = read_array(&mut buf.as_slice()).unwrap();
        assert!(data.is_empty());
    }

    #[test]
    fn array_1_byte_has_3_byte_padding() {
        let buf = roundtrip_buf(|w| write_array(w, &[0xAB]).unwrap());
        assert_eq!(buf.len(), 4 + 4); // 4 (len) + 1 (data) + 3 (padding)
        assert_eq!(read_array(&mut buf.as_slice()).unwrap(), &[0xAB]);
    }

    #[test]
    fn array_4_bytes_no_padding() {
        let data = [1u8, 2, 3, 4];
        let buf = roundtrip_buf(|w| write_array(w, &data).unwrap());
        assert_eq!(buf.len(), 4 + 4); // 4 (len) + 4 (data) + 0 (padding)
        assert_eq!(read_array(&mut buf.as_slice()).unwrap(), &data);
    }

    #[test]
    fn array_5_bytes_has_3_byte_padding() {
        let data = [0u8, 1, 2, 3, 4];
        let buf = roundtrip_buf(|w| write_array(w, &data).unwrap());
        assert_eq!(buf.len(), 4 + 8); // 4 (len) + 5 (data) + 3 (padding)
        assert_eq!(read_array(&mut buf.as_slice()).unwrap(), &data);
    }

    #[test]
    fn array_arbitrary_bytes() {
        let data: Vec<u8> = (0u8..=255).collect();
        let buf = roundtrip_buf(|w| write_array(w, &data).unwrap());
        assert_eq!(read_array(&mut buf.as_slice()).unwrap(), data.as_slice());
    }

    #[test]
    fn array_zero_copy_slice() {
        let data = b"hello";
        let buf = roundtrip_buf(|w| write_array(w, data).unwrap());
        let boxed: Box<[u8]> = buf.into();
        let mut r: &[u8] = &boxed;
        let out = read_array(&mut r).unwrap();
        assert_eq!(out.as_ptr(), boxed[4..].as_ptr());
    }

    #[test]
    fn array_two_consecutive_advances_cursor() {
        let a = b"foo";
        let b_data = b"hello!";
        let mut storage = vec![0u8; 256];
        let mut w: &mut [u8] = &mut storage;
        write_array(&mut w, a).unwrap();
        write_array(&mut w, b_data).unwrap();

        let written = 256 - w.len();
        let mut r: &[u8] = &storage[..written];
        assert_eq!(read_array(&mut r).unwrap(), a.as_slice());
        assert_eq!(read_array(&mut r).unwrap(), b_data.as_slice());
        assert!(r.is_empty());
    }

    #[test]
    fn full_roundtrip() {
        let mut storage = vec![0u8; 512];
        let mut w: &mut [u8] = &mut storage;

        write_new_id(&mut w, 3).unwrap();
        write_new_id_untyped(&mut w, "wl_compositor", 4, 7).unwrap();
        write_array(&mut w, &[0xDE, 0xAD, 0xBE, 0xEF, 0xFF]).unwrap();
        write_string(&mut w, "wayland").unwrap();
        write_i32(&mut w, -42).unwrap();
        write_u32(&mut w, 1337).unwrap();

        let written = 512 - w.len();
        let mut r: &[u8] = &storage[..written];

        assert_eq!(read_new_id(&mut r).unwrap(), 3);
        let (iface, ver, id) = read_new_id_untyped(&mut r).unwrap();
        assert_eq!((iface, ver, id), ("wl_compositor", 4, 7));
        assert_eq!(read_array(&mut r).unwrap(), &[0xDE, 0xAD, 0xBE, 0xEF, 0xFF]);
        assert_eq!(read_string(&mut r).unwrap(), "wayland");
        assert_eq!(read_i32(&mut r).unwrap(), -42);
        assert_eq!(read_u32(&mut r).unwrap(), 1337);
        assert!(r.is_empty());
    }
}