From e84b9bd5bde987322fdceb099c17a12280c4f040 Mon Sep 17 00:00:00 2001 From: Dennis Date: Wed, 8 Jun 2022 09:52:58 +0200 Subject: [PATCH] Compile node graph description to GPU code --- node-graph/5 | 75 +++++++++++++ node-graph/Cargo.lock | 200 +++++++++++++++++++++++++++++++--- node-graph/gcore/src/lib.rs | 22 +++- node-graph/gcore/src/ops.rs | 48 ++++++-- node-graph/gcore/src/value.rs | 32 +++++- node-graph/gstd/Cargo.toml | 4 + node-graph/gstd/src/lib.rs | 95 ++++++++++++++++ node-graph/gstd/src/main.rs | 90 +++++++++------ node-graph/gstd/src/memo.rs | 5 +- node-graph/gstd/src/value.rs | 3 +- 10 files changed, 498 insertions(+), 76 deletions(-) create mode 100644 node-graph/5 diff --git a/node-graph/5 b/node-graph/5 new file mode 100644 index 00000000..8963b045 --- /dev/null +++ b/node-graph/5 @@ -0,0 +1,75 @@ +pub mod value; +pub use graphene_core::{generic, ops /*, structural*/}; + +#[cfg(feature = "caching")] +pub mod caching; +#[cfg(feature = "memoization")] +pub mod memo; + +pub use graphene_core::*; + +use dyn_any::{downcast_ref, DynAny, StaticType}; +pub type DynNode<'n, T> = &'n (dyn Node<'n, Output = T> + 'n); +pub type DynAnyNode<'n> = &'n (dyn Node<'n, Output = &'n dyn DynAny<'n>> + 'n); + +pub trait DynamicInput<'n> { + fn set_kwarg_by_name(&mut self, name: &str, value: DynAnyNode<'n>); + fn set_arg_by_index(&mut self, index: usize, value: DynAnyNode<'n>); +} + +use quote::quote; +use syn::{Expr, ExprPath, Type}; + +/// Given a Node call tree, construct a function +/// that takes an input tuple and evaluates the call graph +/// on the gpu an fn node is constructed that takes a value +/// node as input +struct NodeGraph { + /// Collection of nodes with their corresponding inputs. + /// The first node always always has to be an Input Node. + nodes: Vec, +} +enum NodeKind { + Value(Expr), + Input(Type), + Node(ExprPath, Vec), +} + +impl NodeGraph { + pub fn serialize(&self) -> String { + let mut output = String::new(); + let output_type = if let Some(NodeKind::Node(expr, _)) = self.nodes.last() { + expr + } else { + panic!("last node wasn't a valid node") + }; + let output_type = quote! {#output_type::Output}; + let input_type = if let Some(NodeKind::Input(type_)) = self.nodes.first() { + type_ + } else { + panic!("first node wasn't an input node") + }; + + let mut nodes = Vec::new(); + for (id, node) in self.nodes.iter().enumerate() { + let nid = |id| format!("n{id}"); + let id = nid(&id as u64); + let line = match node { + NodeKind::Value(val) => { + quote! {let #id = graphene_core::value::ValueNode::new(#val);} + } + NodeKind::Node(node, ids) => { + let ids = ids.iter().map(nid).collect::>(); + quote! {let #id = #node::new(((#(#ids),)*));} + } + }; + nodes.push(line) + } + let function = quote! { + fn node_graph(input: #input_type) -> #output_type { + #(#nodes)* + } + }; + function.to_string() + } +} diff --git a/node-graph/Cargo.lock b/node-graph/Cargo.lock index 3a4bce7d..4508a7ff 100644 --- a/node-graph/Cargo.lock +++ b/node-graph/Cargo.lock @@ -17,6 +17,18 @@ version = "0.12.1" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "33954243bd79057c2de7338850b85983a44588021f8a5fee574a8888c6de4344" +[[package]] +name = "arrayref" +version = "0.3.6" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "a4c527152e37cf757a3f78aae5a06fbeefdb07ccc535c980a3208ee3060dd544" + +[[package]] +name = "arrayvec" +version = "0.5.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "23b62fc65de8e4e7f52534fb52b0f3ed04746ae267519eef2a83941e8085068b" + [[package]] name = "arrayvec" version = "0.7.2" @@ -34,22 +46,68 @@ dependencies = [ "syn", ] +[[package]] +name = "atty" +version = "0.2.14" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "d9b39be18770d11421cdb1b9947a45dd3f37e93092cbf377614828a319d5fee8" +dependencies = [ + "hermit-abi", + "libc", + "winapi", +] + [[package]] name = "autocfg" version = "1.1.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "d468802bab17cbc0cc575e9b053f41e72aa36bfa6b7f55e3529ffa43161b97fa" +[[package]] +name = "base64" +version = "0.13.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "904dfeac50f3cdaba28fc6f57fdcddb75f49ed61346676a78c4ffe55877802fd" + [[package]] name = "bitflags" version = "1.3.2" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "bef38d45163c2f1dde094a7dfd33ccf595c92905c8f8f4fdc18d06fb1037718a" +[[package]] +name = "blake2b_simd" +version = "0.5.11" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "afa748e348ad3be8263be728124b24a24f268266f6f5d58af9d75f6a40b5c587" +dependencies = [ + "arrayref", + "arrayvec 0.5.2", + "constant_time_eq", +] + [[package]] name = "borrow_stack" version = "0.1.0" +[[package]] +name = "byteorder" +version = "1.4.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "14c189c53d098945499cdfa7ecc63567cf3886b3332b312a5b4585d8d3a6a610" + +[[package]] +name = "cc" +version = "1.0.73" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "2fff2a6927b3bb87f9595d67196a70493f627687a71d87a0d692242c33f58c11" + +[[package]] +name = "cfg-if" +version = "0.1.10" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "4785bdd1c96b2a846b2bd7cc02e86b6b3dbf14e7e53446c4f54c92a361040822" + [[package]] name = "cfg-if" version = "1.0.0" @@ -108,6 +166,12 @@ dependencies = [ "tracing", ] +[[package]] +name = "constant_time_eq" +version = "0.1.5" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "245097e9a4535ee1e3e3931fcfcd55a796a44c643e8596ff6566d68f09b87bbc" + [[package]] name = "countme" version = "3.0.1" @@ -131,7 +195,7 @@ version = "0.5.4" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "5aaa7bd5fb665c6864b5f963dd9097905c54125909c7aa94c9e18507cdbe6c53" dependencies = [ - "cfg-if", + "cfg-if 1.0.0", "crossbeam-utils", ] @@ -141,7 +205,7 @@ version = "0.8.1" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "6455c0ca19f0d2fbf751b908d5c55c1f5cbc65e03c4225427254b46890bdde1e" dependencies = [ - "cfg-if", + "cfg-if 1.0.0", "crossbeam-epoch", "crossbeam-utils", ] @@ -153,7 +217,7 @@ source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "1145cf131a2c6ba0615079ab6a638f7e1973ac9c2634fcbeaaad6114246efe8c" dependencies = [ "autocfg", - "cfg-if", + "cfg-if 1.0.0", "crossbeam-utils", "lazy_static", "memoffset", @@ -166,7 +230,7 @@ version = "0.8.8" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "0bf124c720b7686e3c2663cf54062ab0f68a88af2fb6a030e87e30bf721fcb38" dependencies = [ - "cfg-if", + "cfg-if 1.0.0", "lazy_static", ] @@ -176,12 +240,23 @@ version = "5.3.4" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "3495912c9c1ccf2e18976439f4443f3fee0fd61f424ff99fde6a66b15ecb448f" dependencies = [ - "cfg-if", + "cfg-if 1.0.0", "hashbrown 0.12.1", "lock_api", "parking_lot_core 0.9.3", ] +[[package]] +name = "dirs" +version = "1.0.5" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "3fd78930633bd1c6e35c4b42b1df7b0cbc6bc191146e512bb3bedf243fcc3901" +dependencies = [ + "libc", + "redox_users", + "winapi", +] + [[package]] name = "dissimilar" version = "1.0.4" @@ -255,6 +330,17 @@ version = "0.4.7" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "7ab85b9b05e3978cc9a9cf8fea7f01b494e1a09ed3037e16ba39edc7a29eb61a" +[[package]] +name = "getrandom" +version = "0.1.16" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "8fc3cb4d91f53b50155bdcfd23f6a4c39ae1969c2ae85982b135750cccaf5fce" +dependencies = [ + "cfg-if 1.0.0", + "libc", + "wasi", +] + [[package]] name = "glam" version = "0.20.5" @@ -295,9 +381,13 @@ dependencies = [ "lock_api", "once_cell", "parking_lot 0.12.1", + "pretty-token-stream", + "proc-macro2", + "quote", "ra_ap_ide", "ra_ap_ide_db", "storage-map", + "syn", ] [[package]] @@ -357,7 +447,7 @@ version = "0.1.12" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "7a5bbe824c507c5da5956355e86a746d82e0e1464f65d862cc5e71da70e94b2c" dependencies = [ - "cfg-if", + "cfg-if 1.0.0", ] [[package]] @@ -403,7 +493,7 @@ version = "0.4.17" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "abb12e687cfb44aa40f41fc3978ef76448f9b6038cad6aef4259d3c095a2382e" dependencies = [ - "cfg-if", + "cfg-if 1.0.0", ] [[package]] @@ -436,6 +526,19 @@ dependencies = [ "windows-sys 0.28.0", ] +[[package]] +name = "nix" +version = "0.14.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "6c722bee1037d430d0f8e687bbdbf222f27cc6e4e68d5caf630857bb2b6dbdce" +dependencies = [ + "bitflags", + "cc", + "cfg-if 0.1.10", + "libc", + "void", +] + [[package]] name = "num-traits" version = "0.2.14" @@ -495,10 +598,10 @@ version = "0.8.5" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "d76e8e1493bcac0d2766c42737f34458f1c8c50c0d23bcb24ea953affb273216" dependencies = [ - "cfg-if", + "cfg-if 1.0.0", "instant", "libc", - "redox_syscall", + "redox_syscall 0.2.13", "smallvec", "winapi", ] @@ -509,9 +612,9 @@ version = "0.9.3" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "09a279cbf25cb0757810394fbc1e359949b59e348145c643a939a525692e6929" dependencies = [ - "cfg-if", + "cfg-if 1.0.0", "libc", - "redox_syscall", + "redox_syscall 0.2.13", "smallvec", "windows-sys 0.36.1", ] @@ -557,6 +660,17 @@ version = "0.2.9" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "e0a7ae3ac2f1173085d398531c705756c94a4c56843785df85a60c1a0afac116" +[[package]] +name = "pretty-token-stream" +version = "0.1.0" +dependencies = [ + "atty", + "cfg-if 0.1.10", + "nix", + "proc-macro2", + "term", +] + [[package]] name = "proc-macro2" version = "1.0.36" @@ -639,7 +753,7 @@ version = "0.0.104" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "02f334df1ac0ceeb8d317e6fc71e6473344ddc8ba75613021a1213ee7b2b3dee" dependencies = [ - "arrayvec", + "arrayvec 0.7.2", "either", "itertools", "once_cell", @@ -663,7 +777,7 @@ source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "400a6db17870220d63834cba06883dba10d42beac6f8b3d773e82fe248762adf" dependencies = [ "anymap", - "arrayvec", + "arrayvec 0.7.2", "bitflags", "cov-mark", "dashmap", @@ -718,7 +832,7 @@ version = "0.0.104" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "df104472251b6c25c8c1efd8aaaf1c16b76afe2d06e17951478d321a2a1fa9d4" dependencies = [ - "arrayvec", + "arrayvec 0.7.2", "chalk-ir", "chalk-recursive", "chalk-solve", @@ -816,7 +930,7 @@ version = "0.0.104" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "4f0dee6238d74065219229ce0cc75c62ce1ae5630cba3ab5859e0b340705f8d7" dependencies = [ - "arrayvec", + "arrayvec 0.7.2", "cov-mark", "either", "fst", @@ -922,7 +1036,7 @@ version = "0.0.104" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "15128f7955c99c82ba17db6d3c3c921719bcb9a6912054486949149d02c4e79a" dependencies = [ - "cfg-if", + "cfg-if 1.0.0", "countme", "libc", "once_cell", @@ -1032,6 +1146,12 @@ dependencies = [ "num_cpus", ] +[[package]] +name = "redox_syscall" +version = "0.1.57" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "41cc0f7e4d5d4544e8861606a285bb08d3e70712ccc7d2b84d7c0ccfaf4b05ce" + [[package]] name = "redox_syscall" version = "0.2.13" @@ -1041,6 +1161,17 @@ dependencies = [ "bitflags", ] +[[package]] +name = "redox_users" +version = "0.3.5" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "de0737333e7a9502c789a36d7c7fa6092a49895d4faa31ca5df163857ded2e9d" +dependencies = [ + "getrandom", + "redox_syscall 0.1.57", + "rust-argon2", +] + [[package]] name = "rowan" version = "0.15.5" @@ -1054,6 +1185,18 @@ dependencies = [ "text-size", ] +[[package]] +name = "rust-argon2" +version = "0.8.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "4b18820d944b33caa75a71378964ac46f58517c92b6ae5f762636247c09e78fb" +dependencies = [ + "base64", + "blake2b_simd", + "constant_time_eq", + "crossbeam-utils", +] + [[package]] name = "rustc-ap-rustc_lexer" version = "725.0.0" @@ -1191,6 +1334,17 @@ dependencies = [ "unicode-xid", ] +[[package]] +name = "term" +version = "0.5.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "edd106a334b7657c10b7c540a0106114feadeb4dc314513e97df481d5d966f42" +dependencies = [ + "byteorder", + "dirs", + "winapi", +] + [[package]] name = "text-size" version = "1.1.0" @@ -1218,7 +1372,7 @@ version = "0.1.34" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "5d0ecdcb44a79f0fe9844f0c4f33a342cbcbb5117de8001e6ba0dc2351327d09" dependencies = [ - "cfg-if", + "cfg-if 1.0.0", "pin-project-lite", "tracing-attributes", "tracing-core", @@ -1304,6 +1458,18 @@ version = "0.9.4" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "49874b5167b65d7193b8aba1567f5c7d93d001cafc34600cee003eda787e483f" +[[package]] +name = "void" +version = "1.0.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "6a02e4885ed3bc0f2de90ea6dd45ebcbb66dacffe03547fadbb0eeae2770887d" + +[[package]] +name = "wasi" +version = "0.9.0+wasi-snapshot-preview1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "cccddf32554fecc6acb585f82a32a72e28b48f8c4c1883ddfeeeaa96f7d8e519" + [[package]] name = "winapi" version = "0.3.9" diff --git a/node-graph/gcore/src/lib.rs b/node-graph/gcore/src/lib.rs index 95f0abdb..91bde8ad 100644 --- a/node-graph/gcore/src/lib.rs +++ b/node-graph/gcore/src/lib.rs @@ -14,7 +14,7 @@ pub mod ops; pub mod value; pub trait Node<'n> { - type Output: 'n; // TODO: replace with generic associated type + type Output; // TODO: replace with generic associated type fn eval(&'n self) -> Self::Output; } @@ -27,12 +27,26 @@ impl<'n, N: Node<'n>> Node<'n> for &'n N { } } +pub trait NodeInput { + type Nodes; + + fn new(input: Self::Nodes) -> Self; +} + +trait FQN { + fn fqn(&self) -> &'static str; +} + +trait Input { + unsafe fn input(&self, input: I); +} + #[cfg(feature = "async")] #[async_trait] pub trait AsyncNode<'n> { - type Output: 'n; // TODO: replace with generic associated type + type Output; // TODO: replace with generic associated type - async fn eval(&'n self) -> Self::Output; + async fn eval_async(&'n self) -> Self::Output; } #[cfg(feature = "async")] @@ -40,7 +54,7 @@ pub trait AsyncNode<'n> { impl<'n, N: Node<'n> + Sync> AsyncNode<'n> for N { type Output = N::Output; - async fn eval(&'n self) -> Self::Output { + async fn eval_async(&'n self) -> Self::Output { Node::eval(self) } } diff --git a/node-graph/gcore/src/ops.rs b/node-graph/gcore/src/ops.rs index ace01f8a..673455a9 100644 --- a/node-graph/gcore/src/ops.rs +++ b/node-graph/gcore/src/ops.rs @@ -1,9 +1,9 @@ use core::{marker::PhantomData, ops::Add}; -use crate::Node; +use crate::{Node, NodeInput}; #[repr(C)] -struct AddNode<'n, L: Add, R, I1: Node<'n, Output = L>, I2: Node<'n, Output = R>>( +pub struct AddNode<'n, L: Add, R, I1: Node<'n, Output = L>, I2: Node<'n, Output = R>>( pub I1, pub I2, PhantomData<&'n (L, R)>, @@ -16,6 +16,13 @@ impl<'n, L: Add, R, I1: Node<'n, Output = L>, I2: Node<'n, Output = R>> Node< self.0.eval() + self.1.eval() } } +impl<'n, L: Add, R, I1: Node<'n, Output = L>, I2: Node<'n, Output = R>> + AddNode<'n, L, R, I1, I2> +{ + pub fn new(input: (I1, I2)) -> AddNode<'n, L, R, I1, I2> { + AddNode(input.0, input.1, PhantomData) + } +} #[repr(C)] pub struct CloneNode<'n, N: Node<'n, Output = &'n O>, O: Clone + 'n>(pub N, PhantomData<&'n ()>); @@ -25,6 +32,11 @@ impl<'n, N: Node<'n, Output = &'n O>, O: Clone> Node<'n> for CloneNode<'n, N, O> self.0.eval().clone() } } +impl<'n, N: Node<'n, Output = &'n O>, O: Clone> CloneNode<'n, N, O> { + pub const fn new(node: N) -> CloneNode<'n, N, O> { + CloneNode(node, PhantomData) + } +} #[repr(C)] pub struct FstNode<'n, N: Node<'n>>(pub N, PhantomData<&'n ()>); @@ -56,13 +68,12 @@ impl<'n, N: Node<'n>> Node<'n> for DupNode<'n, N> { (self.0.eval(), self.0.eval()) //TODO: use Copy/Clone implementation } } +impl<'n, N: Node<'n>> NodeInput for DupNode<'n, N> { + type Nodes = N; -#[repr(C)] -/// Return the unit value -pub struct UnitNode; -impl<'n> Node<'n> for UnitNode { - type Output = (); - fn eval(&'n self) -> Self::Output {} + fn new(input: Self::Nodes) -> Self { + Self(input, PhantomData) + } } #[repr(C)] @@ -74,11 +85,18 @@ impl<'n, N: Node<'n>> Node<'n> for IdNode<'n, N> { self.0.eval() } } +impl<'n, N: Node<'n>> NodeInput for IdNode<'n, N> { + type Nodes = N; + + fn new(input: Self::Nodes) -> Self { + Self(input, PhantomData) + } +} pub fn foo() { - let unit = UnitNode; - let value = IdNode(crate::value::ValueNode::new(2u32), PhantomData); - let value2 = crate::value::ValueNode::new(4u32); + let unit = crate::value::UnitNode; + let value = IdNode(crate::value::ValueNode(2u32), PhantomData); + let value2 = crate::value::ValueNode(4u32); let dup = DupNode(&value, PhantomData); fn int(_: (), state: &u32) -> &u32 { state @@ -120,10 +138,16 @@ pub mod gpu { #[spirv(storage_buffer, descriptor_set = 0, binding = 1)] y: &mut [(u32, u32)], #[spirv(push_constant)] push_consts: &PushConsts, ) { + fn node_graph(input: Input) -> Output { + let n0 = ValueNode::new(input); + let n1 = IdNode::new(n0); + let n2 = IdNode::new(n1); + return n2.eval(); + } let gid = global_id.x as usize; // Only process up to n, which is the length of the buffers. if global_id.x < push_consts.n { - y[gid] = OPERATION.eval(a[gid]); + y[gid] = node_graph(a[gid]); } } #[allow(unused)] diff --git a/node-graph/gcore/src/value.rs b/node-graph/gcore/src/value.rs index b38c4c86..e05e1d57 100644 --- a/node-graph/gcore/src/value.rs +++ b/node-graph/gcore/src/value.rs @@ -1,4 +1,6 @@ use core::marker::PhantomData; +use core::mem::MaybeUninit; +use core::sync::atomic::AtomicBool; use crate::Node; @@ -18,8 +20,7 @@ impl<'n, T: 'n> Node<'n> for ValueNode { &self.0 } } - -impl<'n, T> ValueNode { +impl ValueNode { pub const fn new(value: T) -> ValueNode { ValueNode(value) } @@ -33,8 +34,29 @@ impl<'n, T: Default + 'n> Node<'n> for DefaultNode { T::default() } } -impl DefaultNode { - pub const fn new() -> DefaultNode { - DefaultNode(PhantomData) + +#[repr(C)] +/// Return the unit value +pub struct UnitNode; +impl<'n> Node<'n> for UnitNode { + type Output = (); + fn eval(&'n self) -> Self::Output {} +} + +pub struct InputNode(MaybeUninit, AtomicBool); +impl<'n, T: 'n> Node<'n> for InputNode { + type Output = &'n T; + fn eval(&'n self) -> Self::Output { + if self.1.load(core::sync::atomic::Ordering::SeqCst) { + unsafe { self.0.assume_init_ref() } + } else { + panic!("tried to access an input before setting it") + } + } +} + +impl InputNode { + pub const fn new() -> InputNode { + InputNode(MaybeUninit::uninit(), AtomicBool::new(false)) } } diff --git a/node-graph/gstd/Cargo.toml b/node-graph/gstd/Cargo.toml index 5069da7d..c7888196 100644 --- a/node-graph/gstd/Cargo.toml +++ b/node-graph/gstd/Cargo.toml @@ -26,3 +26,7 @@ ide_db = { version = "*", package = "ra_ap_ide_db" , optional = true } storage-map = { version = "*", optional = true } lock_api = { version= "*", optional = true } parking_lot = { version = "*", optional = true } +pretty-token-stream = {path = "../../pretty-token-stream"} +syn = {version = "1.0", default-features = false, features = ["parsing", "printing"]} +proc-macro2 = {version = "1.0", default-features = false, features = ["proc-macro"]} +quote = {version = "1.0", default-features = false } diff --git a/node-graph/gstd/src/lib.rs b/node-graph/gstd/src/lib.rs index 20b513dc..1c18d263 100644 --- a/node-graph/gstd/src/lib.rs +++ b/node-graph/gstd/src/lib.rs @@ -16,3 +16,98 @@ pub trait DynamicInput<'n> { fn set_kwarg_by_name(&mut self, name: &str, value: DynAnyNode<'n>); fn set_arg_by_index(&mut self, index: usize, value: DynAnyNode<'n>); } + +use quote::quote; +use syn::{Expr, ExprPath, Type}; + +/// Given a Node call tree, construct a function +/// that takes an input tuple and evaluates the call graph +/// on the gpu an fn node is constructed that takes a value +/// node as input +pub struct NodeGraph { + /// Collection of nodes with their corresponding inputs. + /// The first node always always has to be an Input Node. + pub nodes: Vec, + pub output: Type, + pub input: Type, +} +pub enum NodeKind { + Value(Expr), + Input, + Node(ExprPath, Vec), +} + +impl NodeGraph { + pub fn serialize_function(&self) -> proc_macro2::TokenStream { + let output_type = &self.output; + let input_type = &self.input; + + fn nid(id: &usize) -> syn::Ident { + let str = format!("n{id}"); + syn::Ident::new(str.as_str(), proc_macro2::Span::call_site()) + } + let mut nodes = Vec::new(); + for (ref id, node) in self.nodes.iter().enumerate() { + let id = nid(id).clone(); + let line = match node { + NodeKind::Value(val) => { + quote! {let #id = graphene_core::value::ValueNode::new(#val);} + } + NodeKind::Node(node, ids) => { + let ids = ids.iter().map(nid).collect::>(); + quote! {let #id = #node::new((#(&#ids),*));} + } + NodeKind::Input => { + quote! { let n0 = graphene_core::value::ValueNode::new(input);} + } + }; + nodes.push(line) + } + let last_id = self.nodes.len() - 1; + let last_id = nid(&last_id); + let ret = quote! { #last_id.eval() }; + let function = quote! { + fn node_graph(input: #input_type) -> #output_type { + #(#nodes)* + #ret + } + }; + function + } + pub fn serialize_gpu(&self, name: &str) -> proc_macro2::TokenStream { + let function = self.serialize_function(); + let output_type = &self.output; + let input_type = &self.input; + + quote! { + #[cfg(target_arch = "spirv")] + pub mod gpu { + //#![deny(warnings)] + #[repr(C)] + pub struct PushConsts { + n: u32, + node: u32, + } + use super::*; + + use spirv_std::glam::UVec3; + + #[allow(unused)] + #[spirv(compute(threads(64)))] + pub fn #name( + #[spirv(global_invocation_id)] global_id: UVec3, + #[spirv(storage_buffer, descriptor_set = 0, binding = 0)] a: &[#input_type], + #[spirv(storage_buffer, descriptor_set = 0, binding = 1)] y: &mut [#output_type], + #[spirv(push_constant)] push_consts: &PushConsts, + ) { + #function + let gid = global_id.x as usize; + // Only process up to n, which is the length of the buffers. + if global_id.x < push_consts.n { + y[gid] = node_graph(a[gid]); + } + } + } + } + } +} diff --git a/node-graph/gstd/src/main.rs b/node-graph/gstd/src/main.rs index acbc1e24..05f9b2e7 100644 --- a/node-graph/gstd/src/main.rs +++ b/node-graph/gstd/src/main.rs @@ -91,41 +91,61 @@ impl<'n> NodeStore<'n> { } fn main() { - use dyn_any::{downcast_ref, DynAny, StaticType}; - //let mut mul = mul::MulNode::new(); - let mut stack: borrow_stack::FixedSizeStack>> = - borrow_stack::FixedSizeStack::new(42); - unsafe { stack.push(Box::new(AnyValueNode::new(1f32))) }; - //let node = unsafe { stack.get(0) }; - //let boxed = Box::new(StorageNode::new(node)); - //unsafe { stack.push(boxed) }; - let result = unsafe { &stack.get()[0] }.eval(); - dbg!(downcast_ref::(result)); - /*unsafe { - stack - .push(Box::new(AnyRefNode::new(stack.get(0).as_ref())) - as Box>) - };*/ - let f = (3.2f32, 3.1f32); - let a = ValueNode::new(1.); - let id = std::any::TypeId::of::<&f32>(); - let any_a = AnyRefNode::new(&a); - /*let _mul2 = mul::MulNodeInput { - a: None, - b: Some(&any_a), - }; - let mut mul2 = mul::new!(); - //let cached = memo::CacheNode::new(&mul1); - //let foo = value::AnyRefNode::new(&cached); - mul2.set_arg_by_index(0, &any_a);*/ - let int = value::IntNode::<32>; - Node::eval(&int); - println!("{}", Node::eval(&int)); - //let _add: u32 = ops::AddNode::::default().eval((int.exec(), int.exec())); - //let fnode = generic::FnNode::new(|(a, b): &(i32, i32)| a - b); - //let sub = fnode.any(&("a", 2)); - //let cache = memo::CacheNode::new(&fnode); - //let cached_result = cache.eval(&(2, 3)); + use graphene_std::*; + use quote::quote; + use syn::parse::Parse; + let nodes = vec![ + NodeKind::Input, + NodeKind::Value(syn::parse_quote!(1u32)), + NodeKind::Node(syn::parse_quote!(graphene_core::ops::AddNode), vec![0, 0]), + ]; + //println!("{}", node_graph(1)); + + let nodegraph = NodeGraph { + nodes, + input: syn::Type::Verbatim(quote! {u32}), + output: syn::Type::Verbatim(quote! {u32}), + }; + + let pretty = pretty_token_stream::Pretty::new(nodegraph.serialize_gpu("add")); + pretty.print(); + /* + use dyn_any::{downcast_ref, DynAny, StaticType}; + //let mut mul = mul::MulNode::new(); + let mut stack: borrow_stack::FixedSizeStack>> = + borrow_stack::FixedSizeStack::new(42); + unsafe { stack.push(Box::new(AnyValueNode::new(1f32))) }; + //let node = unsafe { stack.get(0) }; + //let boxed = Box::new(StorageNode::new(node)); + //unsafe { stack.push(boxed) }; + let result = unsafe { &stack.get()[0] }.eval(); + dbg!(downcast_ref::(result)); + /*unsafe { + stack + .push(Box::new(AnyRefNode::new(stack.get(0).as_ref())) + as Box>) + };*/ + let f = (3.2f32, 3.1f32); + let a = ValueNode::new(1.); + let id = std::any::TypeId::of::<&f32>(); + let any_a = AnyRefNode::new(&a); + /*let _mul2 = mul::MulNodeInput { + a: None, + b: Some(&any_a), + }; + let mut mul2 = mul::new!(); + //let cached = memo::CacheNode::new(&mul1); + //let foo = value::AnyRefNode::new(&cached); + mul2.set_arg_by_index(0, &any_a);*/ + let int = value::IntNode::<32>; + Node::eval(&int); + println!("{}", Node::eval(&int)); + //let _add: u32 = ops::AddNode::::default().eval((int.exec(), int.exec())); + //let fnode = generic::FnNode::new(|(a, b): &(i32, i32)| a - b); + //let sub = fnode.any(&("a", 2)); + //let cache = memo::CacheNode::new(&fnode); + //let cached_result = cache.eval(&(2, 3)); + */ //println!("{}", cached_result) } diff --git a/node-graph/gstd/src/memo.rs b/node-graph/gstd/src/memo.rs index cbf97952..f97fd21b 100644 --- a/node-graph/gstd/src/memo.rs +++ b/node-graph/gstd/src/memo.rs @@ -8,7 +8,10 @@ pub struct CacheNode<'n, CachedNode: Node<'n>> { cache: OnceCell, _phantom: PhantomData<&'n ()>, } -impl<'n, CashedNode: Node<'n>> Node<'n> for CacheNode<'n, CashedNode> { +impl<'n, CashedNode: Node<'n>> Node<'n> for CacheNode<'n, CashedNode> +where + CashedNode::Output: 'n, +{ type Output = &'n CashedNode::Output; fn eval(&'n self) -> Self::Output { self.cache.get_or_init(|| self.node.eval()) diff --git a/node-graph/gstd/src/value.rs b/node-graph/gstd/src/value.rs index 14b5c205..46fc9299 100644 --- a/node-graph/gstd/src/value.rs +++ b/node-graph/gstd/src/value.rs @@ -24,8 +24,7 @@ pub struct StorageNode<'n>(&'n dyn Node<'n, Output = &'n dyn DynAny<'n>>); impl<'n> Node<'n> for StorageNode<'n> { type Output = &'n (dyn DynAny<'n>); fn eval(&'n self) -> Self::Output { - let value = self.0.eval(); - value + self.0.eval() } } impl<'n> StorageNode<'n> {