Compile node graph description to GPU code

This commit is contained in:
Dennis 2022-06-08 09:52:58 +02:00 committed by Keavon Chambers
parent 998f37d1b0
commit e84b9bd5bd
10 changed files with 498 additions and 76 deletions

75
node-graph/5 Normal file
View File

@ -0,0 +1,75 @@
pub mod value;
pub use graphene_core::{generic, ops /*, structural*/};
#[cfg(feature = "caching")]
pub mod caching;
#[cfg(feature = "memoization")]
pub mod memo;
pub use graphene_core::*;
use dyn_any::{downcast_ref, DynAny, StaticType};
pub type DynNode<'n, T> = &'n (dyn Node<'n, Output = T> + 'n);
pub type DynAnyNode<'n> = &'n (dyn Node<'n, Output = &'n dyn DynAny<'n>> + 'n);
pub trait DynamicInput<'n> {
fn set_kwarg_by_name(&mut self, name: &str, value: DynAnyNode<'n>);
fn set_arg_by_index(&mut self, index: usize, value: DynAnyNode<'n>);
}
use quote::quote;
use syn::{Expr, ExprPath, Type};
/// Given a Node call tree, construct a function
/// that takes an input tuple and evaluates the call graph
/// on the gpu an fn node is constructed that takes a value
/// node as input
struct NodeGraph {
/// Collection of nodes with their corresponding inputs.
/// The first node always always has to be an Input Node.
nodes: Vec<NodeKind>,
}
enum NodeKind {
Value(Expr),
Input(Type),
Node(ExprPath, Vec<u64>),
}
impl NodeGraph {
pub fn serialize(&self) -> String {
let mut output = String::new();
let output_type = if let Some(NodeKind::Node(expr, _)) = self.nodes.last() {
expr
} else {
panic!("last node wasn't a valid node")
};
let output_type = quote! {#output_type::Output};
let input_type = if let Some(NodeKind::Input(type_)) = self.nodes.first() {
type_
} else {
panic!("first node wasn't an input node")
};
let mut nodes = Vec::new();
for (id, node) in self.nodes.iter().enumerate() {
let nid = |id| format!("n{id}");
let id = nid(&id as u64);
let line = match node {
NodeKind::Value(val) => {
quote! {let #id = graphene_core::value::ValueNode::new(#val);}
}
NodeKind::Node(node, ids) => {
let ids = ids.iter().map(nid).collect::<Vec<_>>();
quote! {let #id = #node::new(((#(#ids),)*));}
}
};
nodes.push(line)
}
let function = quote! {
fn node_graph(input: #input_type) -> #output_type {
#(#nodes)*
}
};
function.to_string()
}
}

200
node-graph/Cargo.lock generated
View File

@ -17,6 +17,18 @@ version = "0.12.1"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "33954243bd79057c2de7338850b85983a44588021f8a5fee574a8888c6de4344"
[[package]]
name = "arrayref"
version = "0.3.6"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "a4c527152e37cf757a3f78aae5a06fbeefdb07ccc535c980a3208ee3060dd544"
[[package]]
name = "arrayvec"
version = "0.5.2"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "23b62fc65de8e4e7f52534fb52b0f3ed04746ae267519eef2a83941e8085068b"
[[package]]
name = "arrayvec"
version = "0.7.2"
@ -34,22 +46,68 @@ dependencies = [
"syn",
]
[[package]]
name = "atty"
version = "0.2.14"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "d9b39be18770d11421cdb1b9947a45dd3f37e93092cbf377614828a319d5fee8"
dependencies = [
"hermit-abi",
"libc",
"winapi",
]
[[package]]
name = "autocfg"
version = "1.1.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "d468802bab17cbc0cc575e9b053f41e72aa36bfa6b7f55e3529ffa43161b97fa"
[[package]]
name = "base64"
version = "0.13.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "904dfeac50f3cdaba28fc6f57fdcddb75f49ed61346676a78c4ffe55877802fd"
[[package]]
name = "bitflags"
version = "1.3.2"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "bef38d45163c2f1dde094a7dfd33ccf595c92905c8f8f4fdc18d06fb1037718a"
[[package]]
name = "blake2b_simd"
version = "0.5.11"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "afa748e348ad3be8263be728124b24a24f268266f6f5d58af9d75f6a40b5c587"
dependencies = [
"arrayref",
"arrayvec 0.5.2",
"constant_time_eq",
]
[[package]]
name = "borrow_stack"
version = "0.1.0"
[[package]]
name = "byteorder"
version = "1.4.3"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "14c189c53d098945499cdfa7ecc63567cf3886b3332b312a5b4585d8d3a6a610"
[[package]]
name = "cc"
version = "1.0.73"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "2fff2a6927b3bb87f9595d67196a70493f627687a71d87a0d692242c33f58c11"
[[package]]
name = "cfg-if"
version = "0.1.10"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "4785bdd1c96b2a846b2bd7cc02e86b6b3dbf14e7e53446c4f54c92a361040822"
[[package]]
name = "cfg-if"
version = "1.0.0"
@ -108,6 +166,12 @@ dependencies = [
"tracing",
]
[[package]]
name = "constant_time_eq"
version = "0.1.5"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "245097e9a4535ee1e3e3931fcfcd55a796a44c643e8596ff6566d68f09b87bbc"
[[package]]
name = "countme"
version = "3.0.1"
@ -131,7 +195,7 @@ version = "0.5.4"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "5aaa7bd5fb665c6864b5f963dd9097905c54125909c7aa94c9e18507cdbe6c53"
dependencies = [
"cfg-if",
"cfg-if 1.0.0",
"crossbeam-utils",
]
@ -141,7 +205,7 @@ version = "0.8.1"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "6455c0ca19f0d2fbf751b908d5c55c1f5cbc65e03c4225427254b46890bdde1e"
dependencies = [
"cfg-if",
"cfg-if 1.0.0",
"crossbeam-epoch",
"crossbeam-utils",
]
@ -153,7 +217,7 @@ source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "1145cf131a2c6ba0615079ab6a638f7e1973ac9c2634fcbeaaad6114246efe8c"
dependencies = [
"autocfg",
"cfg-if",
"cfg-if 1.0.0",
"crossbeam-utils",
"lazy_static",
"memoffset",
@ -166,7 +230,7 @@ version = "0.8.8"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "0bf124c720b7686e3c2663cf54062ab0f68a88af2fb6a030e87e30bf721fcb38"
dependencies = [
"cfg-if",
"cfg-if 1.0.0",
"lazy_static",
]
@ -176,12 +240,23 @@ version = "5.3.4"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "3495912c9c1ccf2e18976439f4443f3fee0fd61f424ff99fde6a66b15ecb448f"
dependencies = [
"cfg-if",
"cfg-if 1.0.0",
"hashbrown 0.12.1",
"lock_api",
"parking_lot_core 0.9.3",
]
[[package]]
name = "dirs"
version = "1.0.5"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "3fd78930633bd1c6e35c4b42b1df7b0cbc6bc191146e512bb3bedf243fcc3901"
dependencies = [
"libc",
"redox_users",
"winapi",
]
[[package]]
name = "dissimilar"
version = "1.0.4"
@ -255,6 +330,17 @@ version = "0.4.7"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "7ab85b9b05e3978cc9a9cf8fea7f01b494e1a09ed3037e16ba39edc7a29eb61a"
[[package]]
name = "getrandom"
version = "0.1.16"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "8fc3cb4d91f53b50155bdcfd23f6a4c39ae1969c2ae85982b135750cccaf5fce"
dependencies = [
"cfg-if 1.0.0",
"libc",
"wasi",
]
[[package]]
name = "glam"
version = "0.20.5"
@ -295,9 +381,13 @@ dependencies = [
"lock_api",
"once_cell",
"parking_lot 0.12.1",
"pretty-token-stream",
"proc-macro2",
"quote",
"ra_ap_ide",
"ra_ap_ide_db",
"storage-map",
"syn",
]
[[package]]
@ -357,7 +447,7 @@ version = "0.1.12"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "7a5bbe824c507c5da5956355e86a746d82e0e1464f65d862cc5e71da70e94b2c"
dependencies = [
"cfg-if",
"cfg-if 1.0.0",
]
[[package]]
@ -403,7 +493,7 @@ version = "0.4.17"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "abb12e687cfb44aa40f41fc3978ef76448f9b6038cad6aef4259d3c095a2382e"
dependencies = [
"cfg-if",
"cfg-if 1.0.0",
]
[[package]]
@ -436,6 +526,19 @@ dependencies = [
"windows-sys 0.28.0",
]
[[package]]
name = "nix"
version = "0.14.1"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "6c722bee1037d430d0f8e687bbdbf222f27cc6e4e68d5caf630857bb2b6dbdce"
dependencies = [
"bitflags",
"cc",
"cfg-if 0.1.10",
"libc",
"void",
]
[[package]]
name = "num-traits"
version = "0.2.14"
@ -495,10 +598,10 @@ version = "0.8.5"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "d76e8e1493bcac0d2766c42737f34458f1c8c50c0d23bcb24ea953affb273216"
dependencies = [
"cfg-if",
"cfg-if 1.0.0",
"instant",
"libc",
"redox_syscall",
"redox_syscall 0.2.13",
"smallvec",
"winapi",
]
@ -509,9 +612,9 @@ version = "0.9.3"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "09a279cbf25cb0757810394fbc1e359949b59e348145c643a939a525692e6929"
dependencies = [
"cfg-if",
"cfg-if 1.0.0",
"libc",
"redox_syscall",
"redox_syscall 0.2.13",
"smallvec",
"windows-sys 0.36.1",
]
@ -557,6 +660,17 @@ version = "0.2.9"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "e0a7ae3ac2f1173085d398531c705756c94a4c56843785df85a60c1a0afac116"
[[package]]
name = "pretty-token-stream"
version = "0.1.0"
dependencies = [
"atty",
"cfg-if 0.1.10",
"nix",
"proc-macro2",
"term",
]
[[package]]
name = "proc-macro2"
version = "1.0.36"
@ -639,7 +753,7 @@ version = "0.0.104"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "02f334df1ac0ceeb8d317e6fc71e6473344ddc8ba75613021a1213ee7b2b3dee"
dependencies = [
"arrayvec",
"arrayvec 0.7.2",
"either",
"itertools",
"once_cell",
@ -663,7 +777,7 @@ source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "400a6db17870220d63834cba06883dba10d42beac6f8b3d773e82fe248762adf"
dependencies = [
"anymap",
"arrayvec",
"arrayvec 0.7.2",
"bitflags",
"cov-mark",
"dashmap",
@ -718,7 +832,7 @@ version = "0.0.104"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "df104472251b6c25c8c1efd8aaaf1c16b76afe2d06e17951478d321a2a1fa9d4"
dependencies = [
"arrayvec",
"arrayvec 0.7.2",
"chalk-ir",
"chalk-recursive",
"chalk-solve",
@ -816,7 +930,7 @@ version = "0.0.104"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "4f0dee6238d74065219229ce0cc75c62ce1ae5630cba3ab5859e0b340705f8d7"
dependencies = [
"arrayvec",
"arrayvec 0.7.2",
"cov-mark",
"either",
"fst",
@ -922,7 +1036,7 @@ version = "0.0.104"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "15128f7955c99c82ba17db6d3c3c921719bcb9a6912054486949149d02c4e79a"
dependencies = [
"cfg-if",
"cfg-if 1.0.0",
"countme",
"libc",
"once_cell",
@ -1032,6 +1146,12 @@ dependencies = [
"num_cpus",
]
[[package]]
name = "redox_syscall"
version = "0.1.57"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "41cc0f7e4d5d4544e8861606a285bb08d3e70712ccc7d2b84d7c0ccfaf4b05ce"
[[package]]
name = "redox_syscall"
version = "0.2.13"
@ -1041,6 +1161,17 @@ dependencies = [
"bitflags",
]
[[package]]
name = "redox_users"
version = "0.3.5"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "de0737333e7a9502c789a36d7c7fa6092a49895d4faa31ca5df163857ded2e9d"
dependencies = [
"getrandom",
"redox_syscall 0.1.57",
"rust-argon2",
]
[[package]]
name = "rowan"
version = "0.15.5"
@ -1054,6 +1185,18 @@ dependencies = [
"text-size",
]
[[package]]
name = "rust-argon2"
version = "0.8.3"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "4b18820d944b33caa75a71378964ac46f58517c92b6ae5f762636247c09e78fb"
dependencies = [
"base64",
"blake2b_simd",
"constant_time_eq",
"crossbeam-utils",
]
[[package]]
name = "rustc-ap-rustc_lexer"
version = "725.0.0"
@ -1191,6 +1334,17 @@ dependencies = [
"unicode-xid",
]
[[package]]
name = "term"
version = "0.5.2"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "edd106a334b7657c10b7c540a0106114feadeb4dc314513e97df481d5d966f42"
dependencies = [
"byteorder",
"dirs",
"winapi",
]
[[package]]
name = "text-size"
version = "1.1.0"
@ -1218,7 +1372,7 @@ version = "0.1.34"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "5d0ecdcb44a79f0fe9844f0c4f33a342cbcbb5117de8001e6ba0dc2351327d09"
dependencies = [
"cfg-if",
"cfg-if 1.0.0",
"pin-project-lite",
"tracing-attributes",
"tracing-core",
@ -1304,6 +1458,18 @@ version = "0.9.4"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "49874b5167b65d7193b8aba1567f5c7d93d001cafc34600cee003eda787e483f"
[[package]]
name = "void"
version = "1.0.2"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "6a02e4885ed3bc0f2de90ea6dd45ebcbb66dacffe03547fadbb0eeae2770887d"
[[package]]
name = "wasi"
version = "0.9.0+wasi-snapshot-preview1"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "cccddf32554fecc6acb585f82a32a72e28b48f8c4c1883ddfeeeaa96f7d8e519"
[[package]]
name = "winapi"
version = "0.3.9"

View File

@ -14,7 +14,7 @@ pub mod ops;
pub mod value;
pub trait Node<'n> {
type Output: 'n; // TODO: replace with generic associated type
type Output; // TODO: replace with generic associated type
fn eval(&'n self) -> Self::Output;
}
@ -27,12 +27,26 @@ impl<'n, N: Node<'n>> Node<'n> for &'n N {
}
}
pub trait NodeInput {
type Nodes;
fn new(input: Self::Nodes) -> Self;
}
trait FQN {
fn fqn(&self) -> &'static str;
}
trait Input<I> {
unsafe fn input(&self, input: I);
}
#[cfg(feature = "async")]
#[async_trait]
pub trait AsyncNode<'n> {
type Output: 'n; // TODO: replace with generic associated type
type Output; // TODO: replace with generic associated type
async fn eval(&'n self) -> Self::Output;
async fn eval_async(&'n self) -> Self::Output;
}
#[cfg(feature = "async")]
@ -40,7 +54,7 @@ pub trait AsyncNode<'n> {
impl<'n, N: Node<'n> + Sync> AsyncNode<'n> for N {
type Output = N::Output;
async fn eval(&'n self) -> Self::Output {
async fn eval_async(&'n self) -> Self::Output {
Node::eval(self)
}
}

View File

@ -1,9 +1,9 @@
use core::{marker::PhantomData, ops::Add};
use crate::Node;
use crate::{Node, NodeInput};
#[repr(C)]
struct AddNode<'n, L: Add<R>, R, I1: Node<'n, Output = L>, I2: Node<'n, Output = R>>(
pub struct AddNode<'n, L: Add<R>, R, I1: Node<'n, Output = L>, I2: Node<'n, Output = R>>(
pub I1,
pub I2,
PhantomData<&'n (L, R)>,
@ -16,6 +16,13 @@ impl<'n, L: Add<R>, R, I1: Node<'n, Output = L>, I2: Node<'n, Output = R>> Node<
self.0.eval() + self.1.eval()
}
}
impl<'n, L: Add<R>, R, I1: Node<'n, Output = L>, I2: Node<'n, Output = R>>
AddNode<'n, L, R, I1, I2>
{
pub fn new(input: (I1, I2)) -> AddNode<'n, L, R, I1, I2> {
AddNode(input.0, input.1, PhantomData)
}
}
#[repr(C)]
pub struct CloneNode<'n, N: Node<'n, Output = &'n O>, O: Clone + 'n>(pub N, PhantomData<&'n ()>);
@ -25,6 +32,11 @@ impl<'n, N: Node<'n, Output = &'n O>, O: Clone> Node<'n> for CloneNode<'n, N, O>
self.0.eval().clone()
}
}
impl<'n, N: Node<'n, Output = &'n O>, O: Clone> CloneNode<'n, N, O> {
pub const fn new(node: N) -> CloneNode<'n, N, O> {
CloneNode(node, PhantomData)
}
}
#[repr(C)]
pub struct FstNode<'n, N: Node<'n>>(pub N, PhantomData<&'n ()>);
@ -56,13 +68,12 @@ impl<'n, N: Node<'n>> Node<'n> for DupNode<'n, N> {
(self.0.eval(), self.0.eval()) //TODO: use Copy/Clone implementation
}
}
impl<'n, N: Node<'n>> NodeInput for DupNode<'n, N> {
type Nodes = N;
#[repr(C)]
/// Return the unit value
pub struct UnitNode;
impl<'n> Node<'n> for UnitNode {
type Output = ();
fn eval(&'n self) -> Self::Output {}
fn new(input: Self::Nodes) -> Self {
Self(input, PhantomData)
}
}
#[repr(C)]
@ -74,11 +85,18 @@ impl<'n, N: Node<'n>> Node<'n> for IdNode<'n, N> {
self.0.eval()
}
}
impl<'n, N: Node<'n>> NodeInput for IdNode<'n, N> {
type Nodes = N;
fn new(input: Self::Nodes) -> Self {
Self(input, PhantomData)
}
}
pub fn foo() {
let unit = UnitNode;
let value = IdNode(crate::value::ValueNode::new(2u32), PhantomData);
let value2 = crate::value::ValueNode::new(4u32);
let unit = crate::value::UnitNode;
let value = IdNode(crate::value::ValueNode(2u32), PhantomData);
let value2 = crate::value::ValueNode(4u32);
let dup = DupNode(&value, PhantomData);
fn int(_: (), state: &u32) -> &u32 {
state
@ -120,10 +138,16 @@ pub mod gpu {
#[spirv(storage_buffer, descriptor_set = 0, binding = 1)] y: &mut [(u32, u32)],
#[spirv(push_constant)] push_consts: &PushConsts,
) {
fn node_graph(input: Input) -> Output {
let n0 = ValueNode::new(input);
let n1 = IdNode::new(n0);
let n2 = IdNode::new(n1);
return n2.eval();
}
let gid = global_id.x as usize;
// Only process up to n, which is the length of the buffers.
if global_id.x < push_consts.n {
y[gid] = OPERATION.eval(a[gid]);
y[gid] = node_graph(a[gid]);
}
}
#[allow(unused)]

View File

@ -1,4 +1,6 @@
use core::marker::PhantomData;
use core::mem::MaybeUninit;
use core::sync::atomic::AtomicBool;
use crate::Node;
@ -18,8 +20,7 @@ impl<'n, T: 'n> Node<'n> for ValueNode<T> {
&self.0
}
}
impl<'n, T> ValueNode<T> {
impl<T> ValueNode<T> {
pub const fn new(value: T) -> ValueNode<T> {
ValueNode(value)
}
@ -33,8 +34,29 @@ impl<'n, T: Default + 'n> Node<'n> for DefaultNode<T> {
T::default()
}
}
impl<T> DefaultNode<T> {
pub const fn new() -> DefaultNode<T> {
DefaultNode(PhantomData)
#[repr(C)]
/// Return the unit value
pub struct UnitNode;
impl<'n> Node<'n> for UnitNode {
type Output = ();
fn eval(&'n self) -> Self::Output {}
}
pub struct InputNode<T>(MaybeUninit<T>, AtomicBool);
impl<'n, T: 'n> Node<'n> for InputNode<T> {
type Output = &'n T;
fn eval(&'n self) -> Self::Output {
if self.1.load(core::sync::atomic::Ordering::SeqCst) {
unsafe { self.0.assume_init_ref() }
} else {
panic!("tried to access an input before setting it")
}
}
}
impl<T> InputNode<T> {
pub const fn new() -> InputNode<T> {
InputNode(MaybeUninit::uninit(), AtomicBool::new(false))
}
}

View File

@ -26,3 +26,7 @@ ide_db = { version = "*", package = "ra_ap_ide_db" , optional = true }
storage-map = { version = "*", optional = true }
lock_api = { version= "*", optional = true }
parking_lot = { version = "*", optional = true }
pretty-token-stream = {path = "../../pretty-token-stream"}
syn = {version = "1.0", default-features = false, features = ["parsing", "printing"]}
proc-macro2 = {version = "1.0", default-features = false, features = ["proc-macro"]}
quote = {version = "1.0", default-features = false }

View File

@ -16,3 +16,98 @@ pub trait DynamicInput<'n> {
fn set_kwarg_by_name(&mut self, name: &str, value: DynAnyNode<'n>);
fn set_arg_by_index(&mut self, index: usize, value: DynAnyNode<'n>);
}
use quote::quote;
use syn::{Expr, ExprPath, Type};
/// Given a Node call tree, construct a function
/// that takes an input tuple and evaluates the call graph
/// on the gpu an fn node is constructed that takes a value
/// node as input
pub struct NodeGraph {
/// Collection of nodes with their corresponding inputs.
/// The first node always always has to be an Input Node.
pub nodes: Vec<NodeKind>,
pub output: Type,
pub input: Type,
}
pub enum NodeKind {
Value(Expr),
Input,
Node(ExprPath, Vec<usize>),
}
impl NodeGraph {
pub fn serialize_function(&self) -> proc_macro2::TokenStream {
let output_type = &self.output;
let input_type = &self.input;
fn nid(id: &usize) -> syn::Ident {
let str = format!("n{id}");
syn::Ident::new(str.as_str(), proc_macro2::Span::call_site())
}
let mut nodes = Vec::new();
for (ref id, node) in self.nodes.iter().enumerate() {
let id = nid(id).clone();
let line = match node {
NodeKind::Value(val) => {
quote! {let #id = graphene_core::value::ValueNode::new(#val);}
}
NodeKind::Node(node, ids) => {
let ids = ids.iter().map(nid).collect::<Vec<_>>();
quote! {let #id = #node::new((#(&#ids),*));}
}
NodeKind::Input => {
quote! { let n0 = graphene_core::value::ValueNode::new(input);}
}
};
nodes.push(line)
}
let last_id = self.nodes.len() - 1;
let last_id = nid(&last_id);
let ret = quote! { #last_id.eval() };
let function = quote! {
fn node_graph(input: #input_type) -> #output_type {
#(#nodes)*
#ret
}
};
function
}
pub fn serialize_gpu(&self, name: &str) -> proc_macro2::TokenStream {
let function = self.serialize_function();
let output_type = &self.output;
let input_type = &self.input;
quote! {
#[cfg(target_arch = "spirv")]
pub mod gpu {
//#![deny(warnings)]
#[repr(C)]
pub struct PushConsts {
n: u32,
node: u32,
}
use super::*;
use spirv_std::glam::UVec3;
#[allow(unused)]
#[spirv(compute(threads(64)))]
pub fn #name(
#[spirv(global_invocation_id)] global_id: UVec3,
#[spirv(storage_buffer, descriptor_set = 0, binding = 0)] a: &[#input_type],
#[spirv(storage_buffer, descriptor_set = 0, binding = 1)] y: &mut [#output_type],
#[spirv(push_constant)] push_consts: &PushConsts,
) {
#function
let gid = global_id.x as usize;
// Only process up to n, which is the length of the buffers.
if global_id.x < push_consts.n {
y[gid] = node_graph(a[gid]);
}
}
}
}
}
}

View File

@ -91,41 +91,61 @@ impl<'n> NodeStore<'n> {
}
fn main() {
use dyn_any::{downcast_ref, DynAny, StaticType};
//let mut mul = mul::MulNode::new();
let mut stack: borrow_stack::FixedSizeStack<Box<dyn Node<'_, Output = &dyn DynAny>>> =
borrow_stack::FixedSizeStack::new(42);
unsafe { stack.push(Box::new(AnyValueNode::new(1f32))) };
//let node = unsafe { stack.get(0) };
//let boxed = Box::new(StorageNode::new(node));
//unsafe { stack.push(boxed) };
let result = unsafe { &stack.get()[0] }.eval();
dbg!(downcast_ref::<f32>(result));
/*unsafe {
stack
.push(Box::new(AnyRefNode::new(stack.get(0).as_ref()))
as Box<dyn Node<(), Output = &dyn DynAny>>)
};*/
let f = (3.2f32, 3.1f32);
let a = ValueNode::new(1.);
let id = std::any::TypeId::of::<&f32>();
let any_a = AnyRefNode::new(&a);
/*let _mul2 = mul::MulNodeInput {
a: None,
b: Some(&any_a),
};
let mut mul2 = mul::new!();
//let cached = memo::CacheNode::new(&mul1);
//let foo = value::AnyRefNode::new(&cached);
mul2.set_arg_by_index(0, &any_a);*/
let int = value::IntNode::<32>;
Node::eval(&int);
println!("{}", Node::eval(&int));
//let _add: u32 = ops::AddNode::<u32>::default().eval((int.exec(), int.exec()));
//let fnode = generic::FnNode::new(|(a, b): &(i32, i32)| a - b);
//let sub = fnode.any(&("a", 2));
//let cache = memo::CacheNode::new(&fnode);
//let cached_result = cache.eval(&(2, 3));
use graphene_std::*;
use quote::quote;
use syn::parse::Parse;
let nodes = vec![
NodeKind::Input,
NodeKind::Value(syn::parse_quote!(1u32)),
NodeKind::Node(syn::parse_quote!(graphene_core::ops::AddNode), vec![0, 0]),
];
//println!("{}", node_graph(1));
let nodegraph = NodeGraph {
nodes,
input: syn::Type::Verbatim(quote! {u32}),
output: syn::Type::Verbatim(quote! {u32}),
};
let pretty = pretty_token_stream::Pretty::new(nodegraph.serialize_gpu("add"));
pretty.print();
/*
use dyn_any::{downcast_ref, DynAny, StaticType};
//let mut mul = mul::MulNode::new();
let mut stack: borrow_stack::FixedSizeStack<Box<dyn Node<'_, Output = &dyn DynAny>>> =
borrow_stack::FixedSizeStack::new(42);
unsafe { stack.push(Box::new(AnyValueNode::new(1f32))) };
//let node = unsafe { stack.get(0) };
//let boxed = Box::new(StorageNode::new(node));
//unsafe { stack.push(boxed) };
let result = unsafe { &stack.get()[0] }.eval();
dbg!(downcast_ref::<f32>(result));
/*unsafe {
stack
.push(Box::new(AnyRefNode::new(stack.get(0).as_ref()))
as Box<dyn Node<(), Output = &dyn DynAny>>)
};*/
let f = (3.2f32, 3.1f32);
let a = ValueNode::new(1.);
let id = std::any::TypeId::of::<&f32>();
let any_a = AnyRefNode::new(&a);
/*let _mul2 = mul::MulNodeInput {
a: None,
b: Some(&any_a),
};
let mut mul2 = mul::new!();
//let cached = memo::CacheNode::new(&mul1);
//let foo = value::AnyRefNode::new(&cached);
mul2.set_arg_by_index(0, &any_a);*/
let int = value::IntNode::<32>;
Node::eval(&int);
println!("{}", Node::eval(&int));
//let _add: u32 = ops::AddNode::<u32>::default().eval((int.exec(), int.exec()));
//let fnode = generic::FnNode::new(|(a, b): &(i32, i32)| a - b);
//let sub = fnode.any(&("a", 2));
//let cache = memo::CacheNode::new(&fnode);
//let cached_result = cache.eval(&(2, 3));
*/
//println!("{}", cached_result)
}

View File

@ -8,7 +8,10 @@ pub struct CacheNode<'n, CachedNode: Node<'n>> {
cache: OnceCell<CachedNode::Output>,
_phantom: PhantomData<&'n ()>,
}
impl<'n, CashedNode: Node<'n>> Node<'n> for CacheNode<'n, CashedNode> {
impl<'n, CashedNode: Node<'n>> Node<'n> for CacheNode<'n, CashedNode>
where
CashedNode::Output: 'n,
{
type Output = &'n CashedNode::Output;
fn eval(&'n self) -> Self::Output {
self.cache.get_or_init(|| self.node.eval())

View File

@ -24,8 +24,7 @@ pub struct StorageNode<'n>(&'n dyn Node<'n, Output = &'n dyn DynAny<'n>>);
impl<'n> Node<'n> for StorageNode<'n> {
type Output = &'n (dyn DynAny<'n>);
fn eval(&'n self) -> Self::Output {
let value = self.0.eval();
value
self.0.eval()
}
}
impl<'n> StorageNode<'n> {