Migrate memo nodes to node macro and make implementing other persistent nodes easier (#3552)
* Add #[data] and #[serialize] attributes to node macro - Add #[data] attribute for struct fields that aren't node parameters - Data fields are initialized with Default::default() - Passed as references to the underlying function - Excluded from registry metadata (internal state) - Generic types in data fields allowed without #[implementations] - Add #[serialize] attribute for custom Node::serialize() implementation - Receives references to all data fields - Generates serialize() method in Node trait impl - Conditional derives based on data field presence - With data fields: Debug, Clone only - Without data fields: Debug, Copy, Clone, PartialEq, Eq, PartialOrd, Ord, Hash * Refactor Memo and Monitor Node to use node macro * Move Complex type into type alias * Fix format * Update node-graph/nodes/gcore/src/memo.rs Co-authored-by: Keavon Chambers <keavon@keavon.com> * Update node-graph/nodes/gcore/src/memo.rs Co-authored-by: Keavon Chambers <keavon@keavon.com> --------- Co-authored-by: Keavon Chambers <keavon@keavon.com>
This commit is contained in:
parent
8f25eb6ca4
commit
fafc687d84
|
|
@ -37,33 +37,102 @@ pub(crate) fn generate_node_code(crate_ident: &CrateIdent, parsed: &ParsedNodeFn
|
|||
};
|
||||
let struct_name = format_ident!("{}Node", struct_name);
|
||||
|
||||
let struct_generics: Vec<Ident> = fields.iter().enumerate().map(|(i, _)| format_ident!("Node{}", i)).collect();
|
||||
// Separate data fields from regular fields
|
||||
let (data_fields, regular_fields): (Vec<_>, Vec<_>) = fields.iter().partition(|f| f.is_data_field);
|
||||
|
||||
// Extract function generics used by data fields
|
||||
let data_field_generics: Vec<_> = fn_generics
|
||||
.iter()
|
||||
.filter(|generic| {
|
||||
let generic_ident = match generic {
|
||||
syn::GenericParam::Type(type_param) => &type_param.ident,
|
||||
_ => return false,
|
||||
};
|
||||
|
||||
// Check if this generic is used in any data field type
|
||||
data_fields.iter().any(|field| match &field.ty {
|
||||
ParsedFieldType::Regular(RegularParsedField { ty, .. }) => type_contains_ident(ty, generic_ident),
|
||||
_ => false,
|
||||
})
|
||||
})
|
||||
.cloned()
|
||||
.collect();
|
||||
|
||||
// Node generics for regular fields (Node0, Node1, ...)
|
||||
let node_generics: Vec<Ident> = regular_fields.iter().enumerate().map(|(i, _)| format_ident!("Node{}", i)).collect();
|
||||
|
||||
// Extract just the idents from data_field_generics for struct type parameters
|
||||
let data_field_generic_idents: Vec<Ident> = data_field_generics
|
||||
.iter()
|
||||
.filter_map(|gp| match gp {
|
||||
syn::GenericParam::Type(tp) => Some(tp.ident.clone()),
|
||||
_ => None,
|
||||
})
|
||||
.collect();
|
||||
|
||||
// Combined struct type parameters: data field generic idents (T, U, ...) + node generics (Node0, Node1, ...)
|
||||
// For struct type instantiation: MemoNode<T, Node0>
|
||||
let struct_type_params: Vec<Ident> = data_field_generic_idents.iter().cloned().chain(node_generics.iter().cloned()).collect();
|
||||
|
||||
// Combined struct generic parameters with bounds for struct definition
|
||||
// struct MemoNode<T: Clone, Node0>
|
||||
let struct_generic_params: Vec<TokenStream2> = data_field_generics.iter().map(|gp| quote!(#gp)).chain(node_generics.iter().map(|id| quote!(#id))).collect();
|
||||
let input_ident = &input.pat_ident;
|
||||
|
||||
let context_features = &input.context_features;
|
||||
|
||||
let field_idents: Vec<_> = fields.iter().map(|f| &f.pat_ident).collect();
|
||||
// Regular field idents and names (for function parameters)
|
||||
let field_idents: Vec<_> = regular_fields.iter().map(|f| &f.pat_ident).collect();
|
||||
let field_names: Vec<_> = field_idents.iter().map(|pat_ident| &pat_ident.ident).collect();
|
||||
let regular_field_names: Vec<_> = regular_fields.iter().map(|f| &f.pat_ident.ident).collect();
|
||||
let data_field_names: Vec<_> = data_fields.iter().map(|f| &f.pat_ident.ident).collect();
|
||||
|
||||
let input_names: Vec<_> = fields
|
||||
// Only regular fields have input names/descriptions (for UI)
|
||||
let input_names: Vec<_> = regular_fields
|
||||
.iter()
|
||||
.map(|f| &f.name)
|
||||
.zip(field_names.iter())
|
||||
.zip(regular_field_names.iter())
|
||||
.map(|zipped| match zipped {
|
||||
(Some(name), _) => name.value(),
|
||||
(_, name) => name.to_string().to_case(Case::Title),
|
||||
})
|
||||
.collect();
|
||||
|
||||
let input_descriptions: Vec<_> = fields.iter().map(|f| &f.description).collect();
|
||||
let input_descriptions: Vec<_> = regular_fields.iter().map(|f| &f.description).collect();
|
||||
|
||||
let struct_fields = field_names.iter().zip(struct_generics.iter()).map(|(name, r#gen)| {
|
||||
// Generate struct fields: data fields (concrete types) + regular fields (generic types)
|
||||
let data_field_defs = data_fields.iter().map(|field| {
|
||||
let name = &field.pat_ident.ident;
|
||||
let ty = match &field.ty {
|
||||
ParsedFieldType::Regular(RegularParsedField { ty, .. }) => ty,
|
||||
_ => unreachable!("Data fields must be Regular types, not Node types"),
|
||||
};
|
||||
quote! { pub(super) #name: #ty }
|
||||
});
|
||||
|
||||
let regular_field_defs = regular_field_names.iter().zip(node_generics.iter()).map(|(name, r#gen)| {
|
||||
quote! { pub(super) #name: #r#gen }
|
||||
});
|
||||
|
||||
let struct_fields = data_field_defs.chain(regular_field_defs);
|
||||
|
||||
let mut future_idents = Vec::new();
|
||||
|
||||
let field_types: Vec<_> = fields
|
||||
// Data fields get passed as references to the underlying function
|
||||
let data_field_idents: Vec<_> = data_fields.iter().map(|f| &f.pat_ident).collect();
|
||||
let data_field_types: Vec<_> = data_fields
|
||||
.iter()
|
||||
.map(|field| match &field.ty {
|
||||
ParsedFieldType::Regular(RegularParsedField { ty, .. }) => {
|
||||
let ty = ty.clone();
|
||||
quote!(&#ty)
|
||||
}
|
||||
_ => unreachable!("Data fields must be Regular types, not Node types"),
|
||||
})
|
||||
.collect();
|
||||
|
||||
// Regular fields have types passed to the function
|
||||
let field_types: Vec<_> = regular_fields
|
||||
.iter()
|
||||
.map(|field| match &field.ty {
|
||||
ParsedFieldType::Regular(RegularParsedField { ty, .. }) => ty.clone(),
|
||||
|
|
@ -74,7 +143,8 @@ pub(crate) fn generate_node_code(crate_ident: &CrateIdent, parsed: &ParsedNodeFn
|
|||
})
|
||||
.collect();
|
||||
|
||||
let widget_override: Vec<_> = fields
|
||||
// Only regular fields have UI metadata (data fields are internal state)
|
||||
let widget_override: Vec<_> = regular_fields
|
||||
.iter()
|
||||
.map(|field| match &field.widget_override {
|
||||
ParsedWidgetOverride::None => quote!(RegistryWidgetOverride::None),
|
||||
|
|
@ -84,7 +154,7 @@ pub(crate) fn generate_node_code(crate_ident: &CrateIdent, parsed: &ParsedNodeFn
|
|||
})
|
||||
.collect();
|
||||
|
||||
let value_sources: Vec<_> = fields
|
||||
let value_sources: Vec<_> = regular_fields
|
||||
.iter()
|
||||
.map(|field| match &field.ty {
|
||||
ParsedFieldType::Regular(RegularParsedField { value_source, .. }) => match value_source {
|
||||
|
|
@ -104,7 +174,7 @@ pub(crate) fn generate_node_code(crate_ident: &CrateIdent, parsed: &ParsedNodeFn
|
|||
})
|
||||
.collect();
|
||||
|
||||
let default_types: Vec<_> = fields
|
||||
let default_types: Vec<_> = regular_fields
|
||||
.iter()
|
||||
.map(|field| match &field.ty {
|
||||
ParsedFieldType::Regular(RegularParsedField { implementations, .. }) => match implementations.first() {
|
||||
|
|
@ -115,7 +185,7 @@ pub(crate) fn generate_node_code(crate_ident: &CrateIdent, parsed: &ParsedNodeFn
|
|||
})
|
||||
.collect();
|
||||
|
||||
let number_min_values: Vec<_> = fields
|
||||
let number_min_values: Vec<_> = regular_fields
|
||||
.iter()
|
||||
.map(|field| match &field.ty {
|
||||
ParsedFieldType::Regular(RegularParsedField { number_soft_min, number_hard_min, .. }) => match (number_soft_min, number_hard_min) {
|
||||
|
|
@ -126,7 +196,7 @@ pub(crate) fn generate_node_code(crate_ident: &CrateIdent, parsed: &ParsedNodeFn
|
|||
_ => quote!(None),
|
||||
})
|
||||
.collect();
|
||||
let number_max_values: Vec<_> = fields
|
||||
let number_max_values: Vec<_> = regular_fields
|
||||
.iter()
|
||||
.map(|field| match &field.ty {
|
||||
ParsedFieldType::Regular(RegularParsedField { number_soft_max, number_hard_max, .. }) => match (number_soft_max, number_hard_max) {
|
||||
|
|
@ -137,7 +207,7 @@ pub(crate) fn generate_node_code(crate_ident: &CrateIdent, parsed: &ParsedNodeFn
|
|||
_ => quote!(None),
|
||||
})
|
||||
.collect();
|
||||
let number_mode_range_values: Vec<_> = fields
|
||||
let number_mode_range_values: Vec<_> = regular_fields
|
||||
.iter()
|
||||
.map(|field| match &field.ty {
|
||||
ParsedFieldType::Regular(RegularParsedField {
|
||||
|
|
@ -147,15 +217,15 @@ pub(crate) fn generate_node_code(crate_ident: &CrateIdent, parsed: &ParsedNodeFn
|
|||
_ => quote!(None),
|
||||
})
|
||||
.collect();
|
||||
let number_display_decimal_places: Vec<_> = fields
|
||||
let number_display_decimal_places: Vec<_> = regular_fields
|
||||
.iter()
|
||||
.map(|field| field.number_display_decimal_places.as_ref().map_or(quote!(None), |i| quote!(Some(#i))))
|
||||
.collect();
|
||||
let number_step: Vec<_> = fields.iter().map(|field| field.number_step.as_ref().map_or(quote!(None), |i| quote!(Some(#i)))).collect();
|
||||
let number_step: Vec<_> = regular_fields.iter().map(|field| field.number_step.as_ref().map_or(quote!(None), |i| quote!(Some(#i)))).collect();
|
||||
|
||||
let unit_suffix: Vec<_> = fields.iter().map(|field| field.unit.as_ref().map_or(quote!(None), |i| quote!(Some(#i)))).collect();
|
||||
let unit_suffix: Vec<_> = regular_fields.iter().map(|field| field.unit.as_ref().map_or(quote!(None), |i| quote!(Some(#i)))).collect();
|
||||
|
||||
let exposed: Vec<_> = fields
|
||||
let exposed: Vec<_> = regular_fields
|
||||
.iter()
|
||||
.map(|field| match &field.ty {
|
||||
ParsedFieldType::Regular(RegularParsedField { exposed, .. }) => quote!(#exposed),
|
||||
|
|
@ -163,7 +233,8 @@ pub(crate) fn generate_node_code(crate_ident: &CrateIdent, parsed: &ParsedNodeFn
|
|||
})
|
||||
.collect();
|
||||
|
||||
let eval_args = fields.iter().map(|field| {
|
||||
// Only eval regular fields (data fields are accessed directly as self.field_name)
|
||||
let eval_args = regular_fields.iter().map(|field| {
|
||||
let name = &field.pat_ident.ident;
|
||||
match &field.ty {
|
||||
ParsedFieldType::Regular { .. } => {
|
||||
|
|
@ -175,7 +246,8 @@ pub(crate) fn generate_node_code(crate_ident: &CrateIdent, parsed: &ParsedNodeFn
|
|||
}
|
||||
});
|
||||
|
||||
let min_max_args = fields.iter().map(|field| match &field.ty {
|
||||
// Only regular fields can have min/max constraints
|
||||
let min_max_args = regular_fields.iter().map(|field| match &field.ty {
|
||||
ParsedFieldType::Regular(RegularParsedField { number_hard_min, number_hard_max, .. }) => {
|
||||
let name = &field.pat_ident.ident;
|
||||
let mut tokens = quote!();
|
||||
|
|
@ -208,7 +280,7 @@ pub(crate) fn generate_node_code(crate_ident: &CrateIdent, parsed: &ParsedNodeFn
|
|||
let mut clauses = Vec::new();
|
||||
let mut clampable_clauses = Vec::new();
|
||||
|
||||
for (field, name) in fields.iter().zip(struct_generics.iter()) {
|
||||
for (field, name) in regular_fields.iter().zip(node_generics.iter()) {
|
||||
clauses.push(match (&field.ty, *is_async) {
|
||||
(
|
||||
ParsedFieldType::Regular(RegularParsedField {
|
||||
|
|
@ -259,13 +331,42 @@ pub(crate) fn generate_node_code(crate_ident: &CrateIdent, parsed: &ParsedNodeFn
|
|||
);
|
||||
struct_where_clause.predicates.extend(extra_where);
|
||||
|
||||
let new_args = struct_generics.iter().zip(field_names.iter()).map(|(r#gen, name)| {
|
||||
// Only regular fields are parameters to new()
|
||||
let new_args = node_generics.iter().zip(regular_field_names.iter()).map(|(r#gen, name)| {
|
||||
quote! { #name: #r#gen }
|
||||
});
|
||||
|
||||
// Initialize data fields with Default, regular fields with parameters
|
||||
let data_inits = data_field_names.iter().map(|name| {
|
||||
quote! { #name: Default::default() }
|
||||
});
|
||||
let regular_inits = regular_field_names.iter().map(|name| {
|
||||
quote! { #name }
|
||||
});
|
||||
let all_field_inits = data_inits.chain(regular_inits);
|
||||
|
||||
let async_keyword = is_async.then(|| quote!(async));
|
||||
let await_keyword = is_async.then(|| quote!(.await));
|
||||
|
||||
// Data fields may not implement Copy, PartialEq, etc., so only derive Debug and Clone
|
||||
let struct_derives = if data_fields.is_empty() {
|
||||
quote!(#[derive(Debug, Copy, Clone, PartialEq, Eq, PartialOrd, Ord, Hash)])
|
||||
} else {
|
||||
quote!(#[derive(Debug, Clone)])
|
||||
};
|
||||
|
||||
// Generate serialize method if serialize attribute is specified
|
||||
let serialize_impl = if let Some(serialize_fn) = &parsed.attributes.serialize {
|
||||
let data_field_refs = data_field_names.iter().map(|name| quote!(&self.#name));
|
||||
quote! {
|
||||
fn serialize(&self) -> Option<std::sync::Arc<dyn std::any::Any + Send + Sync>> {
|
||||
#serialize_fn(#(#data_field_refs),*)
|
||||
}
|
||||
}
|
||||
} else {
|
||||
quote!()
|
||||
};
|
||||
|
||||
let eval_impl = quote! {
|
||||
type Output = #core_types::registry::DynFuture<'n, #output_type>;
|
||||
#[inline]
|
||||
|
|
@ -275,9 +376,11 @@ pub(crate) fn generate_node_code(crate_ident: &CrateIdent, parsed: &ParsedNodeFn
|
|||
|
||||
#(#eval_args)*
|
||||
#(#min_max_args)*
|
||||
self::#fn_name(__input #(, #field_names)*) #await_keyword
|
||||
self::#fn_name(__input #(, &self.#data_field_names)* #(, #regular_field_names)*) #await_keyword
|
||||
})
|
||||
}
|
||||
|
||||
#serialize_impl
|
||||
};
|
||||
|
||||
let identifier = format_ident!("{}_proto_ident", fn_name);
|
||||
|
|
@ -302,11 +405,11 @@ pub(crate) fn generate_node_code(crate_ident: &CrateIdent, parsed: &ParsedNodeFn
|
|||
/// Underlying implementation for [#struct_name]
|
||||
#[inline]
|
||||
#[allow(clippy::too_many_arguments)]
|
||||
#vis #async_keyword fn #fn_name <'n, #(#fn_generics,)*> (#input_ident: #input_type #(, #field_idents: #field_types)*) -> #output_type #where_clause #body
|
||||
#vis #async_keyword fn #fn_name <'n, #(#fn_generics,)*> (#input_ident: #input_type #(, #data_field_idents: #data_field_types)* #(, #field_idents: #field_types)*) -> #output_type #where_clause #body
|
||||
|
||||
#cfg
|
||||
#[automatically_derived]
|
||||
impl<'n, #(#fn_generics,)* #(#struct_generics,)* #(#future_idents,)*> #core_types::Node<'n, #input_type> for #mod_name::#struct_name<#(#struct_generics,)*>
|
||||
impl<'n, #(#fn_generics,)* #(#node_generics,)* #(#future_idents,)*> #core_types::Node<'n, #input_type> for #mod_name::#struct_name<#(#struct_type_params,)*>
|
||||
#struct_where_clause
|
||||
{
|
||||
#eval_impl
|
||||
|
|
@ -340,18 +443,18 @@ pub(crate) fn generate_node_code(crate_ident: &CrateIdent, parsed: &ParsedNodeFn
|
|||
|
||||
static #import_name: core::marker::PhantomData<(#(#all_implementation_types,)*)> = core::marker::PhantomData;
|
||||
|
||||
#[derive(Debug, Copy, Clone, PartialEq, Eq, PartialOrd, Ord, Hash)]
|
||||
pub struct #struct_name<#(#struct_generics,)*> {
|
||||
#struct_derives
|
||||
pub struct #struct_name<#(#struct_generic_params,)*> {
|
||||
#(#struct_fields,)*
|
||||
}
|
||||
|
||||
#[automatically_derived]
|
||||
impl<'n, #(#struct_generics,)*> #struct_name<#(#struct_generics,)*>
|
||||
impl<'n, #(#struct_generic_params,)*> #struct_name<#(#struct_type_params,)*>
|
||||
{
|
||||
#[allow(clippy::too_many_arguments)]
|
||||
pub fn new(#(#new_args,)*) -> Self {
|
||||
Self {
|
||||
#(#field_names,)*
|
||||
#(#all_field_inits,)*
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
@ -493,8 +596,10 @@ fn generate_register_node_impl(parsed: &ParsedNodeFn, field_names: &[&Ident], st
|
|||
|
||||
let mut constructors = Vec::new();
|
||||
let unit = parse_quote!(gcore::Context);
|
||||
let parameter_types: Vec<_> = parsed
|
||||
.fields
|
||||
|
||||
let regular_fields: Vec<_> = parsed.fields.iter().filter(|f| !f.is_data_field).collect();
|
||||
|
||||
let parameter_types: Vec<_> = regular_fields
|
||||
.iter()
|
||||
.map(|field| {
|
||||
match &field.ty {
|
||||
|
|
@ -535,7 +640,7 @@ fn generate_register_node_impl(parsed: &ParsedNodeFn, field_names: &[&Ident], st
|
|||
let field_name = field_names[j];
|
||||
let (input_type, output_type) = &types[i.min(types.len() - 1)];
|
||||
|
||||
let node = matches!(parsed.fields[j].ty, ParsedFieldType::Node { .. });
|
||||
let node = matches!(regular_fields[j].ty, ParsedFieldType::Node { .. });
|
||||
|
||||
let downcast_node = quote!(
|
||||
let #field_name: DowncastBothNode<#input_type, #output_type> = DowncastBothNode::new(args[#j].clone());
|
||||
|
|
@ -712,3 +817,23 @@ impl FilterUsedGenerics {
|
|||
self.used(&*modified).cloned().collect()
|
||||
}
|
||||
}
|
||||
|
||||
/// Check if a type contains a reference to a specific identifier (e.g., a generic type parameter)
|
||||
fn type_contains_ident(ty: &Type, ident: &Ident) -> bool {
|
||||
struct IdentChecker<'a> {
|
||||
target: &'a Ident,
|
||||
found: bool,
|
||||
}
|
||||
|
||||
impl<'a, 'ast> syn::visit::Visit<'ast> for IdentChecker<'a> {
|
||||
fn visit_ident(&mut self, i: &'ast Ident) {
|
||||
if i == self.target {
|
||||
self.found = true;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
let mut checker = IdentChecker { target: ident, found: false };
|
||||
syn::visit::visit_type(&mut checker, ty);
|
||||
checker.found
|
||||
}
|
||||
|
|
|
|||
|
|
@ -50,6 +50,8 @@ pub(crate) struct NodeFnAttributes {
|
|||
pub(crate) cfg: Option<TokenStream2>,
|
||||
/// if this node should get a gpu implementation, defaults to None
|
||||
pub(crate) shader_node: Option<ShaderNodeType>,
|
||||
/// Custom serialization function path (e.g., "my_module::custom_serialize")
|
||||
pub(crate) serialize: Option<Path>,
|
||||
// Add more attributes as needed
|
||||
}
|
||||
|
||||
|
|
@ -112,6 +114,7 @@ pub struct ParsedField {
|
|||
pub number_display_decimal_places: Option<LitInt>,
|
||||
pub number_step: Option<LitFloat>,
|
||||
pub unit: Option<LitStr>,
|
||||
pub is_data_field: bool,
|
||||
}
|
||||
|
||||
#[derive(Clone, Debug)]
|
||||
|
|
@ -201,6 +204,7 @@ impl Parse for NodeFnAttributes {
|
|||
let mut properties_string = None;
|
||||
let mut cfg = None;
|
||||
let mut shader_node = None;
|
||||
let mut serialize = None;
|
||||
|
||||
let content = input;
|
||||
// let content;
|
||||
|
|
@ -270,13 +274,23 @@ impl Parse for NodeFnAttributes {
|
|||
let meta = meta.require_list()?;
|
||||
shader_node = Some(syn::parse2(meta.tokens.to_token_stream())?);
|
||||
}
|
||||
"serialize" => {
|
||||
let meta = meta.require_list()?;
|
||||
if serialize.is_some() {
|
||||
return Err(Error::new_spanned(meta, "Multiple 'serialize' attributes are not allowed"));
|
||||
}
|
||||
let parsed_path: Path = meta
|
||||
.parse_args()
|
||||
.map_err(|_| Error::new_spanned(meta, "Expected a valid path for 'serialize', e.g., serialize(my_module::custom_serialize)"))?;
|
||||
serialize = Some(parsed_path);
|
||||
}
|
||||
_ => {
|
||||
return Err(Error::new_spanned(
|
||||
meta,
|
||||
indoc!(
|
||||
r#"
|
||||
Unsupported attribute in `node`.
|
||||
Supported attributes are 'category', 'path' 'name', 'skip_impl', 'cfg' and 'properties'.
|
||||
Supported attributes are 'category', 'path', 'name', 'skip_impl', 'cfg', 'properties', 'serialize', and 'shader_node'.
|
||||
|
||||
Example usage:
|
||||
#[node_macro::node(category("Value"), name("Test Node"))]
|
||||
|
|
@ -295,6 +309,7 @@ impl Parse for NodeFnAttributes {
|
|||
properties_string,
|
||||
cfg,
|
||||
shader_node,
|
||||
serialize,
|
||||
})
|
||||
}
|
||||
}
|
||||
|
|
@ -467,6 +482,9 @@ fn parse_node_implementations<T: Parse>(attr: &Attribute, name: &Ident) -> syn::
|
|||
fn parse_field(pat_ident: PatIdent, ty: Type, attrs: &[Attribute]) -> syn::Result<ParsedField> {
|
||||
let ident = &pat_ident.ident;
|
||||
|
||||
// Check if this is a data field (struct field, not a parameter)
|
||||
let is_data_field = extract_attribute(attrs, "data").is_some();
|
||||
|
||||
let default_value = extract_attribute(attrs, "default")
|
||||
.map(|attr| attr.parse_args().map_err(|e| Error::new_spanned(attr, format!("Invalid `default` value for argument '{ident}': {e}"))))
|
||||
.transpose()?;
|
||||
|
|
@ -489,6 +507,25 @@ fn parse_field(pat_ident: PatIdent, ty: Type, attrs: &[Attribute]) -> syn::Resul
|
|||
|
||||
let exposed = extract_attribute(attrs, "expose").is_some();
|
||||
|
||||
// Validate data field attributes
|
||||
if is_data_field {
|
||||
if default_value.is_some() {
|
||||
return Err(Error::new_spanned(
|
||||
&pat_ident,
|
||||
"Data fields (#[data]) cannot have #[default] attribute. They are automatically initialized with Default::default()",
|
||||
));
|
||||
}
|
||||
if scope.is_some() {
|
||||
return Err(Error::new_spanned(&pat_ident, "Data fields (#[data]) cannot have #[scope] attribute"));
|
||||
}
|
||||
if exposed {
|
||||
return Err(Error::new_spanned(
|
||||
&pat_ident,
|
||||
"Data fields (#[data]) cannot be exposed (#[expose]). They are internal state, not node parameters",
|
||||
));
|
||||
}
|
||||
}
|
||||
|
||||
let value_source = match (default_value, scope) {
|
||||
(Some(_), Some(_)) => return Err(Error::new_spanned(&pat_ident, "Cannot have both `default` and `scope` attributes")),
|
||||
(Some(default_value), _) => ParsedValueSource::Default(default_value),
|
||||
|
|
@ -586,6 +623,14 @@ fn parse_field(pat_ident: PatIdent, ty: Type, attrs: &[Attribute]) -> syn::Resul
|
|||
.fold(String::new(), |acc, b| acc + &b + "\n");
|
||||
|
||||
if is_node {
|
||||
// Data fields cannot be impl Node types
|
||||
if is_data_field {
|
||||
return Err(Error::new_spanned(
|
||||
&ty,
|
||||
"Data fields (#[data]) cannot be of type `impl Node`. Data fields must be concrete types that implement Default",
|
||||
));
|
||||
}
|
||||
|
||||
let (input_type, output_type) = node_input_type
|
||||
.zip(node_output_type)
|
||||
.ok_or_else(|| Error::new_spanned(&ty, "Invalid Node type. Expected `impl Node<Input, Output = OutputType>`"))?;
|
||||
|
|
@ -610,6 +655,7 @@ fn parse_field(pat_ident: PatIdent, ty: Type, attrs: &[Attribute]) -> syn::Resul
|
|||
number_display_decimal_places,
|
||||
number_step,
|
||||
unit,
|
||||
is_data_field,
|
||||
})
|
||||
} else {
|
||||
let implementations = extract_attribute(attrs, "implementations")
|
||||
|
|
@ -636,6 +682,7 @@ fn parse_field(pat_ident: PatIdent, ty: Type, attrs: &[Attribute]) -> syn::Resul
|
|||
number_display_decimal_places,
|
||||
number_step,
|
||||
unit,
|
||||
is_data_field,
|
||||
})
|
||||
}
|
||||
}
|
||||
|
|
@ -826,6 +873,7 @@ mod tests {
|
|||
properties_string: None,
|
||||
cfg: None,
|
||||
shader_node: None,
|
||||
serialize: None,
|
||||
},
|
||||
fn_name: Ident::new("add", Span::call_site()),
|
||||
struct_name: Ident::new("Add", Span::call_site()),
|
||||
|
|
@ -860,6 +908,7 @@ mod tests {
|
|||
number_display_decimal_places: None,
|
||||
number_step: None,
|
||||
unit: None,
|
||||
is_data_field: false,
|
||||
}],
|
||||
body: TokenStream2::new(),
|
||||
description: String::from("Multi\nLine\n"),
|
||||
|
|
@ -892,6 +941,7 @@ mod tests {
|
|||
properties_string: None,
|
||||
cfg: None,
|
||||
shader_node: None,
|
||||
serialize: None,
|
||||
},
|
||||
fn_name: Ident::new("transform", Span::call_site()),
|
||||
struct_name: Ident::new("Transform", Span::call_site()),
|
||||
|
|
@ -920,6 +970,7 @@ mod tests {
|
|||
number_display_decimal_places: None,
|
||||
number_step: None,
|
||||
unit: None,
|
||||
is_data_field: false,
|
||||
},
|
||||
ParsedField {
|
||||
pat_ident: pat_ident("translate"),
|
||||
|
|
@ -941,6 +992,7 @@ mod tests {
|
|||
number_display_decimal_places: None,
|
||||
number_step: None,
|
||||
unit: None,
|
||||
is_data_field: false,
|
||||
},
|
||||
],
|
||||
body: TokenStream2::new(),
|
||||
|
|
@ -971,6 +1023,7 @@ mod tests {
|
|||
properties_string: None,
|
||||
cfg: None,
|
||||
shader_node: None,
|
||||
serialize: None,
|
||||
},
|
||||
fn_name: Ident::new("circle", Span::call_site()),
|
||||
struct_name: Ident::new("Circle", Span::call_site()),
|
||||
|
|
@ -1005,6 +1058,7 @@ mod tests {
|
|||
number_display_decimal_places: None,
|
||||
number_step: None,
|
||||
unit: None,
|
||||
is_data_field: false,
|
||||
}],
|
||||
body: TokenStream2::new(),
|
||||
description: "Test\n".into(),
|
||||
|
|
@ -1033,6 +1087,7 @@ mod tests {
|
|||
properties_string: None,
|
||||
cfg: None,
|
||||
shader_node: None,
|
||||
serialize: None,
|
||||
},
|
||||
fn_name: Ident::new("levels", Span::call_site()),
|
||||
struct_name: Ident::new("Levels", Span::call_site()),
|
||||
|
|
@ -1072,6 +1127,7 @@ mod tests {
|
|||
number_display_decimal_places: None,
|
||||
number_step: None,
|
||||
unit: None,
|
||||
is_data_field: false,
|
||||
}],
|
||||
body: TokenStream2::new(),
|
||||
description: String::new(),
|
||||
|
|
@ -1107,6 +1163,7 @@ mod tests {
|
|||
properties_string: None,
|
||||
cfg: None,
|
||||
shader_node: None,
|
||||
serialize: None,
|
||||
},
|
||||
fn_name: Ident::new("add", Span::call_site()),
|
||||
struct_name: Ident::new("Add", Span::call_site()),
|
||||
|
|
@ -1141,6 +1198,7 @@ mod tests {
|
|||
number_display_decimal_places: None,
|
||||
number_step: None,
|
||||
unit: None,
|
||||
is_data_field: false,
|
||||
}],
|
||||
body: TokenStream2::new(),
|
||||
description: String::new(),
|
||||
|
|
@ -1169,6 +1227,7 @@ mod tests {
|
|||
properties_string: None,
|
||||
cfg: None,
|
||||
shader_node: None,
|
||||
serialize: None,
|
||||
},
|
||||
fn_name: Ident::new("load_image", Span::call_site()),
|
||||
struct_name: Ident::new("LoadImage", Span::call_site()),
|
||||
|
|
@ -1203,6 +1262,7 @@ mod tests {
|
|||
number_display_decimal_places: None,
|
||||
number_step: None,
|
||||
unit: None,
|
||||
is_data_field: false,
|
||||
}],
|
||||
body: TokenStream2::new(),
|
||||
description: String::new(),
|
||||
|
|
@ -1231,6 +1291,7 @@ mod tests {
|
|||
properties_string: None,
|
||||
cfg: None,
|
||||
shader_node: None,
|
||||
serialize: None,
|
||||
},
|
||||
fn_name: Ident::new("custom_node", Span::call_site()),
|
||||
struct_name: Ident::new("CustomNode", Span::call_site()),
|
||||
|
|
|
|||
|
|
@ -245,6 +245,7 @@ impl PerPixelAdjustCodegen<'_> {
|
|||
number_display_decimal_places: None,
|
||||
number_step: None,
|
||||
unit: None,
|
||||
is_data_field: false,
|
||||
});
|
||||
|
||||
// find exactly one gpu_image field, runtime doesn't support more than 1 atm
|
||||
|
|
|
|||
|
|
@ -102,6 +102,11 @@ fn validate_implementations_for_generics(parsed: &ParsedNodeFn) {
|
|||
|
||||
if !has_skip_impl && !parsed.fn_generics.is_empty() {
|
||||
for field in &parsed.fields {
|
||||
// Skip validation for data fields - they're internal state and can be generic
|
||||
if field.is_data_field {
|
||||
continue;
|
||||
}
|
||||
|
||||
let pat_ident = &field.pat_ident;
|
||||
match &field.ty {
|
||||
ParsedFieldType::Regular(RegularParsedField { ty, implementations, .. }) => {
|
||||
|
|
|
|||
|
|
@ -1,7 +1,5 @@
|
|||
use core_types::WasmNotSend;
|
||||
use core_types::memo::*;
|
||||
use core_types::{Node, WasmNotSend};
|
||||
use dyn_any::DynFuture;
|
||||
use std::future::Future;
|
||||
use std::hash::DefaultHasher;
|
||||
use std::hash::{Hash, Hasher};
|
||||
use std::sync::Arc;
|
||||
|
|
@ -14,94 +12,38 @@ use std::sync::Mutex;
|
|||
/// A cache hit occurs when the Option is Some and has a stored hash matching the hash of the call argument. In this case, the node returns the cached value without re-evaluating the inner node.
|
||||
///
|
||||
/// Currently, only one input-output pair is cached. Subsequent calls with different inputs will overwrite the previous cache.
|
||||
#[derive(Default)]
|
||||
pub struct MemoNode<T, CachedNode> {
|
||||
cache: Arc<Mutex<Option<(u64, T)>>>,
|
||||
node: CachedNode,
|
||||
}
|
||||
impl<'i, I: Hash + 'i, T: 'i + Clone + WasmNotSend, CachedNode: 'i> Node<'i, I> for MemoNode<T, CachedNode>
|
||||
where
|
||||
CachedNode: for<'any_input> Node<'any_input, I>,
|
||||
for<'a> <CachedNode as Node<'a, I>>::Output: Future<Output = T> + WasmNotSend,
|
||||
{
|
||||
// TODO: This should return a reference to the cached cached_value
|
||||
// but that requires a lot of lifetime magic <- This was suggested by copilot but is pretty accurate xD
|
||||
type Output = DynFuture<'i, T>;
|
||||
fn eval(&'i self, input: I) -> Self::Output {
|
||||
let mut hasher = DefaultHasher::new();
|
||||
input.hash(&mut hasher);
|
||||
let hash = hasher.finish();
|
||||
#[node_macro::node(category(""), path(graphene_core::memo), skip_impl)]
|
||||
async fn memo<I: Hash + Send + 'n, T: Clone + WasmNotSend>(input: I, #[data] cache: Arc<Mutex<Option<(u64, T)>>>, node: impl Node<I, Output = T>) -> T {
|
||||
let mut hasher = DefaultHasher::new();
|
||||
input.hash(&mut hasher);
|
||||
let hash = hasher.finish();
|
||||
|
||||
if let Some(data) = self.cache.lock().as_ref().unwrap().as_ref().and_then(|data| (data.0 == hash).then_some(data.1.clone())) {
|
||||
Box::pin(async move { data })
|
||||
} else {
|
||||
let fut = self.node.eval(input);
|
||||
let cache = self.cache.clone();
|
||||
Box::pin(async move {
|
||||
let value = fut.await;
|
||||
*cache.lock().unwrap() = Some((hash, value.clone()));
|
||||
value
|
||||
})
|
||||
}
|
||||
if let Some(data) = cache.lock().as_ref().unwrap().as_ref().and_then(|data| (data.0 == hash).then_some(data.1.clone())) {
|
||||
return data;
|
||||
}
|
||||
|
||||
fn reset(&self) {
|
||||
self.cache.lock().unwrap().take();
|
||||
}
|
||||
let value = node.eval(input).await;
|
||||
*cache.lock().unwrap() = Some((hash, value.clone()));
|
||||
value
|
||||
}
|
||||
|
||||
impl<T, CachedNode> MemoNode<T, CachedNode> {
|
||||
pub fn new(node: CachedNode) -> MemoNode<T, CachedNode> {
|
||||
MemoNode { cache: Default::default(), node }
|
||||
}
|
||||
}
|
||||
|
||||
#[allow(clippy::module_inception)]
|
||||
pub mod memo {
|
||||
use core_types::ProtoNodeIdentifier;
|
||||
|
||||
pub const IDENTIFIER: ProtoNodeIdentifier = ProtoNodeIdentifier::new("graphene_core::memo::MemoNode");
|
||||
}
|
||||
type MonitorValue<I, T> = Arc<Mutex<Option<Arc<IORecord<I, T>>>>>;
|
||||
|
||||
/// Caches the output of the last graph evaluation for introspection.
|
||||
#[derive(Default)]
|
||||
pub struct MonitorNode<I, T, N> {
|
||||
#[node_macro::node(category(""), path(graphene_core::memo), serialize(serialize_monitor), skip_impl)]
|
||||
async fn monitor<I: Clone + 'static + Send + Sync, T: Clone + 'static + Send + Sync>(
|
||||
input: I,
|
||||
#[allow(clippy::type_complexity)]
|
||||
io: Arc<Mutex<Option<Arc<IORecord<I, T>>>>>,
|
||||
node: N,
|
||||
#[data]
|
||||
io: MonitorValue<I, T>,
|
||||
node: impl Node<I, Output = T>,
|
||||
) -> T {
|
||||
let output = node.eval(input.clone()).await;
|
||||
*io.lock().unwrap() = Some(Arc::new(IORecord { input, output: output.clone() }));
|
||||
output
|
||||
}
|
||||
|
||||
impl<'i, T, I, N> Node<'i, I> for MonitorNode<I, T, N>
|
||||
where
|
||||
I: Clone + 'static + Send + Sync,
|
||||
T: Clone + 'static + Send + Sync,
|
||||
for<'a> N: Node<'a, I, Output: Future<Output = T> + WasmNotSend> + 'i,
|
||||
{
|
||||
type Output = DynFuture<'i, T>;
|
||||
fn eval(&'i self, input: I) -> Self::Output {
|
||||
let io = self.io.clone();
|
||||
let output_fut = self.node.eval(input.clone());
|
||||
Box::pin(async move {
|
||||
let output = output_fut.await;
|
||||
*io.lock().unwrap() = Some(Arc::new(IORecord { input, output: output.clone() }));
|
||||
output
|
||||
})
|
||||
}
|
||||
|
||||
fn serialize(&self) -> Option<Arc<dyn std::any::Any + Send + Sync>> {
|
||||
let io = self.io.lock().unwrap();
|
||||
(io).as_ref().map(|output| output.clone() as Arc<dyn std::any::Any + Send + Sync>)
|
||||
}
|
||||
}
|
||||
|
||||
impl<I, T, N> MonitorNode<I, T, N> {
|
||||
pub fn new(node: N) -> MonitorNode<I, T, N> {
|
||||
MonitorNode { io: Arc::new(Mutex::new(None)), node }
|
||||
}
|
||||
}
|
||||
|
||||
pub mod monitor {
|
||||
use core_types::ProtoNodeIdentifier;
|
||||
|
||||
pub const IDENTIFIER: ProtoNodeIdentifier = ProtoNodeIdentifier::new("graphene_core::memo::MonitorNode");
|
||||
fn serialize_monitor<I: Clone + 'static + Send + Sync, T: Clone + 'static + Send + Sync>(io: &MonitorValue<I, T>) -> Option<Arc<dyn std::any::Any + Send + Sync>> {
|
||||
let io = io.lock().unwrap();
|
||||
io.as_ref().map(|output| output.clone() as Arc<dyn std::any::Any + Send + Sync>)
|
||||
}
|
||||
|
|
|
|||
Loading…
Reference in New Issue