use crate::ir::{NodeId, TrigGraph, TrigOp}; use std::collections::HashSet; /// Parallelism classification for a subexpression. #[derive(Debug, Clone, Copy, PartialEq, Eq)] pub enum ParallelClass { /// Commutative, associative: union, add. Branches evaluate independently. Additive, /// Same operation applied N times: scale, repetition. Multiplicative, /// difference, intersection. Parallelizable if operands are independent. Divisive, /// True data dependency — must evaluate sequentially. Sequential, } /// An independent subtree identified in the DAG. #[derive(Debug, Clone)] pub struct Subtree { pub root: NodeId, pub nodes: HashSet, pub inputs: HashSet, } /// Identify independent subtrees in a TrigGraph. /// /// Two subtrees are independent if they share no intermediate nodes /// (they may share inputs like InputX/Y/Z and constants). pub fn find_independent_subtrees(graph: &TrigGraph) -> Vec { let output = graph.output; let root_op = &graph.nodes[output as usize]; // Only split at union/intersection/add (commutative operations) match root_op { TrigOp::Min(a, b) | TrigOp::Max(a, b) | TrigOp::Add(a, b) => { let left = collect_subtree(graph, *a); let right = collect_subtree(graph, *b); // Check independence: no shared non-input nodes let shared: HashSet = left.nodes.intersection(&right.nodes) .copied() .filter(|&id| !is_shared_input(&graph.nodes[id as usize])) .collect(); if shared.is_empty() { return vec![left, right]; } } _ => {} } // Entire graph is one subtree vec![collect_subtree(graph, output)] } /// Classify the parallelism of the root operation. pub fn classify_root(graph: &TrigGraph) -> ParallelClass { match &graph.nodes[graph.output as usize] { TrigOp::Min(_, _) | TrigOp::Add(_, _) => ParallelClass::Additive, TrigOp::Mul(_, _) | TrigOp::Div(_, _) => ParallelClass::Multiplicative, TrigOp::Max(_, _) | TrigOp::Sub(_, _) => ParallelClass::Divisive, _ => ParallelClass::Sequential, } } /// Recursively count the maximum parallelism depth. /// Returns how many independent branches exist at each level. pub fn parallelism_depth(graph: &TrigGraph) -> Vec { let mut levels: Vec = Vec::new(); count_branches(graph, graph.output, 0, &mut levels); levels } fn count_branches(graph: &TrigGraph, node: NodeId, depth: usize, levels: &mut Vec) { while levels.len() <= depth { levels.push(0); } match &graph.nodes[node as usize] { // Splittable operations — both children are independent branches TrigOp::Min(a, b) | TrigOp::Add(a, b) => { levels[depth] += 2; count_branches(graph, *a, depth + 1, levels); count_branches(graph, *b, depth + 1, levels); } // Non-commutative but still two-operand TrigOp::Max(a, b) | TrigOp::Sub(a, b) | TrigOp::Mul(a, b) | TrigOp::Div(a, b) | TrigOp::Hypot(a, b) | TrigOp::Atan2(a, b) => { levels[depth] += 1; count_branches(graph, *a, depth + 1, levels); count_branches(graph, *b, depth + 1, levels); } // Single-operand TrigOp::Neg(a) | TrigOp::Abs(a) | TrigOp::Sin(a) | TrigOp::Cos(a) | TrigOp::Tan(a) | TrigOp::Asin(a) | TrigOp::Acos(a) | TrigOp::Atan(a) | TrigOp::Sinh(a) | TrigOp::Cosh(a) | TrigOp::Tanh(a) | TrigOp::Asinh(a) | TrigOp::Acosh(a) | TrigOp::Atanh(a) | TrigOp::Sqrt(a) | TrigOp::Exp(a) | TrigOp::Ln(a) => { levels[depth] += 1; count_branches(graph, *a, depth + 1, levels); } TrigOp::Clamp { val, lo, hi } => { levels[depth] += 1; count_branches(graph, *val, depth + 1, levels); count_branches(graph, *lo, depth + 1, levels); count_branches(graph, *hi, depth + 1, levels); } // Leaves _ => { levels[depth] += 1; } } } fn collect_subtree(graph: &TrigGraph, root: NodeId) -> Subtree { let mut nodes = HashSet::new(); let mut inputs = HashSet::new(); let mut stack = vec![root]; while let Some(id) = stack.pop() { if !nodes.insert(id) { continue; } let op = &graph.nodes[id as usize]; if is_shared_input(op) { inputs.insert(id); } match op { TrigOp::Add(a, b) | TrigOp::Sub(a, b) | TrigOp::Mul(a, b) | TrigOp::Div(a, b) | TrigOp::Hypot(a, b) | TrigOp::Atan2(a, b) | TrigOp::Min(a, b) | TrigOp::Max(a, b) => { stack.push(*a); stack.push(*b); } TrigOp::Neg(a) | TrigOp::Abs(a) | TrigOp::Sin(a) | TrigOp::Cos(a) | TrigOp::Tan(a) | TrigOp::Asin(a) | TrigOp::Acos(a) | TrigOp::Atan(a) | TrigOp::Sinh(a) | TrigOp::Cosh(a) | TrigOp::Tanh(a) | TrigOp::Asinh(a) | TrigOp::Acosh(a) | TrigOp::Atanh(a) | TrigOp::Sqrt(a) | TrigOp::Exp(a) | TrigOp::Ln(a) => { stack.push(*a); } TrigOp::Clamp { val, lo, hi } => { stack.push(*val); stack.push(*lo); stack.push(*hi); } _ => {} } } Subtree { root, nodes, inputs } } fn is_shared_input(op: &TrigOp) -> bool { matches!(op, TrigOp::InputX | TrigOp::InputY | TrigOp::InputZ | TrigOp::Const(_)) } #[cfg(test)] mod tests { use super::*; #[test] fn union_splits_into_two() { let mut g = TrigGraph::new(); let x = g.push(TrigOp::InputX); let y = g.push(TrigOp::InputY); let z = g.push(TrigOp::InputZ); // sphere(1): hypot(hypot(x,y),z) - 1 let xy = g.push(TrigOp::Hypot(x, y)); let mag = g.push(TrigOp::Hypot(xy, z)); let r1 = g.push(TrigOp::Const(1.0)); let s1 = g.push(TrigOp::Sub(mag, r1)); // sphere(1) translated: same but with offset let ox = g.push(TrigOp::Const(3.0)); let dx = g.push(TrigOp::Sub(x, ox)); let xy2 = g.push(TrigOp::Hypot(dx, y)); let mag2 = g.push(TrigOp::Hypot(xy2, z)); let r2 = g.push(TrigOp::Const(1.0)); let s2 = g.push(TrigOp::Sub(mag2, r2)); // union let u = g.push(TrigOp::Min(s1, s2)); g.set_output(u); let subtrees = find_independent_subtrees(&g); assert_eq!(subtrees.len(), 2, "union of two spheres should split into 2 subtrees"); } #[test] fn classify_union() { let mut g = TrigGraph::new(); let a = g.push(TrigOp::Const(1.0)); let b = g.push(TrigOp::Const(2.0)); let u = g.push(TrigOp::Min(a, b)); g.set_output(u); assert_eq!(classify_root(&g), ParallelClass::Additive); } #[test] fn classify_difference() { let mut g = TrigGraph::new(); let a = g.push(TrigOp::Const(1.0)); let b = g.push(TrigOp::Const(2.0)); let d = g.push(TrigOp::Max(a, b)); g.set_output(d); assert_eq!(classify_root(&g), ParallelClass::Divisive); } }