derive/src/enum.rs
use darling::*;
use proc_macro2::TokenStream;
use quote::{format_ident, quote};
use syn::*;
use crate::utils::*;
#[derive(Debug, FromVariant)]
#[darling(attributes(wf))]
struct WiredField {
ident: Ident,
fields: darling::ast::Fields<Type>,
}
#[derive(Debug, FromMeta)]
struct DiscriminantSpec {
header: Ident,
}
#[derive(Debug, FromDeriveInput)]
#[darling(attributes(wf), supports(enum_tuple))]
pub(crate) struct WiredInput {
ident: Ident,
generics: Generics,
data: ast::Data<WiredField, darling::util::Ignored>,
discriminant: DiscriminantSpec,
}
pub fn expand(input: &WiredInput) -> syn::Result<TokenStream> {
let ident = &input.ident;
let mut generics = input.generics.clone();
let lt = match generics.lifetimes().next() {
Some(lt) => lt.lifetime.clone(),
None => {
let lt = Lifetime::new("'a", proc_macro2::Span::call_site());
generics
.params
.insert(0, GenericParam::Lifetime(LifetimeParam::new(lt.clone())));
lt
}
};
let (_, ty_generics, where_clause) = input.generics.split_for_impl();
let (impl_generics, _, _) = generics.split_for_impl();
let variants = input
.data
.as_ref()
.take_enum()
.expect("supports(enum_tuple) guarantees an enum");
let Some(WiredField {
fields: vfields, ..
}) = variants.first()
else {
return Err(syn::Error::new(
input.ident.span(),
"enums must contain at least 1 variant",
));
};
let Some(bty) = vfields.fields.first() else {
return Err(syn::Error::new(
input.ident.span(),
"enum variants must contain only 1 tuple field",
));
};
let ebty = expr_strip(bty);
let mut len = Vec::<TokenStream>::with_capacity(variants.len());
let mut header = Vec::<TokenStream>::with_capacity(variants.len());
let mut encode = Vec::<TokenStream>::with_capacity(variants.len());
let mut decode = Vec::<TokenStream>::with_capacity(variants.len());
let mut where_msg = Vec::<TokenStream>::with_capacity(variants.len());
let mut assert_disc = Vec::<TokenStream>::with_capacity(variants.len());
let mdisc = format_ident!("MASK_{}", input.discriminant.header);
let disc = &input.discriminant.header;
for (i, f) in variants.iter().enumerate() {
let Some(ty) = f.fields.fields.first() else {
return Err(syn::Error::new(
input.ident.span(),
"enum variants must contain only 1 tuple field",
));
};
let sty = strip_lifetimes(ty);
let ety = expr_strip(ty);
if i != 0 {
where_msg
.push(quote!(#ty: elvwf::Wired<#lt, Header = <#bty as elvwf::Wired<#lt>>::Header>));
assert_disc.push(quote!(assert!(#ety::#mdisc == #ebty::#mdisc);));
}
let ident = &f.ident;
len.push(quote!(Self::#ident(x) => elvwf::msg::body_len::<#sty>(x)));
header.push(quote!(Self::#ident(x) => elvwf::msg::header::<#sty>(x)));
encode.push(quote!(Self::#ident(x) => elvwf::msg::encode_body::<#sty>(buf, x)));
decode.push(
quote!(#ety::#disc => Ok(Self::#ident(elvwf::msg::decode_body::<#sty>(buf, len, h)?))),
);
}
let where_clause: WhereClause = if let Some(wc) = where_clause {
parse_quote! {
#wc
#(#where_msg),*
}
} else {
parse_quote! {
where
#(#where_msg),*
}
};
Ok(quote! {
impl #impl_generics #ident #ty_generics #where_clause {
const fn __elvwf_assert() {
const {
#(#assert_disc)*
}
}
}
#[allow(unused_braces)]
impl #impl_generics elvwf::Wired<#lt> for #ident #ty_generics #where_clause {
type Header = <#bty as elvwf::Wired<#lt>>::Header;
type HeaderCodec = <#bty as elvwf::Wired<#lt>>::HeaderCodec;
fn body_len(&self) -> usize {
match self {
#(#len),*
}
}
fn header(&self) -> Result<Self::Header, elvwf::Error> {
match self {
#(#header),*
} }
fn encode_body(self, buf: &mut &mut [u8]) -> Result<(), elvwf::Error> {
match self {
#(#encode),*
}
}
fn decode_body(buf: &mut &#lt[u8], len: usize, h: Self::Header) -> Result<Self, elvwf::Error> {
match elvwf::header::get::<_, _, <#bty as elvwf::Wired<#lt>>::Header>(h, #ebty::#mdisc)? {
#(#decode),*,
_ => Err(elvwf::Error::UnknownVariant),
}
}
}
})
}