Implement basic request caching for compilation server (#1253)

* Implement basic request caching for compilation server

* Fix formatting
This commit is contained in:
Dennis Kobert 2023-05-28 00:52:10 +02:00 committed by Keavon Chambers
parent 6289d92e02
commit 9da83d3280
3 changed files with 27 additions and 8 deletions

View File

@ -1,4 +1,4 @@
use std::sync::Arc; use std::{collections::HashMap, sync::Arc, sync::RwLock};
use gpu_compiler_bin_wrapper::CompileRequest; use gpu_compiler_bin_wrapper::CompileRequest;
use tower_http::cors::CorsLayer; use tower_http::cors::CorsLayer;
@ -12,12 +12,14 @@ use axum::{
struct AppState { struct AppState {
compile_dir: tempfile::TempDir, compile_dir: tempfile::TempDir,
cache: RwLock<HashMap<CompileRequest, Result<Vec<u8>, StatusCode>>>,
} }
#[tokio::main] #[tokio::main]
async fn main() { async fn main() {
let shared_state = Arc::new(AppState { let shared_state = Arc::new(AppState {
compile_dir: tempfile::tempdir().expect("failed to create tempdir"), compile_dir: tempfile::tempdir().expect("failed to create tempdir"),
cache: Default::default(),
}); });
// build our application with a single route // build our application with a single route
@ -33,9 +35,15 @@ async fn main() {
} }
async fn post_compile_spirv(State(state): State<Arc<AppState>>, Json(compile_request): Json<CompileRequest>) -> Result<Vec<u8>, StatusCode> { async fn post_compile_spirv(State(state): State<Arc<AppState>>, Json(compile_request): Json<CompileRequest>) -> Result<Vec<u8>, StatusCode> {
if let Some(result) = state.cache.read().unwrap().get(&compile_request) {
return result.clone();
}
let path = std::env::var("CARGO_MANIFEST_DIR").unwrap() + "/../gpu-compiler/Cargo.toml"; let path = std::env::var("CARGO_MANIFEST_DIR").unwrap() + "/../gpu-compiler/Cargo.toml";
compile_request.compile(state.compile_dir.path().to_str().expect("non utf8 tempdir path"), &path).map_err(|e| { let result = compile_request.compile(state.compile_dir.path().to_str().expect("non utf8 tempdir path"), &path).map_err(|e| {
eprintln!("compilation failed: {}", e); eprintln!("compilation failed: {}", e);
StatusCode::INTERNAL_SERVER_ERROR StatusCode::INTERNAL_SERVER_ERROR
}) });
state.cache.write().unwrap().insert(compile_request, result.clone());
result
} }

View File

@ -41,7 +41,7 @@ pub fn compile_spirv(request: &CompileRequest, compile_dir: Option<&str>, manife
Ok(std::fs::read(compile_dir.unwrap().to_owned() + "/shader.spv")?) Ok(std::fs::read(compile_dir.unwrap().to_owned() + "/shader.spv")?)
} }
#[derive(Serialize, Deserialize, Debug, Clone, PartialEq)] #[derive(Serialize, Deserialize, Debug, Clone, PartialEq, Hash, Eq)]
pub struct CompileRequest { pub struct CompileRequest {
networks: Vec<graph_craft::proto::ProtoNetwork>, networks: Vec<graph_craft::proto::ProtoNetwork>,
input_types: Vec<Type>, input_types: Vec<Type>,

View File

@ -23,7 +23,7 @@ pub type TypeErasedPinned<'n> = Pin<Box<dyn for<'i> NodeIO<'i, Any<'i>, Output =
pub type NodeConstructor = for<'a> fn(Vec<TypeErasedPinnedRef<'static>>) -> DynFuture<'static, TypeErasedPinned<'static>>; pub type NodeConstructor = for<'a> fn(Vec<TypeErasedPinnedRef<'static>>) -> DynFuture<'static, TypeErasedPinned<'static>>;
#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))] #[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
#[derive(Debug, Default, PartialEq, Clone)] #[derive(Debug, Default, PartialEq, Clone, Hash, Eq)]
pub struct ProtoNetwork { pub struct ProtoNetwork {
// Should a proto Network even allow inputs? Don't think so // Should a proto Network even allow inputs? Don't think so
pub inputs: Vec<NodeId>, pub inputs: Vec<NodeId>,
@ -90,12 +90,23 @@ pub enum ConstructionArgs {
Inline(InlineRust), Inline(InlineRust),
} }
impl Eq for ConstructionArgs {}
impl PartialEq for ConstructionArgs { impl PartialEq for ConstructionArgs {
fn eq(&self, other: &Self) -> bool { fn eq(&self, other: &Self) -> bool {
match (&self, &other) { match (&self, &other) {
(Self::Nodes(n1), Self::Nodes(n2)) => n1 == n2, (Self::Nodes(n1), Self::Nodes(n2)) => n1 == n2,
(Self::Value(v1), Self::Value(v2)) => v1 == v2, (Self::Value(v1), Self::Value(v2)) => v1 == v2,
_ => core::mem::discriminant(self) == core::mem::discriminant(other), _ => {
use std::hash::Hasher;
use xxhash_rust::xxh3::Xxh3;
let hash = |input: &Self| {
let mut hasher = Xxh3::new();
input.hash(&mut hasher);
hasher.finish()
};
hash(self) == hash(other)
}
} }
} }
} }
@ -126,7 +137,7 @@ impl ConstructionArgs {
} }
#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))] #[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
#[derive(Debug, PartialEq, Clone)] #[derive(Debug, PartialEq, Clone, Hash, Eq)]
pub struct ProtoNode { pub struct ProtoNode {
pub construction_args: ConstructionArgs, pub construction_args: ConstructionArgs,
pub input: ProtoNodeInput, pub input: ProtoNodeInput,
@ -137,7 +148,7 @@ pub struct ProtoNode {
/// A ProtoNodeInput represents the input of a node in a ProtoNetwork. /// A ProtoNodeInput represents the input of a node in a ProtoNetwork.
/// For documentation on the meaning of the variants, see the documentation of the `NodeInput` enum /// For documentation on the meaning of the variants, see the documentation of the `NodeInput` enum
/// in the `document` module /// in the `document` module
#[derive(Debug, PartialEq, Eq, Clone)] #[derive(Debug, PartialEq, Eq, Clone, Hash)]
#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))] #[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
pub enum ProtoNodeInput { pub enum ProtoNodeInput {
None, None,