1075 lines
37 KiB
Rust
1075 lines
37 KiB
Rust
use crate::proto::{ConstructionArgs, ProtoNetwork, ProtoNode, ProtoNodeInput};
|
|
use graphene_core::{NodeIdentifier, Type};
|
|
|
|
use dyn_any::{DynAny, StaticType};
|
|
use glam::IVec2;
|
|
pub use graphene_core::uuid::generate_uuid;
|
|
use graphene_core::TypeDescriptor;
|
|
use std::borrow::Cow;
|
|
use std::collections::{HashMap, HashSet};
|
|
|
|
pub mod value;
|
|
|
|
pub type NodeId = u64;
|
|
|
|
fn merge_ids(a: u64, b: u64) -> u64 {
|
|
use std::hash::{Hash, Hasher};
|
|
let mut hasher = std::collections::hash_map::DefaultHasher::new();
|
|
a.hash(&mut hasher);
|
|
b.hash(&mut hasher);
|
|
hasher.finish()
|
|
}
|
|
|
|
#[derive(Clone, Debug, PartialEq, Default, specta::Type)]
|
|
#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
|
|
pub struct DocumentNodeMetadata {
|
|
pub position: IVec2,
|
|
}
|
|
|
|
impl DocumentNodeMetadata {
|
|
pub fn position(position: impl Into<IVec2>) -> Self {
|
|
Self { position: position.into() }
|
|
}
|
|
}
|
|
|
|
#[derive(Clone, Debug, PartialEq)]
|
|
#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
|
|
pub struct DocumentNode {
|
|
pub name: String,
|
|
pub inputs: Vec<NodeInput>,
|
|
pub implementation: DocumentNodeImplementation,
|
|
pub metadata: DocumentNodeMetadata,
|
|
}
|
|
|
|
impl DocumentNode {
|
|
pub fn populate_first_network_input(&mut self, node_id: NodeId, output_index: usize, offset: usize, lambda: bool) {
|
|
let input = self
|
|
.inputs
|
|
.iter()
|
|
.enumerate()
|
|
.filter(|(_, input)| matches!(input, NodeInput::Network(_) | NodeInput::ShortCircut(_)))
|
|
.nth(offset)
|
|
.unwrap_or_else(|| panic!("no network input found for {self:#?} and offset: {offset}"));
|
|
|
|
let index = input.0;
|
|
self.inputs[index] = NodeInput::Node { node_id, output_index, lambda };
|
|
}
|
|
|
|
fn resolve_proto_node(mut self) -> ProtoNode {
|
|
assert_ne!(self.inputs.len(), 0, "Resolving document node {:#?} with no inputs", self);
|
|
let first = self.inputs.remove(0);
|
|
if let DocumentNodeImplementation::Unresolved(fqn) = self.implementation {
|
|
let (input, mut args) = match first {
|
|
NodeInput::Value { tagged_value, .. } => {
|
|
assert_eq!(self.inputs.len(), 0, "{}, {:?}", &self.name, &self.inputs);
|
|
(ProtoNodeInput::None, ConstructionArgs::Value(tagged_value))
|
|
}
|
|
NodeInput::Node { node_id, output_index, lambda } => {
|
|
assert_eq!(output_index, 0, "Outputs should be flattened before converting to protonode.");
|
|
(ProtoNodeInput::Node(node_id, lambda), ConstructionArgs::Nodes(vec![]))
|
|
}
|
|
NodeInput::Network(ty) => (ProtoNodeInput::Network(ty), ConstructionArgs::Nodes(vec![])),
|
|
NodeInput::ShortCircut(ty) => (ProtoNodeInput::ShortCircut(ty), ConstructionArgs::Nodes(vec![])),
|
|
};
|
|
assert!(!self.inputs.iter().any(|input| matches!(input, NodeInput::Network(_))), "recieved non resolved parameter");
|
|
assert!(!self.inputs.iter().any(|input| matches!(input, NodeInput::ShortCircut(_))), "recieved non resolved parameter");
|
|
assert!(
|
|
!self.inputs.iter().any(|input| matches!(input, NodeInput::Value { .. })),
|
|
"recieved value as parameter. inupts: {:#?}, construction_args: {:#?}",
|
|
&self.inputs,
|
|
&args
|
|
);
|
|
|
|
if let ConstructionArgs::Nodes(nodes) = &mut args {
|
|
nodes.extend(self.inputs.iter().map(|input| match input {
|
|
NodeInput::Node { node_id, lambda, .. } => (*node_id, *lambda),
|
|
_ => unreachable!(),
|
|
}));
|
|
}
|
|
ProtoNode {
|
|
identifier: fqn,
|
|
input,
|
|
construction_args: args,
|
|
}
|
|
} else {
|
|
unreachable!("tried to resolve not flattened node on resolved node");
|
|
}
|
|
}
|
|
|
|
/// Converts all node id inputs to a new id based on a HashMap.
|
|
///
|
|
/// If the node is not in the hashmap then a default input is found based on the node name and input index.
|
|
pub fn map_ids<P>(mut self, default_input: P, new_ids: &HashMap<NodeId, NodeId>) -> Self
|
|
where
|
|
P: Fn(String, usize) -> Option<NodeInput>,
|
|
{
|
|
for (index, input) in self.inputs.iter_mut().enumerate() {
|
|
let &mut NodeInput::Node{node_id: id, output_index, lambda} = input else {
|
|
continue;
|
|
};
|
|
if let Some(&new_id) = new_ids.get(&id) {
|
|
*input = NodeInput::Node {
|
|
node_id: new_id,
|
|
output_index,
|
|
lambda,
|
|
};
|
|
} else if let Some(new_input) = default_input(self.name.clone(), index) {
|
|
*input = new_input;
|
|
} else {
|
|
warn!("Node does not exist in library with that many inputs");
|
|
}
|
|
}
|
|
self
|
|
}
|
|
}
|
|
|
|
/// Represents the possible inputs to a node.
|
|
///
|
|
/// # More about short circuting
|
|
///
|
|
/// In Graphite nodes are functions and by default, these are composed into a single function
|
|
/// by inserting Compose nodes.
|
|
///
|
|
/// ```text
|
|
/// ┌─────────────────┐ ┌──────────────────┐ ┌──────────────────┐
|
|
/// │ │◄──────────────┤ │◄───────────────┤ │
|
|
/// │ A │ │ B │ │ C │
|
|
/// │ ├──────────────►│ ├───────────────►│ │
|
|
/// └─────────────────┘ └──────────────────┘ └──────────────────┘
|
|
/// ```
|
|
///
|
|
/// This is equivalent to calling c(b(a(input))) when evaluating c with input ( `c.eval(input)`).
|
|
/// But sometimes we might want to have a little more control over the order of execution.
|
|
/// This is why we allow nodes to opt out of the input forwarding by consuming the input directly.
|
|
///
|
|
/// ```text
|
|
/// ┌─────────────────────┐ ┌─────────────┐
|
|
/// │ │◄───────────────┤ │
|
|
/// │ Cache Node │ │ C │
|
|
/// │ ├───────────────►│ │
|
|
/// ┌──────────────────┐ ├─────────────────────┤ └─────────────┘
|
|
/// │ │◄──────────────┤ │
|
|
/// │ A │ │ * Cached Node │
|
|
/// │ ├──────────────►│ │
|
|
/// └──────────────────┘ └─────────────────────┘
|
|
/// ```
|
|
///
|
|
/// In this case the Cache node actually consumes its input and then manually forwards it to its parameter Node.
|
|
/// This is necessary because the Cache Node needs to short-circut the actual node evaluation.
|
|
#[derive(Debug, Clone, PartialEq, Hash)]
|
|
#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
|
|
pub enum NodeInput {
|
|
Node {
|
|
node_id: NodeId,
|
|
output_index: usize,
|
|
lambda: bool,
|
|
},
|
|
Value {
|
|
tagged_value: crate::document::value::TaggedValue,
|
|
exposed: bool,
|
|
},
|
|
Network(Type),
|
|
/// A short circuting input represents an input that is not resolved through function composition
|
|
/// but actually consuming the provided input instead of passing it to its predecessor.
|
|
/// See [NodeInput] docs for more explanation.
|
|
ShortCircut(Type),
|
|
}
|
|
|
|
impl NodeInput {
|
|
pub const fn node(node_id: NodeId, output_index: usize) -> Self {
|
|
Self::Node { node_id, output_index, lambda: false }
|
|
}
|
|
pub const fn lambda(node_id: NodeId, output_index: usize) -> Self {
|
|
Self::Node { node_id, output_index, lambda: true }
|
|
}
|
|
pub const fn value(tagged_value: crate::document::value::TaggedValue, exposed: bool) -> Self {
|
|
Self::Value { tagged_value, exposed }
|
|
}
|
|
fn map_ids(&mut self, f: impl Fn(NodeId) -> NodeId) {
|
|
if let &mut NodeInput::Node { node_id, output_index, lambda } = self {
|
|
*self = NodeInput::Node {
|
|
node_id: f(node_id),
|
|
output_index,
|
|
lambda,
|
|
}
|
|
}
|
|
}
|
|
pub fn is_exposed(&self) -> bool {
|
|
match self {
|
|
NodeInput::Node { .. } => true,
|
|
NodeInput::Value { exposed, .. } => *exposed,
|
|
NodeInput::Network(_) => false,
|
|
NodeInput::ShortCircut(_) => false,
|
|
}
|
|
}
|
|
pub fn ty(&self) -> Type {
|
|
match self {
|
|
NodeInput::Node { .. } => unreachable!("ty() called on NodeInput::Node"),
|
|
NodeInput::Value { tagged_value, .. } => tagged_value.ty(),
|
|
NodeInput::Network(ty) => ty.clone(),
|
|
NodeInput::ShortCircut(ty) => ty.clone(),
|
|
}
|
|
}
|
|
}
|
|
|
|
#[derive(Clone, Debug, PartialEq)]
|
|
#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
|
|
pub enum DocumentNodeImplementation {
|
|
Network(NodeNetwork),
|
|
Unresolved(NodeIdentifier),
|
|
}
|
|
|
|
impl Default for DocumentNodeImplementation {
|
|
fn default() -> Self {
|
|
Self::Unresolved(NodeIdentifier::new("graphene_cored::ops::IdNode"))
|
|
}
|
|
}
|
|
|
|
impl DocumentNodeImplementation {
|
|
pub fn get_network(&self) -> Option<&NodeNetwork> {
|
|
if let DocumentNodeImplementation::Network(n) = self {
|
|
Some(n)
|
|
} else {
|
|
None
|
|
}
|
|
}
|
|
|
|
pub fn get_network_mut(&mut self) -> Option<&mut NodeNetwork> {
|
|
if let DocumentNodeImplementation::Network(n) = self {
|
|
Some(n)
|
|
} else {
|
|
None
|
|
}
|
|
}
|
|
}
|
|
|
|
#[derive(Clone, Copy, Debug, Default, PartialEq, DynAny, specta::Type)]
|
|
#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
|
|
pub struct NodeOutput {
|
|
pub node_id: NodeId,
|
|
pub node_output_index: usize,
|
|
}
|
|
impl NodeOutput {
|
|
pub fn new(node_id: NodeId, node_output_index: usize) -> Self {
|
|
Self { node_id, node_output_index }
|
|
}
|
|
}
|
|
|
|
#[derive(Clone, Debug, Default, PartialEq, DynAny)]
|
|
#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
|
|
pub struct NodeNetwork {
|
|
pub inputs: Vec<NodeId>,
|
|
pub outputs: Vec<NodeOutput>,
|
|
pub nodes: HashMap<NodeId, DocumentNode>,
|
|
/// These nodes are replaced with identity nodes when flattening
|
|
pub disabled: Vec<NodeId>,
|
|
/// In the case where a new node is chosen as output - what was the original
|
|
pub previous_outputs: Option<Vec<NodeOutput>>,
|
|
}
|
|
|
|
/// Graph modification functions
|
|
impl NodeNetwork {
|
|
/// Get the original output nodes of this network, ignoring any preview node
|
|
pub fn original_outputs(&self) -> &Vec<NodeOutput> {
|
|
self.previous_outputs.as_ref().unwrap_or(&self.outputs)
|
|
}
|
|
|
|
pub fn input_types<'a>(&'a self) -> impl Iterator<Item = Type> + 'a {
|
|
self.inputs.iter().map(move |id| self.nodes[id].inputs.get(0).map(|i| i.ty().clone()).unwrap_or(concrete!(())))
|
|
}
|
|
|
|
/// An empty graph
|
|
pub fn value_network(node: DocumentNode) -> Self {
|
|
Self {
|
|
inputs: vec![0],
|
|
outputs: vec![NodeOutput::new(0, 0)],
|
|
nodes: [(0, node)].into_iter().collect(),
|
|
disabled: vec![],
|
|
previous_outputs: None,
|
|
}
|
|
}
|
|
/// A graph with just an input node
|
|
pub fn new_network() -> Self {
|
|
Self {
|
|
inputs: vec![0],
|
|
outputs: vec![NodeOutput::new(0, 0)],
|
|
nodes: [(
|
|
0,
|
|
DocumentNode {
|
|
name: "Input Frame".into(),
|
|
inputs: vec![NodeInput::ShortCircut(concrete!(u32))],
|
|
implementation: DocumentNodeImplementation::Unresolved("graphene_core::ops::IdNode".into()),
|
|
metadata: DocumentNodeMetadata { position: (8, 4).into() },
|
|
},
|
|
)]
|
|
.into_iter()
|
|
.collect(),
|
|
..Default::default()
|
|
}
|
|
}
|
|
|
|
/// Appends a new node to the network after the output node and sets it as the new output
|
|
pub fn push_node(&mut self, mut node: DocumentNode) -> NodeId {
|
|
let id = self.nodes.len().try_into().expect("Too many nodes in network");
|
|
// Set the correct position for the new node
|
|
if let Some(pos) = self.nodes.get(&self.original_outputs()[0].node_id).map(|n| n.metadata.position) {
|
|
node.metadata.position = pos + IVec2::new(8, 0);
|
|
}
|
|
if self.outputs.is_empty() {
|
|
self.outputs.push(NodeOutput::new(id, 0));
|
|
}
|
|
let input = NodeInput::node(self.outputs[0].node_id, self.outputs[0].node_output_index);
|
|
if node.inputs.is_empty() {
|
|
node.inputs.push(input);
|
|
} else {
|
|
node.inputs[0] = input;
|
|
}
|
|
self.nodes.insert(id, node);
|
|
self.outputs = vec![NodeOutput::new(id, 0)];
|
|
id
|
|
}
|
|
|
|
/// Adds a output identity node to the network
|
|
pub fn push_output_node(&mut self) -> NodeId {
|
|
let node = DocumentNode {
|
|
name: "Output".into(),
|
|
inputs: vec![],
|
|
implementation: DocumentNodeImplementation::Unresolved("graphene_core::ops::IdNode".into()),
|
|
metadata: DocumentNodeMetadata { position: (0, 0).into() },
|
|
};
|
|
self.push_node(node)
|
|
}
|
|
|
|
/// Adds a Cache and a Clone node to the network
|
|
pub fn push_cache_node(&mut self, ty: Type) -> NodeId {
|
|
let node = DocumentNode {
|
|
name: "Cache".into(),
|
|
inputs: vec![],
|
|
implementation: DocumentNodeImplementation::Network(NodeNetwork {
|
|
inputs: vec![0],
|
|
outputs: vec![NodeOutput::new(1, 0)],
|
|
nodes: vec![
|
|
(
|
|
0,
|
|
DocumentNode {
|
|
name: "CacheNode".to_string(),
|
|
inputs: vec![NodeInput::ShortCircut(concrete!(())), NodeInput::Network(ty)],
|
|
implementation: DocumentNodeImplementation::Unresolved(NodeIdentifier::new("graphene_std::memo::CacheNode")),
|
|
metadata: Default::default(),
|
|
},
|
|
),
|
|
(
|
|
1,
|
|
DocumentNode {
|
|
name: "CloneNode".to_string(),
|
|
inputs: vec![NodeInput::node(0, 0)],
|
|
implementation: DocumentNodeImplementation::Unresolved(NodeIdentifier::new("graphene_core::ops::CloneNode<_>")),
|
|
metadata: Default::default(),
|
|
},
|
|
),
|
|
]
|
|
.into_iter()
|
|
.collect(),
|
|
..Default::default()
|
|
}),
|
|
metadata: DocumentNodeMetadata { position: (0, 0).into() },
|
|
};
|
|
self.push_node(node)
|
|
}
|
|
|
|
/// Get the nested network given by the path of node ids
|
|
pub fn nested_network(&self, nested_path: &[NodeId]) -> Option<&Self> {
|
|
let mut network = Some(self);
|
|
|
|
for segment in nested_path {
|
|
network = network.and_then(|network| network.nodes.get(segment)).and_then(|node| node.implementation.get_network());
|
|
}
|
|
network
|
|
}
|
|
|
|
/// Get the mutable nested network given by the path of node ids
|
|
pub fn nested_network_mut(&mut self, nested_path: &[NodeId]) -> Option<&mut Self> {
|
|
let mut network = Some(self);
|
|
|
|
for segment in nested_path {
|
|
network = network.and_then(|network| network.nodes.get_mut(segment)).and_then(|node| node.implementation.get_network_mut());
|
|
}
|
|
network
|
|
}
|
|
|
|
/// Check if the specified node id is connected to the output
|
|
pub fn connected_to_output(&self, target_node_id: NodeId, ignore_imaginate: bool) -> bool {
|
|
// If the node is the output then return true
|
|
if self.outputs.iter().any(|&NodeOutput { node_id, .. }| node_id == target_node_id) {
|
|
return true;
|
|
}
|
|
// Get the outputs
|
|
let Some(mut stack) = self.outputs.iter().map(|&output| self.nodes.get(&output.node_id)).collect::<Option<Vec<_>>>() else {
|
|
return false;
|
|
};
|
|
let mut already_visited = HashSet::new();
|
|
already_visited.extend(self.outputs.iter().map(|output| output.node_id));
|
|
|
|
while let Some(node) = stack.pop() {
|
|
// Skip the imaginate node inputs
|
|
if ignore_imaginate && node.name == "Imaginate" {
|
|
continue;
|
|
}
|
|
|
|
for input in &node.inputs {
|
|
if let &NodeInput::Node { node_id: ref_id, .. } = input {
|
|
// Skip if already viewed
|
|
if already_visited.contains(&ref_id) {
|
|
continue;
|
|
}
|
|
// If the target node is used as input then return true
|
|
if ref_id == target_node_id {
|
|
return true;
|
|
}
|
|
// Add the referenced node to the stack
|
|
let Some(ref_node) = self.nodes.get(&ref_id) else {
|
|
continue;
|
|
};
|
|
already_visited.insert(ref_id);
|
|
stack.push(ref_node);
|
|
}
|
|
}
|
|
}
|
|
|
|
false
|
|
}
|
|
|
|
/// Is the node being used directly as an output?
|
|
pub fn outputs_contain(&self, node_id: NodeId) -> bool {
|
|
self.outputs.iter().any(|output| output.node_id == node_id)
|
|
}
|
|
|
|
/// Is the node being used directly as an original output?
|
|
pub fn original_outputs_contain(&self, node_id: NodeId) -> bool {
|
|
self.original_outputs().iter().any(|output| output.node_id == node_id)
|
|
}
|
|
|
|
/// Is the node being used directly as a previous output?
|
|
pub fn previous_outputs_contain(&self, node_id: NodeId) -> Option<bool> {
|
|
self.previous_outputs.as_ref().map(|outputs| outputs.iter().any(|output| output.node_id == node_id))
|
|
}
|
|
|
|
/// A iterator of all nodes connected by primary inputs.
|
|
///
|
|
/// Used for the properties panel and tools.
|
|
pub fn primary_flow(&self) -> impl Iterator<Item = (&DocumentNode, u64)> {
|
|
struct FlowIter<'a> {
|
|
stack: Vec<NodeId>,
|
|
network: &'a NodeNetwork,
|
|
}
|
|
impl<'a> Iterator for FlowIter<'a> {
|
|
type Item = (&'a DocumentNode, NodeId);
|
|
fn next(&mut self) -> Option<Self::Item> {
|
|
loop {
|
|
let node_id = self.stack.pop()?;
|
|
if let Some(document_node) = self.network.nodes.get(&node_id) {
|
|
self.stack.extend(
|
|
document_node
|
|
.inputs
|
|
.iter()
|
|
.take(1) // Only show the primary input
|
|
.filter_map(|input| if let NodeInput::Node { node_id: ref_id, .. } = input { Some(*ref_id) } else { None }),
|
|
);
|
|
return Some((document_node, node_id));
|
|
};
|
|
}
|
|
}
|
|
}
|
|
FlowIter {
|
|
stack: self.outputs.iter().map(|output| output.node_id).collect(),
|
|
network: &self,
|
|
}
|
|
}
|
|
}
|
|
|
|
/// Functions for compiling the network
|
|
impl NodeNetwork {
|
|
pub fn map_ids(&mut self, f: impl Fn(NodeId) -> NodeId + Copy) {
|
|
self.inputs.iter_mut().for_each(|id| *id = f(*id));
|
|
self.outputs.iter_mut().for_each(|output| output.node_id = f(output.node_id));
|
|
self.disabled.iter_mut().for_each(|id| *id = f(*id));
|
|
self.previous_outputs
|
|
.iter_mut()
|
|
.for_each(|nodes| nodes.iter_mut().for_each(|output| output.node_id = f(output.node_id)));
|
|
let mut empty = HashMap::new();
|
|
std::mem::swap(&mut self.nodes, &mut empty);
|
|
self.nodes = empty
|
|
.into_iter()
|
|
.map(|(id, mut node)| {
|
|
node.inputs.iter_mut().for_each(|input| input.map_ids(f));
|
|
(f(id), node)
|
|
})
|
|
.collect();
|
|
}
|
|
|
|
/// Collect a hashmap of nodes with a list of the nodes that use it as input
|
|
pub fn collect_outwards_links(&self) -> HashMap<NodeId, Vec<NodeId>> {
|
|
let mut outwards_links: HashMap<u64, Vec<u64>> = HashMap::new();
|
|
for (node_id, node) in &self.nodes {
|
|
for input in &node.inputs {
|
|
if let NodeInput::Node { node_id: ref_id, .. } = input {
|
|
outwards_links.entry(*ref_id).or_default().push(*node_id)
|
|
}
|
|
}
|
|
}
|
|
outwards_links
|
|
}
|
|
|
|
/// When a node has multiple outputs, we actually just duplicate the node and evaluate each output separately
|
|
pub fn duplicate_outputs(&mut self, mut gen_id: &mut impl FnMut() -> NodeId) {
|
|
let mut duplicating_nodes = HashMap::new();
|
|
// Find the nodes where the inputs require duplicating
|
|
for node in &mut self.nodes.values_mut() {
|
|
// Recursively duplicate children
|
|
if let DocumentNodeImplementation::Network(network) = &mut node.implementation {
|
|
network.duplicate_outputs(gen_id);
|
|
}
|
|
|
|
for input in &mut node.inputs {
|
|
let &mut NodeInput::Node { node_id, output_index, .. } = input else {
|
|
continue;
|
|
};
|
|
// Use the initial node when getting the first output
|
|
if output_index == 0 {
|
|
continue;
|
|
}
|
|
// Get the existing duplicated node id (or create a new one)
|
|
let duplicated_node_id = *duplicating_nodes.entry((node_id, output_index)).or_insert_with(&mut gen_id);
|
|
// Use the first output from the duplicated node
|
|
*input = NodeInput::node(duplicated_node_id, 0);
|
|
}
|
|
}
|
|
// Find the network outputs that require duplicating
|
|
for network_output in &mut self.outputs {
|
|
// Use the initial node when getting the first output
|
|
if network_output.node_output_index == 0 {
|
|
continue;
|
|
}
|
|
// Get the existing duplicated node id (or create a new one)
|
|
let duplicated_node_id = *duplicating_nodes.entry((network_output.node_id, network_output.node_output_index)).or_insert_with(&mut gen_id);
|
|
// Use the first output from the duplicated node
|
|
*network_output = NodeOutput::new(duplicated_node_id, 0);
|
|
}
|
|
// Duplicate the nodes
|
|
for ((original_node_id, output_index), new_node_id) in duplicating_nodes {
|
|
let Some(original_node) = self.nodes.get(&original_node_id) else {
|
|
continue;
|
|
};
|
|
let mut new_node = original_node.clone();
|
|
// Update the required outputs from a nested network to be just the relevant output
|
|
if let DocumentNodeImplementation::Network(network) = &mut new_node.implementation {
|
|
if network.outputs.is_empty() {
|
|
continue;
|
|
}
|
|
network.outputs = vec![network.outputs[output_index]];
|
|
}
|
|
self.nodes.insert(new_node_id, new_node);
|
|
}
|
|
|
|
// Ensure all nodes only have one output
|
|
for node in self.nodes.values_mut() {
|
|
if let DocumentNodeImplementation::Network(network) = &mut node.implementation {
|
|
if network.outputs.is_empty() {
|
|
continue;
|
|
}
|
|
network.outputs = vec![network.outputs[0]];
|
|
}
|
|
}
|
|
}
|
|
|
|
/// Removes unused nodes from the graph. Returns a list of bools which represent if each of the inputs have been retained
|
|
pub fn remove_dead_nodes(&mut self) -> Vec<bool> {
|
|
// Take all the nodes out of the nodes list
|
|
let mut old_nodes = std::mem::take(&mut self.nodes);
|
|
let mut stack = self.outputs.iter().map(|output| output.node_id).collect::<Vec<_>>();
|
|
while let Some(node_id) = stack.pop() {
|
|
let Some((node_id, mut document_node)) = old_nodes.remove_entry(&node_id) else {
|
|
continue;
|
|
};
|
|
// Remove dead nodes from child networks
|
|
if let DocumentNodeImplementation::Network(network) = &mut document_node.implementation {
|
|
// Remove inputs to the parent node if they have been removed from the child
|
|
let mut retain_inputs = network.remove_dead_nodes().into_iter();
|
|
document_node.inputs.retain(|_| retain_inputs.next().unwrap_or(true))
|
|
}
|
|
// Visit all nodes that this node references
|
|
stack.extend(
|
|
document_node
|
|
.inputs
|
|
.iter()
|
|
.filter_map(|input| if let NodeInput::Node { node_id, .. } = input { Some(node_id) } else { None }),
|
|
);
|
|
// Add the node back to the list of nodes
|
|
self.nodes.insert(node_id, document_node);
|
|
}
|
|
|
|
// Check if inputs are used and store for return value
|
|
let are_inputs_used = self.inputs.iter().map(|input| self.nodes.contains_key(input)).collect();
|
|
// Remove unused inputs from graph
|
|
self.inputs.retain(|input| self.nodes.contains_key(input));
|
|
|
|
are_inputs_used
|
|
}
|
|
|
|
pub fn flatten(&mut self, node: NodeId) {
|
|
self.flatten_with_fns(node, merge_ids, generate_uuid)
|
|
}
|
|
|
|
/// Recursively dissolve non-primitive document nodes and return a single flattened network of nodes.
|
|
pub fn flatten_with_fns(&mut self, node: NodeId, map_ids: impl Fn(NodeId, NodeId) -> NodeId + Copy, gen_id: impl Fn() -> NodeId + Copy) {
|
|
let (id, mut node) = self
|
|
.nodes
|
|
.remove_entry(&node)
|
|
.unwrap_or_else(|| panic!("The node which was supposed to be flattened does not exist in the network, id {} network {:#?}", node, self));
|
|
|
|
if self.disabled.contains(&id) {
|
|
node.implementation = DocumentNodeImplementation::Unresolved("graphene_core::ops::IdNode".into());
|
|
node.inputs.drain(1..);
|
|
self.nodes.insert(id, node);
|
|
return;
|
|
}
|
|
// replace value inputs with value nodes
|
|
for input in &mut node.inputs {
|
|
let mut dummy_input = NodeInput::ShortCircut(concrete!(()));
|
|
std::mem::swap(&mut dummy_input, input);
|
|
if let NodeInput::Value { tagged_value, exposed } = dummy_input {
|
|
let value_node_id = gen_id();
|
|
let value_node_id = map_ids(id, value_node_id);
|
|
self.nodes.insert(
|
|
value_node_id,
|
|
DocumentNode {
|
|
name: "Value".into(),
|
|
inputs: vec![NodeInput::Value { tagged_value, exposed }],
|
|
implementation: DocumentNodeImplementation::Unresolved("graphene_core::value::ValueNode".into()),
|
|
metadata: DocumentNodeMetadata::default(),
|
|
},
|
|
);
|
|
*input = NodeInput::Node {
|
|
node_id: value_node_id,
|
|
output_index: 0,
|
|
lambda: false,
|
|
};
|
|
} else {
|
|
*input = dummy_input;
|
|
}
|
|
}
|
|
|
|
match node.implementation {
|
|
DocumentNodeImplementation::Network(mut inner_network) => {
|
|
// Connect all network inputs to either the parent network nodes, or newly created value nodes.
|
|
inner_network.map_ids(|inner_id| map_ids(id, inner_id));
|
|
let new_nodes = inner_network.nodes.keys().cloned().collect::<Vec<_>>();
|
|
// Copy nodes from the inner network into the parent network
|
|
self.nodes.extend(inner_network.nodes);
|
|
self.disabled.extend(inner_network.disabled);
|
|
|
|
let mut network_offsets = HashMap::new();
|
|
for (document_input, network_input) in node.inputs.into_iter().zip(inner_network.inputs.iter()) {
|
|
let offset = network_offsets.entry(network_input).or_insert(0);
|
|
match document_input {
|
|
NodeInput::Node { node_id, output_index, lambda } => {
|
|
let network_input = self.nodes.get_mut(network_input).unwrap();
|
|
network_input.populate_first_network_input(node_id, output_index, *offset, lambda);
|
|
}
|
|
NodeInput::Network(_) => {
|
|
*network_offsets.get_mut(network_input).unwrap() += 1;
|
|
if let Some(index) = self.inputs.iter().position(|i| *i == id) {
|
|
self.inputs[index] = *network_input;
|
|
}
|
|
}
|
|
NodeInput::ShortCircut(_) => (),
|
|
NodeInput::Value { .. } => unreachable!("Value inputs should have been replaced with value nodes"),
|
|
}
|
|
}
|
|
node.implementation = DocumentNodeImplementation::Unresolved("graphene_core::ops::IdNode".into());
|
|
node.inputs = inner_network
|
|
.outputs
|
|
.iter()
|
|
.map(|&NodeOutput { node_id, node_output_index }| NodeInput::Node {
|
|
node_id,
|
|
output_index: node_output_index,
|
|
lambda: false,
|
|
})
|
|
.collect();
|
|
|
|
for node_id in new_nodes {
|
|
self.flatten_with_fns(node_id, map_ids, gen_id);
|
|
}
|
|
}
|
|
DocumentNodeImplementation::Unresolved(_) => {}
|
|
}
|
|
assert!(!self.nodes.contains_key(&id), "Trying to insert a node into the network caused an id conflict");
|
|
self.nodes.insert(id, node);
|
|
}
|
|
|
|
pub fn into_proto_networks(self) -> impl Iterator<Item = ProtoNetwork> {
|
|
let mut nodes: Vec<_> = self.nodes.into_iter().map(|(id, node)| (id, node.resolve_proto_node())).collect();
|
|
nodes.sort_unstable_by_key(|(i, _)| *i);
|
|
|
|
// Create a network to evaluate each output
|
|
self.outputs.into_iter().map(move |output| ProtoNetwork {
|
|
inputs: self.inputs.clone(),
|
|
output: output.node_id,
|
|
nodes: nodes.clone(),
|
|
})
|
|
}
|
|
}
|
|
|
|
#[cfg(test)]
|
|
mod test {
|
|
use super::*;
|
|
use crate::proto::{ConstructionArgs, ProtoNetwork, ProtoNode, ProtoNodeInput};
|
|
use graphene_core::NodeIdentifier;
|
|
|
|
fn gen_node_id() -> NodeId {
|
|
static mut NODE_ID: NodeId = 3;
|
|
unsafe {
|
|
NODE_ID += 1;
|
|
NODE_ID
|
|
}
|
|
}
|
|
|
|
fn add_network() -> NodeNetwork {
|
|
NodeNetwork {
|
|
inputs: vec![0, 0],
|
|
outputs: vec![NodeOutput::new(1, 0)],
|
|
nodes: [
|
|
(
|
|
0,
|
|
DocumentNode {
|
|
name: "Cons".into(),
|
|
inputs: vec![NodeInput::Network(concrete!(u32)), NodeInput::Network(concrete!(u32))],
|
|
metadata: DocumentNodeMetadata::default(),
|
|
implementation: DocumentNodeImplementation::Unresolved("graphene_core::structural::ConsNode".into()),
|
|
},
|
|
),
|
|
(
|
|
1,
|
|
DocumentNode {
|
|
name: "Add".into(),
|
|
inputs: vec![NodeInput::node(0, 0)],
|
|
metadata: DocumentNodeMetadata::default(),
|
|
implementation: DocumentNodeImplementation::Unresolved("graphene_core::ops::AddNode".into()),
|
|
},
|
|
),
|
|
]
|
|
.into_iter()
|
|
.collect(),
|
|
..Default::default()
|
|
}
|
|
}
|
|
|
|
#[test]
|
|
fn map_ids() {
|
|
let mut network = add_network();
|
|
network.map_ids(|id| id + 1);
|
|
let maped_add = NodeNetwork {
|
|
inputs: vec![1, 1],
|
|
outputs: vec![NodeOutput::new(2, 0)],
|
|
nodes: [
|
|
(
|
|
1,
|
|
DocumentNode {
|
|
name: "Cons".into(),
|
|
inputs: vec![NodeInput::Network(concrete!(u32)), NodeInput::Network(concrete!(u32))],
|
|
metadata: DocumentNodeMetadata::default(),
|
|
implementation: DocumentNodeImplementation::Unresolved("graphene_core::structural::ConsNode".into()),
|
|
},
|
|
),
|
|
(
|
|
2,
|
|
DocumentNode {
|
|
name: "Add".into(),
|
|
inputs: vec![NodeInput::node(1, 0)],
|
|
metadata: DocumentNodeMetadata::default(),
|
|
implementation: DocumentNodeImplementation::Unresolved("graphene_core::ops::AddNode".into()),
|
|
},
|
|
),
|
|
]
|
|
.into_iter()
|
|
.collect(),
|
|
..Default::default()
|
|
};
|
|
assert_eq!(network, maped_add);
|
|
}
|
|
|
|
#[test]
|
|
fn flatten_add() {
|
|
let mut network = NodeNetwork {
|
|
inputs: vec![1],
|
|
outputs: vec![NodeOutput::new(1, 0)],
|
|
nodes: [(
|
|
1,
|
|
DocumentNode {
|
|
name: "Inc".into(),
|
|
inputs: vec![
|
|
NodeInput::Network(concrete!(u32)),
|
|
NodeInput::Value {
|
|
tagged_value: crate::document::value::TaggedValue::U32(2),
|
|
exposed: false,
|
|
},
|
|
],
|
|
implementation: DocumentNodeImplementation::Network(add_network()),
|
|
metadata: DocumentNodeMetadata::default(),
|
|
},
|
|
)]
|
|
.into_iter()
|
|
.collect(),
|
|
..Default::default()
|
|
};
|
|
network.flatten_with_fns(1, |self_id, inner_id| self_id * 10 + inner_id, gen_node_id);
|
|
let flat_network = flat_network();
|
|
|
|
assert_eq!(flat_network, network);
|
|
}
|
|
|
|
#[test]
|
|
fn resolve_proto_node_add() {
|
|
let document_node = DocumentNode {
|
|
name: "Cons".into(),
|
|
inputs: vec![NodeInput::Network(concrete!(u32)), NodeInput::node(0, 0)],
|
|
metadata: DocumentNodeMetadata::default(),
|
|
implementation: DocumentNodeImplementation::Unresolved("graphene_core::structural::ConsNode".into()),
|
|
};
|
|
|
|
let proto_node = document_node.resolve_proto_node();
|
|
let reference = ProtoNode {
|
|
identifier: "graphene_core::structural::ConsNode".into(),
|
|
input: ProtoNodeInput::Network(concrete!(u32)),
|
|
construction_args: ConstructionArgs::Nodes(vec![(0, false)]),
|
|
};
|
|
assert_eq!(proto_node, reference);
|
|
}
|
|
|
|
#[test]
|
|
fn resolve_flatten_add_as_proto_network() {
|
|
let construction_network = ProtoNetwork {
|
|
inputs: vec![10],
|
|
output: 1,
|
|
nodes: [
|
|
(
|
|
1,
|
|
ProtoNode {
|
|
identifier: "graphene_core::ops::IdNode".into(),
|
|
input: ProtoNodeInput::Node(11, false),
|
|
construction_args: ConstructionArgs::Nodes(vec![]),
|
|
},
|
|
),
|
|
(
|
|
10,
|
|
ProtoNode {
|
|
identifier: "graphene_core::structural::ConsNode".into(),
|
|
input: ProtoNodeInput::Network(concrete!(u32)),
|
|
construction_args: ConstructionArgs::Nodes(vec![(14, false)]),
|
|
},
|
|
),
|
|
(
|
|
11,
|
|
ProtoNode {
|
|
identifier: "graphene_core::ops::AddNode".into(),
|
|
input: ProtoNodeInput::Node(10, false),
|
|
construction_args: ConstructionArgs::Nodes(vec![]),
|
|
},
|
|
),
|
|
(14, ProtoNode::value(ConstructionArgs::Value(crate::document::value::TaggedValue::U32(2)))),
|
|
]
|
|
.into_iter()
|
|
.collect(),
|
|
};
|
|
let network = flat_network();
|
|
let resolved_network = network.into_proto_networks().collect::<Vec<_>>();
|
|
|
|
println!("{:#?}", resolved_network[0]);
|
|
println!("{:#?}", construction_network);
|
|
assert_eq!(resolved_network[0], construction_network);
|
|
}
|
|
|
|
fn flat_network() -> NodeNetwork {
|
|
NodeNetwork {
|
|
inputs: vec![10],
|
|
outputs: vec![NodeOutput::new(1, 0)],
|
|
nodes: [
|
|
(
|
|
1,
|
|
DocumentNode {
|
|
name: "Inc".into(),
|
|
inputs: vec![NodeInput::node(11, 0)],
|
|
metadata: DocumentNodeMetadata::default(),
|
|
implementation: DocumentNodeImplementation::Unresolved("graphene_core::ops::IdNode".into()),
|
|
},
|
|
),
|
|
(
|
|
10,
|
|
DocumentNode {
|
|
name: "Cons".into(),
|
|
inputs: vec![NodeInput::Network(concrete!(u32)), NodeInput::node(14, 0)],
|
|
metadata: DocumentNodeMetadata::default(),
|
|
implementation: DocumentNodeImplementation::Unresolved("graphene_core::structural::ConsNode".into()),
|
|
},
|
|
),
|
|
(
|
|
14,
|
|
DocumentNode {
|
|
name: "Value".into(),
|
|
inputs: vec![NodeInput::Value {
|
|
tagged_value: crate::document::value::TaggedValue::U32(2),
|
|
exposed: false,
|
|
}],
|
|
metadata: DocumentNodeMetadata::default(),
|
|
implementation: DocumentNodeImplementation::Unresolved("graphene_core::value::ValueNode".into()),
|
|
},
|
|
),
|
|
(
|
|
11,
|
|
DocumentNode {
|
|
name: "Add".into(),
|
|
inputs: vec![NodeInput::node(10, 0)],
|
|
metadata: DocumentNodeMetadata::default(),
|
|
implementation: DocumentNodeImplementation::Unresolved("graphene_core::ops::AddNode".into()),
|
|
},
|
|
),
|
|
]
|
|
.into_iter()
|
|
.collect(),
|
|
..Default::default()
|
|
}
|
|
}
|
|
|
|
fn two_node_identity() -> NodeNetwork {
|
|
NodeNetwork {
|
|
inputs: vec![1, 2],
|
|
outputs: vec![NodeOutput::new(1, 0), NodeOutput::new(2, 0)],
|
|
nodes: [
|
|
(
|
|
1,
|
|
DocumentNode {
|
|
name: "Identity 1".into(),
|
|
inputs: vec![NodeInput::Network(concrete!(u32))],
|
|
metadata: DocumentNodeMetadata::default(),
|
|
implementation: DocumentNodeImplementation::Unresolved(NodeIdentifier::new("graphene_core::ops::IdNode")),
|
|
},
|
|
),
|
|
(
|
|
2,
|
|
DocumentNode {
|
|
name: "Identity 2".into(),
|
|
inputs: vec![NodeInput::Network(concrete!(u32))],
|
|
metadata: DocumentNodeMetadata::default(),
|
|
implementation: DocumentNodeImplementation::Unresolved(NodeIdentifier::new("graphene_core::ops::IdNode")),
|
|
},
|
|
),
|
|
]
|
|
.into_iter()
|
|
.collect(),
|
|
..Default::default()
|
|
}
|
|
}
|
|
|
|
fn output_duplicate(network_outputs: Vec<NodeOutput>, result_node_input: NodeInput) -> NodeNetwork {
|
|
let mut network = NodeNetwork {
|
|
inputs: Vec::new(),
|
|
outputs: network_outputs,
|
|
nodes: [
|
|
(
|
|
10,
|
|
DocumentNode {
|
|
name: "Nested network".into(),
|
|
inputs: vec![
|
|
NodeInput::value(crate::document::value::TaggedValue::F32(1.), false),
|
|
NodeInput::value(crate::document::value::TaggedValue::F32(2.), false),
|
|
],
|
|
metadata: DocumentNodeMetadata::default(),
|
|
implementation: DocumentNodeImplementation::Network(two_node_identity()),
|
|
},
|
|
),
|
|
(
|
|
11,
|
|
DocumentNode {
|
|
name: "Result".into(),
|
|
inputs: vec![result_node_input],
|
|
metadata: DocumentNodeMetadata::default(),
|
|
implementation: DocumentNodeImplementation::Unresolved(NodeIdentifier::new("graphene_core::ops::IdNode")),
|
|
},
|
|
),
|
|
]
|
|
.into_iter()
|
|
.collect(),
|
|
..Default::default()
|
|
};
|
|
let mut new_ids = 101..;
|
|
network.duplicate_outputs(&mut || new_ids.next().unwrap());
|
|
network.remove_dead_nodes();
|
|
network
|
|
}
|
|
|
|
#[test]
|
|
fn simple_duplicate() {
|
|
let result = output_duplicate(vec![NodeOutput::new(10, 1)], NodeInput::node(10, 0));
|
|
assert_eq!(result.outputs.len(), 1, "The number of outputs should remain as 1");
|
|
assert_eq!(result.outputs[0], NodeOutput::new(101, 0), "The outer network output should be from a duplicated inner network");
|
|
assert_eq!(result.nodes.keys().copied().collect::<Vec<_>>(), vec![101], "Should just call nested network");
|
|
let nested_network_node = result.nodes.get(&101).unwrap();
|
|
assert_eq!(nested_network_node.name, "Nested network".to_string(), "Name should not change");
|
|
assert_eq!(
|
|
nested_network_node.inputs,
|
|
vec![NodeInput::value(crate::document::value::TaggedValue::F32(2.), false)],
|
|
"Input should be 2"
|
|
);
|
|
let inner_network = nested_network_node.implementation.get_network().expect("Implementation should be network");
|
|
assert_eq!(inner_network.inputs, vec![2], "The input should be sent to the second node");
|
|
assert_eq!(inner_network.outputs, vec![NodeOutput::new(2, 0)], "The output should be node id 2");
|
|
assert_eq!(inner_network.nodes.get(&2).unwrap().name, "Identity 2", "The node should be identity 2");
|
|
}
|
|
|
|
#[test]
|
|
fn out_of_order_duplicate() {
|
|
let result = output_duplicate(vec![NodeOutput::new(10, 1), NodeOutput::new(10, 0)], NodeInput::node(10, 0));
|
|
assert_eq!(result.outputs[0], NodeOutput::new(101, 0), "The first network output should be from a duplicated nested network");
|
|
assert_eq!(result.outputs[1], NodeOutput::new(10, 0), "The second network output should be from the original nested network");
|
|
assert!(
|
|
result.nodes.contains_key(&10) && result.nodes.contains_key(&101) && result.nodes.len() == 2,
|
|
"Network should contain two duplicated nodes"
|
|
);
|
|
for (node_id, input_value, inner_id) in [(10, 1., 1), (101, 2., 2)] {
|
|
let nested_network_node = result.nodes.get(&node_id).unwrap();
|
|
assert_eq!(nested_network_node.name, "Nested network".to_string(), "Name should not change");
|
|
assert_eq!(
|
|
nested_network_node.inputs,
|
|
vec![NodeInput::value(crate::document::value::TaggedValue::F32(input_value), false)],
|
|
"Input should be stable"
|
|
);
|
|
let inner_network = nested_network_node.implementation.get_network().expect("Implementation should be network");
|
|
assert_eq!(inner_network.inputs, vec![inner_id], "The input should be sent to the second node");
|
|
assert_eq!(inner_network.outputs, vec![NodeOutput::new(inner_id, 0)], "The output should be node id");
|
|
assert_eq!(inner_network.nodes.get(&inner_id).unwrap().name, format!("Identity {inner_id}"), "The node should be identity");
|
|
}
|
|
}
|
|
#[test]
|
|
fn using_other_node_duplicate() {
|
|
let result = output_duplicate(vec![NodeOutput::new(11, 0)], NodeInput::node(10, 1));
|
|
assert_eq!(result.outputs, vec![NodeOutput::new(11, 0)], "The network output should be the result node");
|
|
assert!(
|
|
result.nodes.contains_key(&11) && result.nodes.contains_key(&101) && result.nodes.len() == 2,
|
|
"Network should contain a duplicated node and a result node"
|
|
);
|
|
let result_node = result.nodes.get(&11).unwrap();
|
|
assert_eq!(result_node.inputs, vec![NodeInput::node(101, 0)], "Result node should refer to duplicate node as input");
|
|
let nested_network_node = result.nodes.get(&101).unwrap();
|
|
assert_eq!(nested_network_node.name, "Nested network".to_string(), "Name should not change");
|
|
assert_eq!(
|
|
nested_network_node.inputs,
|
|
vec![NodeInput::value(crate::document::value::TaggedValue::F32(2.), false)],
|
|
"Input should be 2"
|
|
);
|
|
let inner_network = nested_network_node.implementation.get_network().expect("Implementation should be network");
|
|
assert_eq!(inner_network.inputs, vec![2], "The input should be sent to the second node");
|
|
assert_eq!(inner_network.outputs, vec![NodeOutput::new(2, 0)], "The output should be node id 2");
|
|
assert_eq!(inner_network.nodes.get(&2).unwrap().name, "Identity 2", "The node should be identity 2");
|
|
}
|
|
}
|