Implement basic request caching for compilation server (#1253)
* Implement basic request caching for compilation server * Fix formatting
This commit is contained in:
parent
6289d92e02
commit
9da83d3280
|
|
@ -1,4 +1,4 @@
|
||||||
use std::sync::Arc;
|
use std::{collections::HashMap, sync::Arc, sync::RwLock};
|
||||||
|
|
||||||
use gpu_compiler_bin_wrapper::CompileRequest;
|
use gpu_compiler_bin_wrapper::CompileRequest;
|
||||||
use tower_http::cors::CorsLayer;
|
use tower_http::cors::CorsLayer;
|
||||||
|
|
@ -12,12 +12,14 @@ use axum::{
|
||||||
|
|
||||||
struct AppState {
|
struct AppState {
|
||||||
compile_dir: tempfile::TempDir,
|
compile_dir: tempfile::TempDir,
|
||||||
|
cache: RwLock<HashMap<CompileRequest, Result<Vec<u8>, StatusCode>>>,
|
||||||
}
|
}
|
||||||
|
|
||||||
#[tokio::main]
|
#[tokio::main]
|
||||||
async fn main() {
|
async fn main() {
|
||||||
let shared_state = Arc::new(AppState {
|
let shared_state = Arc::new(AppState {
|
||||||
compile_dir: tempfile::tempdir().expect("failed to create tempdir"),
|
compile_dir: tempfile::tempdir().expect("failed to create tempdir"),
|
||||||
|
cache: Default::default(),
|
||||||
});
|
});
|
||||||
|
|
||||||
// build our application with a single route
|
// build our application with a single route
|
||||||
|
|
@ -33,9 +35,15 @@ async fn main() {
|
||||||
}
|
}
|
||||||
|
|
||||||
async fn post_compile_spirv(State(state): State<Arc<AppState>>, Json(compile_request): Json<CompileRequest>) -> Result<Vec<u8>, StatusCode> {
|
async fn post_compile_spirv(State(state): State<Arc<AppState>>, Json(compile_request): Json<CompileRequest>) -> Result<Vec<u8>, StatusCode> {
|
||||||
|
if let Some(result) = state.cache.read().unwrap().get(&compile_request) {
|
||||||
|
return result.clone();
|
||||||
|
}
|
||||||
|
|
||||||
let path = std::env::var("CARGO_MANIFEST_DIR").unwrap() + "/../gpu-compiler/Cargo.toml";
|
let path = std::env::var("CARGO_MANIFEST_DIR").unwrap() + "/../gpu-compiler/Cargo.toml";
|
||||||
compile_request.compile(state.compile_dir.path().to_str().expect("non utf8 tempdir path"), &path).map_err(|e| {
|
let result = compile_request.compile(state.compile_dir.path().to_str().expect("non utf8 tempdir path"), &path).map_err(|e| {
|
||||||
eprintln!("compilation failed: {}", e);
|
eprintln!("compilation failed: {}", e);
|
||||||
StatusCode::INTERNAL_SERVER_ERROR
|
StatusCode::INTERNAL_SERVER_ERROR
|
||||||
})
|
});
|
||||||
|
state.cache.write().unwrap().insert(compile_request, result.clone());
|
||||||
|
result
|
||||||
}
|
}
|
||||||
|
|
|
||||||
|
|
@ -41,7 +41,7 @@ pub fn compile_spirv(request: &CompileRequest, compile_dir: Option<&str>, manife
|
||||||
Ok(std::fs::read(compile_dir.unwrap().to_owned() + "/shader.spv")?)
|
Ok(std::fs::read(compile_dir.unwrap().to_owned() + "/shader.spv")?)
|
||||||
}
|
}
|
||||||
|
|
||||||
#[derive(Serialize, Deserialize, Debug, Clone, PartialEq)]
|
#[derive(Serialize, Deserialize, Debug, Clone, PartialEq, Hash, Eq)]
|
||||||
pub struct CompileRequest {
|
pub struct CompileRequest {
|
||||||
networks: Vec<graph_craft::proto::ProtoNetwork>,
|
networks: Vec<graph_craft::proto::ProtoNetwork>,
|
||||||
input_types: Vec<Type>,
|
input_types: Vec<Type>,
|
||||||
|
|
|
||||||
|
|
@ -23,7 +23,7 @@ pub type TypeErasedPinned<'n> = Pin<Box<dyn for<'i> NodeIO<'i, Any<'i>, Output =
|
||||||
pub type NodeConstructor = for<'a> fn(Vec<TypeErasedPinnedRef<'static>>) -> DynFuture<'static, TypeErasedPinned<'static>>;
|
pub type NodeConstructor = for<'a> fn(Vec<TypeErasedPinnedRef<'static>>) -> DynFuture<'static, TypeErasedPinned<'static>>;
|
||||||
|
|
||||||
#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
|
#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
|
||||||
#[derive(Debug, Default, PartialEq, Clone)]
|
#[derive(Debug, Default, PartialEq, Clone, Hash, Eq)]
|
||||||
pub struct ProtoNetwork {
|
pub struct ProtoNetwork {
|
||||||
// Should a proto Network even allow inputs? Don't think so
|
// Should a proto Network even allow inputs? Don't think so
|
||||||
pub inputs: Vec<NodeId>,
|
pub inputs: Vec<NodeId>,
|
||||||
|
|
@ -90,12 +90,23 @@ pub enum ConstructionArgs {
|
||||||
Inline(InlineRust),
|
Inline(InlineRust),
|
||||||
}
|
}
|
||||||
|
|
||||||
|
impl Eq for ConstructionArgs {}
|
||||||
|
|
||||||
impl PartialEq for ConstructionArgs {
|
impl PartialEq for ConstructionArgs {
|
||||||
fn eq(&self, other: &Self) -> bool {
|
fn eq(&self, other: &Self) -> bool {
|
||||||
match (&self, &other) {
|
match (&self, &other) {
|
||||||
(Self::Nodes(n1), Self::Nodes(n2)) => n1 == n2,
|
(Self::Nodes(n1), Self::Nodes(n2)) => n1 == n2,
|
||||||
(Self::Value(v1), Self::Value(v2)) => v1 == v2,
|
(Self::Value(v1), Self::Value(v2)) => v1 == v2,
|
||||||
_ => core::mem::discriminant(self) == core::mem::discriminant(other),
|
_ => {
|
||||||
|
use std::hash::Hasher;
|
||||||
|
use xxhash_rust::xxh3::Xxh3;
|
||||||
|
let hash = |input: &Self| {
|
||||||
|
let mut hasher = Xxh3::new();
|
||||||
|
input.hash(&mut hasher);
|
||||||
|
hasher.finish()
|
||||||
|
};
|
||||||
|
hash(self) == hash(other)
|
||||||
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
@ -126,7 +137,7 @@ impl ConstructionArgs {
|
||||||
}
|
}
|
||||||
|
|
||||||
#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
|
#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
|
||||||
#[derive(Debug, PartialEq, Clone)]
|
#[derive(Debug, PartialEq, Clone, Hash, Eq)]
|
||||||
pub struct ProtoNode {
|
pub struct ProtoNode {
|
||||||
pub construction_args: ConstructionArgs,
|
pub construction_args: ConstructionArgs,
|
||||||
pub input: ProtoNodeInput,
|
pub input: ProtoNodeInput,
|
||||||
|
|
@ -137,7 +148,7 @@ pub struct ProtoNode {
|
||||||
/// A ProtoNodeInput represents the input of a node in a ProtoNetwork.
|
/// A ProtoNodeInput represents the input of a node in a ProtoNetwork.
|
||||||
/// For documentation on the meaning of the variants, see the documentation of the `NodeInput` enum
|
/// For documentation on the meaning of the variants, see the documentation of the `NodeInput` enum
|
||||||
/// in the `document` module
|
/// in the `document` module
|
||||||
#[derive(Debug, PartialEq, Eq, Clone)]
|
#[derive(Debug, PartialEq, Eq, Clone, Hash)]
|
||||||
#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
|
#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
|
||||||
pub enum ProtoNodeInput {
|
pub enum ProtoNodeInput {
|
||||||
None,
|
None,
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue