216 lines
7.2 KiB
Rust
216 lines
7.2 KiB
Rust
use crate::ir::{NodeId, TrigGraph, TrigOp};
|
|
use std::collections::HashSet;
|
|
|
|
/// Parallelism classification for a subexpression.
|
|
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
|
|
pub enum ParallelClass {
|
|
/// Commutative, associative: union, add. Branches evaluate independently.
|
|
Additive,
|
|
/// Same operation applied N times: scale, repetition.
|
|
Multiplicative,
|
|
/// difference, intersection. Parallelizable if operands are independent.
|
|
Divisive,
|
|
/// True data dependency — must evaluate sequentially.
|
|
Sequential,
|
|
}
|
|
|
|
/// An independent subtree identified in the DAG.
|
|
#[derive(Debug, Clone)]
|
|
pub struct Subtree {
|
|
pub root: NodeId,
|
|
pub nodes: HashSet<NodeId>,
|
|
pub inputs: HashSet<NodeId>,
|
|
}
|
|
|
|
/// Identify independent subtrees in a TrigGraph.
|
|
///
|
|
/// Two subtrees are independent if they share no intermediate nodes
|
|
/// (they may share inputs like InputX/Y/Z and constants).
|
|
pub fn find_independent_subtrees(graph: &TrigGraph) -> Vec<Subtree> {
|
|
let output = graph.output;
|
|
let root_op = &graph.nodes[output as usize];
|
|
|
|
// Only split at union/intersection/add (commutative operations)
|
|
match root_op {
|
|
TrigOp::Min(a, b) | TrigOp::Max(a, b)
|
|
| TrigOp::Add(a, b) => {
|
|
let left = collect_subtree(graph, *a);
|
|
let right = collect_subtree(graph, *b);
|
|
|
|
// Check independence: no shared non-input nodes
|
|
let shared: HashSet<NodeId> = left.nodes.intersection(&right.nodes)
|
|
.copied()
|
|
.filter(|&id| !is_shared_input(&graph.nodes[id as usize]))
|
|
.collect();
|
|
|
|
if shared.is_empty() {
|
|
return vec![left, right];
|
|
}
|
|
}
|
|
_ => {}
|
|
}
|
|
|
|
// Entire graph is one subtree
|
|
vec![collect_subtree(graph, output)]
|
|
}
|
|
|
|
/// Classify the parallelism of the root operation.
|
|
pub fn classify_root(graph: &TrigGraph) -> ParallelClass {
|
|
match &graph.nodes[graph.output as usize] {
|
|
TrigOp::Min(_, _) | TrigOp::Add(_, _) => ParallelClass::Additive,
|
|
TrigOp::Mul(_, _) | TrigOp::Div(_, _) => ParallelClass::Multiplicative,
|
|
TrigOp::Max(_, _) | TrigOp::Sub(_, _) => ParallelClass::Divisive,
|
|
_ => ParallelClass::Sequential,
|
|
}
|
|
}
|
|
|
|
/// Recursively count the maximum parallelism depth.
|
|
/// Returns how many independent branches exist at each level.
|
|
pub fn parallelism_depth(graph: &TrigGraph) -> Vec<usize> {
|
|
let mut levels: Vec<usize> = Vec::new();
|
|
count_branches(graph, graph.output, 0, &mut levels);
|
|
levels
|
|
}
|
|
|
|
fn count_branches(graph: &TrigGraph, node: NodeId, depth: usize, levels: &mut Vec<usize>) {
|
|
while levels.len() <= depth {
|
|
levels.push(0);
|
|
}
|
|
|
|
match &graph.nodes[node as usize] {
|
|
// Splittable operations — both children are independent branches
|
|
TrigOp::Min(a, b) | TrigOp::Add(a, b) => {
|
|
levels[depth] += 2;
|
|
count_branches(graph, *a, depth + 1, levels);
|
|
count_branches(graph, *b, depth + 1, levels);
|
|
}
|
|
// Non-commutative but still two-operand
|
|
TrigOp::Max(a, b) | TrigOp::Sub(a, b) | TrigOp::Mul(a, b) | TrigOp::Div(a, b)
|
|
| TrigOp::Hypot(a, b) | TrigOp::Atan2(a, b) => {
|
|
levels[depth] += 1;
|
|
count_branches(graph, *a, depth + 1, levels);
|
|
count_branches(graph, *b, depth + 1, levels);
|
|
}
|
|
// Single-operand
|
|
TrigOp::Neg(a) | TrigOp::Abs(a)
|
|
| TrigOp::Sin(a) | TrigOp::Cos(a) | TrigOp::Tan(a)
|
|
| TrigOp::Asin(a) | TrigOp::Acos(a) | TrigOp::Atan(a)
|
|
| TrigOp::Sinh(a) | TrigOp::Cosh(a) | TrigOp::Tanh(a)
|
|
| TrigOp::Asinh(a) | TrigOp::Acosh(a) | TrigOp::Atanh(a)
|
|
| TrigOp::Sqrt(a) | TrigOp::Exp(a) | TrigOp::Ln(a) => {
|
|
levels[depth] += 1;
|
|
count_branches(graph, *a, depth + 1, levels);
|
|
}
|
|
TrigOp::Clamp { val, lo, hi } => {
|
|
levels[depth] += 1;
|
|
count_branches(graph, *val, depth + 1, levels);
|
|
count_branches(graph, *lo, depth + 1, levels);
|
|
count_branches(graph, *hi, depth + 1, levels);
|
|
}
|
|
// Leaves
|
|
_ => {
|
|
levels[depth] += 1;
|
|
}
|
|
}
|
|
}
|
|
|
|
fn collect_subtree(graph: &TrigGraph, root: NodeId) -> Subtree {
|
|
let mut nodes = HashSet::new();
|
|
let mut inputs = HashSet::new();
|
|
let mut stack = vec![root];
|
|
|
|
while let Some(id) = stack.pop() {
|
|
if !nodes.insert(id) {
|
|
continue;
|
|
}
|
|
|
|
let op = &graph.nodes[id as usize];
|
|
if is_shared_input(op) {
|
|
inputs.insert(id);
|
|
}
|
|
|
|
match op {
|
|
TrigOp::Add(a, b) | TrigOp::Sub(a, b) | TrigOp::Mul(a, b) | TrigOp::Div(a, b)
|
|
| TrigOp::Hypot(a, b) | TrigOp::Atan2(a, b)
|
|
| TrigOp::Min(a, b) | TrigOp::Max(a, b) => {
|
|
stack.push(*a);
|
|
stack.push(*b);
|
|
}
|
|
TrigOp::Neg(a) | TrigOp::Abs(a)
|
|
| TrigOp::Sin(a) | TrigOp::Cos(a) | TrigOp::Tan(a)
|
|
| TrigOp::Asin(a) | TrigOp::Acos(a) | TrigOp::Atan(a)
|
|
| TrigOp::Sinh(a) | TrigOp::Cosh(a) | TrigOp::Tanh(a)
|
|
| TrigOp::Asinh(a) | TrigOp::Acosh(a) | TrigOp::Atanh(a)
|
|
| TrigOp::Sqrt(a) | TrigOp::Exp(a) | TrigOp::Ln(a) => {
|
|
stack.push(*a);
|
|
}
|
|
TrigOp::Clamp { val, lo, hi } => {
|
|
stack.push(*val);
|
|
stack.push(*lo);
|
|
stack.push(*hi);
|
|
}
|
|
_ => {}
|
|
}
|
|
}
|
|
|
|
Subtree { root, nodes, inputs }
|
|
}
|
|
|
|
fn is_shared_input(op: &TrigOp) -> bool {
|
|
matches!(op, TrigOp::InputX | TrigOp::InputY | TrigOp::InputZ | TrigOp::Const(_))
|
|
}
|
|
|
|
#[cfg(test)]
|
|
mod tests {
|
|
use super::*;
|
|
|
|
#[test]
|
|
fn union_splits_into_two() {
|
|
let mut g = TrigGraph::new();
|
|
let x = g.push(TrigOp::InputX);
|
|
let y = g.push(TrigOp::InputY);
|
|
let z = g.push(TrigOp::InputZ);
|
|
|
|
// sphere(1): hypot(hypot(x,y),z) - 1
|
|
let xy = g.push(TrigOp::Hypot(x, y));
|
|
let mag = g.push(TrigOp::Hypot(xy, z));
|
|
let r1 = g.push(TrigOp::Const(1.0));
|
|
let s1 = g.push(TrigOp::Sub(mag, r1));
|
|
|
|
// sphere(1) translated: same but with offset
|
|
let ox = g.push(TrigOp::Const(3.0));
|
|
let dx = g.push(TrigOp::Sub(x, ox));
|
|
let xy2 = g.push(TrigOp::Hypot(dx, y));
|
|
let mag2 = g.push(TrigOp::Hypot(xy2, z));
|
|
let r2 = g.push(TrigOp::Const(1.0));
|
|
let s2 = g.push(TrigOp::Sub(mag2, r2));
|
|
|
|
// union
|
|
let u = g.push(TrigOp::Min(s1, s2));
|
|
g.set_output(u);
|
|
|
|
let subtrees = find_independent_subtrees(&g);
|
|
assert_eq!(subtrees.len(), 2, "union of two spheres should split into 2 subtrees");
|
|
}
|
|
|
|
#[test]
|
|
fn classify_union() {
|
|
let mut g = TrigGraph::new();
|
|
let a = g.push(TrigOp::Const(1.0));
|
|
let b = g.push(TrigOp::Const(2.0));
|
|
let u = g.push(TrigOp::Min(a, b));
|
|
g.set_output(u);
|
|
assert_eq!(classify_root(&g), ParallelClass::Additive);
|
|
}
|
|
|
|
#[test]
|
|
fn classify_difference() {
|
|
let mut g = TrigGraph::new();
|
|
let a = g.push(TrigOp::Const(1.0));
|
|
let b = g.push(TrigOp::Const(2.0));
|
|
let d = g.push(TrigOp::Max(a, b));
|
|
g.set_output(d);
|
|
assert_eq!(classify_root(&g), ParallelClass::Divisive);
|
|
}
|
|
}
|