Migrate memo nodes to node macro and make implementing other persistent nodes easier (#3552)

* Add #[data] and #[serialize] attributes to node macro

- Add #[data] attribute for struct fields that aren't node parameters
  - Data fields are initialized with Default::default()
  - Passed as references to the underlying function
  - Excluded from registry metadata (internal state)
  - Generic types in data fields allowed without #[implementations]

- Add #[serialize] attribute for custom Node::serialize() implementation
  - Receives references to all data fields
  - Generates serialize() method in Node trait impl

- Conditional derives based on data field presence
  - With data fields: Debug, Clone only
  - Without data fields: Debug, Copy, Clone, PartialEq, Eq, PartialOrd, Ord, Hash

* Refactor Memo and Monitor Node to use node macro

* Move Complex type into type alias

* Fix format

* Update node-graph/nodes/gcore/src/memo.rs

Co-authored-by: Keavon Chambers <keavon@keavon.com>

* Update node-graph/nodes/gcore/src/memo.rs

Co-authored-by: Keavon Chambers <keavon@keavon.com>

---------

Co-authored-by: Keavon Chambers <keavon@keavon.com>
This commit is contained in:
Dennis Kobert 2026-01-05 22:28:02 +01:00 committed by GitHub
parent 8f25eb6ca4
commit fafc687d84
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
5 changed files with 249 additions and 115 deletions

View File

@ -37,33 +37,102 @@ pub(crate) fn generate_node_code(crate_ident: &CrateIdent, parsed: &ParsedNodeFn
}; };
let struct_name = format_ident!("{}Node", struct_name); let struct_name = format_ident!("{}Node", struct_name);
let struct_generics: Vec<Ident> = fields.iter().enumerate().map(|(i, _)| format_ident!("Node{}", i)).collect(); // Separate data fields from regular fields
let (data_fields, regular_fields): (Vec<_>, Vec<_>) = fields.iter().partition(|f| f.is_data_field);
// Extract function generics used by data fields
let data_field_generics: Vec<_> = fn_generics
.iter()
.filter(|generic| {
let generic_ident = match generic {
syn::GenericParam::Type(type_param) => &type_param.ident,
_ => return false,
};
// Check if this generic is used in any data field type
data_fields.iter().any(|field| match &field.ty {
ParsedFieldType::Regular(RegularParsedField { ty, .. }) => type_contains_ident(ty, generic_ident),
_ => false,
})
})
.cloned()
.collect();
// Node generics for regular fields (Node0, Node1, ...)
let node_generics: Vec<Ident> = regular_fields.iter().enumerate().map(|(i, _)| format_ident!("Node{}", i)).collect();
// Extract just the idents from data_field_generics for struct type parameters
let data_field_generic_idents: Vec<Ident> = data_field_generics
.iter()
.filter_map(|gp| match gp {
syn::GenericParam::Type(tp) => Some(tp.ident.clone()),
_ => None,
})
.collect();
// Combined struct type parameters: data field generic idents (T, U, ...) + node generics (Node0, Node1, ...)
// For struct type instantiation: MemoNode<T, Node0>
let struct_type_params: Vec<Ident> = data_field_generic_idents.iter().cloned().chain(node_generics.iter().cloned()).collect();
// Combined struct generic parameters with bounds for struct definition
// struct MemoNode<T: Clone, Node0>
let struct_generic_params: Vec<TokenStream2> = data_field_generics.iter().map(|gp| quote!(#gp)).chain(node_generics.iter().map(|id| quote!(#id))).collect();
let input_ident = &input.pat_ident; let input_ident = &input.pat_ident;
let context_features = &input.context_features; let context_features = &input.context_features;
let field_idents: Vec<_> = fields.iter().map(|f| &f.pat_ident).collect(); // Regular field idents and names (for function parameters)
let field_idents: Vec<_> = regular_fields.iter().map(|f| &f.pat_ident).collect();
let field_names: Vec<_> = field_idents.iter().map(|pat_ident| &pat_ident.ident).collect(); let field_names: Vec<_> = field_idents.iter().map(|pat_ident| &pat_ident.ident).collect();
let regular_field_names: Vec<_> = regular_fields.iter().map(|f| &f.pat_ident.ident).collect();
let data_field_names: Vec<_> = data_fields.iter().map(|f| &f.pat_ident.ident).collect();
let input_names: Vec<_> = fields // Only regular fields have input names/descriptions (for UI)
let input_names: Vec<_> = regular_fields
.iter() .iter()
.map(|f| &f.name) .map(|f| &f.name)
.zip(field_names.iter()) .zip(regular_field_names.iter())
.map(|zipped| match zipped { .map(|zipped| match zipped {
(Some(name), _) => name.value(), (Some(name), _) => name.value(),
(_, name) => name.to_string().to_case(Case::Title), (_, name) => name.to_string().to_case(Case::Title),
}) })
.collect(); .collect();
let input_descriptions: Vec<_> = fields.iter().map(|f| &f.description).collect(); let input_descriptions: Vec<_> = regular_fields.iter().map(|f| &f.description).collect();
let struct_fields = field_names.iter().zip(struct_generics.iter()).map(|(name, r#gen)| { // Generate struct fields: data fields (concrete types) + regular fields (generic types)
let data_field_defs = data_fields.iter().map(|field| {
let name = &field.pat_ident.ident;
let ty = match &field.ty {
ParsedFieldType::Regular(RegularParsedField { ty, .. }) => ty,
_ => unreachable!("Data fields must be Regular types, not Node types"),
};
quote! { pub(super) #name: #ty }
});
let regular_field_defs = regular_field_names.iter().zip(node_generics.iter()).map(|(name, r#gen)| {
quote! { pub(super) #name: #r#gen } quote! { pub(super) #name: #r#gen }
}); });
let struct_fields = data_field_defs.chain(regular_field_defs);
let mut future_idents = Vec::new(); let mut future_idents = Vec::new();
let field_types: Vec<_> = fields // Data fields get passed as references to the underlying function
let data_field_idents: Vec<_> = data_fields.iter().map(|f| &f.pat_ident).collect();
let data_field_types: Vec<_> = data_fields
.iter()
.map(|field| match &field.ty {
ParsedFieldType::Regular(RegularParsedField { ty, .. }) => {
let ty = ty.clone();
quote!(&#ty)
}
_ => unreachable!("Data fields must be Regular types, not Node types"),
})
.collect();
// Regular fields have types passed to the function
let field_types: Vec<_> = regular_fields
.iter() .iter()
.map(|field| match &field.ty { .map(|field| match &field.ty {
ParsedFieldType::Regular(RegularParsedField { ty, .. }) => ty.clone(), ParsedFieldType::Regular(RegularParsedField { ty, .. }) => ty.clone(),
@ -74,7 +143,8 @@ pub(crate) fn generate_node_code(crate_ident: &CrateIdent, parsed: &ParsedNodeFn
}) })
.collect(); .collect();
let widget_override: Vec<_> = fields // Only regular fields have UI metadata (data fields are internal state)
let widget_override: Vec<_> = regular_fields
.iter() .iter()
.map(|field| match &field.widget_override { .map(|field| match &field.widget_override {
ParsedWidgetOverride::None => quote!(RegistryWidgetOverride::None), ParsedWidgetOverride::None => quote!(RegistryWidgetOverride::None),
@ -84,7 +154,7 @@ pub(crate) fn generate_node_code(crate_ident: &CrateIdent, parsed: &ParsedNodeFn
}) })
.collect(); .collect();
let value_sources: Vec<_> = fields let value_sources: Vec<_> = regular_fields
.iter() .iter()
.map(|field| match &field.ty { .map(|field| match &field.ty {
ParsedFieldType::Regular(RegularParsedField { value_source, .. }) => match value_source { ParsedFieldType::Regular(RegularParsedField { value_source, .. }) => match value_source {
@ -104,7 +174,7 @@ pub(crate) fn generate_node_code(crate_ident: &CrateIdent, parsed: &ParsedNodeFn
}) })
.collect(); .collect();
let default_types: Vec<_> = fields let default_types: Vec<_> = regular_fields
.iter() .iter()
.map(|field| match &field.ty { .map(|field| match &field.ty {
ParsedFieldType::Regular(RegularParsedField { implementations, .. }) => match implementations.first() { ParsedFieldType::Regular(RegularParsedField { implementations, .. }) => match implementations.first() {
@ -115,7 +185,7 @@ pub(crate) fn generate_node_code(crate_ident: &CrateIdent, parsed: &ParsedNodeFn
}) })
.collect(); .collect();
let number_min_values: Vec<_> = fields let number_min_values: Vec<_> = regular_fields
.iter() .iter()
.map(|field| match &field.ty { .map(|field| match &field.ty {
ParsedFieldType::Regular(RegularParsedField { number_soft_min, number_hard_min, .. }) => match (number_soft_min, number_hard_min) { ParsedFieldType::Regular(RegularParsedField { number_soft_min, number_hard_min, .. }) => match (number_soft_min, number_hard_min) {
@ -126,7 +196,7 @@ pub(crate) fn generate_node_code(crate_ident: &CrateIdent, parsed: &ParsedNodeFn
_ => quote!(None), _ => quote!(None),
}) })
.collect(); .collect();
let number_max_values: Vec<_> = fields let number_max_values: Vec<_> = regular_fields
.iter() .iter()
.map(|field| match &field.ty { .map(|field| match &field.ty {
ParsedFieldType::Regular(RegularParsedField { number_soft_max, number_hard_max, .. }) => match (number_soft_max, number_hard_max) { ParsedFieldType::Regular(RegularParsedField { number_soft_max, number_hard_max, .. }) => match (number_soft_max, number_hard_max) {
@ -137,7 +207,7 @@ pub(crate) fn generate_node_code(crate_ident: &CrateIdent, parsed: &ParsedNodeFn
_ => quote!(None), _ => quote!(None),
}) })
.collect(); .collect();
let number_mode_range_values: Vec<_> = fields let number_mode_range_values: Vec<_> = regular_fields
.iter() .iter()
.map(|field| match &field.ty { .map(|field| match &field.ty {
ParsedFieldType::Regular(RegularParsedField { ParsedFieldType::Regular(RegularParsedField {
@ -147,15 +217,15 @@ pub(crate) fn generate_node_code(crate_ident: &CrateIdent, parsed: &ParsedNodeFn
_ => quote!(None), _ => quote!(None),
}) })
.collect(); .collect();
let number_display_decimal_places: Vec<_> = fields let number_display_decimal_places: Vec<_> = regular_fields
.iter() .iter()
.map(|field| field.number_display_decimal_places.as_ref().map_or(quote!(None), |i| quote!(Some(#i)))) .map(|field| field.number_display_decimal_places.as_ref().map_or(quote!(None), |i| quote!(Some(#i))))
.collect(); .collect();
let number_step: Vec<_> = fields.iter().map(|field| field.number_step.as_ref().map_or(quote!(None), |i| quote!(Some(#i)))).collect(); let number_step: Vec<_> = regular_fields.iter().map(|field| field.number_step.as_ref().map_or(quote!(None), |i| quote!(Some(#i)))).collect();
let unit_suffix: Vec<_> = fields.iter().map(|field| field.unit.as_ref().map_or(quote!(None), |i| quote!(Some(#i)))).collect(); let unit_suffix: Vec<_> = regular_fields.iter().map(|field| field.unit.as_ref().map_or(quote!(None), |i| quote!(Some(#i)))).collect();
let exposed: Vec<_> = fields let exposed: Vec<_> = regular_fields
.iter() .iter()
.map(|field| match &field.ty { .map(|field| match &field.ty {
ParsedFieldType::Regular(RegularParsedField { exposed, .. }) => quote!(#exposed), ParsedFieldType::Regular(RegularParsedField { exposed, .. }) => quote!(#exposed),
@ -163,7 +233,8 @@ pub(crate) fn generate_node_code(crate_ident: &CrateIdent, parsed: &ParsedNodeFn
}) })
.collect(); .collect();
let eval_args = fields.iter().map(|field| { // Only eval regular fields (data fields are accessed directly as self.field_name)
let eval_args = regular_fields.iter().map(|field| {
let name = &field.pat_ident.ident; let name = &field.pat_ident.ident;
match &field.ty { match &field.ty {
ParsedFieldType::Regular { .. } => { ParsedFieldType::Regular { .. } => {
@ -175,7 +246,8 @@ pub(crate) fn generate_node_code(crate_ident: &CrateIdent, parsed: &ParsedNodeFn
} }
}); });
let min_max_args = fields.iter().map(|field| match &field.ty { // Only regular fields can have min/max constraints
let min_max_args = regular_fields.iter().map(|field| match &field.ty {
ParsedFieldType::Regular(RegularParsedField { number_hard_min, number_hard_max, .. }) => { ParsedFieldType::Regular(RegularParsedField { number_hard_min, number_hard_max, .. }) => {
let name = &field.pat_ident.ident; let name = &field.pat_ident.ident;
let mut tokens = quote!(); let mut tokens = quote!();
@ -208,7 +280,7 @@ pub(crate) fn generate_node_code(crate_ident: &CrateIdent, parsed: &ParsedNodeFn
let mut clauses = Vec::new(); let mut clauses = Vec::new();
let mut clampable_clauses = Vec::new(); let mut clampable_clauses = Vec::new();
for (field, name) in fields.iter().zip(struct_generics.iter()) { for (field, name) in regular_fields.iter().zip(node_generics.iter()) {
clauses.push(match (&field.ty, *is_async) { clauses.push(match (&field.ty, *is_async) {
( (
ParsedFieldType::Regular(RegularParsedField { ParsedFieldType::Regular(RegularParsedField {
@ -259,13 +331,42 @@ pub(crate) fn generate_node_code(crate_ident: &CrateIdent, parsed: &ParsedNodeFn
); );
struct_where_clause.predicates.extend(extra_where); struct_where_clause.predicates.extend(extra_where);
let new_args = struct_generics.iter().zip(field_names.iter()).map(|(r#gen, name)| { // Only regular fields are parameters to new()
let new_args = node_generics.iter().zip(regular_field_names.iter()).map(|(r#gen, name)| {
quote! { #name: #r#gen } quote! { #name: #r#gen }
}); });
// Initialize data fields with Default, regular fields with parameters
let data_inits = data_field_names.iter().map(|name| {
quote! { #name: Default::default() }
});
let regular_inits = regular_field_names.iter().map(|name| {
quote! { #name }
});
let all_field_inits = data_inits.chain(regular_inits);
let async_keyword = is_async.then(|| quote!(async)); let async_keyword = is_async.then(|| quote!(async));
let await_keyword = is_async.then(|| quote!(.await)); let await_keyword = is_async.then(|| quote!(.await));
// Data fields may not implement Copy, PartialEq, etc., so only derive Debug and Clone
let struct_derives = if data_fields.is_empty() {
quote!(#[derive(Debug, Copy, Clone, PartialEq, Eq, PartialOrd, Ord, Hash)])
} else {
quote!(#[derive(Debug, Clone)])
};
// Generate serialize method if serialize attribute is specified
let serialize_impl = if let Some(serialize_fn) = &parsed.attributes.serialize {
let data_field_refs = data_field_names.iter().map(|name| quote!(&self.#name));
quote! {
fn serialize(&self) -> Option<std::sync::Arc<dyn std::any::Any + Send + Sync>> {
#serialize_fn(#(#data_field_refs),*)
}
}
} else {
quote!()
};
let eval_impl = quote! { let eval_impl = quote! {
type Output = #core_types::registry::DynFuture<'n, #output_type>; type Output = #core_types::registry::DynFuture<'n, #output_type>;
#[inline] #[inline]
@ -275,9 +376,11 @@ pub(crate) fn generate_node_code(crate_ident: &CrateIdent, parsed: &ParsedNodeFn
#(#eval_args)* #(#eval_args)*
#(#min_max_args)* #(#min_max_args)*
self::#fn_name(__input #(, #field_names)*) #await_keyword self::#fn_name(__input #(, &self.#data_field_names)* #(, #regular_field_names)*) #await_keyword
}) })
} }
#serialize_impl
}; };
let identifier = format_ident!("{}_proto_ident", fn_name); let identifier = format_ident!("{}_proto_ident", fn_name);
@ -302,11 +405,11 @@ pub(crate) fn generate_node_code(crate_ident: &CrateIdent, parsed: &ParsedNodeFn
/// Underlying implementation for [#struct_name] /// Underlying implementation for [#struct_name]
#[inline] #[inline]
#[allow(clippy::too_many_arguments)] #[allow(clippy::too_many_arguments)]
#vis #async_keyword fn #fn_name <'n, #(#fn_generics,)*> (#input_ident: #input_type #(, #field_idents: #field_types)*) -> #output_type #where_clause #body #vis #async_keyword fn #fn_name <'n, #(#fn_generics,)*> (#input_ident: #input_type #(, #data_field_idents: #data_field_types)* #(, #field_idents: #field_types)*) -> #output_type #where_clause #body
#cfg #cfg
#[automatically_derived] #[automatically_derived]
impl<'n, #(#fn_generics,)* #(#struct_generics,)* #(#future_idents,)*> #core_types::Node<'n, #input_type> for #mod_name::#struct_name<#(#struct_generics,)*> impl<'n, #(#fn_generics,)* #(#node_generics,)* #(#future_idents,)*> #core_types::Node<'n, #input_type> for #mod_name::#struct_name<#(#struct_type_params,)*>
#struct_where_clause #struct_where_clause
{ {
#eval_impl #eval_impl
@ -340,18 +443,18 @@ pub(crate) fn generate_node_code(crate_ident: &CrateIdent, parsed: &ParsedNodeFn
static #import_name: core::marker::PhantomData<(#(#all_implementation_types,)*)> = core::marker::PhantomData; static #import_name: core::marker::PhantomData<(#(#all_implementation_types,)*)> = core::marker::PhantomData;
#[derive(Debug, Copy, Clone, PartialEq, Eq, PartialOrd, Ord, Hash)] #struct_derives
pub struct #struct_name<#(#struct_generics,)*> { pub struct #struct_name<#(#struct_generic_params,)*> {
#(#struct_fields,)* #(#struct_fields,)*
} }
#[automatically_derived] #[automatically_derived]
impl<'n, #(#struct_generics,)*> #struct_name<#(#struct_generics,)*> impl<'n, #(#struct_generic_params,)*> #struct_name<#(#struct_type_params,)*>
{ {
#[allow(clippy::too_many_arguments)] #[allow(clippy::too_many_arguments)]
pub fn new(#(#new_args,)*) -> Self { pub fn new(#(#new_args,)*) -> Self {
Self { Self {
#(#field_names,)* #(#all_field_inits,)*
} }
} }
} }
@ -493,8 +596,10 @@ fn generate_register_node_impl(parsed: &ParsedNodeFn, field_names: &[&Ident], st
let mut constructors = Vec::new(); let mut constructors = Vec::new();
let unit = parse_quote!(gcore::Context); let unit = parse_quote!(gcore::Context);
let parameter_types: Vec<_> = parsed
.fields let regular_fields: Vec<_> = parsed.fields.iter().filter(|f| !f.is_data_field).collect();
let parameter_types: Vec<_> = regular_fields
.iter() .iter()
.map(|field| { .map(|field| {
match &field.ty { match &field.ty {
@ -535,7 +640,7 @@ fn generate_register_node_impl(parsed: &ParsedNodeFn, field_names: &[&Ident], st
let field_name = field_names[j]; let field_name = field_names[j];
let (input_type, output_type) = &types[i.min(types.len() - 1)]; let (input_type, output_type) = &types[i.min(types.len() - 1)];
let node = matches!(parsed.fields[j].ty, ParsedFieldType::Node { .. }); let node = matches!(regular_fields[j].ty, ParsedFieldType::Node { .. });
let downcast_node = quote!( let downcast_node = quote!(
let #field_name: DowncastBothNode<#input_type, #output_type> = DowncastBothNode::new(args[#j].clone()); let #field_name: DowncastBothNode<#input_type, #output_type> = DowncastBothNode::new(args[#j].clone());
@ -712,3 +817,23 @@ impl FilterUsedGenerics {
self.used(&*modified).cloned().collect() self.used(&*modified).cloned().collect()
} }
} }
/// Check if a type contains a reference to a specific identifier (e.g., a generic type parameter)
fn type_contains_ident(ty: &Type, ident: &Ident) -> bool {
struct IdentChecker<'a> {
target: &'a Ident,
found: bool,
}
impl<'a, 'ast> syn::visit::Visit<'ast> for IdentChecker<'a> {
fn visit_ident(&mut self, i: &'ast Ident) {
if i == self.target {
self.found = true;
}
}
}
let mut checker = IdentChecker { target: ident, found: false };
syn::visit::visit_type(&mut checker, ty);
checker.found
}

View File

@ -50,6 +50,8 @@ pub(crate) struct NodeFnAttributes {
pub(crate) cfg: Option<TokenStream2>, pub(crate) cfg: Option<TokenStream2>,
/// if this node should get a gpu implementation, defaults to None /// if this node should get a gpu implementation, defaults to None
pub(crate) shader_node: Option<ShaderNodeType>, pub(crate) shader_node: Option<ShaderNodeType>,
/// Custom serialization function path (e.g., "my_module::custom_serialize")
pub(crate) serialize: Option<Path>,
// Add more attributes as needed // Add more attributes as needed
} }
@ -112,6 +114,7 @@ pub struct ParsedField {
pub number_display_decimal_places: Option<LitInt>, pub number_display_decimal_places: Option<LitInt>,
pub number_step: Option<LitFloat>, pub number_step: Option<LitFloat>,
pub unit: Option<LitStr>, pub unit: Option<LitStr>,
pub is_data_field: bool,
} }
#[derive(Clone, Debug)] #[derive(Clone, Debug)]
@ -201,6 +204,7 @@ impl Parse for NodeFnAttributes {
let mut properties_string = None; let mut properties_string = None;
let mut cfg = None; let mut cfg = None;
let mut shader_node = None; let mut shader_node = None;
let mut serialize = None;
let content = input; let content = input;
// let content; // let content;
@ -270,13 +274,23 @@ impl Parse for NodeFnAttributes {
let meta = meta.require_list()?; let meta = meta.require_list()?;
shader_node = Some(syn::parse2(meta.tokens.to_token_stream())?); shader_node = Some(syn::parse2(meta.tokens.to_token_stream())?);
} }
"serialize" => {
let meta = meta.require_list()?;
if serialize.is_some() {
return Err(Error::new_spanned(meta, "Multiple 'serialize' attributes are not allowed"));
}
let parsed_path: Path = meta
.parse_args()
.map_err(|_| Error::new_spanned(meta, "Expected a valid path for 'serialize', e.g., serialize(my_module::custom_serialize)"))?;
serialize = Some(parsed_path);
}
_ => { _ => {
return Err(Error::new_spanned( return Err(Error::new_spanned(
meta, meta,
indoc!( indoc!(
r#" r#"
Unsupported attribute in `node`. Unsupported attribute in `node`.
Supported attributes are 'category', 'path' 'name', 'skip_impl', 'cfg' and 'properties'. Supported attributes are 'category', 'path', 'name', 'skip_impl', 'cfg', 'properties', 'serialize', and 'shader_node'.
Example usage: Example usage:
#[node_macro::node(category("Value"), name("Test Node"))] #[node_macro::node(category("Value"), name("Test Node"))]
@ -295,6 +309,7 @@ impl Parse for NodeFnAttributes {
properties_string, properties_string,
cfg, cfg,
shader_node, shader_node,
serialize,
}) })
} }
} }
@ -467,6 +482,9 @@ fn parse_node_implementations<T: Parse>(attr: &Attribute, name: &Ident) -> syn::
fn parse_field(pat_ident: PatIdent, ty: Type, attrs: &[Attribute]) -> syn::Result<ParsedField> { fn parse_field(pat_ident: PatIdent, ty: Type, attrs: &[Attribute]) -> syn::Result<ParsedField> {
let ident = &pat_ident.ident; let ident = &pat_ident.ident;
// Check if this is a data field (struct field, not a parameter)
let is_data_field = extract_attribute(attrs, "data").is_some();
let default_value = extract_attribute(attrs, "default") let default_value = extract_attribute(attrs, "default")
.map(|attr| attr.parse_args().map_err(|e| Error::new_spanned(attr, format!("Invalid `default` value for argument '{ident}': {e}")))) .map(|attr| attr.parse_args().map_err(|e| Error::new_spanned(attr, format!("Invalid `default` value for argument '{ident}': {e}"))))
.transpose()?; .transpose()?;
@ -489,6 +507,25 @@ fn parse_field(pat_ident: PatIdent, ty: Type, attrs: &[Attribute]) -> syn::Resul
let exposed = extract_attribute(attrs, "expose").is_some(); let exposed = extract_attribute(attrs, "expose").is_some();
// Validate data field attributes
if is_data_field {
if default_value.is_some() {
return Err(Error::new_spanned(
&pat_ident,
"Data fields (#[data]) cannot have #[default] attribute. They are automatically initialized with Default::default()",
));
}
if scope.is_some() {
return Err(Error::new_spanned(&pat_ident, "Data fields (#[data]) cannot have #[scope] attribute"));
}
if exposed {
return Err(Error::new_spanned(
&pat_ident,
"Data fields (#[data]) cannot be exposed (#[expose]). They are internal state, not node parameters",
));
}
}
let value_source = match (default_value, scope) { let value_source = match (default_value, scope) {
(Some(_), Some(_)) => return Err(Error::new_spanned(&pat_ident, "Cannot have both `default` and `scope` attributes")), (Some(_), Some(_)) => return Err(Error::new_spanned(&pat_ident, "Cannot have both `default` and `scope` attributes")),
(Some(default_value), _) => ParsedValueSource::Default(default_value), (Some(default_value), _) => ParsedValueSource::Default(default_value),
@ -586,6 +623,14 @@ fn parse_field(pat_ident: PatIdent, ty: Type, attrs: &[Attribute]) -> syn::Resul
.fold(String::new(), |acc, b| acc + &b + "\n"); .fold(String::new(), |acc, b| acc + &b + "\n");
if is_node { if is_node {
// Data fields cannot be impl Node types
if is_data_field {
return Err(Error::new_spanned(
&ty,
"Data fields (#[data]) cannot be of type `impl Node`. Data fields must be concrete types that implement Default",
));
}
let (input_type, output_type) = node_input_type let (input_type, output_type) = node_input_type
.zip(node_output_type) .zip(node_output_type)
.ok_or_else(|| Error::new_spanned(&ty, "Invalid Node type. Expected `impl Node<Input, Output = OutputType>`"))?; .ok_or_else(|| Error::new_spanned(&ty, "Invalid Node type. Expected `impl Node<Input, Output = OutputType>`"))?;
@ -610,6 +655,7 @@ fn parse_field(pat_ident: PatIdent, ty: Type, attrs: &[Attribute]) -> syn::Resul
number_display_decimal_places, number_display_decimal_places,
number_step, number_step,
unit, unit,
is_data_field,
}) })
} else { } else {
let implementations = extract_attribute(attrs, "implementations") let implementations = extract_attribute(attrs, "implementations")
@ -636,6 +682,7 @@ fn parse_field(pat_ident: PatIdent, ty: Type, attrs: &[Attribute]) -> syn::Resul
number_display_decimal_places, number_display_decimal_places,
number_step, number_step,
unit, unit,
is_data_field,
}) })
} }
} }
@ -826,6 +873,7 @@ mod tests {
properties_string: None, properties_string: None,
cfg: None, cfg: None,
shader_node: None, shader_node: None,
serialize: None,
}, },
fn_name: Ident::new("add", Span::call_site()), fn_name: Ident::new("add", Span::call_site()),
struct_name: Ident::new("Add", Span::call_site()), struct_name: Ident::new("Add", Span::call_site()),
@ -860,6 +908,7 @@ mod tests {
number_display_decimal_places: None, number_display_decimal_places: None,
number_step: None, number_step: None,
unit: None, unit: None,
is_data_field: false,
}], }],
body: TokenStream2::new(), body: TokenStream2::new(),
description: String::from("Multi\nLine\n"), description: String::from("Multi\nLine\n"),
@ -892,6 +941,7 @@ mod tests {
properties_string: None, properties_string: None,
cfg: None, cfg: None,
shader_node: None, shader_node: None,
serialize: None,
}, },
fn_name: Ident::new("transform", Span::call_site()), fn_name: Ident::new("transform", Span::call_site()),
struct_name: Ident::new("Transform", Span::call_site()), struct_name: Ident::new("Transform", Span::call_site()),
@ -920,6 +970,7 @@ mod tests {
number_display_decimal_places: None, number_display_decimal_places: None,
number_step: None, number_step: None,
unit: None, unit: None,
is_data_field: false,
}, },
ParsedField { ParsedField {
pat_ident: pat_ident("translate"), pat_ident: pat_ident("translate"),
@ -941,6 +992,7 @@ mod tests {
number_display_decimal_places: None, number_display_decimal_places: None,
number_step: None, number_step: None,
unit: None, unit: None,
is_data_field: false,
}, },
], ],
body: TokenStream2::new(), body: TokenStream2::new(),
@ -971,6 +1023,7 @@ mod tests {
properties_string: None, properties_string: None,
cfg: None, cfg: None,
shader_node: None, shader_node: None,
serialize: None,
}, },
fn_name: Ident::new("circle", Span::call_site()), fn_name: Ident::new("circle", Span::call_site()),
struct_name: Ident::new("Circle", Span::call_site()), struct_name: Ident::new("Circle", Span::call_site()),
@ -1005,6 +1058,7 @@ mod tests {
number_display_decimal_places: None, number_display_decimal_places: None,
number_step: None, number_step: None,
unit: None, unit: None,
is_data_field: false,
}], }],
body: TokenStream2::new(), body: TokenStream2::new(),
description: "Test\n".into(), description: "Test\n".into(),
@ -1033,6 +1087,7 @@ mod tests {
properties_string: None, properties_string: None,
cfg: None, cfg: None,
shader_node: None, shader_node: None,
serialize: None,
}, },
fn_name: Ident::new("levels", Span::call_site()), fn_name: Ident::new("levels", Span::call_site()),
struct_name: Ident::new("Levels", Span::call_site()), struct_name: Ident::new("Levels", Span::call_site()),
@ -1072,6 +1127,7 @@ mod tests {
number_display_decimal_places: None, number_display_decimal_places: None,
number_step: None, number_step: None,
unit: None, unit: None,
is_data_field: false,
}], }],
body: TokenStream2::new(), body: TokenStream2::new(),
description: String::new(), description: String::new(),
@ -1107,6 +1163,7 @@ mod tests {
properties_string: None, properties_string: None,
cfg: None, cfg: None,
shader_node: None, shader_node: None,
serialize: None,
}, },
fn_name: Ident::new("add", Span::call_site()), fn_name: Ident::new("add", Span::call_site()),
struct_name: Ident::new("Add", Span::call_site()), struct_name: Ident::new("Add", Span::call_site()),
@ -1141,6 +1198,7 @@ mod tests {
number_display_decimal_places: None, number_display_decimal_places: None,
number_step: None, number_step: None,
unit: None, unit: None,
is_data_field: false,
}], }],
body: TokenStream2::new(), body: TokenStream2::new(),
description: String::new(), description: String::new(),
@ -1169,6 +1227,7 @@ mod tests {
properties_string: None, properties_string: None,
cfg: None, cfg: None,
shader_node: None, shader_node: None,
serialize: None,
}, },
fn_name: Ident::new("load_image", Span::call_site()), fn_name: Ident::new("load_image", Span::call_site()),
struct_name: Ident::new("LoadImage", Span::call_site()), struct_name: Ident::new("LoadImage", Span::call_site()),
@ -1203,6 +1262,7 @@ mod tests {
number_display_decimal_places: None, number_display_decimal_places: None,
number_step: None, number_step: None,
unit: None, unit: None,
is_data_field: false,
}], }],
body: TokenStream2::new(), body: TokenStream2::new(),
description: String::new(), description: String::new(),
@ -1231,6 +1291,7 @@ mod tests {
properties_string: None, properties_string: None,
cfg: None, cfg: None,
shader_node: None, shader_node: None,
serialize: None,
}, },
fn_name: Ident::new("custom_node", Span::call_site()), fn_name: Ident::new("custom_node", Span::call_site()),
struct_name: Ident::new("CustomNode", Span::call_site()), struct_name: Ident::new("CustomNode", Span::call_site()),

View File

@ -245,6 +245,7 @@ impl PerPixelAdjustCodegen<'_> {
number_display_decimal_places: None, number_display_decimal_places: None,
number_step: None, number_step: None,
unit: None, unit: None,
is_data_field: false,
}); });
// find exactly one gpu_image field, runtime doesn't support more than 1 atm // find exactly one gpu_image field, runtime doesn't support more than 1 atm

View File

@ -102,6 +102,11 @@ fn validate_implementations_for_generics(parsed: &ParsedNodeFn) {
if !has_skip_impl && !parsed.fn_generics.is_empty() { if !has_skip_impl && !parsed.fn_generics.is_empty() {
for field in &parsed.fields { for field in &parsed.fields {
// Skip validation for data fields - they're internal state and can be generic
if field.is_data_field {
continue;
}
let pat_ident = &field.pat_ident; let pat_ident = &field.pat_ident;
match &field.ty { match &field.ty {
ParsedFieldType::Regular(RegularParsedField { ty, implementations, .. }) => { ParsedFieldType::Regular(RegularParsedField { ty, implementations, .. }) => {

View File

@ -1,7 +1,5 @@
use core_types::WasmNotSend;
use core_types::memo::*; use core_types::memo::*;
use core_types::{Node, WasmNotSend};
use dyn_any::DynFuture;
use std::future::Future;
use std::hash::DefaultHasher; use std::hash::DefaultHasher;
use std::hash::{Hash, Hasher}; use std::hash::{Hash, Hasher};
use std::sync::Arc; use std::sync::Arc;
@ -14,94 +12,38 @@ use std::sync::Mutex;
/// A cache hit occurs when the Option is Some and has a stored hash matching the hash of the call argument. In this case, the node returns the cached value without re-evaluating the inner node. /// A cache hit occurs when the Option is Some and has a stored hash matching the hash of the call argument. In this case, the node returns the cached value without re-evaluating the inner node.
/// ///
/// Currently, only one input-output pair is cached. Subsequent calls with different inputs will overwrite the previous cache. /// Currently, only one input-output pair is cached. Subsequent calls with different inputs will overwrite the previous cache.
#[derive(Default)] #[node_macro::node(category(""), path(graphene_core::memo), skip_impl)]
pub struct MemoNode<T, CachedNode> { async fn memo<I: Hash + Send + 'n, T: Clone + WasmNotSend>(input: I, #[data] cache: Arc<Mutex<Option<(u64, T)>>>, node: impl Node<I, Output = T>) -> T {
cache: Arc<Mutex<Option<(u64, T)>>>, let mut hasher = DefaultHasher::new();
node: CachedNode, input.hash(&mut hasher);
} let hash = hasher.finish();
impl<'i, I: Hash + 'i, T: 'i + Clone + WasmNotSend, CachedNode: 'i> Node<'i, I> for MemoNode<T, CachedNode>
where
CachedNode: for<'any_input> Node<'any_input, I>,
for<'a> <CachedNode as Node<'a, I>>::Output: Future<Output = T> + WasmNotSend,
{
// TODO: This should return a reference to the cached cached_value
// but that requires a lot of lifetime magic <- This was suggested by copilot but is pretty accurate xD
type Output = DynFuture<'i, T>;
fn eval(&'i self, input: I) -> Self::Output {
let mut hasher = DefaultHasher::new();
input.hash(&mut hasher);
let hash = hasher.finish();
if let Some(data) = self.cache.lock().as_ref().unwrap().as_ref().and_then(|data| (data.0 == hash).then_some(data.1.clone())) { if let Some(data) = cache.lock().as_ref().unwrap().as_ref().and_then(|data| (data.0 == hash).then_some(data.1.clone())) {
Box::pin(async move { data }) return data;
} else {
let fut = self.node.eval(input);
let cache = self.cache.clone();
Box::pin(async move {
let value = fut.await;
*cache.lock().unwrap() = Some((hash, value.clone()));
value
})
}
} }
fn reset(&self) { let value = node.eval(input).await;
self.cache.lock().unwrap().take(); *cache.lock().unwrap() = Some((hash, value.clone()));
} value
} }
impl<T, CachedNode> MemoNode<T, CachedNode> { type MonitorValue<I, T> = Arc<Mutex<Option<Arc<IORecord<I, T>>>>>;
pub fn new(node: CachedNode) -> MemoNode<T, CachedNode> {
MemoNode { cache: Default::default(), node }
}
}
#[allow(clippy::module_inception)]
pub mod memo {
use core_types::ProtoNodeIdentifier;
pub const IDENTIFIER: ProtoNodeIdentifier = ProtoNodeIdentifier::new("graphene_core::memo::MemoNode");
}
/// Caches the output of the last graph evaluation for introspection. /// Caches the output of the last graph evaluation for introspection.
#[derive(Default)] #[node_macro::node(category(""), path(graphene_core::memo), serialize(serialize_monitor), skip_impl)]
pub struct MonitorNode<I, T, N> { async fn monitor<I: Clone + 'static + Send + Sync, T: Clone + 'static + Send + Sync>(
input: I,
#[allow(clippy::type_complexity)] #[allow(clippy::type_complexity)]
io: Arc<Mutex<Option<Arc<IORecord<I, T>>>>>, #[data]
node: N, io: MonitorValue<I, T>,
node: impl Node<I, Output = T>,
) -> T {
let output = node.eval(input.clone()).await;
*io.lock().unwrap() = Some(Arc::new(IORecord { input, output: output.clone() }));
output
} }
impl<'i, T, I, N> Node<'i, I> for MonitorNode<I, T, N> fn serialize_monitor<I: Clone + 'static + Send + Sync, T: Clone + 'static + Send + Sync>(io: &MonitorValue<I, T>) -> Option<Arc<dyn std::any::Any + Send + Sync>> {
where let io = io.lock().unwrap();
I: Clone + 'static + Send + Sync, io.as_ref().map(|output| output.clone() as Arc<dyn std::any::Any + Send + Sync>)
T: Clone + 'static + Send + Sync,
for<'a> N: Node<'a, I, Output: Future<Output = T> + WasmNotSend> + 'i,
{
type Output = DynFuture<'i, T>;
fn eval(&'i self, input: I) -> Self::Output {
let io = self.io.clone();
let output_fut = self.node.eval(input.clone());
Box::pin(async move {
let output = output_fut.await;
*io.lock().unwrap() = Some(Arc::new(IORecord { input, output: output.clone() }));
output
})
}
fn serialize(&self) -> Option<Arc<dyn std::any::Any + Send + Sync>> {
let io = self.io.lock().unwrap();
(io).as_ref().map(|output| output.clone() as Arc<dyn std::any::Any + Send + Sync>)
}
}
impl<I, T, N> MonitorNode<I, T, N> {
pub fn new(node: N) -> MonitorNode<I, T, N> {
MonitorNode { io: Arc::new(Mutex::new(None)), node }
}
}
pub mod monitor {
use core_types::ProtoNodeIdentifier;
pub const IDENTIFIER: ProtoNodeIdentifier = ProtoNodeIdentifier::new("graphene_core::memo::MonitorNode");
} }