Compile node graph description to GPU code
This commit is contained in:
parent
998f37d1b0
commit
e84b9bd5bd
|
|
@ -0,0 +1,75 @@
|
|||
pub mod value;
|
||||
pub use graphene_core::{generic, ops /*, structural*/};
|
||||
|
||||
#[cfg(feature = "caching")]
|
||||
pub mod caching;
|
||||
#[cfg(feature = "memoization")]
|
||||
pub mod memo;
|
||||
|
||||
pub use graphene_core::*;
|
||||
|
||||
use dyn_any::{downcast_ref, DynAny, StaticType};
|
||||
pub type DynNode<'n, T> = &'n (dyn Node<'n, Output = T> + 'n);
|
||||
pub type DynAnyNode<'n> = &'n (dyn Node<'n, Output = &'n dyn DynAny<'n>> + 'n);
|
||||
|
||||
pub trait DynamicInput<'n> {
|
||||
fn set_kwarg_by_name(&mut self, name: &str, value: DynAnyNode<'n>);
|
||||
fn set_arg_by_index(&mut self, index: usize, value: DynAnyNode<'n>);
|
||||
}
|
||||
|
||||
use quote::quote;
|
||||
use syn::{Expr, ExprPath, Type};
|
||||
|
||||
/// Given a Node call tree, construct a function
|
||||
/// that takes an input tuple and evaluates the call graph
|
||||
/// on the gpu an fn node is constructed that takes a value
|
||||
/// node as input
|
||||
struct NodeGraph {
|
||||
/// Collection of nodes with their corresponding inputs.
|
||||
/// The first node always always has to be an Input Node.
|
||||
nodes: Vec<NodeKind>,
|
||||
}
|
||||
enum NodeKind {
|
||||
Value(Expr),
|
||||
Input(Type),
|
||||
Node(ExprPath, Vec<u64>),
|
||||
}
|
||||
|
||||
impl NodeGraph {
|
||||
pub fn serialize(&self) -> String {
|
||||
let mut output = String::new();
|
||||
let output_type = if let Some(NodeKind::Node(expr, _)) = self.nodes.last() {
|
||||
expr
|
||||
} else {
|
||||
panic!("last node wasn't a valid node")
|
||||
};
|
||||
let output_type = quote! {#output_type::Output};
|
||||
let input_type = if let Some(NodeKind::Input(type_)) = self.nodes.first() {
|
||||
type_
|
||||
} else {
|
||||
panic!("first node wasn't an input node")
|
||||
};
|
||||
|
||||
let mut nodes = Vec::new();
|
||||
for (id, node) in self.nodes.iter().enumerate() {
|
||||
let nid = |id| format!("n{id}");
|
||||
let id = nid(&id as u64);
|
||||
let line = match node {
|
||||
NodeKind::Value(val) => {
|
||||
quote! {let #id = graphene_core::value::ValueNode::new(#val);}
|
||||
}
|
||||
NodeKind::Node(node, ids) => {
|
||||
let ids = ids.iter().map(nid).collect::<Vec<_>>();
|
||||
quote! {let #id = #node::new(((#(#ids),)*));}
|
||||
}
|
||||
};
|
||||
nodes.push(line)
|
||||
}
|
||||
let function = quote! {
|
||||
fn node_graph(input: #input_type) -> #output_type {
|
||||
#(#nodes)*
|
||||
}
|
||||
};
|
||||
function.to_string()
|
||||
}
|
||||
}
|
||||
|
|
@ -17,6 +17,18 @@ version = "0.12.1"
|
|||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "33954243bd79057c2de7338850b85983a44588021f8a5fee574a8888c6de4344"
|
||||
|
||||
[[package]]
|
||||
name = "arrayref"
|
||||
version = "0.3.6"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "a4c527152e37cf757a3f78aae5a06fbeefdb07ccc535c980a3208ee3060dd544"
|
||||
|
||||
[[package]]
|
||||
name = "arrayvec"
|
||||
version = "0.5.2"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "23b62fc65de8e4e7f52534fb52b0f3ed04746ae267519eef2a83941e8085068b"
|
||||
|
||||
[[package]]
|
||||
name = "arrayvec"
|
||||
version = "0.7.2"
|
||||
|
|
@ -34,22 +46,68 @@ dependencies = [
|
|||
"syn",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "atty"
|
||||
version = "0.2.14"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "d9b39be18770d11421cdb1b9947a45dd3f37e93092cbf377614828a319d5fee8"
|
||||
dependencies = [
|
||||
"hermit-abi",
|
||||
"libc",
|
||||
"winapi",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "autocfg"
|
||||
version = "1.1.0"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "d468802bab17cbc0cc575e9b053f41e72aa36bfa6b7f55e3529ffa43161b97fa"
|
||||
|
||||
[[package]]
|
||||
name = "base64"
|
||||
version = "0.13.0"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "904dfeac50f3cdaba28fc6f57fdcddb75f49ed61346676a78c4ffe55877802fd"
|
||||
|
||||
[[package]]
|
||||
name = "bitflags"
|
||||
version = "1.3.2"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "bef38d45163c2f1dde094a7dfd33ccf595c92905c8f8f4fdc18d06fb1037718a"
|
||||
|
||||
[[package]]
|
||||
name = "blake2b_simd"
|
||||
version = "0.5.11"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "afa748e348ad3be8263be728124b24a24f268266f6f5d58af9d75f6a40b5c587"
|
||||
dependencies = [
|
||||
"arrayref",
|
||||
"arrayvec 0.5.2",
|
||||
"constant_time_eq",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "borrow_stack"
|
||||
version = "0.1.0"
|
||||
|
||||
[[package]]
|
||||
name = "byteorder"
|
||||
version = "1.4.3"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "14c189c53d098945499cdfa7ecc63567cf3886b3332b312a5b4585d8d3a6a610"
|
||||
|
||||
[[package]]
|
||||
name = "cc"
|
||||
version = "1.0.73"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "2fff2a6927b3bb87f9595d67196a70493f627687a71d87a0d692242c33f58c11"
|
||||
|
||||
[[package]]
|
||||
name = "cfg-if"
|
||||
version = "0.1.10"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "4785bdd1c96b2a846b2bd7cc02e86b6b3dbf14e7e53446c4f54c92a361040822"
|
||||
|
||||
[[package]]
|
||||
name = "cfg-if"
|
||||
version = "1.0.0"
|
||||
|
|
@ -108,6 +166,12 @@ dependencies = [
|
|||
"tracing",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "constant_time_eq"
|
||||
version = "0.1.5"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "245097e9a4535ee1e3e3931fcfcd55a796a44c643e8596ff6566d68f09b87bbc"
|
||||
|
||||
[[package]]
|
||||
name = "countme"
|
||||
version = "3.0.1"
|
||||
|
|
@ -131,7 +195,7 @@ version = "0.5.4"
|
|||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "5aaa7bd5fb665c6864b5f963dd9097905c54125909c7aa94c9e18507cdbe6c53"
|
||||
dependencies = [
|
||||
"cfg-if",
|
||||
"cfg-if 1.0.0",
|
||||
"crossbeam-utils",
|
||||
]
|
||||
|
||||
|
|
@ -141,7 +205,7 @@ version = "0.8.1"
|
|||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "6455c0ca19f0d2fbf751b908d5c55c1f5cbc65e03c4225427254b46890bdde1e"
|
||||
dependencies = [
|
||||
"cfg-if",
|
||||
"cfg-if 1.0.0",
|
||||
"crossbeam-epoch",
|
||||
"crossbeam-utils",
|
||||
]
|
||||
|
|
@ -153,7 +217,7 @@ source = "registry+https://github.com/rust-lang/crates.io-index"
|
|||
checksum = "1145cf131a2c6ba0615079ab6a638f7e1973ac9c2634fcbeaaad6114246efe8c"
|
||||
dependencies = [
|
||||
"autocfg",
|
||||
"cfg-if",
|
||||
"cfg-if 1.0.0",
|
||||
"crossbeam-utils",
|
||||
"lazy_static",
|
||||
"memoffset",
|
||||
|
|
@ -166,7 +230,7 @@ version = "0.8.8"
|
|||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "0bf124c720b7686e3c2663cf54062ab0f68a88af2fb6a030e87e30bf721fcb38"
|
||||
dependencies = [
|
||||
"cfg-if",
|
||||
"cfg-if 1.0.0",
|
||||
"lazy_static",
|
||||
]
|
||||
|
||||
|
|
@ -176,12 +240,23 @@ version = "5.3.4"
|
|||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "3495912c9c1ccf2e18976439f4443f3fee0fd61f424ff99fde6a66b15ecb448f"
|
||||
dependencies = [
|
||||
"cfg-if",
|
||||
"cfg-if 1.0.0",
|
||||
"hashbrown 0.12.1",
|
||||
"lock_api",
|
||||
"parking_lot_core 0.9.3",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "dirs"
|
||||
version = "1.0.5"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "3fd78930633bd1c6e35c4b42b1df7b0cbc6bc191146e512bb3bedf243fcc3901"
|
||||
dependencies = [
|
||||
"libc",
|
||||
"redox_users",
|
||||
"winapi",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "dissimilar"
|
||||
version = "1.0.4"
|
||||
|
|
@ -255,6 +330,17 @@ version = "0.4.7"
|
|||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "7ab85b9b05e3978cc9a9cf8fea7f01b494e1a09ed3037e16ba39edc7a29eb61a"
|
||||
|
||||
[[package]]
|
||||
name = "getrandom"
|
||||
version = "0.1.16"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "8fc3cb4d91f53b50155bdcfd23f6a4c39ae1969c2ae85982b135750cccaf5fce"
|
||||
dependencies = [
|
||||
"cfg-if 1.0.0",
|
||||
"libc",
|
||||
"wasi",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "glam"
|
||||
version = "0.20.5"
|
||||
|
|
@ -295,9 +381,13 @@ dependencies = [
|
|||
"lock_api",
|
||||
"once_cell",
|
||||
"parking_lot 0.12.1",
|
||||
"pretty-token-stream",
|
||||
"proc-macro2",
|
||||
"quote",
|
||||
"ra_ap_ide",
|
||||
"ra_ap_ide_db",
|
||||
"storage-map",
|
||||
"syn",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
|
|
@ -357,7 +447,7 @@ version = "0.1.12"
|
|||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "7a5bbe824c507c5da5956355e86a746d82e0e1464f65d862cc5e71da70e94b2c"
|
||||
dependencies = [
|
||||
"cfg-if",
|
||||
"cfg-if 1.0.0",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
|
|
@ -403,7 +493,7 @@ version = "0.4.17"
|
|||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "abb12e687cfb44aa40f41fc3978ef76448f9b6038cad6aef4259d3c095a2382e"
|
||||
dependencies = [
|
||||
"cfg-if",
|
||||
"cfg-if 1.0.0",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
|
|
@ -436,6 +526,19 @@ dependencies = [
|
|||
"windows-sys 0.28.0",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "nix"
|
||||
version = "0.14.1"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "6c722bee1037d430d0f8e687bbdbf222f27cc6e4e68d5caf630857bb2b6dbdce"
|
||||
dependencies = [
|
||||
"bitflags",
|
||||
"cc",
|
||||
"cfg-if 0.1.10",
|
||||
"libc",
|
||||
"void",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "num-traits"
|
||||
version = "0.2.14"
|
||||
|
|
@ -495,10 +598,10 @@ version = "0.8.5"
|
|||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "d76e8e1493bcac0d2766c42737f34458f1c8c50c0d23bcb24ea953affb273216"
|
||||
dependencies = [
|
||||
"cfg-if",
|
||||
"cfg-if 1.0.0",
|
||||
"instant",
|
||||
"libc",
|
||||
"redox_syscall",
|
||||
"redox_syscall 0.2.13",
|
||||
"smallvec",
|
||||
"winapi",
|
||||
]
|
||||
|
|
@ -509,9 +612,9 @@ version = "0.9.3"
|
|||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "09a279cbf25cb0757810394fbc1e359949b59e348145c643a939a525692e6929"
|
||||
dependencies = [
|
||||
"cfg-if",
|
||||
"cfg-if 1.0.0",
|
||||
"libc",
|
||||
"redox_syscall",
|
||||
"redox_syscall 0.2.13",
|
||||
"smallvec",
|
||||
"windows-sys 0.36.1",
|
||||
]
|
||||
|
|
@ -557,6 +660,17 @@ version = "0.2.9"
|
|||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "e0a7ae3ac2f1173085d398531c705756c94a4c56843785df85a60c1a0afac116"
|
||||
|
||||
[[package]]
|
||||
name = "pretty-token-stream"
|
||||
version = "0.1.0"
|
||||
dependencies = [
|
||||
"atty",
|
||||
"cfg-if 0.1.10",
|
||||
"nix",
|
||||
"proc-macro2",
|
||||
"term",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "proc-macro2"
|
||||
version = "1.0.36"
|
||||
|
|
@ -639,7 +753,7 @@ version = "0.0.104"
|
|||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "02f334df1ac0ceeb8d317e6fc71e6473344ddc8ba75613021a1213ee7b2b3dee"
|
||||
dependencies = [
|
||||
"arrayvec",
|
||||
"arrayvec 0.7.2",
|
||||
"either",
|
||||
"itertools",
|
||||
"once_cell",
|
||||
|
|
@ -663,7 +777,7 @@ source = "registry+https://github.com/rust-lang/crates.io-index"
|
|||
checksum = "400a6db17870220d63834cba06883dba10d42beac6f8b3d773e82fe248762adf"
|
||||
dependencies = [
|
||||
"anymap",
|
||||
"arrayvec",
|
||||
"arrayvec 0.7.2",
|
||||
"bitflags",
|
||||
"cov-mark",
|
||||
"dashmap",
|
||||
|
|
@ -718,7 +832,7 @@ version = "0.0.104"
|
|||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "df104472251b6c25c8c1efd8aaaf1c16b76afe2d06e17951478d321a2a1fa9d4"
|
||||
dependencies = [
|
||||
"arrayvec",
|
||||
"arrayvec 0.7.2",
|
||||
"chalk-ir",
|
||||
"chalk-recursive",
|
||||
"chalk-solve",
|
||||
|
|
@ -816,7 +930,7 @@ version = "0.0.104"
|
|||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "4f0dee6238d74065219229ce0cc75c62ce1ae5630cba3ab5859e0b340705f8d7"
|
||||
dependencies = [
|
||||
"arrayvec",
|
||||
"arrayvec 0.7.2",
|
||||
"cov-mark",
|
||||
"either",
|
||||
"fst",
|
||||
|
|
@ -922,7 +1036,7 @@ version = "0.0.104"
|
|||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "15128f7955c99c82ba17db6d3c3c921719bcb9a6912054486949149d02c4e79a"
|
||||
dependencies = [
|
||||
"cfg-if",
|
||||
"cfg-if 1.0.0",
|
||||
"countme",
|
||||
"libc",
|
||||
"once_cell",
|
||||
|
|
@ -1032,6 +1146,12 @@ dependencies = [
|
|||
"num_cpus",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "redox_syscall"
|
||||
version = "0.1.57"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "41cc0f7e4d5d4544e8861606a285bb08d3e70712ccc7d2b84d7c0ccfaf4b05ce"
|
||||
|
||||
[[package]]
|
||||
name = "redox_syscall"
|
||||
version = "0.2.13"
|
||||
|
|
@ -1041,6 +1161,17 @@ dependencies = [
|
|||
"bitflags",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "redox_users"
|
||||
version = "0.3.5"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "de0737333e7a9502c789a36d7c7fa6092a49895d4faa31ca5df163857ded2e9d"
|
||||
dependencies = [
|
||||
"getrandom",
|
||||
"redox_syscall 0.1.57",
|
||||
"rust-argon2",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "rowan"
|
||||
version = "0.15.5"
|
||||
|
|
@ -1054,6 +1185,18 @@ dependencies = [
|
|||
"text-size",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "rust-argon2"
|
||||
version = "0.8.3"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "4b18820d944b33caa75a71378964ac46f58517c92b6ae5f762636247c09e78fb"
|
||||
dependencies = [
|
||||
"base64",
|
||||
"blake2b_simd",
|
||||
"constant_time_eq",
|
||||
"crossbeam-utils",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "rustc-ap-rustc_lexer"
|
||||
version = "725.0.0"
|
||||
|
|
@ -1191,6 +1334,17 @@ dependencies = [
|
|||
"unicode-xid",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "term"
|
||||
version = "0.5.2"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "edd106a334b7657c10b7c540a0106114feadeb4dc314513e97df481d5d966f42"
|
||||
dependencies = [
|
||||
"byteorder",
|
||||
"dirs",
|
||||
"winapi",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "text-size"
|
||||
version = "1.1.0"
|
||||
|
|
@ -1218,7 +1372,7 @@ version = "0.1.34"
|
|||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "5d0ecdcb44a79f0fe9844f0c4f33a342cbcbb5117de8001e6ba0dc2351327d09"
|
||||
dependencies = [
|
||||
"cfg-if",
|
||||
"cfg-if 1.0.0",
|
||||
"pin-project-lite",
|
||||
"tracing-attributes",
|
||||
"tracing-core",
|
||||
|
|
@ -1304,6 +1458,18 @@ version = "0.9.4"
|
|||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "49874b5167b65d7193b8aba1567f5c7d93d001cafc34600cee003eda787e483f"
|
||||
|
||||
[[package]]
|
||||
name = "void"
|
||||
version = "1.0.2"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "6a02e4885ed3bc0f2de90ea6dd45ebcbb66dacffe03547fadbb0eeae2770887d"
|
||||
|
||||
[[package]]
|
||||
name = "wasi"
|
||||
version = "0.9.0+wasi-snapshot-preview1"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "cccddf32554fecc6acb585f82a32a72e28b48f8c4c1883ddfeeeaa96f7d8e519"
|
||||
|
||||
[[package]]
|
||||
name = "winapi"
|
||||
version = "0.3.9"
|
||||
|
|
|
|||
|
|
@ -14,7 +14,7 @@ pub mod ops;
|
|||
pub mod value;
|
||||
|
||||
pub trait Node<'n> {
|
||||
type Output: 'n; // TODO: replace with generic associated type
|
||||
type Output; // TODO: replace with generic associated type
|
||||
|
||||
fn eval(&'n self) -> Self::Output;
|
||||
}
|
||||
|
|
@ -27,12 +27,26 @@ impl<'n, N: Node<'n>> Node<'n> for &'n N {
|
|||
}
|
||||
}
|
||||
|
||||
pub trait NodeInput {
|
||||
type Nodes;
|
||||
|
||||
fn new(input: Self::Nodes) -> Self;
|
||||
}
|
||||
|
||||
trait FQN {
|
||||
fn fqn(&self) -> &'static str;
|
||||
}
|
||||
|
||||
trait Input<I> {
|
||||
unsafe fn input(&self, input: I);
|
||||
}
|
||||
|
||||
#[cfg(feature = "async")]
|
||||
#[async_trait]
|
||||
pub trait AsyncNode<'n> {
|
||||
type Output: 'n; // TODO: replace with generic associated type
|
||||
type Output; // TODO: replace with generic associated type
|
||||
|
||||
async fn eval(&'n self) -> Self::Output;
|
||||
async fn eval_async(&'n self) -> Self::Output;
|
||||
}
|
||||
|
||||
#[cfg(feature = "async")]
|
||||
|
|
@ -40,7 +54,7 @@ pub trait AsyncNode<'n> {
|
|||
impl<'n, N: Node<'n> + Sync> AsyncNode<'n> for N {
|
||||
type Output = N::Output;
|
||||
|
||||
async fn eval(&'n self) -> Self::Output {
|
||||
async fn eval_async(&'n self) -> Self::Output {
|
||||
Node::eval(self)
|
||||
}
|
||||
}
|
||||
|
|
|
|||
|
|
@ -1,9 +1,9 @@
|
|||
use core::{marker::PhantomData, ops::Add};
|
||||
|
||||
use crate::Node;
|
||||
use crate::{Node, NodeInput};
|
||||
|
||||
#[repr(C)]
|
||||
struct AddNode<'n, L: Add<R>, R, I1: Node<'n, Output = L>, I2: Node<'n, Output = R>>(
|
||||
pub struct AddNode<'n, L: Add<R>, R, I1: Node<'n, Output = L>, I2: Node<'n, Output = R>>(
|
||||
pub I1,
|
||||
pub I2,
|
||||
PhantomData<&'n (L, R)>,
|
||||
|
|
@ -16,6 +16,13 @@ impl<'n, L: Add<R>, R, I1: Node<'n, Output = L>, I2: Node<'n, Output = R>> Node<
|
|||
self.0.eval() + self.1.eval()
|
||||
}
|
||||
}
|
||||
impl<'n, L: Add<R>, R, I1: Node<'n, Output = L>, I2: Node<'n, Output = R>>
|
||||
AddNode<'n, L, R, I1, I2>
|
||||
{
|
||||
pub fn new(input: (I1, I2)) -> AddNode<'n, L, R, I1, I2> {
|
||||
AddNode(input.0, input.1, PhantomData)
|
||||
}
|
||||
}
|
||||
|
||||
#[repr(C)]
|
||||
pub struct CloneNode<'n, N: Node<'n, Output = &'n O>, O: Clone + 'n>(pub N, PhantomData<&'n ()>);
|
||||
|
|
@ -25,6 +32,11 @@ impl<'n, N: Node<'n, Output = &'n O>, O: Clone> Node<'n> for CloneNode<'n, N, O>
|
|||
self.0.eval().clone()
|
||||
}
|
||||
}
|
||||
impl<'n, N: Node<'n, Output = &'n O>, O: Clone> CloneNode<'n, N, O> {
|
||||
pub const fn new(node: N) -> CloneNode<'n, N, O> {
|
||||
CloneNode(node, PhantomData)
|
||||
}
|
||||
}
|
||||
|
||||
#[repr(C)]
|
||||
pub struct FstNode<'n, N: Node<'n>>(pub N, PhantomData<&'n ()>);
|
||||
|
|
@ -56,13 +68,12 @@ impl<'n, N: Node<'n>> Node<'n> for DupNode<'n, N> {
|
|||
(self.0.eval(), self.0.eval()) //TODO: use Copy/Clone implementation
|
||||
}
|
||||
}
|
||||
impl<'n, N: Node<'n>> NodeInput for DupNode<'n, N> {
|
||||
type Nodes = N;
|
||||
|
||||
#[repr(C)]
|
||||
/// Return the unit value
|
||||
pub struct UnitNode;
|
||||
impl<'n> Node<'n> for UnitNode {
|
||||
type Output = ();
|
||||
fn eval(&'n self) -> Self::Output {}
|
||||
fn new(input: Self::Nodes) -> Self {
|
||||
Self(input, PhantomData)
|
||||
}
|
||||
}
|
||||
|
||||
#[repr(C)]
|
||||
|
|
@ -74,11 +85,18 @@ impl<'n, N: Node<'n>> Node<'n> for IdNode<'n, N> {
|
|||
self.0.eval()
|
||||
}
|
||||
}
|
||||
impl<'n, N: Node<'n>> NodeInput for IdNode<'n, N> {
|
||||
type Nodes = N;
|
||||
|
||||
fn new(input: Self::Nodes) -> Self {
|
||||
Self(input, PhantomData)
|
||||
}
|
||||
}
|
||||
|
||||
pub fn foo() {
|
||||
let unit = UnitNode;
|
||||
let value = IdNode(crate::value::ValueNode::new(2u32), PhantomData);
|
||||
let value2 = crate::value::ValueNode::new(4u32);
|
||||
let unit = crate::value::UnitNode;
|
||||
let value = IdNode(crate::value::ValueNode(2u32), PhantomData);
|
||||
let value2 = crate::value::ValueNode(4u32);
|
||||
let dup = DupNode(&value, PhantomData);
|
||||
fn int(_: (), state: &u32) -> &u32 {
|
||||
state
|
||||
|
|
@ -120,10 +138,16 @@ pub mod gpu {
|
|||
#[spirv(storage_buffer, descriptor_set = 0, binding = 1)] y: &mut [(u32, u32)],
|
||||
#[spirv(push_constant)] push_consts: &PushConsts,
|
||||
) {
|
||||
fn node_graph(input: Input) -> Output {
|
||||
let n0 = ValueNode::new(input);
|
||||
let n1 = IdNode::new(n0);
|
||||
let n2 = IdNode::new(n1);
|
||||
return n2.eval();
|
||||
}
|
||||
let gid = global_id.x as usize;
|
||||
// Only process up to n, which is the length of the buffers.
|
||||
if global_id.x < push_consts.n {
|
||||
y[gid] = OPERATION.eval(a[gid]);
|
||||
y[gid] = node_graph(a[gid]);
|
||||
}
|
||||
}
|
||||
#[allow(unused)]
|
||||
|
|
|
|||
|
|
@ -1,4 +1,6 @@
|
|||
use core::marker::PhantomData;
|
||||
use core::mem::MaybeUninit;
|
||||
use core::sync::atomic::AtomicBool;
|
||||
|
||||
use crate::Node;
|
||||
|
||||
|
|
@ -18,8 +20,7 @@ impl<'n, T: 'n> Node<'n> for ValueNode<T> {
|
|||
&self.0
|
||||
}
|
||||
}
|
||||
|
||||
impl<'n, T> ValueNode<T> {
|
||||
impl<T> ValueNode<T> {
|
||||
pub const fn new(value: T) -> ValueNode<T> {
|
||||
ValueNode(value)
|
||||
}
|
||||
|
|
@ -33,8 +34,29 @@ impl<'n, T: Default + 'n> Node<'n> for DefaultNode<T> {
|
|||
T::default()
|
||||
}
|
||||
}
|
||||
impl<T> DefaultNode<T> {
|
||||
pub const fn new() -> DefaultNode<T> {
|
||||
DefaultNode(PhantomData)
|
||||
|
||||
#[repr(C)]
|
||||
/// Return the unit value
|
||||
pub struct UnitNode;
|
||||
impl<'n> Node<'n> for UnitNode {
|
||||
type Output = ();
|
||||
fn eval(&'n self) -> Self::Output {}
|
||||
}
|
||||
|
||||
pub struct InputNode<T>(MaybeUninit<T>, AtomicBool);
|
||||
impl<'n, T: 'n> Node<'n> for InputNode<T> {
|
||||
type Output = &'n T;
|
||||
fn eval(&'n self) -> Self::Output {
|
||||
if self.1.load(core::sync::atomic::Ordering::SeqCst) {
|
||||
unsafe { self.0.assume_init_ref() }
|
||||
} else {
|
||||
panic!("tried to access an input before setting it")
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl<T> InputNode<T> {
|
||||
pub const fn new() -> InputNode<T> {
|
||||
InputNode(MaybeUninit::uninit(), AtomicBool::new(false))
|
||||
}
|
||||
}
|
||||
|
|
|
|||
|
|
@ -26,3 +26,7 @@ ide_db = { version = "*", package = "ra_ap_ide_db" , optional = true }
|
|||
storage-map = { version = "*", optional = true }
|
||||
lock_api = { version= "*", optional = true }
|
||||
parking_lot = { version = "*", optional = true }
|
||||
pretty-token-stream = {path = "../../pretty-token-stream"}
|
||||
syn = {version = "1.0", default-features = false, features = ["parsing", "printing"]}
|
||||
proc-macro2 = {version = "1.0", default-features = false, features = ["proc-macro"]}
|
||||
quote = {version = "1.0", default-features = false }
|
||||
|
|
|
|||
|
|
@ -16,3 +16,98 @@ pub trait DynamicInput<'n> {
|
|||
fn set_kwarg_by_name(&mut self, name: &str, value: DynAnyNode<'n>);
|
||||
fn set_arg_by_index(&mut self, index: usize, value: DynAnyNode<'n>);
|
||||
}
|
||||
|
||||
use quote::quote;
|
||||
use syn::{Expr, ExprPath, Type};
|
||||
|
||||
/// Given a Node call tree, construct a function
|
||||
/// that takes an input tuple and evaluates the call graph
|
||||
/// on the gpu an fn node is constructed that takes a value
|
||||
/// node as input
|
||||
pub struct NodeGraph {
|
||||
/// Collection of nodes with their corresponding inputs.
|
||||
/// The first node always always has to be an Input Node.
|
||||
pub nodes: Vec<NodeKind>,
|
||||
pub output: Type,
|
||||
pub input: Type,
|
||||
}
|
||||
pub enum NodeKind {
|
||||
Value(Expr),
|
||||
Input,
|
||||
Node(ExprPath, Vec<usize>),
|
||||
}
|
||||
|
||||
impl NodeGraph {
|
||||
pub fn serialize_function(&self) -> proc_macro2::TokenStream {
|
||||
let output_type = &self.output;
|
||||
let input_type = &self.input;
|
||||
|
||||
fn nid(id: &usize) -> syn::Ident {
|
||||
let str = format!("n{id}");
|
||||
syn::Ident::new(str.as_str(), proc_macro2::Span::call_site())
|
||||
}
|
||||
let mut nodes = Vec::new();
|
||||
for (ref id, node) in self.nodes.iter().enumerate() {
|
||||
let id = nid(id).clone();
|
||||
let line = match node {
|
||||
NodeKind::Value(val) => {
|
||||
quote! {let #id = graphene_core::value::ValueNode::new(#val);}
|
||||
}
|
||||
NodeKind::Node(node, ids) => {
|
||||
let ids = ids.iter().map(nid).collect::<Vec<_>>();
|
||||
quote! {let #id = #node::new((#(&#ids),*));}
|
||||
}
|
||||
NodeKind::Input => {
|
||||
quote! { let n0 = graphene_core::value::ValueNode::new(input);}
|
||||
}
|
||||
};
|
||||
nodes.push(line)
|
||||
}
|
||||
let last_id = self.nodes.len() - 1;
|
||||
let last_id = nid(&last_id);
|
||||
let ret = quote! { #last_id.eval() };
|
||||
let function = quote! {
|
||||
fn node_graph(input: #input_type) -> #output_type {
|
||||
#(#nodes)*
|
||||
#ret
|
||||
}
|
||||
};
|
||||
function
|
||||
}
|
||||
pub fn serialize_gpu(&self, name: &str) -> proc_macro2::TokenStream {
|
||||
let function = self.serialize_function();
|
||||
let output_type = &self.output;
|
||||
let input_type = &self.input;
|
||||
|
||||
quote! {
|
||||
#[cfg(target_arch = "spirv")]
|
||||
pub mod gpu {
|
||||
//#![deny(warnings)]
|
||||
#[repr(C)]
|
||||
pub struct PushConsts {
|
||||
n: u32,
|
||||
node: u32,
|
||||
}
|
||||
use super::*;
|
||||
|
||||
use spirv_std::glam::UVec3;
|
||||
|
||||
#[allow(unused)]
|
||||
#[spirv(compute(threads(64)))]
|
||||
pub fn #name(
|
||||
#[spirv(global_invocation_id)] global_id: UVec3,
|
||||
#[spirv(storage_buffer, descriptor_set = 0, binding = 0)] a: &[#input_type],
|
||||
#[spirv(storage_buffer, descriptor_set = 0, binding = 1)] y: &mut [#output_type],
|
||||
#[spirv(push_constant)] push_consts: &PushConsts,
|
||||
) {
|
||||
#function
|
||||
let gid = global_id.x as usize;
|
||||
// Only process up to n, which is the length of the buffers.
|
||||
if global_id.x < push_consts.n {
|
||||
y[gid] = node_graph(a[gid]);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
|
|||
|
|
@ -91,6 +91,26 @@ impl<'n> NodeStore<'n> {
|
|||
}
|
||||
|
||||
fn main() {
|
||||
use graphene_std::*;
|
||||
use quote::quote;
|
||||
use syn::parse::Parse;
|
||||
let nodes = vec![
|
||||
NodeKind::Input,
|
||||
NodeKind::Value(syn::parse_quote!(1u32)),
|
||||
NodeKind::Node(syn::parse_quote!(graphene_core::ops::AddNode), vec![0, 0]),
|
||||
];
|
||||
|
||||
//println!("{}", node_graph(1));
|
||||
|
||||
let nodegraph = NodeGraph {
|
||||
nodes,
|
||||
input: syn::Type::Verbatim(quote! {u32}),
|
||||
output: syn::Type::Verbatim(quote! {u32}),
|
||||
};
|
||||
|
||||
let pretty = pretty_token_stream::Pretty::new(nodegraph.serialize_gpu("add"));
|
||||
pretty.print();
|
||||
/*
|
||||
use dyn_any::{downcast_ref, DynAny, StaticType};
|
||||
//let mut mul = mul::MulNode::new();
|
||||
let mut stack: borrow_stack::FixedSizeStack<Box<dyn Node<'_, Output = &dyn DynAny>>> =
|
||||
|
|
@ -126,6 +146,6 @@ fn main() {
|
|||
//let sub = fnode.any(&("a", 2));
|
||||
//let cache = memo::CacheNode::new(&fnode);
|
||||
//let cached_result = cache.eval(&(2, 3));
|
||||
|
||||
*/
|
||||
//println!("{}", cached_result)
|
||||
}
|
||||
|
|
|
|||
|
|
@ -8,7 +8,10 @@ pub struct CacheNode<'n, CachedNode: Node<'n>> {
|
|||
cache: OnceCell<CachedNode::Output>,
|
||||
_phantom: PhantomData<&'n ()>,
|
||||
}
|
||||
impl<'n, CashedNode: Node<'n>> Node<'n> for CacheNode<'n, CashedNode> {
|
||||
impl<'n, CashedNode: Node<'n>> Node<'n> for CacheNode<'n, CashedNode>
|
||||
where
|
||||
CashedNode::Output: 'n,
|
||||
{
|
||||
type Output = &'n CashedNode::Output;
|
||||
fn eval(&'n self) -> Self::Output {
|
||||
self.cache.get_or_init(|| self.node.eval())
|
||||
|
|
|
|||
|
|
@ -24,8 +24,7 @@ pub struct StorageNode<'n>(&'n dyn Node<'n, Output = &'n dyn DynAny<'n>>);
|
|||
impl<'n> Node<'n> for StorageNode<'n> {
|
||||
type Output = &'n (dyn DynAny<'n>);
|
||||
fn eval(&'n self) -> Self::Output {
|
||||
let value = self.0.eval();
|
||||
value
|
||||
self.0.eval()
|
||||
}
|
||||
}
|
||||
impl<'n> StorageNode<'n> {
|
||||
|
|
|
|||
Loading…
Reference in New Issue