Migrate memo nodes to node macro and make implementing other persistent nodes easier (#3552)
* Add #[data] and #[serialize] attributes to node macro - Add #[data] attribute for struct fields that aren't node parameters - Data fields are initialized with Default::default() - Passed as references to the underlying function - Excluded from registry metadata (internal state) - Generic types in data fields allowed without #[implementations] - Add #[serialize] attribute for custom Node::serialize() implementation - Receives references to all data fields - Generates serialize() method in Node trait impl - Conditional derives based on data field presence - With data fields: Debug, Clone only - Without data fields: Debug, Copy, Clone, PartialEq, Eq, PartialOrd, Ord, Hash * Refactor Memo and Monitor Node to use node macro * Move Complex type into type alias * Fix format * Update node-graph/nodes/gcore/src/memo.rs Co-authored-by: Keavon Chambers <keavon@keavon.com> * Update node-graph/nodes/gcore/src/memo.rs Co-authored-by: Keavon Chambers <keavon@keavon.com> --------- Co-authored-by: Keavon Chambers <keavon@keavon.com>
This commit is contained in:
parent
8f25eb6ca4
commit
fafc687d84
|
|
@ -37,33 +37,102 @@ pub(crate) fn generate_node_code(crate_ident: &CrateIdent, parsed: &ParsedNodeFn
|
||||||
};
|
};
|
||||||
let struct_name = format_ident!("{}Node", struct_name);
|
let struct_name = format_ident!("{}Node", struct_name);
|
||||||
|
|
||||||
let struct_generics: Vec<Ident> = fields.iter().enumerate().map(|(i, _)| format_ident!("Node{}", i)).collect();
|
// Separate data fields from regular fields
|
||||||
|
let (data_fields, regular_fields): (Vec<_>, Vec<_>) = fields.iter().partition(|f| f.is_data_field);
|
||||||
|
|
||||||
|
// Extract function generics used by data fields
|
||||||
|
let data_field_generics: Vec<_> = fn_generics
|
||||||
|
.iter()
|
||||||
|
.filter(|generic| {
|
||||||
|
let generic_ident = match generic {
|
||||||
|
syn::GenericParam::Type(type_param) => &type_param.ident,
|
||||||
|
_ => return false,
|
||||||
|
};
|
||||||
|
|
||||||
|
// Check if this generic is used in any data field type
|
||||||
|
data_fields.iter().any(|field| match &field.ty {
|
||||||
|
ParsedFieldType::Regular(RegularParsedField { ty, .. }) => type_contains_ident(ty, generic_ident),
|
||||||
|
_ => false,
|
||||||
|
})
|
||||||
|
})
|
||||||
|
.cloned()
|
||||||
|
.collect();
|
||||||
|
|
||||||
|
// Node generics for regular fields (Node0, Node1, ...)
|
||||||
|
let node_generics: Vec<Ident> = regular_fields.iter().enumerate().map(|(i, _)| format_ident!("Node{}", i)).collect();
|
||||||
|
|
||||||
|
// Extract just the idents from data_field_generics for struct type parameters
|
||||||
|
let data_field_generic_idents: Vec<Ident> = data_field_generics
|
||||||
|
.iter()
|
||||||
|
.filter_map(|gp| match gp {
|
||||||
|
syn::GenericParam::Type(tp) => Some(tp.ident.clone()),
|
||||||
|
_ => None,
|
||||||
|
})
|
||||||
|
.collect();
|
||||||
|
|
||||||
|
// Combined struct type parameters: data field generic idents (T, U, ...) + node generics (Node0, Node1, ...)
|
||||||
|
// For struct type instantiation: MemoNode<T, Node0>
|
||||||
|
let struct_type_params: Vec<Ident> = data_field_generic_idents.iter().cloned().chain(node_generics.iter().cloned()).collect();
|
||||||
|
|
||||||
|
// Combined struct generic parameters with bounds for struct definition
|
||||||
|
// struct MemoNode<T: Clone, Node0>
|
||||||
|
let struct_generic_params: Vec<TokenStream2> = data_field_generics.iter().map(|gp| quote!(#gp)).chain(node_generics.iter().map(|id| quote!(#id))).collect();
|
||||||
let input_ident = &input.pat_ident;
|
let input_ident = &input.pat_ident;
|
||||||
|
|
||||||
let context_features = &input.context_features;
|
let context_features = &input.context_features;
|
||||||
|
|
||||||
let field_idents: Vec<_> = fields.iter().map(|f| &f.pat_ident).collect();
|
// Regular field idents and names (for function parameters)
|
||||||
|
let field_idents: Vec<_> = regular_fields.iter().map(|f| &f.pat_ident).collect();
|
||||||
let field_names: Vec<_> = field_idents.iter().map(|pat_ident| &pat_ident.ident).collect();
|
let field_names: Vec<_> = field_idents.iter().map(|pat_ident| &pat_ident.ident).collect();
|
||||||
|
let regular_field_names: Vec<_> = regular_fields.iter().map(|f| &f.pat_ident.ident).collect();
|
||||||
|
let data_field_names: Vec<_> = data_fields.iter().map(|f| &f.pat_ident.ident).collect();
|
||||||
|
|
||||||
let input_names: Vec<_> = fields
|
// Only regular fields have input names/descriptions (for UI)
|
||||||
|
let input_names: Vec<_> = regular_fields
|
||||||
.iter()
|
.iter()
|
||||||
.map(|f| &f.name)
|
.map(|f| &f.name)
|
||||||
.zip(field_names.iter())
|
.zip(regular_field_names.iter())
|
||||||
.map(|zipped| match zipped {
|
.map(|zipped| match zipped {
|
||||||
(Some(name), _) => name.value(),
|
(Some(name), _) => name.value(),
|
||||||
(_, name) => name.to_string().to_case(Case::Title),
|
(_, name) => name.to_string().to_case(Case::Title),
|
||||||
})
|
})
|
||||||
.collect();
|
.collect();
|
||||||
|
|
||||||
let input_descriptions: Vec<_> = fields.iter().map(|f| &f.description).collect();
|
let input_descriptions: Vec<_> = regular_fields.iter().map(|f| &f.description).collect();
|
||||||
|
|
||||||
let struct_fields = field_names.iter().zip(struct_generics.iter()).map(|(name, r#gen)| {
|
// Generate struct fields: data fields (concrete types) + regular fields (generic types)
|
||||||
|
let data_field_defs = data_fields.iter().map(|field| {
|
||||||
|
let name = &field.pat_ident.ident;
|
||||||
|
let ty = match &field.ty {
|
||||||
|
ParsedFieldType::Regular(RegularParsedField { ty, .. }) => ty,
|
||||||
|
_ => unreachable!("Data fields must be Regular types, not Node types"),
|
||||||
|
};
|
||||||
|
quote! { pub(super) #name: #ty }
|
||||||
|
});
|
||||||
|
|
||||||
|
let regular_field_defs = regular_field_names.iter().zip(node_generics.iter()).map(|(name, r#gen)| {
|
||||||
quote! { pub(super) #name: #r#gen }
|
quote! { pub(super) #name: #r#gen }
|
||||||
});
|
});
|
||||||
|
|
||||||
|
let struct_fields = data_field_defs.chain(regular_field_defs);
|
||||||
|
|
||||||
let mut future_idents = Vec::new();
|
let mut future_idents = Vec::new();
|
||||||
|
|
||||||
let field_types: Vec<_> = fields
|
// Data fields get passed as references to the underlying function
|
||||||
|
let data_field_idents: Vec<_> = data_fields.iter().map(|f| &f.pat_ident).collect();
|
||||||
|
let data_field_types: Vec<_> = data_fields
|
||||||
|
.iter()
|
||||||
|
.map(|field| match &field.ty {
|
||||||
|
ParsedFieldType::Regular(RegularParsedField { ty, .. }) => {
|
||||||
|
let ty = ty.clone();
|
||||||
|
quote!(&#ty)
|
||||||
|
}
|
||||||
|
_ => unreachable!("Data fields must be Regular types, not Node types"),
|
||||||
|
})
|
||||||
|
.collect();
|
||||||
|
|
||||||
|
// Regular fields have types passed to the function
|
||||||
|
let field_types: Vec<_> = regular_fields
|
||||||
.iter()
|
.iter()
|
||||||
.map(|field| match &field.ty {
|
.map(|field| match &field.ty {
|
||||||
ParsedFieldType::Regular(RegularParsedField { ty, .. }) => ty.clone(),
|
ParsedFieldType::Regular(RegularParsedField { ty, .. }) => ty.clone(),
|
||||||
|
|
@ -74,7 +143,8 @@ pub(crate) fn generate_node_code(crate_ident: &CrateIdent, parsed: &ParsedNodeFn
|
||||||
})
|
})
|
||||||
.collect();
|
.collect();
|
||||||
|
|
||||||
let widget_override: Vec<_> = fields
|
// Only regular fields have UI metadata (data fields are internal state)
|
||||||
|
let widget_override: Vec<_> = regular_fields
|
||||||
.iter()
|
.iter()
|
||||||
.map(|field| match &field.widget_override {
|
.map(|field| match &field.widget_override {
|
||||||
ParsedWidgetOverride::None => quote!(RegistryWidgetOverride::None),
|
ParsedWidgetOverride::None => quote!(RegistryWidgetOverride::None),
|
||||||
|
|
@ -84,7 +154,7 @@ pub(crate) fn generate_node_code(crate_ident: &CrateIdent, parsed: &ParsedNodeFn
|
||||||
})
|
})
|
||||||
.collect();
|
.collect();
|
||||||
|
|
||||||
let value_sources: Vec<_> = fields
|
let value_sources: Vec<_> = regular_fields
|
||||||
.iter()
|
.iter()
|
||||||
.map(|field| match &field.ty {
|
.map(|field| match &field.ty {
|
||||||
ParsedFieldType::Regular(RegularParsedField { value_source, .. }) => match value_source {
|
ParsedFieldType::Regular(RegularParsedField { value_source, .. }) => match value_source {
|
||||||
|
|
@ -104,7 +174,7 @@ pub(crate) fn generate_node_code(crate_ident: &CrateIdent, parsed: &ParsedNodeFn
|
||||||
})
|
})
|
||||||
.collect();
|
.collect();
|
||||||
|
|
||||||
let default_types: Vec<_> = fields
|
let default_types: Vec<_> = regular_fields
|
||||||
.iter()
|
.iter()
|
||||||
.map(|field| match &field.ty {
|
.map(|field| match &field.ty {
|
||||||
ParsedFieldType::Regular(RegularParsedField { implementations, .. }) => match implementations.first() {
|
ParsedFieldType::Regular(RegularParsedField { implementations, .. }) => match implementations.first() {
|
||||||
|
|
@ -115,7 +185,7 @@ pub(crate) fn generate_node_code(crate_ident: &CrateIdent, parsed: &ParsedNodeFn
|
||||||
})
|
})
|
||||||
.collect();
|
.collect();
|
||||||
|
|
||||||
let number_min_values: Vec<_> = fields
|
let number_min_values: Vec<_> = regular_fields
|
||||||
.iter()
|
.iter()
|
||||||
.map(|field| match &field.ty {
|
.map(|field| match &field.ty {
|
||||||
ParsedFieldType::Regular(RegularParsedField { number_soft_min, number_hard_min, .. }) => match (number_soft_min, number_hard_min) {
|
ParsedFieldType::Regular(RegularParsedField { number_soft_min, number_hard_min, .. }) => match (number_soft_min, number_hard_min) {
|
||||||
|
|
@ -126,7 +196,7 @@ pub(crate) fn generate_node_code(crate_ident: &CrateIdent, parsed: &ParsedNodeFn
|
||||||
_ => quote!(None),
|
_ => quote!(None),
|
||||||
})
|
})
|
||||||
.collect();
|
.collect();
|
||||||
let number_max_values: Vec<_> = fields
|
let number_max_values: Vec<_> = regular_fields
|
||||||
.iter()
|
.iter()
|
||||||
.map(|field| match &field.ty {
|
.map(|field| match &field.ty {
|
||||||
ParsedFieldType::Regular(RegularParsedField { number_soft_max, number_hard_max, .. }) => match (number_soft_max, number_hard_max) {
|
ParsedFieldType::Regular(RegularParsedField { number_soft_max, number_hard_max, .. }) => match (number_soft_max, number_hard_max) {
|
||||||
|
|
@ -137,7 +207,7 @@ pub(crate) fn generate_node_code(crate_ident: &CrateIdent, parsed: &ParsedNodeFn
|
||||||
_ => quote!(None),
|
_ => quote!(None),
|
||||||
})
|
})
|
||||||
.collect();
|
.collect();
|
||||||
let number_mode_range_values: Vec<_> = fields
|
let number_mode_range_values: Vec<_> = regular_fields
|
||||||
.iter()
|
.iter()
|
||||||
.map(|field| match &field.ty {
|
.map(|field| match &field.ty {
|
||||||
ParsedFieldType::Regular(RegularParsedField {
|
ParsedFieldType::Regular(RegularParsedField {
|
||||||
|
|
@ -147,15 +217,15 @@ pub(crate) fn generate_node_code(crate_ident: &CrateIdent, parsed: &ParsedNodeFn
|
||||||
_ => quote!(None),
|
_ => quote!(None),
|
||||||
})
|
})
|
||||||
.collect();
|
.collect();
|
||||||
let number_display_decimal_places: Vec<_> = fields
|
let number_display_decimal_places: Vec<_> = regular_fields
|
||||||
.iter()
|
.iter()
|
||||||
.map(|field| field.number_display_decimal_places.as_ref().map_or(quote!(None), |i| quote!(Some(#i))))
|
.map(|field| field.number_display_decimal_places.as_ref().map_or(quote!(None), |i| quote!(Some(#i))))
|
||||||
.collect();
|
.collect();
|
||||||
let number_step: Vec<_> = fields.iter().map(|field| field.number_step.as_ref().map_or(quote!(None), |i| quote!(Some(#i)))).collect();
|
let number_step: Vec<_> = regular_fields.iter().map(|field| field.number_step.as_ref().map_or(quote!(None), |i| quote!(Some(#i)))).collect();
|
||||||
|
|
||||||
let unit_suffix: Vec<_> = fields.iter().map(|field| field.unit.as_ref().map_or(quote!(None), |i| quote!(Some(#i)))).collect();
|
let unit_suffix: Vec<_> = regular_fields.iter().map(|field| field.unit.as_ref().map_or(quote!(None), |i| quote!(Some(#i)))).collect();
|
||||||
|
|
||||||
let exposed: Vec<_> = fields
|
let exposed: Vec<_> = regular_fields
|
||||||
.iter()
|
.iter()
|
||||||
.map(|field| match &field.ty {
|
.map(|field| match &field.ty {
|
||||||
ParsedFieldType::Regular(RegularParsedField { exposed, .. }) => quote!(#exposed),
|
ParsedFieldType::Regular(RegularParsedField { exposed, .. }) => quote!(#exposed),
|
||||||
|
|
@ -163,7 +233,8 @@ pub(crate) fn generate_node_code(crate_ident: &CrateIdent, parsed: &ParsedNodeFn
|
||||||
})
|
})
|
||||||
.collect();
|
.collect();
|
||||||
|
|
||||||
let eval_args = fields.iter().map(|field| {
|
// Only eval regular fields (data fields are accessed directly as self.field_name)
|
||||||
|
let eval_args = regular_fields.iter().map(|field| {
|
||||||
let name = &field.pat_ident.ident;
|
let name = &field.pat_ident.ident;
|
||||||
match &field.ty {
|
match &field.ty {
|
||||||
ParsedFieldType::Regular { .. } => {
|
ParsedFieldType::Regular { .. } => {
|
||||||
|
|
@ -175,7 +246,8 @@ pub(crate) fn generate_node_code(crate_ident: &CrateIdent, parsed: &ParsedNodeFn
|
||||||
}
|
}
|
||||||
});
|
});
|
||||||
|
|
||||||
let min_max_args = fields.iter().map(|field| match &field.ty {
|
// Only regular fields can have min/max constraints
|
||||||
|
let min_max_args = regular_fields.iter().map(|field| match &field.ty {
|
||||||
ParsedFieldType::Regular(RegularParsedField { number_hard_min, number_hard_max, .. }) => {
|
ParsedFieldType::Regular(RegularParsedField { number_hard_min, number_hard_max, .. }) => {
|
||||||
let name = &field.pat_ident.ident;
|
let name = &field.pat_ident.ident;
|
||||||
let mut tokens = quote!();
|
let mut tokens = quote!();
|
||||||
|
|
@ -208,7 +280,7 @@ pub(crate) fn generate_node_code(crate_ident: &CrateIdent, parsed: &ParsedNodeFn
|
||||||
let mut clauses = Vec::new();
|
let mut clauses = Vec::new();
|
||||||
let mut clampable_clauses = Vec::new();
|
let mut clampable_clauses = Vec::new();
|
||||||
|
|
||||||
for (field, name) in fields.iter().zip(struct_generics.iter()) {
|
for (field, name) in regular_fields.iter().zip(node_generics.iter()) {
|
||||||
clauses.push(match (&field.ty, *is_async) {
|
clauses.push(match (&field.ty, *is_async) {
|
||||||
(
|
(
|
||||||
ParsedFieldType::Regular(RegularParsedField {
|
ParsedFieldType::Regular(RegularParsedField {
|
||||||
|
|
@ -259,13 +331,42 @@ pub(crate) fn generate_node_code(crate_ident: &CrateIdent, parsed: &ParsedNodeFn
|
||||||
);
|
);
|
||||||
struct_where_clause.predicates.extend(extra_where);
|
struct_where_clause.predicates.extend(extra_where);
|
||||||
|
|
||||||
let new_args = struct_generics.iter().zip(field_names.iter()).map(|(r#gen, name)| {
|
// Only regular fields are parameters to new()
|
||||||
|
let new_args = node_generics.iter().zip(regular_field_names.iter()).map(|(r#gen, name)| {
|
||||||
quote! { #name: #r#gen }
|
quote! { #name: #r#gen }
|
||||||
});
|
});
|
||||||
|
|
||||||
|
// Initialize data fields with Default, regular fields with parameters
|
||||||
|
let data_inits = data_field_names.iter().map(|name| {
|
||||||
|
quote! { #name: Default::default() }
|
||||||
|
});
|
||||||
|
let regular_inits = regular_field_names.iter().map(|name| {
|
||||||
|
quote! { #name }
|
||||||
|
});
|
||||||
|
let all_field_inits = data_inits.chain(regular_inits);
|
||||||
|
|
||||||
let async_keyword = is_async.then(|| quote!(async));
|
let async_keyword = is_async.then(|| quote!(async));
|
||||||
let await_keyword = is_async.then(|| quote!(.await));
|
let await_keyword = is_async.then(|| quote!(.await));
|
||||||
|
|
||||||
|
// Data fields may not implement Copy, PartialEq, etc., so only derive Debug and Clone
|
||||||
|
let struct_derives = if data_fields.is_empty() {
|
||||||
|
quote!(#[derive(Debug, Copy, Clone, PartialEq, Eq, PartialOrd, Ord, Hash)])
|
||||||
|
} else {
|
||||||
|
quote!(#[derive(Debug, Clone)])
|
||||||
|
};
|
||||||
|
|
||||||
|
// Generate serialize method if serialize attribute is specified
|
||||||
|
let serialize_impl = if let Some(serialize_fn) = &parsed.attributes.serialize {
|
||||||
|
let data_field_refs = data_field_names.iter().map(|name| quote!(&self.#name));
|
||||||
|
quote! {
|
||||||
|
fn serialize(&self) -> Option<std::sync::Arc<dyn std::any::Any + Send + Sync>> {
|
||||||
|
#serialize_fn(#(#data_field_refs),*)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
quote!()
|
||||||
|
};
|
||||||
|
|
||||||
let eval_impl = quote! {
|
let eval_impl = quote! {
|
||||||
type Output = #core_types::registry::DynFuture<'n, #output_type>;
|
type Output = #core_types::registry::DynFuture<'n, #output_type>;
|
||||||
#[inline]
|
#[inline]
|
||||||
|
|
@ -275,9 +376,11 @@ pub(crate) fn generate_node_code(crate_ident: &CrateIdent, parsed: &ParsedNodeFn
|
||||||
|
|
||||||
#(#eval_args)*
|
#(#eval_args)*
|
||||||
#(#min_max_args)*
|
#(#min_max_args)*
|
||||||
self::#fn_name(__input #(, #field_names)*) #await_keyword
|
self::#fn_name(__input #(, &self.#data_field_names)* #(, #regular_field_names)*) #await_keyword
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
|
#serialize_impl
|
||||||
};
|
};
|
||||||
|
|
||||||
let identifier = format_ident!("{}_proto_ident", fn_name);
|
let identifier = format_ident!("{}_proto_ident", fn_name);
|
||||||
|
|
@ -302,11 +405,11 @@ pub(crate) fn generate_node_code(crate_ident: &CrateIdent, parsed: &ParsedNodeFn
|
||||||
/// Underlying implementation for [#struct_name]
|
/// Underlying implementation for [#struct_name]
|
||||||
#[inline]
|
#[inline]
|
||||||
#[allow(clippy::too_many_arguments)]
|
#[allow(clippy::too_many_arguments)]
|
||||||
#vis #async_keyword fn #fn_name <'n, #(#fn_generics,)*> (#input_ident: #input_type #(, #field_idents: #field_types)*) -> #output_type #where_clause #body
|
#vis #async_keyword fn #fn_name <'n, #(#fn_generics,)*> (#input_ident: #input_type #(, #data_field_idents: #data_field_types)* #(, #field_idents: #field_types)*) -> #output_type #where_clause #body
|
||||||
|
|
||||||
#cfg
|
#cfg
|
||||||
#[automatically_derived]
|
#[automatically_derived]
|
||||||
impl<'n, #(#fn_generics,)* #(#struct_generics,)* #(#future_idents,)*> #core_types::Node<'n, #input_type> for #mod_name::#struct_name<#(#struct_generics,)*>
|
impl<'n, #(#fn_generics,)* #(#node_generics,)* #(#future_idents,)*> #core_types::Node<'n, #input_type> for #mod_name::#struct_name<#(#struct_type_params,)*>
|
||||||
#struct_where_clause
|
#struct_where_clause
|
||||||
{
|
{
|
||||||
#eval_impl
|
#eval_impl
|
||||||
|
|
@ -340,18 +443,18 @@ pub(crate) fn generate_node_code(crate_ident: &CrateIdent, parsed: &ParsedNodeFn
|
||||||
|
|
||||||
static #import_name: core::marker::PhantomData<(#(#all_implementation_types,)*)> = core::marker::PhantomData;
|
static #import_name: core::marker::PhantomData<(#(#all_implementation_types,)*)> = core::marker::PhantomData;
|
||||||
|
|
||||||
#[derive(Debug, Copy, Clone, PartialEq, Eq, PartialOrd, Ord, Hash)]
|
#struct_derives
|
||||||
pub struct #struct_name<#(#struct_generics,)*> {
|
pub struct #struct_name<#(#struct_generic_params,)*> {
|
||||||
#(#struct_fields,)*
|
#(#struct_fields,)*
|
||||||
}
|
}
|
||||||
|
|
||||||
#[automatically_derived]
|
#[automatically_derived]
|
||||||
impl<'n, #(#struct_generics,)*> #struct_name<#(#struct_generics,)*>
|
impl<'n, #(#struct_generic_params,)*> #struct_name<#(#struct_type_params,)*>
|
||||||
{
|
{
|
||||||
#[allow(clippy::too_many_arguments)]
|
#[allow(clippy::too_many_arguments)]
|
||||||
pub fn new(#(#new_args,)*) -> Self {
|
pub fn new(#(#new_args,)*) -> Self {
|
||||||
Self {
|
Self {
|
||||||
#(#field_names,)*
|
#(#all_field_inits,)*
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
@ -493,8 +596,10 @@ fn generate_register_node_impl(parsed: &ParsedNodeFn, field_names: &[&Ident], st
|
||||||
|
|
||||||
let mut constructors = Vec::new();
|
let mut constructors = Vec::new();
|
||||||
let unit = parse_quote!(gcore::Context);
|
let unit = parse_quote!(gcore::Context);
|
||||||
let parameter_types: Vec<_> = parsed
|
|
||||||
.fields
|
let regular_fields: Vec<_> = parsed.fields.iter().filter(|f| !f.is_data_field).collect();
|
||||||
|
|
||||||
|
let parameter_types: Vec<_> = regular_fields
|
||||||
.iter()
|
.iter()
|
||||||
.map(|field| {
|
.map(|field| {
|
||||||
match &field.ty {
|
match &field.ty {
|
||||||
|
|
@ -535,7 +640,7 @@ fn generate_register_node_impl(parsed: &ParsedNodeFn, field_names: &[&Ident], st
|
||||||
let field_name = field_names[j];
|
let field_name = field_names[j];
|
||||||
let (input_type, output_type) = &types[i.min(types.len() - 1)];
|
let (input_type, output_type) = &types[i.min(types.len() - 1)];
|
||||||
|
|
||||||
let node = matches!(parsed.fields[j].ty, ParsedFieldType::Node { .. });
|
let node = matches!(regular_fields[j].ty, ParsedFieldType::Node { .. });
|
||||||
|
|
||||||
let downcast_node = quote!(
|
let downcast_node = quote!(
|
||||||
let #field_name: DowncastBothNode<#input_type, #output_type> = DowncastBothNode::new(args[#j].clone());
|
let #field_name: DowncastBothNode<#input_type, #output_type> = DowncastBothNode::new(args[#j].clone());
|
||||||
|
|
@ -712,3 +817,23 @@ impl FilterUsedGenerics {
|
||||||
self.used(&*modified).cloned().collect()
|
self.used(&*modified).cloned().collect()
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/// Check if a type contains a reference to a specific identifier (e.g., a generic type parameter)
|
||||||
|
fn type_contains_ident(ty: &Type, ident: &Ident) -> bool {
|
||||||
|
struct IdentChecker<'a> {
|
||||||
|
target: &'a Ident,
|
||||||
|
found: bool,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl<'a, 'ast> syn::visit::Visit<'ast> for IdentChecker<'a> {
|
||||||
|
fn visit_ident(&mut self, i: &'ast Ident) {
|
||||||
|
if i == self.target {
|
||||||
|
self.found = true;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
let mut checker = IdentChecker { target: ident, found: false };
|
||||||
|
syn::visit::visit_type(&mut checker, ty);
|
||||||
|
checker.found
|
||||||
|
}
|
||||||
|
|
|
||||||
|
|
@ -50,6 +50,8 @@ pub(crate) struct NodeFnAttributes {
|
||||||
pub(crate) cfg: Option<TokenStream2>,
|
pub(crate) cfg: Option<TokenStream2>,
|
||||||
/// if this node should get a gpu implementation, defaults to None
|
/// if this node should get a gpu implementation, defaults to None
|
||||||
pub(crate) shader_node: Option<ShaderNodeType>,
|
pub(crate) shader_node: Option<ShaderNodeType>,
|
||||||
|
/// Custom serialization function path (e.g., "my_module::custom_serialize")
|
||||||
|
pub(crate) serialize: Option<Path>,
|
||||||
// Add more attributes as needed
|
// Add more attributes as needed
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
@ -112,6 +114,7 @@ pub struct ParsedField {
|
||||||
pub number_display_decimal_places: Option<LitInt>,
|
pub number_display_decimal_places: Option<LitInt>,
|
||||||
pub number_step: Option<LitFloat>,
|
pub number_step: Option<LitFloat>,
|
||||||
pub unit: Option<LitStr>,
|
pub unit: Option<LitStr>,
|
||||||
|
pub is_data_field: bool,
|
||||||
}
|
}
|
||||||
|
|
||||||
#[derive(Clone, Debug)]
|
#[derive(Clone, Debug)]
|
||||||
|
|
@ -201,6 +204,7 @@ impl Parse for NodeFnAttributes {
|
||||||
let mut properties_string = None;
|
let mut properties_string = None;
|
||||||
let mut cfg = None;
|
let mut cfg = None;
|
||||||
let mut shader_node = None;
|
let mut shader_node = None;
|
||||||
|
let mut serialize = None;
|
||||||
|
|
||||||
let content = input;
|
let content = input;
|
||||||
// let content;
|
// let content;
|
||||||
|
|
@ -270,13 +274,23 @@ impl Parse for NodeFnAttributes {
|
||||||
let meta = meta.require_list()?;
|
let meta = meta.require_list()?;
|
||||||
shader_node = Some(syn::parse2(meta.tokens.to_token_stream())?);
|
shader_node = Some(syn::parse2(meta.tokens.to_token_stream())?);
|
||||||
}
|
}
|
||||||
|
"serialize" => {
|
||||||
|
let meta = meta.require_list()?;
|
||||||
|
if serialize.is_some() {
|
||||||
|
return Err(Error::new_spanned(meta, "Multiple 'serialize' attributes are not allowed"));
|
||||||
|
}
|
||||||
|
let parsed_path: Path = meta
|
||||||
|
.parse_args()
|
||||||
|
.map_err(|_| Error::new_spanned(meta, "Expected a valid path for 'serialize', e.g., serialize(my_module::custom_serialize)"))?;
|
||||||
|
serialize = Some(parsed_path);
|
||||||
|
}
|
||||||
_ => {
|
_ => {
|
||||||
return Err(Error::new_spanned(
|
return Err(Error::new_spanned(
|
||||||
meta,
|
meta,
|
||||||
indoc!(
|
indoc!(
|
||||||
r#"
|
r#"
|
||||||
Unsupported attribute in `node`.
|
Unsupported attribute in `node`.
|
||||||
Supported attributes are 'category', 'path' 'name', 'skip_impl', 'cfg' and 'properties'.
|
Supported attributes are 'category', 'path', 'name', 'skip_impl', 'cfg', 'properties', 'serialize', and 'shader_node'.
|
||||||
|
|
||||||
Example usage:
|
Example usage:
|
||||||
#[node_macro::node(category("Value"), name("Test Node"))]
|
#[node_macro::node(category("Value"), name("Test Node"))]
|
||||||
|
|
@ -295,6 +309,7 @@ impl Parse for NodeFnAttributes {
|
||||||
properties_string,
|
properties_string,
|
||||||
cfg,
|
cfg,
|
||||||
shader_node,
|
shader_node,
|
||||||
|
serialize,
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
@ -467,6 +482,9 @@ fn parse_node_implementations<T: Parse>(attr: &Attribute, name: &Ident) -> syn::
|
||||||
fn parse_field(pat_ident: PatIdent, ty: Type, attrs: &[Attribute]) -> syn::Result<ParsedField> {
|
fn parse_field(pat_ident: PatIdent, ty: Type, attrs: &[Attribute]) -> syn::Result<ParsedField> {
|
||||||
let ident = &pat_ident.ident;
|
let ident = &pat_ident.ident;
|
||||||
|
|
||||||
|
// Check if this is a data field (struct field, not a parameter)
|
||||||
|
let is_data_field = extract_attribute(attrs, "data").is_some();
|
||||||
|
|
||||||
let default_value = extract_attribute(attrs, "default")
|
let default_value = extract_attribute(attrs, "default")
|
||||||
.map(|attr| attr.parse_args().map_err(|e| Error::new_spanned(attr, format!("Invalid `default` value for argument '{ident}': {e}"))))
|
.map(|attr| attr.parse_args().map_err(|e| Error::new_spanned(attr, format!("Invalid `default` value for argument '{ident}': {e}"))))
|
||||||
.transpose()?;
|
.transpose()?;
|
||||||
|
|
@ -489,6 +507,25 @@ fn parse_field(pat_ident: PatIdent, ty: Type, attrs: &[Attribute]) -> syn::Resul
|
||||||
|
|
||||||
let exposed = extract_attribute(attrs, "expose").is_some();
|
let exposed = extract_attribute(attrs, "expose").is_some();
|
||||||
|
|
||||||
|
// Validate data field attributes
|
||||||
|
if is_data_field {
|
||||||
|
if default_value.is_some() {
|
||||||
|
return Err(Error::new_spanned(
|
||||||
|
&pat_ident,
|
||||||
|
"Data fields (#[data]) cannot have #[default] attribute. They are automatically initialized with Default::default()",
|
||||||
|
));
|
||||||
|
}
|
||||||
|
if scope.is_some() {
|
||||||
|
return Err(Error::new_spanned(&pat_ident, "Data fields (#[data]) cannot have #[scope] attribute"));
|
||||||
|
}
|
||||||
|
if exposed {
|
||||||
|
return Err(Error::new_spanned(
|
||||||
|
&pat_ident,
|
||||||
|
"Data fields (#[data]) cannot be exposed (#[expose]). They are internal state, not node parameters",
|
||||||
|
));
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
let value_source = match (default_value, scope) {
|
let value_source = match (default_value, scope) {
|
||||||
(Some(_), Some(_)) => return Err(Error::new_spanned(&pat_ident, "Cannot have both `default` and `scope` attributes")),
|
(Some(_), Some(_)) => return Err(Error::new_spanned(&pat_ident, "Cannot have both `default` and `scope` attributes")),
|
||||||
(Some(default_value), _) => ParsedValueSource::Default(default_value),
|
(Some(default_value), _) => ParsedValueSource::Default(default_value),
|
||||||
|
|
@ -586,6 +623,14 @@ fn parse_field(pat_ident: PatIdent, ty: Type, attrs: &[Attribute]) -> syn::Resul
|
||||||
.fold(String::new(), |acc, b| acc + &b + "\n");
|
.fold(String::new(), |acc, b| acc + &b + "\n");
|
||||||
|
|
||||||
if is_node {
|
if is_node {
|
||||||
|
// Data fields cannot be impl Node types
|
||||||
|
if is_data_field {
|
||||||
|
return Err(Error::new_spanned(
|
||||||
|
&ty,
|
||||||
|
"Data fields (#[data]) cannot be of type `impl Node`. Data fields must be concrete types that implement Default",
|
||||||
|
));
|
||||||
|
}
|
||||||
|
|
||||||
let (input_type, output_type) = node_input_type
|
let (input_type, output_type) = node_input_type
|
||||||
.zip(node_output_type)
|
.zip(node_output_type)
|
||||||
.ok_or_else(|| Error::new_spanned(&ty, "Invalid Node type. Expected `impl Node<Input, Output = OutputType>`"))?;
|
.ok_or_else(|| Error::new_spanned(&ty, "Invalid Node type. Expected `impl Node<Input, Output = OutputType>`"))?;
|
||||||
|
|
@ -610,6 +655,7 @@ fn parse_field(pat_ident: PatIdent, ty: Type, attrs: &[Attribute]) -> syn::Resul
|
||||||
number_display_decimal_places,
|
number_display_decimal_places,
|
||||||
number_step,
|
number_step,
|
||||||
unit,
|
unit,
|
||||||
|
is_data_field,
|
||||||
})
|
})
|
||||||
} else {
|
} else {
|
||||||
let implementations = extract_attribute(attrs, "implementations")
|
let implementations = extract_attribute(attrs, "implementations")
|
||||||
|
|
@ -636,6 +682,7 @@ fn parse_field(pat_ident: PatIdent, ty: Type, attrs: &[Attribute]) -> syn::Resul
|
||||||
number_display_decimal_places,
|
number_display_decimal_places,
|
||||||
number_step,
|
number_step,
|
||||||
unit,
|
unit,
|
||||||
|
is_data_field,
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
@ -826,6 +873,7 @@ mod tests {
|
||||||
properties_string: None,
|
properties_string: None,
|
||||||
cfg: None,
|
cfg: None,
|
||||||
shader_node: None,
|
shader_node: None,
|
||||||
|
serialize: None,
|
||||||
},
|
},
|
||||||
fn_name: Ident::new("add", Span::call_site()),
|
fn_name: Ident::new("add", Span::call_site()),
|
||||||
struct_name: Ident::new("Add", Span::call_site()),
|
struct_name: Ident::new("Add", Span::call_site()),
|
||||||
|
|
@ -860,6 +908,7 @@ mod tests {
|
||||||
number_display_decimal_places: None,
|
number_display_decimal_places: None,
|
||||||
number_step: None,
|
number_step: None,
|
||||||
unit: None,
|
unit: None,
|
||||||
|
is_data_field: false,
|
||||||
}],
|
}],
|
||||||
body: TokenStream2::new(),
|
body: TokenStream2::new(),
|
||||||
description: String::from("Multi\nLine\n"),
|
description: String::from("Multi\nLine\n"),
|
||||||
|
|
@ -892,6 +941,7 @@ mod tests {
|
||||||
properties_string: None,
|
properties_string: None,
|
||||||
cfg: None,
|
cfg: None,
|
||||||
shader_node: None,
|
shader_node: None,
|
||||||
|
serialize: None,
|
||||||
},
|
},
|
||||||
fn_name: Ident::new("transform", Span::call_site()),
|
fn_name: Ident::new("transform", Span::call_site()),
|
||||||
struct_name: Ident::new("Transform", Span::call_site()),
|
struct_name: Ident::new("Transform", Span::call_site()),
|
||||||
|
|
@ -920,6 +970,7 @@ mod tests {
|
||||||
number_display_decimal_places: None,
|
number_display_decimal_places: None,
|
||||||
number_step: None,
|
number_step: None,
|
||||||
unit: None,
|
unit: None,
|
||||||
|
is_data_field: false,
|
||||||
},
|
},
|
||||||
ParsedField {
|
ParsedField {
|
||||||
pat_ident: pat_ident("translate"),
|
pat_ident: pat_ident("translate"),
|
||||||
|
|
@ -941,6 +992,7 @@ mod tests {
|
||||||
number_display_decimal_places: None,
|
number_display_decimal_places: None,
|
||||||
number_step: None,
|
number_step: None,
|
||||||
unit: None,
|
unit: None,
|
||||||
|
is_data_field: false,
|
||||||
},
|
},
|
||||||
],
|
],
|
||||||
body: TokenStream2::new(),
|
body: TokenStream2::new(),
|
||||||
|
|
@ -971,6 +1023,7 @@ mod tests {
|
||||||
properties_string: None,
|
properties_string: None,
|
||||||
cfg: None,
|
cfg: None,
|
||||||
shader_node: None,
|
shader_node: None,
|
||||||
|
serialize: None,
|
||||||
},
|
},
|
||||||
fn_name: Ident::new("circle", Span::call_site()),
|
fn_name: Ident::new("circle", Span::call_site()),
|
||||||
struct_name: Ident::new("Circle", Span::call_site()),
|
struct_name: Ident::new("Circle", Span::call_site()),
|
||||||
|
|
@ -1005,6 +1058,7 @@ mod tests {
|
||||||
number_display_decimal_places: None,
|
number_display_decimal_places: None,
|
||||||
number_step: None,
|
number_step: None,
|
||||||
unit: None,
|
unit: None,
|
||||||
|
is_data_field: false,
|
||||||
}],
|
}],
|
||||||
body: TokenStream2::new(),
|
body: TokenStream2::new(),
|
||||||
description: "Test\n".into(),
|
description: "Test\n".into(),
|
||||||
|
|
@ -1033,6 +1087,7 @@ mod tests {
|
||||||
properties_string: None,
|
properties_string: None,
|
||||||
cfg: None,
|
cfg: None,
|
||||||
shader_node: None,
|
shader_node: None,
|
||||||
|
serialize: None,
|
||||||
},
|
},
|
||||||
fn_name: Ident::new("levels", Span::call_site()),
|
fn_name: Ident::new("levels", Span::call_site()),
|
||||||
struct_name: Ident::new("Levels", Span::call_site()),
|
struct_name: Ident::new("Levels", Span::call_site()),
|
||||||
|
|
@ -1072,6 +1127,7 @@ mod tests {
|
||||||
number_display_decimal_places: None,
|
number_display_decimal_places: None,
|
||||||
number_step: None,
|
number_step: None,
|
||||||
unit: None,
|
unit: None,
|
||||||
|
is_data_field: false,
|
||||||
}],
|
}],
|
||||||
body: TokenStream2::new(),
|
body: TokenStream2::new(),
|
||||||
description: String::new(),
|
description: String::new(),
|
||||||
|
|
@ -1107,6 +1163,7 @@ mod tests {
|
||||||
properties_string: None,
|
properties_string: None,
|
||||||
cfg: None,
|
cfg: None,
|
||||||
shader_node: None,
|
shader_node: None,
|
||||||
|
serialize: None,
|
||||||
},
|
},
|
||||||
fn_name: Ident::new("add", Span::call_site()),
|
fn_name: Ident::new("add", Span::call_site()),
|
||||||
struct_name: Ident::new("Add", Span::call_site()),
|
struct_name: Ident::new("Add", Span::call_site()),
|
||||||
|
|
@ -1141,6 +1198,7 @@ mod tests {
|
||||||
number_display_decimal_places: None,
|
number_display_decimal_places: None,
|
||||||
number_step: None,
|
number_step: None,
|
||||||
unit: None,
|
unit: None,
|
||||||
|
is_data_field: false,
|
||||||
}],
|
}],
|
||||||
body: TokenStream2::new(),
|
body: TokenStream2::new(),
|
||||||
description: String::new(),
|
description: String::new(),
|
||||||
|
|
@ -1169,6 +1227,7 @@ mod tests {
|
||||||
properties_string: None,
|
properties_string: None,
|
||||||
cfg: None,
|
cfg: None,
|
||||||
shader_node: None,
|
shader_node: None,
|
||||||
|
serialize: None,
|
||||||
},
|
},
|
||||||
fn_name: Ident::new("load_image", Span::call_site()),
|
fn_name: Ident::new("load_image", Span::call_site()),
|
||||||
struct_name: Ident::new("LoadImage", Span::call_site()),
|
struct_name: Ident::new("LoadImage", Span::call_site()),
|
||||||
|
|
@ -1203,6 +1262,7 @@ mod tests {
|
||||||
number_display_decimal_places: None,
|
number_display_decimal_places: None,
|
||||||
number_step: None,
|
number_step: None,
|
||||||
unit: None,
|
unit: None,
|
||||||
|
is_data_field: false,
|
||||||
}],
|
}],
|
||||||
body: TokenStream2::new(),
|
body: TokenStream2::new(),
|
||||||
description: String::new(),
|
description: String::new(),
|
||||||
|
|
@ -1231,6 +1291,7 @@ mod tests {
|
||||||
properties_string: None,
|
properties_string: None,
|
||||||
cfg: None,
|
cfg: None,
|
||||||
shader_node: None,
|
shader_node: None,
|
||||||
|
serialize: None,
|
||||||
},
|
},
|
||||||
fn_name: Ident::new("custom_node", Span::call_site()),
|
fn_name: Ident::new("custom_node", Span::call_site()),
|
||||||
struct_name: Ident::new("CustomNode", Span::call_site()),
|
struct_name: Ident::new("CustomNode", Span::call_site()),
|
||||||
|
|
|
||||||
|
|
@ -245,6 +245,7 @@ impl PerPixelAdjustCodegen<'_> {
|
||||||
number_display_decimal_places: None,
|
number_display_decimal_places: None,
|
||||||
number_step: None,
|
number_step: None,
|
||||||
unit: None,
|
unit: None,
|
||||||
|
is_data_field: false,
|
||||||
});
|
});
|
||||||
|
|
||||||
// find exactly one gpu_image field, runtime doesn't support more than 1 atm
|
// find exactly one gpu_image field, runtime doesn't support more than 1 atm
|
||||||
|
|
|
||||||
|
|
@ -102,6 +102,11 @@ fn validate_implementations_for_generics(parsed: &ParsedNodeFn) {
|
||||||
|
|
||||||
if !has_skip_impl && !parsed.fn_generics.is_empty() {
|
if !has_skip_impl && !parsed.fn_generics.is_empty() {
|
||||||
for field in &parsed.fields {
|
for field in &parsed.fields {
|
||||||
|
// Skip validation for data fields - they're internal state and can be generic
|
||||||
|
if field.is_data_field {
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
|
||||||
let pat_ident = &field.pat_ident;
|
let pat_ident = &field.pat_ident;
|
||||||
match &field.ty {
|
match &field.ty {
|
||||||
ParsedFieldType::Regular(RegularParsedField { ty, implementations, .. }) => {
|
ParsedFieldType::Regular(RegularParsedField { ty, implementations, .. }) => {
|
||||||
|
|
|
||||||
|
|
@ -1,7 +1,5 @@
|
||||||
|
use core_types::WasmNotSend;
|
||||||
use core_types::memo::*;
|
use core_types::memo::*;
|
||||||
use core_types::{Node, WasmNotSend};
|
|
||||||
use dyn_any::DynFuture;
|
|
||||||
use std::future::Future;
|
|
||||||
use std::hash::DefaultHasher;
|
use std::hash::DefaultHasher;
|
||||||
use std::hash::{Hash, Hasher};
|
use std::hash::{Hash, Hasher};
|
||||||
use std::sync::Arc;
|
use std::sync::Arc;
|
||||||
|
|
@ -14,94 +12,38 @@ use std::sync::Mutex;
|
||||||
/// A cache hit occurs when the Option is Some and has a stored hash matching the hash of the call argument. In this case, the node returns the cached value without re-evaluating the inner node.
|
/// A cache hit occurs when the Option is Some and has a stored hash matching the hash of the call argument. In this case, the node returns the cached value without re-evaluating the inner node.
|
||||||
///
|
///
|
||||||
/// Currently, only one input-output pair is cached. Subsequent calls with different inputs will overwrite the previous cache.
|
/// Currently, only one input-output pair is cached. Subsequent calls with different inputs will overwrite the previous cache.
|
||||||
#[derive(Default)]
|
#[node_macro::node(category(""), path(graphene_core::memo), skip_impl)]
|
||||||
pub struct MemoNode<T, CachedNode> {
|
async fn memo<I: Hash + Send + 'n, T: Clone + WasmNotSend>(input: I, #[data] cache: Arc<Mutex<Option<(u64, T)>>>, node: impl Node<I, Output = T>) -> T {
|
||||||
cache: Arc<Mutex<Option<(u64, T)>>>,
|
let mut hasher = DefaultHasher::new();
|
||||||
node: CachedNode,
|
input.hash(&mut hasher);
|
||||||
}
|
let hash = hasher.finish();
|
||||||
impl<'i, I: Hash + 'i, T: 'i + Clone + WasmNotSend, CachedNode: 'i> Node<'i, I> for MemoNode<T, CachedNode>
|
|
||||||
where
|
|
||||||
CachedNode: for<'any_input> Node<'any_input, I>,
|
|
||||||
for<'a> <CachedNode as Node<'a, I>>::Output: Future<Output = T> + WasmNotSend,
|
|
||||||
{
|
|
||||||
// TODO: This should return a reference to the cached cached_value
|
|
||||||
// but that requires a lot of lifetime magic <- This was suggested by copilot but is pretty accurate xD
|
|
||||||
type Output = DynFuture<'i, T>;
|
|
||||||
fn eval(&'i self, input: I) -> Self::Output {
|
|
||||||
let mut hasher = DefaultHasher::new();
|
|
||||||
input.hash(&mut hasher);
|
|
||||||
let hash = hasher.finish();
|
|
||||||
|
|
||||||
if let Some(data) = self.cache.lock().as_ref().unwrap().as_ref().and_then(|data| (data.0 == hash).then_some(data.1.clone())) {
|
if let Some(data) = cache.lock().as_ref().unwrap().as_ref().and_then(|data| (data.0 == hash).then_some(data.1.clone())) {
|
||||||
Box::pin(async move { data })
|
return data;
|
||||||
} else {
|
|
||||||
let fut = self.node.eval(input);
|
|
||||||
let cache = self.cache.clone();
|
|
||||||
Box::pin(async move {
|
|
||||||
let value = fut.await;
|
|
||||||
*cache.lock().unwrap() = Some((hash, value.clone()));
|
|
||||||
value
|
|
||||||
})
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
fn reset(&self) {
|
let value = node.eval(input).await;
|
||||||
self.cache.lock().unwrap().take();
|
*cache.lock().unwrap() = Some((hash, value.clone()));
|
||||||
}
|
value
|
||||||
}
|
}
|
||||||
|
|
||||||
impl<T, CachedNode> MemoNode<T, CachedNode> {
|
type MonitorValue<I, T> = Arc<Mutex<Option<Arc<IORecord<I, T>>>>>;
|
||||||
pub fn new(node: CachedNode) -> MemoNode<T, CachedNode> {
|
|
||||||
MemoNode { cache: Default::default(), node }
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
#[allow(clippy::module_inception)]
|
|
||||||
pub mod memo {
|
|
||||||
use core_types::ProtoNodeIdentifier;
|
|
||||||
|
|
||||||
pub const IDENTIFIER: ProtoNodeIdentifier = ProtoNodeIdentifier::new("graphene_core::memo::MemoNode");
|
|
||||||
}
|
|
||||||
|
|
||||||
/// Caches the output of the last graph evaluation for introspection.
|
/// Caches the output of the last graph evaluation for introspection.
|
||||||
#[derive(Default)]
|
#[node_macro::node(category(""), path(graphene_core::memo), serialize(serialize_monitor), skip_impl)]
|
||||||
pub struct MonitorNode<I, T, N> {
|
async fn monitor<I: Clone + 'static + Send + Sync, T: Clone + 'static + Send + Sync>(
|
||||||
|
input: I,
|
||||||
#[allow(clippy::type_complexity)]
|
#[allow(clippy::type_complexity)]
|
||||||
io: Arc<Mutex<Option<Arc<IORecord<I, T>>>>>,
|
#[data]
|
||||||
node: N,
|
io: MonitorValue<I, T>,
|
||||||
|
node: impl Node<I, Output = T>,
|
||||||
|
) -> T {
|
||||||
|
let output = node.eval(input.clone()).await;
|
||||||
|
*io.lock().unwrap() = Some(Arc::new(IORecord { input, output: output.clone() }));
|
||||||
|
output
|
||||||
}
|
}
|
||||||
|
|
||||||
impl<'i, T, I, N> Node<'i, I> for MonitorNode<I, T, N>
|
fn serialize_monitor<I: Clone + 'static + Send + Sync, T: Clone + 'static + Send + Sync>(io: &MonitorValue<I, T>) -> Option<Arc<dyn std::any::Any + Send + Sync>> {
|
||||||
where
|
let io = io.lock().unwrap();
|
||||||
I: Clone + 'static + Send + Sync,
|
io.as_ref().map(|output| output.clone() as Arc<dyn std::any::Any + Send + Sync>)
|
||||||
T: Clone + 'static + Send + Sync,
|
|
||||||
for<'a> N: Node<'a, I, Output: Future<Output = T> + WasmNotSend> + 'i,
|
|
||||||
{
|
|
||||||
type Output = DynFuture<'i, T>;
|
|
||||||
fn eval(&'i self, input: I) -> Self::Output {
|
|
||||||
let io = self.io.clone();
|
|
||||||
let output_fut = self.node.eval(input.clone());
|
|
||||||
Box::pin(async move {
|
|
||||||
let output = output_fut.await;
|
|
||||||
*io.lock().unwrap() = Some(Arc::new(IORecord { input, output: output.clone() }));
|
|
||||||
output
|
|
||||||
})
|
|
||||||
}
|
|
||||||
|
|
||||||
fn serialize(&self) -> Option<Arc<dyn std::any::Any + Send + Sync>> {
|
|
||||||
let io = self.io.lock().unwrap();
|
|
||||||
(io).as_ref().map(|output| output.clone() as Arc<dyn std::any::Any + Send + Sync>)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
impl<I, T, N> MonitorNode<I, T, N> {
|
|
||||||
pub fn new(node: N) -> MonitorNode<I, T, N> {
|
|
||||||
MonitorNode { io: Arc::new(Mutex::new(None)), node }
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
pub mod monitor {
|
|
||||||
use core_types::ProtoNodeIdentifier;
|
|
||||||
|
|
||||||
pub const IDENTIFIER: ProtoNodeIdentifier = ProtoNodeIdentifier::new("graphene_core::memo::MonitorNode");
|
|
||||||
}
|
}
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue