From 0586d52f3a3d6a447dfe24200bb91c8f794e6645 Mon Sep 17 00:00:00 2001 From: Dennis Kobert Date: Sat, 27 May 2023 19:27:46 +0200 Subject: [PATCH] Implement experimental WebGPU support (#1238) * Web gpu execution MVP Ready infrastructure for wgpu experimentation Start implementing simple gpu test case Fix Extract Node not working with nested networks Convert inputs for extracted node to network inputs Fix missing cors headers Feature gate gcore to make it once again no-std compatible Add skeleton structure gpu shader Work on gpu node graph output saving Fix Get and Set nodes Fix storage nodes Fix shader construction errors -> spirv errors Add unsafe version Add once cell node Web gpu execution MVP --- Cargo.lock | 292 +++++++++++++----- .../document_node_types.rs | 40 ++- frontend/wasm/.cargo/Config.toml | 5 - frontend/wasm/.cargo/config.toml | 6 + frontend/wasm/Cargo.toml | 4 +- node-graph/compilation-client/src/lib.rs | 14 +- node-graph/compilation-client/src/main.rs | 2 +- node-graph/compilation-server/Cargo.toml | 1 + node-graph/compilation-server/src/main.rs | 2 + node-graph/gcore/src/lib.rs | 9 + node-graph/gcore/src/raster.rs | 46 +++ node-graph/gcore/src/raster/adjustments.rs | 13 + node-graph/gcore/src/storage.rs | 164 ++++++++-- node-graph/gcore/src/structural.rs | 47 +++ node-graph/gcore/src/value.rs | 67 +++- node-graph/gpu-compiler/Cargo.lock | 47 +++ .../gpu-compiler-bin-wrapper/src/lib.rs | 12 +- node-graph/gpu-compiler/src/lib.rs | 121 +++++--- node-graph/gpu-compiler/src/main.rs | 2 +- .../src/templates/spirv-template.rs | 24 +- node-graph/gpu-executor/src/lib.rs | 39 ++- node-graph/graph-craft/src/document.rs | 105 ++++++- node-graph/graph-craft/src/document/value.rs | 2 +- node-graph/graph-craft/src/executor.rs | 6 +- node-graph/graph-craft/src/proto.rs | 10 +- node-graph/gstd/Cargo.toml | 9 +- node-graph/gstd/src/executor.rs | 178 ++++++++++- .../interpreted-executor/src/executor.rs | 1 + node-graph/interpreted-executor/src/lib.rs | 2 +- .../interpreted-executor/src/node_registry.rs | 3 +- node-graph/wgpu-executor/Cargo.toml | 2 +- node-graph/wgpu-executor/src/context.rs | 2 +- node-graph/wgpu-executor/src/lib.rs | 42 ++- 33 files changed, 1080 insertions(+), 239 deletions(-) delete mode 100644 frontend/wasm/.cargo/Config.toml create mode 100644 frontend/wasm/.cargo/config.toml diff --git a/Cargo.lock b/Cargo.lock index 588c0c2b..e609439e 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -8,6 +8,15 @@ version = "0.11.4" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "fe438c63458706e03479442743baae6c88256498e6431708f6dfc520a26515d3" +[[package]] +name = "addr2line" +version = "0.19.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "a76fd60b23679b7d19bd066031410fb7e458ccc5e958eb5c325888ce4baedc97" +dependencies = [ + "gimli", +] + [[package]] name = "adler" version = "1.0.2" @@ -145,7 +154,7 @@ version = "0.37.2+1.3.238" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "28bf19c1f0a470be5fbf7522a308a05df06610252c5bcf5143e1b23f629a9a03" dependencies = [ - "libloading", + "libloading 0.7.4", ] [[package]] @@ -166,7 +175,7 @@ source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "2c3d816ce6f0e2909a96830d6911c2aff044370b1ef92d7f267b43bae5addedd" dependencies = [ "atk-sys", - "bitflags", + "bitflags 1.3.2", "glib", "libc", ] @@ -238,7 +247,7 @@ checksum = "b70caf9f1b0c045f7da350636435b775a9733adf2df56e8aa2a29210fbc335d4" dependencies = [ "async-trait", "axum-core", - "bitflags", + "bitflags 1.3.2", "bytes", "futures-util", "http", @@ -279,6 +288,21 @@ dependencies = [ "tower-service", ] +[[package]] +name = "backtrace" +version = "0.3.67" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "233d376d6d185f2a3093e58f283f60f880315b6c60075b01f36b3b85154564ca" +dependencies = [ + "addr2line", + "cc", + "cfg-if", + "libc", + "miniz_oxide 0.6.2", + "object", + "rustc-demangle", +] + [[package]] name = "base64" version = "0.13.1" @@ -345,6 +369,12 @@ version = "1.3.2" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "bef38d45163c2f1dde094a7dfd33ccf595c92905c8f8f4fdc18d06fb1037718a" +[[package]] +name = "bitflags" +version = "2.3.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "6776fc96284a0bb647b615056fc496d1fe1644a7ab01829818a6d91cae888b84" + [[package]] name = "block" version = "0.1.6" @@ -448,7 +478,7 @@ version = "0.15.12" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "c76ee391b03d35510d9fa917357c7f1855bd9a6659c95a1b392e33f49b3369bc" dependencies = [ - "bitflags", + "bitflags 1.3.2", "cairo-sys-rs", "glib", "libc", @@ -523,12 +553,6 @@ version = "1.0.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "baf1de4339761588bc0619e3cbc0120ee582ebb74b53b4efbf79117bd2da40fd" -[[package]] -name = "cfg_aliases" -version = "0.1.1" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "fd16c4719339c4530435d38e511904438d07cce7950afa3718a84ac36c10e89e" - [[package]] name = "chrono" version = "0.4.24" @@ -550,7 +574,7 @@ version = "0.24.1" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "f425db7937052c684daec3bd6375c8abe2d146dca4b8b143d6db777c39138f3a" dependencies = [ - "bitflags", + "bitflags 1.3.2", "block", "cocoa-foundation", "core-foundation", @@ -566,7 +590,7 @@ version = "0.1.1" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "931d3837c286f56e3c58423ce4eba12d08db2374461a785c86f672b08b5650d6" dependencies = [ - "bitflags", + "bitflags 1.3.2", "block", "core-foundation", "core-graphics-types", @@ -602,6 +626,12 @@ dependencies = [ "winapi", ] +[[package]] +name = "com-rs" +version = "0.2.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "bf43edc576402991846b093a7ca18a3477e0ef9c588cde84964b5d3e43016642" + [[package]] name = "combine" version = "4.6.6" @@ -639,6 +669,7 @@ dependencies = [ "serde_json", "tempfile", "tokio", + "tower-http", ] [[package]] @@ -679,7 +710,7 @@ version = "0.22.3" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "2581bbab3b8ffc6fcbd550bf46c355135d16e9ff2a6ea032ad6b9bf1d7efe4fb" dependencies = [ - "bitflags", + "bitflags 1.3.2", "core-foundation", "core-graphics-types", "foreign-types", @@ -692,7 +723,7 @@ version = "0.1.1" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "3a68b68b3446082644c91ac778bf50cd4104bfb002b5a6a7c44cca5a2c70788b" dependencies = [ - "bitflags", + "bitflags 1.3.2", "core-foundation", "foreign-types", "libc", @@ -844,12 +875,12 @@ dependencies = [ [[package]] name = "d3d12" -version = "0.5.0" +version = "0.6.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "827914e1f53b1e0e025ecd3d967a7836b7bcb54520f90e21ef8df7b4d88a2759" +checksum = "d8f0de2f5a8e7bd4a9eec0e3c781992a4ce1724f68aec7d7a3715344de8b39da" dependencies = [ - "bitflags", - "libloading", + "bitflags 1.3.2", + "libloading 0.7.4", "winapi", ] @@ -1127,7 +1158,7 @@ source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "3b9429470923de8e8cbd4d2dc513535400b4b3fef0319fb5c4e1f520a7bef743" dependencies = [ "crc32fast", - "miniz_oxide", + "miniz_oxide 0.7.1", ] [[package]] @@ -1294,7 +1325,7 @@ version = "0.15.4" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "a6e05c1f572ab0e1f15be94217f0dc29088c248b14f792a5ff0af0d84bcda9e8" dependencies = [ - "bitflags", + "bitflags 1.3.2", "cairo-rs", "gdk-pixbuf", "gdk-sys", @@ -1310,7 +1341,7 @@ version = "0.15.11" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "ad38dd9cc8b099cceecdf41375bb6d481b1b5a7cd5cd603e10a69a9383f8619a" dependencies = [ - "bitflags", + "bitflags 1.3.2", "gdk-pixbuf-sys", "gio", "glib", @@ -1405,13 +1436,19 @@ dependencies = [ "wasi 0.11.0+wasi-snapshot-preview1", ] +[[package]] +name = "gimli" +version = "0.27.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "ad0a93d233ebf96623465aad4046a8d3aa4da22d4f4beba5388838c8a434bbb4" + [[package]] name = "gio" version = "0.15.12" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "68fdbc90312d462781a395f7a16d96a2b379bb6ef8cd6310a2df272771c4283b" dependencies = [ - "bitflags", + "bitflags 1.3.2", "futures-channel", "futures-core", "futures-io", @@ -1452,7 +1489,7 @@ version = "0.15.12" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "edb0306fbad0ab5428b0ca674a23893db909a98582969c9b537be4ced78c505d" dependencies = [ - "bitflags", + "bitflags 1.3.2", "futures-channel", "futures-core", "futures-executor", @@ -1512,9 +1549,9 @@ dependencies = [ [[package]] name = "glow" -version = "0.11.2" +version = "0.12.1" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "d8bd5877156a19b8ac83a29b2306fe20537429d318f3ff0a1a2119f8d9c61919" +checksum = "4e007a07a24de5ecae94160f141029e9a347282cfe25d1d58d85d845cf3130f1" dependencies = [ "js-sys", "slotmap", @@ -1539,7 +1576,7 @@ version = "0.5.3" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "7fc59e5f710e310e76e6707f86c561dd646f69a8876da9131703b2f717de818d" dependencies = [ - "bitflags", + "bitflags 1.3.2", "gpu-alloc-types", ] @@ -1549,7 +1586,20 @@ version = "0.2.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "54804d0d6bc9d7f26db4eaec1ad10def69b599315f487d32c334a80d1efe67a5" dependencies = [ - "bitflags", + "bitflags 1.3.2", +] + +[[package]] +name = "gpu-allocator" +version = "0.22.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "ce95f9e2e11c2c6fadfce42b5af60005db06576f231f5c92550fdded43c423e8" +dependencies = [ + "backtrace", + "log", + "thiserror", + "winapi", + "windows 0.44.0", ] [[package]] @@ -1570,7 +1620,7 @@ version = "0.2.3" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "0b0c02e1ba0bdb14e965058ca34e09c020f8e507a760df1121728e0aef68d57a" dependencies = [ - "bitflags", + "bitflags 1.3.2", "gpu-descriptor-types", "hashbrown", ] @@ -1581,7 +1631,7 @@ version = "0.1.1" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "363e3677e55ad168fef68cf9de3a4a310b53124c5e784c53a1d70e92d23f2126" dependencies = [ - "bitflags", + "bitflags 1.3.2", ] [[package]] @@ -1723,7 +1773,7 @@ name = "graphite-editor" version = "0.0.0" dependencies = [ "bezier-rs", - "bitflags", + "bitflags 1.3.2", "borrow_stack", "derivative", "dyn-any", @@ -1785,7 +1835,7 @@ source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "92e3004a2d5d6d8b5057d2b57b3712c9529b62e82c77f25c1fecde1fd5c23bd0" dependencies = [ "atk", - "bitflags", + "bitflags 1.3.2", "cairo-rs", "field-offset", "futures-channel", @@ -1870,6 +1920,21 @@ dependencies = [ "ahash 0.7.6", ] +[[package]] +name = "hassle-rs" +version = "0.10.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "1397650ee315e8891a0df210707f0fc61771b0cc518c3023896064c5407cb3b0" +dependencies = [ + "bitflags 1.3.2", + "com-rs", + "libc", + "libloading 0.7.4", + "thiserror", + "widestring", + "winapi", +] + [[package]] name = "heck" version = "0.3.3" @@ -1957,6 +2022,12 @@ version = "0.1.5" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "21dec9db110f5f872ed9699c3ecf50cf16f423502706ba5c72462e28d3157573" +[[package]] +name = "http-range-header" +version = "0.3.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "0bfe8eed0a9285ef776bb792479ea3834e8b94e13d615c2f66d03dd50a435a29" + [[package]] name = "httparse" version = "1.8.0" @@ -2189,7 +2260,7 @@ version = "0.16.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "bf053e7843f2812ff03ef5afe34bb9c06ffee120385caad4f6b9967fcd37d41c" dependencies = [ - "bitflags", + "bitflags 1.3.2", "glib", "javascriptcore-rs-sys", ] @@ -2253,7 +2324,7 @@ source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "8c2352bd1d0bceb871cb9d40f24360c8133c11d7486b68b5381c1dd1a32015e3" dependencies = [ "libc", - "libloading", + "libloading 0.7.4", "pkg-config", ] @@ -2321,6 +2392,16 @@ dependencies = [ "winapi", ] +[[package]] +name = "libloading" +version = "0.8.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "d580318f95776505201b28cf98eb1fa5e4be3b689633ba6a3e6cd880ff22d8cb" +dependencies = [ + "cfg-if", + "windows-sys 0.48.0", +] + [[package]] name = "libm" version = "0.2.6" @@ -2484,7 +2565,7 @@ version = "0.24.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "de11355d1f6781482d027a3b4d4de7825dcedb197bf573e0596d00008402d060" dependencies = [ - "bitflags", + "bitflags 1.3.2", "block", "core-graphics-types", "foreign-types", @@ -2498,6 +2579,15 @@ version = "0.3.17" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "6877bb514081ee2a7ff5ef9de3281f14a4dd4bceac4c09388074a6b5df8a139a" +[[package]] +name = "miniz_oxide" +version = "0.6.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "b275950c28b37e794e8c55d88aeb5e139d0ce23fdbbeda68f8d7174abdf9e8fa" +dependencies = [ + "adler", +] + [[package]] name = "miniz_oxide" version = "0.7.1" @@ -2522,12 +2612,12 @@ dependencies = [ [[package]] name = "naga" -version = "0.10.0" +version = "0.12.1" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "262d2840e72dbe250e8cf2f522d080988dfca624c4112c096238a4845f591707" +checksum = "94d3edd593521f4a1dfd9b25193ed0224764572905f013d30ca5fbb85e010876" dependencies = [ "bit-set", - "bitflags", + "bitflags 1.3.2", "codespan-reporting", "hexf-parse", "indexmap", @@ -2609,7 +2699,7 @@ version = "0.6.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "2032c77e030ddee34a6787a64166008da93f6a352b629261d0fee232b8742dd4" dependencies = [ - "bitflags", + "bitflags 1.3.2", "jni-sys", "ndk-sys", "num_enum", @@ -2842,6 +2932,15 @@ dependencies = [ "objc", ] +[[package]] +name = "object" +version = "0.30.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "ea86265d3d3dcb6a27fc51bd29a4bf387fae9d2986b823079d4986af253eb439" +dependencies = [ + "memchr", +] + [[package]] name = "once_cell" version = "1.17.1" @@ -2864,7 +2963,7 @@ version = "0.10.52" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "01b8574602df80f7b85fdfc5392fa884a4e3b3f4f35402c070ab34c3d3f78d56" dependencies = [ - "bitflags", + "bitflags 1.3.2", "cfg-if", "foreign-types", "libc", @@ -2935,7 +3034,7 @@ version = "0.15.10" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "22e4045548659aee5313bde6c582b0d83a627b7904dd20dc2d9ef0895d414e4f" dependencies = [ - "bitflags", + "bitflags 1.3.2", "glib", "libc", "once_cell", @@ -3161,11 +3260,11 @@ version = "0.17.8" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "aaeebc51f9e7d2c150d3f3bfeb667f2aa985db5ef1e3d212847bdedb488beeaa" dependencies = [ - "bitflags", + "bitflags 1.3.2", "crc32fast", "fdeflate", "flate2", - "miniz_oxide", + "miniz_oxide 0.7.1", ] [[package]] @@ -3379,7 +3478,7 @@ version = "0.2.16" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "fb5a58c1855b4b6819d59012155603f0b22ad30cad752600aadfcb695265519a" dependencies = [ - "bitflags", + "bitflags 1.3.2", ] [[package]] @@ -3388,7 +3487,7 @@ version = "0.3.5" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "567664f262709473930a4bf9e51bf2ebf3348f2e748ccc50dea20646858f8f29" dependencies = [ - "bitflags", + "bitflags 1.3.2", ] [[package]] @@ -3447,9 +3546,9 @@ dependencies = [ [[package]] name = "renderdoc-sys" -version = "0.7.1" +version = "1.0.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "f1382d1f0a252c4bf97dc20d979a2fdd05b024acd7c2ed0f7595d7817666a157" +checksum = "216080ab382b992234dda86873c18d4c48358f5cfcb70fd693d7f6f2131b628b" [[package]] name = "reqwest" @@ -3539,10 +3638,16 @@ source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "300a51053b1cb55c80b7a9fde4120726ddf25ca241a1cbb926626f62fb136bff" dependencies = [ "base64 0.13.1", - "bitflags", + "bitflags 1.3.2", "serde", ] +[[package]] +name = "rustc-demangle" +version = "0.1.23" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "d626bb9dae77e28219937af045c257c28bfd3f69333c512553507f5f9798cb76" + [[package]] name = "rustc-hash" version = "1.1.0" @@ -3564,7 +3669,7 @@ version = "0.37.15" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "a0661814f891c57c930a610266415528da53c4933e6dea5fb350cbfe048a9ece" dependencies = [ - "bitflags", + "bitflags 1.3.2", "errno", "io-lifetimes", "libc", @@ -3605,7 +3710,7 @@ version = "0.6.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "ab9e34ecf6900625412355a61bda0bd68099fe674de707c67e5e4aed2c05e489" dependencies = [ - "bitflags", + "bitflags 1.3.2", "bytemuck", "smallvec", "ttf-parser", @@ -3688,7 +3793,7 @@ version = "2.8.2" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "a332be01508d814fed64bf28f798a146d73792121129962fdf335bb3c49a4254" dependencies = [ - "bitflags", + "bitflags 1.3.2", "core-foundation", "core-foundation-sys", "libc", @@ -3711,7 +3816,7 @@ version = "0.22.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "df320f1889ac4ba6bc0cdc9c9af7af4bd64bb927bccdf32d81140dc1f9be12fe" dependencies = [ - "bitflags", + "bitflags 1.3.2", "cssparser", "derive_more", "fxhash", @@ -4018,7 +4123,7 @@ version = "0.2.1" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "b2b4d76501d8ba387cf0fefbe055c3e0a59891d09f0f995ae4e4b16f6b60f3c0" dependencies = [ - "bitflags", + "bitflags 1.3.2", "gio", "glib", "libc", @@ -4032,7 +4137,7 @@ version = "0.2.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "009ef427103fcb17f802871647a7fa6c60cbb654b4c4e4c0ac60a31c5f6dc9cf" dependencies = [ - "bitflags", + "bitflags 1.3.2", "gio-sys", "glib-sys", "gobject-sys", @@ -4087,7 +4192,7 @@ version = "0.2.0+1.5.4" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "246bfa38fe3db3f1dfc8ca5a2cdeb7348c78be2112740cc0ec8ef18b6d94f830" dependencies = [ - "bitflags", + "bitflags 1.3.2", "num-traits", ] @@ -4097,7 +4202,7 @@ version = "0.7.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "3197bd4c021c2dfc0f9dfb356312c8f7842d972d5545c308ad86422c2e2d3e66" dependencies = [ - "bitflags", + "bitflags 1.3.2", "glam", "num-traits", "spirv-std-macros", @@ -4262,7 +4367,7 @@ version = "0.15.8" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "ac8e6399427c8494f9849b58694754d7cc741293348a6836b6c8d2c5aa82d8e6" dependencies = [ - "bitflags", + "bitflags 1.3.2", "cairo-rs", "cc", "cocoa", @@ -4788,6 +4893,24 @@ dependencies = [ "tracing", ] +[[package]] +name = "tower-http" +version = "0.4.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "5d1d42a9b3f3ec46ba828e8d376aec14592ea199f70a06a548587ecd1c4ab658" +dependencies = [ + "bitflags 1.3.2", + "bytes", + "futures-core", + "futures-util", + "http", + "http-body", + "http-range-header", + "pin-project-lite", + "tower-layer", + "tower-service", +] + [[package]] name = "tower-layer" version = "0.3.2" @@ -5073,7 +5196,7 @@ dependencies = [ "heck 0.4.1", "indexmap", "lazy_static", - "libloading", + "libloading 0.7.4", "objc", "parking_lot", "proc-macro2", @@ -5242,7 +5365,7 @@ version = "0.18.2" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "b8f859735e4a452aeb28c6c56a852967a8a76c8eb1cc32dbf931ad28a13d6370" dependencies = [ - "bitflags", + "bitflags 1.3.2", "cairo-rs", "gdk", "gdk-sys", @@ -5267,7 +5390,7 @@ source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "4d76ca6ecc47aeba01ec61e480139dda143796abcae6f83bcddf50d6b5b1dcf3" dependencies = [ "atk-sys", - "bitflags", + "bitflags 1.3.2", "cairo-sys-rs", "gdk-pixbuf-sys", "gdk-sys", @@ -5342,15 +5465,17 @@ dependencies = [ [[package]] name = "wgpu" -version = "0.14.2" +version = "0.16.1" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "81f643110d228fd62a60c5ed2ab56c4d5b3704520bd50561174ec4ec74932937" +checksum = "3059ea4ddec41ca14f356833e2af65e7e38c0a8f91273867ed526fb9bafcca95" dependencies = [ "arrayvec", + "cfg-if", "js-sys", "log", "naga", "parking_lot", + "profiling", "raw-window-handle", "smallvec", "static_assertions", @@ -5364,21 +5489,20 @@ dependencies = [ [[package]] name = "wgpu-core" -version = "0.14.2" +version = "0.16.1" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "6000d1284ef8eec6076fd5544a73125fd7eb9b635f18dceeb829d826f41724ca" +checksum = "8f478237b4bf0d5b70a39898a66fa67ca3a007d79f2520485b8b0c3dfc46f8c2" dependencies = [ "arrayvec", "bit-vec", - "bitflags", - "cfg_aliases", + "bitflags 2.3.1", "codespan-reporting", - "fxhash", "log", "naga", "parking_lot", "profiling", "raw-window-handle", + "rustc-hash", "smallvec", "thiserror", "web-sys", @@ -5410,26 +5534,28 @@ dependencies = [ [[package]] name = "wgpu-hal" -version = "0.14.1" +version = "0.16.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "3cc320a61acb26be4f549c9b1b53405c10a223fbfea363ec39474c32c348d12f" +checksum = "41af2ea7d87bd41ad0a37146252d5f7c26490209f47f544b2ee3b3ff34c7732e" dependencies = [ "android_system_properties", "arrayvec", "ash", "bit-set", - "bitflags", + "bitflags 2.3.1", "block", "core-graphics-types", "d3d12", "foreign-types", - "fxhash", "glow", "gpu-alloc", + "gpu-allocator", "gpu-descriptor", + "hassle-rs", "js-sys", "khronos-egl", - "libloading", + "libc", + "libloading 0.8.0", "log", "metal", "naga", @@ -5439,6 +5565,7 @@ dependencies = [ "range-alloc", "raw-window-handle", "renderdoc-sys", + "rustc-hash", "smallvec", "thiserror", "wasm-bindgen", @@ -5449,11 +5576,13 @@ dependencies = [ [[package]] name = "wgpu-types" -version = "0.14.1" +version = "0.16.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "fb6b28ef22cac17b9109b25b3bf8c9a103eeb293d7c5f78653979b09140375f6" +checksum = "5bd33a976130f03dcdcd39b3810c0c3fc05daf86f0aaf867db14bfb7c4a9a32b" dependencies = [ - "bitflags", + "bitflags 2.3.1", + "js-sys", + "web-sys", ] [[package]] @@ -5466,6 +5595,12 @@ dependencies = [ "safe_arch", ] +[[package]] +name = "widestring" +version = "1.0.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "653f141f39ec16bba3c5abe400a0c60da7468261cc2cbf36805022876bc721a8" + [[package]] name = "winapi" version = "0.3.9" @@ -5524,6 +5659,15 @@ dependencies = [ "windows_x86_64_msvc 0.39.0", ] +[[package]] +name = "windows" +version = "0.44.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "9e745dab35a0c4c77aa3ce42d595e13d2003d6902d6b08c9ef5fc326d08da12b" +dependencies = [ + "windows-targets 0.42.2", +] + [[package]] name = "windows" version = "0.48.0" diff --git a/editor/src/messages/portfolio/document/node_graph/node_graph_message_handler/document_node_types.rs b/editor/src/messages/portfolio/document/node_graph/node_graph_message_handler/document_node_types.rs index b5c8ee9c..f8d5f320 100644 --- a/editor/src/messages/portfolio/document/node_graph/node_graph_message_handler/document_node_types.rs +++ b/editor/src/messages/portfolio/document/node_graph/node_graph_message_handler/document_node_types.rs @@ -69,6 +69,7 @@ pub struct NodePropertiesContext<'a> { pub enum NodeImplementation { ProtoNode(NodeIdentifier), DocumentNode(NodeNetwork), + Extract, } impl NodeImplementation { @@ -718,14 +719,27 @@ fn static_nodes() -> Vec { inputs: vec![ DocumentInputType::value("Image", TaggedValue::ImageFrame(ImageFrame::empty()), true), DocumentInputType { - name: "Path", - data_type: FrontendGraphDataType::Text, - default: NodeInput::value(TaggedValue::String(String::new()), false), + name: "Node", + data_type: FrontendGraphDataType::General, + default: NodeInput::value(TaggedValue::DocumentNode(DocumentNode::default()), true), }, ], outputs: vec![DocumentOutputType::new("Image", FrontendGraphDataType::Raster)], - properties: node_properties::gpu_map_properties, + properties: node_properties::no_properties, }, + DocumentNodeType { + name: "Extract", + category: "Macros", + identifier: NodeImplementation::Extract, + inputs: vec![DocumentInputType { + name: "Node", + data_type: FrontendGraphDataType::General, + default: NodeInput::value(TaggedValue::DocumentNode(DocumentNode::default()), true), + }], + outputs: vec![DocumentOutputType::new("DocumentNode", FrontendGraphDataType::General)], + properties: node_properties::no_properties, + }, + #[cfg(feature = "quantization")] #[cfg(feature = "quantization")] DocumentNodeType { name: "Generate Quantization", @@ -1156,6 +1170,7 @@ impl DocumentNodeType { let num_inputs = self.inputs.len(); let inner_network = match &self.identifier { + NodeImplementation::DocumentNode(network) => network.clone(), NodeImplementation::ProtoNode(ident) => { NodeNetwork { inputs: (0..num_inputs).map(|_| 0).collect(), @@ -1175,7 +1190,22 @@ impl DocumentNodeType { ..Default::default() } } - NodeImplementation::DocumentNode(network) => network.clone(), + NodeImplementation::Extract => NodeNetwork { + inputs: (0..num_inputs).map(|_| 0).collect(), + outputs: vec![NodeOutput::new(0, 0)], + nodes: [( + 0, + DocumentNode { + name: "ExtractNode".to_string(), + implementation: DocumentNodeImplementation::Extract, + inputs: self.inputs.iter().map(|i| NodeInput::Network(i.default.ty())).collect(), + ..Default::default() + }, + )] + .into_iter() + .collect(), + ..Default::default() + }, }; DocumentNodeImplementation::Network(inner_network) diff --git a/frontend/wasm/.cargo/Config.toml b/frontend/wasm/.cargo/Config.toml deleted file mode 100644 index 8d1b8393..00000000 --- a/frontend/wasm/.cargo/Config.toml +++ /dev/null @@ -1,5 +0,0 @@ -[target.wasm32-unknown-unknown] -rustflags = ["-C", "target-feature=+simd128,+atomics,+bulk-memory,+mutable-globals"] - -[unstable] -build-std = ["panic_abort", "std"] diff --git a/frontend/wasm/.cargo/config.toml b/frontend/wasm/.cargo/config.toml new file mode 100644 index 00000000..a7fd7be9 --- /dev/null +++ b/frontend/wasm/.cargo/config.toml @@ -0,0 +1,6 @@ +[target.wasm32-unknown-unknown] +#rustflags = ["-C", "target-feature=+simd128,+atomics,+bulk-memory,+mutable-globals","--cfg=web_sys_unstable_apis"] +rustflags = ["-C", "target-feature=+simd128","--cfg=web_sys_unstable_apis"] + +[unstable] +build-std = ["panic_abort", "std"] diff --git a/frontend/wasm/Cargo.toml b/frontend/wasm/Cargo.toml index 1e3f9a95..38634b15 100644 --- a/frontend/wasm/Cargo.toml +++ b/frontend/wasm/Cargo.toml @@ -13,7 +13,7 @@ license = "Apache-2.0" [features] tauri = ["ron"] gpu = ["editor/gpu"] -default = [] +default = ["gpu"] [lib] crate-type = ["cdylib", "rlib"] @@ -38,7 +38,7 @@ bezier-rs = { path = "../../libraries/bezier-rs" } [dependencies.web-sys] version = "0.3.4" -features = ['Window'] +features = ["Window"] [dev-dependencies] wasm-bindgen-test = "0.3.22" diff --git a/node-graph/compilation-client/src/lib.rs b/node-graph/compilation-client/src/lib.rs index ca7bfc36..c3d43e8d 100644 --- a/node-graph/compilation-client/src/lib.rs +++ b/node-graph/compilation-client/src/lib.rs @@ -2,22 +2,22 @@ use gpu_compiler_bin_wrapper::CompileRequest; use gpu_executor::ShaderIO; use graph_craft::{proto::ProtoNetwork, Type}; -pub async fn compile(network: ProtoNetwork, inputs: Vec, output: Type, io: ShaderIO) -> Result { +pub async fn compile(networks: Vec, inputs: Vec, outputs: Vec, io: ShaderIO) -> Result { let client = reqwest::Client::new(); - let compile_request = CompileRequest::new(network, inputs.clone(), output.clone(), io.clone()); + let compile_request = CompileRequest::new(networks, inputs.clone(), outputs.clone(), io.clone()); let response = client.post("http://localhost:3000/compile/spirv").json(&compile_request).send(); let response = response.await?; response.bytes().await.map(|b| Shader { - spirv_binary: b.windows(4).map(|x| u32::from_le_bytes(x.try_into().unwrap())).collect(), + spirv_binary: b.chunks(4).map(|x| u32::from_le_bytes(x.try_into().unwrap())).collect(), input_types: inputs, - output_type: output, + output_types: outputs, io, }) } -pub fn compile_sync(network: ProtoNetwork, inputs: Vec, output: Type, io: ShaderIO) -> Result { - future_executor::block_on(compile(network, inputs, output, io)) +pub fn compile_sync(networks: Vec, inputs: Vec, outputs: Vec, io: ShaderIO) -> Result { + future_executor::block_on(compile(networks, inputs, outputs, io)) } // TODO: should we add the entry point as a field? @@ -25,6 +25,6 @@ pub fn compile_sync(network: ProtoNetwork, inputs: Vec, output: Type, io: pub struct Shader { pub spirv_binary: Vec, pub input_types: Vec, - pub output_type: Type, + pub output_types: Vec, pub io: ShaderIO, } diff --git a/node-graph/compilation-client/src/main.rs b/node-graph/compilation-client/src/main.rs index 723581cc..d5e6c9c8 100644 --- a/node-graph/compilation-client/src/main.rs +++ b/node-graph/compilation-client/src/main.rs @@ -36,7 +36,7 @@ fn main() { output: ShaderInput::OutputBuffer((), concrete!(&mut [u32])), }; - let compile_request = CompileRequest::new(proto_network, vec![concrete!(u32)], concrete!(u32), io); + let compile_request = CompileRequest::new(vec![proto_network], vec![concrete!(u32)], vec![concrete!(u32)], io); let response = client .post("http://localhost:3000/compile/spirv") .timeout(Duration::from_secs(30)) diff --git a/node-graph/compilation-server/Cargo.toml b/node-graph/compilation-server/Cargo.toml index 97dea8a6..1f82e546 100644 --- a/node-graph/compilation-server/Cargo.toml +++ b/node-graph/compilation-server/Cargo.toml @@ -16,3 +16,4 @@ serde = { version = "1.0", features = ["derive"] } tempfile = "3.3.0" anyhow = "1.0.68" futures = "0.3" +tower-http = { version = "0.4.0", features = ["cors"] } diff --git a/node-graph/compilation-server/src/main.rs b/node-graph/compilation-server/src/main.rs index c908d47c..7dc20066 100644 --- a/node-graph/compilation-server/src/main.rs +++ b/node-graph/compilation-server/src/main.rs @@ -1,6 +1,7 @@ use std::sync::Arc; use gpu_compiler_bin_wrapper::CompileRequest; +use tower_http::cors::CorsLayer; use axum::{ extract::{Json, State}, @@ -24,6 +25,7 @@ async fn main() { .route("/", get(|| async { "Hello from compilation server!" })) .route("/compile", get(|| async { "Supported targets: spirv" })) .route("/compile/spirv", post(post_compile_spirv)) + .layer(CorsLayer::permissive()) .with_state(shared_state); // run it with hyper on localhost:3000 diff --git a/node-graph/gcore/src/lib.rs b/node-graph/gcore/src/lib.rs index 9c0edecc..520bf000 100644 --- a/node-graph/gcore/src/lib.rs +++ b/node-graph/gcore/src/lib.rs @@ -107,6 +107,15 @@ impl<'i, 's: 'i, I: 'i, O: 'i, N: Node<'i, I, Output = O>> Node<'i, I> for &'s N (**self).eval(input) } } +#[cfg(feature = "alloc")] +impl<'i, 's: 'i, I: 'i, O: 'i, N: Node<'i, I, Output = O>> Node<'i, I> for Box { + type Output = O; + + fn eval(&'i self, input: I) -> Self::Output { + (**self).eval(input) + } +} + impl<'i, I: 'i, O: 'i> Node<'i, I> for &'i dyn for<'a> Node<'a, I, Output = O> { type Output = O; diff --git a/node-graph/gcore/src/raster.rs b/node-graph/gcore/src/raster.rs index 075abeb0..842814ae 100644 --- a/node-graph/gcore/src/raster.rs +++ b/node-graph/gcore/src/raster.rs @@ -192,6 +192,14 @@ pub trait Sample { fn sample(&self, pos: DVec2, area: DVec2) -> Option; } +impl<'i, T: Sample> Sample for &'i T { + type Pixel = T::Pixel; + + fn sample(&self, pos: DVec2, area: DVec2) -> Option { + (**self).sample(pos, area) + } +} + // TODO: We might rename this to Bitmap at some point pub trait Raster { type Pixel: Pixel; @@ -200,6 +208,38 @@ pub trait Raster { fn get_pixel(&self, x: u32, y: u32) -> Option; } +impl<'i, T: Raster> Raster for &'i T { + type Pixel = T::Pixel; + + fn width(&self) -> u32 { + (**self).width() + } + + fn height(&self) -> u32 { + (**self).height() + } + + fn get_pixel(&self, x: u32, y: u32) -> Option { + (**self).get_pixel(x, y) + } +} + +impl<'i, T: Raster> Raster for &'i mut T { + type Pixel = T::Pixel; + + fn width(&self) -> u32 { + (**self).width() + } + + fn height(&self) -> u32 { + (**self).height() + } + + fn get_pixel(&self, x: u32, y: u32) -> Option { + (**self).get_pixel(x, y) + } +} + pub trait RasterMut: Raster { fn get_pixel_mut(&mut self, x: u32, y: u32) -> Option<&mut Self::Pixel>; fn set_pixel(&mut self, x: u32, y: u32, pixel: Self::Pixel) { @@ -215,6 +255,12 @@ pub trait RasterMut: Raster { } } +impl<'i, T: RasterMut + Raster> RasterMut for &'i mut T { + fn get_pixel_mut(&mut self, x: u32, y: u32) -> Option<&mut Self::Pixel> { + (*self).get_pixel_mut(x, y) + } +} + #[derive(Debug, Default)] pub struct MapNode { map_fn: MapFn, diff --git a/node-graph/gcore/src/raster/adjustments.rs b/node-graph/gcore/src/raster/adjustments.rs index 0c353029..8c5457d2 100644 --- a/node-graph/gcore/src/raster/adjustments.rs +++ b/node-graph/gcore/src/raster/adjustments.rs @@ -363,6 +363,19 @@ fn invert_image(color: Color) -> Color { color.to_linear_srgb() } +// TODO replace with trait based implementation +impl<'i> Node<'i, &'i Color> for InvertRGBNode { + type Output = Color; + + fn eval(&'i self, color: &'i Color) -> Self::Output { + let color = color.to_gamma_srgb(); + + let color = color.map_rgb(|c| color.a() - c); + + color.to_linear_srgb() + } +} + #[derive(Debug, Clone, Copy)] pub struct ThresholdNode { min_luminance: MinLuminance, diff --git a/node-graph/gcore/src/storage.rs b/node-graph/gcore/src/storage.rs index 3396ec83..ae4bfb61 100644 --- a/node-graph/gcore/src/storage.rs +++ b/node-graph/gcore/src/storage.rs @@ -1,34 +1,158 @@ use crate::Node; +use core::cell::RefMut; use core::marker::PhantomData; -use core::ops::{DerefMut, Index, IndexMut}; +use core::ops::{Deref, DerefMut, Index, IndexMut}; -struct SetNode { +pub struct SetNode { storage: Storage, - index: Index, - _s: PhantomData, - _i: PhantomData, } - -#[node_macro::node_fn(SetNode<_S, _I>)] -fn set_node(value: T, storage: &'input mut _S, index: _I) +/* +#[node_macro::node_fn(SetNode)] +fn set_node<_T, _I, A: 'input>(input: (_T, _I), mut storage: RefMut<'input, A>) where - _S: IndexMut<_I>, - _S::Output: DerefMut + Sized, + A: DerefMut, + A::Target: IndexMut<_I, Output = _T>, { - *storage.index_mut(index).deref_mut() = value; + let (value, index) = input; + *storage.deref_mut().index_mut(index).deref_mut() = value; +}*/ +impl<'input, T: 'input, I: 'input, A: 'input + 'input, S0: 'input> Node<'input, (T, I)> for SetNode +where + A: DerefMut, + A::Target: IndexMut, + S0: for<'any_input> Node<'input, (), Output = A>, +{ + type Output = (); + #[inline] + fn eval(&'input self, input: (T, I)) -> Self::Output { + let mut storage = self.storage.eval(()); + let (value, index) = input; + *storage.deref_mut().index_mut(index).deref_mut() = value; + } +} +impl<'input, S0: 'input> SetNode { + pub const fn new(storage: S0) -> Self { + Self { storage } + } } -struct GetNode { +pub struct ExtractXNode {} + +#[node_macro::node_fn(ExtractXNode)] +fn extract_x_node(input: glam::UVec3) -> usize { + input.x as usize +} + +pub struct SetOwnedNode { + storage: core::cell::RefCell, +} + +impl SetOwnedNode { + pub fn new(storage: Storage) -> Self { + Self { + storage: core::cell::RefCell::new(storage), + } + } +} + +impl<'input, I: 'input, T: 'input, Storage, A: ?Sized> Node<'input, (T, I)> for SetOwnedNode +where + Storage: DerefMut + 'input, + A: IndexMut + 'input, +{ + type Output = (); + fn eval(&'input self, input: (T, I)) -> Self::Output { + let (value, index) = input; + *self.storage.borrow_mut().index_mut(index) = value; + } +} + +pub struct GetNode { storage: Storage, - _s: PhantomData, } -#[node_macro::node_fn(GetNode<_S>)] -fn get_node<_S, I>(index: I, storage: &'input _S) -> &'input _S::Output -where - _S: Index, - _S::Output: Sized, -{ - storage.index(index) +impl GetNode { + pub fn new(storage: Storage) -> Self { + Self { storage } + } +} + +impl<'input, I: 'input, T: 'input, Storage, SNode, A: ?Sized> Node<'input, I> for GetNode +where + SNode: Node<'input, (), Output = Storage>, + Storage: Deref + 'input, + A: Index + 'input, + T: Clone, +{ + type Output = T; + fn eval(&'input self, index: I) -> Self::Output { + let storage = self.storage.eval(()); + storage.deref().index(index).deref().clone() + } +} + +#[cfg(test)] +mod test { + use crate::value::{CopiedNode, OnceCellNode, RefCellMutNode, UnsafeMutValueNode, ValueNode}; + use crate::Node; + + use super::*; + #[test] + fn get_node_array() { + let storage = [1, 2, 3]; + let node = GetNode::new(CopiedNode::new(&storage)); + assert_eq!((&node as &dyn Node<'_, usize, Output = i32>).eval(1), 2); + } + + #[test] + fn get_node_vec() { + let storage = vec![1, 2, 3]; + let node = GetNode::new(CopiedNode::new(&storage)); + assert_eq!(node.eval(1), 2); + } + + #[test] + fn get_node_slice() { + let storage: &[i32] = &[1, 2, 3]; + let node = GetNode::new(CopiedNode::new(storage)); + let _ = &node as &dyn Node<'_, usize, Output = i32>; + assert_eq!(node.eval(1), 2); + } + + #[test] + fn set_node_slice() { + let mut backing_storage = [1, 2, 3]; + let storage: &mut [i32] = &mut backing_storage; + let storage_node = OnceCellNode::new(storage); + let node = SetNode::new(storage_node); + node.eval((4, 1)); + assert_eq!(backing_storage, [1, 4, 3]); + } + + #[test] + fn set_owned_node_array() { + let mut storage = [1, 2, 3]; + let node = SetOwnedNode::new(&mut storage); + node.eval((4, 1)); + assert_eq!(storage, [1, 4, 3]); + } + + #[test] + fn set_owned_node_vec() { + let mut storage = vec![1, 2, 3]; + let node = SetOwnedNode::new(&mut storage); + node.eval((4, 1)); + assert_eq!(storage, [1, 4, 3]); + } + + #[test] + fn set_owned_node_slice() { + let mut backing_storage = [1, 2, 3]; + let storage: &mut [i32] = &mut backing_storage; + let node = SetOwnedNode::new(storage); + let node = &node as &dyn Node<'_, (i32, usize), Output = ()>; + node.eval((4, 1)); + assert_eq!(backing_storage, [1, 4, 3]); + } } diff --git a/node-graph/gcore/src/structural.rs b/node-graph/gcore/src/structural.rs index 57a8b9f4..a3c053cf 100644 --- a/node-graph/gcore/src/structural.rs +++ b/node-graph/gcore/src/structural.rs @@ -39,6 +39,7 @@ pub struct AsyncComposeNode { phantom: PhantomData, } +#[cfg(feature = "alloc")] impl<'i, 'f: 'i, 's: 'i, Input: 'static, First, Second> Node<'i, Input> for AsyncComposeNode where First: Node<'i, Input>, @@ -54,6 +55,7 @@ where } } +#[cfg(feature = "alloc")] impl<'i, First, Second, Input: 'i> AsyncComposeNode where First: Node<'i, Input>, @@ -77,6 +79,7 @@ pub trait Then<'i, Input: 'i>: Sized { impl<'i, First: Node<'i, Input>, Input: 'i> Then<'i, Input> for First {} +#[cfg(feature = "alloc")] pub trait AndThen<'i, Input: 'i>: Sized { fn and_then(self, second: Second) -> AsyncComposeNode where @@ -88,6 +91,7 @@ pub trait AndThen<'i, Input: 'i>: Sized { } } +#[cfg(feature = "alloc")] impl<'i, First: Node<'i, Input>, Input: 'i> AndThen<'i, Input> for First {} pub struct ConsNode, Root>(pub Root, PhantomData); @@ -108,6 +112,38 @@ impl<'i, Root: Node<'i, I>, I: 'i + From<()>> ConsNode { } } +pub struct ApplyNode { + pub node: N, + _o: PhantomData, +} +/* +#[node_macro::node_fn(ApplyNode)] +fn apply(input: In, node: &'any_input N) -> () +where + // TODO: try to allows this to return output other than () + N: for<'any_input> Node<'any_input, In, Output = ()>, +{ + node.eval(input) +} +*/ +impl<'input, In: 'input, N: 'input, S0: 'input, O: 'input> Node<'input, In> for ApplyNode +where + N: Node<'input, In, Output = O>, + S0: Node<'input, (), Output = &'input N>, +{ + type Output = >::Output; + #[inline] + fn eval(&'input self, input: In) -> Self::Output { + let node = self.node.eval(()); + node.eval(input) + } +} +impl<'input, S0: 'input, O: 'static> ApplyNode { + pub const fn new(node: S0) -> Self { + Self { node, _o: PhantomData } + } +} + #[cfg(test)] mod test { use crate::{ops::IdNode, value::ValueNode}; @@ -134,4 +170,15 @@ mod test { assert_eq!(compose.eval(()), &5); } + + #[test] + fn test_apply() { + let mut array = [1, 2, 3]; + let slice = &mut array; + let set_node = crate::storage::SetOwnedNode::new(slice); + + let apply = ApplyNode::new(ValueNode::new(set_node)); + + assert_eq!(apply.eval((1, 2)), ()); + } } diff --git a/node-graph/gcore/src/value.rs b/node-graph/gcore/src/value.rs index 2acd3b2a..7c100f92 100644 --- a/node-graph/gcore/src/value.rs +++ b/node-graph/gcore/src/value.rs @@ -1,6 +1,10 @@ use crate::Node; -use core::marker::PhantomData; +use core::{ + borrow::BorrowMut, + cell::{Cell, RefCell, RefMut}, + marker::PhantomData, +}; #[derive(Default, Copy, Clone, PartialEq, Eq, PartialOrd, Ord, Hash, Debug)] pub struct IntNode; @@ -13,7 +17,7 @@ impl<'i, const N: u32> Node<'i, ()> for IntNode { } } -#[derive(Default, Debug)] +#[derive(Default, Debug, Clone, Copy)] pub struct ValueNode(pub T); impl<'i, T: 'i> Node<'i, ()> for ValueNode { @@ -35,12 +39,62 @@ impl From for ValueNode { ValueNode::new(value) } } -impl Clone for ValueNode { - fn clone(&self) -> Self { - Self(self.0.clone()) + +#[derive(Default, Debug, Clone)] +pub struct RefCellMutNode(pub RefCell); + +impl<'i, T: 'i> Node<'i, ()> for RefCellMutNode { + type Output = RefMut<'i, T>; + #[inline(always)] + fn eval(&'i self, _input: ()) -> Self::Output { + #[cfg(not(target_arch = "spirv"))] + let a = self.0.borrow_mut(); + #[cfg(target_arch = "spirv")] + let a = unsafe { self.0.try_borrow_mut().unwrap_unchecked() }; + a + } +} + +impl RefCellMutNode { + pub const fn new(value: T) -> RefCellMutNode { + RefCellMutNode(RefCell::new(value)) + } +} +/// #Safety: Never use this as it is unsound. +#[derive(Default, Debug)] +pub struct UnsafeMutValueNode(pub T); + +/// #Safety: Never use this as it is unsound. +impl<'i, T: 'i> Node<'i, ()> for UnsafeMutValueNode { + type Output = &'i mut T; + #[inline(always)] + fn eval(&'i self, _input: ()) -> Self::Output { + unsafe { &mut *(&self.0 as &T as *const T as *mut T) } + } +} + +impl UnsafeMutValueNode { + pub const fn new(value: T) -> UnsafeMutValueNode { + UnsafeMutValueNode(value) + } +} + +#[derive(Default)] +pub struct OnceCellNode(pub Cell); + +impl<'i, T: Default + 'i> Node<'i, ()> for OnceCellNode { + type Output = T; + #[inline(always)] + fn eval(&'i self, _input: ()) -> Self::Output { + self.0.replace(T::default()) + } +} + +impl OnceCellNode { + pub const fn new(value: T) -> OnceCellNode { + OnceCellNode(Cell::new(value)) } } -impl Copy for ValueNode {} #[derive(Clone, Copy)] pub struct ClonedNode(pub T); @@ -75,6 +129,7 @@ impl<'i, T: Clone + 'i> Node<'i, ()> for DebugClonedNode { #[inline(always)] fn eval(&'i self, _input: ()) -> Self::Output { // KEEP THIS `debug!()` - It acts as the output for the debug node itself + #[cfg(not(target_arch = "spirv"))] log::debug!("DebugClonedNode::eval"); self.0.clone() diff --git a/node-graph/gpu-compiler/Cargo.lock b/node-graph/gpu-compiler/Cargo.lock index af703458..fdc405a7 100644 --- a/node-graph/gpu-compiler/Cargo.lock +++ b/node-graph/gpu-compiler/Cargo.lock @@ -603,6 +603,7 @@ dependencies = [ "num-traits", "once_cell", "rand_chacha", + "rustybuzz", "serde", "specta", "spin", @@ -1153,6 +1154,22 @@ dependencies = [ "serde", ] +[[package]] +name = "rustybuzz" +version = "0.6.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "ab9e34ecf6900625412355a61bda0bd68099fe674de707c67e5e4aed2c05e489" +dependencies = [ + "bitflags", + "bytemuck", + "smallvec", + "ttf-parser", + "unicode-bidi-mirroring", + "unicode-ccc", + "unicode-general-category", + "unicode-script", +] + [[package]] name = "ryu" version = "1.0.12" @@ -1486,6 +1503,12 @@ dependencies = [ "once_cell", ] +[[package]] +name = "ttf-parser" +version = "0.17.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "375812fa44dab6df41c195cd2f7fecb488f6c09fbaafb62807488cefab642bff" + [[package]] name = "typenum" version = "1.16.0" @@ -1557,12 +1580,36 @@ dependencies = [ "unic-common", ] +[[package]] +name = "unicode-bidi-mirroring" +version = "0.1.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "56d12260fb92d52f9008be7e4bca09f584780eb2266dc8fecc6a192bec561694" + +[[package]] +name = "unicode-ccc" +version = "0.1.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "cc2520efa644f8268dce4dcd3050eaa7fc044fca03961e9998ac7e2e92b77cf1" + +[[package]] +name = "unicode-general-category" +version = "0.6.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "2281c8c1d221438e373249e065ca4989c4c36952c211ff21a0ee91c44a3869e7" + [[package]] name = "unicode-ident" version = "1.0.6" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "84a22b9f218b40614adcb3f4ff08b703773ad44fa9423e4e0d346d5db86e4ebc" +[[package]] +name = "unicode-script" +version = "0.5.5" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "7d817255e1bed6dfd4ca47258685d14d2bdcfbc64fdc9e3819bd5848057b8ecc" + [[package]] name = "unicode-width" version = "0.1.10" diff --git a/node-graph/gpu-compiler/gpu-compiler-bin-wrapper/src/lib.rs b/node-graph/gpu-compiler/gpu-compiler-bin-wrapper/src/lib.rs index 95d1e2a8..43ce67e9 100644 --- a/node-graph/gpu-compiler/gpu-compiler-bin-wrapper/src/lib.rs +++ b/node-graph/gpu-compiler/gpu-compiler-bin-wrapper/src/lib.rs @@ -6,7 +6,7 @@ use std::io::Write; pub fn compile_spirv(request: &CompileRequest, compile_dir: Option<&str>, manifest_path: &str) -> anyhow::Result> { let serialized_graph = serde_json::to_string(&gpu_executor::CompileRequest { - network: request.network.clone(), + networks: request.networks.clone(), io: request.shader_io.clone(), })?; @@ -43,23 +43,23 @@ pub fn compile_spirv(request: &CompileRequest, compile_dir: Option<&str>, manife #[derive(Serialize, Deserialize, Debug, Clone, PartialEq)] pub struct CompileRequest { - network: graph_craft::proto::ProtoNetwork, + networks: Vec, input_types: Vec, - output_type: Type, + output_types: Vec, shader_io: ShaderIO, } impl CompileRequest { - pub fn new(network: ProtoNetwork, input_types: Vec, output_type: Type, io: ShaderIO) -> Self { + pub fn new(networks: Vec, input_types: Vec, output_types: Vec, io: ShaderIO) -> Self { // TODO: add type checking // for (input, buffer) in input_types.iter().zip(io.inputs.iter()) { // assert_eq!(input, &buffer.ty()); // } // assert_eq!(output_type, io.output.ty()); Self { - network, + networks, input_types, - output_type, + output_types, shader_io: io, } } diff --git a/node-graph/gpu-compiler/src/lib.rs b/node-graph/gpu-compiler/src/lib.rs index ef24bf27..aa00b4ae 100644 --- a/node-graph/gpu-compiler/src/lib.rs +++ b/node-graph/gpu-compiler/src/lib.rs @@ -26,7 +26,7 @@ impl Metadata { } } -pub fn create_files(metadata: &Metadata, network: &ProtoNetwork, compile_dir: &Path, io: &ShaderIO) -> anyhow::Result<()> { +pub fn create_files(metadata: &Metadata, networks: &[ProtoNetwork], compile_dir: &Path, io: &ShaderIO) -> anyhow::Result<()> { let src = compile_dir.join("src"); let cargo_file = compile_dir.join("Cargo.toml"); let cargo_toml = create_cargo_toml(metadata)?; @@ -46,7 +46,7 @@ pub fn create_files(metadata: &Metadata, network: &ProtoNetwork, compile_dir: &P } } let lib = src.join("lib.rs"); - let shader = serialize_gpu(network, io)?; + let shader = serialize_gpu(networks, io)?; eprintln!("{}", shader); std::fs::write(lib, shader)?; Ok(()) @@ -67,20 +67,21 @@ fn constant_attribute(constant: &GPUConstant) -> &'static str { } } -pub fn construct_argument(input: &ShaderInput<()>, position: u32) -> String { - match input { - ShaderInput::Constant(constant) => format!("#[spirv({})] i{}: {},", constant_attribute(constant), position, constant.ty()), +pub fn construct_argument(input: &ShaderInput<()>, position: u32, binding_offset: u32) -> String { + let line = match input { + ShaderInput::Constant(constant) => format!("#[spirv({})] i{}: {}", constant_attribute(constant), position, constant.ty()), ShaderInput::UniformBuffer(_, ty) => { - format!("#[spirv(uniform, descriptor_set = 0, binding = {})] i{}: &[{}]", position, position, ty,) + format!("#[spirv(uniform, descriptor_set = 0, binding = {})] i{}: &[{}]", position + binding_offset, position, ty,) } ShaderInput::StorageBuffer(_, ty) | ShaderInput::ReadBackBuffer(_, ty) => { - format!("#[spirv(storage_buffer, descriptor_set = 0, binding = {})] i{}: &[{}]", position, position, ty,) + format!("#[spirv(storage_buffer, descriptor_set = 0, binding = {})] i{}: &[{}]", position + binding_offset, position, ty,) } ShaderInput::OutputBuffer(_, ty) => { - format!("#[spirv(storage_buffer, descriptor_set = 0, binding = {})] i{}: &mut[{}]", position, position, ty,) + format!("#[spirv(storage_buffer, descriptor_set = 0, binding = {})] o{}: &mut[{}]", position + binding_offset, position, ty,) } ShaderInput::WorkGroupMemory(_, ty) => format!("#[spirv(workgroup_memory] i{}: {}", position, ty,), - } + }; + line.replace("glam::u32::uvec3::UVec3", "spirv_std::glam::UVec3") } struct GpuCompiler { @@ -88,10 +89,10 @@ struct GpuCompiler { } impl SpirVCompiler for GpuCompiler { - fn compile(&self, network: ProtoNetwork, io: &ShaderIO) -> anyhow::Result { + fn compile(&self, networks: &[ProtoNetwork], io: &ShaderIO) -> anyhow::Result { let metadata = Metadata::new("project".to_owned(), vec!["test@example.com".to_owned()]); - create_files(&metadata, &network, &self.compile_dir, io)?; + create_files(&metadata, networks, &self.compile_dir, io)?; let result = compile(&self.compile_dir)?; let bytes = std::fs::read(result.module.unwrap_single())?; @@ -105,50 +106,80 @@ impl SpirVCompiler for GpuCompiler { } } -pub fn serialize_gpu(network: &ProtoNetwork, io: &ShaderIO) -> anyhow::Result { +pub fn serialize_gpu(networks: &[ProtoNetwork], io: &ShaderIO) -> anyhow::Result { fn nid(id: &u64) -> String { format!("n{id}") } - dbg!(&network); dbg!(&io); - let inputs = io.inputs.iter().enumerate().map(|(i, input)| construct_argument(input, i as u32)).collect::>(); + let mut inputs = io + .inputs + .iter() + .filter(|x| !x.is_output()) + .enumerate() + .map(|(i, input)| construct_argument(input, i as u32, 0)) + .collect::>(); + let offset = inputs.len() as u32; + + inputs.extend(io.inputs.iter().filter(|x| x.is_output()).enumerate().map(|(i, input)| construct_argument(input, i as u32, offset))); let mut nodes = Vec::new(); let mut input_nodes = Vec::new(); - #[derive(serde::Serialize)] - struct Node { - id: String, - fqn: String, - args: Vec, - } - for id in network.inputs.iter() { - let Some((_, node)) = network.nodes.iter().find(|(i, _)| i == id) else { + let mut output_nodes = Vec::new(); + for network in networks { + dbg!(&network); + //assert_eq!(network.inputs.len(), io.inputs.iter().filter(|x| !x.is_output()).count()); + #[derive(serde::Serialize, Debug)] + struct Node { + id: String, + index: usize, + fqn: String, + args: Vec, + } + for (i, id) in network.inputs.iter().enumerate() { + let Some((_, node)) = network.nodes.iter().find(|(i, _)| i == id) else { anyhow::bail!("Input node not found"); }; - let fqn = &node.identifier.name; - let id = nid(id); - input_nodes.push(Node { - id, - fqn: fqn.to_string().split("<").next().unwrap().to_owned(), - args: node.construction_args.new_function_args(), - }); - } - - for (ref id, node) in network.nodes.iter() { - if network.inputs.contains(id) { - continue; + let fqn = &node.identifier.name; + let id = nid(id); + let node = Node { + id: id.clone(), + index: i, + fqn: fqn.to_string().split('<').next().unwrap().to_owned(), + args: node.construction_args.new_function_args(), + }; + dbg!(&node); + if !io.inputs[i].is_output() { + if input_nodes.iter().any(|x: &Node| x.id == id) { + continue; + } + input_nodes.push(node); + } } - let fqn = &node.identifier.name; - let id = nid(id); + for (ref id, node) in network.nodes.iter() { + if network.inputs.contains(id) { + continue; + } - nodes.push(Node { - id, - fqn: fqn.to_string().split("<").next().unwrap().to_owned(), - args: node.construction_args.new_function_args(), - }); + let fqn = &node.identifier.name; + let id = nid(id); + + if nodes.iter().any(|x: &Node| x.id == id) { + continue; + } + nodes.push(Node { + id, + index: 0, + fqn: fqn.to_string().split("<").next().unwrap().to_owned(), + args: node.construction_args.new_function_args(), + }); + } + + let output = nid(&network.output); + output_nodes.push(output); } + dbg!(&input_nodes); let template = include_str!("templates/spirv-template.rs"); let mut tera = tera::Tera::default(); @@ -156,8 +187,8 @@ pub fn serialize_gpu(network: &ProtoNetwork, io: &ShaderIO) -> anyhow::Result Result anyhow::Result<()> { let metadata = compiler::Metadata::new("project".to_owned(), vec!["test@example.com".to_owned()]); - compiler::create_files(&metadata, &request.network, &compile_dir, &request.io)?; + compiler::create_files(&metadata, &request.networks, &compile_dir, &request.io)?; let result = compiler::compile(&compile_dir)?; let bytes = std::fs::read(result.module.unwrap_single())?; diff --git a/node-graph/gpu-compiler/src/templates/spirv-template.rs b/node-graph/gpu-compiler/src/templates/spirv-template.rs index 4853867b..006327b4 100644 --- a/node-graph/gpu-compiler/src/templates/spirv-template.rs +++ b/node-graph/gpu-compiler/src/templates/spirv-template.rs @@ -4,32 +4,38 @@ #[cfg(target_arch = "spirv")] extern crate spirv_std; -#[cfg(target_arch = "spirv")] -pub mod gpu { - use super::*; +//#[cfg(target_arch = "spirv")] +//pub mod gpu { +//use super::*; use spirv_std::spirv; use spirv_std::glam::UVec3; #[allow(unused)] #[spirv(compute(threads({{compute_threads}})))] pub fn eval ( + #[spirv(global_invocation_id)] _global_index: UVec3, {% for input in inputs %} - {{input}} + {{input}}, {% endfor %} ) { use graphene_core::Node; + /* {% for input in input_nodes %} - let i{{loop.index0}} = graphene_core::value::CopiedNode::new(i{{loop.index0}}); + let i{{input.index}} = graphene_core::value::CopiedNode::new(i{{input.index}}); let _{{input.id}} = {{input.fqn}}::new({% for arg in input.args %}{{arg}}, {% endfor %}); - let {{input.id}} = graphene_core::structural::ComposeNode::new(i{{loop.index0}}, _{{input.id}}); + let {{input.id}} = graphene_core::structural::ComposeNode::new(i{{input.index}}, _{{input.id}}); {% endfor %} + */ {% for node in nodes %} let {{node.id}} = {{node.fqn}}::new({% for arg in node.args %}{{arg}}, {% endfor %}); {% endfor %} - let output = {{last_node}}.eval(()); - // TODO: Write output to buffer + {% for output in output_nodes %} + let v = {{output}}.eval(()); + o{{loop.index0}}[_global_index.x as usize] = v; + {% endfor %} + // TODO: Write output to buffer } -} +//} diff --git a/node-graph/gpu-executor/src/lib.rs b/node-graph/gpu-executor/src/lib.rs index b13e0783..8ec9d65d 100644 --- a/node-graph/gpu-executor/src/lib.rs +++ b/node-graph/gpu-executor/src/lib.rs @@ -1,13 +1,15 @@ +use bytemuck::{Pod, Zeroable}; use graph_craft::proto::ProtoNetwork; use graphene_core::*; use anyhow::Result; -use dyn_any::StaticType; +use dyn_any::{StaticType, StaticTypeSized}; use futures::Future; use glam::UVec3; use serde::{Deserialize, Serialize}; use std::borrow::Cow; use std::pin::Pin; +use std::sync::Arc; type ReadBackFuture = Pin>>>>; @@ -20,18 +22,18 @@ pub trait GpuExecutor { fn create_uniform_buffer(&self, data: T) -> Result>; fn create_storage_buffer(&self, data: T, options: StorageBufferOptions) -> Result>; fn create_output_buffer(&self, len: usize, ty: Type, cpu_readable: bool) -> Result>; - fn create_compute_pass(&self, layout: &PipelineLayout, read_back: Option>, instances: u32) -> Result; + fn create_compute_pass(&self, layout: &PipelineLayout, read_back: Option>>, instances: u32) -> Result; fn execute_compute_pipeline(&self, encoder: Self::CommandBuffer) -> Result<()>; - fn read_output_buffer(&self, buffer: ShaderInput) -> Result; + fn read_output_buffer(&self, buffer: Arc>) -> ReadBackFuture; } pub trait SpirVCompiler { - fn compile(&self, network: ProtoNetwork, io: &ShaderIO) -> Result; + fn compile(&self, network: &[ProtoNetwork], io: &ShaderIO) -> Result; } #[derive(Clone, Debug, PartialEq, Serialize, Deserialize)] pub struct CompileRequest { - pub network: ProtoNetwork, + pub networks: Vec, pub io: ShaderIO, } @@ -101,6 +103,10 @@ impl ShaderInput { ShaderInput::ReadBackBuffer(_, ty) => ty.clone(), } } + + pub fn is_output(&self) -> bool { + matches!(self, ShaderInput::OutputBuffer(_, _)) + } } pub struct Shader<'a> { @@ -119,6 +125,7 @@ pub struct StorageBufferOptions { pub cpu_writable: bool, pub gpu_writable: bool, pub cpu_readable: bool, + pub storage: bool, } pub trait ToUniformBuffer: StaticType { @@ -127,13 +134,22 @@ pub trait ToUniformBuffer: StaticType { } pub trait ToStorageBuffer: StaticType { - type StorageBufferHandle; fn to_bytes(&self) -> Cow<[u8]>; + fn ty(&self) -> Type; +} + +impl ToStorageBuffer for Vec { + fn to_bytes(&self) -> Cow<[u8]> { + Cow::Borrowed(bytemuck::cast_slice(self.as_slice())) + } + fn ty(&self) -> Type { + concrete!(T) + } } /// Collection of all arguments that are passed to the shader. pub struct Bindgroup { - pub buffers: Vec>, + pub buffers: Vec>>, } /// A struct representing a compute pipeline. @@ -141,7 +157,7 @@ pub struct PipelineLayout { pub shader: E::ShaderHandle, pub entry_point: String, pub bind_group: Bindgroup, - pub output_buffer: ShaderInput, + pub output_buffer: Arc>, } /// Extracts arguments from the function arguments and wraps them in a node. @@ -185,6 +201,7 @@ fn storage_node(data: T, executor: &'input E cpu_writable: false, gpu_writable: true, cpu_readable: false, + storage: true, }, ) .unwrap() @@ -216,8 +233,8 @@ pub struct CreateComputePassNode { } #[node_macro::node_fn(CreateComputePassNode)] -fn create_compute_pass_node(layout: PipelineLayout, executor: &'input E, output: ShaderInput, instances: u32) -> E::CommandBuffer { - executor.create_compute_pass(&layout, Some(output), instances).unwrap() +fn create_compute_pass_node(layout: PipelineLayout, executor: &'input E, output: ShaderInput, instances: u32) -> E::CommandBuffer { + executor.create_compute_pass(&layout, Some(output.into()), instances).unwrap() } pub struct CreatePipelineLayoutNode<_E, EntryPoint, Bindgroup, OutputBuffer> { @@ -228,7 +245,7 @@ pub struct CreatePipelineLayoutNode<_E, EntryPoint, Bindgroup, OutputBuffer> { } #[node_macro::node_fn(CreatePipelineLayoutNode<_E>)] -fn create_pipeline_layout_node<_E: GpuExecutor>(shader: _E::ShaderHandle, entry_point: String, bind_group: Bindgroup<_E>, output_buffer: ShaderInput<_E::BufferHandle>) -> PipelineLayout<_E> { +fn create_pipeline_layout_node<_E: GpuExecutor>(shader: _E::ShaderHandle, entry_point: String, bind_group: Bindgroup<_E>, output_buffer: Arc>) -> PipelineLayout<_E> { PipelineLayout { shader, entry_point, diff --git a/node-graph/graph-craft/src/document.rs b/node-graph/graph-craft/src/document.rs index 29cc652d..ab7c647a 100644 --- a/node-graph/graph-craft/src/document.rs +++ b/node-graph/graph-craft/src/document.rs @@ -72,6 +72,7 @@ impl DocumentNode { } NodeInput::Network(ty) => (ProtoNodeInput::Network(ty), ConstructionArgs::Nodes(vec![])), NodeInput::ShortCircut(ty) => (ProtoNodeInput::ShortCircut(ty), ConstructionArgs::Nodes(vec![])), + NodeInput::Inline(inline) => (ProtoNodeInput::None, ConstructionArgs::Inline(inline)), }; assert!(!self.inputs.iter().any(|input| matches!(input, NodeInput::Network(_))), "recieved non resolved parameter"); assert!(!self.inputs.iter().any(|input| matches!(input, NodeInput::ShortCircut(_))), "recieved non resolved parameter"); @@ -82,6 +83,10 @@ impl DocumentNode { &args ); + // If we have one parameter of the type inline, set it as the construction args + if let &[NodeInput::Inline(ref inline)] = &self.inputs[..] { + args = ConstructionArgs::Inline(inline.clone()); + } if let ConstructionArgs::Nodes(nodes) = &mut args { nodes.extend(self.inputs.iter().map(|input| match input { NodeInput::Node { node_id, lambda, .. } => (*node_id, *lambda), @@ -176,6 +181,20 @@ pub enum NodeInput { /// but actually consuming the provided input instead of passing it to its predecessor. /// See [NodeInput] docs for more explanation. ShortCircut(Type), + Inline(InlineRust), +} + +#[derive(Debug, Clone, PartialEq, Hash, DynAny)] +#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))] +pub struct InlineRust { + pub expr: String, + pub ty: Type, +} + +impl InlineRust { + pub fn new(expr: String, ty: Type) -> Self { + Self { expr, ty } + } } impl NodeInput { @@ -203,6 +222,7 @@ impl NodeInput { NodeInput::Value { exposed, .. } => *exposed, NodeInput::Network(_) => false, NodeInput::ShortCircut(_) => false, + NodeInput::Inline(_) => false, } } pub fn ty(&self) -> Type { @@ -211,6 +231,7 @@ impl NodeInput { NodeInput::Value { tagged_value, .. } => tagged_value.ty(), NodeInput::Network(ty) => ty.clone(), NodeInput::ShortCircut(ty) => ty.clone(), + NodeInput::Inline(_) => panic!("ty() called on NodeInput::Inline"), } } } @@ -225,7 +246,7 @@ pub enum DocumentNodeImplementation { impl Default for DocumentNodeImplementation { fn default() -> Self { - Self::Unresolved(NodeIdentifier::new("graphene_cored::ops::IdNode")) + Self::Unresolved(NodeIdentifier::new("graphene_core::ops::IdNode")) } } @@ -299,10 +320,9 @@ impl NodeNetwork { self.inputs.iter().map(move |id| self.nodes[id].inputs.get(0).map(|i| i.ty()).unwrap_or(concrete!(()))) } - /// An empty graph pub fn value_network(node: DocumentNode) -> Self { Self { - inputs: vec![0], + inputs: node.inputs.iter().filter(|input| matches!(input, NodeInput::Network(_))).map(|_| 0).collect(), outputs: vec![NodeOutput::new(0, 0)], nodes: [(0, node)].into_iter().collect(), disabled: vec![], @@ -754,6 +774,7 @@ impl NodeNetwork { } NodeInput::ShortCircut(_) => (), NodeInput::Value { .. } => unreachable!("Value inputs should have been replaced with value nodes"), + NodeInput::Inline(_) => (), } } node.implementation = DocumentNodeImplementation::Unresolved("graphene_core::ops::IdNode".into()); @@ -772,14 +793,69 @@ impl NodeNetwork { } } DocumentNodeImplementation::Unresolved(_) => (), - DocumentNodeImplementation::Extract => { - panic!("Extract nodes should have been removed before flattening"); - } + DocumentNodeImplementation::Extract => (), } assert!(!self.nodes.contains_key(&id), "Trying to insert a node into the network caused an id conflict"); self.nodes.insert(id, node); } + fn remove_id_node(&mut self, id: NodeId) -> Result<(), String> { + let node = self.nodes.get(&id).ok_or_else(|| format!("Node with id {} does not exist", id))?.clone(); + if let DocumentNodeImplementation::Unresolved(ident) = &node.implementation { + if ident.name == "graphene_core::ops::IdNode" { + assert_eq!(node.inputs.len(), 1, "Id node has more than one input"); + if let NodeInput::Node { node_id, output_index, .. } = node.inputs[0] { + let input_node_id = node_id; + for output in self.nodes.values_mut() { + for input in &mut output.inputs { + if let NodeInput::Node { + node_id: output_node_id, + output_index: output_output_index, + .. + } = input + { + if *output_node_id == id { + *output_node_id = input_node_id; + *output_output_index = output_index; + } + } + } + for NodeOutput { + ref mut node_id, + ref mut node_output_index, + } in self.outputs.iter_mut() + { + if *node_id == id { + *node_id = input_node_id; + *node_output_index = output_index; + } + } + } + } + self.nodes.remove(&id); + } + } + Ok(()) + } + + pub fn remove_redundant_id_nodes(&mut self) { + let id_nodes = self + .nodes + .iter() + .filter(|(_, node)| { + matches!(&node.implementation, DocumentNodeImplementation::Unresolved(ident) if ident == &NodeIdentifier::new("graphene_core::ops::IdNode")) + && node.inputs.len() == 1 + && matches!(node.inputs[0], NodeInput::Node { .. }) + }) + .map(|(id, _)| *id) + .collect::>(); + for id in id_nodes { + if let Err(e) = self.remove_id_node(id) { + log::warn!("{}", e) + } + } + } + pub fn resolve_extract_nodes(&mut self) { let mut extraction_nodes = self .nodes @@ -792,14 +868,20 @@ impl NodeNetwork { for (_, node) in &mut extraction_nodes { if let DocumentNodeImplementation::Extract = node.implementation { assert_eq!(node.inputs.len(), 1); - let NodeInput::Node { node_id, output_index, lambda } = node.inputs.pop().unwrap() else { + let NodeInput::Node { node_id, output_index, .. } = node.inputs.pop().unwrap() else { panic!("Extract node has no input"); }; assert_eq!(output_index, 0); - assert!(lambda); - let input_node = self.nodes.get_mut(&node_id).unwrap(); + // TODO: check if we can readd lambda checking + let mut input_node = self.nodes.remove(&node_id).unwrap(); node.implementation = DocumentNodeImplementation::Unresolved("graphene_core::value::ValueNode".into()); - node.inputs = vec![NodeInput::value(TaggedValue::DocumentNode(input_node.clone()), false)]; + for input in input_node.inputs.iter_mut() { + match input { + NodeInput::Node { .. } | NodeInput::Value { .. } => *input = NodeInput::Network(generic!(T)), + _ => (), + } + } + node.inputs = vec![NodeInput::value(TaggedValue::DocumentNode(input_node), false)]; } } self.nodes.extend(extraction_nodes); @@ -926,6 +1008,7 @@ mod test { implementation: DocumentNodeImplementation::Unresolved("graphene_core::ops::IdNode".into()), ..Default::default() }; + // TODO: Extend test cases to test nested network let mut extraction_network = NodeNetwork { inputs: vec![], outputs: vec![NodeOutput::new(1, 0)], @@ -945,7 +1028,7 @@ mod test { ..Default::default() }; extraction_network.resolve_extract_nodes(); - assert_eq!(extraction_network.nodes.len(), 2); + assert_eq!(extraction_network.nodes.len(), 1); let inputs = extraction_network.nodes.get(&1).unwrap().inputs.clone(); assert_eq!(inputs.len(), 1); assert!(matches!(&inputs[0], &NodeInput::Value{ tagged_value: TaggedValue::DocumentNode(ref network), ..} if network == &id_node)); diff --git a/node-graph/graph-craft/src/document/value.rs b/node-graph/graph-craft/src/document/value.rs index 267fdae5..5c5fee9d 100644 --- a/node-graph/graph-craft/src/document/value.rs +++ b/node-graph/graph-craft/src/document/value.rs @@ -189,7 +189,7 @@ impl<'a> TaggedValue { pub fn to_primitive_string(&self) -> String { match self { TaggedValue::None => "()".to_string(), - TaggedValue::String(x) => x.clone(), + TaggedValue::String(x) => format!("\"{}\"", x), TaggedValue::U32(x) => x.to_string(), TaggedValue::F32(x) => x.to_string(), TaggedValue::F64(x) => x.to_string(), diff --git a/node-graph/graph-craft/src/executor.rs b/node-graph/graph-craft/src/executor.rs index 76f1a7f0..bbbcd9a2 100644 --- a/node-graph/graph-craft/src/executor.rs +++ b/node-graph/graph-craft/src/executor.rs @@ -10,19 +10,23 @@ pub struct Compiler {} impl Compiler { pub fn compile(&self, mut network: NodeNetwork, resolve_inputs: bool) -> impl Iterator { let node_ids = network.nodes.keys().copied().collect::>(); - network.resolve_extract_nodes(); println!("flattening"); for id in node_ids { network.flatten(id); } + network.remove_redundant_id_nodes(); + network.resolve_extract_nodes(); + network.remove_dead_nodes(); let proto_networks = network.into_proto_networks(); proto_networks.map(move |mut proto_network| { if resolve_inputs { println!("resolving inputs"); + log::debug!("resolving inputs"); proto_network.resolve_inputs(); } proto_network.reorder_ids(); proto_network.generate_stable_node_ids(); + log::debug!("proto network: {:?}", proto_network); proto_network }) } diff --git a/node-graph/graph-craft/src/proto.rs b/node-graph/graph-craft/src/proto.rs index 19e60e8b..a4905d63 100644 --- a/node-graph/graph-craft/src/proto.rs +++ b/node-graph/graph-craft/src/proto.rs @@ -4,8 +4,8 @@ use std::collections::{HashMap, HashSet}; use std::hash::Hash; use xxhash_rust::xxh3::Xxh3; -use crate::document::value; use crate::document::NodeId; +use crate::document::{value, InlineRust}; use dyn_any::DynAny; use graphene_core::*; #[cfg(feature = "serde")] @@ -66,6 +66,10 @@ impl core::fmt::Display for ProtoNetwork { write_node(f, network, id.0, indent + 1)?; } } + ConstructionArgs::Inline(inline) => { + f.write_str(&"\t".repeat(indent + 1))?; + f.write_fmt(format_args!("Inline construction argument: {inline:?}"))? + } } f.write_str(&"\t".repeat(indent))?; f.write_str("}\n")?; @@ -83,6 +87,7 @@ pub enum ConstructionArgs { Value(value::TaggedValue), // the bool indicates whether to treat the node as lambda node Nodes(Vec<(NodeId, bool)>), + Inline(InlineRust), } impl PartialEq for ConstructionArgs { @@ -105,6 +110,7 @@ impl Hash for ConstructionArgs { } } Self::Value(value) => value.hash(state), + Self::Inline(inline) => inline.hash(state), } } } @@ -114,6 +120,7 @@ impl ConstructionArgs { match self { ConstructionArgs::Nodes(nodes) => nodes.iter().map(|n| format!("&n{}", n.0)).collect(), ConstructionArgs::Value(value) => vec![value.to_primitive_string()], + ConstructionArgs::Inline(inline) => vec![inline.expr.clone()], } } } @@ -453,6 +460,7 @@ impl TypingContext { .map(|node| node.ty()) }) .collect::, String>>()?, + ConstructionArgs::Inline(ref inline) => vec![inline.ty.clone()], }; // Get the node input type from the proto node declaration diff --git a/node-graph/gstd/Cargo.toml b/node-graph/gstd/Cargo.toml index 71f8a650..a9e56156 100644 --- a/node-graph/gstd/Cargo.toml +++ b/node-graph/gstd/Cargo.toml @@ -10,8 +10,13 @@ license = "MIT OR Apache-2.0" [features] memoization = ["once_cell"] -default = ["memoization"] -gpu = ["graphene-core/gpu", "gpu-compiler-bin-wrapper", "compilation-client", "gpu-executor"] +default = ["memoization", "wgpu"] +gpu = [ + "graphene-core/gpu", + "gpu-compiler-bin-wrapper", + "compilation-client", + "gpu-executor", +] vulkan = ["gpu", "vulkan-executor"] wgpu = ["gpu", "wgpu-executor"] quantization = ["autoquant"] diff --git a/node-graph/gstd/src/executor.rs b/node-graph/gstd/src/executor.rs index f085973e..6f9235d5 100644 --- a/node-graph/gstd/src/executor.rs +++ b/node-graph/gstd/src/executor.rs @@ -1,4 +1,7 @@ +use glam::UVec3; +use gpu_executor::{Bindgroup, PipelineLayout, StorageBufferOptions}; use gpu_executor::{GpuExecutor, ShaderIO, ShaderInput}; +use graph_craft::document::value::TaggedValue; use graph_craft::document::*; use graph_craft::proto::*; use graphene_core::raster::*; @@ -9,6 +12,7 @@ use wgpu_executor::NewExecutor; use bytemuck::Pod; use core::marker::PhantomData; use dyn_any::StaticTypeSized; +use std::sync::Arc; pub struct GpuCompiler { typing_context: TypingContext, @@ -19,25 +23,177 @@ pub struct GpuCompiler { #[node_macro::node_fn(GpuCompiler)] async fn compile_gpu(node: &'input DocumentNode, mut typing_context: TypingContext, io: ShaderIO) -> compilation_client::Shader { let compiler = graph_craft::executor::Compiler {}; - let DocumentNodeImplementation::Network(network) = node.implementation; - let proto_network = compiler.compile_single(network, true).unwrap(); - typing_context.update(&proto_network); - let input_types = proto_network.inputs.iter().map(|id| typing_context.get_type(*id).unwrap()).map(|node_io| node_io.output).collect(); - let output_type = typing_context.get_type(proto_network.output).unwrap().output; + let DocumentNodeImplementation::Network(ref network) = node.implementation else { panic!() }; + let proto_networks: Vec<_> = compiler.compile(network.clone(), true).collect(); - let bytes = compilation_client::compile(proto_network, input_types, output_type, io).await.unwrap(); - bytes + for network in proto_networks.iter() { + typing_context.update(network).expect("Failed to type check network"); + } + // TODO: do a proper union + let input_types = proto_networks[0] + .inputs + .iter() + .map(|id| typing_context.type_of(*id).unwrap()) + .map(|node_io| node_io.output.clone()) + .collect(); + let output_types = proto_networks.iter().map(|network| typing_context.type_of(network.output).unwrap().output.clone()).collect(); + + compilation_client::compile(proto_networks, input_types, output_types, io).await.unwrap() } -pub struct MapGpuNode { - shader: Shader, +pub struct MapGpuNode { + node: Node, } +#[node_macro::node_fn(MapGpuNode)] +async fn map_gpu(image: ImageFrame, node: DocumentNode) -> ImageFrame { + log::debug!("Executing gpu node"); + let compiler = graph_craft::executor::Compiler {}; + let inner_network = NodeNetwork::value_network(node); + + log::debug!("inner_network: {:?}", inner_network); + let network = NodeNetwork { + inputs: vec![], //vec![0, 1], + outputs: vec![NodeOutput::new(1, 0)], + nodes: [ + DocumentNode { + name: "Slice".into(), + inputs: vec![NodeInput::Inline(InlineRust::new("i0[_global_index.x as usize]".into(), concrete![Color]))], + implementation: DocumentNodeImplementation::Unresolved("graphene_core::value::CopiedNode".into()), + ..Default::default() + }, + /*DocumentNode { + name: "Index".into(), + //inputs: vec![NodeInput::Network(concrete!(UVec3))], + inputs: vec![NodeInput::Inline(InlineRust::new("i1.x as usize".into(), concrete![u32]))], + implementation: DocumentNodeImplementation::Unresolved("graphene_core::value::CopiedNode".into()), + ..Default::default() + },*/ + /* + DocumentNode { + name: "GetNode".into(), + inputs: vec![NodeInput::node(1, 0), NodeInput::node(0, 0)], + implementation: DocumentNodeImplementation::Unresolved("graphene_core::storage::GetNode".into()), + ..Default::default() + },*/ + DocumentNode { + name: "MapNode".into(), + inputs: vec![NodeInput::node(0, 0)], + implementation: DocumentNodeImplementation::Network(inner_network), + ..Default::default() + }, + /* + DocumentNode { + name: "SaveNode".into(), + inputs: vec![ + //NodeInput::node(0, 0), + NodeInput::Inline(InlineRust::new( + "o0[_global_index.x as usize] = i0[_global_index.x as usize]".into(), + Type::Fn(Box::new(concrete!(Color)), Box::new(concrete!(()))), + )), + ], + implementation: DocumentNodeImplementation::Unresolved("graphene_core::value::ValueNode".into()), + ..Default::default() + }, + */ + ] + .into_iter() + .enumerate() + .map(|(i, n)| (i as u64, n)) + .collect(), + ..Default::default() + }; + log::debug!("compiling network"); + let proto_networks = compiler.compile(network.clone(), true).collect(); + log::debug!("compiling shader"); + let shader = compilation_client::compile( + proto_networks, + vec![concrete!(Color)], //, concrete!(u32)], + vec![concrete!(Color)], + ShaderIO { + inputs: vec![ + ShaderInput::StorageBuffer((), concrete!(Color)), + //ShaderInput::Constant(gpu_executor::GPUConstant::GlobalInvocationId), + ShaderInput::OutputBuffer((), concrete!(Color)), + ], + output: ShaderInput::OutputBuffer((), concrete!(Color)), + }, + ) + .await + .unwrap(); + //return ImageFrame::empty(); + let len = image.image.data.len(); + log::debug!("instances: {}", len); + + let executor = NewExecutor::new().await.unwrap(); + log::debug!("creating buffer"); + let storage_buffer = executor + .create_storage_buffer( + image.image.data.clone(), + StorageBufferOptions { + cpu_writable: false, + gpu_writable: true, + cpu_readable: false, + storage: true, + }, + ) + .unwrap(); + let storage_buffer = Arc::new(storage_buffer); + let output_buffer = executor.create_output_buffer(len, concrete!(Color), false).unwrap(); + let output_buffer = Arc::new(output_buffer); + let readback_buffer = executor.create_output_buffer(len, concrete!(Color), true).unwrap(); + let readback_buffer = Arc::new(readback_buffer); + log::debug!("created buffer"); + let bind_group = Bindgroup { + buffers: vec![storage_buffer.clone()], + }; + + let shader = gpu_executor::Shader { + source: shader.spirv_binary.into(), + name: "gpu::eval", + io: shader.io, + }; + log::debug!("loading shader"); + log::debug!("shader: {:?}", shader.source); + let shader = executor.load_shader(shader).unwrap(); + log::debug!("loaded shader"); + let pipeline = PipelineLayout { + shader, + entry_point: "eval".to_string(), + bind_group, + output_buffer: output_buffer.clone(), + }; + log::debug!("created pipeline"); + let compute_pass = executor.create_compute_pass(&pipeline, Some(readback_buffer.clone()), len.min(65535) as u32).unwrap(); + executor.execute_compute_pipeline(compute_pass).unwrap(); + log::debug!("executed pipeline"); + log::debug!("reading buffer"); + let result = executor.read_output_buffer(readback_buffer).await.unwrap(); + let colors = bytemuck::pod_collect_to_vec::(result.as_slice()); + ImageFrame { + image: Image { + data: colors, + width: image.image.width, + height: image.image.height, + }, + transform: image.transform, + } + + /* + let executor: GpuExecutor = GpuExecutor::new(Context::new().await.unwrap(), shader.into(), "gpu::eval".into()).unwrap(); + let data: Vec<_> = input.into_iter().collect(); + let result = executor.execute(Box::new(data)).unwrap(); + let result = dyn_any::downcast::>(result).unwrap(); + *result + */ +} +/* #[node_macro::node_fn(MapGpuNode)] async fn map_gpu(inputs: Vec::BufferHandle>>, shader: &'any_input compilation_client::Shader) { use graph_craft::executor::Executor; let executor = NewExecutor::new().unwrap(); - for input in shader.inputs.iter() { + for input in shader.io.inputs.iter() { + let buffer = executor.create_storage_buffer(&self, data, options) let buffer = executor.create_buffer(input.size).unwrap(); executor.write_buffer(buffer, input.data).unwrap(); } @@ -74,6 +230,7 @@ fn map_gpu_single_image(input: Image, node: String) -> Image { inputs: vec![NodeInput::Network(concrete!(Color))], implementation: DocumentNodeImplementation::Unresolved(identifier), metadata: DocumentNodeMetadata::default(), + ..Default::default() }, )] .into_iter() @@ -85,3 +242,4 @@ fn map_gpu_single_image(input: Image, node: String) -> Image { let data = map_node.eval(input.data.clone()); Image { data, ..input } } +*/ diff --git a/node-graph/interpreted-executor/src/executor.rs b/node-graph/interpreted-executor/src/executor.rs index 6be76a2c..6c5fb50a 100644 --- a/node-graph/interpreted-executor/src/executor.rs +++ b/node-graph/interpreted-executor/src/executor.rs @@ -203,6 +203,7 @@ impl BorrowTree { let node = unsafe { node.erase_lifetime() }; self.store_node(Arc::new(node.into()), id); } + ConstructionArgs::Inline(_) => unimplemented!("Inline nodes are not supported yet"), ConstructionArgs::Nodes(ids) => { let ids: Vec<_> = ids.iter().map(|(id, _)| *id).collect(); let construction_nodes = self.node_refs(&ids); diff --git a/node-graph/interpreted-executor/src/lib.rs b/node-graph/interpreted-executor/src/lib.rs index 8a043220..d15e24ae 100644 --- a/node-graph/interpreted-executor/src/lib.rs +++ b/node-graph/interpreted-executor/src/lib.rs @@ -131,7 +131,7 @@ mod tests { 0, DocumentNode { name: "id".into(), - inputs: vec![NodeInput::Network(concrete!(u32))], + inputs: vec![NodeInput::ShortCircut(concrete!(u32))], implementation: DocumentNodeImplementation::Unresolved(NodeIdentifier::new("graphene_core::ops::IdNode")), ..Default::default() }, diff --git a/node-graph/interpreted-executor/src/node_registry.rs b/node-graph/interpreted-executor/src/node_registry.rs index e73ea1b8..51404a5b 100644 --- a/node-graph/interpreted-executor/src/node_registry.rs +++ b/node-graph/interpreted-executor/src/node_registry.rs @@ -1,5 +1,6 @@ use glam::{DAffine2, DVec2}; +use graph_craft::document::DocumentNode; use graphene_core::ops::IdNode; use graphene_core::vector::VectorData; use once_cell::sync::Lazy; @@ -219,7 +220,7 @@ fn node_registry() -> HashMap = DowncastBothNode::new(args[0]); - let document_node = ClonedNode::new(document_node.eval(()).await); + //let document_node = ClonedNode::new(document_node.eval(())); let node = graphene_std::executor::MapGpuNode::new(document_node); let any: DynAnyNode, _, _> = graphene_std::any::DynAnyNode::new(graphene_core::value::ValueNode::new(node)); Box::pin(any) as TypeErasedPinned diff --git a/node-graph/wgpu-executor/Cargo.toml b/node-graph/wgpu-executor/Cargo.toml index 3bc5de28..974a26e2 100644 --- a/node-graph/wgpu-executor/Cargo.toml +++ b/node-graph/wgpu-executor/Cargo.toml @@ -23,7 +23,7 @@ base64 = "0.13" bytemuck = {version = "1.8" } anyhow = "1.0.66" -wgpu = { version = "0.14.2", features = ["spirv"] } +wgpu = { version = "0.16", features = ["spirv"] } spirv = "0.2.0" futures-intrusive = "0.5.0" futures = "0.3.25" diff --git a/node-graph/wgpu-executor/src/context.rs b/node-graph/wgpu-executor/src/context.rs index 488b0bdf..5695c75b 100644 --- a/node-graph/wgpu-executor/src/context.rs +++ b/node-graph/wgpu-executor/src/context.rs @@ -11,7 +11,7 @@ pub struct Context { impl Context { pub async fn new() -> Option { // Instantiates instance of WebGPU - let instance = wgpu::Instance::new(wgpu::Backends::all()); + let instance = wgpu::Instance::new(wgpu::InstanceDescriptor::default()); // `request_adapter` instantiates the general connection to the GPU let adapter = instance.request_adapter(&wgpu::RequestAdapterOptions::default()).await?; diff --git a/node-graph/wgpu-executor/src/lib.rs b/node-graph/wgpu-executor/src/lib.rs index 8d18cc8e..46d6908c 100644 --- a/node-graph/wgpu-executor/src/lib.rs +++ b/node-graph/wgpu-executor/src/lib.rs @@ -9,6 +9,7 @@ use graph_craft::Type; use anyhow::{bail, Result}; use futures::Future; use std::pin::Pin; +use std::sync::Arc; use wgpu::util::DeviceExt; use wgpu::{Buffer, BufferDescriptor, CommandBuffer, ShaderModule}; @@ -42,8 +43,11 @@ impl gpu_executor::GpuExecutor for NewExecutor { fn create_storage_buffer(&self, data: T, options: StorageBufferOptions) -> Result> { let bytes = data.to_bytes(); - let mut usage = wgpu::BufferUsages::STORAGE; + let mut usage = wgpu::BufferUsages::empty(); + if options.storage { + usage |= wgpu::BufferUsages::STORAGE; + } if options.gpu_writable { usage |= wgpu::BufferUsages::COPY_SRC | wgpu::BufferUsages::COPY_DST; } @@ -54,15 +58,17 @@ impl gpu_executor::GpuExecutor for NewExecutor { usage |= wgpu::BufferUsages::MAP_WRITE | wgpu::BufferUsages::COPY_SRC; } + log::debug!("Creating storage buffer with usage {:?} and len: {}", usage, bytes.len()); let buffer = self.context.device.create_buffer_init(&wgpu::util::BufferInitDescriptor { label: None, contents: bytes.as_ref(), usage, }); - Ok(ShaderInput::StorageBuffer(buffer, Type::new::())) + Ok(ShaderInput::StorageBuffer(buffer, data.ty())) } fn create_output_buffer(&self, len: usize, ty: Type, cpu_readable: bool) -> Result> { + log::debug!("Creating output buffer with len: {}", len); let create_buffer = |usage| { Ok::<_, anyhow::Error>(self.context.device.create_buffer(&BufferDescriptor { label: None, @@ -72,13 +78,12 @@ impl gpu_executor::GpuExecutor for NewExecutor { })) }; let buffer = match cpu_readable { - true => ShaderInput::ReadBackBuffer(create_buffer(wgpu::BufferUsages::STORAGE | wgpu::BufferUsages::COPY_DST | wgpu::BufferUsages::MAP_READ)?, ty), + true => ShaderInput::ReadBackBuffer(create_buffer(wgpu::BufferUsages::COPY_DST | wgpu::BufferUsages::MAP_READ)?, ty), false => ShaderInput::OutputBuffer(create_buffer(wgpu::BufferUsages::STORAGE | wgpu::BufferUsages::COPY_SRC)?, ty), }; Ok(buffer) } - - fn create_compute_pass(&self, layout: &gpu_executor::PipelineLayout, read_back: Option>, instances: u32) -> Result { + fn create_compute_pass(&self, layout: &gpu_executor::PipelineLayout, read_back: Option>>, instances: u32) -> Result { let compute_pipeline = self.context.device.create_compute_pipeline(&wgpu::ComputePipelineDescriptor { label: None, layout: None, @@ -116,10 +121,13 @@ impl gpu_executor::GpuExecutor for NewExecutor { } // Sets adds copy operation to command encoder. // Will copy data from storage buffer on GPU to staging buffer on CPU. - if let Some(ShaderInput::ReadBackBuffer(output, ty)) = read_back { + if let Some(buffer) = read_back { + let ShaderInput::ReadBackBuffer(output, ty) = buffer.as_ref() else { + bail!("Tried to read back from a non read back buffer"); + }; let size = output.size(); assert_eq!(size, layout.output_buffer.buffer().unwrap().size()); - assert_eq!(ty, layout.output_buffer.ty()); + assert_eq!(ty, &layout.output_buffer.ty()); encoder.copy_buffer_to_buffer( layout.output_buffer.buffer().ok_or_else(|| anyhow::anyhow!("Tried to use an non buffer as the shader output"))?, 0, @@ -143,9 +151,9 @@ impl gpu_executor::GpuExecutor for NewExecutor { Ok(()) } - fn read_output_buffer(&self, buffer: ShaderInput) -> Result>>>>> { - if let ShaderInput::ReadBackBuffer(buffer, _) = buffer { - let future = Box::pin(async move { + fn read_output_buffer(&self, buffer: Arc>) -> Pin>>>> { + let future = Box::pin(async move { + if let ShaderInput::ReadBackBuffer(buffer, _) = buffer.as_ref() { let buffer_slice = buffer.slice(..); // Sets the buffer up for mapping, sending over the result of the mapping back to us when it is finished. @@ -175,17 +183,17 @@ impl gpu_executor::GpuExecutor for NewExecutor { } else { bail!("failed to run compute on gpu!") } - }); - Ok(future) - } else { - bail!("Tried to read a non readback buffer") - } + } else { + bail!("Tried to read a non readback buffer") + } + }); + future } } impl NewExecutor { - pub fn new() -> Option { - let context = Context::new_sync()?; + pub async fn new() -> Option { + let context = Context::new().await?; Some(Self { context }) } }