From b2397b06c6f022bf75f149a15cda5b5199a59dd0 Mon Sep 17 00:00:00 2001 From: Vlad Rakhmanin <91677083+toadkarter@users.noreply.github.com> Date: Sat, 30 Sep 2023 11:07:29 +0100 Subject: [PATCH] Fix crash when a cycle is introduced into the graph (#1427) * Changing return of topological_sort to Result and propagating error * Simplifying "compile()" method, adding "expect()" to tests. * Removing Result type from "map_gpu()" * Reverting to assertion and removing unnecessary returns --- .../graph-craft/src/graphene_compiler.rs | 19 ++-- node-graph/graph-craft/src/proto.rs | 88 ++++++++++++++----- node-graph/gstd/src/gpu_nodes.rs | 27 +++--- 3 files changed, 94 insertions(+), 40 deletions(-) diff --git a/node-graph/graph-craft/src/graphene_compiler.rs b/node-graph/graph-craft/src/graphene_compiler.rs index daf263ad..f6fe73e7 100644 --- a/node-graph/graph-craft/src/graphene_compiler.rs +++ b/node-graph/graph-craft/src/graphene_compiler.rs @@ -8,7 +8,7 @@ use crate::proto::{LocalFuture, ProtoNetwork}; pub struct Compiler {} impl Compiler { - pub fn compile(&self, mut network: NodeNetwork) -> impl Iterator { + pub fn compile(&self, mut network: NodeNetwork) -> Result, String> { println!("flattening"); let node_ids = network.nodes.keys().copied().collect::>(); for id in node_ids { @@ -17,15 +17,20 @@ impl Compiler { network.remove_redundant_id_nodes(); network.remove_dead_nodes(); let proto_networks = network.into_proto_networks(); - proto_networks.map(move |mut proto_network| { - proto_network.resolve_inputs(); - proto_network.generate_stable_node_ids(); - proto_network - }) + + let proto_networks_result: Vec = proto_networks + .map(move |mut proto_network| { + proto_network.resolve_inputs()?; + proto_network.generate_stable_node_ids(); + Ok(proto_network) + }) + .collect::, String>>()?; + + Ok(proto_networks_result.into_iter()) } pub fn compile_single(&self, network: NodeNetwork) -> Result { assert_eq!(network.outputs.len(), 1, "Graph with multiple outputs not yet handled"); - let Some(proto_network) = self.compile(network).next() else { + let Some(proto_network) = self.compile(network)?.next() else { return Err("Failed to convert graph into proto graph".to_string()); }; Ok(proto_network) diff --git a/node-graph/graph-craft/src/proto.rs b/node-graph/graph-craft/src/proto.rs index 66951823..a2f41c85 100644 --- a/node-graph/graph-craft/src/proto.rs +++ b/node-graph/graph-craft/src/proto.rs @@ -336,9 +336,9 @@ impl ProtoNetwork { edges } - pub fn resolve_inputs(&mut self) { + pub fn resolve_inputs(&mut self) -> Result<(), String> { // Perform topological sort once - self.reorder_ids(); + self.reorder_ids()?; let max_id = self.nodes.len() as NodeId - 1; @@ -370,7 +370,8 @@ impl ProtoNetwork { self.replace_node_id(&outwards_edges, node_id, compose_node_id, true); } } - self.reorder_ids(); + self.reorder_ids()?; + Ok(()) } fn replace_node_id(&mut self, outwards_edges: &HashMap>, node_id: u64, compose_node_id: u64, skip_lambdas: bool) { @@ -392,33 +393,35 @@ impl ProtoNetwork { } }); } - // Based on https://en.wikipedia.org/wiki/Topological_sorting#Depth-first_search // This approach excludes nodes that are not connected - pub fn topological_sort(&self) -> Vec { + pub fn topological_sort(&self) -> Result, String> { let mut sorted = Vec::new(); let inwards_edges = self.collect_inwards_edges(); - fn visit(node_id: NodeId, temp_marks: &mut HashSet, sorted: &mut Vec, inwards_edges: &HashMap>, network: &ProtoNetwork) { + fn visit(node_id: NodeId, temp_marks: &mut HashSet, sorted: &mut Vec, inwards_edges: &HashMap>, network: &ProtoNetwork) -> Result<(), String> { if sorted.contains(&node_id) { - return; + return Ok(()); }; if temp_marks.contains(&node_id) { - panic!("Cycle detected {:#?}, {:#?}", &inwards_edges, &network); + return Err(format!("Cycle detected {:#?}, {:#?}", &inwards_edges, &network)); } if let Some(dependencies) = inwards_edges.get(&node_id) { temp_marks.insert(node_id); for &dependant in dependencies { - visit(dependant, temp_marks, sorted, inwards_edges, network); + visit(dependant, temp_marks, sorted, inwards_edges, network)?; } temp_marks.remove(&node_id); } sorted.push(node_id); + Ok(()) } - assert!(self.nodes.iter().any(|(id, _)| *id == self.output), "Output id {} does not exist", self.output); - visit(self.output, &mut HashSet::new(), &mut sorted, &inwards_edges, self); - sorted + if !self.nodes.iter().any(|(id, _)| *id == self.output) { + return Err(format!("Output id {} does not exist", self.output)); + } + visit(self.output, &mut HashSet::new(), &mut sorted, &inwards_edges, self)?; + Ok(sorted) } fn is_topologically_sorted(&self) -> bool { @@ -465,8 +468,8 @@ impl ProtoNetwork { sorted }*/ - fn reorder_ids(&mut self) { - let order = self.topological_sort(); + fn reorder_ids(&mut self) -> Result<(), String> { + let order = self.topological_sort()?; // Map of node ids to their current index in the nodes vector let current_positions: HashMap<_, _> = self.nodes.iter().enumerate().map(|(pos, (id, _))| (*id, pos)).collect(); @@ -492,6 +495,7 @@ impl ProtoNetwork { self.output = *new_positions.get(&self.output).unwrap(); assert_eq!(order.len(), self.nodes.len()); + Ok(()) } } @@ -687,17 +691,24 @@ mod test { #[test] fn topological_sort() { let construction_network = test_network(); - let sorted = construction_network.topological_sort(); - + let sorted = construction_network.topological_sort().expect("Error when calling 'topological_sort' on 'construction_network."); println!("{:#?}", sorted); assert_eq!(sorted, vec![14, 10, 11, 1]); } + #[test] + fn topological_sort_with_cycles() { + let construction_network = test_network_with_cycles(); + let sorted = construction_network.topological_sort(); + + assert!(sorted.is_err()) + } + #[test] fn id_reordering() { let mut construction_network = test_network(); - construction_network.reorder_ids(); - let sorted = construction_network.topological_sort(); + construction_network.reorder_ids().expect("Error when calling 'reorder_ids' on 'construction_network."); + let sorted = construction_network.topological_sort().expect("Error when calling 'topological_sort' on 'construction_network."); println!("nodes: {:#?}", construction_network.nodes); assert_eq!(sorted, vec![0, 1, 2, 3]); let ids: Vec<_> = construction_network.nodes.iter().map(|(id, _)| *id).collect(); @@ -710,9 +721,9 @@ mod test { #[test] fn id_reordering_idempotent() { let mut construction_network = test_network(); - construction_network.reorder_ids(); - construction_network.reorder_ids(); - let sorted = construction_network.topological_sort(); + construction_network.reorder_ids().expect("Error when calling 'reorder_ids' on 'construction_network."); + construction_network.reorder_ids().expect("Error when calling 'reorder_ids' on 'construction_network."); + let sorted = construction_network.topological_sort().expect("Error when calling 'topological_sort' on 'construction_network."); assert_eq!(sorted, vec![0, 1, 2, 3]); let ids: Vec<_> = construction_network.nodes.iter().map(|(id, _)| *id).collect(); println!("{:#?}", ids); @@ -723,7 +734,7 @@ mod test { #[test] fn input_resolution() { let mut construction_network = test_network(); - construction_network.resolve_inputs(); + construction_network.resolve_inputs().expect("Error when calling 'resolve_inputs' on 'construction_network."); println!("{:#?}", construction_network); assert_eq!(construction_network.nodes[0].1.identifier.name.as_ref(), "value"); assert_eq!(construction_network.nodes.len(), 6); @@ -733,7 +744,7 @@ mod test { #[test] fn stable_node_id_generation() { let mut construction_network = test_network(); - construction_network.resolve_inputs(); + construction_network.resolve_inputs().expect("Error when calling 'resolve_inputs' on 'construction_network."); construction_network.generate_stable_node_ids(); assert_eq!(construction_network.nodes[0].1.identifier.name.as_ref(), "value"); let ids: Vec<_> = construction_network.nodes.iter().map(|(id, _)| *id).collect(); @@ -810,4 +821,35 @@ mod test { .collect(), } } + + fn test_network_with_cycles() -> ProtoNetwork { + ProtoNetwork { + inputs: vec![1], + output: 1, + nodes: [ + ( + 1, + ProtoNode { + identifier: "id".into(), + input: ProtoNodeInput::Node(2, false), + construction_args: ConstructionArgs::Nodes(vec![]), + document_node_path: vec![], + skip_deduplication: false, + }, + ), + ( + 2, + ProtoNode { + identifier: "id".into(), + input: ProtoNodeInput::Node(1, false), + construction_args: ConstructionArgs::Nodes(vec![]), + document_node_path: vec![], + skip_deduplication: false, + }, + ), + ] + .into_iter() + .collect(), + } + } } diff --git a/node-graph/gstd/src/gpu_nodes.rs b/node-graph/gstd/src/gpu_nodes.rs index 6c280c26..244f6c1a 100644 --- a/node-graph/gstd/src/gpu_nodes.rs +++ b/node-graph/gstd/src/gpu_nodes.rs @@ -26,10 +26,10 @@ pub struct GpuCompiler { // TODO: Move to graph-craft #[node_macro::node_fn(GpuCompiler)] -async fn compile_gpu(node: &'input DocumentNode, mut typing_context: TypingContext, io: ShaderIO) -> compilation_client::Shader { +async fn compile_gpu(node: &'input DocumentNode, mut typing_context: TypingContext, io: ShaderIO) -> Result { let compiler = graph_craft::graphene_compiler::Compiler {}; let DocumentNodeImplementation::Network(ref network) = node.implementation else { panic!() }; - let proto_networks: Vec<_> = compiler.compile(network.clone()).collect(); + let proto_networks: Vec<_> = compiler.compile(network.clone())?.collect(); for network in proto_networks.iter() { typing_context.update(network).expect("Failed to type check network"); @@ -43,7 +43,7 @@ async fn compile_gpu(node: &'input DocumentNode, mut typing_context: TypingConte .collect(); let output_types = proto_networks.iter().map(|network| typing_context.type_of(network.output).unwrap().output.clone()).collect(); - compilation_client::compile(proto_networks, input_types, output_types, io).await.unwrap() + Ok(compilation_client::compile(proto_networks, input_types, output_types, io).await.unwrap()) } pub struct MapGpuNode { @@ -97,7 +97,10 @@ async fn map_gpu<'a: 'input>(image: ImageFrame, node: DocumentNode, edito self.cache.borrow().get(&node.name).unwrap().clone() } else { let name = node.name.clone(); - let compute_pass_descriptor = create_compute_pass_descriptor(node, &image, executor, quantization).await; + let Ok(compute_pass_descriptor) = create_compute_pass_descriptor(node, &image, executor, quantization).await else { + log::error!("Error creating compute pass descriptor in 'map_gpu()"); + return ImageFrame::empty(); + }; self.cache.borrow_mut().insert(name, compute_pass_descriptor.clone()); log::error!("created compute pass"); compute_pass_descriptor @@ -156,7 +159,7 @@ async fn create_compute_pass_descriptor( image: &ImageFrame, executor: &&WgpuExecutor, quantization: QuantizationChannels, -) -> ComputePass { +) -> Result, String> { let compiler = graph_craft::graphene_compiler::Compiler {}; let inner_network = NodeNetwork::value_network(node); @@ -246,7 +249,7 @@ async fn create_compute_pass_descriptor( ..Default::default() }; log::debug!("compiling network"); - let proto_networks = compiler.compile(network.clone()).collect(); + let proto_networks = compiler.compile(network.clone())?.collect(); log::debug!("compiling shader"); let shader = compilation_client::compile( proto_networks, @@ -344,10 +347,10 @@ async fn create_compute_pass_descriptor( }; log::debug!("created pipeline"); - ComputePass { + Ok(ComputePass { pipeline_layout: pipeline, readback_buffer: Some(readback_buffer.clone()), - } + }) } /* #[node_macro::node_fn(MapGpuNode)] @@ -417,7 +420,7 @@ pub struct BlendGpuImageNode { async fn blend_gpu_image(foreground: ImageFrame, background: ImageFrame, blend_mode: BlendMode, opacity: f32) -> ImageFrame { let foreground_size = DVec2::new(foreground.image.width as f64, foreground.image.height as f64); let background_size = DVec2::new(background.image.width as f64, background.image.height as f64); - // Transforms a point from the background image to the forground image + // Transforms a point from the background image to the foreground image let bg_to_fg = DAffine2::from_scale(foreground_size) * foreground.transform.inverse() * background.transform * DAffine2::from_scale(1. / background_size); let transform_matrix: Mat2 = bg_to_fg.matrix2.as_mat2(); @@ -464,7 +467,11 @@ async fn blend_gpu_image(foreground: ImageFrame, background: ImageFrame