Implement cons node

This commit is contained in:
Dennis 2022-08-13 14:54:12 +02:00 committed by Keavon Chambers
parent 7f415febed
commit b06e00ce61
6 changed files with 97 additions and 213 deletions

View File

@ -20,4 +20,4 @@ dyn-any = {path = "../../libraries/dyn-any", features = ["derive"], optional = t
spirv-std = { git = "https://github.com/EmbarkStudios/rust-gpu", features = ["glam"] , optional = true} spirv-std = { git = "https://github.com/EmbarkStudios/rust-gpu", features = ["glam"] , optional = true}
async-trait = {version = "0.1", optional = true} async-trait = {version = "0.1", optional = true}
serde = "1.0" serde = {version = "1.0", features = ["derive"]}

View File

@ -1,32 +1,32 @@
use core::marker::PhantomData; use core::marker::PhantomData;
use crate::Node; use crate::Node;
pub struct FnNode<'n, T: Fn(<N as Node<'n>>::Output) -> O, N: Node<'n>, O>(T, N, PhantomData<&'n O>); pub struct FnNode<'n, T: Fn(I) -> O, I, O>(T, PhantomData<&'n (I, O)>);
impl<'n, T: Fn(<N as Node<'n>>::Output) -> O, N: Node<'n>, O> Node<'n> for FnNode<'n, T, N, O> { impl<'n, T: Fn(I) -> O, O, I> Node<'n, I> for FnNode<'n, T, I, O> {
type Output = O; type Output = O;
fn eval(&'n self) -> Self::Output { fn eval(&'n self, input: I) -> Self::Output {
self.0(self.1.eval()) self.0(input)
} }
} }
impl<'n, T: Fn(<N as Node<'n>>::Output) -> O, N: Node<'n>, O> FnNode<'n, T, N, O> { impl<'n, T: Fn(I) -> O, I, O> FnNode<'n, T, I, O> {
pub fn new(f: T, input: N) -> Self { pub fn new(f: T) -> Self {
FnNode(f, input, PhantomData) FnNode(f, PhantomData)
} }
} }
pub struct FnNodeWithState<'n, T: Fn(<N as Node<'n>>::Output, &'n State) -> O, N: Node<'n>, O, State: 'n>(T, N, State, PhantomData<&'n O>); pub struct FnNodeWithState<'n, T: Fn(I, &'n State) -> O, I, O, State: 'n>(T, State, PhantomData<&'n (O, I)>);
impl<'n, T: Fn(<N as Node<'n>>::Output, &'n State) -> O, N: Node<'n>, O: 'n, State: 'n> Node<'n> for FnNodeWithState<'n, T, N, O, State> { impl<'n, T: Fn(I, &'n State) -> O, I, O: 'n, State: 'n> Node<'n, I> for FnNodeWithState<'n, T, I, O, State> {
type Output = O; type Output = O;
fn eval(&'n self) -> Self::Output { fn eval(&'n self, input: I) -> Self::Output {
self.0(self.1.eval(), &self.2) self.0(input, &self.1)
} }
} }
impl<'n, T: Fn(<N as Node<'n>>::Output, &'n State) -> O, N: Node<'n>, O: 'n, State: 'n> FnNodeWithState<'n, T, N, O, State> { impl<'n, T: Fn(I, &'n State) -> O, I, O: 'n, State: 'n> FnNodeWithState<'n, T, I, O, State> {
pub fn new(f: T, input: N, state: State) -> Self { pub fn new(f: T, state: State) -> Self {
FnNodeWithState(f, input, state, PhantomData) FnNodeWithState(f, state, PhantomData)
} }
} }

View File

@ -8,10 +8,10 @@ use alloc::boxed::Box;
#[cfg(feature = "async")] #[cfg(feature = "async")]
use async_trait::async_trait; use async_trait::async_trait;
//pub mod generic; pub mod generic;
//pub mod ops; pub mod ops;
//pub mod structural;
pub mod raster; pub mod raster;
pub mod structural;
pub mod value; pub mod value;
pub trait Node<'n, T> { pub trait Node<'n, T> {
@ -28,12 +28,6 @@ impl<'n, N: Node<'n, T>, T> Node<'n, T> for &'n N {
} }
} }
pub trait NodeInput {
type Nodes;
fn new(input: Self::Nodes) -> Self;
}
trait Input<I> { trait Input<I> {
unsafe fn input(&self, input: I); unsafe fn input(&self, input: I);
} }

View File

@ -1,158 +1,83 @@
use core::{marker::PhantomData, ops::Add}; use core::ops::Add;
use crate::{Node, NodeInput}; use crate::Node;
#[repr(C)] pub struct AddNode;
pub struct AddNode<'n, L: Add<R>, R, I1: Node<'n, Output = L>, I2: Node<'n, Output = R>>(pub I1, pub I2, PhantomData<&'n (L, R)>); impl<'n, L: Add<R>, R> Node<'n, (L, R)> for AddNode {
impl<'n, L: Add<R>, R, I1: Node<'n, Output = L>, I2: Node<'n, Output = R>> Node<'n> for AddNode<'n, L, R, I1, I2> {
type Output = <L as Add<R>>::Output; type Output = <L as Add<R>>::Output;
fn eval(&'n self) -> Self::Output { fn eval(&'n self, input: (L, R)) -> Self::Output {
self.0.eval() + self.1.eval() input.0 + input.1
}
}
impl<'n, L: Add<R>, R, I1: Node<'n, Output = L>, I2: Node<'n, Output = R>> AddNode<'n, L, R, I1, I2> {
pub fn new(input: (I1, I2)) -> AddNode<'n, L, R, I1, I2> {
AddNode(input.0, input.1, PhantomData)
} }
} }
#[repr(C)] pub struct CloneNode;
pub struct CloneNode<'n, N: Node<'n, Output = &'n O>, O: Clone + 'n>(pub N, PhantomData<&'n ()>); impl<'n, O: Clone> Node<'n, &'n O> for CloneNode {
impl<'n, N: Node<'n, Output = &'n O>, O: Clone> Node<'n> for CloneNode<'n, N, O> {
type Output = O; type Output = O;
fn eval(&'n self) -> Self::Output { fn eval(&'n self, input: &'n O) -> Self::Output {
self.0.eval().clone() input.clone()
}
}
impl<'n, N: Node<'n, Output = &'n O>, O: Clone> CloneNode<'n, N, O> {
pub const fn new(node: N) -> CloneNode<'n, N, O> {
CloneNode(node, PhantomData)
} }
} }
#[repr(C)] pub struct FstNode;
pub struct FstNode<'n, N: Node<'n>>(pub N, PhantomData<&'n ()>); impl<'n, T: 'n, U> Node<'n, (T, U)> for FstNode {
impl<'n, T: 'n, U, N: Node<'n, Output = (T, U)>> Node<'n> for FstNode<'n, N> {
type Output = T; type Output = T;
fn eval(&'n self) -> Self::Output { fn eval(&'n self, input: (T, U)) -> Self::Output {
let (a, _) = self.0.eval(); let (a, _) = input;
a a
} }
} }
#[repr(C)]
/// Destructures a Tuple of two values and returns the first one /// Destructures a Tuple of two values and returns the first one
pub struct SndNode<'n, N: Node<'n>>(pub N, PhantomData<&'n ()>); pub struct SndNode;
impl<'n, T, U: 'n, N: Node<'n, Output = (T, U)>> Node<'n> for SndNode<'n, N> { impl<'n, T, U: 'n> Node<'n, (T, U)> for SndNode {
type Output = U; type Output = U;
fn eval(&'n self) -> Self::Output { fn eval(&'n self, input: (T, U)) -> Self::Output {
let (_, b) = self.0.eval(); let (_, b) = input;
b b
} }
} }
#[repr(C)]
/// Return a tuple with two instances of the input argument /// Return a tuple with two instances of the input argument
pub struct DupNode<'n, N: Node<'n>>(N, PhantomData<&'n ()>); pub struct DupNode;
impl<'n, N: Node<'n>> Node<'n> for DupNode<'n, N> { impl<'n, T: Clone> Node<'n, T> for DupNode {
type Output = (N::Output, N::Output); type Output = (T, T);
fn eval(&'n self) -> Self::Output { fn eval(&'n self, input: T) -> Self::Output {
(self.0.eval(), self.0.eval()) //TODO: use Copy/Clone implementation (input.clone(), input) //TODO: use Copy/Clone implementation
}
}
impl<'n, N: Node<'n>> NodeInput for DupNode<'n, N> {
type Nodes = N;
fn new(input: Self::Nodes) -> Self {
Self(input, PhantomData)
} }
} }
#[repr(C)]
/// Return the Input Argument /// Return the Input Argument
pub struct IdNode<'n, N: Node<'n>>(N, PhantomData<&'n ()>); pub struct IdNode;
impl<'n, N: Node<'n>> Node<'n> for IdNode<'n, N> { impl<'n, T> Node<'n, T> for IdNode {
type Output = N::Output; type Output = T;
fn eval(&'n self) -> Self::Output { fn eval(&'n self, input: T) -> Self::Output {
self.0.eval() input
}
}
impl<'n, N: Node<'n>> NodeInput for IdNode<'n, N> {
type Nodes = N;
fn new(input: Self::Nodes) -> Self {
Self(input, PhantomData)
} }
} }
pub fn foo() { #[cfg(test)]
let unit = crate::value::UnitNode; mod test {
let value = IdNode(crate::value::ValueNode(2u32), PhantomData); use super::*;
let value2 = crate::value::ValueNode(4u32); use crate::{generic::*, structural::*, value::*};
let dup = DupNode(&value, PhantomData);
#[test]
pub fn foo() {
let value = ComposeNode::new(ValueNode(4u32), IdNode);
let value2 = ValueNode(5u32);
let dup = DupNode.after(value);
fn int(_: (), state: &u32) -> &u32 { fn int(_: (), state: &u32) -> &u32 {
state state
} }
fn swap<'n>(input: (&'n u32, &'n u32)) -> (&'n u32, &'n u32) { fn swap(input: (u32, u32)) -> (u32, u32) {
(input.1, input.0) (input.1, input.0)
} }
let fnn = crate::generic::FnNode::new(swap, &dup); let fnn = FnNode::new(&swap);
let fns = crate::generic::FnNodeWithState::new(int, &unit, 42u32); let fns = FnNodeWithState::new(int, 42u32);
let _ = fnn.eval(); assert_eq!(fnn.eval((1u32, 2u32)), (2, 1));
let _ = fns.eval(); let _ = fns.eval(());
let snd = SndNode(&fnn, PhantomData); let snd = SndNode.after(dup);
let _ = snd.eval(); assert_eq!(snd.eval(()), &4u32);
let add = AddNode(&snd, value2, PhantomData); let sum = AddNode.after(ConsNode(snd)).eval(value2.eval(()));
let _ = add.eval(); assert_eq!(sum, 9);
}
#[cfg(target_arch = "spirv")]
pub mod gpu {
//#![deny(warnings)]
#[repr(C)]
pub struct PushConsts {
n: u32,
node: u32,
}
use super::*;
use crate::{structural::ComposeNodeOwned, Node};
//use crate::Node;
use spirv_std::glam::UVec3;
const ADD: AddNode<u32> = AddNode(PhantomData);
const OPERATION: ComposeNodeOwned<'_, (u32, u32), u32, FstNode<u32, u32>, DupNode<u32>> = ComposeNodeOwned::new(FstNode(PhantomData, PhantomData), DupNode(PhantomData));
#[allow(unused)]
#[spirv(compute(threads(64)))]
pub fn spread(
#[spirv(global_invocation_id)] global_id: UVec3,
#[spirv(storage_buffer, descriptor_set = 0, binding = 0)] a: &[(u32, u32)],
#[spirv(storage_buffer, descriptor_set = 0, binding = 1)] y: &mut [(u32, u32)],
#[spirv(push_constant)] push_consts: &PushConsts,
) {
fn node_graph(input: Input) -> Output {
let n0 = ValueNode::new(input);
let n1 = IdNode::new(n0);
let n2 = IdNode::new(n1);
return n2.eval();
}
let gid = global_id.x as usize;
// Only process up to n, which is the length of the buffers.
if global_id.x < push_consts.n {
y[gid] = node_graph(a[gid]);
}
}
#[allow(unused)]
#[spirv(compute(threads(64)))]
pub fn add(
#[spirv(global_invocation_id)] global_id: UVec3,
#[spirv(storage_buffer, descriptor_set = 0, binding = 0)] a: &[(u32, u32)],
#[spirv(storage_buffer, descriptor_set = 0, binding = 1)] y: &mut [u32],
#[spirv(push_constant)] push_consts: &PushConsts,
) {
let gid = global_id.x as usize;
// Only process up to n, which is the length of the buffers.
if global_id.x < push_consts.n {
y[gid] = ADD.eval(a[gid]);
}
} }
} }

View File

@ -1,11 +1,11 @@
// use serde::{Deserialize, Serialize}; use serde::{Deserialize, Serialize};
/// Structure that represents a color. /// Structure that represents a color.
/// Internally alpha is stored as `f32` that ranges from `0.0` (transparent) to `1.0` (opaque). /// Internally alpha is stored as `f32` that ranges from `0.0` (transparent) to `1.0` (opaque).
/// The other components (RGB) are stored as `f32` that range from `0.0` up to `f32::MAX`, /// The other components (RGB) are stored as `f32` that range from `0.0` up to `f32::MAX`,
/// the values encode the brightness of each channel proportional to the light intensity in cd/m² (nits) in HDR, and `0.0` (black) to `1.0` (white) in SDR color. /// the values encode the brightness of each channel proportional to the light intensity in cd/m² (nits) in HDR, and `0.0` (black) to `1.0` (white) in SDR color.
#[repr(C)] #[repr(C)]
#[derive(Debug, Clone, Copy, PartialEq, Default)] //, Serialize, Deserialize)] #[derive(Debug, Clone, Copy, PartialEq, Default, Serialize, Deserialize)]
pub struct Color { pub struct Color {
red: f32, red: f32,
green: f32, green: f32,

View File

@ -2,52 +2,13 @@ use core::marker::PhantomData;
use crate::Node; use crate::Node;
pub struct ComposeNode<'n, Inter, First, Second> { pub struct ComposeNode<'n, Input, First: Node<'n, Input>, Second> {
first: &'n First, first: First,
second: &'n Second, second: Second,
_phantom: PhantomData<&'n Input>, _phantom: PhantomData<&'n Input>,
_phantom2: PhantomData<Inter>,
} }
impl<'n, Input: 'n, Inter: 'n, First, Second> Node<'n, Input> for ComposeNode<'n, Input, Inter, First, Second> impl<'n, Input, Inter, First, Second> Node<'n, Input> for ComposeNode<'n, Input, First, Second>
where
First: Node<'n, Input, Output = Inter>,
Second: Node<'n, Inter>, /*+ Node<<First as Node<Input>>::Output<'n>>*/
{
type Output = <Second as Node<'n, Inter>>::Output;
fn eval(&'n self, input: Input) -> Self::Output {
// evaluate the first node with the given input
// and then pipe the result from the first computation
// into the second node
let arg: Inter = self.first.eval(input);
self.second.eval(arg)
}
}
impl<'n, Input, Inter, FIRST, SECOND> ComposeNode<'n, Input, Inter, FIRST, SECOND>
where
FIRST: Node<'n, Input>,
{
pub const fn new(first: &'n FIRST, second: &'n SECOND) -> Self {
ComposeNode::<'n, Input, Inter, FIRST, SECOND> {
first,
second,
_phantom: PhantomData,
_phantom2: PhantomData,
}
}
}
#[repr(C)]
pub struct ComposeNodeOwned<'n, Input, Inter, FIRST, SECOND> {
first: FIRST,
second: SECOND,
_phantom: PhantomData<&'n Input>,
_phantom2: PhantomData<Inter>,
}
impl<'n, Input: 'n, Inter: 'n, First, Second> Node<'n, Input> for ComposeNodeOwned<'n, Input, Inter, First, Second>
where where
First: Node<'n, Input, Output = Inter>, First: Node<'n, Input, Output = Inter>,
Second: Node<'n, Inter>, Second: Node<'n, Inter>,
@ -63,33 +24,37 @@ where
} }
} }
impl<'n, Input, Inter, First: 'n, Second> ComposeNodeOwned<'n, Input, Inter, First, Second> impl<'n, Input, First, Second> ComposeNode<'n, Input, First, Second>
where where
First: Node<'n, Input, Output = Inter>, First: Node<'n, Input>,
Second: Node<'n, First::Output>,
{ {
#[cfg(feature = "nightly")]
pub const fn new(first: First, second: Second) -> Self { pub const fn new(first: First, second: Second) -> Self {
ComposeNodeOwned::<'n, Input, Inter, First, Second> { ComposeNode::<'n, Input, First, Second> { first, second, _phantom: PhantomData }
first,
second,
_phantom: PhantomData,
_phantom2: PhantomData,
}
}
#[cfg(not(feature = "nightly"))]
pub fn new(first: First, second: Second) -> Self {
ComposeNodeOwned::<'n, Input, Inter, First, Second> {
first,
second,
_phantom: PhantomData,
_phantom2: PhantomData,
}
} }
} }
pub trait After<I>: Sized { pub trait After<Inter>: Sized {
fn after<'n, First: Node<'n, I>>(&'n self, first: &'n First) -> ComposeNode<'n, I, <First as Node<'n, I>>::Output, First, Self> { fn after<'n, First, Input>(self, first: First) -> ComposeNode<'n, Input, First, Self>
where
First: Node<'n, Input, Output = Inter>,
Self: Node<'n, Inter>,
{
ComposeNode::new(first, self) ComposeNode::new(first, self)
} }
} }
impl<Second: for<'n> Node<'n, I>, I> After<I> for Second {} impl<'n, Second: Node<'n, I>, I> After<I> for Second {}
pub struct ConsNode<Root>(pub Root);
impl<'n, Root, Input> Node<'n, Input> for ConsNode<Root>
where
Root: Node<'n, ()>,
{
type Output = (Input, <Root as Node<'n, ()>>::Output);
fn eval(&'n self, input: Input) -> Self::Output {
let arg = self.0.eval(());
(input, arg)
}
}