130 lines
3.9 KiB
Rust
130 lines
3.9 KiB
Rust
extern crate proc_macro;
|
|
|
|
use proc_macro::TokenStream;
|
|
use proc_macro2::TokenStream as TokenStream2;
|
|
use quote::quote;
|
|
use syn::{Data, DeriveInput, Fields, parse_macro_input};
|
|
|
|
/// Derives `CacheHash` for a struct or enum.
|
|
///
|
|
/// All fields must implement `CacheHash`. Fields annotated with `#[cache_hash(skip)]`
|
|
/// are excluded from hashing.
|
|
///
|
|
/// # Example
|
|
///
|
|
/// ```
|
|
/// # use graphene_hash::CacheHash;
|
|
/// #[derive(CacheHash)]
|
|
/// pub struct MyNode {
|
|
/// pub value: f64,
|
|
/// pub count: u32,
|
|
/// #[cache_hash(skip)]
|
|
/// pub debug_label: String,
|
|
/// }
|
|
/// ```
|
|
#[proc_macro_derive(CacheHash, attributes(cache_hash))]
|
|
pub fn derive_cache_hash(input: TokenStream) -> TokenStream {
|
|
let ast = parse_macro_input!(input as DeriveInput);
|
|
let name = &ast.ident;
|
|
let mut generics = ast.generics.clone();
|
|
for param in &mut generics.params {
|
|
if let syn::GenericParam::Type(type_param) = param {
|
|
type_param.bounds.push(syn::parse_quote!(graphene_hash::CacheHash));
|
|
}
|
|
}
|
|
let (impl_generics, ty_generics, where_clause) = generics.split_for_impl();
|
|
|
|
let body = match &ast.data {
|
|
Data::Struct(s) => hash_fields(&s.fields, quote! { self }),
|
|
Data::Enum(e) => {
|
|
let arms = e.variants.iter().map(|variant| {
|
|
let variant_name = &variant.ident;
|
|
let (pattern, hash_body) = match &variant.fields {
|
|
Fields::Unit => (quote! {}, quote! {}),
|
|
Fields::Unnamed(fields) => {
|
|
let bindings: Vec<_> = (0..fields.unnamed.len())
|
|
.map(|i| {
|
|
let ident = proc_macro2::Ident::new(&format!("f{i}"), proc_macro2::Span::call_site());
|
|
quote! { #ident }
|
|
})
|
|
.collect();
|
|
let hash_stmts = fields.unnamed.iter().enumerate().filter_map(|(i, field)| {
|
|
if has_skip_attr(&field.attrs) {
|
|
return None;
|
|
}
|
|
let ident = proc_macro2::Ident::new(&format!("f{i}"), proc_macro2::Span::call_site());
|
|
Some(quote! { graphene_hash::CacheHash::cache_hash(#ident, state); })
|
|
});
|
|
(quote! { (#(#bindings,)*) }, quote! { #(#hash_stmts)* })
|
|
}
|
|
Fields::Named(fields) => {
|
|
let names: Vec<_> = fields.named.iter().map(|f| f.ident.as_ref().unwrap()).collect();
|
|
let hash_stmts = fields.named.iter().filter_map(|field| {
|
|
if has_skip_attr(&field.attrs) {
|
|
return None;
|
|
}
|
|
let ident = field.ident.as_ref().unwrap();
|
|
Some(quote! { graphene_hash::CacheHash::cache_hash(#ident, state); })
|
|
});
|
|
(quote! { { #(#names,)* } }, quote! { #(#hash_stmts)* })
|
|
}
|
|
};
|
|
quote! {
|
|
Self::#variant_name #pattern => { #hash_body }
|
|
}
|
|
});
|
|
quote! {
|
|
::core::hash::Hash::hash(&::core::mem::discriminant(self), state);
|
|
match self {
|
|
#(#arms)*
|
|
}
|
|
}
|
|
}
|
|
Data::Union(_) => return syn::Error::new(ast.ident.span(), "CacheHash cannot be derived for unions").to_compile_error().into(),
|
|
};
|
|
|
|
quote! {
|
|
impl #impl_generics graphene_hash::CacheHash for #name #ty_generics #where_clause {
|
|
fn cache_hash<H: ::core::hash::Hasher>(&self, state: &mut H) {
|
|
#body
|
|
}
|
|
}
|
|
}
|
|
.into()
|
|
}
|
|
|
|
fn hash_fields(fields: &Fields, self_expr: TokenStream2) -> TokenStream2 {
|
|
match fields {
|
|
Fields::Unit => quote! {},
|
|
Fields::Unnamed(fields) => {
|
|
let stmts = fields.unnamed.iter().enumerate().filter_map(|(i, field)| {
|
|
if has_skip_attr(&field.attrs) {
|
|
return None;
|
|
}
|
|
let index = syn::Index::from(i);
|
|
Some(quote! { graphene_hash::CacheHash::cache_hash(&#self_expr.#index, state); })
|
|
});
|
|
quote! { #(#stmts)* }
|
|
}
|
|
Fields::Named(fields) => {
|
|
let stmts = fields.named.iter().filter_map(|field| {
|
|
if has_skip_attr(&field.attrs) {
|
|
return None;
|
|
}
|
|
let ident = field.ident.as_ref().unwrap();
|
|
Some(quote! { graphene_hash::CacheHash::cache_hash(&#self_expr.#ident, state); })
|
|
});
|
|
quote! { #(#stmts)* }
|
|
}
|
|
}
|
|
}
|
|
|
|
fn has_skip_attr(attrs: &[syn::Attribute]) -> bool {
|
|
attrs.iter().any(|attr| {
|
|
if !attr.path().is_ident("cache_hash") {
|
|
return false;
|
|
}
|
|
attr.parse_args::<syn::Ident>().map(|id| id == "skip").unwrap_or(false)
|
|
})
|
|
}
|