Implement basic request caching for compilation server (#1253)
* Implement basic request caching for compilation server * Fix formatting
This commit is contained in:
parent
6289d92e02
commit
9da83d3280
|
|
@ -1,4 +1,4 @@
|
|||
use std::sync::Arc;
|
||||
use std::{collections::HashMap, sync::Arc, sync::RwLock};
|
||||
|
||||
use gpu_compiler_bin_wrapper::CompileRequest;
|
||||
use tower_http::cors::CorsLayer;
|
||||
|
|
@ -12,12 +12,14 @@ use axum::{
|
|||
|
||||
struct AppState {
|
||||
compile_dir: tempfile::TempDir,
|
||||
cache: RwLock<HashMap<CompileRequest, Result<Vec<u8>, StatusCode>>>,
|
||||
}
|
||||
|
||||
#[tokio::main]
|
||||
async fn main() {
|
||||
let shared_state = Arc::new(AppState {
|
||||
compile_dir: tempfile::tempdir().expect("failed to create tempdir"),
|
||||
cache: Default::default(),
|
||||
});
|
||||
|
||||
// build our application with a single route
|
||||
|
|
@ -33,9 +35,15 @@ async fn main() {
|
|||
}
|
||||
|
||||
async fn post_compile_spirv(State(state): State<Arc<AppState>>, Json(compile_request): Json<CompileRequest>) -> Result<Vec<u8>, StatusCode> {
|
||||
if let Some(result) = state.cache.read().unwrap().get(&compile_request) {
|
||||
return result.clone();
|
||||
}
|
||||
|
||||
let path = std::env::var("CARGO_MANIFEST_DIR").unwrap() + "/../gpu-compiler/Cargo.toml";
|
||||
compile_request.compile(state.compile_dir.path().to_str().expect("non utf8 tempdir path"), &path).map_err(|e| {
|
||||
let result = compile_request.compile(state.compile_dir.path().to_str().expect("non utf8 tempdir path"), &path).map_err(|e| {
|
||||
eprintln!("compilation failed: {}", e);
|
||||
StatusCode::INTERNAL_SERVER_ERROR
|
||||
})
|
||||
});
|
||||
state.cache.write().unwrap().insert(compile_request, result.clone());
|
||||
result
|
||||
}
|
||||
|
|
|
|||
|
|
@ -41,7 +41,7 @@ pub fn compile_spirv(request: &CompileRequest, compile_dir: Option<&str>, manife
|
|||
Ok(std::fs::read(compile_dir.unwrap().to_owned() + "/shader.spv")?)
|
||||
}
|
||||
|
||||
#[derive(Serialize, Deserialize, Debug, Clone, PartialEq)]
|
||||
#[derive(Serialize, Deserialize, Debug, Clone, PartialEq, Hash, Eq)]
|
||||
pub struct CompileRequest {
|
||||
networks: Vec<graph_craft::proto::ProtoNetwork>,
|
||||
input_types: Vec<Type>,
|
||||
|
|
|
|||
|
|
@ -23,7 +23,7 @@ pub type TypeErasedPinned<'n> = Pin<Box<dyn for<'i> NodeIO<'i, Any<'i>, Output =
|
|||
pub type NodeConstructor = for<'a> fn(Vec<TypeErasedPinnedRef<'static>>) -> DynFuture<'static, TypeErasedPinned<'static>>;
|
||||
|
||||
#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
|
||||
#[derive(Debug, Default, PartialEq, Clone)]
|
||||
#[derive(Debug, Default, PartialEq, Clone, Hash, Eq)]
|
||||
pub struct ProtoNetwork {
|
||||
// Should a proto Network even allow inputs? Don't think so
|
||||
pub inputs: Vec<NodeId>,
|
||||
|
|
@ -90,12 +90,23 @@ pub enum ConstructionArgs {
|
|||
Inline(InlineRust),
|
||||
}
|
||||
|
||||
impl Eq for ConstructionArgs {}
|
||||
|
||||
impl PartialEq for ConstructionArgs {
|
||||
fn eq(&self, other: &Self) -> bool {
|
||||
match (&self, &other) {
|
||||
(Self::Nodes(n1), Self::Nodes(n2)) => n1 == n2,
|
||||
(Self::Value(v1), Self::Value(v2)) => v1 == v2,
|
||||
_ => core::mem::discriminant(self) == core::mem::discriminant(other),
|
||||
_ => {
|
||||
use std::hash::Hasher;
|
||||
use xxhash_rust::xxh3::Xxh3;
|
||||
let hash = |input: &Self| {
|
||||
let mut hasher = Xxh3::new();
|
||||
input.hash(&mut hasher);
|
||||
hasher.finish()
|
||||
};
|
||||
hash(self) == hash(other)
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
@ -126,7 +137,7 @@ impl ConstructionArgs {
|
|||
}
|
||||
|
||||
#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
|
||||
#[derive(Debug, PartialEq, Clone)]
|
||||
#[derive(Debug, PartialEq, Clone, Hash, Eq)]
|
||||
pub struct ProtoNode {
|
||||
pub construction_args: ConstructionArgs,
|
||||
pub input: ProtoNodeInput,
|
||||
|
|
@ -137,7 +148,7 @@ pub struct ProtoNode {
|
|||
/// A ProtoNodeInput represents the input of a node in a ProtoNetwork.
|
||||
/// For documentation on the meaning of the variants, see the documentation of the `NodeInput` enum
|
||||
/// in the `document` module
|
||||
#[derive(Debug, PartialEq, Eq, Clone)]
|
||||
#[derive(Debug, PartialEq, Eq, Clone, Hash)]
|
||||
#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
|
||||
pub enum ProtoNodeInput {
|
||||
None,
|
||||
|
|
|
|||
Loading…
Reference in New Issue