elvwf

8 commits
Updated 2026-06-13 11:19:44
derive/src
derive/src/enum.rs
use darling::*;
use proc_macro2::TokenStream;
use quote::{format_ident, quote};
use syn::*;

use crate::utils::*;

#[derive(Debug, FromVariant)]
#[darling(attributes(wf))]
struct WiredField {
    ident: Ident,
    fields: darling::ast::Fields<Type>,
}

#[derive(Debug, FromMeta)]
struct DiscriminantSpec {
    header: Ident,
}

#[derive(Debug, FromDeriveInput)]
#[darling(attributes(wf), supports(enum_tuple))]
pub(crate) struct WiredInput {
    ident: Ident,
    generics: Generics,
    data: ast::Data<WiredField, darling::util::Ignored>,
    discriminant: DiscriminantSpec,
}

pub fn expand(input: &WiredInput) -> syn::Result<TokenStream> {
    let ident = &input.ident;
    let mut generics = input.generics.clone();
    let lt = match generics.lifetimes().next() {
        Some(lt) => lt.lifetime.clone(),
        None => {
            let lt = Lifetime::new("'a", proc_macro2::Span::call_site());
            generics
                .params
                .insert(0, GenericParam::Lifetime(LifetimeParam::new(lt.clone())));
            lt
        }
    };

    let (_, ty_generics, where_clause) = input.generics.split_for_impl();
    let (impl_generics, _, _) = generics.split_for_impl();

    let variants = input
        .data
        .as_ref()
        .take_enum()
        .expect("supports(enum_tuple) guarantees an enum");

    let Some(WiredField {
        fields: vfields, ..
    }) = variants.first()
    else {
        return Err(syn::Error::new(
            input.ident.span(),
            "enums must contain at least 1 variant",
        ));
    };

    let Some(bty) = vfields.fields.first() else {
        return Err(syn::Error::new(
            input.ident.span(),
            "enum variants must contain only 1 tuple field",
        ));
    };

    let ebty = expr_strip(bty);

    let mut len = Vec::<TokenStream>::with_capacity(variants.len());
    let mut header = Vec::<TokenStream>::with_capacity(variants.len());
    let mut encode = Vec::<TokenStream>::with_capacity(variants.len());
    let mut decode = Vec::<TokenStream>::with_capacity(variants.len());
    let mut where_msg = Vec::<TokenStream>::with_capacity(variants.len());
    let mut assert_disc = Vec::<TokenStream>::with_capacity(variants.len());

    let mdisc = format_ident!("MASK_{}", input.discriminant.header);
    let disc = &input.discriminant.header;

    for (i, f) in variants.iter().enumerate() {
        let Some(ty) = f.fields.fields.first() else {
            return Err(syn::Error::new(
                input.ident.span(),
                "enum variants must contain only 1 tuple field",
            ));
        };

        let sty = strip_lifetimes(ty);
        let ety = expr_strip(ty);

        if i != 0 {
            where_msg
                .push(quote!(#ty: elvwf::Wired<#lt, Header = <#bty as elvwf::Wired<#lt>>::Header>));

            assert_disc.push(quote!(assert!(#ety::#mdisc == #ebty::#mdisc);));
        }

        let ident = &f.ident;

        len.push(quote!(Self::#ident(x) => elvwf::msg::body_len::<#sty>(x)));
        header.push(quote!(Self::#ident(x) => elvwf::msg::header::<#sty>(x)));
        encode.push(quote!(Self::#ident(x) => elvwf::msg::encode_body::<#sty>(buf, x)));
        decode.push(
            quote!(#ety::#disc => Ok(Self::#ident(elvwf::msg::decode_body::<#sty>(buf, len, h)?))),
        );
    }

    let where_clause: WhereClause = if let Some(wc) = where_clause {
        parse_quote! {
            #wc
            #(#where_msg),*
        }
    } else {
        parse_quote! {
            where
                #(#where_msg),*
        }
    };

    Ok(quote! {
        impl #impl_generics #ident #ty_generics #where_clause {
            const fn __elvwf_assert() {
                const {
                    #(#assert_disc)*
                }
            }
        }

        #[allow(unused_braces)]
        impl #impl_generics elvwf::Wired<#lt> for #ident #ty_generics #where_clause {
            type Header = <#bty as elvwf::Wired<#lt>>::Header;
            type HeaderCodec = <#bty as elvwf::Wired<#lt>>::HeaderCodec;

            fn body_len(&self) -> usize {
                match self {
                    #(#len),*
                }
            }

            fn header(&self) -> Result<Self::Header, elvwf::Error> {
                match self {
                    #(#header),*
                }            }

            fn encode_body(self, buf: &mut &mut [u8]) -> Result<(), elvwf::Error> {
                match self {
                    #(#encode),*
                }
            }

            fn decode_body(buf: &mut &#lt[u8], len: usize, h: Self::Header) -> Result<Self, elvwf::Error> {
                match elvwf::header::get::<_, _, <#bty as elvwf::Wired<#lt>>::Header>(h, #ebty::#mdisc)? {
                    #(#decode),*,
                    _ => Err(elvwf::Error::UnknownVariant),
                }
            }
        }
    })
}