diff --git a/node-graph/compilation-server/src/main.rs b/node-graph/compilation-server/src/main.rs index 7dc20066..90959f43 100644 --- a/node-graph/compilation-server/src/main.rs +++ b/node-graph/compilation-server/src/main.rs @@ -1,4 +1,4 @@ -use std::sync::Arc; +use std::{collections::HashMap, sync::Arc, sync::RwLock}; use gpu_compiler_bin_wrapper::CompileRequest; use tower_http::cors::CorsLayer; @@ -12,12 +12,14 @@ use axum::{ struct AppState { compile_dir: tempfile::TempDir, + cache: RwLock, StatusCode>>>, } #[tokio::main] async fn main() { let shared_state = Arc::new(AppState { compile_dir: tempfile::tempdir().expect("failed to create tempdir"), + cache: Default::default(), }); // build our application with a single route @@ -33,9 +35,15 @@ async fn main() { } async fn post_compile_spirv(State(state): State>, Json(compile_request): Json) -> Result, StatusCode> { + if let Some(result) = state.cache.read().unwrap().get(&compile_request) { + return result.clone(); + } + let path = std::env::var("CARGO_MANIFEST_DIR").unwrap() + "/../gpu-compiler/Cargo.toml"; - compile_request.compile(state.compile_dir.path().to_str().expect("non utf8 tempdir path"), &path).map_err(|e| { + let result = compile_request.compile(state.compile_dir.path().to_str().expect("non utf8 tempdir path"), &path).map_err(|e| { eprintln!("compilation failed: {}", e); StatusCode::INTERNAL_SERVER_ERROR - }) + }); + state.cache.write().unwrap().insert(compile_request, result.clone()); + result } diff --git a/node-graph/gpu-compiler/gpu-compiler-bin-wrapper/src/lib.rs b/node-graph/gpu-compiler/gpu-compiler-bin-wrapper/src/lib.rs index 43ce67e9..7690a082 100644 --- a/node-graph/gpu-compiler/gpu-compiler-bin-wrapper/src/lib.rs +++ b/node-graph/gpu-compiler/gpu-compiler-bin-wrapper/src/lib.rs @@ -41,7 +41,7 @@ pub fn compile_spirv(request: &CompileRequest, compile_dir: Option<&str>, manife Ok(std::fs::read(compile_dir.unwrap().to_owned() + "/shader.spv")?) } -#[derive(Serialize, Deserialize, Debug, Clone, PartialEq)] +#[derive(Serialize, Deserialize, Debug, Clone, PartialEq, Hash, Eq)] pub struct CompileRequest { networks: Vec, input_types: Vec, diff --git a/node-graph/graph-craft/src/proto.rs b/node-graph/graph-craft/src/proto.rs index a4905d63..105b582b 100644 --- a/node-graph/graph-craft/src/proto.rs +++ b/node-graph/graph-craft/src/proto.rs @@ -23,7 +23,7 @@ pub type TypeErasedPinned<'n> = Pin NodeIO<'i, Any<'i>, Output = pub type NodeConstructor = for<'a> fn(Vec>) -> DynFuture<'static, TypeErasedPinned<'static>>; #[cfg_attr(feature = "serde", derive(Serialize, Deserialize))] -#[derive(Debug, Default, PartialEq, Clone)] +#[derive(Debug, Default, PartialEq, Clone, Hash, Eq)] pub struct ProtoNetwork { // Should a proto Network even allow inputs? Don't think so pub inputs: Vec, @@ -90,12 +90,23 @@ pub enum ConstructionArgs { Inline(InlineRust), } +impl Eq for ConstructionArgs {} + impl PartialEq for ConstructionArgs { fn eq(&self, other: &Self) -> bool { match (&self, &other) { (Self::Nodes(n1), Self::Nodes(n2)) => n1 == n2, (Self::Value(v1), Self::Value(v2)) => v1 == v2, - _ => core::mem::discriminant(self) == core::mem::discriminant(other), + _ => { + use std::hash::Hasher; + use xxhash_rust::xxh3::Xxh3; + let hash = |input: &Self| { + let mut hasher = Xxh3::new(); + input.hash(&mut hasher); + hasher.finish() + }; + hash(self) == hash(other) + } } } } @@ -126,7 +137,7 @@ impl ConstructionArgs { } #[cfg_attr(feature = "serde", derive(Serialize, Deserialize))] -#[derive(Debug, PartialEq, Clone)] +#[derive(Debug, PartialEq, Clone, Hash, Eq)] pub struct ProtoNode { pub construction_args: ConstructionArgs, pub input: ProtoNodeInput, @@ -137,7 +148,7 @@ pub struct ProtoNode { /// A ProtoNodeInput represents the input of a node in a ProtoNetwork. /// For documentation on the meaning of the variants, see the documentation of the `NodeInput` enum /// in the `document` module -#[derive(Debug, PartialEq, Eq, Clone)] +#[derive(Debug, PartialEq, Eq, Clone, Hash)] #[cfg_attr(feature = "serde", derive(Serialize, Deserialize))] pub enum ProtoNodeInput { None,