From a58d51d6852de15a23ed3ea33cb0ec59ec6dc942 Mon Sep 17 00:00:00 2001 From: 0HyperCube <78500760+0HyperCube@users.noreply.github.com> Date: Sun, 9 Apr 2023 23:30:57 +0100 Subject: [PATCH] Add cache clearing to stop the memory leak (#1106) * Add cache clearing * Add TODO comment --------- Co-authored-by: Keavon Chambers --- node-graph/gcore/src/lib.rs | 1 + node-graph/gstd/src/any.rs | 5 +++ node-graph/gstd/src/memo.rs | 24 ++++++++++-- .../interpreted-executor/src/executor.rs | 39 ++++++++++++------- 4 files changed, 51 insertions(+), 18 deletions(-) diff --git a/node-graph/gcore/src/lib.rs b/node-graph/gcore/src/lib.rs index 902556ba..d0e8a9bd 100644 --- a/node-graph/gcore/src/lib.rs +++ b/node-graph/gcore/src/lib.rs @@ -34,6 +34,7 @@ pub use raster::Color; pub trait Node<'i, Input: 'i>: 'i { type Output: 'i; fn eval<'s: 'i>(&'s self, input: Input) -> Self::Output; + fn reset(self: Pin<&mut Self>) {} } #[cfg(feature = "alloc")] diff --git a/node-graph/gstd/src/any.rs b/node-graph/gstd/src/any.rs index e9ed6a86..2f0a549a 100644 --- a/node-graph/gstd/src/any.rs +++ b/node-graph/gstd/src/any.rs @@ -34,7 +34,12 @@ where Box::new(self.node.eval(*input)) } } + fn reset(self: std::pin::Pin<&mut Self>) { + let wrapped_node = unsafe { self.map_unchecked_mut(|e| &mut e.node) }; + Node::reset(wrapped_node); + } } + impl<_I, _O, S0> DynAnyRefNode<_I, _O, S0> { pub const fn new(node: S0) -> Self { Self { node, _i: core::marker::PhantomData } diff --git a/node-graph/gstd/src/memo.rs b/node-graph/gstd/src/memo.rs index 77cb85d8..40007dde 100644 --- a/node-graph/gstd/src/memo.rs +++ b/node-graph/gstd/src/memo.rs @@ -2,6 +2,8 @@ use graphene_core::Node; use std::hash::{Hash, Hasher}; use std::marker::PhantomData; +use std::pin::Pin; +use std::sync::atomic::AtomicBool; use xxhash_rust::xxh3::Xxh3; /// Caches the output of a given Node and acts as a proxy @@ -9,7 +11,7 @@ use xxhash_rust::xxh3::Xxh3; pub struct CacheNode { // We have to use an append only data structure to make sure the references // to the cache entries are always valid - cache: boxcar::Vec<(u64, T)>, + cache: boxcar::Vec<(u64, T, AtomicBool)>, node: CachedNode, } impl<'i, T: 'i, I: 'i + Hash, CachedNode: 'i> Node<'i, I> for CacheNode @@ -22,17 +24,25 @@ where input.hash(&mut hasher); let hash = hasher.finish(); - if let Some((_, cached_value)) = self.cache.iter().find(|(h, _)| *h == hash) { + if let Some((_, cached_value, keep)) = self.cache.iter().find(|(h, _, _)| *h == hash) { + keep.store(true, std::sync::atomic::Ordering::Relaxed); return cached_value; } else { trace!("Cache miss"); let output = self.node.eval(input); - let index = self.cache.push((hash, output)); + let index = self.cache.push((hash, output, AtomicBool::new(true))); return &self.cache[index].1; } } + + fn reset(mut self: Pin<&mut Self>) { + let old_cache = std::mem::take(&mut self.cache); + self.cache = old_cache.into_iter().filter(|(_, _, keep)| keep.swap(false, std::sync::atomic::Ordering::Relaxed)).collect(); + } } +impl std::marker::Unpin for CacheNode {} + impl CacheNode { pub fn new(node: CachedNode) -> CacheNode { CacheNode { cache: boxcar::Vec::new(), node } @@ -72,8 +82,16 @@ impl<'i, T: 'i + Hash> Node<'i, Option> for LetNode { None => &self.cache.iter().last().expect("Let node was not initialized").1, } } + + fn reset(mut self: Pin<&mut Self>) { + if let Some(last) = std::mem::take(&mut self.cache).into_iter().last() { + self.cache = boxcar::vec![last]; + } + } } +impl std::marker::Unpin for LetNode {} + impl LetNode { pub fn new() -> LetNode { LetNode { cache: boxcar::Vec::new() } diff --git a/node-graph/interpreted-executor/src/executor.rs b/node-graph/interpreted-executor/src/executor.rs index 9755a822..eb48745a 100644 --- a/node-graph/interpreted-executor/src/executor.rs +++ b/node-graph/interpreted-executor/src/executor.rs @@ -1,6 +1,6 @@ -use std::collections::HashSet; +use std::collections::{HashMap, HashSet}; use std::error::Error; -use std::{collections::HashMap, sync::Arc}; +use std::sync::{Arc, RwLock}; use dyn_any::StaticType; use graph_craft::document::value::UpcastNode; @@ -59,7 +59,7 @@ impl Executor for DynamicExecutor { pub struct NodeContainer<'n> { pub node: TypeErasedPinned<'n>, // the dependencies are only kept to ensure that the nodes are not dropped while still in use - _dependencies: Vec>>, + _dependencies: Vec>>>, } impl<'a> core::fmt::Debug for NodeContainer<'a> { @@ -69,7 +69,7 @@ impl<'a> core::fmt::Debug for NodeContainer<'a> { } impl<'a> NodeContainer<'a> { - pub fn new(node: TypeErasedPinned<'a>, _dependencies: Vec>>) -> Self { + pub fn new(node: TypeErasedPinned<'a>, _dependencies: Vec>>>) -> Self { Self { node, _dependencies } } @@ -89,7 +89,7 @@ impl NodeContainer<'static> { #[derive(Default, Debug, Clone)] pub struct BorrowTree { - nodes: HashMap>>, + nodes: HashMap>>>, } impl BorrowTree { @@ -107,6 +107,11 @@ impl BorrowTree { for (id, node) in proto_network.nodes { if !self.nodes.contains_key(&id) { self.push_node(id, node, typing_context)?; + } else { + let Some(node_container) = self.nodes.get_mut(&id) else { continue }; + let mut node_container_writer = node_container.write().unwrap(); + let node = node_container_writer.node.as_mut(); + node.reset(); } old_nodes.remove(&id); } @@ -114,29 +119,33 @@ impl BorrowTree { } fn node_refs(&self, nodes: &[NodeId]) -> Vec> { - self.node_deps(nodes).into_iter().map(|node| unsafe { node.as_ref().static_ref() }).collect() + self.node_deps(nodes).into_iter().map(|node| unsafe { node.read().unwrap().static_ref() }).collect() } - fn node_deps(&self, nodes: &[NodeId]) -> Vec>> { + fn node_deps(&self, nodes: &[NodeId]) -> Vec>>> { nodes.iter().map(|node| self.nodes.get(node).unwrap().clone()).collect() } - fn store_node(&mut self, node: Arc>, id: NodeId) -> Arc> { + fn store_node(&mut self, node: Arc>>, id: NodeId) -> Arc>> { self.nodes.insert(id, node.clone()); node } - pub fn get(&self, id: NodeId) -> Option>> { + pub fn get(&self, id: NodeId) -> Option>>> { self.nodes.get(&id).cloned() } - pub fn eval<'i, I: StaticType + 'i, O: StaticType + 'i>(&self, id: NodeId, input: I) -> Option { + pub fn eval<'i, I: StaticType + 'i, O: StaticType + 'i>(&'i self, id: NodeId, input: I) -> Option { let node = self.nodes.get(&id).cloned()?; - let output = node.node.eval(Box::new(input)); + let reader = node.read().unwrap(); + let output = reader.node.eval(Box::new(input)); dyn_any::downcast::(output).ok().map(|o| *o) } - pub fn eval_any<'i, 's: 'i>(&'s self, id: NodeId, input: Any<'i>) -> Option> { + pub fn eval_any<'i>(&'i self, id: NodeId, input: Any<'i>) -> Option> { let node = self.nodes.get(&id)?; - Some(node.node.eval(input)) + // TODO: Comments by @TrueDoctor before this was merged: + // TODO: Oof I dislike the evaluation being an unsafe operation but I guess its fine because it only is a lifetime extension + // TODO: We should ideally let miri run on a test that evaluates the nodegraph multiple times to check if this contains any subtle UB but this looks fine for now + Some(unsafe { (*((&*node.read().unwrap()) as *const NodeContainer)).node.eval(input) }) } pub fn free_node(&mut self, id: NodeId) { @@ -152,7 +161,7 @@ impl BorrowTree { let node = Box::pin(upcasted) as TypeErasedPinned<'_>; let node = NodeContainer { node, _dependencies: vec![] }; let node = unsafe { node.erase_lifetime() }; - self.store_node(Arc::new(node), id); + self.store_node(Arc::new(node.into()), id); } ConstructionArgs::Nodes(ids) => { let ids: Vec<_> = ids.iter().map(|(id, _)| *id).collect(); @@ -164,7 +173,7 @@ impl BorrowTree { _dependencies: self.node_deps(&ids), }; let node = unsafe { node.erase_lifetime() }; - self.store_node(Arc::new(node), id); + self.store_node(Arc::new(node.into()), id); } }; Ok(())