Cord/crates/cord-trig/src/parallel.rs

216 lines
7.2 KiB
Rust

use crate::ir::{NodeId, TrigGraph, TrigOp};
use std::collections::HashSet;
/// Parallelism classification for a subexpression.
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum ParallelClass {
/// Commutative, associative: union, add. Branches evaluate independently.
Additive,
/// Same operation applied N times: scale, repetition.
Multiplicative,
/// difference, intersection. Parallelizable if operands are independent.
Divisive,
/// True data dependency — must evaluate sequentially.
Sequential,
}
/// An independent subtree identified in the DAG.
#[derive(Debug, Clone)]
pub struct Subtree {
pub root: NodeId,
pub nodes: HashSet<NodeId>,
pub inputs: HashSet<NodeId>,
}
/// Identify independent subtrees in a TrigGraph.
///
/// Two subtrees are independent if they share no intermediate nodes
/// (they may share inputs like InputX/Y/Z and constants).
pub fn find_independent_subtrees(graph: &TrigGraph) -> Vec<Subtree> {
let output = graph.output;
let root_op = &graph.nodes[output as usize];
// Only split at union/intersection/add (commutative operations)
match root_op {
TrigOp::Min(a, b) | TrigOp::Max(a, b)
| TrigOp::Add(a, b) => {
let left = collect_subtree(graph, *a);
let right = collect_subtree(graph, *b);
// Check independence: no shared non-input nodes
let shared: HashSet<NodeId> = left.nodes.intersection(&right.nodes)
.copied()
.filter(|&id| !is_shared_input(&graph.nodes[id as usize]))
.collect();
if shared.is_empty() {
return vec![left, right];
}
}
_ => {}
}
// Entire graph is one subtree
vec![collect_subtree(graph, output)]
}
/// Classify the parallelism of the root operation.
pub fn classify_root(graph: &TrigGraph) -> ParallelClass {
match &graph.nodes[graph.output as usize] {
TrigOp::Min(_, _) | TrigOp::Add(_, _) => ParallelClass::Additive,
TrigOp::Mul(_, _) | TrigOp::Div(_, _) => ParallelClass::Multiplicative,
TrigOp::Max(_, _) | TrigOp::Sub(_, _) => ParallelClass::Divisive,
_ => ParallelClass::Sequential,
}
}
/// Recursively count the maximum parallelism depth.
/// Returns how many independent branches exist at each level.
pub fn parallelism_depth(graph: &TrigGraph) -> Vec<usize> {
let mut levels: Vec<usize> = Vec::new();
count_branches(graph, graph.output, 0, &mut levels);
levels
}
fn count_branches(graph: &TrigGraph, node: NodeId, depth: usize, levels: &mut Vec<usize>) {
while levels.len() <= depth {
levels.push(0);
}
match &graph.nodes[node as usize] {
// Splittable operations — both children are independent branches
TrigOp::Min(a, b) | TrigOp::Add(a, b) => {
levels[depth] += 2;
count_branches(graph, *a, depth + 1, levels);
count_branches(graph, *b, depth + 1, levels);
}
// Non-commutative but still two-operand
TrigOp::Max(a, b) | TrigOp::Sub(a, b) | TrigOp::Mul(a, b) | TrigOp::Div(a, b)
| TrigOp::Hypot(a, b) | TrigOp::Atan2(a, b) => {
levels[depth] += 1;
count_branches(graph, *a, depth + 1, levels);
count_branches(graph, *b, depth + 1, levels);
}
// Single-operand
TrigOp::Neg(a) | TrigOp::Abs(a)
| TrigOp::Sin(a) | TrigOp::Cos(a) | TrigOp::Tan(a)
| TrigOp::Asin(a) | TrigOp::Acos(a) | TrigOp::Atan(a)
| TrigOp::Sinh(a) | TrigOp::Cosh(a) | TrigOp::Tanh(a)
| TrigOp::Asinh(a) | TrigOp::Acosh(a) | TrigOp::Atanh(a)
| TrigOp::Sqrt(a) | TrigOp::Exp(a) | TrigOp::Ln(a) => {
levels[depth] += 1;
count_branches(graph, *a, depth + 1, levels);
}
TrigOp::Clamp { val, lo, hi } => {
levels[depth] += 1;
count_branches(graph, *val, depth + 1, levels);
count_branches(graph, *lo, depth + 1, levels);
count_branches(graph, *hi, depth + 1, levels);
}
// Leaves
_ => {
levels[depth] += 1;
}
}
}
fn collect_subtree(graph: &TrigGraph, root: NodeId) -> Subtree {
let mut nodes = HashSet::new();
let mut inputs = HashSet::new();
let mut stack = vec![root];
while let Some(id) = stack.pop() {
if !nodes.insert(id) {
continue;
}
let op = &graph.nodes[id as usize];
if is_shared_input(op) {
inputs.insert(id);
}
match op {
TrigOp::Add(a, b) | TrigOp::Sub(a, b) | TrigOp::Mul(a, b) | TrigOp::Div(a, b)
| TrigOp::Hypot(a, b) | TrigOp::Atan2(a, b)
| TrigOp::Min(a, b) | TrigOp::Max(a, b) => {
stack.push(*a);
stack.push(*b);
}
TrigOp::Neg(a) | TrigOp::Abs(a)
| TrigOp::Sin(a) | TrigOp::Cos(a) | TrigOp::Tan(a)
| TrigOp::Asin(a) | TrigOp::Acos(a) | TrigOp::Atan(a)
| TrigOp::Sinh(a) | TrigOp::Cosh(a) | TrigOp::Tanh(a)
| TrigOp::Asinh(a) | TrigOp::Acosh(a) | TrigOp::Atanh(a)
| TrigOp::Sqrt(a) | TrigOp::Exp(a) | TrigOp::Ln(a) => {
stack.push(*a);
}
TrigOp::Clamp { val, lo, hi } => {
stack.push(*val);
stack.push(*lo);
stack.push(*hi);
}
_ => {}
}
}
Subtree { root, nodes, inputs }
}
fn is_shared_input(op: &TrigOp) -> bool {
matches!(op, TrigOp::InputX | TrigOp::InputY | TrigOp::InputZ | TrigOp::Const(_))
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn union_splits_into_two() {
let mut g = TrigGraph::new();
let x = g.push(TrigOp::InputX);
let y = g.push(TrigOp::InputY);
let z = g.push(TrigOp::InputZ);
// sphere(1): hypot(hypot(x,y),z) - 1
let xy = g.push(TrigOp::Hypot(x, y));
let mag = g.push(TrigOp::Hypot(xy, z));
let r1 = g.push(TrigOp::Const(1.0));
let s1 = g.push(TrigOp::Sub(mag, r1));
// sphere(1) translated: same but with offset
let ox = g.push(TrigOp::Const(3.0));
let dx = g.push(TrigOp::Sub(x, ox));
let xy2 = g.push(TrigOp::Hypot(dx, y));
let mag2 = g.push(TrigOp::Hypot(xy2, z));
let r2 = g.push(TrigOp::Const(1.0));
let s2 = g.push(TrigOp::Sub(mag2, r2));
// union
let u = g.push(TrigOp::Min(s1, s2));
g.set_output(u);
let subtrees = find_independent_subtrees(&g);
assert_eq!(subtrees.len(), 2, "union of two spheres should split into 2 subtrees");
}
#[test]
fn classify_union() {
let mut g = TrigGraph::new();
let a = g.push(TrigOp::Const(1.0));
let b = g.push(TrigOp::Const(2.0));
let u = g.push(TrigOp::Min(a, b));
g.set_output(u);
assert_eq!(classify_root(&g), ParallelClass::Additive);
}
#[test]
fn classify_difference() {
let mut g = TrigGraph::new();
let a = g.push(TrigOp::Const(1.0));
let b = g.push(TrigOp::Const(2.0));
let d = g.push(TrigOp::Max(a, b));
g.set_output(d);
assert_eq!(classify_root(&g), ParallelClass::Divisive);
}
}