Add the "memoize" attribute to the node macro (#4065)

* Add memoization attribute to node macro

* Fix memoization insertion for networks without conversion nodes
This commit is contained in:
Dennis Kobert 2026-04-28 13:55:38 +02:00 committed by GitHub
parent 3eba762135
commit e0368435b9
6 changed files with 47 additions and 6 deletions

View File

@ -36,6 +36,7 @@ pub(super) fn post_process_nodes(custom: Vec<DocumentNodeDefinition>) -> HashMap
description, description,
properties, properties,
context_features, context_features,
memoize: _,
} = metadata; } = metadata;
let Some(implementations) = &node_registry.get(id) else { continue }; let Some(implementations) = &node_registry.get(id) else { continue };

View File

@ -16,6 +16,7 @@ pub struct NodeMetadata {
pub description: &'static str, pub description: &'static str,
pub properties: Option<&'static str>, pub properties: Option<&'static str>,
pub context_features: Vec<ContextFeature>, pub context_features: Vec<ContextFeature>,
pub memoize: bool,
} }
// Translation struct between macro and definition // Translation struct between macro and definition

View File

@ -401,6 +401,7 @@ pub(crate) fn generate_node_code(crate_ident: &CrateIdent, parsed: &ParsedNodeFn
let import_name = format_ident!("_IMPORT_STUB_{}", mod_name.to_string().to_case(Case::UpperSnake)); let import_name = format_ident!("_IMPORT_STUB_{}", mod_name.to_string().to_case(Case::UpperSnake));
let properties = &attributes.properties_string.as_ref().map(|value| quote!(Some(#value))).unwrap_or(quote!(None)); let properties = &attributes.properties_string.as_ref().map(|value| quote!(Some(#value))).unwrap_or(quote!(None));
let memoize_flag = attributes.memoize;
let cfg = crate::shader_nodes::modify_cfg(attributes); let cfg = crate::shader_nodes::modify_cfg(attributes);
let node_input_accessor = generate_node_input_references(parsed, fn_generics, &field_idents, core_types, &identifier, &cfg); let node_input_accessor = generate_node_input_references(parsed, fn_generics, &field_idents, core_types, &identifier, &cfg);
@ -498,6 +499,7 @@ pub(crate) fn generate_node_code(crate_ident: &CrateIdent, parsed: &ParsedNodeFn
description: #description, description: #description,
properties: #properties, properties: #properties,
context_features: vec![#(ContextFeature::#context_features,)*], context_features: vec![#(ContextFeature::#context_features,)*],
memoize: #memoize_flag,
fields: vec![ fields: vec![
#( #(
FieldMetadata { FieldMetadata {

View File

@ -52,7 +52,8 @@ pub(crate) struct NodeFnAttributes {
pub(crate) shader_node: Option<ShaderNodeType>, pub(crate) shader_node: Option<ShaderNodeType>,
/// Custom serialization function path (e.g., "my_module::custom_serialize") /// Custom serialization function path (e.g., "my_module::custom_serialize")
pub(crate) serialize: Option<Path>, pub(crate) serialize: Option<Path>,
// Add more attributes as needed /// Whether the preprocessor should add a Memo node after this node in the generated subnetwork
pub(crate) memoize: bool,
} }
#[derive(Clone, Debug, Default)] #[derive(Clone, Debug, Default)]
@ -259,6 +260,7 @@ impl Parse for NodeFnAttributes {
let mut cfg = None; let mut cfg = None;
let mut shader_node = None; let mut shader_node = None;
let mut serialize = None; let mut serialize = None;
let mut memoize = false;
let content = input; let content = input;
// let content; // let content;
@ -377,13 +379,25 @@ impl Parse for NodeFnAttributes {
.map_err(|_| Error::new_spanned(meta, "Expected a valid path for 'serialize', e.g., serialize(my_module::custom_serialize)"))?; .map_err(|_| Error::new_spanned(meta, "Expected a valid path for 'serialize', e.g., serialize(my_module::custom_serialize)"))?;
serialize = Some(parsed_path); serialize = Some(parsed_path);
} }
// Instructs the preprocessor to insert a Memo node after this node in the generated subnetwork,
// caching its output across evaluations with identical inputs.
//
// Example usage:
// #[node_macro::node(..., memoize, ...)]
"memoize" => {
let path = meta.require_path_only()?;
if memoize {
return Err(Error::new_spanned(path, "Multiple 'memoize' attributes are not allowed"));
}
memoize = true;
}
_ => { _ => {
return Err(Error::new_spanned( return Err(Error::new_spanned(
meta, meta,
indoc!( indoc!(
r#" r#"
Unsupported attribute in `node`. Unsupported attribute in `node`.
Supported attributes are 'category', 'name', 'path', 'skip_impl', 'properties', 'cfg', 'shader_node', and 'serialize'. Supported attributes are 'category', 'name', 'path', 'skip_impl', 'properties', 'cfg', 'shader_node', 'serialize', and 'memoize'.
Example usage: Example usage:
#[node_macro::node(..., name("Test Node"), ...)] #[node_macro::node(..., name("Test Node"), ...)]
"# "#
@ -415,6 +429,7 @@ impl Parse for NodeFnAttributes {
cfg, cfg,
shader_node, shader_node,
serialize, serialize,
memoize,
}) })
} }
} }
@ -1020,6 +1035,7 @@ mod tests {
cfg: None, cfg: None,
shader_node: None, shader_node: None,
serialize: None, serialize: None,
memoize: false,
}, },
fn_name: Ident::new("add", Span::call_site()), fn_name: Ident::new("add", Span::call_site()),
struct_name: Ident::new("Add", Span::call_site()), struct_name: Ident::new("Add", Span::call_site()),
@ -1088,6 +1104,7 @@ mod tests {
cfg: None, cfg: None,
shader_node: None, shader_node: None,
serialize: None, serialize: None,
memoize: false,
}, },
fn_name: Ident::new("transform", Span::call_site()), fn_name: Ident::new("transform", Span::call_site()),
struct_name: Ident::new("Transform", Span::call_site()), struct_name: Ident::new("Transform", Span::call_site()),
@ -1170,6 +1187,7 @@ mod tests {
cfg: None, cfg: None,
shader_node: None, shader_node: None,
serialize: None, serialize: None,
memoize: false,
}, },
fn_name: Ident::new("circle", Span::call_site()), fn_name: Ident::new("circle", Span::call_site()),
struct_name: Ident::new("Circle", Span::call_site()), struct_name: Ident::new("Circle", Span::call_site()),
@ -1234,6 +1252,7 @@ mod tests {
cfg: None, cfg: None,
shader_node: None, shader_node: None,
serialize: None, serialize: None,
memoize: false,
}, },
fn_name: Ident::new("levels", Span::call_site()), fn_name: Ident::new("levels", Span::call_site()),
struct_name: Ident::new("Levels", Span::call_site()), struct_name: Ident::new("Levels", Span::call_site()),
@ -1310,6 +1329,7 @@ mod tests {
cfg: None, cfg: None,
shader_node: None, shader_node: None,
serialize: None, serialize: None,
memoize: false,
}, },
fn_name: Ident::new("add", Span::call_site()), fn_name: Ident::new("add", Span::call_site()),
struct_name: Ident::new("Add", Span::call_site()), struct_name: Ident::new("Add", Span::call_site()),
@ -1374,6 +1394,7 @@ mod tests {
cfg: None, cfg: None,
shader_node: None, shader_node: None,
serialize: None, serialize: None,
memoize: false,
}, },
fn_name: Ident::new("load_image", Span::call_site()), fn_name: Ident::new("load_image", Span::call_site()),
struct_name: Ident::new("LoadImage", Span::call_site()), struct_name: Ident::new("LoadImage", Span::call_site()),
@ -1438,6 +1459,7 @@ mod tests {
cfg: None, cfg: None,
shader_node: None, shader_node: None,
serialize: None, serialize: None,
memoize: false,
}, },
fn_name: Ident::new("custom_node", Span::call_site()), fn_name: Ident::new("custom_node", Span::call_site()),
struct_name: Ident::new("CustomNode", Span::call_site()), struct_name: Ident::new("CustomNode", Span::call_site()),

View File

@ -1326,7 +1326,7 @@ pub async fn flatten_path<T: IntoGraphicTable + 'n + Send>(_: impl Ctx, #[implem
} }
/// Convert vector geometry into a polyline composed of evenly spaced points. /// Convert vector geometry into a polyline composed of evenly spaced points.
#[node_macro::node(category("Vector: Modifier"), path(core_types::vector), properties("sample_polyline_properties"))] #[node_macro::node(category("Vector: Modifier"), path(core_types::vector), properties("sample_polyline_properties"), memoize)]
async fn sample_polyline( async fn sample_polyline(
_: impl Ctx, _: impl Ctx,
content: Table<Vector>, content: Table<Vector>,

View File

@ -42,7 +42,7 @@ pub fn generate_node_substitutions() -> HashMap<ProtoNodeIdentifier, DocumentNod
for (id, metadata) in core_types::registry::NODE_METADATA.lock().unwrap().iter() { for (id, metadata) in core_types::registry::NODE_METADATA.lock().unwrap().iter() {
let id = id.clone(); let id = id.clone();
let NodeMetadata { fields, .. } = metadata; let NodeMetadata { fields, memoize, .. } = metadata;
let Some(implementations) = node_registry.get(&id) else { continue }; let Some(implementations) = node_registry.get(&id) else { continue };
let valid_call_args: HashSet<_> = implementations.iter().map(|(_, node_io)| node_io.call_argument.clone()).collect(); let valid_call_args: HashSet<_> = implementations.iter().map(|(_, node_io)| node_io.call_argument.clone()).collect();
let first_node_io = implementations.first().map(|(_, node_io)| node_io).unwrap_or(const { &NodeIOTypes::empty() }); let first_node_io = implementations.first().map(|(_, node_io)| node_io).unwrap_or(const { &NodeIOTypes::empty() });
@ -111,7 +111,7 @@ pub fn generate_node_substitutions() -> HashMap<ProtoNodeIdentifier, DocumentNod
}) })
.collect(); .collect();
if generated_nodes == 0 { if generated_nodes == 0 && !memoize {
continue; continue;
} }
@ -127,12 +127,27 @@ pub fn generate_node_substitutions() -> HashMap<ProtoNodeIdentifier, DocumentNod
nodes.insert(NodeId(input_count as u64), document_node); nodes.insert(NodeId(input_count as u64), document_node);
// If memoize is requested, append a Memo node after the main node and redirect the export through it
let export_node_id = if *memoize {
let memo_node_id = NodeId(input_count as u64 + 1);
let memo_node = DocumentNode {
inputs: vec![NodeInput::node(NodeId(input_count as u64), 0)],
implementation: DocumentNodeImplementation::ProtoNode(graphene_core::memo::memo::IDENTIFIER.clone()),
visible: true,
..Default::default()
};
nodes.insert(memo_node_id, memo_node);
memo_node_id
} else {
NodeId(input_count as u64)
};
let node = DocumentNode { let node = DocumentNode {
inputs, inputs,
call_argument: input_type.clone(), call_argument: input_type.clone(),
implementation: DocumentNodeImplementation::Network(NodeNetwork { implementation: DocumentNodeImplementation::Network(NodeNetwork {
exports: vec![NodeInput::Node { exports: vec![NodeInput::Node {
node_id: NodeId(input_count as u64), node_id: export_node_id,
output_index: 0, output_index: 0,
}], }],
nodes, nodes,