Restructure GPU execution to model GPU pipelines in the node graph (#1088)

* Start implementing GpuExecutor for wgpu

* Implement read_output_buffer function

* Implement extraction node in the compiler

* Generate type annotations during shader compilation

* Start adding node wrapprs for graph execution api

* Wrap more of the api in nodes

* Restructure Pipeline to accept arbitrary shader inputs

* Adapt nodes to new trait definitions

* Start implementing gpu-compiler trait

* Adapt shader generation

* Hardstuck on pointer casts

* Pass nodes as references in gpu code to avoid zsts

* Update gcore to compile on the gpu

* Fix color doc tests

* Impl Node for node refs
This commit is contained in:
Dennis Kobert 2023-04-23 10:18:31 +02:00 committed by Keavon Chambers
parent 161bbc62b4
commit bdc1ef926a
43 changed files with 1874 additions and 515 deletions

521
Cargo.lock generated

File diff suppressed because it is too large Load Diff

View File

@ -15,6 +15,7 @@ members = [
"node-graph/compilation-client",
"node-graph/vulkan-executor",
"node-graph/wgpu-executor",
"node-graph/gpu-executor",
"node-graph/future-executor",
"node-graph/gpu-compiler/gpu-compiler-bin-wrapper",
"libraries/dyn-any",

View File

@ -18,13 +18,13 @@
# dependencies not shared by any other crates, would be ignored, as the target
# list here is effectively saying which targets you are building for.
targets = [
# The triple can be any string, but only the target triples built in to
# rustc (as of 1.40) can be checked against actual config expressions
#{ triple = "x86_64-unknown-linux-musl" },
# You can also specify which target_features you promise are enabled for a
# particular target. target_features are currently not validated against
# the actual valid features supported by the target architecture.
#{ triple = "wasm32-unknown-unknown", features = ["atomics"] },
# The triple can be any string, but only the target triples built in to
# rustc (as of 1.40) can be checked against actual config expressions
#{ triple = "x86_64-unknown-linux-musl" },
# You can also specify which target_features you promise are enabled for a
# particular target. target_features are currently not validated against
# the actual valid features supported by the target architecture.
#{ triple = "wasm32-unknown-unknown", features = ["atomics"] },
]
# This section is considered when running `cargo deny check advisories`
@ -48,7 +48,7 @@ notice = "warn"
# A list of advisory IDs to ignore. Note that ignored advisories will still
# output a note when they are encountered.
ignore = [
#"RUSTSEC-0000-0000",
#"RUSTSEC-0000-0000",
"RUSTSEC-2020-0071", # This has been fixed in the version of chrono we use
]
# Threshold for security vulnerabilities, any vulnerability with a CVSS score
@ -71,25 +71,25 @@ unlicensed = "deny"
# See https://spdx.org/licenses/ for list of possible licenses
# [possible values: any SPDX 3.11 short identifier (+ optional exception)].
allow = [
"MIT",
"MIT-0",
"Apache-2.0",
"BSD-3-Clause",
"BSD-2-Clause",
"Zlib",
"MIT",
"MIT-0",
"Apache-2.0",
"BSD-3-Clause",
"BSD-2-Clause",
"Zlib",
"Unicode-DFS-2016",
"ISC",
"MPL-2.0",
"CC0-1.0",
"OpenSSL",
"BSL-1.0",
#"Apache-2.0 WITH LLVM-exception",
"Apache-2.0 WITH LLVM-exception",
]
# List of explicitly disallowed licenses
# See https://spdx.org/licenses/ for list of possible licenses
# [possible values: any SPDX 3.11 short identifier (+ optional exception)].
deny = [
#"Nokia",
#"Nokia",
]
# Lint level for licenses considered copyleft
copyleft = "deny"
@ -113,9 +113,9 @@ confidence-threshold = 0.8
# Allow 1 or more licenses on a per-crate basis, so that particular licenses
# aren't accepted for every possible crate as with the normal allow list
exceptions = [
# Each entry is the crate and version constraint, and its specific allow
# list
#{ allow = ["Zlib"], name = "adler32", version = "*" },
# Each entry is the crate and version constraint, and its specific allow
# list
#{ allow = ["Zlib"], name = "adler32", version = "*" },
]
# Some crates don't have (easily) machine readable licensing information,
@ -134,8 +134,8 @@ expression = "MIT AND ISC AND OpenSSL"
# and the crate will be checked normally, which may produce warnings or errors
# depending on the rest of your configuration
license-files = [
# Each entry is a crate relative path, and the (opaque) hash of its contents
{ path = "LICENSE", hash = 0xbd0eed23 }
# Each entry is a crate relative path, and the (opaque) hash of its contents
{ path = "LICENSE", hash = 0xbd0eed23 }
]
[licenses.private]
@ -146,7 +146,7 @@ ignore = false
# is only published to private registries, and ignore is true, the crate will
# not have its license(s) checked
registries = [
#"https://sekretz.com/registry
#"https://sekretz.com/registry
]
# This section is considered when running `cargo deny check bans`.
@ -165,29 +165,29 @@ wildcards = "allow"
highlight = "all"
# List of crates that are allowed. Use with care!
allow = [
#{ name = "ansi_term", version = "=0.11.0" },
#{ name = "ansi_term", version = "=0.11.0" },
]
# List of crates to deny
deny = [
# Each entry the name of a crate and a version range. If version is
# not specified, all versions will be matched.
#{ name = "ansi_term", version = "=0.11.0" },
#
# Wrapper crates can optionally be specified to allow the crate when it
# is a direct dependency of the otherwise banned crate
#{ name = "ansi_term", version = "=0.11.0", wrappers = [] },
# Each entry the name of a crate and a version range. If version is
# not specified, all versions will be matched.
#{ name = "ansi_term", version = "=0.11.0" },
#
# Wrapper crates can optionally be specified to allow the crate when it
# is a direct dependency of the otherwise banned crate
#{ name = "ansi_term", version = "=0.11.0", wrappers = [] },
]
# Certain crates/versions that will be skipped when doing duplicate detection.
skip = [
#{ name = "ansi_term", version = "=0.11.0" },
{ name = "cfg-if", version = "=0.1.10" },
#{ name = "ansi_term", version = "=0.11.0" },
{ name = "cfg-if", version = "=0.1.10" },
]
# Similarly to `skip` allows you to skip certain crates during duplicate
# detection. Unlike skip, it also includes the entire tree of transitive
# dependencies starting at the specified crate, up to a certain depth, which is
# by default infinite
skip-tree = [
#{ name = "ansi_term", version = "=0.11.0", depth = 20 },
#{ name = "ansi_term", version = "=0.11.0", depth = 20 },
]
# This section is considered when running `cargo deny check sources`.

View File

@ -169,13 +169,13 @@ impl<'a> ModifyInputsContext<'a> {
let NodeInput::Value {
tagged_value: TaggedValue::Subpaths(subpaths),
..
} = subpaths else{
} = subpaths else {
return;
};
let NodeInput::Value {
tagged_value: TaggedValue::ManipulatorGroupIds(mirror_angle_groups),
..
} = mirror_angle_groups else{
} = mirror_angle_groups else {
return;
};

View File

@ -1018,8 +1018,8 @@ pub fn wrap_network_in_scope(network: NodeNetwork) -> NodeNetwork {
// if the network has no network inputs, it doesn't need to be wrapped in a scope either
let Some(input_type) = input else {
return network;
};
return network;
};
let inner_network = DocumentNode {
name: "Scope".to_string(),

View File

@ -160,13 +160,13 @@ unsafe impl<'a, T: StaticTypeSized> StaticType for &'a [T] {
type Static = &'static [<T as StaticTypeSized>::Static];
}
macro_rules! impl_slice {
($($id:ident),*) => {
$(
unsafe impl<'a, T: StaticTypeSized> StaticType for $id<'a, T> {
type Static = $id<'static, <T as StaticTypeSized>::Static>;
}
)*
};
($($id:ident),*) => {
$(
unsafe impl<'a, T: StaticTypeSized> StaticType for $id<'a, T> {
type Static = $id<'static, <T as StaticTypeSized>::Static>;
}
)*
};
}
mod slice {

View File

@ -9,9 +9,18 @@ license = "MIT OR Apache-2.0"
[dependencies]
#tokio = { version = "1.0", features = ["full"] }
serde_json = "1.0"
graph-craft = { version = "0.1.0", path = "../graph-craft", features = ["serde"] }
graph-craft = { version = "0.1.0", path = "../graph-craft", features = [
"serde",
] }
gpu-executor = { version = "0.1.0", path = "../gpu-executor" }
gpu-compiler-bin-wrapper = { version = "0.1.0", path = "../gpu-compiler/gpu-compiler-bin-wrapper" }
tempfile = "3.3.0"
anyhow = "1.0.68"
reqwest = { version = "0.11", features = ["blocking", "serde_json", "json", "rustls", "rustls-tls"] }
future-executor = {path = "../future-executor"}
reqwest = { version = "0.11", features = [
"blocking",
"serde_json",
"json",
"rustls",
"rustls-tls",
] }
future-executor = { path = "../future-executor" }

View File

@ -1,15 +1,30 @@
use gpu_compiler_bin_wrapper::CompileRequest;
use graph_craft::document::*;
use gpu_executor::ShaderIO;
use graph_craft::{proto::ProtoNetwork, Type};
pub async fn compile<I, O>(network: NodeNetwork) -> Result<Vec<u8>, reqwest::Error> {
pub async fn compile(network: ProtoNetwork, inputs: Vec<Type>, output: Type, io: ShaderIO) -> Result<Shader, reqwest::Error> {
let client = reqwest::Client::new();
let compile_request = CompileRequest::new(network, std::any::type_name::<I>().to_owned(), std::any::type_name::<O>().to_owned());
let compile_request = CompileRequest::new(network, inputs.clone(), output.clone(), io.clone());
let response = client.post("http://localhost:3000/compile/spirv").json(&compile_request).send();
let response = response.await?;
response.bytes().await.map(|b| b.to_vec())
response.bytes().await.map(|b| Shader {
spirv_binary: b.windows(4).map(|x| u32::from_le_bytes(x.try_into().unwrap())).collect(),
input_types: inputs,
output_type: output,
io,
})
}
pub fn compile_sync<I: 'static, O: 'static>(network: NodeNetwork) -> Result<Vec<u8>, reqwest::Error> {
future_executor::block_on(compile::<I, O>(network))
pub fn compile_sync(network: ProtoNetwork, inputs: Vec<Type>, output: Type, io: ShaderIO) -> Result<Shader, reqwest::Error> {
future_executor::block_on(compile(network, inputs, output, io))
}
// TODO: should we add the entry point as a field?
/// A compiled shader with type annotations.
pub struct Shader {
pub spirv_binary: Vec<u32>,
pub input_types: Vec<Type>,
pub output_type: Type,
pub io: ShaderIO,
}

View File

@ -1,40 +1,55 @@
use gpu_compiler_bin_wrapper::CompileRequest;
use gpu_executor::{ShaderIO, ShaderInput};
use graph_craft::concrete;
use graph_craft::document::*;
use graph_craft::*;
use std::borrow::Cow;
use std::time::Duration;
fn main() {
let client = reqwest::blocking::Client::new();
let network = NodeNetwork {
inputs: vec![0],
outputs: vec![NodeOutput::new(0, 0)],
disabled: vec![],
previous_outputs: None,
nodes: [(
0,
DocumentNode {
name: "Inc".into(),
inputs: vec![NodeInput::Network(concrete!(u32))],
implementation: DocumentNodeImplementation::Network(add_network()),
metadata: DocumentNodeMetadata::default(),
},
)]
.into_iter()
.collect(),
// let network = NodeNetwork {
// inputs: vec![0],
// outputs: vec![NodeOutput::new(0, 0)],
// disabled: vec![],
// previous_outputs: None,
// nodes: [(
// 0,
// DocumentNode {
// name: "Inc".into(),
// inputs: vec![NodeInput::Network(concrete!(u32))],
// implementation: DocumentNodeImplementation::Network(add_network()),
// metadata: DocumentNodeMetadata::default(),
// },
// )]
// .into_iter()
// .collect(),
// };
let network = add_network();
let compiler = graph_craft::executor::Compiler {};
let proto_network = compiler.compile_single(network, true).unwrap();
let io = ShaderIO {
inputs: vec![ShaderInput::StorageBuffer((), concrete!(u32))],
output: ShaderInput::OutputBuffer((), concrete!(&mut [u32])),
};
let compile_request = CompileRequest::new(network, "u32".to_owned(), "u32".to_owned());
let response = client.post("http://localhost:3000/compile/spirv").json(&compile_request).send().unwrap();
let compile_request = CompileRequest::new(proto_network, vec![concrete!(u32)], concrete!(u32), io);
let response = client
.post("http://localhost:3000/compile/spirv")
.timeout(Duration::from_secs(30))
.json(&compile_request)
.send()
.unwrap();
println!("response: {:?}", response);
}
fn add_network() -> NodeNetwork {
NodeNetwork {
inputs: vec![0],
outputs: vec![NodeOutput::new(1, 0)],
inputs: vec![],
outputs: vec![NodeOutput::new(0, 0)],
disabled: vec![],
previous_outputs: None,
nodes: [
@ -42,20 +57,20 @@ fn add_network() -> NodeNetwork {
0,
DocumentNode {
name: "Dup".into(),
inputs: vec![NodeInput::Network(concrete!(u32))],
inputs: vec![NodeInput::value(value::TaggedValue::U32(5u32), false)],
metadata: DocumentNodeMetadata::default(),
implementation: DocumentNodeImplementation::Unresolved(NodeIdentifier::new("graphene_core::ops::DupNode")),
},
),
(
1,
DocumentNode {
name: "Add".into(),
inputs: vec![NodeInput::node(0, 0)],
metadata: DocumentNodeMetadata::default(),
implementation: DocumentNodeImplementation::Unresolved(NodeIdentifier::new("graphene_core::ops::AddNode")),
implementation: DocumentNodeImplementation::Unresolved(NodeIdentifier::new("graphene_core::ops::IdNode")),
},
),
// (
// 1,
// DocumentNode {
// name: "Add".into(),
// inputs: vec![NodeInput::node(0, 0)],
// metadata: DocumentNodeMetadata::default(),
// implementation: DocumentNodeImplementation::Unresolved(NodeIdentifier::new("graphene_core::ops::AddNode")),
// },
// ),
]
.into_iter()
.collect(),

View File

@ -9,10 +9,23 @@ license = "MIT OR Apache-2.0"
# See more keys and their definitions at https://doc.rust-lang.org/cargo/reference/manifest.html
[features]
std = ["dyn-any", "dyn-any/std", "alloc", "glam/std", "specta"]
default = ["async", "serde", "kurbo", "log", "std"]
std = [
"dyn-any",
"dyn-any/std",
"alloc",
"glam/std",
"specta",
"num-traits/std",
]
default = ["async", "serde", "kurbo", "log", "std", "rand_chacha"]
log = ["dep:log"]
serde = ["dep:serde", "glam/serde", "bezier-rs/serde", "base64"]
serde = [
"dep:serde",
"glam/serde",
"bezier-rs/serde",
"bezier-rs/serde",
"base64",
]
gpu = ["spirv-std", "glam/bytemuck", "dyn-any", "glam/libm"]
async = ["async-trait", "alloc"]
nightly = []
@ -25,7 +38,7 @@ dyn-any = { path = "../../libraries/dyn-any", features = [
"glam",
], optional = true, default-features = false }
spirv-std = { version = "0.5", features = ["glam"], optional = true }
spirv-std = { version = "0.7", optional = true }
bytemuck = { version = "1.8", features = ["derive"] }
async-trait = { version = "0.1", optional = true }
serde = { version = "1.0", features = [
@ -33,11 +46,11 @@ serde = { version = "1.0", features = [
], optional = true, default-features = false }
log = { version = "0.4", optional = true }
rand_chacha = { version = "0.3.1", optional = true }
bezier-rs = { path = "../../libraries/bezier-rs", optional = true }
kurbo = { git = "https://github.com/linebender/kurbo.git", features = [
"serde",
], optional = true }
rand_chacha = "0.3.1"
spin = "0.9.2"
glam = { version = "^0.22", default-features = false, features = [
"scalar-math",
@ -47,6 +60,7 @@ base64 = { version = "0.13", optional = true }
specta.workspace = true
specta.optional = true
once_cell = { version = "1.17.0", default-features = false, optional = true }
num = "0.4.0"
num-derive = "0.3.3"
num-traits = "0.2.15"
num-derive = { version = "0.3.3" }
num-traits = { version = "0.2.15", default-features = false, features = [
"i128",
] }

View File

@ -1,4 +1,7 @@
use crate::{raster::Sample, Color};
use bytemuck::{Pod, Zeroable};
use spirv_std::image::{Image2d, SampledImage};
#[repr(C)]
#[derive(Clone, Copy, Debug, PartialEq, Eq, Pod, Zeroable)]
@ -6,3 +9,12 @@ pub struct PushConstants {
pub n: u32,
pub node: u32,
}
impl Sample for SampledImage<Image2d> {
type Pixel = Color;
fn sample(&self, pos: glam::DVec2) -> Option<Self::Pixel> {
let color = self.sample(pos);
Color::from_rgbaf32(color.x, color.y, color.z, color.w)
}
}

View File

@ -18,6 +18,8 @@ pub mod value;
#[cfg(feature = "gpu")]
pub mod gpu;
pub mod storage;
pub mod raster;
#[cfg(feature = "alloc")]
pub mod transform;
@ -44,8 +46,8 @@ pub use types::*;
pub trait NodeIO<'i, Input: 'i>: 'i + Node<'i, Input>
where
Self::Output: 'i + StaticType,
Input: 'i + StaticType,
Self::Output: 'i + StaticTypeSized,
Input: 'i + StaticTypeSized,
{
fn input_type(&self) -> TypeId {
TypeId::of::<Input::Static>()
@ -54,7 +56,7 @@ where
core::any::type_name::<Input>()
}
fn output_type(&self) -> core::any::TypeId {
TypeId::of::<<Self::Output as StaticType>::Static>()
TypeId::of::<<Self::Output as StaticTypeSized>::Static>()
}
fn output_type_name(&self) -> &'static str {
core::any::type_name::<Self::Output>()
@ -62,8 +64,8 @@ where
#[cfg(feature = "alloc")]
fn to_node_io(&self, parameters: Vec<Type>) -> NodeIOTypes {
NodeIOTypes {
input: concrete!(<Input as StaticType>::Static),
output: concrete!(<Self::Output as StaticType>::Static),
input: concrete!(<Input as StaticTypeSized>::Static),
output: concrete!(<Self::Output as StaticTypeSized>::Static),
parameters,
}
}
@ -71,8 +73,8 @@ where
impl<'i, N: Node<'i, I>, I> NodeIO<'i, I> for N
where
N::Output: 'i + StaticType,
I: 'i + StaticType,
N::Output: 'i + StaticTypeSized,
I: 'i + StaticTypeSized,
{
}
@ -83,6 +85,13 @@ where
(**self).eval(input)
}
}*/
impl<'i, 's: 'i, I: 'i, O: 'i, N: Node<'i, I, Output = O>> Node<'i, I> for &'s N {
type Output = O;
fn eval(&'i self, input: I) -> Self::Output {
(**self).eval(input)
}
}
impl<'i, I: 'i, O: 'i> Node<'i, I> for &'i dyn for<'a> Node<'a, I, Output = O> {
type Output = O;
@ -92,7 +101,7 @@ impl<'i, I: 'i, O: 'i> Node<'i, I> for &'i dyn for<'a> Node<'a, I, Output = O> {
}
use core::pin::Pin;
use dyn_any::StaticType;
use dyn_any::StaticTypeSized;
#[cfg(feature = "alloc")]
impl<'i, I: 'i, O: 'i> Node<'i, I> for Pin<Box<dyn for<'a> Node<'a, I, Output = O> + 'i>> {
type Output = O;

View File

@ -1,5 +1,5 @@
use core::marker::PhantomData;
use core::ops::Add;
use core::ops::{Add, Mul};
use crate::Node;
@ -30,6 +30,27 @@ where
first + second
}
pub struct MulParameterNode<Second> {
second: Second,
}
#[node_macro::node_fn(MulParameterNode)]
fn flat_map<U, T>(first: U, second: T) -> <U as Mul<T>>::Output
where
U: Mul<T>,
{
first * second
}
#[cfg(feature = "std")]
struct SizeOfNode {}
#[cfg(feature = "std")]
#[node_macro::node_fn(SizeOfNode)]
fn flat_map(ty: crate::Type) -> Option<usize> {
ty.size()
}
#[derive(Debug, Copy, Clone, PartialEq, Eq, PartialOrd, Ord, Hash, Default)]
pub struct SomeNode;
#[node_macro::node_fn(SomeNode)]

View File

@ -2,6 +2,9 @@ use crate::raster::Color;
use crate::Node;
use dyn_any::{DynAny, StaticType};
#[cfg(target_arch = "spirv")]
use spirv_std::num_traits::Float;
#[derive(Clone, Debug, DynAny, PartialEq)]
#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
pub struct Quantization {

View File

@ -4,50 +4,52 @@ use crate::Node;
use bytemuck::{Pod, Zeroable};
use glam::DVec2;
use num::Num;
#[cfg(not(target_arch = "spirv"))]
use num_traits::{cast::cast as num_cast, Num, NumCast};
#[cfg(target_arch = "spirv")]
use spirv_std::num_traits::float::Float;
use spirv_std::num_traits::{cast::cast as num_cast, float::Float, FromPrimitive, Num, NumCast, ToPrimitive};
pub use self::color::Color;
pub mod adjustments;
#[cfg(not(target_arch = "spirv"))]
pub mod brightness_contrast;
pub mod color;
pub use adjustments::*;
pub trait Channel: Copy + Debug + num::Num + num::NumCast {
pub trait Channel: Copy + Debug + Num + NumCast {
fn to_linear<Out: Linear>(self) -> Out;
fn from_linear<In: Linear>(linear: In) -> Self;
fn to_f32(self) -> f32 {
num::cast(self).expect("Failed to convert channel to f32")
num_cast(self).expect("Failed to convert channel to f32")
}
fn from_f32(value: f32) -> Self {
num::cast(value).expect("Failed to convert f32 to channel")
num_cast(value).expect("Failed to convert f32 to channel")
}
fn to_f64(self) -> f64 {
num::cast(self).expect("Failed to convert channel to f64")
num_cast(self).expect("Failed to convert channel to f64")
}
fn from_f64(value: f64) -> Self {
num::cast(value).expect("Failed to convert f64 to channel")
num_cast(value).expect("Failed to convert f64 to channel")
}
fn to_channel<Out: Channel>(self) -> Out {
num::cast(self).expect("Failed to convert channel to channel")
num_cast(self).expect("Failed to convert channel to channel")
}
}
pub trait Linear: num::NumCast + Num {}
pub trait Linear: NumCast + Num {}
impl Linear for f32 {}
impl Linear for f64 {}
impl<T: Linear + Debug + Copy> Channel for T {
#[inline(always)]
fn to_linear<Out: Linear>(self) -> Out {
num::cast(self).expect("Failed to convert channel to linear")
num_cast(self).expect("Failed to convert channel to linear")
}
#[inline(always)]
fn from_linear<In: Linear>(linear: In) -> Self {
num::cast(linear).expect("Failed to convert linear to channel")
num_cast(linear).expect("Failed to convert linear to channel")
}
}
@ -58,16 +60,16 @@ struct SRGBGammaFloat(f32);
impl Channel for SRGBGammaFloat {
#[inline(always)]
fn to_linear<Out: Linear>(self) -> Out {
let channel = num::cast::<_, f32>(self).expect("Failed to convert srgb to linear");
let channel = num_cast::<_, f32>(self).expect("Failed to convert srgb to linear");
let out = if channel <= 0.04045 { channel / 12.92 } else { ((channel + 0.055) / 1.055).powf(2.4) };
num::cast(out).expect("Failed to convert srgb to linear")
num_cast(out).expect("Failed to convert srgb to linear")
}
#[inline(always)]
fn from_linear<In: Linear>(linear: In) -> Self {
let linear = num::cast::<_, f32>(linear).expect("Failed to convert linear to srgb");
let linear = num_cast::<_, f32>(linear).expect("Failed to convert linear to srgb");
let out = if linear <= 0.0031308 { linear * 12.92 } else { 1.055 * linear.powf(1. / 2.4) - 0.055 };
num::cast(out).expect("Failed to convert linear to srgb")
num_cast(out).expect("Failed to convert linear to srgb")
}
}
pub trait RGBPrimaries {
@ -98,6 +100,7 @@ impl<T> Serde for T {}
// TODO: Come up with a better name for this trait
pub trait Pixel: Clone + Pod + Zeroable {
#[cfg(not(target_arch = "spirv"))]
fn to_bytes(&self) -> Vec<u8> {
bytemuck::bytes_of(self).to_vec()
}
@ -107,7 +110,7 @@ pub trait Pixel: Clone + Pod + Zeroable {
}
fn byte_size() -> usize {
std::mem::size_of::<Self>()
core::mem::size_of::<Self>()
}
}
pub trait RGB: Pixel {
@ -448,6 +451,8 @@ pub struct ImageSlice<'a, Pixel> {
pub data: &'a [Pixel],
#[cfg(target_arch = "spirv")]
pub data: &'a (),
#[cfg(target_arch = "spirv")]
pub _marker: PhantomData<Pixel>,
}
unsafe impl<P: StaticTypeSized> StaticType for ImageSlice<'_, P> {
@ -470,20 +475,17 @@ impl<'a, P> Default for ImageSlice<'a, P> {
width: Default::default(),
height: Default::default(),
data: &NOTHING,
_marker: PhantomData,
}
}
}
#[cfg(not(target_arch = "spirv"))]
impl<P: Copy + Debug + Pixel> Raster for ImageSlice<'_, P> {
type Pixel = P;
#[cfg(not(target_arch = "spirv"))]
fn get_pixel(&self, x: u32, y: u32) -> Option<P> {
self.data.get((x + y * self.width) as usize).copied()
}
#[cfg(target_arch = "spirv")]
fn get_pixel(&self, _x: u32, _y: u32) -> P {
Color::default()
}
fn width(&self) -> u32 {
self.width
}
@ -605,6 +607,8 @@ mod image {
}
}
// TODO: Evaluate if this will be a problem for our use case.
/// Warning: This is an approximation of a hash, and is not guaranteed to not collide.
impl<P: Hash + Pixel> Hash for Image<P> {
fn hash<H: Hasher>(&self, state: &mut H) {
const HASH_SAMPLES: u64 = 1000;
@ -661,7 +665,7 @@ mod image {
let Image { width, height, data } = self;
let to_gamma = |x| SRGBGammaFloat::from_linear(x);
let to_u8 = |x| (num::cast::<_, f32>(x).unwrap() * 255.) as u8;
let to_u8 = |x| (num_cast::<_, f32>(x).unwrap() * 255.) as u8;
let result_bytes = data
.into_iter()
@ -670,7 +674,7 @@ mod image {
to_u8(to_gamma(color.r() / color.a().to_channel())),
to_u8(to_gamma(color.g() / color.a().to_channel())),
to_u8(to_gamma(color.b() / color.a().to_channel())),
(num::cast::<_, f32>(color.a()).unwrap() * 255.) as u8,
(num_cast::<_, f32>(color.a()).unwrap() * 255.) as u8,
]
})
.collect();

View File

@ -3,6 +3,7 @@ use crate::Node;
use core::fmt::Debug;
use dyn_any::{DynAny, StaticType};
#[cfg(feature = "serde")]
use serde::{Deserialize, Serialize};
#[cfg(target_arch = "spirv")]
@ -457,8 +458,9 @@ fn vibrance_node(color: Color, vibrance: f64) -> Color {
}
}
#[repr(C)]
#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize, Hash, DynAny, specta::Type)]
#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
#[cfg_attr(feature = "std", derive(specta::Type))]
#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, DynAny)]
pub enum RedGreenBlue {
Red,
Green,
@ -542,8 +544,9 @@ fn channel_mixer_node(
color.to_linear_srgb()
}
#[repr(C)]
#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize, Hash, DynAny, specta::Type)]
#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
#[cfg_attr(feature = "std", derive(specta::Type))]
#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, DynAny)]
pub enum RelativeAbsolute {
Relative,
Absolute,
@ -559,7 +562,9 @@ impl core::fmt::Display for RelativeAbsolute {
}
#[repr(C)]
#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize, Hash, DynAny, specta::Type)]
#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
#[cfg_attr(feature = "std", derive(specta::Type))]
#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, DynAny)]
pub enum SelectiveColorChoice {
Reds,
Yellows,
@ -797,17 +802,26 @@ fn exposure(color: Color, exposure: f64, offset: f64, gamma_correction: f64) ->
adjusted.map_rgb(|c: f32| c.clamp(0., 1.))
}
#[derive(Debug)]
pub struct IndexNode<Index> {
pub index: Index,
}
#[cfg(feature = "alloc")]
pub use index_node::IndexNode;
#[node_macro::node_fn(IndexNode)]
pub fn index_node(input: Vec<super::ImageFrame<Color>>, index: u32) -> super::ImageFrame<Color> {
if (index as usize) < input.len() {
input[index as usize].clone()
} else {
warn!("The number of segments is {} and the requested segment is {}!", input.len(), index);
super::ImageFrame::empty()
#[cfg(feature = "alloc")]
mod index_node {
use crate::raster::{Color, ImageFrame};
use crate::Node;
#[derive(Debug)]
pub struct IndexNode<Index> {
pub index: Index,
}
#[node_macro::node_fn(IndexNode)]
pub fn index_node(input: Vec<ImageFrame<Color>>, index: u32) -> ImageFrame<Color> {
if (index as usize) < input.len() {
input[index as usize].clone()
} else {
warn!("The number of segments is {} and the requested segment is {}!", input.len(), index);
ImageFrame::empty()
}
}
}

View File

@ -30,7 +30,7 @@ fn brightness_contrast_legacy_node(_primary: (), brightness: f32, contrast: f32)
let brightness = brightness / 255.;
let contrast = contrast / 100.;
let contrast = if contrast > 0. { (contrast * std::f32::consts::FRAC_PI_2 - 0.01).tan() } else { contrast };
let contrast = if contrast > 0. { (contrast * core::f32::consts::FRAC_PI_2 - 0.01).tan() } else { contrast };
let combined = brightness * contrast + brightness - contrast / 2.;
@ -172,7 +172,7 @@ fn solve_cubic_splines(cubic_spline_values: &CubicSplines) -> [f32; 4] {
// Eliminate the current column in all rows below the current one
for row_below_current in row + 1..4 {
assert!(augmented_matrix[row][row].abs() > std::f32::EPSILON);
assert!(augmented_matrix[row][row].abs() > core::f32::EPSILON);
let scale_factor = augmented_matrix[row_below_current][row] / augmented_matrix[row][row];
for col in row..5 {
@ -184,7 +184,7 @@ fn solve_cubic_splines(cubic_spline_values: &CubicSplines) -> [f32; 4] {
// Gaussian elimination: back substitution
let mut solutions = [0.; 4];
for col in (0..4).rev() {
assert!(augmented_matrix[col][col].abs() > std::f32::EPSILON);
assert!(augmented_matrix[col][col].abs() > core::f32::EPSILON);
solutions[col] = augmented_matrix[col][4] / augmented_matrix[col][col];

View File

@ -53,6 +53,7 @@ impl RGB for Color {
}
impl Pixel for Color {
#[cfg(not(target_arch = "spirv"))]
fn to_bytes(&self) -> Vec<u8> {
self.to_rgba8_srgb().to_vec()
}
@ -121,7 +122,6 @@ impl Color {
/// let color = Color::from_rgbaf32(1.0, 1.0, 1.0, f32::NAN);
/// assert!(color == None);
/// ```
#[cfg(not(target_arch = "spirv"))]
pub fn from_rgbaf32(red: f32, green: f32, blue: f32, alpha: f32) -> Option<Color> {
if alpha > 1. || [red, green, blue, alpha].iter().any(|c| c.is_sign_negative() || !c.is_finite()) {
return None;
@ -492,7 +492,7 @@ impl Color {
/// ```
/// use graphene_core::raster::color::Color;
/// let color = Color::from_rgbaf32(0.114, 0.103, 0.98, 0.97).unwrap();
/// assert!(color.components() == (0.114, 0.103, 0.98, 0.97));
/// assert_eq!(color.components(), (0.114, 0.103, 0.98, 0.97));
/// ```
pub fn components(&self) -> (f32, f32, f32, f32) {
(self.red, self.green, self.blue, self.alpha)
@ -585,7 +585,6 @@ impl Color {
/// use graphene_core::raster::color::Color;
/// let color = Color::from_rgba_str("7C67FA61").unwrap();
/// ```
#[cfg(not(target_arch = "spirv"))]
pub fn from_rgba_str(color_str: &str) -> Option<Color> {
if color_str.len() != 8 {
return None;
@ -603,7 +602,6 @@ impl Color {
/// use graphene_core::raster::color::Color;
/// let color = Color::from_rgb_str("7C67FA").unwrap();
/// ```
#[cfg(not(target_arch = "spirv"))]
pub fn from_rgb_str(color_str: &str) -> Option<Color> {
if color_str.len() != 6 {
return None;

View File

@ -0,0 +1,34 @@
use crate::Node;
use core::marker::PhantomData;
use core::ops::{DerefMut, Index, IndexMut};
struct SetNode<S, I, Storage, Index> {
storage: Storage,
index: Index,
_s: PhantomData<S>,
_i: PhantomData<I>,
}
#[node_macro::node_fn(SetNode<_S, _I>)]
fn set_node<T, _S, _I>(value: T, storage: &'any_input mut _S, index: _I)
where
_S: IndexMut<_I>,
_S::Output: DerefMut<Target = T> + Sized,
{
*storage.index_mut(index).deref_mut() = value;
}
struct GetNode<S, Storage> {
storage: Storage,
_s: PhantomData<S>,
}
#[node_macro::node_fn(GetNode<_S>)]
fn get_node<_S, I>(index: I, storage: &'any_input _S) -> &'input _S::Output
where
_S: Index<I>,
_S::Output: Sized,
{
storage.index(index)
}

View File

@ -2,60 +2,47 @@ use core::marker::PhantomData;
use crate::Node;
pub struct ComposeNode<First: for<'i> Node<'i, I>, Second: for<'i> Node<'i, <First as Node<'i, I>>::Output>, I> {
#[derive(Clone)]
pub struct ComposeNode<First, Second, I> {
first: First,
second: Second,
phantom: PhantomData<I>,
}
impl<'i, Input: 'i, First, Second> Node<'i, Input> for ComposeNode<First, Second, Input>
impl<'i, 'f: 'i, 's: 'i, Input: 'i, First, Second> Node<'i, Input> for ComposeNode<First, Second, Input>
where
First: for<'a> Node<'a, Input> + 'i,
Second: for<'a> Node<'a, <First as Node<'a, Input>>::Output> + 'i,
First: Node<'i, Input>,
Second: Node<'i, <First as Node<'i, Input>>::Output> + 'i,
{
type Output = <Second as Node<'i, <First as Node<'i, Input>>::Output>>::Output;
fn eval(&'i self, input: Input) -> Self::Output {
let arg = self.first.eval(input);
self.second.eval(arg)
let second = &self.second;
second.eval(arg)
}
}
impl<First, Second, Input> ComposeNode<First, Second, Input>
impl<'i, First, Second, Input: 'i> ComposeNode<First, Second, Input>
where
First: for<'a> Node<'a, Input>,
Second: for<'a> Node<'a, <First as Node<'a, Input>>::Output>,
First: Node<'i, Input>,
Second: Node<'i, <First as Node<'i, Input>>::Output>,
{
pub const fn new(first: First, second: Second) -> Self {
ComposeNode::<First, Second, Input> { first, second, phantom: PhantomData }
}
}
// impl Clone for ComposeNode<First, Second, Input>
impl<First, Second, Input> Clone for ComposeNode<First, Second, Input>
where
First: for<'a> Node<'a, Input> + Clone,
Second: for<'a> Node<'a, <First as Node<'a, Input>>::Output> + Clone,
{
fn clone(&self) -> Self {
ComposeNode::<First, Second, Input> {
first: self.first.clone(),
second: self.second.clone(),
phantom: PhantomData,
}
}
}
pub trait Then<'i, Input: 'i>: Sized {
fn then<Second>(self, second: Second) -> ComposeNode<Self, Second, Input>
where
Self: for<'a> Node<'a, Input>,
Second: for<'a> Node<'a, <Self as Node<'a, Input>>::Output>,
Self: Node<'i, Input>,
Second: Node<'i, <Self as Node<'i, Input>>::Output>,
{
ComposeNode::new(self, second)
}
}
impl<'i, First: for<'a> Node<'a, Input>, Input: 'i> Then<'i, Input> for First {}
impl<'i, First: Node<'i, Input>, Input: 'i> Then<'i, Input> for First {}
pub struct ConsNode<I: From<()>, Root>(pub Root, PhantomData<I>);
@ -89,4 +76,16 @@ mod test {
let type_erased = &compose as &dyn for<'i> Node<'i, (), Output = &'i u32>;
assert_eq!(type_erased.eval(()), &4u32);
}
#[test]
fn test_ref_eval() {
let value = ValueNode::new(5);
assert_eq!((&value).eval(()), &5);
let id = IdNode::new();
let compose = ComposeNode::new(&value, &id);
assert_eq!(compose.eval(()), &5);
}
}

View File

@ -2,6 +2,7 @@ use core::any::TypeId;
#[cfg(not(feature = "std"))]
pub use alloc::borrow::Cow;
use dyn_any::StaticType;
#[cfg(feature = "std")]
pub use std::borrow::Cow;
@ -28,6 +29,8 @@ macro_rules! concrete {
Type::Concrete(TypeDescriptor {
id: Some(core::any::TypeId::of::<$type>()),
name: Cow::Borrowed(core::any::type_name::<$type>()),
size: core::mem::size_of::<$type>(),
align: core::mem::align_of::<$type>(),
})
};
}
@ -65,6 +68,8 @@ pub struct TypeDescriptor {
#[specta(skip)]
pub id: Option<TypeId>,
pub name: Cow<'static, str>,
pub size: usize,
pub align: usize,
}
impl core::hash::Hash for TypeDescriptor {
@ -137,6 +142,32 @@ impl Type {
}
}
impl Type {
pub fn new<T: StaticType + Sized>() -> Self {
Self::Concrete(TypeDescriptor {
id: Some(TypeId::of::<T::Static>()),
name: Cow::Borrowed(core::any::type_name::<T::Static>()),
size: core::mem::size_of::<T>(),
align: core::mem::align_of::<T>(),
})
}
pub fn size(&self) -> Option<usize> {
match self {
Self::Generic(_) => None,
Self::Concrete(ty) => Some(ty.size),
Self::Fn(_, _) => None,
}
}
pub fn align(&self) -> Option<usize> {
match self {
Self::Generic(_) => None,
Self::Concrete(ty) => Some(ty.align),
Self::Fn(_, _) => None,
}
}
}
impl core::fmt::Debug for Type {
fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result {
match self {

View File

@ -40,7 +40,7 @@ impl<T: Clone> Clone for ValueNode<T> {
}
impl<T: Clone + Copy> Copy for ValueNode<T> {}
#[derive(Clone)]
#[derive(Clone, Copy)]
pub struct ClonedNode<T: Clone>(pub T);
impl<'i, T: Clone + 'i> Node<'i, ()> for ClonedNode<T> {
@ -61,7 +61,22 @@ impl<T: Clone> From<T> for ClonedNode<T> {
ClonedNode::new(value)
}
}
impl<T: Clone + Copy> Copy for ClonedNode<T> {}
#[derive(Clone, Copy)]
pub struct CopiedNode<T: Copy>(pub T);
impl<'i, T: Copy + 'i> Node<'i, ()> for CopiedNode<T> {
type Output = T;
fn eval(&'i self, _input: ()) -> Self::Output {
self.0
}
}
impl<T: Copy> CopiedNode<T> {
pub const fn new(value: T) -> CopiedNode<T> {
CopiedNode(value)
}
}
#[derive(Default)]
pub struct DefaultNode<T>(PhantomData<T>);

View File

@ -122,6 +122,20 @@ name = "bytemuck"
version = "1.12.3"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "aaa3a8d9a1ca92e282c96a32d6511b695d7d994d1d102ba85d279f9b2756947f"
dependencies = [
"bytemuck_derive",
]
[[package]]
name = "bytemuck_derive"
version = "1.4.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "1aca418a974d83d40a0c1f0c5cba6ff4bc28d8df099109ca459a2118d40b6322"
dependencies = [
"proc-macro2",
"quote",
"syn",
]
[[package]]
name = "byteorder"
@ -293,7 +307,7 @@ dependencies = [
[[package]]
name = "dyn-any"
version = "0.2.1"
version = "0.3.1"
dependencies = [
"dyn-any-derive",
"glam",
@ -302,7 +316,7 @@ dependencies = [
[[package]]
name = "dyn-any-derive"
version = "0.2.1"
version = "0.3.0"
dependencies = [
"proc-macro2",
"quote",
@ -346,6 +360,106 @@ version = "1.0.7"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "3f9eec918d3f24069decb9af1554cad7c880e2da24a9afd88aca000531ab82c1"
[[package]]
name = "futures"
version = "0.3.27"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "531ac96c6ff5fd7c62263c5e3c67a603af4fcaee2e1a0ae5565ba3a11e69e549"
dependencies = [
"futures-channel",
"futures-core",
"futures-executor",
"futures-io",
"futures-sink",
"futures-task",
"futures-util",
]
[[package]]
name = "futures-channel"
version = "0.3.27"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "164713a5a0dcc3e7b4b1ed7d3b433cabc18025386f9339346e8daf15963cf7ac"
dependencies = [
"futures-core",
"futures-sink",
]
[[package]]
name = "futures-core"
version = "0.3.27"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "86d7a0c1aa76363dac491de0ee99faf6941128376f1cf96f07db7603b7de69dd"
[[package]]
name = "futures-executor"
version = "0.3.27"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "1997dd9df74cdac935c76252744c1ed5794fac083242ea4fe77ef3ed60ba0f83"
dependencies = [
"futures-core",
"futures-task",
"futures-util",
]
[[package]]
name = "futures-intrusive"
version = "0.5.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "1d930c203dd0b6ff06e0201a4a2fe9149b43c684fd4420555b26d21b1a02956f"
dependencies = [
"futures-core",
"lock_api",
"parking_lot",
]
[[package]]
name = "futures-io"
version = "0.3.27"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "89d422fa3cbe3b40dca574ab087abb5bc98258ea57eea3fd6f1fa7162c778b91"
[[package]]
name = "futures-macro"
version = "0.3.27"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "3eb14ed937631bd8b8b8977f2c198443447a8355b6e3ca599f38c975e5a963b6"
dependencies = [
"proc-macro2",
"quote",
"syn",
]
[[package]]
name = "futures-sink"
version = "0.3.27"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "ec93083a4aecafb2a80a885c9de1f0ccae9dbd32c2bb54b0c3a65690e0b8d2f2"
[[package]]
name = "futures-task"
version = "0.3.27"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "fd65540d33b37b16542a0438c12e6aeead10d4ac5d05bd3f805b8f35ab592879"
[[package]]
name = "futures-util"
version = "0.3.27"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "3ef6b17e481503ec85211fed8f39d1970f128935ca1f814cd32ac4a6842e84ab"
dependencies = [
"futures-channel",
"futures-core",
"futures-io",
"futures-macro",
"futures-sink",
"futures-task",
"memchr",
"pin-project-lite",
"pin-utils",
"slab",
]
[[package]]
name = "fxhash"
version = "0.2.1"
@ -382,6 +496,8 @@ version = "0.22.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "12f597d56c1bd55a811a1be189459e8fad2bbc272616375602443bdfb37fa774"
dependencies = [
"bytemuck",
"num-traits",
"serde",
]
@ -418,6 +534,7 @@ dependencies = [
"bytemuck",
"dyn-any",
"glam",
"gpu-executor",
"graph-craft",
"graphene-core",
"log",
@ -430,6 +547,26 @@ dependencies = [
"tera",
]
[[package]]
name = "gpu-executor"
version = "0.1.0"
dependencies = [
"anyhow",
"base64",
"bytemuck",
"dyn-any",
"futures",
"futures-intrusive",
"glam",
"graph-craft",
"graphene-core",
"log",
"node-macro",
"num-traits",
"serde",
"spirv",
]
[[package]]
name = "graph-craft"
version = "0.1.0"
@ -454,17 +591,22 @@ name = "graphene-core"
version = "0.1.0"
dependencies = [
"async-trait",
"base64",
"bezier-rs",
"bytemuck",
"dyn-any",
"glam",
"kurbo",
"log",
"node-macro",
"num-derive",
"num-traits",
"once_cell",
"rand_chacha",
"serde",
"specta",
"spin",
"spirv-std",
]
[[package]]
@ -549,6 +691,12 @@ dependencies = [
"cfg-if",
]
[[package]]
name = "internal-iterator"
version = "0.2.1"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "a668ef46056a63366da9d74f48062da9ece1a27958f2f3704aa6f7421c4433f5"
[[package]]
name = "itertools"
version = "0.10.5"
@ -643,6 +791,12 @@ dependencies = [
"cfg-if",
]
[[package]]
name = "longest-increasing-subsequence"
version = "0.1.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "b3bd0dd2cd90571056fdb71f6275fada10131182f84899f4b2a916e565d81d86"
[[package]]
name = "memchr"
version = "2.5.0"
@ -658,6 +812,17 @@ dependencies = [
"syn",
]
[[package]]
name = "num-derive"
version = "0.3.3"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "876a53fff98e03a936a674b29568b0e605f06b29372c2489ff4de23f1949743d"
dependencies = [
"proc-macro2",
"quote",
"syn",
]
[[package]]
name = "num-integer"
version = "0.1.45"
@ -693,6 +858,29 @@ version = "1.17.1"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "b7e5500299e16ebb147ae15a00a942af264cf3688f47923b8fc2cd5858f23ad3"
[[package]]
name = "parking_lot"
version = "0.12.1"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "3742b2c103b9f06bc9fff0a37ff4912935851bee6d36f3c02bcc755bcfec228f"
dependencies = [
"lock_api",
"parking_lot_core",
]
[[package]]
name = "parking_lot_core"
version = "0.9.7"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "9069cbb9f99e3a5083476ccb29ceb1de18b9118cafa53e90c9551235de2b9521"
dependencies = [
"cfg-if",
"libc",
"redox_syscall",
"smallvec",
"windows-sys",
]
[[package]]
name = "parse-zoneinfo"
version = "0.3.0"
@ -797,6 +985,18 @@ dependencies = [
"uncased",
]
[[package]]
name = "pin-project-lite"
version = "0.2.9"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "e0a7ae3ac2f1173085d398531c705756c94a4c56843785df85a60c1a0afac116"
[[package]]
name = "pin-utils"
version = "0.1.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "8b870d8c151b6f2fb93e84a13146138f05d02ed11c7e7c54f8826aaaf7c9f184"
[[package]]
name = "ppv-lite86"
version = "0.2.17"
@ -917,14 +1117,15 @@ checksum = "08d43f7aa6b08d49f382cde6a7982047c3426db949b1424bc4b7ec9ae12c6ce2"
[[package]]
name = "rustc_codegen_spirv"
version = "0.5.0"
version = "0.7.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "1b803b49618cdde99e1065af9415f489b993374765e30a6b80f2bea2cca65914"
checksum = "48862935255ac76002118e3561c54f3fb413a752ef5cd0bdae693cf5c5ce37a5"
dependencies = [
"ar",
"either",
"hashbrown 0.11.2",
"indexmap",
"itertools",
"lazy_static",
"libc",
"num-traits",
@ -944,9 +1145,9 @@ dependencies = [
[[package]]
name = "rustc_codegen_spirv-types"
version = "0.5.0"
version = "0.7.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "330aedc6b09b9bf3c58cc7fb942c1377310a9cff00fae9e4f6cc09a7a28f542e"
checksum = "339309d7fce2e7204decea1b2683c4d439a6ebe4c1518d8134fe10aefeae7362"
dependencies = [
"rspirv",
"serde",
@ -1037,6 +1238,15 @@ version = "0.3.10"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "7bd3e3206899af3f8b12af284fafc038cc1dc2b41d1b89dd17297221c5d225de"
[[package]]
name = "slab"
version = "0.4.8"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "6528351c9bc8ab22353f9d776db39a20288e8d6c37ef8cfe3317cf875eecfc2d"
dependencies = [
"autocfg",
]
[[package]]
name = "slug"
version = "0.1.4"
@ -1092,16 +1302,18 @@ dependencies = [
[[package]]
name = "spirt"
version = "0.1.0"
version = "0.2.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "06834ebbbbc6f86448fd5dc7ccbac80e36f52f8d66838683752e19d3cae9a459"
checksum = "e24fa996f12f3c667efbceaa99c222b8910a295a14d2c43c3880dfab2752def7"
dependencies = [
"arrayvec",
"bytemuck",
"elsa",
"indexmap",
"internal-iterator",
"itertools",
"lazy_static",
"longest-increasing-subsequence",
"rustc-hash",
"serde",
"serde_json",
@ -1120,9 +1332,9 @@ dependencies = [
[[package]]
name = "spirv-builder"
version = "0.5.0"
version = "0.7.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "93f656f97ac742e5603843d2ea3ea644cdf5630635b3320ec87666385766e7ab"
checksum = "0310607328cbb098b681e4580a25871fcf72dd9e1d560e7cc09c409eedef2a27"
dependencies = [
"memchr",
"raw-string",
@ -1132,6 +1344,37 @@ dependencies = [
"serde_json",
]
[[package]]
name = "spirv-std"
version = "0.7.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "3197bd4c021c2dfc0f9dfb356312c8f7842d972d5545c308ad86422c2e2d3e66"
dependencies = [
"bitflags",
"glam",
"num-traits",
"spirv-std-macros",
"spirv-std-types",
]
[[package]]
name = "spirv-std-macros"
version = "0.7.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "bbaffad626ab9d3ac61c4b74b5d51cb52f1939a8041d7ac09ec828eb4ad44d72"
dependencies = [
"proc-macro2",
"quote",
"spirv-std-types",
"syn",
]
[[package]]
name = "spirv-std-types"
version = "0.7.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "ab83875e851bc803c687024d2d950730f350c0073714b95b3a6b1d22e9eac42a"
[[package]]
name = "spirv-tools"
version = "0.9.0"
@ -1434,6 +1677,72 @@ version = "0.4.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "712e227841d057c1ee1cd2fb22fa7e5a5461ae8e48fa2ca79ec42cfc1931183f"
[[package]]
name = "windows-sys"
version = "0.45.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "75283be5efb2831d37ea142365f009c02ec203cd29a3ebecbc093d52315b66d0"
dependencies = [
"windows-targets",
]
[[package]]
name = "windows-targets"
version = "0.42.2"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "8e5180c00cd44c9b1c88adb3693291f1cd93605ded80c250a75d472756b4d071"
dependencies = [
"windows_aarch64_gnullvm",
"windows_aarch64_msvc",
"windows_i686_gnu",
"windows_i686_msvc",
"windows_x86_64_gnu",
"windows_x86_64_gnullvm",
"windows_x86_64_msvc",
]
[[package]]
name = "windows_aarch64_gnullvm"
version = "0.42.2"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "597a5118570b68bc08d8d59125332c54f1ba9d9adeedeef5b99b02ba2b0698f8"
[[package]]
name = "windows_aarch64_msvc"
version = "0.42.2"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "e08e8864a60f06ef0d0ff4ba04124db8b0fb3be5776a5cd47641e942e58c4d43"
[[package]]
name = "windows_i686_gnu"
version = "0.42.2"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "c61d927d8da41da96a81f029489353e68739737d3beca43145c8afec9a31a84f"
[[package]]
name = "windows_i686_msvc"
version = "0.42.2"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "44d840b6ec649f480a41c8d80f9c65108b92d89345dd94027bfe06ac444d1060"
[[package]]
name = "windows_x86_64_gnu"
version = "0.42.2"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "8de912b8b8feb55c064867cf047dda097f92d51efad5b491dfb98f6bbb70cb36"
[[package]]
name = "windows_x86_64_gnullvm"
version = "0.42.2"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "26d41b46a36d453748aedef1486d5c7a85db22e56aff34643984ea85514e94a3"
[[package]]
name = "windows_x86_64_msvc"
version = "0.42.2"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "9aec5da331524158c6d1a4ac0ab1541149c0b9505fde06423b02f5ef0106b9f0"
[[package]]
name = "xxhash-rust"
version = "0.8.6"

View File

@ -13,18 +13,25 @@ serde = ["graphene-core/serde", "glam/serde"]
[dependencies]
graphene-core = { path = "../gcore", features = ["async", "std", "alloc"] }
graph-craft = {path = "../graph-craft", features = ["serde"] }
dyn-any = { path = "../../libraries/dyn-any", features = ["log-bad-types", "rc", "glam"] }
graph-craft = { path = "../graph-craft", features = ["serde"] }
gpu-executor = { path = "../gpu-executor" }
dyn-any = { path = "../../libraries/dyn-any", features = [
"log-bad-types",
"rc",
"glam",
] }
num-traits = "0.2"
log = "0.4"
serde = { version = "1", features = ["derive", "rc"]}
serde = { version = "1", features = ["derive", "rc"] }
glam = { version = "0.22" }
base64 = "0.13"
bytemuck = { version = "1.8" }
nvtx = { version = "1.1.1", optional = true }
tempfile = "3"
spirv-builder = { version = "0.5", default-features = false, features=["use-installed-tools"] }
spirv-builder = { version = "0.7", default-features = false, features = [
"use-installed-tools",
] }
tera = { version = "1.17.1" }
anyhow = "1.0.66"
serde_json = "1.0.91"

View File

@ -11,7 +11,8 @@ profiling = []
# See more keys and their definitions at https://doc.rust-lang.org/cargo/reference/manifest.html
[dependencies]
graph-craft = {path = "../../graph-craft" , features = ["serde"]}
graph-craft = { path = "../../graph-craft", features = ["serde"] }
gpu-executor = { path = "../../gpu-executor" }
log = "0.4"
anyhow = "1.0.66"
serde_json = "1.0.91"

View File

@ -1,8 +1,15 @@
use gpu_executor::ShaderIO;
use graph_craft::{proto::ProtoNetwork, Type};
use serde::{Deserialize, Serialize};
use std::io::Write;
pub fn compile_spirv(network: &graph_craft::document::NodeNetwork, input_type: &str, output_type: &str, compile_dir: Option<&str>, manifest_path: &str) -> anyhow::Result<Vec<u8>> {
let serialized_graph = serde_json::to_string(&network)?;
pub fn compile_spirv(request: &CompileRequest, compile_dir: Option<&str>, manifest_path: &str) -> anyhow::Result<Vec<u8>> {
let serialized_graph = serde_json::to_string(&gpu_executor::CompileRequest {
network: request.network.clone(),
io: request.shader_io.clone(),
})?;
let features = "";
#[cfg(feature = "profiling")]
let features = "profiling";
@ -19,9 +26,6 @@ pub fn compile_spirv(network: &graph_craft::document::NodeNetwork, input_type: &
.envs(non_cargo_env_vars)
.arg("--features")
.arg(features)
.arg("--")
.arg(input_type)
.arg(output_type)
// TODO: handle None case properly
.arg(compile_dir.unwrap())
.stdin(std::process::Stdio::piped())
@ -38,16 +42,27 @@ pub fn compile_spirv(network: &graph_craft::document::NodeNetwork, input_type: &
#[derive(Serialize, Deserialize, Debug, Clone, PartialEq)]
pub struct CompileRequest {
network: graph_craft::document::NodeNetwork,
input_type: String,
output_type: String,
network: graph_craft::proto::ProtoNetwork,
input_types: Vec<Type>,
output_type: Type,
shader_io: ShaderIO,
}
impl CompileRequest {
pub fn new(network: graph_craft::document::NodeNetwork, input_type: String, output_type: String) -> Self {
Self { network, input_type, output_type }
pub fn new(network: ProtoNetwork, input_types: Vec<Type>, output_type: Type, io: ShaderIO) -> Self {
// TODO: add type checking
// for (input, buffer) in input_types.iter().zip(io.inputs.iter()) {
// assert_eq!(input, &buffer.ty());
// }
// assert_eq!(output_type, io.output.ty());
Self {
network,
input_types,
output_type,
shader_io: io,
}
}
pub fn compile(&self, compile_dir: &str, manifest_path: &str) -> anyhow::Result<Vec<u8>> {
compile_spirv(&self.network, &self.input_type, &self.output_type, Some(compile_dir), manifest_path)
compile_spirv(self, Some(compile_dir), manifest_path)
}
}

View File

@ -1,3 +1,10 @@
[toolchain]
channel = "nightly-2022-12-18"
components = ["rust-src", "rustc-dev", "llvm-tools-preview", "clippy", "cargofmt", "rustc"]
channel = "nightly-2023-03-04"
components = [
"rust-src",
"rustc-dev",
"llvm-tools-preview",
"clippy",
"rustfmt",
"rustc",
]

View File

@ -1,6 +1,8 @@
use std::path::Path;
use gpu_executor::{GPUConstant, ShaderIO, ShaderInput, SpirVCompiler};
use graph_craft::proto::*;
use graphene_core::Cow;
use std::path::{Path, PathBuf};
use tera::Context;
fn create_cargo_toml(metadata: &Metadata) -> Result<String, tera::Error> {
@ -24,10 +26,10 @@ impl Metadata {
}
}
pub fn create_files(matadata: &Metadata, network: &ProtoNetwork, compile_dir: &Path, input_type: &str, output_type: &str) -> anyhow::Result<()> {
pub fn create_files(metadata: &Metadata, network: &ProtoNetwork, compile_dir: &Path, io: &ShaderIO) -> anyhow::Result<()> {
let src = compile_dir.join("src");
let cargo_file = compile_dir.join("Cargo.toml");
let cargo_toml = create_cargo_toml(matadata)?;
let cargo_toml = create_cargo_toml(metadata)?;
std::fs::write(cargo_file, cargo_toml)?;
let toolchain_file = compile_dir.join("rust-toolchain.toml");
@ -44,26 +46,100 @@ pub fn create_files(matadata: &Metadata, network: &ProtoNetwork, compile_dir: &P
}
}
let lib = src.join("lib.rs");
let shader = serialize_gpu(network, input_type, output_type)?;
println!("{}", shader);
let shader = serialize_gpu(network, io)?;
eprintln!("{}", shader);
std::fs::write(lib, shader)?;
Ok(())
}
pub fn serialize_gpu(network: &ProtoNetwork, input_type: &str, output_type: &str) -> anyhow::Result<String> {
assert_eq!(network.inputs.len(), 1);
fn constant_attribute(constant: &GPUConstant) -> &'static str {
match constant {
GPUConstant::SubGroupId => "subgroup_id",
GPUConstant::SubGroupInvocationId => "subgroup_local_invocation_id",
GPUConstant::SubGroupSize => todo!(),
GPUConstant::NumSubGroups => "num_subgroups",
GPUConstant::WorkGroupId => "workgroup_id",
GPUConstant::WorkGroupInvocationId => "local_invocation_id",
GPUConstant::WorkGroupSize => todo!(),
GPUConstant::NumWorkGroups => "num_workgroups",
GPUConstant::GlobalInvocationId => "global_invocation_id",
GPUConstant::GlobalSize => todo!(),
}
}
pub fn construct_argument(input: &ShaderInput<()>, position: u32) -> String {
match input {
ShaderInput::Constant(constant) => format!("#[spirv({})] i{}: {},", constant_attribute(constant), position, constant.ty()),
ShaderInput::UniformBuffer(_, ty) => {
format!("#[spirv(uniform, descriptor_set = 0, binding = {})] i{}: &[{}]", position, position, ty,)
}
ShaderInput::StorageBuffer(_, ty) | ShaderInput::ReadBackBuffer(_, ty) => {
format!("#[spirv(storage_buffer, descriptor_set = 0, binding = {})] i{}: &[{}]", position, position, ty,)
}
ShaderInput::OutputBuffer(_, ty) => {
format!("#[spirv(storage_buffer, descriptor_set = 0, binding = {})] i{}: &mut[{}]", position, position, ty,)
}
ShaderInput::WorkGroupMemory(_, ty) => format!("#[spirv(workgroup_memory] i{}: {}", position, ty,),
}
}
struct GpuCompiler {
compile_dir: PathBuf,
}
impl SpirVCompiler for GpuCompiler {
fn compile(&self, network: ProtoNetwork, io: &ShaderIO) -> anyhow::Result<gpu_executor::Shader> {
let metadata = Metadata::new("project".to_owned(), vec!["test@example.com".to_owned()]);
create_files(&metadata, &network, &self.compile_dir, io)?;
let result = compile(&self.compile_dir)?;
let bytes = std::fs::read(result.module.unwrap_single())?;
let words = bytes.chunks(4).map(|chunk| u32::from_ne_bytes(chunk.try_into().unwrap())).collect::<Vec<_>>();
Ok(gpu_executor::Shader {
source: Cow::Owned(words),
name: "",
io: io.clone(),
})
}
}
pub fn serialize_gpu(network: &ProtoNetwork, io: &ShaderIO) -> anyhow::Result<String> {
fn nid(id: &u64) -> String {
format!("n{id}")
}
dbg!(&network);
dbg!(&io);
let inputs = io.inputs.iter().enumerate().map(|(i, input)| construct_argument(input, i as u32)).collect::<Vec<_>>();
let mut nodes = Vec::new();
let mut input_nodes = Vec::new();
#[derive(serde::Serialize)]
struct Node {
id: String,
fqn: String,
args: Vec<String>,
}
for id in network.inputs.iter() {
let Some((_, node)) = network.nodes.iter().find(|(i, _)| i == id) else {
anyhow::bail!("Input node not found");
};
let fqn = &node.identifier.name;
let id = nid(id);
input_nodes.push(Node {
id,
fqn: fqn.to_string().split("<").next().unwrap().to_owned(),
args: node.construction_args.new_function_args(),
});
}
for (ref id, node) in network.nodes.iter() {
if network.inputs.contains(id) {
continue;
}
let fqn = &node.identifier.name;
let id = nid(id);
@ -78,8 +154,8 @@ pub fn serialize_gpu(network: &ProtoNetwork, input_type: &str, output_type: &str
let mut tera = tera::Tera::default();
tera.add_raw_template("spirv", template)?;
let mut context = Context::new();
context.insert("input_type", &input_type);
context.insert("output_type", &output_type);
context.insert("inputs", &inputs);
context.insert("input_nodes", &input_nodes);
context.insert("nodes", &nodes);
context.insert("last_node", &nid(&network.output));
context.insert("compute_threads", &64);
@ -89,14 +165,15 @@ pub fn serialize_gpu(network: &ProtoNetwork, input_type: &str, output_type: &str
use spirv_builder::{MetadataPrintout, SpirvBuilder, SpirvMetadata};
pub fn compile(dir: &Path) -> Result<spirv_builder::CompileResult, spirv_builder::SpirvBuilderError> {
dbg!(&dir);
let result = SpirvBuilder::new(dir, "spirv-unknown-spv1.5")
let result = SpirvBuilder::new(dir, "spirv-unknown-vulkan1.2")
.print_metadata(MetadataPrintout::DependencyOnly)
.multimodule(false)
.preserve_bindings(true)
.release(true)
//.relax_struct_store(true)
//.relax_block_layout(true)
.spirv_metadata(SpirvMetadata::Full)
.extra_arg("no-early-report-zombies")
.extra_arg("no-infer-storage-classes")
.extra_arg("spirt-passes=qptr")
.build()?;
Ok(result)
@ -104,7 +181,6 @@ pub fn compile(dir: &Path) -> Result<spirv_builder::CompileResult, spirv_builder
#[cfg(test)]
mod test {
#[test]
fn test_create_cargo_toml() {
let cargo_toml = super::create_cargo_toml(&super::Metadata {

View File

@ -1,4 +1,5 @@
use gpu_compiler as compiler;
use gpu_executor::CompileRequest;
use graph_craft::document::NodeNetwork;
use std::io::Write;
@ -6,17 +7,13 @@ fn main() -> anyhow::Result<()> {
println!("Starting GPU Compiler!");
let mut stdin = std::io::stdin();
let mut stdout = std::io::stdout();
let input_type = std::env::args().nth(1).expect("input type arg missing");
let output_type = std::env::args().nth(2).expect("output type arg missing");
let compile_dir = std::env::args().nth(3).map(|x| std::path::PathBuf::from(&x)).unwrap_or(tempfile::tempdir()?.into_path());
let network: NodeNetwork = serde_json::from_reader(&mut stdin)?;
let compiler = graph_craft::executor::Compiler {};
let proto_network = compiler.compile_single(network, true).unwrap();
let compile_dir = std::env::args().nth(1).map(|x| std::path::PathBuf::from(&x)).unwrap_or(tempfile::tempdir()?.into_path());
let request: CompileRequest = serde_json::from_reader(&mut stdin)?;
dbg!(&compile_dir);
let metadata = compiler::Metadata::new("project".to_owned(), vec!["test@example.com".to_owned()]);
compiler::create_files(&metadata, &proto_network, &compile_dir, &input_type, &output_type)?;
compiler::create_files(&metadata, &request.network, &compile_dir, &request.io)?;
let result = compiler::compile(&compile_dir)?;
let bytes = std::fs::read(result.module.unwrap_single())?;

View File

@ -1,7 +1,7 @@
[package]
authors = [{% for author in authors %}"{{author}}", {% endfor %}]
name = "{{name}}-node"
version = "0.1.0"
authors = [{%for author in authors%}"{{author}}", {%endfor%}]
edition = "2021"
license = "MIT OR Apache-2.0"
publish = false
@ -13,5 +13,7 @@ crate-type = ["dylib", "lib"]
libm = { git = "https://github.com/rust-lang/libm", tag = "0.2.5" }
[dependencies]
spirv-std = { version = "0.5" , features= ["glam"]}
graphene-core = {path = "{{gcore_path}}", default-features = false, features = ["gpu"]}
spirv-std = { version = "0.7" }
graphene-core = { path = "{{gcore_path}}", default-features = false, features = [
"gpu",
] }

View File

@ -1,3 +1,10 @@
[toolchain]
channel = "nightly-2022-12-18"
components = ["rust-src", "rustc-dev", "llvm-tools-preview", "clippy", "cargofmt", "rustc"]
channel = "nightly-2023-03-04"
components = [
"rust-src",
"rustc-dev",
"llvm-tools-preview",
"clippy",
"rustfmt",
"rustc",
]

View File

@ -1,6 +1,5 @@
#![no_std]
#![feature(unchecked_math)]
#![deny(warnings)]
#[cfg(target_arch = "spirv")]
extern crate spirv_std;
@ -14,25 +13,23 @@ pub mod gpu {
#[allow(unused)]
#[spirv(compute(threads({{compute_threads}})))]
pub fn eval (
#[spirv(global_invocation_id)] global_id: UVec3,
#[spirv(storage_buffer, descriptor_set = 0, binding = 0)] a: &[{{input_type}}],
#[spirv(storage_buffer, descriptor_set = 0, binding = 1)] y: &mut [{{output_type}}],
//#[spirv(push_constant)] push_consts: &graphene_core::gpu::PushConstants,
{% for input in inputs %}
{{input}}
{% endfor %}
) {
let gid = global_id.x as usize;
// Only process up to n, which is the length of the buffers.
//if global_id.x < push_consts.n {
y[gid] = node_graph(a[gid]);
//}
}
fn node_graph(input: {{input_type}}) -> {{output_type}} {
use graphene_core::Node;
{% for input in input_nodes %}
let i{{loop.index0}} = graphene_core::value::CopiedNode::new(i{{loop.index0}});
let _{{input.id}} = {{input.fqn}}::new({% for arg in input.args %}{{arg}}, {% endfor %});
let {{input.id}} = graphene_core::structural::ComposeNode::new(i{{loop.index0}}, _{{input.id}});
{% endfor %}
{% for node in nodes %}
let {{node.id}} = {{node.fqn}}::new({% for arg in node.args %}{{arg}}, {% endfor %});
{% endfor %}
{{last_node}}.eval(input)
}
let output = {{last_node}}.eval(());
// TODO: Write output to buffer
}
}

View File

@ -0,0 +1,36 @@
[package]
name = "gpu-executor"
version = "0.1.0"
edition = "2021"
license = "MIT OR Apache-2.0"
[features]
default = []
# See more keys and their definitions at https://doc.rust-lang.org/cargo/reference/manifest.html
[dependencies]
graphene-core = { path = "../gcore", features = [
"async",
"std",
"alloc",
"gpu",
] }
graph-craft = { path = "../graph-craft", features = ["serde"] }
node-macro = { path = "../node-macro" }
dyn-any = { path = "../../libraries/dyn-any", features = [
"log-bad-types",
"rc",
"glam",
] }
num-traits = "0.2"
log = "0.4"
serde = { version = "1", features = ["derive", "rc"] }
glam = "0.22"
base64 = "0.13"
bytemuck = { version = "1.8" }
anyhow = "1.0.66"
spirv = "0.2.0"
futures-intrusive = "0.5.0"
futures = "0.3.25"

View File

@ -0,0 +1,258 @@
use graph_craft::proto::ProtoNetwork;
use graphene_core::*;
use anyhow::Result;
use dyn_any::StaticType;
use futures::Future;
use glam::UVec3;
use serde::{Deserialize, Serialize};
use std::borrow::Cow;
use std::pin::Pin;
type ReadBackFuture = Pin<Box<dyn Future<Output = Result<Vec<u8>>>>>;
pub trait GpuExecutor {
type ShaderHandle;
type BufferHandle;
type CommandBuffer;
fn load_shader(&self, shader: Shader) -> Result<Self::ShaderHandle>;
fn create_uniform_buffer<T: ToUniformBuffer>(&self, data: T) -> Result<ShaderInput<Self::BufferHandle>>;
fn create_storage_buffer<T: ToStorageBuffer>(&self, data: T, options: StorageBufferOptions) -> Result<ShaderInput<Self::BufferHandle>>;
fn create_output_buffer(&self, len: usize, ty: Type, cpu_readable: bool) -> Result<ShaderInput<Self::BufferHandle>>;
fn create_compute_pass(&self, layout: &PipelineLayout<Self>, read_back: Option<ShaderInput<Self::BufferHandle>>, instances: u32) -> Result<Self::CommandBuffer>;
fn execute_compute_pipeline(&self, encoder: Self::CommandBuffer) -> Result<()>;
fn read_output_buffer(&self, buffer: ShaderInput<Self::BufferHandle>) -> Result<ReadBackFuture>;
}
pub trait SpirVCompiler {
fn compile(&self, network: ProtoNetwork, io: &ShaderIO) -> Result<Shader>;
}
#[derive(Clone, Debug, PartialEq, Serialize, Deserialize)]
pub struct CompileRequest {
pub network: ProtoNetwork,
pub io: ShaderIO,
}
#[derive(Clone, Debug, PartialEq, Eq, Hash, Serialize, Deserialize)]
/// GPU constants that can be used as inputs to a shader.
pub enum GPUConstant {
SubGroupId,
SubGroupInvocationId,
SubGroupSize,
NumSubGroups,
WorkGroupId,
WorkGroupInvocationId,
WorkGroupSize,
NumWorkGroups,
GlobalInvocationId,
GlobalSize,
}
impl GPUConstant {
pub fn ty(&self) -> Type {
match self {
GPUConstant::SubGroupId => concrete!(u32),
GPUConstant::SubGroupInvocationId => concrete!(u32),
GPUConstant::SubGroupSize => concrete!(u32),
GPUConstant::NumSubGroups => concrete!(u32),
GPUConstant::WorkGroupId => concrete!(UVec3),
GPUConstant::WorkGroupInvocationId => concrete!(UVec3),
GPUConstant::WorkGroupSize => concrete!(u32),
GPUConstant::NumWorkGroups => concrete!(u32),
GPUConstant::GlobalInvocationId => concrete!(UVec3),
GPUConstant::GlobalSize => concrete!(UVec3),
}
}
}
#[derive(Clone, Debug, PartialEq, Eq, Hash, Serialize, Deserialize)]
/// All the possible inputs to a shader.
pub enum ShaderInput<BufferHandle> {
UniformBuffer(BufferHandle, Type),
StorageBuffer(BufferHandle, Type),
/// A struct representing a work group memory buffer. This cannot be accessed by the CPU.
WorkGroupMemory(usize, Type),
Constant(GPUConstant),
OutputBuffer(BufferHandle, Type),
ReadBackBuffer(BufferHandle, Type),
}
/// Extract the buffer handle from a shader input.
impl<BufferHandle> ShaderInput<BufferHandle> {
pub fn buffer(&self) -> Option<&BufferHandle> {
match self {
ShaderInput::UniformBuffer(buffer, _) => Some(buffer),
ShaderInput::StorageBuffer(buffer, _) => Some(buffer),
ShaderInput::WorkGroupMemory(_, _) => None,
ShaderInput::Constant(_) => None,
ShaderInput::OutputBuffer(buffer, _) => Some(buffer),
ShaderInput::ReadBackBuffer(buffer, _) => Some(buffer),
}
}
pub fn ty(&self) -> Type {
match self {
ShaderInput::UniformBuffer(_, ty) => ty.clone(),
ShaderInput::StorageBuffer(_, ty) => ty.clone(),
ShaderInput::WorkGroupMemory(_, ty) => ty.clone(),
ShaderInput::Constant(c) => c.ty(),
ShaderInput::OutputBuffer(_, ty) => ty.clone(),
ShaderInput::ReadBackBuffer(_, ty) => ty.clone(),
}
}
}
pub struct Shader<'a> {
pub source: Cow<'a, [u32]>,
pub name: &'a str,
pub io: ShaderIO,
}
#[derive(Clone, Debug, PartialEq, Eq, Hash, Serialize, Deserialize)]
pub struct ShaderIO {
pub inputs: Vec<ShaderInput<()>>,
pub output: ShaderInput<()>,
}
pub struct StorageBufferOptions {
pub cpu_writable: bool,
pub gpu_writable: bool,
pub cpu_readable: bool,
}
pub trait ToUniformBuffer: StaticType {
type UniformBufferHandle;
fn to_bytes(&self) -> Cow<[u8]>;
}
pub trait ToStorageBuffer: StaticType {
type StorageBufferHandle;
fn to_bytes(&self) -> Cow<[u8]>;
}
/// Collection of all arguments that are passed to the shader.
pub struct Bindgroup<E: GpuExecutor + ?Sized> {
pub buffers: Vec<ShaderInput<E::BufferHandle>>,
}
/// A struct representing a compute pipeline.
pub struct PipelineLayout<E: GpuExecutor + ?Sized> {
pub shader: E::ShaderHandle,
pub entry_point: String,
pub bind_group: Bindgroup<E>,
pub output_buffer: ShaderInput<E::BufferHandle>,
}
/// Extracts arguments from the function arguments and wraps them in a node.
pub struct ShaderInputNode<T> {
data: T,
}
impl<'i, T: 'i> Node<'i, ()> for ShaderInputNode<T> {
type Output = &'i T;
fn eval(&'i self, _: ()) -> Self::Output {
&self.data
}
}
impl<T> ShaderInputNode<T> {
pub fn new(data: T) -> Self {
Self { data }
}
}
pub struct UniformNode<Executor> {
executor: Executor,
}
#[node_macro::node_fn(UniformNode)]
fn uniform_node<T: ToUniformBuffer, E: GpuExecutor>(data: T, executor: &'any_input E) -> ShaderInput<E::BufferHandle> {
let handle = executor.create_uniform_buffer(data).unwrap();
handle
}
pub struct StorageNode<Executor> {
executor: Executor,
}
#[node_macro::node_fn(StorageNode)]
fn storage_node<T: ToStorageBuffer, E: GpuExecutor>(data: T, executor: &'any_input E) -> ShaderInput<E::BufferHandle> {
let handle = executor
.create_storage_buffer(
data,
StorageBufferOptions {
cpu_writable: false,
gpu_writable: true,
cpu_readable: false,
},
)
.unwrap();
handle
}
pub struct PushNode<Value> {
value: Value,
}
#[node_macro::node_fn(PushNode)]
fn push_node<T>(mut vec: Vec<T>, value: T) {
vec.push(value);
}
pub struct CreateOutputBufferNode<Executor, Ty> {
executor: Executor,
ty: Ty,
}
#[node_macro::node_fn(CreateOutputBufferNode)]
fn create_output_buffer_node<E: GpuExecutor>(size: usize, executor: &'any_input E, ty: Type) -> ShaderInput<E::BufferHandle> {
executor.create_output_buffer(size, ty, true).unwrap()
}
pub struct CreateComputePassNode<Executor, Output, Instances> {
executor: Executor,
output: Output,
instances: Instances,
}
#[node_macro::node_fn(CreateComputePassNode)]
fn create_compute_pass_node<E: GpuExecutor>(layout: PipelineLayout<E>, executor: &'any_input E, output: ShaderInput<E::BufferHandle>, instances: u32) -> E::CommandBuffer {
executor.create_compute_pass(&layout, Some(output), instances).unwrap()
}
pub struct CreatePipelineLayoutNode<_E, EntryPoint, Bindgroup, OutputBuffer> {
entry_point: EntryPoint,
bind_group: Bindgroup,
output_buffer: OutputBuffer,
_e: std::marker::PhantomData<_E>,
}
#[node_macro::node_fn(CreatePipelineLayoutNode<_E>)]
fn create_pipeline_layout_node<_E: GpuExecutor>(shader: _E::ShaderHandle, entry_point: String, bind_group: Bindgroup<_E>, output_buffer: ShaderInput<_E::BufferHandle>) -> PipelineLayout<_E> {
PipelineLayout {
shader,
entry_point,
bind_group,
output_buffer,
}
}
pub struct ExecuteComputePipelineNode<Executor> {
executor: Executor,
}
#[node_macro::node_fn(ExecuteComputePipelineNode)]
fn execute_compute_pipeline_node<E: GpuExecutor>(encoder: E::CommandBuffer, executor: &'any_input mut E) {
executor.execute_compute_pipeline(encoder).unwrap();
}
// TODO
// pub struct ReadOutputBufferNode<Executor> {
// executor: Executor,
// }
// #[node_macro::node_fn(ReadOutputBufferNode)]
// fn read_output_buffer_node<E: GpuExecutor>(buffer: E::BufferHandle, executor: &'any_input mut E) -> Vec<u8> {
// executor.read_output_buffer(buffer).await.unwrap()
// }

View File

@ -1,3 +1,4 @@
use crate::document::value::TaggedValue;
use crate::proto::{ConstructionArgs, ProtoNetwork, ProtoNode, ProtoNodeInput};
use graphene_core::{NodeIdentifier, Type};
@ -20,7 +21,7 @@ fn merge_ids(a: u64, b: u64) -> u64 {
hasher.finish()
}
#[derive(Clone, Debug, PartialEq, Default, specta::Type)]
#[derive(Clone, Debug, PartialEq, Default, specta::Type, Hash, DynAny)]
#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
pub struct DocumentNodeMetadata {
pub position: IVec2,
@ -32,7 +33,7 @@ impl DocumentNodeMetadata {
}
}
#[derive(Clone, Debug, PartialEq)]
#[derive(Clone, Debug, PartialEq, Hash, DynAny)]
#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
pub struct DocumentNode {
pub name: String,
@ -156,7 +157,7 @@ impl DocumentNode {
///
/// In this case the Cache node actually consumes its input and then manually forwards it to its parameter Node.
/// This is necessary because the Cache Node needs to short-circut the actual node evaluation.
#[derive(Debug, Clone, PartialEq, Hash)]
#[derive(Debug, Clone, PartialEq, Hash, DynAny)]
#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
pub enum NodeInput {
Node {
@ -165,7 +166,7 @@ pub enum NodeInput {
lambda: bool,
},
Value {
tagged_value: crate::document::value::TaggedValue,
tagged_value: TaggedValue,
exposed: bool,
},
Network(Type),
@ -182,7 +183,7 @@ impl NodeInput {
pub const fn lambda(node_id: NodeId, output_index: usize) -> Self {
Self::Node { node_id, output_index, lambda: true }
}
pub const fn value(tagged_value: crate::document::value::TaggedValue, exposed: bool) -> Self {
pub const fn value(tagged_value: TaggedValue, exposed: bool) -> Self {
Self::Value { tagged_value, exposed }
}
fn map_ids(&mut self, f: impl Fn(NodeId) -> NodeId) {
@ -212,11 +213,12 @@ impl NodeInput {
}
}
#[derive(Clone, Debug, PartialEq)]
#[derive(Clone, Debug, PartialEq, Hash, DynAny)]
#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
pub enum DocumentNodeImplementation {
Network(NodeNetwork),
Unresolved(NodeIdentifier),
Extract,
}
impl Default for DocumentNodeImplementation {
@ -227,23 +229,21 @@ impl Default for DocumentNodeImplementation {
impl DocumentNodeImplementation {
pub fn get_network(&self) -> Option<&NodeNetwork> {
if let DocumentNodeImplementation::Network(n) = self {
Some(n)
} else {
None
match self {
DocumentNodeImplementation::Network(n) => Some(n),
_ => None,
}
}
pub fn get_network_mut(&mut self) -> Option<&mut NodeNetwork> {
if let DocumentNodeImplementation::Network(n) = self {
Some(n)
} else {
None
match self {
DocumentNodeImplementation::Network(n) => Some(n),
_ => None,
}
}
}
#[derive(Clone, Copy, Debug, Default, PartialEq, DynAny, specta::Type)]
#[derive(Clone, Copy, Debug, Default, PartialEq, DynAny, specta::Type, Hash)]
#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
pub struct NodeOutput {
pub node_id: NodeId,
@ -267,6 +267,21 @@ pub struct NodeNetwork {
pub previous_outputs: Option<Vec<NodeOutput>>,
}
impl std::hash::Hash for NodeNetwork {
fn hash<H: std::hash::Hasher>(&self, state: &mut H) {
self.inputs.hash(state);
self.outputs.hash(state);
let mut nodes: Vec<_> = self.nodes.iter().collect();
nodes.sort_by_key(|(id, _)| *id);
for (id, node) in nodes {
id.hash(state);
node.hash(state);
}
self.disabled.hash(state);
self.previous_outputs.hash(state);
}
}
/// Graph modification functions
impl NodeNetwork {
/// Get the original output nodes of this network, ignoring any preview node
@ -701,12 +716,43 @@ impl NodeNetwork {
self.flatten_with_fns(node_id, map_ids, gen_id);
}
}
DocumentNodeImplementation::Unresolved(_) => {}
DocumentNodeImplementation::Unresolved(_) => (),
DocumentNodeImplementation::Extract => {
panic!("Extract nodes should have been removed before flattening");
}
}
assert!(!self.nodes.contains_key(&id), "Trying to insert a node into the network caused an id conflict");
self.nodes.insert(id, node);
}
pub fn resolve_extract_nodes(&mut self) {
let mut extraction_nodes = self
.nodes
.iter()
.filter(|(_, node)| matches!(node.implementation, DocumentNodeImplementation::Extract))
.map(|(id, node)| (*id, node.clone()))
.collect::<Vec<_>>();
self.nodes.retain(|_, node| !matches!(node.implementation, DocumentNodeImplementation::Extract));
for (_, node) in &mut extraction_nodes {
match node.implementation {
DocumentNodeImplementation::Extract => {
assert_eq!(node.inputs.len(), 1);
let NodeInput::Node { node_id, output_index, lambda } = node.inputs.pop().unwrap() else {
panic!("Extract node has no input");
};
assert_eq!(output_index, 0);
assert!(lambda);
let input_node = self.nodes.get_mut(&node_id).unwrap();
node.implementation = DocumentNodeImplementation::Unresolved("graphene_core::value::ValueNode".into());
node.inputs = vec![NodeInput::value(TaggedValue::DocumentNode(input_node.clone()), false)];
}
_ => (),
}
}
self.nodes.extend(extraction_nodes);
}
pub fn into_proto_networks(self) -> impl Iterator<Item = ProtoNetwork> {
let mut nodes: Vec<_> = self.nodes.into_iter().map(|(id, node)| (id, node.resolve_proto_node())).collect();
nodes.sort_unstable_by_key(|(i, _)| *i);
@ -798,6 +844,39 @@ mod test {
assert_eq!(network, maped_add);
}
#[test]
fn extract_node() {
let id_node = DocumentNode {
name: "Id".into(),
inputs: vec![],
metadata: DocumentNodeMetadata::default(),
implementation: DocumentNodeImplementation::Unresolved("graphene_core::ops::IdNode".into()),
};
let mut extraction_network = NodeNetwork {
inputs: vec![],
outputs: vec![NodeOutput::new(1, 0)],
nodes: [
id_node.clone(),
DocumentNode {
name: "Extract".into(),
inputs: vec![NodeInput::lambda(0, 0)],
metadata: DocumentNodeMetadata::default(),
implementation: DocumentNodeImplementation::Extract,
},
]
.into_iter()
.enumerate()
.map(|(id, node)| (id as NodeId, node))
.collect(),
..Default::default()
};
extraction_network.resolve_extract_nodes();
assert_eq!(extraction_network.nodes.len(), 2);
let inputs = extraction_network.nodes.get(&1).unwrap().inputs.clone();
assert_eq!(inputs.len(), 1);
assert!(matches!(&inputs[0], &NodeInput::Value{ tagged_value: TaggedValue::DocumentNode(ref network), ..} if network == &id_node));
}
#[test]
fn flatten_add() {
let mut network = NodeNetwork {
@ -810,7 +889,7 @@ mod test {
inputs: vec![
NodeInput::Network(concrete!(u32)),
NodeInput::Value {
tagged_value: crate::document::value::TaggedValue::U32(2),
tagged_value: TaggedValue::U32(2),
exposed: false,
},
],
@ -876,7 +955,7 @@ mod test {
construction_args: ConstructionArgs::Nodes(vec![]),
},
),
(14, ProtoNode::value(ConstructionArgs::Value(crate::document::value::TaggedValue::U32(2)))),
(14, ProtoNode::value(ConstructionArgs::Value(TaggedValue::U32(2)))),
]
.into_iter()
.collect(),
@ -917,7 +996,7 @@ mod test {
DocumentNode {
name: "Value".into(),
inputs: vec![NodeInput::Value {
tagged_value: crate::document::value::TaggedValue::U32(2),
tagged_value: TaggedValue::U32(2),
exposed: false,
}],
metadata: DocumentNodeMetadata::default(),
@ -979,10 +1058,7 @@ mod test {
10,
DocumentNode {
name: "Nested network".into(),
inputs: vec![
NodeInput::value(crate::document::value::TaggedValue::F32(1.), false),
NodeInput::value(crate::document::value::TaggedValue::F32(2.), false),
],
inputs: vec![NodeInput::value(TaggedValue::F32(1.), false), NodeInput::value(TaggedValue::F32(2.), false)],
metadata: DocumentNodeMetadata::default(),
implementation: DocumentNodeImplementation::Network(two_node_identity()),
},
@ -1015,11 +1091,7 @@ mod test {
assert_eq!(result.nodes.keys().copied().collect::<Vec<_>>(), vec![101], "Should just call nested network");
let nested_network_node = result.nodes.get(&101).unwrap();
assert_eq!(nested_network_node.name, "Nested network".to_string(), "Name should not change");
assert_eq!(
nested_network_node.inputs,
vec![NodeInput::value(crate::document::value::TaggedValue::F32(2.), false)],
"Input should be 2"
);
assert_eq!(nested_network_node.inputs, vec![NodeInput::value(TaggedValue::F32(2.), false)], "Input should be 2");
let inner_network = nested_network_node.implementation.get_network().expect("Implementation should be network");
assert_eq!(inner_network.inputs, vec![2], "The input should be sent to the second node");
assert_eq!(inner_network.outputs, vec![NodeOutput::new(2, 0)], "The output should be node id 2");
@ -1038,11 +1110,7 @@ mod test {
for (node_id, input_value, inner_id) in [(10, 1., 1), (101, 2., 2)] {
let nested_network_node = result.nodes.get(&node_id).unwrap();
assert_eq!(nested_network_node.name, "Nested network".to_string(), "Name should not change");
assert_eq!(
nested_network_node.inputs,
vec![NodeInput::value(crate::document::value::TaggedValue::F32(input_value), false)],
"Input should be stable"
);
assert_eq!(nested_network_node.inputs, vec![NodeInput::value(TaggedValue::F32(input_value), false)], "Input should be stable");
let inner_network = nested_network_node.implementation.get_network().expect("Implementation should be network");
assert_eq!(inner_network.inputs, vec![inner_id], "The input should be sent to the second node");
assert_eq!(inner_network.outputs, vec![NodeOutput::new(inner_id, 0)], "The output should be node id");
@ -1061,11 +1129,7 @@ mod test {
assert_eq!(result_node.inputs, vec![NodeInput::node(101, 0)], "Result node should refer to duplicate node as input");
let nested_network_node = result.nodes.get(&101).unwrap();
assert_eq!(nested_network_node.name, "Nested network".to_string(), "Name should not change");
assert_eq!(
nested_network_node.inputs,
vec![NodeInput::value(crate::document::value::TaggedValue::F32(2.), false)],
"Input should be 2"
);
assert_eq!(nested_network_node.inputs, vec![NodeInput::value(TaggedValue::F32(2.), false)], "Input should be 2");
let inner_network = nested_network_node.implementation.get_network().expect("Implementation should be network");
assert_eq!(inner_network.inputs, vec![2], "The input should be sent to the second node");
assert_eq!(inner_network.outputs, vec![NodeOutput::new(2, 0)], "The output should be node id 2");

View File

@ -1,15 +1,17 @@
use super::DocumentNode;
use crate::executor::Any;
pub use crate::imaginate_input::{ImaginateMaskStartingFill, ImaginateSamplingMethod, ImaginateStatus};
use graphene_core::raster::{BlendMode, LuminanceCalculation};
use graphene_core::{Color, Node, Type};
pub use dyn_any::StaticType;
use dyn_any::{DynAny, Upcast};
use dyn_clone::DynClone;
pub use glam::{DAffine2, DVec2};
use graphene_core::raster::{BlendMode, LuminanceCalculation};
use graphene_core::{Color, Node, Type};
use std::hash::Hash;
pub use std::sync::Arc;
use crate::executor::Any;
pub use crate::imaginate_input::{ImaginateMaskStartingFill, ImaginateSamplingMethod, ImaginateStatus};
/// A type that is known, allowing serialization (serde::Deserialize is not object safe)
#[derive(Clone, Debug, PartialEq)]
#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
@ -52,6 +54,7 @@ pub enum TaggedValue {
ManipulatorGroupIds(Vec<graphene_core::uuid::ManipulatorGroupId>),
VecDVec2(Vec<DVec2>),
Segments(Vec<graphene_core::raster::ImageFrame<Color>>),
DocumentNode(DocumentNode),
}
#[allow(clippy::derived_hash_with_manual_eq)]
@ -119,11 +122,13 @@ impl Hash for TaggedValue {
}
}
Self::Segments(segments) => {
32.hash(state);
for segment in segments {
segment.hash(state)
}
}
Self::DocumentNode(document_node) => {
document_node.hash(state);
}
}
}
}
@ -170,6 +175,19 @@ impl<'a> TaggedValue {
TaggedValue::ManipulatorGroupIds(x) => Box::new(x),
TaggedValue::VecDVec2(x) => Box::new(x),
TaggedValue::Segments(x) => Box::new(x),
TaggedValue::DocumentNode(x) => Box::new(x),
}
}
pub fn to_primitive_string(&self) -> String {
match self {
TaggedValue::None => "()".to_string(),
TaggedValue::String(x) => x.clone(),
TaggedValue::U32(x) => x.to_string(),
TaggedValue::F32(x) => x.to_string(),
TaggedValue::F64(x) => x.to_string(),
TaggedValue::Bool(x) => x.to_string(),
_ => panic!("Cannot convert to primitive string"),
}
}
@ -215,6 +233,7 @@ impl<'a> TaggedValue {
TaggedValue::ManipulatorGroupIds(_) => concrete!(Vec<graphene_core::uuid::ManipulatorGroupId>),
TaggedValue::VecDVec2(_) => concrete!(Vec<DVec2>),
TaggedValue::Segments(_) => concrete!(graphene_core::raster::IndexNode<Vec<graphene_core::raster::ImageFrame<Color>>>),
TaggedValue::DocumentNode(_) => concrete!(crate::document::DocumentNode),
}
}
}

View File

@ -10,6 +10,7 @@ pub struct Compiler {}
impl Compiler {
pub fn compile(&self, mut network: NodeNetwork, resolve_inputs: bool) -> impl Iterator<Item = ProtoNetwork> {
let node_ids = network.nodes.keys().copied().collect::<Vec<_>>();
network.resolve_extract_nodes();
println!("flattening");
for id in node_ids {
network.flatten(id);

View File

@ -7,6 +7,8 @@ use crate::document::value;
use crate::document::NodeId;
use dyn_any::DynAny;
use graphene_core::*;
#[cfg(feature = "serde")]
use serde::{Deserialize, Serialize};
use std::pin::Pin;
pub type Any<'n> = Box<dyn DynAny<'n> + 'n>;
@ -16,7 +18,8 @@ pub type TypeErasedPinned<'n> = Pin<Box<dyn for<'i> NodeIO<'i, Any<'i>, Output =
pub type NodeConstructor = for<'a> fn(Vec<TypeErasedPinnedRef<'static>>) -> TypeErasedPinned<'static>;
#[derive(Debug, Default, PartialEq)]
#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
#[derive(Debug, Default, PartialEq, Clone)]
pub struct ProtoNetwork {
// Should a proto Network even allow inputs? Don't think so
pub inputs: Vec<NodeId>,
@ -29,7 +32,7 @@ impl core::fmt::Display for ProtoNetwork {
f.write_str("Proto Network with nodes: ")?;
fn write_node(f: &mut core::fmt::Formatter<'_>, network: &ProtoNetwork, id: NodeId, indent: usize) -> core::fmt::Result {
f.write_str(&"\t".repeat(indent))?;
let Some((_, node)) = network.nodes.iter().find(|(node_id, _)|*node_id == id) else{
let Some((_, node)) = network.nodes.iter().find(|(node_id, _)|*node_id == id) else {
return f.write_str("{{Unknown Node}}");
};
f.write_str("Node: ")?;
@ -70,6 +73,7 @@ impl core::fmt::Display for ProtoNetwork {
}
}
#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
#[derive(Debug, Clone)]
pub enum ConstructionArgs {
Value(value::TaggedValue),
@ -104,12 +108,13 @@ impl Hash for ConstructionArgs {
impl ConstructionArgs {
pub fn new_function_args(&self) -> Vec<String> {
match self {
ConstructionArgs::Nodes(nodes) => nodes.iter().map(|n| format!("n{}", n.0)).collect(),
ConstructionArgs::Value(value) => vec![format!("{:?}", value)],
ConstructionArgs::Nodes(nodes) => nodes.iter().map(|n| format!("&n{}", n.0)).collect(),
ConstructionArgs::Value(value) => vec![value.to_primitive_string()],
}
}
}
#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
#[derive(Debug, PartialEq, Clone)]
pub struct ProtoNode {
pub construction_args: ConstructionArgs,
@ -121,6 +126,7 @@ pub struct ProtoNode {
/// For documentation on the meaning of the variants, see the documentation of the `NodeInput` enum
/// in the `document` module
#[derive(Debug, PartialEq, Eq, Clone)]
#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
pub enum ProtoNodeInput {
None,
Network(Type),

View File

@ -11,41 +11,58 @@ license = "MIT OR Apache-2.0"
[features]
memoization = ["once_cell"]
default = ["memoization"]
gpu = ["graphene-core/gpu", "gpu-compiler-bin-wrapper", "compilation-client"]
gpu = [
"graphene-core/gpu",
"gpu-compiler-bin-wrapper",
"compilation-client",
"gpu-executor",
]
vulkan = ["gpu", "vulkan-executor"]
wgpu = ["gpu", "wgpu-executor"]
quantization = ["autoquant"]
[dependencies]
autoquant = { git = "https://github.com/truedoctor/autoquant", optional = true, features = ["fitting"] }
graphene-core = {path = "../gcore", features = ["async", "std", "serde" ], default-features = false}
borrow_stack = {path = "../borrow_stack"}
dyn-any = {path = "../../libraries/dyn-any", features = ["derive"]}
graph-craft = {path = "../graph-craft"}
vulkan-executor = {path = "../vulkan-executor", optional = true}
wgpu-executor = {path = "../wgpu-executor", optional = true}
gpu-compiler-bin-wrapper = {path = "../gpu-compiler/gpu-compiler-bin-wrapper", optional = true}
compilation-client = {path = "../compilation-client", optional = true}
bytemuck = {version = "1.8" }
autoquant = { git = "https://github.com/truedoctor/autoquant", optional = true, features = [
"fitting",
] }
graphene-core = { path = "../gcore", features = [
"async",
"std",
"serde",
], default-features = false }
borrow_stack = { path = "../borrow_stack" }
dyn-any = { path = "../../libraries/dyn-any", features = ["derive"] }
graph-craft = { path = "../graph-craft" }
vulkan-executor = { path = "../vulkan-executor", optional = true }
wgpu-executor = { path = "../wgpu-executor", optional = true, version = "0.1.0" }
gpu-executor = { path = "../gpu-executor", optional = true }
gpu-compiler-bin-wrapper = { path = "../gpu-compiler/gpu-compiler-bin-wrapper", optional = true }
compilation-client = { path = "../compilation-client", optional = true }
bytemuck = { version = "1.8" }
tempfile = "3"
once_cell = {version= "1.10", optional = true}
once_cell = { version = "1.10", optional = true }
#pretty-token-stream = {path = "../../pretty-token-stream"}
syn = {version = "1.0", default-features = false, features = ["parsing", "printing"]}
proc-macro2 = {version = "1.0", default-features = false, features = ["proc-macro"]}
quote = {version = "1.0", default-features = false }
syn = { version = "1.0", default-features = false, features = [
"parsing",
"printing",
] }
proc-macro2 = { version = "1.0", default-features = false, features = [
"proc-macro",
] }
quote = { version = "1.0", default-features = false }
image = { version = "*", default-features = false }
dyn-clone = "1.0"
log = "0.4"
bezier-rs = { path = "../../libraries/bezier-rs" , features = ["serde"] }
bezier-rs = { path = "../../libraries/bezier-rs", features = ["serde"] }
kurbo = { git = "https://github.com/linebender/kurbo.git", features = [
"serde",
] }
glam = { version = "0.22", features = ["serde"] }
node-macro = { path="../node-macro" }
node-macro = { path = "../node-macro" }
boxcar = "0.1.0"
xxhash-rust = {workspace = true}
xxhash-rust = { workspace = true }
[dependencies.serde]
version = "1.0"

View File

@ -1,24 +1,48 @@
use gpu_executor::{GpuExecutor, ShaderIO, ShaderInput};
use graph_craft::document::*;
use graph_craft::proto::*;
use graphene_core::raster::*;
use graphene_core::value::ValueNode;
use graphene_core::*;
use wgpu_executor::NewExecutor;
use bytemuck::Pod;
use core::marker::PhantomData;
use dyn_any::StaticTypeSized;
pub struct MapGpuNode<O, Network> {
network: Network,
_o: PhantomData<O>,
pub struct GpuCompiler<TypingContext, ShaderIO> {
typing_context: TypingContext,
io: ShaderIO,
}
#[node_macro::node_fn(MapGpuNode<_O>)]
fn map_gpu<I: IntoIterator<Item = S>, S: StaticTypeSized + Sync + Send + Pod, _O: StaticTypeSized + Sync + Send + Pod>(input: I, network: &'any_input NodeNetwork) -> Vec<_O> {
// TODO: Move to graph-craft
#[node_macro::node_fn(GpuCompiler)]
fn compile_gpu(node: &'input DocumentNode, mut typing_context: TypingContext, io: ShaderIO) -> compilation_client::Shader {
let compiler = graph_craft::executor::Compiler {};
let DocumentNodeImplementation::Network(network) = node.implementation;
let proto_network = compiler.compile_single(network, true).unwrap();
typing_context.update(&proto_network);
let input_types = proto_network.inputs.iter().map(|id| typing_context.get_type(*id).unwrap()).map(|node_io| node_io.output).collect();
let output_type = typing_context.get_type(proto_network.output).unwrap().output;
let bytes = compilation_client::compile_sync(proto_network, input_types, output_type, io).unwrap();
bytes
}
pub struct MapGpuNode<Shader> {
shader: Shader,
}
#[node_macro::node_fn(MapGpuNode)]
fn map_gpu(inputs: Vec<ShaderInput<<NewExecutor as GpuExecutor>::BufferHandle>>, shader: &'any_input compilation_client::Shader) {
use graph_craft::executor::Executor;
let bytes = compilation_client::compile_sync::<S, _O>(network.clone()).unwrap();
let words = unsafe { std::slice::from_raw_parts(bytes.as_ptr() as *const u32, bytes.len() / 4) };
use wgpu_executor::{Context, GpuExecutor};
let executor: GpuExecutor<S, _O> = GpuExecutor::new(Context::new_sync().unwrap(), words.into(), "gpu::eval".into()).unwrap();
let executor = NewExecutor::new().unwrap();
for input in shader.inputs.iter() {
let buffer = executor.create_buffer(input.size).unwrap();
executor.write_buffer(buffer, input.data).unwrap();
}
todo!();
let executor: GpuExecutor = GpuExecutor::new(Context::new_sync().unwrap(), shader.into(), "gpu::eval".into()).unwrap();
let data: Vec<_> = input.into_iter().collect();
let result = executor.execute(Box::new(data)).unwrap();
let result = dyn_any::downcast::<Vec<_O>>(result).unwrap();
@ -30,7 +54,7 @@ pub struct MapGpuSingleImageNode<N> {
}
#[node_macro::node_fn(MapGpuSingleImageNode)]
fn map_gpu_single_image(input: Image, node: String) -> Image {
fn map_gpu_single_image(input: Image<Color>, node: String) -> Image<Color> {
use graph_craft::document::*;
use graph_craft::NodeIdentifier;

View File

@ -12,6 +12,7 @@ default = []
[dependencies]
graphene-core = { path = "../gcore", features = ["async", "std", "alloc", "gpu"] }
graph-craft = {path = "../graph-craft" }
gpu-executor = { path = "../gpu-executor" }
dyn-any = { path = "../../libraries/dyn-any", features = ["log-bad-types", "rc", "glam"] }
future-executor = { path = "../future-executor" }
num-traits = "0.2"

View File

@ -1,7 +1,7 @@
use std::sync::Arc;
use wgpu::{Device, Instance, Queue};
#[derive(Debug)]
#[derive(Debug, Clone)]
pub struct Context {
pub device: Arc<Device>,
pub queue: Arc<Queue>,

View File

@ -3,3 +3,189 @@ mod executor;
pub use context::Context;
pub use executor::GpuExecutor;
use gpu_executor::{Shader, ShaderInput, StorageBufferOptions, ToStorageBuffer, ToUniformBuffer};
use graph_craft::Type;
use anyhow::{bail, Result};
use futures::Future;
use std::pin::Pin;
use wgpu::util::DeviceExt;
use wgpu::{Buffer, BufferDescriptor, CommandBuffer, ShaderModule};
#[derive(Debug, Clone)]
pub struct NewExecutor {
context: Context,
}
impl gpu_executor::GpuExecutor for NewExecutor {
type ShaderHandle = ShaderModule;
type BufferHandle = Buffer;
type CommandBuffer = CommandBuffer;
fn load_shader(&self, shader: Shader) -> Result<Self::ShaderHandle> {
let shader_module = self.context.device.create_shader_module(wgpu::ShaderModuleDescriptor {
label: Some(shader.name),
source: wgpu::ShaderSource::SpirV(shader.source),
});
Ok(shader_module)
}
fn create_uniform_buffer<T: ToUniformBuffer>(&self, data: T) -> Result<ShaderInput<Self::BufferHandle>> {
let bytes = data.to_bytes();
let buffer = self.context.device.create_buffer_init(&wgpu::util::BufferInitDescriptor {
label: None,
contents: bytes.as_ref(),
usage: wgpu::BufferUsages::UNIFORM,
});
Ok(ShaderInput::UniformBuffer(buffer, Type::new::<T>()))
}
fn create_storage_buffer<T: ToStorageBuffer>(&self, data: T, options: StorageBufferOptions) -> Result<ShaderInput<Self::BufferHandle>> {
let bytes = data.to_bytes();
let mut usage = wgpu::BufferUsages::STORAGE;
if options.gpu_writable {
usage |= wgpu::BufferUsages::COPY_SRC | wgpu::BufferUsages::COPY_DST;
}
if options.cpu_readable {
usage |= wgpu::BufferUsages::MAP_READ | wgpu::BufferUsages::COPY_DST;
}
if options.cpu_writable {
usage |= wgpu::BufferUsages::MAP_WRITE | wgpu::BufferUsages::COPY_SRC;
}
let buffer = self.context.device.create_buffer_init(&wgpu::util::BufferInitDescriptor {
label: None,
contents: bytes.as_ref(),
usage,
});
Ok(ShaderInput::StorageBuffer(buffer, Type::new::<T>()))
}
fn create_output_buffer(&self, len: usize, ty: Type, cpu_readable: bool) -> Result<ShaderInput<Self::BufferHandle>> {
let create_buffer = |usage| {
Ok::<_, anyhow::Error>(self.context.device.create_buffer(&BufferDescriptor {
label: None,
size: len as u64 * ty.size().ok_or_else(|| anyhow::anyhow!("Cannot create buffer of type {:?}", ty))? as u64,
usage,
mapped_at_creation: false,
}))
};
let buffer = match cpu_readable {
true => ShaderInput::ReadBackBuffer(create_buffer(wgpu::BufferUsages::STORAGE | wgpu::BufferUsages::COPY_DST | wgpu::BufferUsages::MAP_READ)?, ty),
false => ShaderInput::OutputBuffer(create_buffer(wgpu::BufferUsages::STORAGE | wgpu::BufferUsages::COPY_SRC)?, ty),
};
Ok(buffer)
}
fn create_compute_pass(&self, layout: &gpu_executor::PipelineLayout<Self>, read_back: Option<ShaderInput<Self::BufferHandle>>, instances: u32) -> Result<CommandBuffer> {
let compute_pipeline = self.context.device.create_compute_pipeline(&wgpu::ComputePipelineDescriptor {
label: None,
layout: None,
module: &layout.shader,
entry_point: layout.entry_point.as_str(),
});
let bind_group_layout = compute_pipeline.get_bind_group_layout(0);
let entries = layout
.bind_group
.buffers
.iter()
.chain(std::iter::once(&layout.output_buffer))
.flat_map(|input| input.buffer())
.enumerate()
.map(|(i, buffer)| wgpu::BindGroupEntry {
binding: i as u32,
resource: buffer.as_entire_binding(),
})
.collect::<Vec<_>>();
let bind_group = self.context.device.create_bind_group(&wgpu::BindGroupDescriptor {
label: None,
layout: &bind_group_layout,
entries: entries.as_slice(),
});
let mut encoder = self.context.device.create_command_encoder(&wgpu::CommandEncoderDescriptor { label: None });
{
let mut cpass = encoder.begin_compute_pass(&wgpu::ComputePassDescriptor { label: None });
cpass.set_pipeline(&compute_pipeline);
cpass.set_bind_group(0, &bind_group, &[]);
cpass.insert_debug_marker("compute node network evaluation");
cpass.dispatch_workgroups(instances, 1, 1); // Number of cells to run, the (x,y,z) size of item being processed
}
// Sets adds copy operation to command encoder.
// Will copy data from storage buffer on GPU to staging buffer on CPU.
if let Some(ShaderInput::ReadBackBuffer(output, ty)) = read_back {
let size = output.size();
assert_eq!(size, layout.output_buffer.buffer().unwrap().size());
assert_eq!(ty, layout.output_buffer.ty());
encoder.copy_buffer_to_buffer(
layout.output_buffer.buffer().ok_or_else(|| anyhow::anyhow!("Tried to use an non buffer as the shader output"))?,
0,
&output,
0,
size,
);
}
// Submits command encoder for processing
Ok(encoder.finish())
}
fn execute_compute_pipeline(&self, encoder: Self::CommandBuffer) -> Result<()> {
self.context.queue.submit(Some(encoder));
// Poll the device in a blocking manner so that our future resolves.
// In an actual application, `device.poll(...)` should
// be called in an event loop or on another thread.
self.context.device.poll(wgpu::Maintain::Wait);
Ok(())
}
fn read_output_buffer(&self, buffer: ShaderInput<Self::BufferHandle>) -> Result<Pin<Box<dyn Future<Output = Result<Vec<u8>>>>>> {
if let ShaderInput::ReadBackBuffer(buffer, _) = buffer {
let future = Box::pin(async move {
let buffer_slice = buffer.slice(..);
// Sets the buffer up for mapping, sending over the result of the mapping back to us when it is finished.
let (sender, receiver) = futures_intrusive::channel::shared::oneshot_channel();
buffer_slice.map_async(wgpu::MapMode::Read, move |v| sender.send(v).unwrap());
// Wait for the mapping to finish.
#[cfg(feature = "profiling")]
nvtx::range_push!("compute");
let result = receiver.receive().await;
#[cfg(feature = "profiling")]
nvtx::range_pop!();
if result == Some(Ok(())) {
// Gets contents of buffer
let data = buffer_slice.get_mapped_range();
// Since contents are got in bytes, this converts these bytes back to u32
let result = bytemuck::cast_slice(&data).to_vec();
// With the current interface, we have to make sure all mapped views are
// dropped before we unmap the buffer.
drop(data);
buffer.unmap(); // Unmaps buffer from memory
// Returns data from buffer
Ok(result)
} else {
bail!("failed to run compute on gpu!")
}
});
Ok(future)
} else {
bail!("Tried to read a non readback buffer")
}
}
}
impl NewExecutor {
pub fn new() -> Option<Self> {
let context = Context::new_sync()?;
Some(Self { context })
}
}