Restructure GPU execution to model GPU pipelines in the node graph (#1088)
* Start implementing GpuExecutor for wgpu * Implement read_output_buffer function * Implement extraction node in the compiler * Generate type annotations during shader compilation * Start adding node wrapprs for graph execution api * Wrap more of the api in nodes * Restructure Pipeline to accept arbitrary shader inputs * Adapt nodes to new trait definitions * Start implementing gpu-compiler trait * Adapt shader generation * Hardstuck on pointer casts * Pass nodes as references in gpu code to avoid zsts * Update gcore to compile on the gpu * Fix color doc tests * Impl Node for node refs
This commit is contained in:
parent
161bbc62b4
commit
bdc1ef926a
File diff suppressed because it is too large
Load Diff
|
|
@ -15,6 +15,7 @@ members = [
|
|||
"node-graph/compilation-client",
|
||||
"node-graph/vulkan-executor",
|
||||
"node-graph/wgpu-executor",
|
||||
"node-graph/gpu-executor",
|
||||
"node-graph/future-executor",
|
||||
"node-graph/gpu-compiler/gpu-compiler-bin-wrapper",
|
||||
"libraries/dyn-any",
|
||||
|
|
|
|||
66
deny.toml
66
deny.toml
|
|
@ -18,13 +18,13 @@
|
|||
# dependencies not shared by any other crates, would be ignored, as the target
|
||||
# list here is effectively saying which targets you are building for.
|
||||
targets = [
|
||||
# The triple can be any string, but only the target triples built in to
|
||||
# rustc (as of 1.40) can be checked against actual config expressions
|
||||
#{ triple = "x86_64-unknown-linux-musl" },
|
||||
# You can also specify which target_features you promise are enabled for a
|
||||
# particular target. target_features are currently not validated against
|
||||
# the actual valid features supported by the target architecture.
|
||||
#{ triple = "wasm32-unknown-unknown", features = ["atomics"] },
|
||||
# The triple can be any string, but only the target triples built in to
|
||||
# rustc (as of 1.40) can be checked against actual config expressions
|
||||
#{ triple = "x86_64-unknown-linux-musl" },
|
||||
# You can also specify which target_features you promise are enabled for a
|
||||
# particular target. target_features are currently not validated against
|
||||
# the actual valid features supported by the target architecture.
|
||||
#{ triple = "wasm32-unknown-unknown", features = ["atomics"] },
|
||||
]
|
||||
|
||||
# This section is considered when running `cargo deny check advisories`
|
||||
|
|
@ -48,7 +48,7 @@ notice = "warn"
|
|||
# A list of advisory IDs to ignore. Note that ignored advisories will still
|
||||
# output a note when they are encountered.
|
||||
ignore = [
|
||||
#"RUSTSEC-0000-0000",
|
||||
#"RUSTSEC-0000-0000",
|
||||
"RUSTSEC-2020-0071", # This has been fixed in the version of chrono we use
|
||||
]
|
||||
# Threshold for security vulnerabilities, any vulnerability with a CVSS score
|
||||
|
|
@ -71,25 +71,25 @@ unlicensed = "deny"
|
|||
# See https://spdx.org/licenses/ for list of possible licenses
|
||||
# [possible values: any SPDX 3.11 short identifier (+ optional exception)].
|
||||
allow = [
|
||||
"MIT",
|
||||
"MIT-0",
|
||||
"Apache-2.0",
|
||||
"BSD-3-Clause",
|
||||
"BSD-2-Clause",
|
||||
"Zlib",
|
||||
"MIT",
|
||||
"MIT-0",
|
||||
"Apache-2.0",
|
||||
"BSD-3-Clause",
|
||||
"BSD-2-Clause",
|
||||
"Zlib",
|
||||
"Unicode-DFS-2016",
|
||||
"ISC",
|
||||
"MPL-2.0",
|
||||
"CC0-1.0",
|
||||
"OpenSSL",
|
||||
"BSL-1.0",
|
||||
#"Apache-2.0 WITH LLVM-exception",
|
||||
"Apache-2.0 WITH LLVM-exception",
|
||||
]
|
||||
# List of explicitly disallowed licenses
|
||||
# See https://spdx.org/licenses/ for list of possible licenses
|
||||
# [possible values: any SPDX 3.11 short identifier (+ optional exception)].
|
||||
deny = [
|
||||
#"Nokia",
|
||||
#"Nokia",
|
||||
]
|
||||
# Lint level for licenses considered copyleft
|
||||
copyleft = "deny"
|
||||
|
|
@ -113,9 +113,9 @@ confidence-threshold = 0.8
|
|||
# Allow 1 or more licenses on a per-crate basis, so that particular licenses
|
||||
# aren't accepted for every possible crate as with the normal allow list
|
||||
exceptions = [
|
||||
# Each entry is the crate and version constraint, and its specific allow
|
||||
# list
|
||||
#{ allow = ["Zlib"], name = "adler32", version = "*" },
|
||||
# Each entry is the crate and version constraint, and its specific allow
|
||||
# list
|
||||
#{ allow = ["Zlib"], name = "adler32", version = "*" },
|
||||
]
|
||||
|
||||
# Some crates don't have (easily) machine readable licensing information,
|
||||
|
|
@ -134,8 +134,8 @@ expression = "MIT AND ISC AND OpenSSL"
|
|||
# and the crate will be checked normally, which may produce warnings or errors
|
||||
# depending on the rest of your configuration
|
||||
license-files = [
|
||||
# Each entry is a crate relative path, and the (opaque) hash of its contents
|
||||
{ path = "LICENSE", hash = 0xbd0eed23 }
|
||||
# Each entry is a crate relative path, and the (opaque) hash of its contents
|
||||
{ path = "LICENSE", hash = 0xbd0eed23 }
|
||||
]
|
||||
|
||||
[licenses.private]
|
||||
|
|
@ -146,7 +146,7 @@ ignore = false
|
|||
# is only published to private registries, and ignore is true, the crate will
|
||||
# not have its license(s) checked
|
||||
registries = [
|
||||
#"https://sekretz.com/registry
|
||||
#"https://sekretz.com/registry
|
||||
]
|
||||
|
||||
# This section is considered when running `cargo deny check bans`.
|
||||
|
|
@ -165,29 +165,29 @@ wildcards = "allow"
|
|||
highlight = "all"
|
||||
# List of crates that are allowed. Use with care!
|
||||
allow = [
|
||||
#{ name = "ansi_term", version = "=0.11.0" },
|
||||
#{ name = "ansi_term", version = "=0.11.0" },
|
||||
]
|
||||
# List of crates to deny
|
||||
deny = [
|
||||
# Each entry the name of a crate and a version range. If version is
|
||||
# not specified, all versions will be matched.
|
||||
#{ name = "ansi_term", version = "=0.11.0" },
|
||||
#
|
||||
# Wrapper crates can optionally be specified to allow the crate when it
|
||||
# is a direct dependency of the otherwise banned crate
|
||||
#{ name = "ansi_term", version = "=0.11.0", wrappers = [] },
|
||||
# Each entry the name of a crate and a version range. If version is
|
||||
# not specified, all versions will be matched.
|
||||
#{ name = "ansi_term", version = "=0.11.0" },
|
||||
#
|
||||
# Wrapper crates can optionally be specified to allow the crate when it
|
||||
# is a direct dependency of the otherwise banned crate
|
||||
#{ name = "ansi_term", version = "=0.11.0", wrappers = [] },
|
||||
]
|
||||
# Certain crates/versions that will be skipped when doing duplicate detection.
|
||||
skip = [
|
||||
#{ name = "ansi_term", version = "=0.11.0" },
|
||||
{ name = "cfg-if", version = "=0.1.10" },
|
||||
#{ name = "ansi_term", version = "=0.11.0" },
|
||||
{ name = "cfg-if", version = "=0.1.10" },
|
||||
]
|
||||
# Similarly to `skip` allows you to skip certain crates during duplicate
|
||||
# detection. Unlike skip, it also includes the entire tree of transitive
|
||||
# dependencies starting at the specified crate, up to a certain depth, which is
|
||||
# by default infinite
|
||||
skip-tree = [
|
||||
#{ name = "ansi_term", version = "=0.11.0", depth = 20 },
|
||||
#{ name = "ansi_term", version = "=0.11.0", depth = 20 },
|
||||
]
|
||||
|
||||
# This section is considered when running `cargo deny check sources`.
|
||||
|
|
|
|||
|
|
@ -169,13 +169,13 @@ impl<'a> ModifyInputsContext<'a> {
|
|||
let NodeInput::Value {
|
||||
tagged_value: TaggedValue::Subpaths(subpaths),
|
||||
..
|
||||
} = subpaths else{
|
||||
} = subpaths else {
|
||||
return;
|
||||
};
|
||||
let NodeInput::Value {
|
||||
tagged_value: TaggedValue::ManipulatorGroupIds(mirror_angle_groups),
|
||||
..
|
||||
} = mirror_angle_groups else{
|
||||
} = mirror_angle_groups else {
|
||||
return;
|
||||
};
|
||||
|
||||
|
|
|
|||
|
|
@ -1018,8 +1018,8 @@ pub fn wrap_network_in_scope(network: NodeNetwork) -> NodeNetwork {
|
|||
|
||||
// if the network has no network inputs, it doesn't need to be wrapped in a scope either
|
||||
let Some(input_type) = input else {
|
||||
return network;
|
||||
};
|
||||
return network;
|
||||
};
|
||||
|
||||
let inner_network = DocumentNode {
|
||||
name: "Scope".to_string(),
|
||||
|
|
|
|||
|
|
@ -160,13 +160,13 @@ unsafe impl<'a, T: StaticTypeSized> StaticType for &'a [T] {
|
|||
type Static = &'static [<T as StaticTypeSized>::Static];
|
||||
}
|
||||
macro_rules! impl_slice {
|
||||
($($id:ident),*) => {
|
||||
$(
|
||||
unsafe impl<'a, T: StaticTypeSized> StaticType for $id<'a, T> {
|
||||
type Static = $id<'static, <T as StaticTypeSized>::Static>;
|
||||
}
|
||||
)*
|
||||
};
|
||||
($($id:ident),*) => {
|
||||
$(
|
||||
unsafe impl<'a, T: StaticTypeSized> StaticType for $id<'a, T> {
|
||||
type Static = $id<'static, <T as StaticTypeSized>::Static>;
|
||||
}
|
||||
)*
|
||||
};
|
||||
}
|
||||
|
||||
mod slice {
|
||||
|
|
|
|||
|
|
@ -9,9 +9,18 @@ license = "MIT OR Apache-2.0"
|
|||
[dependencies]
|
||||
#tokio = { version = "1.0", features = ["full"] }
|
||||
serde_json = "1.0"
|
||||
graph-craft = { version = "0.1.0", path = "../graph-craft", features = ["serde"] }
|
||||
graph-craft = { version = "0.1.0", path = "../graph-craft", features = [
|
||||
"serde",
|
||||
] }
|
||||
gpu-executor = { version = "0.1.0", path = "../gpu-executor" }
|
||||
gpu-compiler-bin-wrapper = { version = "0.1.0", path = "../gpu-compiler/gpu-compiler-bin-wrapper" }
|
||||
tempfile = "3.3.0"
|
||||
anyhow = "1.0.68"
|
||||
reqwest = { version = "0.11", features = ["blocking", "serde_json", "json", "rustls", "rustls-tls"] }
|
||||
future-executor = {path = "../future-executor"}
|
||||
reqwest = { version = "0.11", features = [
|
||||
"blocking",
|
||||
"serde_json",
|
||||
"json",
|
||||
"rustls",
|
||||
"rustls-tls",
|
||||
] }
|
||||
future-executor = { path = "../future-executor" }
|
||||
|
|
|
|||
|
|
@ -1,15 +1,30 @@
|
|||
use gpu_compiler_bin_wrapper::CompileRequest;
|
||||
use graph_craft::document::*;
|
||||
use gpu_executor::ShaderIO;
|
||||
use graph_craft::{proto::ProtoNetwork, Type};
|
||||
|
||||
pub async fn compile<I, O>(network: NodeNetwork) -> Result<Vec<u8>, reqwest::Error> {
|
||||
pub async fn compile(network: ProtoNetwork, inputs: Vec<Type>, output: Type, io: ShaderIO) -> Result<Shader, reqwest::Error> {
|
||||
let client = reqwest::Client::new();
|
||||
|
||||
let compile_request = CompileRequest::new(network, std::any::type_name::<I>().to_owned(), std::any::type_name::<O>().to_owned());
|
||||
let compile_request = CompileRequest::new(network, inputs.clone(), output.clone(), io.clone());
|
||||
let response = client.post("http://localhost:3000/compile/spirv").json(&compile_request).send();
|
||||
let response = response.await?;
|
||||
response.bytes().await.map(|b| b.to_vec())
|
||||
response.bytes().await.map(|b| Shader {
|
||||
spirv_binary: b.windows(4).map(|x| u32::from_le_bytes(x.try_into().unwrap())).collect(),
|
||||
input_types: inputs,
|
||||
output_type: output,
|
||||
io,
|
||||
})
|
||||
}
|
||||
|
||||
pub fn compile_sync<I: 'static, O: 'static>(network: NodeNetwork) -> Result<Vec<u8>, reqwest::Error> {
|
||||
future_executor::block_on(compile::<I, O>(network))
|
||||
pub fn compile_sync(network: ProtoNetwork, inputs: Vec<Type>, output: Type, io: ShaderIO) -> Result<Shader, reqwest::Error> {
|
||||
future_executor::block_on(compile(network, inputs, output, io))
|
||||
}
|
||||
|
||||
// TODO: should we add the entry point as a field?
|
||||
/// A compiled shader with type annotations.
|
||||
pub struct Shader {
|
||||
pub spirv_binary: Vec<u32>,
|
||||
pub input_types: Vec<Type>,
|
||||
pub output_type: Type,
|
||||
pub io: ShaderIO,
|
||||
}
|
||||
|
|
|
|||
|
|
@ -1,40 +1,55 @@
|
|||
use gpu_compiler_bin_wrapper::CompileRequest;
|
||||
use gpu_executor::{ShaderIO, ShaderInput};
|
||||
use graph_craft::concrete;
|
||||
use graph_craft::document::*;
|
||||
|
||||
use graph_craft::*;
|
||||
|
||||
use std::borrow::Cow;
|
||||
use std::time::Duration;
|
||||
|
||||
fn main() {
|
||||
let client = reqwest::blocking::Client::new();
|
||||
|
||||
let network = NodeNetwork {
|
||||
inputs: vec![0],
|
||||
outputs: vec![NodeOutput::new(0, 0)],
|
||||
disabled: vec![],
|
||||
previous_outputs: None,
|
||||
nodes: [(
|
||||
0,
|
||||
DocumentNode {
|
||||
name: "Inc".into(),
|
||||
inputs: vec![NodeInput::Network(concrete!(u32))],
|
||||
implementation: DocumentNodeImplementation::Network(add_network()),
|
||||
metadata: DocumentNodeMetadata::default(),
|
||||
},
|
||||
)]
|
||||
.into_iter()
|
||||
.collect(),
|
||||
// let network = NodeNetwork {
|
||||
// inputs: vec![0],
|
||||
// outputs: vec![NodeOutput::new(0, 0)],
|
||||
// disabled: vec![],
|
||||
// previous_outputs: None,
|
||||
// nodes: [(
|
||||
// 0,
|
||||
// DocumentNode {
|
||||
// name: "Inc".into(),
|
||||
// inputs: vec![NodeInput::Network(concrete!(u32))],
|
||||
// implementation: DocumentNodeImplementation::Network(add_network()),
|
||||
// metadata: DocumentNodeMetadata::default(),
|
||||
// },
|
||||
// )]
|
||||
// .into_iter()
|
||||
// .collect(),
|
||||
// };
|
||||
let network = add_network();
|
||||
let compiler = graph_craft::executor::Compiler {};
|
||||
let proto_network = compiler.compile_single(network, true).unwrap();
|
||||
|
||||
let io = ShaderIO {
|
||||
inputs: vec![ShaderInput::StorageBuffer((), concrete!(u32))],
|
||||
output: ShaderInput::OutputBuffer((), concrete!(&mut [u32])),
|
||||
};
|
||||
|
||||
let compile_request = CompileRequest::new(network, "u32".to_owned(), "u32".to_owned());
|
||||
let response = client.post("http://localhost:3000/compile/spirv").json(&compile_request).send().unwrap();
|
||||
let compile_request = CompileRequest::new(proto_network, vec![concrete!(u32)], concrete!(u32), io);
|
||||
let response = client
|
||||
.post("http://localhost:3000/compile/spirv")
|
||||
.timeout(Duration::from_secs(30))
|
||||
.json(&compile_request)
|
||||
.send()
|
||||
.unwrap();
|
||||
println!("response: {:?}", response);
|
||||
}
|
||||
|
||||
fn add_network() -> NodeNetwork {
|
||||
NodeNetwork {
|
||||
inputs: vec![0],
|
||||
outputs: vec![NodeOutput::new(1, 0)],
|
||||
inputs: vec![],
|
||||
outputs: vec![NodeOutput::new(0, 0)],
|
||||
disabled: vec![],
|
||||
previous_outputs: None,
|
||||
nodes: [
|
||||
|
|
@ -42,20 +57,20 @@ fn add_network() -> NodeNetwork {
|
|||
0,
|
||||
DocumentNode {
|
||||
name: "Dup".into(),
|
||||
inputs: vec![NodeInput::Network(concrete!(u32))],
|
||||
inputs: vec![NodeInput::value(value::TaggedValue::U32(5u32), false)],
|
||||
metadata: DocumentNodeMetadata::default(),
|
||||
implementation: DocumentNodeImplementation::Unresolved(NodeIdentifier::new("graphene_core::ops::DupNode")),
|
||||
},
|
||||
),
|
||||
(
|
||||
1,
|
||||
DocumentNode {
|
||||
name: "Add".into(),
|
||||
inputs: vec![NodeInput::node(0, 0)],
|
||||
metadata: DocumentNodeMetadata::default(),
|
||||
implementation: DocumentNodeImplementation::Unresolved(NodeIdentifier::new("graphene_core::ops::AddNode")),
|
||||
implementation: DocumentNodeImplementation::Unresolved(NodeIdentifier::new("graphene_core::ops::IdNode")),
|
||||
},
|
||||
),
|
||||
// (
|
||||
// 1,
|
||||
// DocumentNode {
|
||||
// name: "Add".into(),
|
||||
// inputs: vec![NodeInput::node(0, 0)],
|
||||
// metadata: DocumentNodeMetadata::default(),
|
||||
// implementation: DocumentNodeImplementation::Unresolved(NodeIdentifier::new("graphene_core::ops::AddNode")),
|
||||
// },
|
||||
// ),
|
||||
]
|
||||
.into_iter()
|
||||
.collect(),
|
||||
|
|
|
|||
|
|
@ -9,10 +9,23 @@ license = "MIT OR Apache-2.0"
|
|||
# See more keys and their definitions at https://doc.rust-lang.org/cargo/reference/manifest.html
|
||||
|
||||
[features]
|
||||
std = ["dyn-any", "dyn-any/std", "alloc", "glam/std", "specta"]
|
||||
default = ["async", "serde", "kurbo", "log", "std"]
|
||||
std = [
|
||||
"dyn-any",
|
||||
"dyn-any/std",
|
||||
"alloc",
|
||||
"glam/std",
|
||||
"specta",
|
||||
"num-traits/std",
|
||||
]
|
||||
default = ["async", "serde", "kurbo", "log", "std", "rand_chacha"]
|
||||
log = ["dep:log"]
|
||||
serde = ["dep:serde", "glam/serde", "bezier-rs/serde", "base64"]
|
||||
serde = [
|
||||
"dep:serde",
|
||||
"glam/serde",
|
||||
"bezier-rs/serde",
|
||||
"bezier-rs/serde",
|
||||
"base64",
|
||||
]
|
||||
gpu = ["spirv-std", "glam/bytemuck", "dyn-any", "glam/libm"]
|
||||
async = ["async-trait", "alloc"]
|
||||
nightly = []
|
||||
|
|
@ -25,7 +38,7 @@ dyn-any = { path = "../../libraries/dyn-any", features = [
|
|||
"glam",
|
||||
], optional = true, default-features = false }
|
||||
|
||||
spirv-std = { version = "0.5", features = ["glam"], optional = true }
|
||||
spirv-std = { version = "0.7", optional = true }
|
||||
bytemuck = { version = "1.8", features = ["derive"] }
|
||||
async-trait = { version = "0.1", optional = true }
|
||||
serde = { version = "1.0", features = [
|
||||
|
|
@ -33,11 +46,11 @@ serde = { version = "1.0", features = [
|
|||
], optional = true, default-features = false }
|
||||
log = { version = "0.4", optional = true }
|
||||
|
||||
rand_chacha = { version = "0.3.1", optional = true }
|
||||
bezier-rs = { path = "../../libraries/bezier-rs", optional = true }
|
||||
kurbo = { git = "https://github.com/linebender/kurbo.git", features = [
|
||||
"serde",
|
||||
], optional = true }
|
||||
rand_chacha = "0.3.1"
|
||||
spin = "0.9.2"
|
||||
glam = { version = "^0.22", default-features = false, features = [
|
||||
"scalar-math",
|
||||
|
|
@ -47,6 +60,7 @@ base64 = { version = "0.13", optional = true }
|
|||
specta.workspace = true
|
||||
specta.optional = true
|
||||
once_cell = { version = "1.17.0", default-features = false, optional = true }
|
||||
num = "0.4.0"
|
||||
num-derive = "0.3.3"
|
||||
num-traits = "0.2.15"
|
||||
num-derive = { version = "0.3.3" }
|
||||
num-traits = { version = "0.2.15", default-features = false, features = [
|
||||
"i128",
|
||||
] }
|
||||
|
|
|
|||
|
|
@ -1,4 +1,7 @@
|
|||
use crate::{raster::Sample, Color};
|
||||
|
||||
use bytemuck::{Pod, Zeroable};
|
||||
use spirv_std::image::{Image2d, SampledImage};
|
||||
|
||||
#[repr(C)]
|
||||
#[derive(Clone, Copy, Debug, PartialEq, Eq, Pod, Zeroable)]
|
||||
|
|
@ -6,3 +9,12 @@ pub struct PushConstants {
|
|||
pub n: u32,
|
||||
pub node: u32,
|
||||
}
|
||||
|
||||
impl Sample for SampledImage<Image2d> {
|
||||
type Pixel = Color;
|
||||
|
||||
fn sample(&self, pos: glam::DVec2) -> Option<Self::Pixel> {
|
||||
let color = self.sample(pos);
|
||||
Color::from_rgbaf32(color.x, color.y, color.z, color.w)
|
||||
}
|
||||
}
|
||||
|
|
|
|||
|
|
@ -18,6 +18,8 @@ pub mod value;
|
|||
#[cfg(feature = "gpu")]
|
||||
pub mod gpu;
|
||||
|
||||
pub mod storage;
|
||||
|
||||
pub mod raster;
|
||||
#[cfg(feature = "alloc")]
|
||||
pub mod transform;
|
||||
|
|
@ -44,8 +46,8 @@ pub use types::*;
|
|||
|
||||
pub trait NodeIO<'i, Input: 'i>: 'i + Node<'i, Input>
|
||||
where
|
||||
Self::Output: 'i + StaticType,
|
||||
Input: 'i + StaticType,
|
||||
Self::Output: 'i + StaticTypeSized,
|
||||
Input: 'i + StaticTypeSized,
|
||||
{
|
||||
fn input_type(&self) -> TypeId {
|
||||
TypeId::of::<Input::Static>()
|
||||
|
|
@ -54,7 +56,7 @@ where
|
|||
core::any::type_name::<Input>()
|
||||
}
|
||||
fn output_type(&self) -> core::any::TypeId {
|
||||
TypeId::of::<<Self::Output as StaticType>::Static>()
|
||||
TypeId::of::<<Self::Output as StaticTypeSized>::Static>()
|
||||
}
|
||||
fn output_type_name(&self) -> &'static str {
|
||||
core::any::type_name::<Self::Output>()
|
||||
|
|
@ -62,8 +64,8 @@ where
|
|||
#[cfg(feature = "alloc")]
|
||||
fn to_node_io(&self, parameters: Vec<Type>) -> NodeIOTypes {
|
||||
NodeIOTypes {
|
||||
input: concrete!(<Input as StaticType>::Static),
|
||||
output: concrete!(<Self::Output as StaticType>::Static),
|
||||
input: concrete!(<Input as StaticTypeSized>::Static),
|
||||
output: concrete!(<Self::Output as StaticTypeSized>::Static),
|
||||
parameters,
|
||||
}
|
||||
}
|
||||
|
|
@ -71,8 +73,8 @@ where
|
|||
|
||||
impl<'i, N: Node<'i, I>, I> NodeIO<'i, I> for N
|
||||
where
|
||||
N::Output: 'i + StaticType,
|
||||
I: 'i + StaticType,
|
||||
N::Output: 'i + StaticTypeSized,
|
||||
I: 'i + StaticTypeSized,
|
||||
{
|
||||
}
|
||||
|
||||
|
|
@ -83,6 +85,13 @@ where
|
|||
(**self).eval(input)
|
||||
}
|
||||
}*/
|
||||
impl<'i, 's: 'i, I: 'i, O: 'i, N: Node<'i, I, Output = O>> Node<'i, I> for &'s N {
|
||||
type Output = O;
|
||||
|
||||
fn eval(&'i self, input: I) -> Self::Output {
|
||||
(**self).eval(input)
|
||||
}
|
||||
}
|
||||
impl<'i, I: 'i, O: 'i> Node<'i, I> for &'i dyn for<'a> Node<'a, I, Output = O> {
|
||||
type Output = O;
|
||||
|
||||
|
|
@ -92,7 +101,7 @@ impl<'i, I: 'i, O: 'i> Node<'i, I> for &'i dyn for<'a> Node<'a, I, Output = O> {
|
|||
}
|
||||
use core::pin::Pin;
|
||||
|
||||
use dyn_any::StaticType;
|
||||
use dyn_any::StaticTypeSized;
|
||||
#[cfg(feature = "alloc")]
|
||||
impl<'i, I: 'i, O: 'i> Node<'i, I> for Pin<Box<dyn for<'a> Node<'a, I, Output = O> + 'i>> {
|
||||
type Output = O;
|
||||
|
|
|
|||
|
|
@ -1,5 +1,5 @@
|
|||
use core::marker::PhantomData;
|
||||
use core::ops::Add;
|
||||
use core::ops::{Add, Mul};
|
||||
|
||||
use crate::Node;
|
||||
|
||||
|
|
@ -30,6 +30,27 @@ where
|
|||
first + second
|
||||
}
|
||||
|
||||
pub struct MulParameterNode<Second> {
|
||||
second: Second,
|
||||
}
|
||||
|
||||
#[node_macro::node_fn(MulParameterNode)]
|
||||
fn flat_map<U, T>(first: U, second: T) -> <U as Mul<T>>::Output
|
||||
where
|
||||
U: Mul<T>,
|
||||
{
|
||||
first * second
|
||||
}
|
||||
|
||||
#[cfg(feature = "std")]
|
||||
struct SizeOfNode {}
|
||||
|
||||
#[cfg(feature = "std")]
|
||||
#[node_macro::node_fn(SizeOfNode)]
|
||||
fn flat_map(ty: crate::Type) -> Option<usize> {
|
||||
ty.size()
|
||||
}
|
||||
|
||||
#[derive(Debug, Copy, Clone, PartialEq, Eq, PartialOrd, Ord, Hash, Default)]
|
||||
pub struct SomeNode;
|
||||
#[node_macro::node_fn(SomeNode)]
|
||||
|
|
|
|||
|
|
@ -2,6 +2,9 @@ use crate::raster::Color;
|
|||
use crate::Node;
|
||||
use dyn_any::{DynAny, StaticType};
|
||||
|
||||
#[cfg(target_arch = "spirv")]
|
||||
use spirv_std::num_traits::Float;
|
||||
|
||||
#[derive(Clone, Debug, DynAny, PartialEq)]
|
||||
#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
|
||||
pub struct Quantization {
|
||||
|
|
|
|||
|
|
@ -4,50 +4,52 @@ use crate::Node;
|
|||
|
||||
use bytemuck::{Pod, Zeroable};
|
||||
use glam::DVec2;
|
||||
use num::Num;
|
||||
#[cfg(not(target_arch = "spirv"))]
|
||||
use num_traits::{cast::cast as num_cast, Num, NumCast};
|
||||
#[cfg(target_arch = "spirv")]
|
||||
use spirv_std::num_traits::float::Float;
|
||||
use spirv_std::num_traits::{cast::cast as num_cast, float::Float, FromPrimitive, Num, NumCast, ToPrimitive};
|
||||
|
||||
pub use self::color::Color;
|
||||
|
||||
pub mod adjustments;
|
||||
#[cfg(not(target_arch = "spirv"))]
|
||||
pub mod brightness_contrast;
|
||||
pub mod color;
|
||||
pub use adjustments::*;
|
||||
|
||||
pub trait Channel: Copy + Debug + num::Num + num::NumCast {
|
||||
pub trait Channel: Copy + Debug + Num + NumCast {
|
||||
fn to_linear<Out: Linear>(self) -> Out;
|
||||
fn from_linear<In: Linear>(linear: In) -> Self;
|
||||
fn to_f32(self) -> f32 {
|
||||
num::cast(self).expect("Failed to convert channel to f32")
|
||||
num_cast(self).expect("Failed to convert channel to f32")
|
||||
}
|
||||
fn from_f32(value: f32) -> Self {
|
||||
num::cast(value).expect("Failed to convert f32 to channel")
|
||||
num_cast(value).expect("Failed to convert f32 to channel")
|
||||
}
|
||||
fn to_f64(self) -> f64 {
|
||||
num::cast(self).expect("Failed to convert channel to f64")
|
||||
num_cast(self).expect("Failed to convert channel to f64")
|
||||
}
|
||||
fn from_f64(value: f64) -> Self {
|
||||
num::cast(value).expect("Failed to convert f64 to channel")
|
||||
num_cast(value).expect("Failed to convert f64 to channel")
|
||||
}
|
||||
fn to_channel<Out: Channel>(self) -> Out {
|
||||
num::cast(self).expect("Failed to convert channel to channel")
|
||||
num_cast(self).expect("Failed to convert channel to channel")
|
||||
}
|
||||
}
|
||||
|
||||
pub trait Linear: num::NumCast + Num {}
|
||||
pub trait Linear: NumCast + Num {}
|
||||
impl Linear for f32 {}
|
||||
impl Linear for f64 {}
|
||||
|
||||
impl<T: Linear + Debug + Copy> Channel for T {
|
||||
#[inline(always)]
|
||||
fn to_linear<Out: Linear>(self) -> Out {
|
||||
num::cast(self).expect("Failed to convert channel to linear")
|
||||
num_cast(self).expect("Failed to convert channel to linear")
|
||||
}
|
||||
|
||||
#[inline(always)]
|
||||
fn from_linear<In: Linear>(linear: In) -> Self {
|
||||
num::cast(linear).expect("Failed to convert linear to channel")
|
||||
num_cast(linear).expect("Failed to convert linear to channel")
|
||||
}
|
||||
}
|
||||
|
||||
|
|
@ -58,16 +60,16 @@ struct SRGBGammaFloat(f32);
|
|||
impl Channel for SRGBGammaFloat {
|
||||
#[inline(always)]
|
||||
fn to_linear<Out: Linear>(self) -> Out {
|
||||
let channel = num::cast::<_, f32>(self).expect("Failed to convert srgb to linear");
|
||||
let channel = num_cast::<_, f32>(self).expect("Failed to convert srgb to linear");
|
||||
let out = if channel <= 0.04045 { channel / 12.92 } else { ((channel + 0.055) / 1.055).powf(2.4) };
|
||||
num::cast(out).expect("Failed to convert srgb to linear")
|
||||
num_cast(out).expect("Failed to convert srgb to linear")
|
||||
}
|
||||
|
||||
#[inline(always)]
|
||||
fn from_linear<In: Linear>(linear: In) -> Self {
|
||||
let linear = num::cast::<_, f32>(linear).expect("Failed to convert linear to srgb");
|
||||
let linear = num_cast::<_, f32>(linear).expect("Failed to convert linear to srgb");
|
||||
let out = if linear <= 0.0031308 { linear * 12.92 } else { 1.055 * linear.powf(1. / 2.4) - 0.055 };
|
||||
num::cast(out).expect("Failed to convert linear to srgb")
|
||||
num_cast(out).expect("Failed to convert linear to srgb")
|
||||
}
|
||||
}
|
||||
pub trait RGBPrimaries {
|
||||
|
|
@ -98,6 +100,7 @@ impl<T> Serde for T {}
|
|||
|
||||
// TODO: Come up with a better name for this trait
|
||||
pub trait Pixel: Clone + Pod + Zeroable {
|
||||
#[cfg(not(target_arch = "spirv"))]
|
||||
fn to_bytes(&self) -> Vec<u8> {
|
||||
bytemuck::bytes_of(self).to_vec()
|
||||
}
|
||||
|
|
@ -107,7 +110,7 @@ pub trait Pixel: Clone + Pod + Zeroable {
|
|||
}
|
||||
|
||||
fn byte_size() -> usize {
|
||||
std::mem::size_of::<Self>()
|
||||
core::mem::size_of::<Self>()
|
||||
}
|
||||
}
|
||||
pub trait RGB: Pixel {
|
||||
|
|
@ -448,6 +451,8 @@ pub struct ImageSlice<'a, Pixel> {
|
|||
pub data: &'a [Pixel],
|
||||
#[cfg(target_arch = "spirv")]
|
||||
pub data: &'a (),
|
||||
#[cfg(target_arch = "spirv")]
|
||||
pub _marker: PhantomData<Pixel>,
|
||||
}
|
||||
|
||||
unsafe impl<P: StaticTypeSized> StaticType for ImageSlice<'_, P> {
|
||||
|
|
@ -470,20 +475,17 @@ impl<'a, P> Default for ImageSlice<'a, P> {
|
|||
width: Default::default(),
|
||||
height: Default::default(),
|
||||
data: &NOTHING,
|
||||
_marker: PhantomData,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(not(target_arch = "spirv"))]
|
||||
impl<P: Copy + Debug + Pixel> Raster for ImageSlice<'_, P> {
|
||||
type Pixel = P;
|
||||
#[cfg(not(target_arch = "spirv"))]
|
||||
fn get_pixel(&self, x: u32, y: u32) -> Option<P> {
|
||||
self.data.get((x + y * self.width) as usize).copied()
|
||||
}
|
||||
#[cfg(target_arch = "spirv")]
|
||||
fn get_pixel(&self, _x: u32, _y: u32) -> P {
|
||||
Color::default()
|
||||
}
|
||||
fn width(&self) -> u32 {
|
||||
self.width
|
||||
}
|
||||
|
|
@ -605,6 +607,8 @@ mod image {
|
|||
}
|
||||
}
|
||||
|
||||
// TODO: Evaluate if this will be a problem for our use case.
|
||||
/// Warning: This is an approximation of a hash, and is not guaranteed to not collide.
|
||||
impl<P: Hash + Pixel> Hash for Image<P> {
|
||||
fn hash<H: Hasher>(&self, state: &mut H) {
|
||||
const HASH_SAMPLES: u64 = 1000;
|
||||
|
|
@ -661,7 +665,7 @@ mod image {
|
|||
let Image { width, height, data } = self;
|
||||
|
||||
let to_gamma = |x| SRGBGammaFloat::from_linear(x);
|
||||
let to_u8 = |x| (num::cast::<_, f32>(x).unwrap() * 255.) as u8;
|
||||
let to_u8 = |x| (num_cast::<_, f32>(x).unwrap() * 255.) as u8;
|
||||
|
||||
let result_bytes = data
|
||||
.into_iter()
|
||||
|
|
@ -670,7 +674,7 @@ mod image {
|
|||
to_u8(to_gamma(color.r() / color.a().to_channel())),
|
||||
to_u8(to_gamma(color.g() / color.a().to_channel())),
|
||||
to_u8(to_gamma(color.b() / color.a().to_channel())),
|
||||
(num::cast::<_, f32>(color.a()).unwrap() * 255.) as u8,
|
||||
(num_cast::<_, f32>(color.a()).unwrap() * 255.) as u8,
|
||||
]
|
||||
})
|
||||
.collect();
|
||||
|
|
|
|||
|
|
@ -3,6 +3,7 @@ use crate::Node;
|
|||
|
||||
use core::fmt::Debug;
|
||||
use dyn_any::{DynAny, StaticType};
|
||||
#[cfg(feature = "serde")]
|
||||
use serde::{Deserialize, Serialize};
|
||||
|
||||
#[cfg(target_arch = "spirv")]
|
||||
|
|
@ -457,8 +458,9 @@ fn vibrance_node(color: Color, vibrance: f64) -> Color {
|
|||
}
|
||||
}
|
||||
|
||||
#[repr(C)]
|
||||
#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize, Hash, DynAny, specta::Type)]
|
||||
#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
|
||||
#[cfg_attr(feature = "std", derive(specta::Type))]
|
||||
#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, DynAny)]
|
||||
pub enum RedGreenBlue {
|
||||
Red,
|
||||
Green,
|
||||
|
|
@ -542,8 +544,9 @@ fn channel_mixer_node(
|
|||
color.to_linear_srgb()
|
||||
}
|
||||
|
||||
#[repr(C)]
|
||||
#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize, Hash, DynAny, specta::Type)]
|
||||
#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
|
||||
#[cfg_attr(feature = "std", derive(specta::Type))]
|
||||
#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, DynAny)]
|
||||
pub enum RelativeAbsolute {
|
||||
Relative,
|
||||
Absolute,
|
||||
|
|
@ -559,7 +562,9 @@ impl core::fmt::Display for RelativeAbsolute {
|
|||
}
|
||||
|
||||
#[repr(C)]
|
||||
#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize, Hash, DynAny, specta::Type)]
|
||||
#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
|
||||
#[cfg_attr(feature = "std", derive(specta::Type))]
|
||||
#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, DynAny)]
|
||||
pub enum SelectiveColorChoice {
|
||||
Reds,
|
||||
Yellows,
|
||||
|
|
@ -797,17 +802,26 @@ fn exposure(color: Color, exposure: f64, offset: f64, gamma_correction: f64) ->
|
|||
adjusted.map_rgb(|c: f32| c.clamp(0., 1.))
|
||||
}
|
||||
|
||||
#[derive(Debug)]
|
||||
pub struct IndexNode<Index> {
|
||||
pub index: Index,
|
||||
}
|
||||
#[cfg(feature = "alloc")]
|
||||
pub use index_node::IndexNode;
|
||||
|
||||
#[node_macro::node_fn(IndexNode)]
|
||||
pub fn index_node(input: Vec<super::ImageFrame<Color>>, index: u32) -> super::ImageFrame<Color> {
|
||||
if (index as usize) < input.len() {
|
||||
input[index as usize].clone()
|
||||
} else {
|
||||
warn!("The number of segments is {} and the requested segment is {}!", input.len(), index);
|
||||
super::ImageFrame::empty()
|
||||
#[cfg(feature = "alloc")]
|
||||
mod index_node {
|
||||
use crate::raster::{Color, ImageFrame};
|
||||
use crate::Node;
|
||||
|
||||
#[derive(Debug)]
|
||||
pub struct IndexNode<Index> {
|
||||
pub index: Index,
|
||||
}
|
||||
|
||||
#[node_macro::node_fn(IndexNode)]
|
||||
pub fn index_node(input: Vec<ImageFrame<Color>>, index: u32) -> ImageFrame<Color> {
|
||||
if (index as usize) < input.len() {
|
||||
input[index as usize].clone()
|
||||
} else {
|
||||
warn!("The number of segments is {} and the requested segment is {}!", input.len(), index);
|
||||
ImageFrame::empty()
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
|
|||
|
|
@ -30,7 +30,7 @@ fn brightness_contrast_legacy_node(_primary: (), brightness: f32, contrast: f32)
|
|||
let brightness = brightness / 255.;
|
||||
|
||||
let contrast = contrast / 100.;
|
||||
let contrast = if contrast > 0. { (contrast * std::f32::consts::FRAC_PI_2 - 0.01).tan() } else { contrast };
|
||||
let contrast = if contrast > 0. { (contrast * core::f32::consts::FRAC_PI_2 - 0.01).tan() } else { contrast };
|
||||
|
||||
let combined = brightness * contrast + brightness - contrast / 2.;
|
||||
|
||||
|
|
@ -172,7 +172,7 @@ fn solve_cubic_splines(cubic_spline_values: &CubicSplines) -> [f32; 4] {
|
|||
|
||||
// Eliminate the current column in all rows below the current one
|
||||
for row_below_current in row + 1..4 {
|
||||
assert!(augmented_matrix[row][row].abs() > std::f32::EPSILON);
|
||||
assert!(augmented_matrix[row][row].abs() > core::f32::EPSILON);
|
||||
|
||||
let scale_factor = augmented_matrix[row_below_current][row] / augmented_matrix[row][row];
|
||||
for col in row..5 {
|
||||
|
|
@ -184,7 +184,7 @@ fn solve_cubic_splines(cubic_spline_values: &CubicSplines) -> [f32; 4] {
|
|||
// Gaussian elimination: back substitution
|
||||
let mut solutions = [0.; 4];
|
||||
for col in (0..4).rev() {
|
||||
assert!(augmented_matrix[col][col].abs() > std::f32::EPSILON);
|
||||
assert!(augmented_matrix[col][col].abs() > core::f32::EPSILON);
|
||||
|
||||
solutions[col] = augmented_matrix[col][4] / augmented_matrix[col][col];
|
||||
|
||||
|
|
|
|||
|
|
@ -53,6 +53,7 @@ impl RGB for Color {
|
|||
}
|
||||
|
||||
impl Pixel for Color {
|
||||
#[cfg(not(target_arch = "spirv"))]
|
||||
fn to_bytes(&self) -> Vec<u8> {
|
||||
self.to_rgba8_srgb().to_vec()
|
||||
}
|
||||
|
|
@ -121,7 +122,6 @@ impl Color {
|
|||
/// let color = Color::from_rgbaf32(1.0, 1.0, 1.0, f32::NAN);
|
||||
/// assert!(color == None);
|
||||
/// ```
|
||||
#[cfg(not(target_arch = "spirv"))]
|
||||
pub fn from_rgbaf32(red: f32, green: f32, blue: f32, alpha: f32) -> Option<Color> {
|
||||
if alpha > 1. || [red, green, blue, alpha].iter().any(|c| c.is_sign_negative() || !c.is_finite()) {
|
||||
return None;
|
||||
|
|
@ -492,7 +492,7 @@ impl Color {
|
|||
/// ```
|
||||
/// use graphene_core::raster::color::Color;
|
||||
/// let color = Color::from_rgbaf32(0.114, 0.103, 0.98, 0.97).unwrap();
|
||||
/// assert!(color.components() == (0.114, 0.103, 0.98, 0.97));
|
||||
/// assert_eq!(color.components(), (0.114, 0.103, 0.98, 0.97));
|
||||
/// ```
|
||||
pub fn components(&self) -> (f32, f32, f32, f32) {
|
||||
(self.red, self.green, self.blue, self.alpha)
|
||||
|
|
@ -585,7 +585,6 @@ impl Color {
|
|||
/// use graphene_core::raster::color::Color;
|
||||
/// let color = Color::from_rgba_str("7C67FA61").unwrap();
|
||||
/// ```
|
||||
#[cfg(not(target_arch = "spirv"))]
|
||||
pub fn from_rgba_str(color_str: &str) -> Option<Color> {
|
||||
if color_str.len() != 8 {
|
||||
return None;
|
||||
|
|
@ -603,7 +602,6 @@ impl Color {
|
|||
/// use graphene_core::raster::color::Color;
|
||||
/// let color = Color::from_rgb_str("7C67FA").unwrap();
|
||||
/// ```
|
||||
#[cfg(not(target_arch = "spirv"))]
|
||||
pub fn from_rgb_str(color_str: &str) -> Option<Color> {
|
||||
if color_str.len() != 6 {
|
||||
return None;
|
||||
|
|
|
|||
|
|
@ -0,0 +1,34 @@
|
|||
use crate::Node;
|
||||
|
||||
use core::marker::PhantomData;
|
||||
use core::ops::{DerefMut, Index, IndexMut};
|
||||
|
||||
struct SetNode<S, I, Storage, Index> {
|
||||
storage: Storage,
|
||||
index: Index,
|
||||
_s: PhantomData<S>,
|
||||
_i: PhantomData<I>,
|
||||
}
|
||||
|
||||
#[node_macro::node_fn(SetNode<_S, _I>)]
|
||||
fn set_node<T, _S, _I>(value: T, storage: &'any_input mut _S, index: _I)
|
||||
where
|
||||
_S: IndexMut<_I>,
|
||||
_S::Output: DerefMut<Target = T> + Sized,
|
||||
{
|
||||
*storage.index_mut(index).deref_mut() = value;
|
||||
}
|
||||
|
||||
struct GetNode<S, Storage> {
|
||||
storage: Storage,
|
||||
_s: PhantomData<S>,
|
||||
}
|
||||
|
||||
#[node_macro::node_fn(GetNode<_S>)]
|
||||
fn get_node<_S, I>(index: I, storage: &'any_input _S) -> &'input _S::Output
|
||||
where
|
||||
_S: Index<I>,
|
||||
_S::Output: Sized,
|
||||
{
|
||||
storage.index(index)
|
||||
}
|
||||
|
|
@ -2,60 +2,47 @@ use core::marker::PhantomData;
|
|||
|
||||
use crate::Node;
|
||||
|
||||
pub struct ComposeNode<First: for<'i> Node<'i, I>, Second: for<'i> Node<'i, <First as Node<'i, I>>::Output>, I> {
|
||||
#[derive(Clone)]
|
||||
pub struct ComposeNode<First, Second, I> {
|
||||
first: First,
|
||||
second: Second,
|
||||
phantom: PhantomData<I>,
|
||||
}
|
||||
|
||||
impl<'i, Input: 'i, First, Second> Node<'i, Input> for ComposeNode<First, Second, Input>
|
||||
impl<'i, 'f: 'i, 's: 'i, Input: 'i, First, Second> Node<'i, Input> for ComposeNode<First, Second, Input>
|
||||
where
|
||||
First: for<'a> Node<'a, Input> + 'i,
|
||||
Second: for<'a> Node<'a, <First as Node<'a, Input>>::Output> + 'i,
|
||||
First: Node<'i, Input>,
|
||||
Second: Node<'i, <First as Node<'i, Input>>::Output> + 'i,
|
||||
{
|
||||
type Output = <Second as Node<'i, <First as Node<'i, Input>>::Output>>::Output;
|
||||
fn eval(&'i self, input: Input) -> Self::Output {
|
||||
let arg = self.first.eval(input);
|
||||
self.second.eval(arg)
|
||||
let second = &self.second;
|
||||
second.eval(arg)
|
||||
}
|
||||
}
|
||||
|
||||
impl<First, Second, Input> ComposeNode<First, Second, Input>
|
||||
impl<'i, First, Second, Input: 'i> ComposeNode<First, Second, Input>
|
||||
where
|
||||
First: for<'a> Node<'a, Input>,
|
||||
Second: for<'a> Node<'a, <First as Node<'a, Input>>::Output>,
|
||||
First: Node<'i, Input>,
|
||||
Second: Node<'i, <First as Node<'i, Input>>::Output>,
|
||||
{
|
||||
pub const fn new(first: First, second: Second) -> Self {
|
||||
ComposeNode::<First, Second, Input> { first, second, phantom: PhantomData }
|
||||
}
|
||||
}
|
||||
|
||||
// impl Clone for ComposeNode<First, Second, Input>
|
||||
impl<First, Second, Input> Clone for ComposeNode<First, Second, Input>
|
||||
where
|
||||
First: for<'a> Node<'a, Input> + Clone,
|
||||
Second: for<'a> Node<'a, <First as Node<'a, Input>>::Output> + Clone,
|
||||
{
|
||||
fn clone(&self) -> Self {
|
||||
ComposeNode::<First, Second, Input> {
|
||||
first: self.first.clone(),
|
||||
second: self.second.clone(),
|
||||
phantom: PhantomData,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
pub trait Then<'i, Input: 'i>: Sized {
|
||||
fn then<Second>(self, second: Second) -> ComposeNode<Self, Second, Input>
|
||||
where
|
||||
Self: for<'a> Node<'a, Input>,
|
||||
Second: for<'a> Node<'a, <Self as Node<'a, Input>>::Output>,
|
||||
Self: Node<'i, Input>,
|
||||
Second: Node<'i, <Self as Node<'i, Input>>::Output>,
|
||||
{
|
||||
ComposeNode::new(self, second)
|
||||
}
|
||||
}
|
||||
|
||||
impl<'i, First: for<'a> Node<'a, Input>, Input: 'i> Then<'i, Input> for First {}
|
||||
impl<'i, First: Node<'i, Input>, Input: 'i> Then<'i, Input> for First {}
|
||||
|
||||
pub struct ConsNode<I: From<()>, Root>(pub Root, PhantomData<I>);
|
||||
|
||||
|
|
@ -89,4 +76,16 @@ mod test {
|
|||
let type_erased = &compose as &dyn for<'i> Node<'i, (), Output = &'i u32>;
|
||||
assert_eq!(type_erased.eval(()), &4u32);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_ref_eval() {
|
||||
let value = ValueNode::new(5);
|
||||
|
||||
assert_eq!((&value).eval(()), &5);
|
||||
let id = IdNode::new();
|
||||
|
||||
let compose = ComposeNode::new(&value, &id);
|
||||
|
||||
assert_eq!(compose.eval(()), &5);
|
||||
}
|
||||
}
|
||||
|
|
|
|||
|
|
@ -2,6 +2,7 @@ use core::any::TypeId;
|
|||
|
||||
#[cfg(not(feature = "std"))]
|
||||
pub use alloc::borrow::Cow;
|
||||
use dyn_any::StaticType;
|
||||
#[cfg(feature = "std")]
|
||||
pub use std::borrow::Cow;
|
||||
|
||||
|
|
@ -28,6 +29,8 @@ macro_rules! concrete {
|
|||
Type::Concrete(TypeDescriptor {
|
||||
id: Some(core::any::TypeId::of::<$type>()),
|
||||
name: Cow::Borrowed(core::any::type_name::<$type>()),
|
||||
size: core::mem::size_of::<$type>(),
|
||||
align: core::mem::align_of::<$type>(),
|
||||
})
|
||||
};
|
||||
}
|
||||
|
|
@ -65,6 +68,8 @@ pub struct TypeDescriptor {
|
|||
#[specta(skip)]
|
||||
pub id: Option<TypeId>,
|
||||
pub name: Cow<'static, str>,
|
||||
pub size: usize,
|
||||
pub align: usize,
|
||||
}
|
||||
|
||||
impl core::hash::Hash for TypeDescriptor {
|
||||
|
|
@ -137,6 +142,32 @@ impl Type {
|
|||
}
|
||||
}
|
||||
|
||||
impl Type {
|
||||
pub fn new<T: StaticType + Sized>() -> Self {
|
||||
Self::Concrete(TypeDescriptor {
|
||||
id: Some(TypeId::of::<T::Static>()),
|
||||
name: Cow::Borrowed(core::any::type_name::<T::Static>()),
|
||||
size: core::mem::size_of::<T>(),
|
||||
align: core::mem::align_of::<T>(),
|
||||
})
|
||||
}
|
||||
pub fn size(&self) -> Option<usize> {
|
||||
match self {
|
||||
Self::Generic(_) => None,
|
||||
Self::Concrete(ty) => Some(ty.size),
|
||||
Self::Fn(_, _) => None,
|
||||
}
|
||||
}
|
||||
|
||||
pub fn align(&self) -> Option<usize> {
|
||||
match self {
|
||||
Self::Generic(_) => None,
|
||||
Self::Concrete(ty) => Some(ty.align),
|
||||
Self::Fn(_, _) => None,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl core::fmt::Debug for Type {
|
||||
fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result {
|
||||
match self {
|
||||
|
|
|
|||
|
|
@ -40,7 +40,7 @@ impl<T: Clone> Clone for ValueNode<T> {
|
|||
}
|
||||
impl<T: Clone + Copy> Copy for ValueNode<T> {}
|
||||
|
||||
#[derive(Clone)]
|
||||
#[derive(Clone, Copy)]
|
||||
pub struct ClonedNode<T: Clone>(pub T);
|
||||
|
||||
impl<'i, T: Clone + 'i> Node<'i, ()> for ClonedNode<T> {
|
||||
|
|
@ -61,7 +61,22 @@ impl<T: Clone> From<T> for ClonedNode<T> {
|
|||
ClonedNode::new(value)
|
||||
}
|
||||
}
|
||||
impl<T: Clone + Copy> Copy for ClonedNode<T> {}
|
||||
|
||||
#[derive(Clone, Copy)]
|
||||
pub struct CopiedNode<T: Copy>(pub T);
|
||||
|
||||
impl<'i, T: Copy + 'i> Node<'i, ()> for CopiedNode<T> {
|
||||
type Output = T;
|
||||
fn eval(&'i self, _input: ()) -> Self::Output {
|
||||
self.0
|
||||
}
|
||||
}
|
||||
|
||||
impl<T: Copy> CopiedNode<T> {
|
||||
pub const fn new(value: T) -> CopiedNode<T> {
|
||||
CopiedNode(value)
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Default)]
|
||||
pub struct DefaultNode<T>(PhantomData<T>);
|
||||
|
|
|
|||
|
|
@ -122,6 +122,20 @@ name = "bytemuck"
|
|||
version = "1.12.3"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "aaa3a8d9a1ca92e282c96a32d6511b695d7d994d1d102ba85d279f9b2756947f"
|
||||
dependencies = [
|
||||
"bytemuck_derive",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "bytemuck_derive"
|
||||
version = "1.4.0"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "1aca418a974d83d40a0c1f0c5cba6ff4bc28d8df099109ca459a2118d40b6322"
|
||||
dependencies = [
|
||||
"proc-macro2",
|
||||
"quote",
|
||||
"syn",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "byteorder"
|
||||
|
|
@ -293,7 +307,7 @@ dependencies = [
|
|||
|
||||
[[package]]
|
||||
name = "dyn-any"
|
||||
version = "0.2.1"
|
||||
version = "0.3.1"
|
||||
dependencies = [
|
||||
"dyn-any-derive",
|
||||
"glam",
|
||||
|
|
@ -302,7 +316,7 @@ dependencies = [
|
|||
|
||||
[[package]]
|
||||
name = "dyn-any-derive"
|
||||
version = "0.2.1"
|
||||
version = "0.3.0"
|
||||
dependencies = [
|
||||
"proc-macro2",
|
||||
"quote",
|
||||
|
|
@ -346,6 +360,106 @@ version = "1.0.7"
|
|||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "3f9eec918d3f24069decb9af1554cad7c880e2da24a9afd88aca000531ab82c1"
|
||||
|
||||
[[package]]
|
||||
name = "futures"
|
||||
version = "0.3.27"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "531ac96c6ff5fd7c62263c5e3c67a603af4fcaee2e1a0ae5565ba3a11e69e549"
|
||||
dependencies = [
|
||||
"futures-channel",
|
||||
"futures-core",
|
||||
"futures-executor",
|
||||
"futures-io",
|
||||
"futures-sink",
|
||||
"futures-task",
|
||||
"futures-util",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "futures-channel"
|
||||
version = "0.3.27"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "164713a5a0dcc3e7b4b1ed7d3b433cabc18025386f9339346e8daf15963cf7ac"
|
||||
dependencies = [
|
||||
"futures-core",
|
||||
"futures-sink",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "futures-core"
|
||||
version = "0.3.27"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "86d7a0c1aa76363dac491de0ee99faf6941128376f1cf96f07db7603b7de69dd"
|
||||
|
||||
[[package]]
|
||||
name = "futures-executor"
|
||||
version = "0.3.27"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "1997dd9df74cdac935c76252744c1ed5794fac083242ea4fe77ef3ed60ba0f83"
|
||||
dependencies = [
|
||||
"futures-core",
|
||||
"futures-task",
|
||||
"futures-util",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "futures-intrusive"
|
||||
version = "0.5.0"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "1d930c203dd0b6ff06e0201a4a2fe9149b43c684fd4420555b26d21b1a02956f"
|
||||
dependencies = [
|
||||
"futures-core",
|
||||
"lock_api",
|
||||
"parking_lot",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "futures-io"
|
||||
version = "0.3.27"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "89d422fa3cbe3b40dca574ab087abb5bc98258ea57eea3fd6f1fa7162c778b91"
|
||||
|
||||
[[package]]
|
||||
name = "futures-macro"
|
||||
version = "0.3.27"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "3eb14ed937631bd8b8b8977f2c198443447a8355b6e3ca599f38c975e5a963b6"
|
||||
dependencies = [
|
||||
"proc-macro2",
|
||||
"quote",
|
||||
"syn",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "futures-sink"
|
||||
version = "0.3.27"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "ec93083a4aecafb2a80a885c9de1f0ccae9dbd32c2bb54b0c3a65690e0b8d2f2"
|
||||
|
||||
[[package]]
|
||||
name = "futures-task"
|
||||
version = "0.3.27"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "fd65540d33b37b16542a0438c12e6aeead10d4ac5d05bd3f805b8f35ab592879"
|
||||
|
||||
[[package]]
|
||||
name = "futures-util"
|
||||
version = "0.3.27"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "3ef6b17e481503ec85211fed8f39d1970f128935ca1f814cd32ac4a6842e84ab"
|
||||
dependencies = [
|
||||
"futures-channel",
|
||||
"futures-core",
|
||||
"futures-io",
|
||||
"futures-macro",
|
||||
"futures-sink",
|
||||
"futures-task",
|
||||
"memchr",
|
||||
"pin-project-lite",
|
||||
"pin-utils",
|
||||
"slab",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "fxhash"
|
||||
version = "0.2.1"
|
||||
|
|
@ -382,6 +496,8 @@ version = "0.22.0"
|
|||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "12f597d56c1bd55a811a1be189459e8fad2bbc272616375602443bdfb37fa774"
|
||||
dependencies = [
|
||||
"bytemuck",
|
||||
"num-traits",
|
||||
"serde",
|
||||
]
|
||||
|
||||
|
|
@ -418,6 +534,7 @@ dependencies = [
|
|||
"bytemuck",
|
||||
"dyn-any",
|
||||
"glam",
|
||||
"gpu-executor",
|
||||
"graph-craft",
|
||||
"graphene-core",
|
||||
"log",
|
||||
|
|
@ -430,6 +547,26 @@ dependencies = [
|
|||
"tera",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "gpu-executor"
|
||||
version = "0.1.0"
|
||||
dependencies = [
|
||||
"anyhow",
|
||||
"base64",
|
||||
"bytemuck",
|
||||
"dyn-any",
|
||||
"futures",
|
||||
"futures-intrusive",
|
||||
"glam",
|
||||
"graph-craft",
|
||||
"graphene-core",
|
||||
"log",
|
||||
"node-macro",
|
||||
"num-traits",
|
||||
"serde",
|
||||
"spirv",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "graph-craft"
|
||||
version = "0.1.0"
|
||||
|
|
@ -454,17 +591,22 @@ name = "graphene-core"
|
|||
version = "0.1.0"
|
||||
dependencies = [
|
||||
"async-trait",
|
||||
"base64",
|
||||
"bezier-rs",
|
||||
"bytemuck",
|
||||
"dyn-any",
|
||||
"glam",
|
||||
"kurbo",
|
||||
"log",
|
||||
"node-macro",
|
||||
"num-derive",
|
||||
"num-traits",
|
||||
"once_cell",
|
||||
"rand_chacha",
|
||||
"serde",
|
||||
"specta",
|
||||
"spin",
|
||||
"spirv-std",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
|
|
@ -549,6 +691,12 @@ dependencies = [
|
|||
"cfg-if",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "internal-iterator"
|
||||
version = "0.2.1"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "a668ef46056a63366da9d74f48062da9ece1a27958f2f3704aa6f7421c4433f5"
|
||||
|
||||
[[package]]
|
||||
name = "itertools"
|
||||
version = "0.10.5"
|
||||
|
|
@ -643,6 +791,12 @@ dependencies = [
|
|||
"cfg-if",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "longest-increasing-subsequence"
|
||||
version = "0.1.0"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "b3bd0dd2cd90571056fdb71f6275fada10131182f84899f4b2a916e565d81d86"
|
||||
|
||||
[[package]]
|
||||
name = "memchr"
|
||||
version = "2.5.0"
|
||||
|
|
@ -658,6 +812,17 @@ dependencies = [
|
|||
"syn",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "num-derive"
|
||||
version = "0.3.3"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "876a53fff98e03a936a674b29568b0e605f06b29372c2489ff4de23f1949743d"
|
||||
dependencies = [
|
||||
"proc-macro2",
|
||||
"quote",
|
||||
"syn",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "num-integer"
|
||||
version = "0.1.45"
|
||||
|
|
@ -693,6 +858,29 @@ version = "1.17.1"
|
|||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "b7e5500299e16ebb147ae15a00a942af264cf3688f47923b8fc2cd5858f23ad3"
|
||||
|
||||
[[package]]
|
||||
name = "parking_lot"
|
||||
version = "0.12.1"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "3742b2c103b9f06bc9fff0a37ff4912935851bee6d36f3c02bcc755bcfec228f"
|
||||
dependencies = [
|
||||
"lock_api",
|
||||
"parking_lot_core",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "parking_lot_core"
|
||||
version = "0.9.7"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "9069cbb9f99e3a5083476ccb29ceb1de18b9118cafa53e90c9551235de2b9521"
|
||||
dependencies = [
|
||||
"cfg-if",
|
||||
"libc",
|
||||
"redox_syscall",
|
||||
"smallvec",
|
||||
"windows-sys",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "parse-zoneinfo"
|
||||
version = "0.3.0"
|
||||
|
|
@ -797,6 +985,18 @@ dependencies = [
|
|||
"uncased",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "pin-project-lite"
|
||||
version = "0.2.9"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "e0a7ae3ac2f1173085d398531c705756c94a4c56843785df85a60c1a0afac116"
|
||||
|
||||
[[package]]
|
||||
name = "pin-utils"
|
||||
version = "0.1.0"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "8b870d8c151b6f2fb93e84a13146138f05d02ed11c7e7c54f8826aaaf7c9f184"
|
||||
|
||||
[[package]]
|
||||
name = "ppv-lite86"
|
||||
version = "0.2.17"
|
||||
|
|
@ -917,14 +1117,15 @@ checksum = "08d43f7aa6b08d49f382cde6a7982047c3426db949b1424bc4b7ec9ae12c6ce2"
|
|||
|
||||
[[package]]
|
||||
name = "rustc_codegen_spirv"
|
||||
version = "0.5.0"
|
||||
version = "0.7.0"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "1b803b49618cdde99e1065af9415f489b993374765e30a6b80f2bea2cca65914"
|
||||
checksum = "48862935255ac76002118e3561c54f3fb413a752ef5cd0bdae693cf5c5ce37a5"
|
||||
dependencies = [
|
||||
"ar",
|
||||
"either",
|
||||
"hashbrown 0.11.2",
|
||||
"indexmap",
|
||||
"itertools",
|
||||
"lazy_static",
|
||||
"libc",
|
||||
"num-traits",
|
||||
|
|
@ -944,9 +1145,9 @@ dependencies = [
|
|||
|
||||
[[package]]
|
||||
name = "rustc_codegen_spirv-types"
|
||||
version = "0.5.0"
|
||||
version = "0.7.0"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "330aedc6b09b9bf3c58cc7fb942c1377310a9cff00fae9e4f6cc09a7a28f542e"
|
||||
checksum = "339309d7fce2e7204decea1b2683c4d439a6ebe4c1518d8134fe10aefeae7362"
|
||||
dependencies = [
|
||||
"rspirv",
|
||||
"serde",
|
||||
|
|
@ -1037,6 +1238,15 @@ version = "0.3.10"
|
|||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "7bd3e3206899af3f8b12af284fafc038cc1dc2b41d1b89dd17297221c5d225de"
|
||||
|
||||
[[package]]
|
||||
name = "slab"
|
||||
version = "0.4.8"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "6528351c9bc8ab22353f9d776db39a20288e8d6c37ef8cfe3317cf875eecfc2d"
|
||||
dependencies = [
|
||||
"autocfg",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "slug"
|
||||
version = "0.1.4"
|
||||
|
|
@ -1092,16 +1302,18 @@ dependencies = [
|
|||
|
||||
[[package]]
|
||||
name = "spirt"
|
||||
version = "0.1.0"
|
||||
version = "0.2.0"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "06834ebbbbc6f86448fd5dc7ccbac80e36f52f8d66838683752e19d3cae9a459"
|
||||
checksum = "e24fa996f12f3c667efbceaa99c222b8910a295a14d2c43c3880dfab2752def7"
|
||||
dependencies = [
|
||||
"arrayvec",
|
||||
"bytemuck",
|
||||
"elsa",
|
||||
"indexmap",
|
||||
"internal-iterator",
|
||||
"itertools",
|
||||
"lazy_static",
|
||||
"longest-increasing-subsequence",
|
||||
"rustc-hash",
|
||||
"serde",
|
||||
"serde_json",
|
||||
|
|
@ -1120,9 +1332,9 @@ dependencies = [
|
|||
|
||||
[[package]]
|
||||
name = "spirv-builder"
|
||||
version = "0.5.0"
|
||||
version = "0.7.0"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "93f656f97ac742e5603843d2ea3ea644cdf5630635b3320ec87666385766e7ab"
|
||||
checksum = "0310607328cbb098b681e4580a25871fcf72dd9e1d560e7cc09c409eedef2a27"
|
||||
dependencies = [
|
||||
"memchr",
|
||||
"raw-string",
|
||||
|
|
@ -1132,6 +1344,37 @@ dependencies = [
|
|||
"serde_json",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "spirv-std"
|
||||
version = "0.7.0"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "3197bd4c021c2dfc0f9dfb356312c8f7842d972d5545c308ad86422c2e2d3e66"
|
||||
dependencies = [
|
||||
"bitflags",
|
||||
"glam",
|
||||
"num-traits",
|
||||
"spirv-std-macros",
|
||||
"spirv-std-types",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "spirv-std-macros"
|
||||
version = "0.7.0"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "bbaffad626ab9d3ac61c4b74b5d51cb52f1939a8041d7ac09ec828eb4ad44d72"
|
||||
dependencies = [
|
||||
"proc-macro2",
|
||||
"quote",
|
||||
"spirv-std-types",
|
||||
"syn",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "spirv-std-types"
|
||||
version = "0.7.0"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "ab83875e851bc803c687024d2d950730f350c0073714b95b3a6b1d22e9eac42a"
|
||||
|
||||
[[package]]
|
||||
name = "spirv-tools"
|
||||
version = "0.9.0"
|
||||
|
|
@ -1434,6 +1677,72 @@ version = "0.4.0"
|
|||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "712e227841d057c1ee1cd2fb22fa7e5a5461ae8e48fa2ca79ec42cfc1931183f"
|
||||
|
||||
[[package]]
|
||||
name = "windows-sys"
|
||||
version = "0.45.0"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "75283be5efb2831d37ea142365f009c02ec203cd29a3ebecbc093d52315b66d0"
|
||||
dependencies = [
|
||||
"windows-targets",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "windows-targets"
|
||||
version = "0.42.2"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "8e5180c00cd44c9b1c88adb3693291f1cd93605ded80c250a75d472756b4d071"
|
||||
dependencies = [
|
||||
"windows_aarch64_gnullvm",
|
||||
"windows_aarch64_msvc",
|
||||
"windows_i686_gnu",
|
||||
"windows_i686_msvc",
|
||||
"windows_x86_64_gnu",
|
||||
"windows_x86_64_gnullvm",
|
||||
"windows_x86_64_msvc",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "windows_aarch64_gnullvm"
|
||||
version = "0.42.2"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "597a5118570b68bc08d8d59125332c54f1ba9d9adeedeef5b99b02ba2b0698f8"
|
||||
|
||||
[[package]]
|
||||
name = "windows_aarch64_msvc"
|
||||
version = "0.42.2"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "e08e8864a60f06ef0d0ff4ba04124db8b0fb3be5776a5cd47641e942e58c4d43"
|
||||
|
||||
[[package]]
|
||||
name = "windows_i686_gnu"
|
||||
version = "0.42.2"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "c61d927d8da41da96a81f029489353e68739737d3beca43145c8afec9a31a84f"
|
||||
|
||||
[[package]]
|
||||
name = "windows_i686_msvc"
|
||||
version = "0.42.2"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "44d840b6ec649f480a41c8d80f9c65108b92d89345dd94027bfe06ac444d1060"
|
||||
|
||||
[[package]]
|
||||
name = "windows_x86_64_gnu"
|
||||
version = "0.42.2"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "8de912b8b8feb55c064867cf047dda097f92d51efad5b491dfb98f6bbb70cb36"
|
||||
|
||||
[[package]]
|
||||
name = "windows_x86_64_gnullvm"
|
||||
version = "0.42.2"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "26d41b46a36d453748aedef1486d5c7a85db22e56aff34643984ea85514e94a3"
|
||||
|
||||
[[package]]
|
||||
name = "windows_x86_64_msvc"
|
||||
version = "0.42.2"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "9aec5da331524158c6d1a4ac0ab1541149c0b9505fde06423b02f5ef0106b9f0"
|
||||
|
||||
[[package]]
|
||||
name = "xxhash-rust"
|
||||
version = "0.8.6"
|
||||
|
|
|
|||
|
|
@ -13,18 +13,25 @@ serde = ["graphene-core/serde", "glam/serde"]
|
|||
|
||||
[dependencies]
|
||||
graphene-core = { path = "../gcore", features = ["async", "std", "alloc"] }
|
||||
graph-craft = {path = "../graph-craft", features = ["serde"] }
|
||||
dyn-any = { path = "../../libraries/dyn-any", features = ["log-bad-types", "rc", "glam"] }
|
||||
graph-craft = { path = "../graph-craft", features = ["serde"] }
|
||||
gpu-executor = { path = "../gpu-executor" }
|
||||
dyn-any = { path = "../../libraries/dyn-any", features = [
|
||||
"log-bad-types",
|
||||
"rc",
|
||||
"glam",
|
||||
] }
|
||||
num-traits = "0.2"
|
||||
log = "0.4"
|
||||
serde = { version = "1", features = ["derive", "rc"]}
|
||||
serde = { version = "1", features = ["derive", "rc"] }
|
||||
glam = { version = "0.22" }
|
||||
base64 = "0.13"
|
||||
|
||||
bytemuck = { version = "1.8" }
|
||||
nvtx = { version = "1.1.1", optional = true }
|
||||
tempfile = "3"
|
||||
spirv-builder = { version = "0.5", default-features = false, features=["use-installed-tools"] }
|
||||
spirv-builder = { version = "0.7", default-features = false, features = [
|
||||
"use-installed-tools",
|
||||
] }
|
||||
tera = { version = "1.17.1" }
|
||||
anyhow = "1.0.66"
|
||||
serde_json = "1.0.91"
|
||||
|
|
|
|||
|
|
@ -11,7 +11,8 @@ profiling = []
|
|||
# See more keys and their definitions at https://doc.rust-lang.org/cargo/reference/manifest.html
|
||||
|
||||
[dependencies]
|
||||
graph-craft = {path = "../../graph-craft" , features = ["serde"]}
|
||||
graph-craft = { path = "../../graph-craft", features = ["serde"] }
|
||||
gpu-executor = { path = "../../gpu-executor" }
|
||||
log = "0.4"
|
||||
anyhow = "1.0.66"
|
||||
serde_json = "1.0.91"
|
||||
|
|
|
|||
|
|
@ -1,8 +1,15 @@
|
|||
use gpu_executor::ShaderIO;
|
||||
use graph_craft::{proto::ProtoNetwork, Type};
|
||||
|
||||
use serde::{Deserialize, Serialize};
|
||||
use std::io::Write;
|
||||
|
||||
pub fn compile_spirv(network: &graph_craft::document::NodeNetwork, input_type: &str, output_type: &str, compile_dir: Option<&str>, manifest_path: &str) -> anyhow::Result<Vec<u8>> {
|
||||
let serialized_graph = serde_json::to_string(&network)?;
|
||||
pub fn compile_spirv(request: &CompileRequest, compile_dir: Option<&str>, manifest_path: &str) -> anyhow::Result<Vec<u8>> {
|
||||
let serialized_graph = serde_json::to_string(&gpu_executor::CompileRequest {
|
||||
network: request.network.clone(),
|
||||
io: request.shader_io.clone(),
|
||||
})?;
|
||||
|
||||
let features = "";
|
||||
#[cfg(feature = "profiling")]
|
||||
let features = "profiling";
|
||||
|
|
@ -19,9 +26,6 @@ pub fn compile_spirv(network: &graph_craft::document::NodeNetwork, input_type: &
|
|||
.envs(non_cargo_env_vars)
|
||||
.arg("--features")
|
||||
.arg(features)
|
||||
.arg("--")
|
||||
.arg(input_type)
|
||||
.arg(output_type)
|
||||
// TODO: handle None case properly
|
||||
.arg(compile_dir.unwrap())
|
||||
.stdin(std::process::Stdio::piped())
|
||||
|
|
@ -38,16 +42,27 @@ pub fn compile_spirv(network: &graph_craft::document::NodeNetwork, input_type: &
|
|||
|
||||
#[derive(Serialize, Deserialize, Debug, Clone, PartialEq)]
|
||||
pub struct CompileRequest {
|
||||
network: graph_craft::document::NodeNetwork,
|
||||
input_type: String,
|
||||
output_type: String,
|
||||
network: graph_craft::proto::ProtoNetwork,
|
||||
input_types: Vec<Type>,
|
||||
output_type: Type,
|
||||
shader_io: ShaderIO,
|
||||
}
|
||||
|
||||
impl CompileRequest {
|
||||
pub fn new(network: graph_craft::document::NodeNetwork, input_type: String, output_type: String) -> Self {
|
||||
Self { network, input_type, output_type }
|
||||
pub fn new(network: ProtoNetwork, input_types: Vec<Type>, output_type: Type, io: ShaderIO) -> Self {
|
||||
// TODO: add type checking
|
||||
// for (input, buffer) in input_types.iter().zip(io.inputs.iter()) {
|
||||
// assert_eq!(input, &buffer.ty());
|
||||
// }
|
||||
// assert_eq!(output_type, io.output.ty());
|
||||
Self {
|
||||
network,
|
||||
input_types,
|
||||
output_type,
|
||||
shader_io: io,
|
||||
}
|
||||
}
|
||||
pub fn compile(&self, compile_dir: &str, manifest_path: &str) -> anyhow::Result<Vec<u8>> {
|
||||
compile_spirv(&self.network, &self.input_type, &self.output_type, Some(compile_dir), manifest_path)
|
||||
compile_spirv(self, Some(compile_dir), manifest_path)
|
||||
}
|
||||
}
|
||||
|
|
|
|||
|
|
@ -1,3 +1,10 @@
|
|||
[toolchain]
|
||||
channel = "nightly-2022-12-18"
|
||||
components = ["rust-src", "rustc-dev", "llvm-tools-preview", "clippy", "cargofmt", "rustc"]
|
||||
channel = "nightly-2023-03-04"
|
||||
components = [
|
||||
"rust-src",
|
||||
"rustc-dev",
|
||||
"llvm-tools-preview",
|
||||
"clippy",
|
||||
"rustfmt",
|
||||
"rustc",
|
||||
]
|
||||
|
|
|
|||
|
|
@ -1,6 +1,8 @@
|
|||
use std::path::Path;
|
||||
|
||||
use gpu_executor::{GPUConstant, ShaderIO, ShaderInput, SpirVCompiler};
|
||||
use graph_craft::proto::*;
|
||||
use graphene_core::Cow;
|
||||
|
||||
use std::path::{Path, PathBuf};
|
||||
use tera::Context;
|
||||
|
||||
fn create_cargo_toml(metadata: &Metadata) -> Result<String, tera::Error> {
|
||||
|
|
@ -24,10 +26,10 @@ impl Metadata {
|
|||
}
|
||||
}
|
||||
|
||||
pub fn create_files(matadata: &Metadata, network: &ProtoNetwork, compile_dir: &Path, input_type: &str, output_type: &str) -> anyhow::Result<()> {
|
||||
pub fn create_files(metadata: &Metadata, network: &ProtoNetwork, compile_dir: &Path, io: &ShaderIO) -> anyhow::Result<()> {
|
||||
let src = compile_dir.join("src");
|
||||
let cargo_file = compile_dir.join("Cargo.toml");
|
||||
let cargo_toml = create_cargo_toml(matadata)?;
|
||||
let cargo_toml = create_cargo_toml(metadata)?;
|
||||
std::fs::write(cargo_file, cargo_toml)?;
|
||||
|
||||
let toolchain_file = compile_dir.join("rust-toolchain.toml");
|
||||
|
|
@ -44,26 +46,100 @@ pub fn create_files(matadata: &Metadata, network: &ProtoNetwork, compile_dir: &P
|
|||
}
|
||||
}
|
||||
let lib = src.join("lib.rs");
|
||||
let shader = serialize_gpu(network, input_type, output_type)?;
|
||||
println!("{}", shader);
|
||||
let shader = serialize_gpu(network, io)?;
|
||||
eprintln!("{}", shader);
|
||||
std::fs::write(lib, shader)?;
|
||||
Ok(())
|
||||
}
|
||||
|
||||
pub fn serialize_gpu(network: &ProtoNetwork, input_type: &str, output_type: &str) -> anyhow::Result<String> {
|
||||
assert_eq!(network.inputs.len(), 1);
|
||||
fn constant_attribute(constant: &GPUConstant) -> &'static str {
|
||||
match constant {
|
||||
GPUConstant::SubGroupId => "subgroup_id",
|
||||
GPUConstant::SubGroupInvocationId => "subgroup_local_invocation_id",
|
||||
GPUConstant::SubGroupSize => todo!(),
|
||||
GPUConstant::NumSubGroups => "num_subgroups",
|
||||
GPUConstant::WorkGroupId => "workgroup_id",
|
||||
GPUConstant::WorkGroupInvocationId => "local_invocation_id",
|
||||
GPUConstant::WorkGroupSize => todo!(),
|
||||
GPUConstant::NumWorkGroups => "num_workgroups",
|
||||
GPUConstant::GlobalInvocationId => "global_invocation_id",
|
||||
GPUConstant::GlobalSize => todo!(),
|
||||
}
|
||||
}
|
||||
|
||||
pub fn construct_argument(input: &ShaderInput<()>, position: u32) -> String {
|
||||
match input {
|
||||
ShaderInput::Constant(constant) => format!("#[spirv({})] i{}: {},", constant_attribute(constant), position, constant.ty()),
|
||||
ShaderInput::UniformBuffer(_, ty) => {
|
||||
format!("#[spirv(uniform, descriptor_set = 0, binding = {})] i{}: &[{}]", position, position, ty,)
|
||||
}
|
||||
ShaderInput::StorageBuffer(_, ty) | ShaderInput::ReadBackBuffer(_, ty) => {
|
||||
format!("#[spirv(storage_buffer, descriptor_set = 0, binding = {})] i{}: &[{}]", position, position, ty,)
|
||||
}
|
||||
ShaderInput::OutputBuffer(_, ty) => {
|
||||
format!("#[spirv(storage_buffer, descriptor_set = 0, binding = {})] i{}: &mut[{}]", position, position, ty,)
|
||||
}
|
||||
ShaderInput::WorkGroupMemory(_, ty) => format!("#[spirv(workgroup_memory] i{}: {}", position, ty,),
|
||||
}
|
||||
}
|
||||
|
||||
struct GpuCompiler {
|
||||
compile_dir: PathBuf,
|
||||
}
|
||||
|
||||
impl SpirVCompiler for GpuCompiler {
|
||||
fn compile(&self, network: ProtoNetwork, io: &ShaderIO) -> anyhow::Result<gpu_executor::Shader> {
|
||||
let metadata = Metadata::new("project".to_owned(), vec!["test@example.com".to_owned()]);
|
||||
|
||||
create_files(&metadata, &network, &self.compile_dir, io)?;
|
||||
let result = compile(&self.compile_dir)?;
|
||||
|
||||
let bytes = std::fs::read(result.module.unwrap_single())?;
|
||||
let words = bytes.chunks(4).map(|chunk| u32::from_ne_bytes(chunk.try_into().unwrap())).collect::<Vec<_>>();
|
||||
|
||||
Ok(gpu_executor::Shader {
|
||||
source: Cow::Owned(words),
|
||||
name: "",
|
||||
io: io.clone(),
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
pub fn serialize_gpu(network: &ProtoNetwork, io: &ShaderIO) -> anyhow::Result<String> {
|
||||
fn nid(id: &u64) -> String {
|
||||
format!("n{id}")
|
||||
}
|
||||
|
||||
dbg!(&network);
|
||||
dbg!(&io);
|
||||
let inputs = io.inputs.iter().enumerate().map(|(i, input)| construct_argument(input, i as u32)).collect::<Vec<_>>();
|
||||
|
||||
let mut nodes = Vec::new();
|
||||
let mut input_nodes = Vec::new();
|
||||
#[derive(serde::Serialize)]
|
||||
struct Node {
|
||||
id: String,
|
||||
fqn: String,
|
||||
args: Vec<String>,
|
||||
}
|
||||
for id in network.inputs.iter() {
|
||||
let Some((_, node)) = network.nodes.iter().find(|(i, _)| i == id) else {
|
||||
anyhow::bail!("Input node not found");
|
||||
};
|
||||
let fqn = &node.identifier.name;
|
||||
let id = nid(id);
|
||||
input_nodes.push(Node {
|
||||
id,
|
||||
fqn: fqn.to_string().split("<").next().unwrap().to_owned(),
|
||||
args: node.construction_args.new_function_args(),
|
||||
});
|
||||
}
|
||||
|
||||
for (ref id, node) in network.nodes.iter() {
|
||||
if network.inputs.contains(id) {
|
||||
continue;
|
||||
}
|
||||
|
||||
let fqn = &node.identifier.name;
|
||||
let id = nid(id);
|
||||
|
||||
|
|
@ -78,8 +154,8 @@ pub fn serialize_gpu(network: &ProtoNetwork, input_type: &str, output_type: &str
|
|||
let mut tera = tera::Tera::default();
|
||||
tera.add_raw_template("spirv", template)?;
|
||||
let mut context = Context::new();
|
||||
context.insert("input_type", &input_type);
|
||||
context.insert("output_type", &output_type);
|
||||
context.insert("inputs", &inputs);
|
||||
context.insert("input_nodes", &input_nodes);
|
||||
context.insert("nodes", &nodes);
|
||||
context.insert("last_node", &nid(&network.output));
|
||||
context.insert("compute_threads", &64);
|
||||
|
|
@ -89,14 +165,15 @@ pub fn serialize_gpu(network: &ProtoNetwork, input_type: &str, output_type: &str
|
|||
use spirv_builder::{MetadataPrintout, SpirvBuilder, SpirvMetadata};
|
||||
pub fn compile(dir: &Path) -> Result<spirv_builder::CompileResult, spirv_builder::SpirvBuilderError> {
|
||||
dbg!(&dir);
|
||||
let result = SpirvBuilder::new(dir, "spirv-unknown-spv1.5")
|
||||
let result = SpirvBuilder::new(dir, "spirv-unknown-vulkan1.2")
|
||||
.print_metadata(MetadataPrintout::DependencyOnly)
|
||||
.multimodule(false)
|
||||
.preserve_bindings(true)
|
||||
.release(true)
|
||||
//.relax_struct_store(true)
|
||||
//.relax_block_layout(true)
|
||||
.spirv_metadata(SpirvMetadata::Full)
|
||||
.extra_arg("no-early-report-zombies")
|
||||
.extra_arg("no-infer-storage-classes")
|
||||
.extra_arg("spirt-passes=qptr")
|
||||
.build()?;
|
||||
|
||||
Ok(result)
|
||||
|
|
@ -104,7 +181,6 @@ pub fn compile(dir: &Path) -> Result<spirv_builder::CompileResult, spirv_builder
|
|||
|
||||
#[cfg(test)]
|
||||
mod test {
|
||||
|
||||
#[test]
|
||||
fn test_create_cargo_toml() {
|
||||
let cargo_toml = super::create_cargo_toml(&super::Metadata {
|
||||
|
|
|
|||
|
|
@ -1,4 +1,5 @@
|
|||
use gpu_compiler as compiler;
|
||||
use gpu_executor::CompileRequest;
|
||||
use graph_craft::document::NodeNetwork;
|
||||
use std::io::Write;
|
||||
|
||||
|
|
@ -6,17 +7,13 @@ fn main() -> anyhow::Result<()> {
|
|||
println!("Starting GPU Compiler!");
|
||||
let mut stdin = std::io::stdin();
|
||||
let mut stdout = std::io::stdout();
|
||||
let input_type = std::env::args().nth(1).expect("input type arg missing");
|
||||
let output_type = std::env::args().nth(2).expect("output type arg missing");
|
||||
let compile_dir = std::env::args().nth(3).map(|x| std::path::PathBuf::from(&x)).unwrap_or(tempfile::tempdir()?.into_path());
|
||||
let network: NodeNetwork = serde_json::from_reader(&mut stdin)?;
|
||||
let compiler = graph_craft::executor::Compiler {};
|
||||
let proto_network = compiler.compile_single(network, true).unwrap();
|
||||
let compile_dir = std::env::args().nth(1).map(|x| std::path::PathBuf::from(&x)).unwrap_or(tempfile::tempdir()?.into_path());
|
||||
let request: CompileRequest = serde_json::from_reader(&mut stdin)?;
|
||||
dbg!(&compile_dir);
|
||||
|
||||
let metadata = compiler::Metadata::new("project".to_owned(), vec!["test@example.com".to_owned()]);
|
||||
|
||||
compiler::create_files(&metadata, &proto_network, &compile_dir, &input_type, &output_type)?;
|
||||
compiler::create_files(&metadata, &request.network, &compile_dir, &request.io)?;
|
||||
let result = compiler::compile(&compile_dir)?;
|
||||
|
||||
let bytes = std::fs::read(result.module.unwrap_single())?;
|
||||
|
|
|
|||
|
|
@ -1,7 +1,7 @@
|
|||
[package]
|
||||
authors = [{% for author in authors %}"{{author}}", {% endfor %}]
|
||||
name = "{{name}}-node"
|
||||
version = "0.1.0"
|
||||
authors = [{%for author in authors%}"{{author}}", {%endfor%}]
|
||||
edition = "2021"
|
||||
license = "MIT OR Apache-2.0"
|
||||
publish = false
|
||||
|
|
@ -13,5 +13,7 @@ crate-type = ["dylib", "lib"]
|
|||
libm = { git = "https://github.com/rust-lang/libm", tag = "0.2.5" }
|
||||
|
||||
[dependencies]
|
||||
spirv-std = { version = "0.5" , features= ["glam"]}
|
||||
graphene-core = {path = "{{gcore_path}}", default-features = false, features = ["gpu"]}
|
||||
spirv-std = { version = "0.7" }
|
||||
graphene-core = { path = "{{gcore_path}}", default-features = false, features = [
|
||||
"gpu",
|
||||
] }
|
||||
|
|
|
|||
|
|
@ -1,3 +1,10 @@
|
|||
[toolchain]
|
||||
channel = "nightly-2022-12-18"
|
||||
components = ["rust-src", "rustc-dev", "llvm-tools-preview", "clippy", "cargofmt", "rustc"]
|
||||
channel = "nightly-2023-03-04"
|
||||
components = [
|
||||
"rust-src",
|
||||
"rustc-dev",
|
||||
"llvm-tools-preview",
|
||||
"clippy",
|
||||
"rustfmt",
|
||||
"rustc",
|
||||
]
|
||||
|
|
|
|||
|
|
@ -1,6 +1,5 @@
|
|||
#![no_std]
|
||||
#![feature(unchecked_math)]
|
||||
#![deny(warnings)]
|
||||
|
||||
#[cfg(target_arch = "spirv")]
|
||||
extern crate spirv_std;
|
||||
|
|
@ -14,25 +13,23 @@ pub mod gpu {
|
|||
#[allow(unused)]
|
||||
#[spirv(compute(threads({{compute_threads}})))]
|
||||
pub fn eval (
|
||||
#[spirv(global_invocation_id)] global_id: UVec3,
|
||||
#[spirv(storage_buffer, descriptor_set = 0, binding = 0)] a: &[{{input_type}}],
|
||||
#[spirv(storage_buffer, descriptor_set = 0, binding = 1)] y: &mut [{{output_type}}],
|
||||
//#[spirv(push_constant)] push_consts: &graphene_core::gpu::PushConstants,
|
||||
{% for input in inputs %}
|
||||
{{input}}
|
||||
{% endfor %}
|
||||
) {
|
||||
let gid = global_id.x as usize;
|
||||
// Only process up to n, which is the length of the buffers.
|
||||
//if global_id.x < push_consts.n {
|
||||
y[gid] = node_graph(a[gid]);
|
||||
//}
|
||||
}
|
||||
|
||||
fn node_graph(input: {{input_type}}) -> {{output_type}} {
|
||||
use graphene_core::Node;
|
||||
|
||||
{% for input in input_nodes %}
|
||||
let i{{loop.index0}} = graphene_core::value::CopiedNode::new(i{{loop.index0}});
|
||||
let _{{input.id}} = {{input.fqn}}::new({% for arg in input.args %}{{arg}}, {% endfor %});
|
||||
let {{input.id}} = graphene_core::structural::ComposeNode::new(i{{loop.index0}}, _{{input.id}});
|
||||
{% endfor %}
|
||||
|
||||
{% for node in nodes %}
|
||||
let {{node.id}} = {{node.fqn}}::new({% for arg in node.args %}{{arg}}, {% endfor %});
|
||||
{% endfor %}
|
||||
{{last_node}}.eval(input)
|
||||
}
|
||||
let output = {{last_node}}.eval(());
|
||||
// TODO: Write output to buffer
|
||||
|
||||
}
|
||||
}
|
||||
|
|
|
|||
|
|
@ -0,0 +1,36 @@
|
|||
[package]
|
||||
name = "gpu-executor"
|
||||
version = "0.1.0"
|
||||
edition = "2021"
|
||||
license = "MIT OR Apache-2.0"
|
||||
|
||||
[features]
|
||||
default = []
|
||||
|
||||
# See more keys and their definitions at https://doc.rust-lang.org/cargo/reference/manifest.html
|
||||
|
||||
[dependencies]
|
||||
graphene-core = { path = "../gcore", features = [
|
||||
"async",
|
||||
"std",
|
||||
"alloc",
|
||||
"gpu",
|
||||
] }
|
||||
graph-craft = { path = "../graph-craft", features = ["serde"] }
|
||||
node-macro = { path = "../node-macro" }
|
||||
dyn-any = { path = "../../libraries/dyn-any", features = [
|
||||
"log-bad-types",
|
||||
"rc",
|
||||
"glam",
|
||||
] }
|
||||
num-traits = "0.2"
|
||||
log = "0.4"
|
||||
serde = { version = "1", features = ["derive", "rc"] }
|
||||
glam = "0.22"
|
||||
base64 = "0.13"
|
||||
|
||||
bytemuck = { version = "1.8" }
|
||||
anyhow = "1.0.66"
|
||||
spirv = "0.2.0"
|
||||
futures-intrusive = "0.5.0"
|
||||
futures = "0.3.25"
|
||||
|
|
@ -0,0 +1,258 @@
|
|||
use graph_craft::proto::ProtoNetwork;
|
||||
use graphene_core::*;
|
||||
|
||||
use anyhow::Result;
|
||||
use dyn_any::StaticType;
|
||||
use futures::Future;
|
||||
use glam::UVec3;
|
||||
use serde::{Deserialize, Serialize};
|
||||
use std::borrow::Cow;
|
||||
use std::pin::Pin;
|
||||
|
||||
type ReadBackFuture = Pin<Box<dyn Future<Output = Result<Vec<u8>>>>>;
|
||||
|
||||
pub trait GpuExecutor {
|
||||
type ShaderHandle;
|
||||
type BufferHandle;
|
||||
type CommandBuffer;
|
||||
|
||||
fn load_shader(&self, shader: Shader) -> Result<Self::ShaderHandle>;
|
||||
fn create_uniform_buffer<T: ToUniformBuffer>(&self, data: T) -> Result<ShaderInput<Self::BufferHandle>>;
|
||||
fn create_storage_buffer<T: ToStorageBuffer>(&self, data: T, options: StorageBufferOptions) -> Result<ShaderInput<Self::BufferHandle>>;
|
||||
fn create_output_buffer(&self, len: usize, ty: Type, cpu_readable: bool) -> Result<ShaderInput<Self::BufferHandle>>;
|
||||
fn create_compute_pass(&self, layout: &PipelineLayout<Self>, read_back: Option<ShaderInput<Self::BufferHandle>>, instances: u32) -> Result<Self::CommandBuffer>;
|
||||
fn execute_compute_pipeline(&self, encoder: Self::CommandBuffer) -> Result<()>;
|
||||
fn read_output_buffer(&self, buffer: ShaderInput<Self::BufferHandle>) -> Result<ReadBackFuture>;
|
||||
}
|
||||
|
||||
pub trait SpirVCompiler {
|
||||
fn compile(&self, network: ProtoNetwork, io: &ShaderIO) -> Result<Shader>;
|
||||
}
|
||||
|
||||
#[derive(Clone, Debug, PartialEq, Serialize, Deserialize)]
|
||||
pub struct CompileRequest {
|
||||
pub network: ProtoNetwork,
|
||||
pub io: ShaderIO,
|
||||
}
|
||||
|
||||
#[derive(Clone, Debug, PartialEq, Eq, Hash, Serialize, Deserialize)]
|
||||
/// GPU constants that can be used as inputs to a shader.
|
||||
pub enum GPUConstant {
|
||||
SubGroupId,
|
||||
SubGroupInvocationId,
|
||||
SubGroupSize,
|
||||
NumSubGroups,
|
||||
WorkGroupId,
|
||||
WorkGroupInvocationId,
|
||||
WorkGroupSize,
|
||||
NumWorkGroups,
|
||||
GlobalInvocationId,
|
||||
GlobalSize,
|
||||
}
|
||||
|
||||
impl GPUConstant {
|
||||
pub fn ty(&self) -> Type {
|
||||
match self {
|
||||
GPUConstant::SubGroupId => concrete!(u32),
|
||||
GPUConstant::SubGroupInvocationId => concrete!(u32),
|
||||
GPUConstant::SubGroupSize => concrete!(u32),
|
||||
GPUConstant::NumSubGroups => concrete!(u32),
|
||||
GPUConstant::WorkGroupId => concrete!(UVec3),
|
||||
GPUConstant::WorkGroupInvocationId => concrete!(UVec3),
|
||||
GPUConstant::WorkGroupSize => concrete!(u32),
|
||||
GPUConstant::NumWorkGroups => concrete!(u32),
|
||||
GPUConstant::GlobalInvocationId => concrete!(UVec3),
|
||||
GPUConstant::GlobalSize => concrete!(UVec3),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Clone, Debug, PartialEq, Eq, Hash, Serialize, Deserialize)]
|
||||
/// All the possible inputs to a shader.
|
||||
pub enum ShaderInput<BufferHandle> {
|
||||
UniformBuffer(BufferHandle, Type),
|
||||
StorageBuffer(BufferHandle, Type),
|
||||
/// A struct representing a work group memory buffer. This cannot be accessed by the CPU.
|
||||
WorkGroupMemory(usize, Type),
|
||||
Constant(GPUConstant),
|
||||
OutputBuffer(BufferHandle, Type),
|
||||
ReadBackBuffer(BufferHandle, Type),
|
||||
}
|
||||
|
||||
/// Extract the buffer handle from a shader input.
|
||||
impl<BufferHandle> ShaderInput<BufferHandle> {
|
||||
pub fn buffer(&self) -> Option<&BufferHandle> {
|
||||
match self {
|
||||
ShaderInput::UniformBuffer(buffer, _) => Some(buffer),
|
||||
ShaderInput::StorageBuffer(buffer, _) => Some(buffer),
|
||||
ShaderInput::WorkGroupMemory(_, _) => None,
|
||||
ShaderInput::Constant(_) => None,
|
||||
ShaderInput::OutputBuffer(buffer, _) => Some(buffer),
|
||||
ShaderInput::ReadBackBuffer(buffer, _) => Some(buffer),
|
||||
}
|
||||
}
|
||||
pub fn ty(&self) -> Type {
|
||||
match self {
|
||||
ShaderInput::UniformBuffer(_, ty) => ty.clone(),
|
||||
ShaderInput::StorageBuffer(_, ty) => ty.clone(),
|
||||
ShaderInput::WorkGroupMemory(_, ty) => ty.clone(),
|
||||
ShaderInput::Constant(c) => c.ty(),
|
||||
ShaderInput::OutputBuffer(_, ty) => ty.clone(),
|
||||
ShaderInput::ReadBackBuffer(_, ty) => ty.clone(),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
pub struct Shader<'a> {
|
||||
pub source: Cow<'a, [u32]>,
|
||||
pub name: &'a str,
|
||||
pub io: ShaderIO,
|
||||
}
|
||||
|
||||
#[derive(Clone, Debug, PartialEq, Eq, Hash, Serialize, Deserialize)]
|
||||
pub struct ShaderIO {
|
||||
pub inputs: Vec<ShaderInput<()>>,
|
||||
pub output: ShaderInput<()>,
|
||||
}
|
||||
|
||||
pub struct StorageBufferOptions {
|
||||
pub cpu_writable: bool,
|
||||
pub gpu_writable: bool,
|
||||
pub cpu_readable: bool,
|
||||
}
|
||||
|
||||
pub trait ToUniformBuffer: StaticType {
|
||||
type UniformBufferHandle;
|
||||
fn to_bytes(&self) -> Cow<[u8]>;
|
||||
}
|
||||
|
||||
pub trait ToStorageBuffer: StaticType {
|
||||
type StorageBufferHandle;
|
||||
fn to_bytes(&self) -> Cow<[u8]>;
|
||||
}
|
||||
|
||||
/// Collection of all arguments that are passed to the shader.
|
||||
pub struct Bindgroup<E: GpuExecutor + ?Sized> {
|
||||
pub buffers: Vec<ShaderInput<E::BufferHandle>>,
|
||||
}
|
||||
|
||||
/// A struct representing a compute pipeline.
|
||||
pub struct PipelineLayout<E: GpuExecutor + ?Sized> {
|
||||
pub shader: E::ShaderHandle,
|
||||
pub entry_point: String,
|
||||
pub bind_group: Bindgroup<E>,
|
||||
pub output_buffer: ShaderInput<E::BufferHandle>,
|
||||
}
|
||||
|
||||
/// Extracts arguments from the function arguments and wraps them in a node.
|
||||
pub struct ShaderInputNode<T> {
|
||||
data: T,
|
||||
}
|
||||
|
||||
impl<'i, T: 'i> Node<'i, ()> for ShaderInputNode<T> {
|
||||
type Output = &'i T;
|
||||
|
||||
fn eval(&'i self, _: ()) -> Self::Output {
|
||||
&self.data
|
||||
}
|
||||
}
|
||||
|
||||
impl<T> ShaderInputNode<T> {
|
||||
pub fn new(data: T) -> Self {
|
||||
Self { data }
|
||||
}
|
||||
}
|
||||
|
||||
pub struct UniformNode<Executor> {
|
||||
executor: Executor,
|
||||
}
|
||||
|
||||
#[node_macro::node_fn(UniformNode)]
|
||||
fn uniform_node<T: ToUniformBuffer, E: GpuExecutor>(data: T, executor: &'any_input E) -> ShaderInput<E::BufferHandle> {
|
||||
let handle = executor.create_uniform_buffer(data).unwrap();
|
||||
handle
|
||||
}
|
||||
|
||||
pub struct StorageNode<Executor> {
|
||||
executor: Executor,
|
||||
}
|
||||
|
||||
#[node_macro::node_fn(StorageNode)]
|
||||
fn storage_node<T: ToStorageBuffer, E: GpuExecutor>(data: T, executor: &'any_input E) -> ShaderInput<E::BufferHandle> {
|
||||
let handle = executor
|
||||
.create_storage_buffer(
|
||||
data,
|
||||
StorageBufferOptions {
|
||||
cpu_writable: false,
|
||||
gpu_writable: true,
|
||||
cpu_readable: false,
|
||||
},
|
||||
)
|
||||
.unwrap();
|
||||
handle
|
||||
}
|
||||
|
||||
pub struct PushNode<Value> {
|
||||
value: Value,
|
||||
}
|
||||
|
||||
#[node_macro::node_fn(PushNode)]
|
||||
fn push_node<T>(mut vec: Vec<T>, value: T) {
|
||||
vec.push(value);
|
||||
}
|
||||
|
||||
pub struct CreateOutputBufferNode<Executor, Ty> {
|
||||
executor: Executor,
|
||||
ty: Ty,
|
||||
}
|
||||
|
||||
#[node_macro::node_fn(CreateOutputBufferNode)]
|
||||
fn create_output_buffer_node<E: GpuExecutor>(size: usize, executor: &'any_input E, ty: Type) -> ShaderInput<E::BufferHandle> {
|
||||
executor.create_output_buffer(size, ty, true).unwrap()
|
||||
}
|
||||
|
||||
pub struct CreateComputePassNode<Executor, Output, Instances> {
|
||||
executor: Executor,
|
||||
output: Output,
|
||||
instances: Instances,
|
||||
}
|
||||
|
||||
#[node_macro::node_fn(CreateComputePassNode)]
|
||||
fn create_compute_pass_node<E: GpuExecutor>(layout: PipelineLayout<E>, executor: &'any_input E, output: ShaderInput<E::BufferHandle>, instances: u32) -> E::CommandBuffer {
|
||||
executor.create_compute_pass(&layout, Some(output), instances).unwrap()
|
||||
}
|
||||
|
||||
pub struct CreatePipelineLayoutNode<_E, EntryPoint, Bindgroup, OutputBuffer> {
|
||||
entry_point: EntryPoint,
|
||||
bind_group: Bindgroup,
|
||||
output_buffer: OutputBuffer,
|
||||
_e: std::marker::PhantomData<_E>,
|
||||
}
|
||||
|
||||
#[node_macro::node_fn(CreatePipelineLayoutNode<_E>)]
|
||||
fn create_pipeline_layout_node<_E: GpuExecutor>(shader: _E::ShaderHandle, entry_point: String, bind_group: Bindgroup<_E>, output_buffer: ShaderInput<_E::BufferHandle>) -> PipelineLayout<_E> {
|
||||
PipelineLayout {
|
||||
shader,
|
||||
entry_point,
|
||||
bind_group,
|
||||
output_buffer,
|
||||
}
|
||||
}
|
||||
|
||||
pub struct ExecuteComputePipelineNode<Executor> {
|
||||
executor: Executor,
|
||||
}
|
||||
|
||||
#[node_macro::node_fn(ExecuteComputePipelineNode)]
|
||||
fn execute_compute_pipeline_node<E: GpuExecutor>(encoder: E::CommandBuffer, executor: &'any_input mut E) {
|
||||
executor.execute_compute_pipeline(encoder).unwrap();
|
||||
}
|
||||
|
||||
// TODO
|
||||
// pub struct ReadOutputBufferNode<Executor> {
|
||||
// executor: Executor,
|
||||
// }
|
||||
// #[node_macro::node_fn(ReadOutputBufferNode)]
|
||||
// fn read_output_buffer_node<E: GpuExecutor>(buffer: E::BufferHandle, executor: &'any_input mut E) -> Vec<u8> {
|
||||
// executor.read_output_buffer(buffer).await.unwrap()
|
||||
// }
|
||||
|
|
@ -1,3 +1,4 @@
|
|||
use crate::document::value::TaggedValue;
|
||||
use crate::proto::{ConstructionArgs, ProtoNetwork, ProtoNode, ProtoNodeInput};
|
||||
use graphene_core::{NodeIdentifier, Type};
|
||||
|
||||
|
|
@ -20,7 +21,7 @@ fn merge_ids(a: u64, b: u64) -> u64 {
|
|||
hasher.finish()
|
||||
}
|
||||
|
||||
#[derive(Clone, Debug, PartialEq, Default, specta::Type)]
|
||||
#[derive(Clone, Debug, PartialEq, Default, specta::Type, Hash, DynAny)]
|
||||
#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
|
||||
pub struct DocumentNodeMetadata {
|
||||
pub position: IVec2,
|
||||
|
|
@ -32,7 +33,7 @@ impl DocumentNodeMetadata {
|
|||
}
|
||||
}
|
||||
|
||||
#[derive(Clone, Debug, PartialEq)]
|
||||
#[derive(Clone, Debug, PartialEq, Hash, DynAny)]
|
||||
#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
|
||||
pub struct DocumentNode {
|
||||
pub name: String,
|
||||
|
|
@ -156,7 +157,7 @@ impl DocumentNode {
|
|||
///
|
||||
/// In this case the Cache node actually consumes its input and then manually forwards it to its parameter Node.
|
||||
/// This is necessary because the Cache Node needs to short-circut the actual node evaluation.
|
||||
#[derive(Debug, Clone, PartialEq, Hash)]
|
||||
#[derive(Debug, Clone, PartialEq, Hash, DynAny)]
|
||||
#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
|
||||
pub enum NodeInput {
|
||||
Node {
|
||||
|
|
@ -165,7 +166,7 @@ pub enum NodeInput {
|
|||
lambda: bool,
|
||||
},
|
||||
Value {
|
||||
tagged_value: crate::document::value::TaggedValue,
|
||||
tagged_value: TaggedValue,
|
||||
exposed: bool,
|
||||
},
|
||||
Network(Type),
|
||||
|
|
@ -182,7 +183,7 @@ impl NodeInput {
|
|||
pub const fn lambda(node_id: NodeId, output_index: usize) -> Self {
|
||||
Self::Node { node_id, output_index, lambda: true }
|
||||
}
|
||||
pub const fn value(tagged_value: crate::document::value::TaggedValue, exposed: bool) -> Self {
|
||||
pub const fn value(tagged_value: TaggedValue, exposed: bool) -> Self {
|
||||
Self::Value { tagged_value, exposed }
|
||||
}
|
||||
fn map_ids(&mut self, f: impl Fn(NodeId) -> NodeId) {
|
||||
|
|
@ -212,11 +213,12 @@ impl NodeInput {
|
|||
}
|
||||
}
|
||||
|
||||
#[derive(Clone, Debug, PartialEq)]
|
||||
#[derive(Clone, Debug, PartialEq, Hash, DynAny)]
|
||||
#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
|
||||
pub enum DocumentNodeImplementation {
|
||||
Network(NodeNetwork),
|
||||
Unresolved(NodeIdentifier),
|
||||
Extract,
|
||||
}
|
||||
|
||||
impl Default for DocumentNodeImplementation {
|
||||
|
|
@ -227,23 +229,21 @@ impl Default for DocumentNodeImplementation {
|
|||
|
||||
impl DocumentNodeImplementation {
|
||||
pub fn get_network(&self) -> Option<&NodeNetwork> {
|
||||
if let DocumentNodeImplementation::Network(n) = self {
|
||||
Some(n)
|
||||
} else {
|
||||
None
|
||||
match self {
|
||||
DocumentNodeImplementation::Network(n) => Some(n),
|
||||
_ => None,
|
||||
}
|
||||
}
|
||||
|
||||
pub fn get_network_mut(&mut self) -> Option<&mut NodeNetwork> {
|
||||
if let DocumentNodeImplementation::Network(n) = self {
|
||||
Some(n)
|
||||
} else {
|
||||
None
|
||||
match self {
|
||||
DocumentNodeImplementation::Network(n) => Some(n),
|
||||
_ => None,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Clone, Copy, Debug, Default, PartialEq, DynAny, specta::Type)]
|
||||
#[derive(Clone, Copy, Debug, Default, PartialEq, DynAny, specta::Type, Hash)]
|
||||
#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
|
||||
pub struct NodeOutput {
|
||||
pub node_id: NodeId,
|
||||
|
|
@ -267,6 +267,21 @@ pub struct NodeNetwork {
|
|||
pub previous_outputs: Option<Vec<NodeOutput>>,
|
||||
}
|
||||
|
||||
impl std::hash::Hash for NodeNetwork {
|
||||
fn hash<H: std::hash::Hasher>(&self, state: &mut H) {
|
||||
self.inputs.hash(state);
|
||||
self.outputs.hash(state);
|
||||
let mut nodes: Vec<_> = self.nodes.iter().collect();
|
||||
nodes.sort_by_key(|(id, _)| *id);
|
||||
for (id, node) in nodes {
|
||||
id.hash(state);
|
||||
node.hash(state);
|
||||
}
|
||||
self.disabled.hash(state);
|
||||
self.previous_outputs.hash(state);
|
||||
}
|
||||
}
|
||||
|
||||
/// Graph modification functions
|
||||
impl NodeNetwork {
|
||||
/// Get the original output nodes of this network, ignoring any preview node
|
||||
|
|
@ -701,12 +716,43 @@ impl NodeNetwork {
|
|||
self.flatten_with_fns(node_id, map_ids, gen_id);
|
||||
}
|
||||
}
|
||||
DocumentNodeImplementation::Unresolved(_) => {}
|
||||
DocumentNodeImplementation::Unresolved(_) => (),
|
||||
DocumentNodeImplementation::Extract => {
|
||||
panic!("Extract nodes should have been removed before flattening");
|
||||
}
|
||||
}
|
||||
assert!(!self.nodes.contains_key(&id), "Trying to insert a node into the network caused an id conflict");
|
||||
self.nodes.insert(id, node);
|
||||
}
|
||||
|
||||
pub fn resolve_extract_nodes(&mut self) {
|
||||
let mut extraction_nodes = self
|
||||
.nodes
|
||||
.iter()
|
||||
.filter(|(_, node)| matches!(node.implementation, DocumentNodeImplementation::Extract))
|
||||
.map(|(id, node)| (*id, node.clone()))
|
||||
.collect::<Vec<_>>();
|
||||
self.nodes.retain(|_, node| !matches!(node.implementation, DocumentNodeImplementation::Extract));
|
||||
|
||||
for (_, node) in &mut extraction_nodes {
|
||||
match node.implementation {
|
||||
DocumentNodeImplementation::Extract => {
|
||||
assert_eq!(node.inputs.len(), 1);
|
||||
let NodeInput::Node { node_id, output_index, lambda } = node.inputs.pop().unwrap() else {
|
||||
panic!("Extract node has no input");
|
||||
};
|
||||
assert_eq!(output_index, 0);
|
||||
assert!(lambda);
|
||||
let input_node = self.nodes.get_mut(&node_id).unwrap();
|
||||
node.implementation = DocumentNodeImplementation::Unresolved("graphene_core::value::ValueNode".into());
|
||||
node.inputs = vec![NodeInput::value(TaggedValue::DocumentNode(input_node.clone()), false)];
|
||||
}
|
||||
_ => (),
|
||||
}
|
||||
}
|
||||
self.nodes.extend(extraction_nodes);
|
||||
}
|
||||
|
||||
pub fn into_proto_networks(self) -> impl Iterator<Item = ProtoNetwork> {
|
||||
let mut nodes: Vec<_> = self.nodes.into_iter().map(|(id, node)| (id, node.resolve_proto_node())).collect();
|
||||
nodes.sort_unstable_by_key(|(i, _)| *i);
|
||||
|
|
@ -798,6 +844,39 @@ mod test {
|
|||
assert_eq!(network, maped_add);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn extract_node() {
|
||||
let id_node = DocumentNode {
|
||||
name: "Id".into(),
|
||||
inputs: vec![],
|
||||
metadata: DocumentNodeMetadata::default(),
|
||||
implementation: DocumentNodeImplementation::Unresolved("graphene_core::ops::IdNode".into()),
|
||||
};
|
||||
let mut extraction_network = NodeNetwork {
|
||||
inputs: vec![],
|
||||
outputs: vec![NodeOutput::new(1, 0)],
|
||||
nodes: [
|
||||
id_node.clone(),
|
||||
DocumentNode {
|
||||
name: "Extract".into(),
|
||||
inputs: vec![NodeInput::lambda(0, 0)],
|
||||
metadata: DocumentNodeMetadata::default(),
|
||||
implementation: DocumentNodeImplementation::Extract,
|
||||
},
|
||||
]
|
||||
.into_iter()
|
||||
.enumerate()
|
||||
.map(|(id, node)| (id as NodeId, node))
|
||||
.collect(),
|
||||
..Default::default()
|
||||
};
|
||||
extraction_network.resolve_extract_nodes();
|
||||
assert_eq!(extraction_network.nodes.len(), 2);
|
||||
let inputs = extraction_network.nodes.get(&1).unwrap().inputs.clone();
|
||||
assert_eq!(inputs.len(), 1);
|
||||
assert!(matches!(&inputs[0], &NodeInput::Value{ tagged_value: TaggedValue::DocumentNode(ref network), ..} if network == &id_node));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn flatten_add() {
|
||||
let mut network = NodeNetwork {
|
||||
|
|
@ -810,7 +889,7 @@ mod test {
|
|||
inputs: vec![
|
||||
NodeInput::Network(concrete!(u32)),
|
||||
NodeInput::Value {
|
||||
tagged_value: crate::document::value::TaggedValue::U32(2),
|
||||
tagged_value: TaggedValue::U32(2),
|
||||
exposed: false,
|
||||
},
|
||||
],
|
||||
|
|
@ -876,7 +955,7 @@ mod test {
|
|||
construction_args: ConstructionArgs::Nodes(vec![]),
|
||||
},
|
||||
),
|
||||
(14, ProtoNode::value(ConstructionArgs::Value(crate::document::value::TaggedValue::U32(2)))),
|
||||
(14, ProtoNode::value(ConstructionArgs::Value(TaggedValue::U32(2)))),
|
||||
]
|
||||
.into_iter()
|
||||
.collect(),
|
||||
|
|
@ -917,7 +996,7 @@ mod test {
|
|||
DocumentNode {
|
||||
name: "Value".into(),
|
||||
inputs: vec![NodeInput::Value {
|
||||
tagged_value: crate::document::value::TaggedValue::U32(2),
|
||||
tagged_value: TaggedValue::U32(2),
|
||||
exposed: false,
|
||||
}],
|
||||
metadata: DocumentNodeMetadata::default(),
|
||||
|
|
@ -979,10 +1058,7 @@ mod test {
|
|||
10,
|
||||
DocumentNode {
|
||||
name: "Nested network".into(),
|
||||
inputs: vec![
|
||||
NodeInput::value(crate::document::value::TaggedValue::F32(1.), false),
|
||||
NodeInput::value(crate::document::value::TaggedValue::F32(2.), false),
|
||||
],
|
||||
inputs: vec![NodeInput::value(TaggedValue::F32(1.), false), NodeInput::value(TaggedValue::F32(2.), false)],
|
||||
metadata: DocumentNodeMetadata::default(),
|
||||
implementation: DocumentNodeImplementation::Network(two_node_identity()),
|
||||
},
|
||||
|
|
@ -1015,11 +1091,7 @@ mod test {
|
|||
assert_eq!(result.nodes.keys().copied().collect::<Vec<_>>(), vec![101], "Should just call nested network");
|
||||
let nested_network_node = result.nodes.get(&101).unwrap();
|
||||
assert_eq!(nested_network_node.name, "Nested network".to_string(), "Name should not change");
|
||||
assert_eq!(
|
||||
nested_network_node.inputs,
|
||||
vec![NodeInput::value(crate::document::value::TaggedValue::F32(2.), false)],
|
||||
"Input should be 2"
|
||||
);
|
||||
assert_eq!(nested_network_node.inputs, vec![NodeInput::value(TaggedValue::F32(2.), false)], "Input should be 2");
|
||||
let inner_network = nested_network_node.implementation.get_network().expect("Implementation should be network");
|
||||
assert_eq!(inner_network.inputs, vec![2], "The input should be sent to the second node");
|
||||
assert_eq!(inner_network.outputs, vec![NodeOutput::new(2, 0)], "The output should be node id 2");
|
||||
|
|
@ -1038,11 +1110,7 @@ mod test {
|
|||
for (node_id, input_value, inner_id) in [(10, 1., 1), (101, 2., 2)] {
|
||||
let nested_network_node = result.nodes.get(&node_id).unwrap();
|
||||
assert_eq!(nested_network_node.name, "Nested network".to_string(), "Name should not change");
|
||||
assert_eq!(
|
||||
nested_network_node.inputs,
|
||||
vec![NodeInput::value(crate::document::value::TaggedValue::F32(input_value), false)],
|
||||
"Input should be stable"
|
||||
);
|
||||
assert_eq!(nested_network_node.inputs, vec![NodeInput::value(TaggedValue::F32(input_value), false)], "Input should be stable");
|
||||
let inner_network = nested_network_node.implementation.get_network().expect("Implementation should be network");
|
||||
assert_eq!(inner_network.inputs, vec![inner_id], "The input should be sent to the second node");
|
||||
assert_eq!(inner_network.outputs, vec![NodeOutput::new(inner_id, 0)], "The output should be node id");
|
||||
|
|
@ -1061,11 +1129,7 @@ mod test {
|
|||
assert_eq!(result_node.inputs, vec![NodeInput::node(101, 0)], "Result node should refer to duplicate node as input");
|
||||
let nested_network_node = result.nodes.get(&101).unwrap();
|
||||
assert_eq!(nested_network_node.name, "Nested network".to_string(), "Name should not change");
|
||||
assert_eq!(
|
||||
nested_network_node.inputs,
|
||||
vec![NodeInput::value(crate::document::value::TaggedValue::F32(2.), false)],
|
||||
"Input should be 2"
|
||||
);
|
||||
assert_eq!(nested_network_node.inputs, vec![NodeInput::value(TaggedValue::F32(2.), false)], "Input should be 2");
|
||||
let inner_network = nested_network_node.implementation.get_network().expect("Implementation should be network");
|
||||
assert_eq!(inner_network.inputs, vec![2], "The input should be sent to the second node");
|
||||
assert_eq!(inner_network.outputs, vec![NodeOutput::new(2, 0)], "The output should be node id 2");
|
||||
|
|
|
|||
|
|
@ -1,15 +1,17 @@
|
|||
use super::DocumentNode;
|
||||
use crate::executor::Any;
|
||||
pub use crate::imaginate_input::{ImaginateMaskStartingFill, ImaginateSamplingMethod, ImaginateStatus};
|
||||
|
||||
use graphene_core::raster::{BlendMode, LuminanceCalculation};
|
||||
use graphene_core::{Color, Node, Type};
|
||||
|
||||
pub use dyn_any::StaticType;
|
||||
use dyn_any::{DynAny, Upcast};
|
||||
use dyn_clone::DynClone;
|
||||
pub use glam::{DAffine2, DVec2};
|
||||
use graphene_core::raster::{BlendMode, LuminanceCalculation};
|
||||
use graphene_core::{Color, Node, Type};
|
||||
use std::hash::Hash;
|
||||
pub use std::sync::Arc;
|
||||
|
||||
use crate::executor::Any;
|
||||
pub use crate::imaginate_input::{ImaginateMaskStartingFill, ImaginateSamplingMethod, ImaginateStatus};
|
||||
|
||||
/// A type that is known, allowing serialization (serde::Deserialize is not object safe)
|
||||
#[derive(Clone, Debug, PartialEq)]
|
||||
#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
|
||||
|
|
@ -52,6 +54,7 @@ pub enum TaggedValue {
|
|||
ManipulatorGroupIds(Vec<graphene_core::uuid::ManipulatorGroupId>),
|
||||
VecDVec2(Vec<DVec2>),
|
||||
Segments(Vec<graphene_core::raster::ImageFrame<Color>>),
|
||||
DocumentNode(DocumentNode),
|
||||
}
|
||||
|
||||
#[allow(clippy::derived_hash_with_manual_eq)]
|
||||
|
|
@ -119,11 +122,13 @@ impl Hash for TaggedValue {
|
|||
}
|
||||
}
|
||||
Self::Segments(segments) => {
|
||||
32.hash(state);
|
||||
for segment in segments {
|
||||
segment.hash(state)
|
||||
}
|
||||
}
|
||||
Self::DocumentNode(document_node) => {
|
||||
document_node.hash(state);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
@ -170,6 +175,19 @@ impl<'a> TaggedValue {
|
|||
TaggedValue::ManipulatorGroupIds(x) => Box::new(x),
|
||||
TaggedValue::VecDVec2(x) => Box::new(x),
|
||||
TaggedValue::Segments(x) => Box::new(x),
|
||||
TaggedValue::DocumentNode(x) => Box::new(x),
|
||||
}
|
||||
}
|
||||
|
||||
pub fn to_primitive_string(&self) -> String {
|
||||
match self {
|
||||
TaggedValue::None => "()".to_string(),
|
||||
TaggedValue::String(x) => x.clone(),
|
||||
TaggedValue::U32(x) => x.to_string(),
|
||||
TaggedValue::F32(x) => x.to_string(),
|
||||
TaggedValue::F64(x) => x.to_string(),
|
||||
TaggedValue::Bool(x) => x.to_string(),
|
||||
_ => panic!("Cannot convert to primitive string"),
|
||||
}
|
||||
}
|
||||
|
||||
|
|
@ -215,6 +233,7 @@ impl<'a> TaggedValue {
|
|||
TaggedValue::ManipulatorGroupIds(_) => concrete!(Vec<graphene_core::uuid::ManipulatorGroupId>),
|
||||
TaggedValue::VecDVec2(_) => concrete!(Vec<DVec2>),
|
||||
TaggedValue::Segments(_) => concrete!(graphene_core::raster::IndexNode<Vec<graphene_core::raster::ImageFrame<Color>>>),
|
||||
TaggedValue::DocumentNode(_) => concrete!(crate::document::DocumentNode),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
|
|||
|
|
@ -10,6 +10,7 @@ pub struct Compiler {}
|
|||
impl Compiler {
|
||||
pub fn compile(&self, mut network: NodeNetwork, resolve_inputs: bool) -> impl Iterator<Item = ProtoNetwork> {
|
||||
let node_ids = network.nodes.keys().copied().collect::<Vec<_>>();
|
||||
network.resolve_extract_nodes();
|
||||
println!("flattening");
|
||||
for id in node_ids {
|
||||
network.flatten(id);
|
||||
|
|
|
|||
|
|
@ -7,6 +7,8 @@ use crate::document::value;
|
|||
use crate::document::NodeId;
|
||||
use dyn_any::DynAny;
|
||||
use graphene_core::*;
|
||||
#[cfg(feature = "serde")]
|
||||
use serde::{Deserialize, Serialize};
|
||||
use std::pin::Pin;
|
||||
|
||||
pub type Any<'n> = Box<dyn DynAny<'n> + 'n>;
|
||||
|
|
@ -16,7 +18,8 @@ pub type TypeErasedPinned<'n> = Pin<Box<dyn for<'i> NodeIO<'i, Any<'i>, Output =
|
|||
|
||||
pub type NodeConstructor = for<'a> fn(Vec<TypeErasedPinnedRef<'static>>) -> TypeErasedPinned<'static>;
|
||||
|
||||
#[derive(Debug, Default, PartialEq)]
|
||||
#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
|
||||
#[derive(Debug, Default, PartialEq, Clone)]
|
||||
pub struct ProtoNetwork {
|
||||
// Should a proto Network even allow inputs? Don't think so
|
||||
pub inputs: Vec<NodeId>,
|
||||
|
|
@ -29,7 +32,7 @@ impl core::fmt::Display for ProtoNetwork {
|
|||
f.write_str("Proto Network with nodes: ")?;
|
||||
fn write_node(f: &mut core::fmt::Formatter<'_>, network: &ProtoNetwork, id: NodeId, indent: usize) -> core::fmt::Result {
|
||||
f.write_str(&"\t".repeat(indent))?;
|
||||
let Some((_, node)) = network.nodes.iter().find(|(node_id, _)|*node_id == id) else{
|
||||
let Some((_, node)) = network.nodes.iter().find(|(node_id, _)|*node_id == id) else {
|
||||
return f.write_str("{{Unknown Node}}");
|
||||
};
|
||||
f.write_str("Node: ")?;
|
||||
|
|
@ -70,6 +73,7 @@ impl core::fmt::Display for ProtoNetwork {
|
|||
}
|
||||
}
|
||||
|
||||
#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
|
||||
#[derive(Debug, Clone)]
|
||||
pub enum ConstructionArgs {
|
||||
Value(value::TaggedValue),
|
||||
|
|
@ -104,12 +108,13 @@ impl Hash for ConstructionArgs {
|
|||
impl ConstructionArgs {
|
||||
pub fn new_function_args(&self) -> Vec<String> {
|
||||
match self {
|
||||
ConstructionArgs::Nodes(nodes) => nodes.iter().map(|n| format!("n{}", n.0)).collect(),
|
||||
ConstructionArgs::Value(value) => vec![format!("{:?}", value)],
|
||||
ConstructionArgs::Nodes(nodes) => nodes.iter().map(|n| format!("&n{}", n.0)).collect(),
|
||||
ConstructionArgs::Value(value) => vec![value.to_primitive_string()],
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
|
||||
#[derive(Debug, PartialEq, Clone)]
|
||||
pub struct ProtoNode {
|
||||
pub construction_args: ConstructionArgs,
|
||||
|
|
@ -121,6 +126,7 @@ pub struct ProtoNode {
|
|||
/// For documentation on the meaning of the variants, see the documentation of the `NodeInput` enum
|
||||
/// in the `document` module
|
||||
#[derive(Debug, PartialEq, Eq, Clone)]
|
||||
#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
|
||||
pub enum ProtoNodeInput {
|
||||
None,
|
||||
Network(Type),
|
||||
|
|
|
|||
|
|
@ -11,41 +11,58 @@ license = "MIT OR Apache-2.0"
|
|||
[features]
|
||||
memoization = ["once_cell"]
|
||||
default = ["memoization"]
|
||||
gpu = ["graphene-core/gpu", "gpu-compiler-bin-wrapper", "compilation-client"]
|
||||
gpu = [
|
||||
"graphene-core/gpu",
|
||||
"gpu-compiler-bin-wrapper",
|
||||
"compilation-client",
|
||||
"gpu-executor",
|
||||
]
|
||||
vulkan = ["gpu", "vulkan-executor"]
|
||||
wgpu = ["gpu", "wgpu-executor"]
|
||||
quantization = ["autoquant"]
|
||||
|
||||
|
||||
[dependencies]
|
||||
autoquant = { git = "https://github.com/truedoctor/autoquant", optional = true, features = ["fitting"] }
|
||||
graphene-core = {path = "../gcore", features = ["async", "std", "serde" ], default-features = false}
|
||||
borrow_stack = {path = "../borrow_stack"}
|
||||
dyn-any = {path = "../../libraries/dyn-any", features = ["derive"]}
|
||||
graph-craft = {path = "../graph-craft"}
|
||||
vulkan-executor = {path = "../vulkan-executor", optional = true}
|
||||
wgpu-executor = {path = "../wgpu-executor", optional = true}
|
||||
gpu-compiler-bin-wrapper = {path = "../gpu-compiler/gpu-compiler-bin-wrapper", optional = true}
|
||||
compilation-client = {path = "../compilation-client", optional = true}
|
||||
bytemuck = {version = "1.8" }
|
||||
autoquant = { git = "https://github.com/truedoctor/autoquant", optional = true, features = [
|
||||
"fitting",
|
||||
] }
|
||||
graphene-core = { path = "../gcore", features = [
|
||||
"async",
|
||||
"std",
|
||||
"serde",
|
||||
], default-features = false }
|
||||
borrow_stack = { path = "../borrow_stack" }
|
||||
dyn-any = { path = "../../libraries/dyn-any", features = ["derive"] }
|
||||
graph-craft = { path = "../graph-craft" }
|
||||
vulkan-executor = { path = "../vulkan-executor", optional = true }
|
||||
wgpu-executor = { path = "../wgpu-executor", optional = true, version = "0.1.0" }
|
||||
gpu-executor = { path = "../gpu-executor", optional = true }
|
||||
gpu-compiler-bin-wrapper = { path = "../gpu-compiler/gpu-compiler-bin-wrapper", optional = true }
|
||||
compilation-client = { path = "../compilation-client", optional = true }
|
||||
bytemuck = { version = "1.8" }
|
||||
tempfile = "3"
|
||||
once_cell = {version= "1.10", optional = true}
|
||||
once_cell = { version = "1.10", optional = true }
|
||||
#pretty-token-stream = {path = "../../pretty-token-stream"}
|
||||
syn = {version = "1.0", default-features = false, features = ["parsing", "printing"]}
|
||||
proc-macro2 = {version = "1.0", default-features = false, features = ["proc-macro"]}
|
||||
quote = {version = "1.0", default-features = false }
|
||||
syn = { version = "1.0", default-features = false, features = [
|
||||
"parsing",
|
||||
"printing",
|
||||
] }
|
||||
proc-macro2 = { version = "1.0", default-features = false, features = [
|
||||
"proc-macro",
|
||||
] }
|
||||
quote = { version = "1.0", default-features = false }
|
||||
image = { version = "*", default-features = false }
|
||||
dyn-clone = "1.0"
|
||||
|
||||
log = "0.4"
|
||||
bezier-rs = { path = "../../libraries/bezier-rs" , features = ["serde"] }
|
||||
bezier-rs = { path = "../../libraries/bezier-rs", features = ["serde"] }
|
||||
kurbo = { git = "https://github.com/linebender/kurbo.git", features = [
|
||||
"serde",
|
||||
] }
|
||||
glam = { version = "0.22", features = ["serde"] }
|
||||
node-macro = { path="../node-macro" }
|
||||
node-macro = { path = "../node-macro" }
|
||||
boxcar = "0.1.0"
|
||||
xxhash-rust = {workspace = true}
|
||||
xxhash-rust = { workspace = true }
|
||||
|
||||
[dependencies.serde]
|
||||
version = "1.0"
|
||||
|
|
|
|||
|
|
@ -1,24 +1,48 @@
|
|||
use gpu_executor::{GpuExecutor, ShaderIO, ShaderInput};
|
||||
use graph_craft::document::*;
|
||||
use graph_craft::proto::*;
|
||||
use graphene_core::raster::*;
|
||||
use graphene_core::value::ValueNode;
|
||||
use graphene_core::*;
|
||||
use wgpu_executor::NewExecutor;
|
||||
|
||||
use bytemuck::Pod;
|
||||
use core::marker::PhantomData;
|
||||
use dyn_any::StaticTypeSized;
|
||||
|
||||
pub struct MapGpuNode<O, Network> {
|
||||
network: Network,
|
||||
_o: PhantomData<O>,
|
||||
pub struct GpuCompiler<TypingContext, ShaderIO> {
|
||||
typing_context: TypingContext,
|
||||
io: ShaderIO,
|
||||
}
|
||||
|
||||
#[node_macro::node_fn(MapGpuNode<_O>)]
|
||||
fn map_gpu<I: IntoIterator<Item = S>, S: StaticTypeSized + Sync + Send + Pod, _O: StaticTypeSized + Sync + Send + Pod>(input: I, network: &'any_input NodeNetwork) -> Vec<_O> {
|
||||
// TODO: Move to graph-craft
|
||||
#[node_macro::node_fn(GpuCompiler)]
|
||||
fn compile_gpu(node: &'input DocumentNode, mut typing_context: TypingContext, io: ShaderIO) -> compilation_client::Shader {
|
||||
let compiler = graph_craft::executor::Compiler {};
|
||||
let DocumentNodeImplementation::Network(network) = node.implementation;
|
||||
let proto_network = compiler.compile_single(network, true).unwrap();
|
||||
typing_context.update(&proto_network);
|
||||
let input_types = proto_network.inputs.iter().map(|id| typing_context.get_type(*id).unwrap()).map(|node_io| node_io.output).collect();
|
||||
let output_type = typing_context.get_type(proto_network.output).unwrap().output;
|
||||
|
||||
let bytes = compilation_client::compile_sync(proto_network, input_types, output_type, io).unwrap();
|
||||
bytes
|
||||
}
|
||||
|
||||
pub struct MapGpuNode<Shader> {
|
||||
shader: Shader,
|
||||
}
|
||||
|
||||
#[node_macro::node_fn(MapGpuNode)]
|
||||
fn map_gpu(inputs: Vec<ShaderInput<<NewExecutor as GpuExecutor>::BufferHandle>>, shader: &'any_input compilation_client::Shader) {
|
||||
use graph_craft::executor::Executor;
|
||||
let bytes = compilation_client::compile_sync::<S, _O>(network.clone()).unwrap();
|
||||
let words = unsafe { std::slice::from_raw_parts(bytes.as_ptr() as *const u32, bytes.len() / 4) };
|
||||
use wgpu_executor::{Context, GpuExecutor};
|
||||
let executor: GpuExecutor<S, _O> = GpuExecutor::new(Context::new_sync().unwrap(), words.into(), "gpu::eval".into()).unwrap();
|
||||
let executor = NewExecutor::new().unwrap();
|
||||
for input in shader.inputs.iter() {
|
||||
let buffer = executor.create_buffer(input.size).unwrap();
|
||||
executor.write_buffer(buffer, input.data).unwrap();
|
||||
}
|
||||
todo!();
|
||||
let executor: GpuExecutor = GpuExecutor::new(Context::new_sync().unwrap(), shader.into(), "gpu::eval".into()).unwrap();
|
||||
let data: Vec<_> = input.into_iter().collect();
|
||||
let result = executor.execute(Box::new(data)).unwrap();
|
||||
let result = dyn_any::downcast::<Vec<_O>>(result).unwrap();
|
||||
|
|
@ -30,7 +54,7 @@ pub struct MapGpuSingleImageNode<N> {
|
|||
}
|
||||
|
||||
#[node_macro::node_fn(MapGpuSingleImageNode)]
|
||||
fn map_gpu_single_image(input: Image, node: String) -> Image {
|
||||
fn map_gpu_single_image(input: Image<Color>, node: String) -> Image<Color> {
|
||||
use graph_craft::document::*;
|
||||
use graph_craft::NodeIdentifier;
|
||||
|
||||
|
|
|
|||
|
|
@ -12,6 +12,7 @@ default = []
|
|||
[dependencies]
|
||||
graphene-core = { path = "../gcore", features = ["async", "std", "alloc", "gpu"] }
|
||||
graph-craft = {path = "../graph-craft" }
|
||||
gpu-executor = { path = "../gpu-executor" }
|
||||
dyn-any = { path = "../../libraries/dyn-any", features = ["log-bad-types", "rc", "glam"] }
|
||||
future-executor = { path = "../future-executor" }
|
||||
num-traits = "0.2"
|
||||
|
|
|
|||
|
|
@ -1,7 +1,7 @@
|
|||
use std::sync::Arc;
|
||||
use wgpu::{Device, Instance, Queue};
|
||||
|
||||
#[derive(Debug)]
|
||||
#[derive(Debug, Clone)]
|
||||
pub struct Context {
|
||||
pub device: Arc<Device>,
|
||||
pub queue: Arc<Queue>,
|
||||
|
|
|
|||
|
|
@ -3,3 +3,189 @@ mod executor;
|
|||
|
||||
pub use context::Context;
|
||||
pub use executor::GpuExecutor;
|
||||
use gpu_executor::{Shader, ShaderInput, StorageBufferOptions, ToStorageBuffer, ToUniformBuffer};
|
||||
use graph_craft::Type;
|
||||
|
||||
use anyhow::{bail, Result};
|
||||
use futures::Future;
|
||||
use std::pin::Pin;
|
||||
use wgpu::util::DeviceExt;
|
||||
use wgpu::{Buffer, BufferDescriptor, CommandBuffer, ShaderModule};
|
||||
|
||||
#[derive(Debug, Clone)]
|
||||
pub struct NewExecutor {
|
||||
context: Context,
|
||||
}
|
||||
|
||||
impl gpu_executor::GpuExecutor for NewExecutor {
|
||||
type ShaderHandle = ShaderModule;
|
||||
type BufferHandle = Buffer;
|
||||
type CommandBuffer = CommandBuffer;
|
||||
|
||||
fn load_shader(&self, shader: Shader) -> Result<Self::ShaderHandle> {
|
||||
let shader_module = self.context.device.create_shader_module(wgpu::ShaderModuleDescriptor {
|
||||
label: Some(shader.name),
|
||||
source: wgpu::ShaderSource::SpirV(shader.source),
|
||||
});
|
||||
Ok(shader_module)
|
||||
}
|
||||
|
||||
fn create_uniform_buffer<T: ToUniformBuffer>(&self, data: T) -> Result<ShaderInput<Self::BufferHandle>> {
|
||||
let bytes = data.to_bytes();
|
||||
let buffer = self.context.device.create_buffer_init(&wgpu::util::BufferInitDescriptor {
|
||||
label: None,
|
||||
contents: bytes.as_ref(),
|
||||
usage: wgpu::BufferUsages::UNIFORM,
|
||||
});
|
||||
Ok(ShaderInput::UniformBuffer(buffer, Type::new::<T>()))
|
||||
}
|
||||
|
||||
fn create_storage_buffer<T: ToStorageBuffer>(&self, data: T, options: StorageBufferOptions) -> Result<ShaderInput<Self::BufferHandle>> {
|
||||
let bytes = data.to_bytes();
|
||||
let mut usage = wgpu::BufferUsages::STORAGE;
|
||||
|
||||
if options.gpu_writable {
|
||||
usage |= wgpu::BufferUsages::COPY_SRC | wgpu::BufferUsages::COPY_DST;
|
||||
}
|
||||
if options.cpu_readable {
|
||||
usage |= wgpu::BufferUsages::MAP_READ | wgpu::BufferUsages::COPY_DST;
|
||||
}
|
||||
if options.cpu_writable {
|
||||
usage |= wgpu::BufferUsages::MAP_WRITE | wgpu::BufferUsages::COPY_SRC;
|
||||
}
|
||||
|
||||
let buffer = self.context.device.create_buffer_init(&wgpu::util::BufferInitDescriptor {
|
||||
label: None,
|
||||
contents: bytes.as_ref(),
|
||||
usage,
|
||||
});
|
||||
Ok(ShaderInput::StorageBuffer(buffer, Type::new::<T>()))
|
||||
}
|
||||
|
||||
fn create_output_buffer(&self, len: usize, ty: Type, cpu_readable: bool) -> Result<ShaderInput<Self::BufferHandle>> {
|
||||
let create_buffer = |usage| {
|
||||
Ok::<_, anyhow::Error>(self.context.device.create_buffer(&BufferDescriptor {
|
||||
label: None,
|
||||
size: len as u64 * ty.size().ok_or_else(|| anyhow::anyhow!("Cannot create buffer of type {:?}", ty))? as u64,
|
||||
usage,
|
||||
mapped_at_creation: false,
|
||||
}))
|
||||
};
|
||||
let buffer = match cpu_readable {
|
||||
true => ShaderInput::ReadBackBuffer(create_buffer(wgpu::BufferUsages::STORAGE | wgpu::BufferUsages::COPY_DST | wgpu::BufferUsages::MAP_READ)?, ty),
|
||||
false => ShaderInput::OutputBuffer(create_buffer(wgpu::BufferUsages::STORAGE | wgpu::BufferUsages::COPY_SRC)?, ty),
|
||||
};
|
||||
Ok(buffer)
|
||||
}
|
||||
|
||||
fn create_compute_pass(&self, layout: &gpu_executor::PipelineLayout<Self>, read_back: Option<ShaderInput<Self::BufferHandle>>, instances: u32) -> Result<CommandBuffer> {
|
||||
let compute_pipeline = self.context.device.create_compute_pipeline(&wgpu::ComputePipelineDescriptor {
|
||||
label: None,
|
||||
layout: None,
|
||||
module: &layout.shader,
|
||||
entry_point: layout.entry_point.as_str(),
|
||||
});
|
||||
let bind_group_layout = compute_pipeline.get_bind_group_layout(0);
|
||||
|
||||
let entries = layout
|
||||
.bind_group
|
||||
.buffers
|
||||
.iter()
|
||||
.chain(std::iter::once(&layout.output_buffer))
|
||||
.flat_map(|input| input.buffer())
|
||||
.enumerate()
|
||||
.map(|(i, buffer)| wgpu::BindGroupEntry {
|
||||
binding: i as u32,
|
||||
resource: buffer.as_entire_binding(),
|
||||
})
|
||||
.collect::<Vec<_>>();
|
||||
|
||||
let bind_group = self.context.device.create_bind_group(&wgpu::BindGroupDescriptor {
|
||||
label: None,
|
||||
layout: &bind_group_layout,
|
||||
entries: entries.as_slice(),
|
||||
});
|
||||
|
||||
let mut encoder = self.context.device.create_command_encoder(&wgpu::CommandEncoderDescriptor { label: None });
|
||||
{
|
||||
let mut cpass = encoder.begin_compute_pass(&wgpu::ComputePassDescriptor { label: None });
|
||||
cpass.set_pipeline(&compute_pipeline);
|
||||
cpass.set_bind_group(0, &bind_group, &[]);
|
||||
cpass.insert_debug_marker("compute node network evaluation");
|
||||
cpass.dispatch_workgroups(instances, 1, 1); // Number of cells to run, the (x,y,z) size of item being processed
|
||||
}
|
||||
// Sets adds copy operation to command encoder.
|
||||
// Will copy data from storage buffer on GPU to staging buffer on CPU.
|
||||
if let Some(ShaderInput::ReadBackBuffer(output, ty)) = read_back {
|
||||
let size = output.size();
|
||||
assert_eq!(size, layout.output_buffer.buffer().unwrap().size());
|
||||
assert_eq!(ty, layout.output_buffer.ty());
|
||||
encoder.copy_buffer_to_buffer(
|
||||
layout.output_buffer.buffer().ok_or_else(|| anyhow::anyhow!("Tried to use an non buffer as the shader output"))?,
|
||||
0,
|
||||
&output,
|
||||
0,
|
||||
size,
|
||||
);
|
||||
}
|
||||
|
||||
// Submits command encoder for processing
|
||||
Ok(encoder.finish())
|
||||
}
|
||||
|
||||
fn execute_compute_pipeline(&self, encoder: Self::CommandBuffer) -> Result<()> {
|
||||
self.context.queue.submit(Some(encoder));
|
||||
|
||||
// Poll the device in a blocking manner so that our future resolves.
|
||||
// In an actual application, `device.poll(...)` should
|
||||
// be called in an event loop or on another thread.
|
||||
self.context.device.poll(wgpu::Maintain::Wait);
|
||||
Ok(())
|
||||
}
|
||||
|
||||
fn read_output_buffer(&self, buffer: ShaderInput<Self::BufferHandle>) -> Result<Pin<Box<dyn Future<Output = Result<Vec<u8>>>>>> {
|
||||
if let ShaderInput::ReadBackBuffer(buffer, _) = buffer {
|
||||
let future = Box::pin(async move {
|
||||
let buffer_slice = buffer.slice(..);
|
||||
|
||||
// Sets the buffer up for mapping, sending over the result of the mapping back to us when it is finished.
|
||||
let (sender, receiver) = futures_intrusive::channel::shared::oneshot_channel();
|
||||
buffer_slice.map_async(wgpu::MapMode::Read, move |v| sender.send(v).unwrap());
|
||||
|
||||
// Wait for the mapping to finish.
|
||||
#[cfg(feature = "profiling")]
|
||||
nvtx::range_push!("compute");
|
||||
let result = receiver.receive().await;
|
||||
#[cfg(feature = "profiling")]
|
||||
nvtx::range_pop!();
|
||||
|
||||
if result == Some(Ok(())) {
|
||||
// Gets contents of buffer
|
||||
let data = buffer_slice.get_mapped_range();
|
||||
// Since contents are got in bytes, this converts these bytes back to u32
|
||||
let result = bytemuck::cast_slice(&data).to_vec();
|
||||
|
||||
// With the current interface, we have to make sure all mapped views are
|
||||
// dropped before we unmap the buffer.
|
||||
drop(data);
|
||||
buffer.unmap(); // Unmaps buffer from memory
|
||||
|
||||
// Returns data from buffer
|
||||
Ok(result)
|
||||
} else {
|
||||
bail!("failed to run compute on gpu!")
|
||||
}
|
||||
});
|
||||
Ok(future)
|
||||
} else {
|
||||
bail!("Tried to read a non readback buffer")
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl NewExecutor {
|
||||
pub fn new() -> Option<Self> {
|
||||
let context = Context::new_sync()?;
|
||||
Some(Self { context })
|
||||
}
|
||||
}
|
||||
|
|
|
|||
Loading…
Reference in New Issue