Implement node composition and Cache node

This commit is contained in:
Dennis 2022-03-27 23:12:12 +02:00 committed by Keavon Chambers
parent ab727de684
commit 1174fadfaf
5 changed files with 349 additions and 132 deletions

99
node-graph/Cargo.lock generated
View File

@ -98,7 +98,7 @@ version = "2.0.4"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "328b822bdcba4d4e402be8d9adb6eebf269f969f8eadef977a553ff3c4fbcb58"
dependencies = [
"dashmap",
"dashmap 4.0.2",
"once_cell",
"rustc-hash",
]
@ -163,6 +163,17 @@ dependencies = [
"num_cpus",
]
[[package]]
name = "dashmap"
version = "5.2.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "4c8858831f7781322e539ea39e72449c46b059638250c14344fec8d0aa6e539c"
dependencies = [
"cfg-if",
"num_cpus",
"parking_lot 0.12.0",
]
[[package]]
name = "dissimilar"
version = "1.0.2"
@ -304,9 +315,9 @@ checksum = "12b8adadd720df158f4d70dfe7ccc6adb0472d7c55ca83445f6a5ab3e36f8fb6"
[[package]]
name = "lock_api"
version = "0.4.4"
version = "0.4.6"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "0382880606dff6d15c9476c416d18690b72742aa7b605bb6dd6ec9030fbf07eb"
checksum = "88943dd7ef4a2e5a4bfa2753aaab3013e34ce2533d1996fb18ef591e315e2b3b"
dependencies = [
"scopeguard",
]
@ -354,16 +365,18 @@ dependencies = [
name = "nodegraph-experiments"
version = "0.1.0"
dependencies = [
"dashmap 5.2.0",
"graph-proc-macros",
"once_cell",
"ra_ap_ide",
"ra_ap_ide_db",
]
[[package]]
name = "num_cpus"
version = "1.13.0"
version = "1.13.1"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "05499f3756671c15885fee9034446956fff3f243d6077b91e5767df161f766b3"
checksum = "19e64526ebdee182341572e50e9ad03965aa510cd94427a4549448f285e957a1"
dependencies = [
"hermit-abi",
"libc",
@ -371,9 +384,9 @@ dependencies = [
[[package]]
name = "once_cell"
version = "1.8.0"
version = "1.10.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "692fcb63b64b1758029e0a96ee63e049ce8c5948587f2f7208df04625e5f6b56"
checksum = "87f3e037eac156d1775da914196f0f37741a274155e34a0b7e427c35d2a2ecb9"
[[package]]
name = "oorandom"
@ -389,7 +402,17 @@ checksum = "6d7744ac029df22dca6284efe4e898991d28e3085c706c972bcd7da4a27a15eb"
dependencies = [
"instant",
"lock_api",
"parking_lot_core",
"parking_lot_core 0.8.3",
]
[[package]]
name = "parking_lot"
version = "0.12.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "87f5ec2493a61ac0506c0f4199f99070cbe83857b0337006a30f3e6719b8ef58"
dependencies = [
"lock_api",
"parking_lot_core 0.9.1",
]
[[package]]
@ -406,6 +429,19 @@ dependencies = [
"winapi",
]
[[package]]
name = "parking_lot_core"
version = "0.9.1"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "28141e0cc4143da2443301914478dc976a61ffdb3f043058310c70df2fed8954"
dependencies = [
"cfg-if",
"libc",
"redox_syscall",
"smallvec",
"windows-sys",
]
[[package]]
name = "percent-encoding"
version = "2.1.0"
@ -544,7 +580,7 @@ checksum = "70a5c4623546813f0c970e72591face7602f88df6cd29c41ac73e9fc8de4f1a9"
dependencies = [
"anymap",
"cov-mark",
"dashmap",
"dashmap 4.0.2",
"drop_bomb",
"either",
"fst",
@ -939,7 +975,7 @@ dependencies = [
"lock_api",
"log",
"oorandom",
"parking_lot",
"parking_lot 0.11.1",
"rustc-hash",
"salsa-macros",
"smallvec",
@ -1158,3 +1194,46 @@ name = "winapi-x86_64-pc-windows-gnu"
version = "0.4.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "712e227841d057c1ee1cd2fb22fa7e5a5461ae8e48fa2ca79ec42cfc1931183f"
[[package]]
name = "windows-sys"
version = "0.32.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "3df6e476185f92a12c072be4a189a0210dcdcf512a1891d6dff9edb874deadc6"
dependencies = [
"windows_aarch64_msvc",
"windows_i686_gnu",
"windows_i686_msvc",
"windows_x86_64_gnu",
"windows_x86_64_msvc",
]
[[package]]
name = "windows_aarch64_msvc"
version = "0.32.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "d8e92753b1c443191654ec532f14c199742964a061be25d77d7a96f09db20bf5"
[[package]]
name = "windows_i686_gnu"
version = "0.32.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "6a711c68811799e017b6038e0922cb27a5e2f43a2ddb609fe0b6f3eeda9de615"
[[package]]
name = "windows_i686_msvc"
version = "0.32.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "146c11bb1a02615db74680b32a68e2d61f553cc24c4eb5b4ca10311740e44172"
[[package]]
name = "windows_x86_64_gnu"
version = "0.32.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "c912b12f7454c6620635bbff3450962753834be2a594819bd5e945af18ec64bc"
[[package]]
name = "windows_x86_64_msvc"
version = "0.32.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "504a2476202769977a040c6364301a3f65d0cc9e3fb08600b2bda150a0488316"

View File

@ -12,3 +12,5 @@ rust_analyzer = ["ide", "ide_db"]
ide = { version = "*", package = "ra_ap_ide", optional = true }
ide_db = { version = "*", package = "ra_ap_ide_db" , optional = true }
graph-proc-macros = {path = "proc-macro"}
once_cell = "1.10"
dashmap = "5.2"

38
node-graph/src/iter.rs Normal file
View File

@ -0,0 +1,38 @@
#[derive(Clone)]
pub struct InsertAfterNth<A>
where
A: Iterator,
{
n: usize,
iter: A,
value: Option<A::Item>,
}
impl<A> Iterator for InsertAfterNth<A>
where
A: Iterator,
{
type Item = A::Item;
fn next(&mut self) -> Option<Self::Item> {
match self.n {
1.. => {
self.n -= 1;
self.iter.next()
}
0 if self.value.is_some() => self.value.take(),
_ => self.iter.next(),
}
}
}
pub fn insert_after_nth<A>(n: usize, iter: A, value: A::Item) -> InsertAfterNth<A>
where
A: Iterator,
{
InsertAfterNth {
n,
iter,
value: Some(value),
}
}

View File

@ -1,131 +1,29 @@
use std::{any::Any, iter::Sum, ops::Add};
#![deny(rust_2018_idioms)]
use std::any::Any;
pub struct InsertAfterNth<A>
where
A: Iterator,
{
n: usize,
iter: A,
value: Option<A::Item>,
}
mod iter;
mod nodes;
use iter::insert_after_nth;
use nodes::*;
impl<A> Iterator for InsertAfterNth<A>
where
A: Iterator,
{
type Item = A::Item;
fn next(&mut self) -> Option<Self::Item> {
match self.n {
1.. => {
self.n -= 1;
self.iter.next()
}
0 if self.value.is_some() => self.value.take(),
_ => self.iter.next(),
}
}
}
pub fn insert_after_nth<A>(n: usize, iter: A, value: A::Item) -> InsertAfterNth<A>
where
A: Iterator,
{
InsertAfterNth {
n,
iter,
value: Some(value),
}
}
trait Node<O> {
fn eval<'a>(&'a self, input: impl Iterator<Item = &'a dyn Any>) -> O;
pub trait Node<'n, OUT> {
fn eval(&'n self, input: impl Iterator<Item = &'n dyn Any> + Clone) -> OUT;
// fn source code
// positon
}
struct IntNode;
impl Node<u32> for IntNode {
fn eval<'a>(&'a self, _input: impl Iterator<Item = &'a dyn Any>) -> u32 {
42
}
}
struct AddNode;
impl<T: Sum + 'static + Copy> Node<T> for AddNode {
fn eval<'a>(&'a self, input: impl Iterator<Item = &'a dyn Any>) -> T {
input
.take(2)
.map(|x| *(x.downcast_ref::<T>().unwrap()))
.sum::<T>()
}
}
struct CurryNthArgNode<'a, T: Node<O>, A, O, const N: usize> {
node: &'a T,
arg: A,
_phantom_data: std::marker::PhantomData<O>,
}
impl<'a, T: Node<O>, A: 'static, O, const N: usize> Node<O> for CurryNthArgNode<'a, T, A, O, N> {
fn eval<'b>(&'b self, input: impl Iterator<Item = &'b dyn Any>) -> O {
self.node
.eval(insert_after_nth(N, input, &self.arg as &dyn Any))
}
}
impl<'a, T: Node<O>, A: 'static, O, const N: usize> CurryNthArgNode<'a, T, A, O, N> {
fn new(node: &'a T, arg: A) -> Self {
CurryNthArgNode::<'a, T, A, O, N> {
node,
arg,
_phantom_data: std::marker::PhantomData::default(),
}
}
}
struct ComposeNode<'a, L, R, B>
where
L: Node<B>,
{
first: &'a L,
second: &'a R,
_phantom_data: std::marker::PhantomData<B>,
}
impl<'a, B: 'static, L, R, O> Node<O> for ComposeNode<'a, L, R, B>
where
L: Node<B>,
R: Node<O>,
{
fn eval<'b>(&'b self, input: impl Iterator<Item = &'b dyn Any>) -> O {
let curry = CurryNthArgNode::<'a, R, B, O, 0> {
node: self.second,
arg: self.first.eval(input),
_phantom_data: std::marker::PhantomData::default(),
};
let result: O = curry.eval([].into_iter());
result
}
}
impl<'a, L, R, B: 'static> ComposeNode<'a, L, R, B>
where
L: Node<B>,
{
fn new(first: &'a L, second: &'a R) -> Self {
ComposeNode::<'a, L, R, B> {
first,
second,
_phantom_data: std::marker::PhantomData::default(),
}
}
trait After<'n, OUT, SECOND: Node<'n, OUT>> {
fn after<INTERMEDIATE, FIRST: Node<'n, INTERMEDIATE>>(
&'n self,
first: &'n FIRST,
) -> ComposeNode<'n, FIRST, SECOND, INTERMEDIATE>;
}
fn main() {
let int = IntNode;
let curry: CurryNthArgNode<_, u32, u32, 0> =
CurryNthArgNode::new(&AddNode, int.eval(std::iter::empty()));
let composition = ComposeNode::new(&curry, &curry);
let curry: CurryNthArgNode<_, u32, _, 0> = CurryNthArgNode::new(&composition, 10);
println!("{}", curry.eval(std::iter::empty()))
use std::iter;
let int = IntNode::<32>;
let curry: CurryNthArgNode<'_, _, _, u32, u32, 0> = CurryNthArgNode::new(&AddNode, &int);
let composition = curry.after(&curry);
let n = ValueNode::new(10_u32);
let curry: CurryNthArgNode<'_, _, _, u32, _, 0> = CurryNthArgNode::new(&composition, &n);
println!("{}", curry.eval(iter::empty()))
}

200
node-graph/src/nodes.rs Normal file
View File

@ -0,0 +1,200 @@
use std::{
any::Any, collections::hash_map::DefaultHasher, hash::Hasher, iter, iter::Sum,
marker::PhantomData,
};
use crate::{insert_after_nth, After, Node};
use once_cell::sync::OnceCell;
pub struct IntNode<const N: u32>;
impl<'n, const N: u32> Node<'n, u32> for IntNode<N> {
fn eval(&'n self, _input: impl Iterator<Item = &'n dyn Any>) -> u32 {
N
}
}
#[derive(Default)]
pub struct ValueNode<T>(T);
impl<'n, T> Node<'n, &'n T> for ValueNode<T> {
fn eval(&'n self, _input: impl Iterator<Item = &'n dyn Any>) -> &T {
&self.0
}
}
impl<'n, T: Copy> Node<'n, T> for ValueNode<T> {
fn eval(&'n self, _input: impl Iterator<Item = &'n dyn Any>) -> T {
self.0
}
}
impl<T> ValueNode<T> {
pub fn new(value: T) -> ValueNode<T> {
ValueNode(value)
}
}
pub struct AddNode;
impl<'n, T: Sum + 'static + Copy> Node<'n, T> for AddNode {
fn eval(&'n self, input: impl Iterator<Item = &'n dyn Any>) -> T {
input.map(|x| *(x.downcast_ref::<T>().unwrap())).sum::<T>()
}
}
/// Caches the output of a given Node and acts as a proxy
pub struct CacheNode<'n, NODE: Node<'n, OUT>, OUT: Clone> {
node: &'n NODE,
cache: OnceCell<OUT>,
}
impl<'n, NODE: Node<'n, OUT>, OUT: Clone> Node<'n, &'n OUT> for CacheNode<'n, NODE, OUT> {
fn eval(&'n self, input: impl Iterator<Item = &'n dyn Any> + Clone) -> &'n OUT {
self.cache.get_or_init(|| self.node.eval(input))
}
}
impl<'n, NODE: Node<'n, OUT>, OUT: Clone> CacheNode<'n, NODE, OUT> {
fn clear(&'n mut self) {
self.cache = OnceCell::new();
}
fn new(node: &'n NODE) -> CacheNode<'n, NODE, OUT> {
CacheNode {
node,
cache: OnceCell::new(),
}
}
}
/*
/// Caches the output of a given Node and acts as a proxy
/// Automatically resets if it receives different input
pub struct SmartCacheNode<'n, NODE: Node<'n, OUT>, OUT: Clone> {
node: &'n NODE,
map: dashmap::DashMap<u64, CacheNode<'n, NODE, OUT>>,
}
impl<'n, NODE: for<'a> Node<'a, OUT>, OUT: Clone> Node<'n, &'n CacheNode<'n, NODE, OUT>>
for SmartCacheNode<'n, NODE, OUT>
{
fn eval(
&'n self,
input: impl Iterator<Item = &'n dyn Any> + Clone,
) -> &'n CacheNode<'n, NODE, OUT> {
let mut hasher = DefaultHasher::new();
input.clone().for_each(|value| unsafe {
hasher.write(std::slice::from_raw_parts(
value as *const dyn Any as *const u8,
std::mem::size_of_val(value),
))
});
let hash = hasher.finish();
self.map.entry(hash).or_insert(CacheNode::new(self.node));
fn map<'a, 'c, 'd, N, OUT: Clone>(
_key: &'a u64,
node: &'c CacheNode<'d, N, OUT>,
) -> &'c CacheNode<'b, N, OUT>
where
N: for<'b> Node<'b, OUT>,
{
node
}
let foo: Option<&CacheNode<'n, NODE, OUT>> = self.map.view(&hash, map);
foo.unwrap()
}
}
impl<'n, NODE: Node<'n, OUT>, OUT: Clone> SmartCacheNode<'n, NODE, OUT> {
fn clear(&'n mut self) {
self.map.clear();
}
fn new(node: &'n NODE) -> SmartCacheNode<'n, NODE, OUT> {
SmartCacheNode {
node,
map: dashmap::DashMap::new(),
}
}
}*/
pub struct CurryNthArgNode<
'n,
CurryNode: Node<'n, OUT>,
ArgNode: Node<'n, ARG>,
ARG: Clone,
OUT,
const NTH: usize,
> {
node: &'n CurryNode,
arg: CacheNode<'n, ArgNode, ARG>,
_phantom_out: std::marker::PhantomData<OUT>,
_phantom_arg: std::marker::PhantomData<ARG>,
}
impl<
'n,
CurryNode: Node<'n, OUT>,
ArgNode: Node<'n, ARG>,
ARG: 'static + Clone,
OUT,
const NTH: usize,
> Node<'n, OUT> for CurryNthArgNode<'n, CurryNode, ArgNode, ARG, OUT, NTH>
{
fn eval(&'n self, input: impl Iterator<Item = &'n dyn Any> + Clone) -> OUT {
let arg = self.arg.eval(iter::empty());
let arg: &dyn Any = arg as &dyn Any;
self.node.eval(insert_after_nth(NTH, input, arg))
}
}
impl<'n, CurryNode: Node<'n, Out>, ArgNode: Node<'n, Arg>, Arg: Clone, Out, const Nth: usize>
CurryNthArgNode<'n, CurryNode, ArgNode, Arg, Out, Nth>
{
pub fn new(node: &'n CurryNode, arg: &'n ArgNode) -> Self {
CurryNthArgNode::<'n, CurryNode, ArgNode, Arg, Out, Nth> {
node,
arg: CacheNode::new(arg),
_phantom_out: PhantomData::default(),
_phantom_arg: PhantomData::default(),
}
}
}
pub struct ComposeNode<'n, FIRST, SECOND, INTERMEDIATE>
where
FIRST: Node<'n, INTERMEDIATE>,
{
first: &'n FIRST,
second: &'n SECOND,
_phantom_data: PhantomData<INTERMEDIATE>,
}
impl<'n, FIRST, SECOND, OUT: 'n, INTERMEDIATE: 'static + Clone> Node<'n, OUT>
for ComposeNode<'n, FIRST, SECOND, INTERMEDIATE>
where
FIRST: Node<'n, INTERMEDIATE>,
SECOND: Node<'n, OUT>,
{
fn eval(&'n self, input: impl Iterator<Item = &'n dyn Any> + Clone) -> OUT {
let curry = CurryNthArgNode::<'_, _, _, _, _, 0>::new(self.second, self.first);
CurryNthArgNode::<'_, _, _, _, _, 0>::new(curry, ValueNode::new(input)).eval(input)
}
}
impl<'n, FIRST, SECOND, INTERMEDIATE: 'static> ComposeNode<'n, FIRST, SECOND, INTERMEDIATE>
where
FIRST: Node<'n, INTERMEDIATE>,
{
pub fn new(first: &'n FIRST, second: &'n SECOND) -> Self {
ComposeNode::<'n, FIRST, SECOND, INTERMEDIATE> {
first,
second,
_phantom_data: PhantomData::default(),
}
}
}
impl<'n, OUT, SECOND: Node<'n, OUT>> After<'n, OUT, SECOND> for SECOND {
fn after<INTERMEDIATE, FIRST: Node<'n, INTERMEDIATE>>(
&'n self,
first: &'n FIRST,
) -> ComposeNode<'n, FIRST, SECOND, INTERMEDIATE> {
ComposeNode::<'n, FIRST, SECOND, INTERMEDIATE> {
first,
second: self,
_phantom_data: PhantomData::default(),
}
}
}