use std::borrow::Cow; use std::collections::{HashMap, HashSet}; use std::hash::Hash; use crate::document::value; use crate::document::NodeId; use dyn_any::DynAny; use graphene_core::*; use std::pin::Pin; pub type Any<'n> = Box + 'n>; pub type TypeErasedNode<'n> = dyn for<'i> NodeIO<'i, Any<'i>, Output = Any<'i>> + 'n + Send + Sync; pub type TypeErasedPinnedRef<'n> = Pin<&'n (dyn for<'i> NodeIO<'i, Any<'i>, Output = Any<'i>> + 'n + Send + Sync)>; pub type TypeErasedPinned<'n> = Pin NodeIO<'i, Any<'i>, Output = Any<'i>> + 'n + Send + Sync>>; pub type NodeConstructor = for<'a> fn(Vec>) -> TypeErasedPinned<'static>; #[derive(Debug, Default, PartialEq)] pub struct ProtoNetwork { // Should a proto Network even allow inputs? Don't think so pub inputs: Vec, pub output: NodeId, pub nodes: Vec<(NodeId, ProtoNode)>, } impl core::fmt::Display for ProtoNetwork { fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result { f.write_str("Proto Network with nodes: ")?; fn write_node(f: &mut core::fmt::Formatter<'_>, network: &ProtoNetwork, id: NodeId, indent: usize) -> core::fmt::Result { f.write_str(&"\t".repeat(indent))?; let Some((_, node)) = network.nodes.iter().find(|(node_id, _)|*node_id == id) else{ return f.write_str("{{Unknown Node}}"); }; f.write_str("Node: ")?; f.write_str(&node.identifier.name)?; f.write_str("\n")?; f.write_str(&"\t".repeat(indent))?; f.write_str("{\n")?; f.write_str(&"\t".repeat(indent + 1))?; f.write_str("Primary input: ")?; match &node.input { ProtoNodeInput::None => f.write_str("None")?, ProtoNodeInput::Network(ty) => f.write_fmt(format_args!("Network (type = {:?})", ty))?, ProtoNodeInput::Node(_) => f.write_str("Node")?, } f.write_str("\n")?; match &node.construction_args { ConstructionArgs::Value(value) => { f.write_str(&"\t".repeat(indent + 1))?; f.write_fmt(format_args!("Value construction argument: {value:?}"))? } ConstructionArgs::Nodes(nodes) => { for id in nodes { write_node(f, network, *id, indent + 1)?; } } } f.write_str(&"\t".repeat(indent))?; f.write_str("}\n")?; Ok(()) } let id = self.output; write_node(f, self, id, 0) } } #[derive(Debug, Clone)] pub enum ConstructionArgs { Value(value::TaggedValue), Nodes(Vec), } impl PartialEq for ConstructionArgs { fn eq(&self, other: &Self) -> bool { match (&self, &other) { (Self::Nodes(n1), Self::Nodes(n2)) => n1 == n2, (Self::Value(v1), Self::Value(v2)) => v1 == v2, _ => core::mem::discriminant(self) == core::mem::discriminant(other), } } } impl Hash for ConstructionArgs { fn hash(&self, state: &mut H) { match self { Self::Nodes(nodes) => { "nodes".hash(state); for node in nodes { node.hash(state); } } Self::Value(value) => value.hash(state), } } } impl ConstructionArgs { pub fn new_function_args(&self) -> Vec { match self { ConstructionArgs::Nodes(nodes) => nodes.iter().map(|n| format!("n{}", n)).collect(), ConstructionArgs::Value(value) => vec![format!("{:?}", value)], } } } #[derive(Debug, PartialEq, Clone)] pub struct ProtoNode { pub construction_args: ConstructionArgs, pub input: ProtoNodeInput, pub identifier: NodeIdentifier, } #[derive(Debug, PartialEq, Eq, Clone)] pub enum ProtoNodeInput { None, Network(Type), Node(NodeId), } impl ProtoNodeInput { pub fn unwrap_node(self) -> NodeId { match self { ProtoNodeInput::Node(id) => id, _ => panic!("tried to unwrap id from non node input \n node: {:#?}", self), } } } impl ProtoNode { pub fn stable_node_id(&self) -> Option { use std::hash::Hasher; let mut hasher = std::collections::hash_map::DefaultHasher::new(); self.identifier.name.hash(&mut hasher); self.construction_args.hash(&mut hasher); match self.input { ProtoNodeInput::None => "none".hash(&mut hasher), ProtoNodeInput::Network(ref ty) => { "network".hash(&mut hasher); ty.hash(&mut hasher); } ProtoNodeInput::Node(id) => id.hash(&mut hasher), }; Some(hasher.finish() as NodeId) } pub fn value(value: ConstructionArgs) -> Self { Self { identifier: NodeIdentifier::new("graphene_core::value::ValueNode"), construction_args: value, input: ProtoNodeInput::None, } } pub fn map_ids(&mut self, f: impl Fn(NodeId) -> NodeId) { if let ProtoNodeInput::Node(id) = self.input { self.input = ProtoNodeInput::Node(f(id)) } if let ConstructionArgs::Nodes(ids) = &mut self.construction_args { ids.iter_mut().for_each(|id| *id = f(*id)); } } pub fn unwrap_construction_nodes(&self) -> Vec { match &self.construction_args { ConstructionArgs::Nodes(nodes) => nodes.clone(), _ => panic!("tried to unwrap nodes from non node construction args \n node: {:#?}", self), } } } impl ProtoNetwork { fn check_ref(&self, ref_id: &NodeId, id: &NodeId) { assert!( self.nodes.iter().any(|(check_id, _)| check_id == ref_id), "Node id:{} has a reference which uses node id:{} which doesn't exist in network {:#?}", id, ref_id, self ); } pub fn collect_outwards_edges(&self) -> HashMap> { let mut edges: HashMap> = HashMap::new(); for (id, node) in &self.nodes { if let ProtoNodeInput::Node(ref_id) = &node.input { self.check_ref(ref_id, id); edges.entry(*ref_id).or_default().push(*id) } if let ConstructionArgs::Nodes(ref_nodes) = &node.construction_args { for ref_id in ref_nodes { self.check_ref(ref_id, id); edges.entry(*ref_id).or_default().push(*id) } } } edges } pub fn generate_stable_node_ids(&mut self) { for i in 0..self.nodes.len() { self.generate_stable_node_id(i); } } pub fn generate_stable_node_id(&mut self, index: usize) -> NodeId { let mut lookup = self.nodes.iter().map(|(id, _)| (*id, *id)).collect::>(); if let Some(sni) = self.nodes[index].1.stable_node_id() { lookup.insert(self.nodes[index].0, sni); self.replace_node_references(&lookup); self.nodes[index].0 = sni; sni } else { panic!("failed to generate stable node id for node {:#?}", self.nodes[index].1); } } pub fn collect_inwards_edges(&self) -> HashMap> { let mut edges: HashMap> = HashMap::new(); for (id, node) in &self.nodes { if let ProtoNodeInput::Node(ref_id) = &node.input { self.check_ref(ref_id, id); edges.entry(*id).or_default().push(*ref_id) } if let ConstructionArgs::Nodes(ref_nodes) = &node.construction_args { for ref_id in ref_nodes { self.check_ref(ref_id, id); edges.entry(*id).or_default().push(*ref_id) } } } edges } pub fn resolve_inputs(&mut self) { let mut resolved = HashSet::new(); while !self.resolve_inputs_impl(&mut resolved) {} } fn resolve_inputs_impl(&mut self, resolved: &mut HashSet) -> bool { self.reorder_ids(); let mut lookup = self.nodes.iter().map(|(id, _)| (*id, *id)).collect::>(); let compose_node_id = self.nodes.len() as NodeId; let inputs = self.nodes.iter().map(|(_, node)| node.input.clone()).collect::>(); let resolved_lookup = resolved.clone(); if let Some((input_node, id, input)) = self.nodes.iter_mut().filter(|(id, _)| !resolved_lookup.contains(id)).find_map(|(id, node)| { if let ProtoNodeInput::Node(input_node) = node.input { resolved.insert(*id); let pre_node_input = inputs.get(input_node as usize).expect("input node should exist"); Some((input_node, *id, pre_node_input.clone())) } else { None } }) { lookup.insert(id, compose_node_id); self.replace_node_references(&lookup); self.nodes.push(( compose_node_id, ProtoNode { identifier: NodeIdentifier::new("graphene_core::structural::ComposeNode<_, _, _>"), construction_args: ConstructionArgs::Nodes(vec![input_node, id]), input, }, )); return false; } true } // Based on https://en.wikipedia.org/wiki/Topological_sorting#Depth-first_search // This approach excludes nodes that are not connected pub fn topological_sort(&self) -> Vec { let mut sorted = Vec::new(); let inwards_edges = self.collect_inwards_edges(); fn visit(node_id: NodeId, temp_marks: &mut HashSet, sorted: &mut Vec, inwards_edges: &HashMap>) { if sorted.contains(&node_id) { return; }; if temp_marks.contains(&node_id) { panic!("Cycle detected"); } if let Some(dependencies) = inwards_edges.get(&node_id) { temp_marks.insert(node_id); for &dependant in dependencies { visit(dependant, temp_marks, sorted, inwards_edges); } temp_marks.remove(&node_id); } sorted.push(node_id); } assert!(self.nodes.iter().any(|(id, _)| *id == self.output), "Output id {} does not exist", self.output); visit(self.output, &mut HashSet::new(), &mut sorted, &inwards_edges); sorted } /*// Based on https://en.wikipedia.org/wiki/Topological_sorting#Kahn's_algorithm pub fn topological_sort(&self) -> Vec { let mut sorted = Vec::new(); let outwards_edges = self.collect_outwards_edges(); let mut inwards_edges = self.collect_inwards_edges(); let mut no_incoming_edges: Vec<_> = self.nodes.iter().map(|entry| entry.0).filter(|id| !inwards_edges.contains_key(id)).collect(); assert_ne!(no_incoming_edges.len(), 0, "Acyclic graphs must have at least one node with no incoming edge"); while let Some(node_id) = no_incoming_edges.pop() { sorted.push(node_id); if let Some(outwards_edges) = outwards_edges.get(&node_id) { for &ref_id in outwards_edges { let dependencies = inwards_edges.get_mut(&ref_id).unwrap(); dependencies.retain(|&id| id != node_id); if dependencies.is_empty() { no_incoming_edges.push(ref_id) } } } } info!("Sorted order {sorted:?}"); sorted }*/ pub fn reorder_ids(&mut self) { let order = self.topological_sort(); // Map of node ids to indexes (which become the node ids as they are inserted into the borrow stack) let lookup: HashMap<_, _> = order.iter().enumerate().map(|(pos, id)| (*id, pos as NodeId)).collect(); self.nodes = order .iter() .enumerate() .map(|(pos, id)| { let node = self.nodes.swap_remove(self.nodes.iter().position(|(test_id, _)| test_id == id).unwrap()).1; (pos as NodeId, node) }) .collect(); self.replace_node_references(&lookup); assert_eq!(order.len(), self.nodes.len()); } fn replace_node_references(&mut self, lookup: &HashMap) { self.nodes.iter_mut().for_each(|(_, node)| { node.map_ids(|id| *lookup.get(&id).expect("node not found in lookup table")); }); self.inputs = self.inputs.iter().filter_map(|id| lookup.get(id).copied()).collect(); self.output = *lookup.get(&self.output).unwrap(); } } /// The `TypingContext` is used to store the types of the nodes indexed by their stable node id. #[derive(Debug, Default, Clone, PartialEq, Eq)] pub struct TypingContext { lookup: Cow<'static, HashMap>>, inferred: HashMap, constructor: HashMap, } impl TypingContext { /// Creates a new `TypingContext` with the given lookup table. pub fn new(lookup: &'static HashMap>) -> Self { Self { lookup: Cow::Borrowed(lookup), ..Default::default() } } /// Updates the `TypingContext` wtih a given proto network. This will infer the types of the nodes /// and store them in the `inferred` field. The proto network has to be topologically sorted /// and contain fully resolved stable node ids. pub fn update(&mut self, network: &ProtoNetwork) -> Result<(), String> { for (id, node) in network.nodes.iter() { self.infer(*id, node)?; } Ok(()) } /// Returns the node constructor for a given node id. pub fn constructor(&self, node_id: NodeId) -> Option { self.constructor.get(&node_id).copied() } /// Returns the inferred types for a given node id. pub fn infer(&mut self, node_id: NodeId, node: &ProtoNode) -> Result { let identifier = node.identifier.name.clone(); // Return the inferred type if it is already known if let Some(infered) = self.inferred.get(&node_id) { return Ok(infered.clone()); } let parameters = match node.construction_args { // If the node has a value parameter we can infer the return type from it ConstructionArgs::Value(ref v) => { assert!(matches!(node.input, ProtoNodeInput::None)); let types = NodeIOTypes::new(concrete!(()), v.ty(), vec![]); self.inferred.insert(node_id, types.clone()); return Ok(types); } // If the node has nodes as parameters we can infer the types from the node outputs ConstructionArgs::Nodes(ref nodes) => nodes .iter() .map(|id| { self.inferred .get(id) .ok_or(format!("Inferring type of {node_id} depends on {id} which is not present in the typing context")) .map(|node| (node.input.clone(), node.output.clone())) }) .collect::, String>>()?, }; // Get the node input type from the proto node declaration let input = match node.input { ProtoNodeInput::None => concrete!(()), ProtoNodeInput::Network(ref ty) => ty.clone(), ProtoNodeInput::Node(id) => { let input = self .inferred .get(&id) .ok_or(format!("Inferring type of {node_id} depends on {id} which is not present in the typing context"))?; input.output.clone() } }; let impls = self.lookup.get(&node.identifier).ok_or(format!("No implementations found for {:?}", node.identifier))?; if matches!(input, Type::Generic(_)) { return Err(format!("Generic types are not supported as inputs yet {:?} occured in {:?}", &input, node.identifier)); } if parameters.iter().any(|p| matches!(p.1, Type::Generic(_))) { return Err(format!("Generic types are not supported in parameters: {:?} occured in {:?}", parameters, node.identifier)); } let covariant = |output, input| match (&output, &input) { (Type::Concrete(t1), Type::Concrete(t2)) => t1 == t2, (Type::Concrete(_), Type::Generic(_)) => true, // TODO: relax this requirement when allowing generic types as inputs (Type::Generic(_), _) => false, }; // List of all implementations that match the input and parameter types let valid_output_types = impls .keys() .filter(|node_io| { covariant(input.clone(), node_io.input.clone()) && parameters .iter() .zip(node_io.parameters.iter()) .all(|(p1, p2)| covariant(p1.0.clone(), p2.0.clone()) && covariant(p1.1.clone(), p2.1.clone())) }) .collect::>(); // Attempt to substitute generic types with concrete types and save the list of results let substitution_results = valid_output_types .iter() .map(|node_io| { collect_generics(node_io) .iter() .try_for_each(|generic| check_generic(node_io, &input, ¶meters, generic).map(|_| ())) .map(|_| { if let Type::Generic(out) = &node_io.output { ((*node_io).clone(), check_generic(node_io, &input, ¶meters, out).unwrap()) } else { ((*node_io).clone(), node_io.output.clone()) } }) }) .collect::>(); // Collect all substitutions that are valid let valid_impls = substitution_results.iter().filter_map(|result| result.as_ref().ok()).collect::>(); match valid_impls.as_slice() { [] => { dbg!(&self.inferred); Err(format!( "No implementations found for {identifier} with \ninput: {input:?} and \nparameters: {parameters:?}.\nOther Implementations found: {:?}", impls, )) } [(org_nio, output)] => { let node_io = NodeIOTypes::new(input, (*output).clone(), parameters); // Save the inferred type self.inferred.insert(node_id, node_io.clone()); self.constructor.insert(node_id, impls[org_nio]); Ok(node_io) } _ => Err(format!( "Multiple implementations found for {identifier} with input {input:?} and parameters {parameters:?} (valid types: {valid_output_types:?}" )), } } } /// Returns a list of all generic types used in the node fn collect_generics(types: &NodeIOTypes) -> Vec> { let inputs = [&types.input].into_iter().chain(types.parameters.iter().map(|(_, x)| x)); let mut generics = inputs .filter_map(|t| match t { Type::Generic(out) => Some(out.clone()), _ => None, }) .collect::>(); if let Type::Generic(out) = &types.output { generics.push(out.clone()); } generics.dedup(); generics } /// Checks if a generic type can be substituted with a concrete type and returns the concrete type fn check_generic(types: &NodeIOTypes, input: &Type, parameters: &[(Type, Type)], generic: &str) -> Result { let inputs = [(&types.input, input)] .into_iter() .chain(types.parameters.iter().map(|(_, x)| x).zip(parameters.iter().map(|(_, x)| x))); let mut concrete_inputs = inputs.filter(|(ni, _)| matches!(ni, Type::Generic(input) if generic == input)); let (_, out_ty) = concrete_inputs .next() .ok_or_else(|| format!("Generic output type {generic} is not dependent on input {input:?} or parameters {parameters:?}",))?; if concrete_inputs.any(|(_, ty)| ty != out_ty) { return Err(format!("Generic output type {generic} is dependent on multiple inputs or parameters",)); } Ok(out_ty.clone()) } #[cfg(test)] mod test { use super::*; use crate::proto::{ConstructionArgs, ProtoNetwork, ProtoNode, ProtoNodeInput}; #[test] fn topological_sort() { let construction_network = test_network(); let sorted = construction_network.topological_sort(); println!("{:#?}", sorted); assert_eq!(sorted, vec![14, 10, 11, 1]); } #[test] fn id_reordering() { let mut construction_network = test_network(); construction_network.reorder_ids(); let sorted = construction_network.topological_sort(); println!("nodes: {:#?}", construction_network.nodes); assert_eq!(sorted, vec![0, 1, 2, 3]); let ids: Vec<_> = construction_network.nodes.iter().map(|(id, _)| *id).collect(); println!("{:#?}", ids); println!("nodes: {:#?}", construction_network.nodes); assert_eq!(construction_network.nodes[0].1.identifier.name.as_ref(), "value"); assert_eq!(ids, vec![0, 1, 2, 3]); } #[test] fn id_reordering_idempotent() { let mut construction_network = test_network(); construction_network.reorder_ids(); construction_network.reorder_ids(); let sorted = construction_network.topological_sort(); assert_eq!(sorted, vec![0, 1, 2, 3]); let ids: Vec<_> = construction_network.nodes.iter().map(|(id, _)| *id).collect(); println!("{:#?}", ids); assert_eq!(construction_network.nodes[0].1.identifier.name.as_ref(), "value"); assert_eq!(ids, vec![0, 1, 2, 3]); } #[test] fn input_resolution() { let mut construction_network = test_network(); construction_network.resolve_inputs(); println!("{:#?}", construction_network); assert_eq!(construction_network.nodes[0].1.identifier.name.as_ref(), "value"); assert_eq!(construction_network.nodes.len(), 6); assert_eq!(construction_network.nodes[5].1.construction_args, ConstructionArgs::Nodes(vec![3, 4])); } #[test] fn stable_node_id_generation() { let mut construction_network = test_network(); construction_network.reorder_ids(); construction_network.generate_stable_node_ids(); construction_network.resolve_inputs(); construction_network.generate_stable_node_ids(); assert_eq!(construction_network.nodes[0].1.identifier.name.as_ref(), "value"); let ids: Vec<_> = construction_network.nodes.iter().map(|(id, _)| *id).collect(); assert_eq!( ids, vec![ 15907139529964845467, 14192092348022507362, 14714934190542167928, 4518275895314664278, 13912679582583718470, 3236993912700824422 ] ); } fn test_network() -> ProtoNetwork { ProtoNetwork { inputs: vec![10], output: 1, nodes: [ ( 7, ProtoNode { identifier: "id".into(), input: ProtoNodeInput::Node(11), construction_args: ConstructionArgs::Nodes(vec![]), }, ), ( 1, ProtoNode { identifier: "id".into(), input: ProtoNodeInput::Node(11), construction_args: ConstructionArgs::Nodes(vec![]), }, ), ( 10, ProtoNode { identifier: "cons".into(), input: ProtoNodeInput::Network(concrete!(u32)), construction_args: ConstructionArgs::Nodes(vec![14]), }, ), ( 11, ProtoNode { identifier: "add".into(), input: ProtoNodeInput::Node(10), construction_args: ConstructionArgs::Nodes(vec![]), }, ), ( 14, ProtoNode { identifier: "value".into(), input: ProtoNodeInput::None, construction_args: ConstructionArgs::Value(value::TaggedValue::U32(2)), }, ), ] .into_iter() .collect(), } } }