elvwf

8 commits
Updated 2026-06-13 11:19:44
derive/src
derive/src/error.rs
use std::collections::{HashMap, HashSet};

use darling::*;
use proc_macro2::TokenStream;
use quote::quote;
use syn::{spanned::Spanned, *};

#[derive(Debug, FromAttributes)]
#[darling(attributes(wfe))]
pub struct EnumError {
    domain: Ident,
    #[darling(default)]
    extends: HashMap<Ident, ()>,
}

#[derive(Debug, FromVariant, Clone)]
#[darling(attributes(wfe))]
pub struct EnumVariant {
    ident: Ident,
    error: LitStr,
    discriminant: Option<Expr>,
}

struct Error {
    ident: Ident,
    all_extended: Vec<Ident>,
    all_variants: Vec<EnumVariant>,
}

pub fn expand(input: ItemMod) -> syn::Result<TokenStream> {
    let (_, items) = input
        .content
        .as_ref()
        .ok_or_else(|| syn::Error::new(input.span(), "module should contain a body"))?;

    let errors = collect_errors(items)?;

    let mut ts = Vec::with_capacity(errors.len());
    for (_, error) in &errors {
        ts.push(generate_enum(error, &errors)?);
    }

    Ok(quote!(#(#ts)*))
}

fn collect_errors(items: &[Item]) -> syn::Result<HashMap<Ident, Error>> {
    struct ErrorP1 {
        ident: Ident,
        extends: Vec<Ident>,
        variants: Vec<EnumVariant>,
    }

    let mut errors_p1 = HashMap::<Ident, ErrorP1>::new();

    for item in items {
        let Item::Enum(item) = item else {
            return Err(syn::Error::new(
                item.span(),
                "module must only contain enums",
            ));
        };

        let EnumError { domain, extends } = EnumError::from_attributes(&item.attrs)?;

        let error = errors_p1.entry(domain).or_insert_with(|| ErrorP1 {
            ident: item.ident.clone(),
            extends: extends.keys().cloned().collect(),
            variants: Vec::new(),
        });

        for variant in &item.variants {
            error.variants.push(EnumVariant::from_variant(variant)?);
        }
    }

    let mut errors = HashMap::<Ident, Error>::new();

    for (domain, error) in &errors_p1 {
        let mut all_variants: Vec<&EnumVariant> = error.variants.iter().collect();
        let mut visited = HashSet::<&Ident>::new();
        let mut to_visit: Vec<&Ident> = error.extends.iter().collect();

        while let Some(domain) = to_visit.pop() {
            if !visited.insert(domain) {
                continue;
            }

            let Some(extended) = errors_p1.get(domain) else {
                return Err(syn::Error::new(
                    error.ident.span(),
                    "extending an unknown domain",
                ));
            };

            all_variants.extend(&extended.variants);

            for domain in &extended.extends {
                if !visited.contains(domain) {
                    to_visit.push(domain);
                }
            }
        }

        let all_variants = Vec::from_iter(all_variants.into_iter().cloned());

        errors.insert(
            domain.clone(),
            Error {
                ident: error.ident.clone(),
                all_extended: visited.into_iter().cloned().collect(),
                all_variants,
            },
        );
    }

    Ok(errors)
}

fn generate_enum(error: &Error, errors: &HashMap<Ident, Error>) -> syn::Result<TokenStream> {
    let ident = &error.ident;
    let all_variants = &error.all_variants;

    let mut vts = Vec::with_capacity(all_variants.len());
    let mut dts = Vec::with_capacity(all_variants.len());
    let mut fts = Vec::new();

    for domain in &error.all_extended {
        let extended = errors.get(domain).unwrap();
        let eid = &extended.ident;

        let fvts = extended.all_variants.iter().map(|variant| {
            let vid = &variant.ident;
            quote!(#eid::#vid => Self::#vid)
        });

        fts.push(quote!(
            impl From<#eid> for #ident {
                fn from(value: #eid) -> Self {
                    match value {
                        #(#fvts),*
                    }
                }
            }
        ));
    }

    for variant in all_variants {
        let EnumVariant {
            ident,
            error,
            discriminant,
        } = variant;

        let Some(Expr::Lit(ExprLit {
            lit: Lit::Int(discriminant),
            ..
        })) = discriminant
        else {
            return Err(syn::Error::new(
                ident.span(),
                "variant must have an int discriminant",
            ));
        };

        vts.push(quote!(#ident = #discriminant));
        dts.push(quote!(Self::#ident => #error));
    }

    Ok(quote! {
        #[derive(Debug, Clone, PartialEq, Eq)]
        #[repr(u8)]
        pub enum #ident {
            #(#vts),*
        }

        impl ::core::fmt::Display for #ident {
            fn fmt(&self, f: &mut ::core::fmt::Formatter<'_>) -> ::core::fmt::Result {
                f.write_str(match self {
                    #(#dts),*
                })
            }
        }

        impl ::core::error::Error for #ident {}

        #(#fts)*
    })
}