use std::collections::HashSet; use std::error::Error; use std::{collections::HashMap, sync::Arc}; use dyn_any::StaticType; use graph_craft::document::value::UpcastNode; use graph_craft::document::NodeId; use graph_craft::executor::Executor; use graph_craft::proto::{ConstructionArgs, ProtoNetwork, ProtoNode, TypingContext}; use graphene_std::any::{Any, TypeErasedPinned, TypeErasedPinnedRef}; use crate::node_registry; #[derive(Debug, Clone)] pub struct DynamicExecutor { output: NodeId, tree: BorrowTree, typing_context: TypingContext, } impl Default for DynamicExecutor { fn default() -> Self { Self { output: Default::default(), tree: Default::default(), typing_context: TypingContext::new(&node_registry::NODE_REGISTRY), } } } impl DynamicExecutor { pub fn new(proto_network: ProtoNetwork) -> Result { let mut typing_context = TypingContext::new(&node_registry::NODE_REGISTRY); typing_context.update(&proto_network)?; let output = proto_network.output; let tree = BorrowTree::new(proto_network, &typing_context)?; Ok(Self { tree, output, typing_context }) } pub fn update(&mut self, proto_network: ProtoNetwork) -> Result<(), String> { self.output = proto_network.output; self.typing_context.update(&proto_network)?; trace!("setting output to {}", self.output); self.tree.update(proto_network, &self.typing_context)?; Ok(()) } } impl Executor for DynamicExecutor { fn execute<'a, 's: 'a>(&'s self, input: Any<'a>) -> Result, Box> { self.tree.eval_any(self.output, input).ok_or_else(|| "Failed to execute".into()) } } pub struct NodeContainer<'n> { pub node: TypeErasedPinned<'n>, // the dependencies are only kept to ensure that the nodes are not dropped while still in use _dependencies: Vec>>, } impl<'a> core::fmt::Debug for NodeContainer<'a> { fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> std::fmt::Result { f.debug_struct("NodeContainer").finish() } } impl<'a> NodeContainer<'a> { pub fn new(node: TypeErasedPinned<'a>, _dependencies: Vec>>) -> Self { Self { node, _dependencies } } /// Return a static reference to the TypeErasedNode /// # Safety /// This is unsafe because the returned reference is only valid as long as the NodeContainer is alive pub unsafe fn erase_lifetime(self) -> NodeContainer<'static> { std::mem::transmute(self) } } impl NodeContainer<'static> { pub unsafe fn static_ref(&self) -> TypeErasedPinnedRef<'static> { let s = &*(self as *const Self); *(&s.node.as_ref() as *const TypeErasedPinnedRef<'static>) } } #[derive(Default, Debug, Clone)] pub struct BorrowTree { nodes: HashMap>>, } impl BorrowTree { pub fn new(proto_network: ProtoNetwork, typing_context: &TypingContext) -> Result { let mut nodes = BorrowTree::default(); for (id, node) in proto_network.nodes { nodes.push_node(id, node, typing_context)? } Ok(nodes) } /// Pushes new nodes into the tree and return orphaned nodes pub fn update(&mut self, proto_network: ProtoNetwork, typing_context: &TypingContext) -> Result, String> { let mut old_nodes: HashSet<_> = self.nodes.keys().copied().collect(); for (id, node) in proto_network.nodes { if !self.nodes.contains_key(&id) { self.push_node(id, node, typing_context)?; old_nodes.remove(&id); } } Ok(old_nodes.into_iter().collect()) } fn node_refs(&self, nodes: &[NodeId]) -> Vec> { self.node_deps(nodes).into_iter().map(|node| unsafe { node.as_ref().static_ref() }).collect() } fn node_deps(&self, nodes: &[NodeId]) -> Vec>> { nodes.iter().map(|node| self.nodes.get(node).unwrap().clone()).collect() } fn store_node(&mut self, node: Arc>, id: NodeId) -> Arc> { self.nodes.insert(id, node.clone()); node } pub fn get(&self, id: NodeId) -> Option>> { self.nodes.get(&id).cloned() } pub fn eval<'i, I: StaticType + 'i, O: StaticType + 'i>(&self, id: NodeId, input: I) -> Option { let node = self.nodes.get(&id).cloned()?; let output = node.node.eval(Box::new(input)); dyn_any::downcast::(output).ok().map(|o| *o) } pub fn eval_any<'i, 's: 'i>(&'s self, id: NodeId, input: Any<'i>) -> Option> { let node = self.nodes.get(&id)?; Some(node.node.eval(input)) } pub fn free_node(&mut self, id: NodeId) { self.nodes.remove(&id); } pub fn push_node(&mut self, id: NodeId, proto_node: ProtoNode, typing_context: &TypingContext) -> Result<(), String> { let ProtoNode { construction_args, identifier, .. } = proto_node; match construction_args { ConstructionArgs::Value(value) => { let upcasted = UpcastNode::new(value); let node = Box::pin(upcasted) as TypeErasedPinned<'_>; let node = NodeContainer { node, _dependencies: vec![] }; let node = unsafe { node.erase_lifetime() }; self.store_node(Arc::new(node), id); } ConstructionArgs::Nodes(ids) => { let ids: Vec<_> = ids.iter().map(|(id, _)| *id).collect(); let construction_nodes = self.node_refs(&ids); let constructor = typing_context.constructor(id).ok_or(format!("No constructor found for node {:?}", identifier))?; let node = constructor(construction_nodes); let node = NodeContainer { node, _dependencies: self.node_deps(&ids), }; let node = unsafe { node.erase_lifetime() }; self.store_node(Arc::new(node), id); } }; Ok(()) } } #[cfg(test)] mod test { use graph_craft::document::value::TaggedValue; use super::*; #[test] fn push_node() { let mut tree = BorrowTree::default(); let val_1_protonode = ProtoNode::value(ConstructionArgs::Value(TaggedValue::U32(2u32))); tree.push_node(0, val_1_protonode, &TypingContext::default()).unwrap(); let _node = tree.get(0).unwrap(); assert_eq!(tree.eval(0, ()), Some(2u32)); } }