diff --git a/editor/src/messages/portfolio/document/node_graph/node_graph_message_handler/document_node_types.rs b/editor/src/messages/portfolio/document/node_graph/node_graph_message_handler/document_node_types.rs index b3c5dafe..b9712fda 100644 --- a/editor/src/messages/portfolio/document/node_graph/node_graph_message_handler/document_node_types.rs +++ b/editor/src/messages/portfolio/document/node_graph/node_graph_message_handler/document_node_types.rs @@ -793,6 +793,28 @@ fn static_nodes() -> Vec { outputs: vec![DocumentOutputType::new("Vector", FrontendGraphDataType::Subpath)], properties: node_properties::stroke_properties, }, + DocumentNodeType { + name: "Image Segmentation", + category: "Image Adjustments", + identifier: NodeImplementation::proto("graphene_std::image_segmentation::ImageSegmentationNode<_>"), + inputs: vec![ + DocumentInputType::value("Image", TaggedValue::ImageFrame(ImageFrame::empty()), true), + DocumentInputType::value("Mask", TaggedValue::ImageFrame(ImageFrame::empty()), true), + ], + outputs: vec![DocumentOutputType::new("Segments", FrontendGraphDataType::Raster)], + properties: node_properties::no_properties, + }, + DocumentNodeType { + name: "Index", + category: "Image Adjustments", + identifier: NodeImplementation::proto("graphene_core::raster::IndexNode<_>"), + inputs: vec![ + DocumentInputType::value("Segmentation", TaggedValue::Segments(vec![ImageFrame::empty()]), true), + DocumentInputType::value("Index", TaggedValue::U32(0), false), + ], + outputs: vec![DocumentOutputType::new("Image", FrontendGraphDataType::Raster)], + properties: node_properties::index_node_properties, + }, ] } diff --git a/editor/src/messages/portfolio/document/node_graph/node_graph_message_handler/node_properties.rs b/editor/src/messages/portfolio/document/node_graph/node_graph_message_handler/node_properties.rs index d71ad768..1e485d0e 100644 --- a/editor/src/messages/portfolio/document/node_graph/node_graph_message_handler/node_properties.rs +++ b/editor/src/messages/portfolio/document/node_graph/node_graph_message_handler/node_properties.rs @@ -1264,6 +1264,12 @@ pub fn no_properties(_document_node: &DocumentNode, _node_id: NodeId, _context: string_properties("Node has no properties") } +pub fn index_node_properties(document_node: &DocumentNode, node_id: NodeId, _context: &mut NodePropertiesContext) -> Vec { + let index = number_widget(document_node, node_id, 1, "Index", NumberInput::default().min(0.), true); + + vec![LayoutGroup::Row { widgets: index }] +} + pub fn generate_node_properties(document_node: &DocumentNode, node_id: NodeId, context: &mut NodePropertiesContext) -> LayoutGroup { let name = document_node.name.clone(); let layout = match super::document_node_types::resolve_document_node_type(&name) { diff --git a/node-graph/gcore/src/raster/adjustments.rs b/node-graph/gcore/src/raster/adjustments.rs index 4cc71047..ee3ac41a 100644 --- a/node-graph/gcore/src/raster/adjustments.rs +++ b/node-graph/gcore/src/raster/adjustments.rs @@ -513,3 +513,18 @@ fn exposure(color: Color, exposure: f64, offset: f64, gamma_correction: f64) -> // TODO: Remove conversion to linear when the whole node graph uses linear color result.to_gamma_srgb() } + +#[derive(Debug)] +pub struct IndexNode { + pub index: Index, +} + +#[node_macro::node_fn(IndexNode)] +pub fn index_node(input: Vec>, index: u32) -> super::ImageFrame { + if (index as usize) < input.len() { + input[index as usize].clone() + } else { + warn!("The number of segments is {} and the requested segment is {}!", input.len(), index); + super::ImageFrame::empty() + } +} diff --git a/node-graph/graph-craft/src/document/value.rs b/node-graph/graph-craft/src/document/value.rs index 74b1848c..c21ded34 100644 --- a/node-graph/graph-craft/src/document/value.rs +++ b/node-graph/graph-craft/src/document/value.rs @@ -48,6 +48,7 @@ pub enum TaggedValue { OptionalColor(Option), ManipulatorGroupIds(Vec), VecDVec2(Vec), + Segments(Vec>), } #[allow(clippy::derived_hash_with_manual_eq)] @@ -111,6 +112,12 @@ impl Hash for TaggedValue { dvec2.to_array().iter().for_each(|x| x.to_bits().hash(state)); } } + Self::Segments(segments) => { + 32.hash(state); + for segment in segments { + segment.hash(state) + } + } } } } @@ -153,6 +160,7 @@ impl<'a> TaggedValue { TaggedValue::OptionalColor(x) => Box::new(x), TaggedValue::ManipulatorGroupIds(x) => Box::new(x), TaggedValue::VecDVec2(x) => Box::new(x), + TaggedValue::Segments(x) => Box::new(x), } } @@ -194,6 +202,7 @@ impl<'a> TaggedValue { TaggedValue::OptionalColor(_) => concrete!(Option), TaggedValue::ManipulatorGroupIds(_) => concrete!(Vec), TaggedValue::VecDVec2(_) => concrete!(Vec), + TaggedValue::Segments(_) => concrete!(graphene_core::raster::IndexNode>>), } } } diff --git a/node-graph/gstd/src/image_segmentation.rs b/node-graph/gstd/src/image_segmentation.rs new file mode 100644 index 00000000..b0f3e277 --- /dev/null +++ b/node-graph/gstd/src/image_segmentation.rs @@ -0,0 +1,126 @@ +use std::collections::hash_map::HashMap; + +use graphene_core::raster::{Color, ImageFrame}; +use graphene_core::Node; + +fn apply_mask(image_frame: &mut ImageFrame, x: usize, y: usize, multiplier: u8) { + let color = &mut image_frame.image.data[y * image_frame.image.width as usize + x]; + let color8 = color.to_rgba8(); + *color = Color::from_rgba8_srgb(color8[0] * multiplier, color8[1] * multiplier, color8[2] * multiplier, color8[3] * multiplier); +} + +pub struct Mask { + pub data: Vec, + pub width: usize, + pub height: usize, +} + +impl Mask { + fn sample(&self, u: f32, v: f32) -> u8 { + let x = (u * (self.width as f32)) as usize; + let y = (v * (self.height as f32)) as usize; + + self.data[y * self.width + x] + } +} + +fn image_segmentation(input_image: &ImageFrame, input_mask: &Mask) -> Vec> { + const NUM_LABELS: usize = u8::MAX as usize; + let mut result = Vec::>::with_capacity(NUM_LABELS); + let mut current_label = 0_usize; + let mut label_appeared = [false; NUM_LABELS + 1]; + let mut max_label = 0_usize; + + if input_mask.data.is_empty() { + warn!("The mask for the segmentation node is empty!"); + return vec![ImageFrame::empty()]; + } + + result.push(input_image.clone()); + let result_last = result.last_mut().unwrap(); + + for y in 0..input_image.image.height { + let v = (y as f32) / (input_image.image.height as f32); + for x in 0..input_image.image.width { + let u = (x as f32) / (input_image.image.width as f32); + let label = input_mask.sample(u, v) as usize; + let multiplier = (label == current_label) as u8; + + apply_mask(result_last, x as usize, y as usize, multiplier); + + if label < NUM_LABELS { + label_appeared[label] = true; + max_label = max_label.max(label); + } + } + } + + if !label_appeared[current_label] { + result.pop(); + } + + for i in 1..=max_label.max(NUM_LABELS) { + current_label = i; + + if !label_appeared[current_label] { + continue; + } + + result.push(input_image.clone()); + let result_last = result.last_mut().unwrap(); + + for y in 0..input_image.image.height { + let v = (y as f32) / (input_image.image.height as f32); + for x in 0..input_image.image.width { + let u = (x as f32) / (input_image.image.width as f32); + let label = input_mask.sample(u, v) as usize; + let multiplier = (label == current_label) as u8; + + apply_mask(result_last, x as usize, y as usize, multiplier); + } + } + } + + result +} + +fn convert_image_to_mask(input: &ImageFrame) -> Vec { + let mut result = vec![0_u8; (input.image.width * input.image.height) as usize]; + let mut colors = HashMap::<[u8; 4], usize>::new(); + let mut last_value = 0_usize; + + for (color, result) in input.image.data.iter().zip(result.iter_mut()) { + let color = color.to_rgba8(); + if let Some(value) = colors.get(&color) { + *result = *value as u8; + } else { + if last_value > u8::MAX as usize { + warn!("The limit for number of segments ({}) has been exceeded!", u8::MAX); + break; + } + + *result = last_value as u8; + colors.insert(color, last_value); + last_value += 1; + } + } + + result +} + +#[derive(Debug)] +pub struct ImageSegmentationNode { + pub(crate) mask_image: MaskImage, +} + +#[node_macro::node_fn(ImageSegmentationNode)] +pub(crate) fn image_segmentation(image: ImageFrame, mask_image: ImageFrame) -> Vec> { + let mask_data = convert_image_to_mask(&mask_image); + let mask = Mask { + data: mask_data, + width: mask_image.image.width as usize, + height: mask_image.image.height as usize, + }; + + image_segmentation(&image, &mask) +} diff --git a/node-graph/gstd/src/lib.rs b/node-graph/gstd/src/lib.rs index 681ed9fd..88b90947 100644 --- a/node-graph/gstd/src/lib.rs +++ b/node-graph/gstd/src/lib.rs @@ -20,4 +20,6 @@ pub mod quantization; pub use graphene_core::*; +pub mod image_segmentation; + pub mod brush; diff --git a/node-graph/interpreted-executor/src/node_registry.rs b/node-graph/interpreted-executor/src/node_registry.rs index 96fb070d..4fc7fe73 100644 --- a/node-graph/interpreted-executor/src/node_registry.rs +++ b/node-graph/interpreted-executor/src/node_registry.rs @@ -218,6 +218,8 @@ fn node_registry() -> HashMap, params: [LuminanceCalculation]), raster_node!(graphene_core::raster::LevelsNode<_, _, _, _, _>, params: [f64, f64, f64, f64, f64]), + register_node!(graphene_std::image_segmentation::ImageSegmentationNode<_>, input: ImageFrame, params: [ImageFrame]), + register_node!(graphene_core::raster::IndexNode<_>, input: Vec>, params: [u32]), vec![ ( NodeIdentifier::new("graphene_core::raster::BlendNode<_, _, _, _>"),