diff --git a/editor/src/messages/portfolio/document/document_message_handler.rs b/editor/src/messages/portfolio/document/document_message_handler.rs index ebd51210..8039c62e 100644 --- a/editor/src/messages/portfolio/document/document_message_handler.rs +++ b/editor/src/messages/portfolio/document/document_message_handler.rs @@ -420,7 +420,7 @@ impl MessageHandler> for DocumentMessag .node_graph_handler .context_menu .as_ref() - .is_some_and(|context_menu| matches!(context_menu.context_menu_data, super::node_graph::utility_types::ContextMenuData::CreateNode)) + .is_some_and(|context_menu| matches!(context_menu.context_menu_data, super::node_graph::utility_types::ContextMenuData::CreateNode { compatible_type: None })) { // Close the context menu self.node_graph_handler.context_menu = None; diff --git a/editor/src/messages/portfolio/document/node_graph/document_node_definitions.rs b/editor/src/messages/portfolio/document/node_graph/document_node_definitions.rs index 53d555da..caa217c0 100644 --- a/editor/src/messages/portfolio/document/node_graph/document_node_definitions.rs +++ b/editor/src/messages/portfolio/document/node_graph/document_node_definitions.rs @@ -3444,11 +3444,86 @@ pub fn resolve_document_node_type(identifier: &str) -> Option<&DocumentNodeDefin } pub fn collect_node_types() -> Vec { - DOCUMENT_NODE_TYPES + // Create a mapping from registry ID to document node identifier + let id_to_identifier_map: HashMap = DOCUMENT_NODE_TYPES + .iter() + .filter_map(|definition| { + if let DocumentNodeImplementation::ProtoNode(ProtoNodeIdentifier { name }) = &definition.node_template.document_node.implementation { + Some((name.to_string(), definition.identifier)) + } else { + None + } + }) + .collect(); + let mut extracted_node_types = Vec::new(); + + let node_registry = graphene_core::registry::NODE_REGISTRY.lock().unwrap(); + let node_metadata = graphene_core::registry::NODE_METADATA.lock().unwrap(); + for (id, metadata) in node_metadata.iter() { + if let Some(implementations) = node_registry.get(id) { + let identifier = match id_to_identifier_map.get(id) { + Some(&id) => id.to_string(), + None => continue, + }; + + // Extract category from metadata (already creates an owned String) + let category = metadata.category.unwrap_or_default().to_string(); + + // Extract input types (already creates owned Strings) + let input_types = implementations + .iter() + .flat_map(|(_, node_io)| node_io.inputs.iter().map(|ty| ty.clone().nested_type().to_string())) + .collect::>() + .into_iter() + .collect::>(); + + // Create a FrontendNodeType + let node_type = FrontendNodeType::with_owned_strings_and_input_types(identifier, category, input_types); + + // Store the created node_type + extracted_node_types.push(node_type); + } + } + + let node_types: Vec = DOCUMENT_NODE_TYPES .iter() .filter(|definition| !definition.category.is_empty()) - .map(|definition| FrontendNodeType::new(definition.identifier, definition.category)) - .collect() + .map(|definition| { + let input_types = definition + .node_template + .document_node + .inputs + .iter() + .filter_map(|node_input| node_input.as_value().map(|node_value| node_value.ty().nested_type().to_string())) + .collect::>(); + + FrontendNodeType::with_input_types(definition.identifier, definition.category, input_types) + }) + .collect(); + + // Update categories in extracted_node_types from node_types + for extracted_node in &mut extracted_node_types { + if extracted_node.category.is_empty() { + // Find matching node in node_types and update category if found + if let Some(matching_node) = node_types.iter().find(|node_type| node_type.name == extracted_node.name) { + extracted_node.category = matching_node.category.clone(); + } + } + } + let missing_nodes: Vec = node_types + .iter() + .filter(|node| !extracted_node_types.iter().any(|extracted| extracted.name == node.name)) + .cloned() + .collect(); + + // Add the missing nodes to extracted_node_types + for node in missing_nodes { + extracted_node_types.push(node); + } + // Remove entries with empty categories + extracted_node_types.retain(|node| !node.category.is_empty()); + + extracted_node_types } pub fn collect_node_descriptions() -> Vec<(String, String)> { diff --git a/editor/src/messages/portfolio/document/node_graph/node_graph_message_handler.rs b/editor/src/messages/portfolio/document/node_graph/node_graph_message_handler.rs index 3e06539e..206594d0 100644 --- a/editor/src/messages/portfolio/document/node_graph/node_graph_message_handler.rs +++ b/editor/src/messages/portfolio/document/node_graph/node_graph_message_handler.rs @@ -302,6 +302,30 @@ impl<'a> MessageHandler> for NodeGrap return; } + let Some(network_metadata) = network_interface.network_metadata(selection_network_path) else { + log::error!("Could not get network metadata in NodeGraphMessage::EnterNestedNetwork"); + return; + }; + + let click = ipp.mouse.position; + let node_graph_point = network_metadata.persistent_metadata.navigation_metadata.node_graph_to_viewport.inverse().transform_point2(click); + + // Check if clicked on empty area (no node, no input/output connector) + let clicked_id = network_interface.node_from_click(click, selection_network_path); + let clicked_input = network_interface.input_connector_from_click(click, selection_network_path); + let clicked_output = network_interface.output_connector_from_click(click, selection_network_path); + + if clicked_id.is_none() && clicked_input.is_none() && clicked_output.is_none() && self.context_menu.is_none() { + // Create a context menu with node creation options + self.context_menu = Some(ContextMenuInformation { + context_menu_coordinates: (node_graph_point.x as i32, node_graph_point.y as i32), + context_menu_data: ContextMenuData::CreateNode { compatible_type: None }, + }); + + responses.add(FrontendMessage::UpdateContextMenuInformation { + context_menu_information: self.context_menu.clone(), + }); + } let Some(node_id) = network_interface.node_from_click(ipp.mouse.position, selection_network_path) else { return; }; @@ -613,11 +637,11 @@ impl<'a> MessageHandler> for NodeGrap let currently_is_node = !network_interface.is_layer(&node_id, selection_network_path); ContextMenuData::ToggleLayer { node_id, currently_is_node } } else { - ContextMenuData::CreateNode + ContextMenuData::CreateNode { compatible_type: None } }; // TODO: Create function - let node_graph_shift = if matches!(context_menu_data, ContextMenuData::CreateNode) { + let node_graph_shift = if matches!(context_menu_data, ContextMenuData::CreateNode { compatible_type: None }) { let appear_right_of_mouse = if click.x > ipp.viewport_bounds.size().x - 180. { -180. } else { 0. }; let appear_above_mouse = if click.y > ipp.viewport_bounds.size().y - 200. { -200. } else { 0. }; DVec2::new(appear_right_of_mouse, appear_above_mouse) / network_metadata.persistent_metadata.navigation_metadata.node_graph_to_viewport.matrix2.x_axis.x @@ -1012,14 +1036,27 @@ impl<'a> MessageHandler> for NodeGrap warn!("No network_metadata"); return; }; + // Get the compatible type from the output connector + let compatible_type = output_connector.and_then(|output_connector| { + output_connector.node_id().and_then(|node_id| { + let output_index = output_connector.index(); + // Get the output types from the network interface + let output_types = network_interface.output_types(&node_id, selection_network_path); + // Extract the type if available + output_types.get(output_index).and_then(|type_option| type_option.as_ref()).map(|(output_type, _)| { + // Create a search term based on the type + format!("type:{}", output_type.clone().nested_type()) + }) + }) + }); let appear_right_of_mouse = if ipp.mouse.position.x > ipp.viewport_bounds.size().x - 173. { -173. } else { 0. }; let appear_above_mouse = if ipp.mouse.position.y > ipp.viewport_bounds.size().y - 34. { -34. } else { 0. }; let node_graph_shift = DVec2::new(appear_right_of_mouse, appear_above_mouse) / network_metadata.persistent_metadata.navigation_metadata.node_graph_to_viewport.matrix2.x_axis.x; self.context_menu = Some(ContextMenuInformation { context_menu_coordinates: ((point.x + node_graph_shift.x) as i32, (point.y + node_graph_shift.y) as i32), - context_menu_data: ContextMenuData::CreateNode, + context_menu_data: ContextMenuData::CreateNode { compatible_type }, }); responses.add(FrontendMessage::UpdateContextMenuInformation { diff --git a/editor/src/messages/portfolio/document/node_graph/utility_types.rs b/editor/src/messages/portfolio/document/node_graph/utility_types.rs index 1c380a49..b23e048d 100644 --- a/editor/src/messages/portfolio/document/node_graph/utility_types.rs +++ b/editor/src/messages/portfolio/document/node_graph/utility_types.rs @@ -107,6 +107,8 @@ pub struct FrontendNodeWire { pub struct FrontendNodeType { pub name: String, pub category: String, + #[serde(rename = "inputTypes")] + pub input_types: Option>, } impl FrontendNodeType { @@ -114,6 +116,23 @@ impl FrontendNodeType { Self { name: name.to_string(), category: category.to_string(), + input_types: None, + } + } + + pub fn with_input_types(name: &'static str, category: &'static str, input_types: Vec) -> Self { + Self { + name: name.to_string(), + category: category.to_string(), + input_types: Some(input_types), + } + } + + pub fn with_owned_strings_and_input_types(name: String, category: String, input_types: Vec) -> Self { + Self { + name, + category, + input_types: Some(input_types), } } } @@ -162,7 +181,11 @@ pub enum ContextMenuData { #[serde(rename = "currentlyIsNode")] currently_is_node: bool, }, - CreateNode, + CreateNode { + #[serde(rename = "compatibleType")] + #[serde(default)] + compatible_type: Option, + }, } #[derive(Clone, Debug, PartialEq, serde::Serialize, serde::Deserialize, specta::Type)] diff --git a/frontend/src/components/floating-menus/NodeCatalog.svelte b/frontend/src/components/floating-menus/NodeCatalog.svelte index e3321a93..60aefbbe 100644 --- a/frontend/src/components/floating-menus/NodeCatalog.svelte +++ b/frontend/src/components/floating-menus/NodeCatalog.svelte @@ -12,9 +12,10 @@ const nodeGraph = getContext("nodeGraph"); export let disabled = false; + export let initialSearchTerm = ""; let nodeSearchInput: TextInput | undefined = undefined; - let searchTerm = ""; + let searchTerm = initialSearchTerm; $: nodeCategories = buildNodeCategories($nodeGraph.nodeTypes, searchTerm); @@ -25,33 +26,60 @@ function buildNodeCategories(nodeTypes: FrontendNodeType[], searchTerm: string): [string, NodeCategoryDetails][] { const categories = new Map(); + const isTypeSearch = searchTerm.toLowerCase().startsWith("type:"); + let typeSearchTerm = ""; + let remainingSearchTerms = [searchTerm.toLowerCase()]; + + if (isTypeSearch) { + // Extract the first word after "type:" as the type search + const searchParts = searchTerm.substring(5).trim().split(/\s+/); + typeSearchTerm = searchParts[0].toLowerCase(); + + remainingSearchTerms = searchParts.slice(1).map((term) => term.toLowerCase()); + } nodeTypes.forEach((node) => { - let nameIncludesSearchTerm = node.name.toLowerCase().includes(searchTerm.toLowerCase()); + let matchesTypeSearch = true; + let matchesRemainingTerms = true; - // Quick and dirty hack to alias "Layer" to "Merge" in the search - if (node.name === "Merge") { - nameIncludesSearchTerm = nameIncludesSearchTerm || "Layer".toLowerCase().includes(searchTerm.toLowerCase()); + if (isTypeSearch && typeSearchTerm) { + matchesTypeSearch = node.inputTypes?.some((inputType) => inputType.toLowerCase().includes(typeSearchTerm)) || false; } - if (searchTerm.length > 0 && !nameIncludesSearchTerm && !node.category.toLowerCase().includes(searchTerm.toLowerCase())) { + if (remainingSearchTerms.length > 0) { + matchesRemainingTerms = remainingSearchTerms.every((term) => { + const nameMatch = node.name.toLowerCase().includes(term); + const categoryMatch = node.category.toLowerCase().includes(term); + + // Quick and dirty hack to alias "Layer" to "Merge" in the search + const layerAliasMatch = node.name === "Merge" && "layer".includes(term); + + return nameMatch || categoryMatch || layerAliasMatch; + }); + } + + // Node matches if it passes both type search and remaining terms filters + const includesSearchTerm = matchesTypeSearch && matchesRemainingTerms; + + if (searchTerm.length > 0 && !includesSearchTerm) { return; } const category = categories.get(node.category); - let open = nameIncludesSearchTerm; + let open = includesSearchTerm; if (searchTerm.length === 0) { open = false; } if (category) { - category.open = open; + category.open = category.open || open; category.nodes.push(node); - } else + } else { categories.set(node.category, { open, nodes: [node], }); + } }); const START_CATEGORIES_ORDER = ["UNCATEGORIZED", "General", "Value", "Math", "Style"]; diff --git a/frontend/src/components/views/Graph.svelte b/frontend/src/components/views/Graph.svelte index 589a8411..702287b3 100644 --- a/frontend/src/components/views/Graph.svelte +++ b/frontend/src/components/views/Graph.svelte @@ -653,8 +653,10 @@ top: `${$nodeGraph.contextMenuInformation.contextMenuCoordinates.y * $nodeGraph.transform.scale + $nodeGraph.transform.y}px`, }} > - {#if $nodeGraph.contextMenuInformation.contextMenuData === "CreateNode"} + {#if typeof $nodeGraph.contextMenuInformation.contextMenuData === "string" && $nodeGraph.contextMenuInformation.contextMenuData === "CreateNode"} createNode(e.detail)} /> + {:else if $nodeGraph.contextMenuInformation.contextMenuData && "compatibleType" in $nodeGraph.contextMenuInformation.contextMenuData} + createNode(e.detail)} /> {:else} {@const contextMenuData = $nodeGraph.contextMenuInformation.contextMenuData} diff --git a/frontend/src/messages.ts b/frontend/src/messages.ts index cf9b2f3e..133f5f95 100644 --- a/frontend/src/messages.ts +++ b/frontend/src/messages.ts @@ -46,6 +46,8 @@ const ContextTupleToVec2 = Transform((data) => { let contextMenuData = data.obj.contextMenuInformation.contextMenuData; if (contextMenuData.ToggleLayer !== undefined) { contextMenuData = { nodeId: contextMenuData.ToggleLayer.nodeId, currentlyIsNode: contextMenuData.ToggleLayer.currentlyIsNode }; + } else if (contextMenuData.CreateNode !== undefined) { + contextMenuData = { type: "CreateNode", compatibleType: contextMenuData.CreateNode.compatibleType }; } return { contextMenuCoordinates, contextMenuData }; }); @@ -185,8 +187,7 @@ export type FrontendClickTargets = { export type ContextMenuInformation = { contextMenuCoordinates: XY; - - contextMenuData: "CreateNode" | { nodeId: bigint; currentlyIsNode: boolean }; + contextMenuData: "CreateNode" | { type: "CreateNode"; compatibleType: string } | { nodeId: bigint; currentlyIsNode: boolean }; }; export type FrontendGraphDataType = "General" | "Raster" | "VectorData" | "Number" | "Group" | "Artboard"; @@ -337,6 +338,8 @@ export class FrontendNodeType { readonly name!: string; readonly category!: string; + + readonly inputTypes!: string[]; } export class NodeGraphTransform { diff --git a/node-graph/gcore/src/types.rs b/node-graph/gcore/src/types.rs index 0fc03e4d..5130dbfd 100644 --- a/node-graph/gcore/src/types.rs +++ b/node-graph/gcore/src/types.rs @@ -357,4 +357,8 @@ impl ProtoNodeIdentifier { pub const fn new(name: &'static str) -> Self { ProtoNodeIdentifier { name: Cow::Borrowed(name) } } + + pub const fn with_owned_string(name: String) -> Self { + ProtoNodeIdentifier { name: Cow::Owned(name) } + } }