Cord/crates/cord-expr/src/remap.rs

177 lines
9.1 KiB
Rust

use cord_trig::ir::{NodeId, TrigOp};
use crate::parser::ExprParser;
impl<'a> ExprParser<'a> {
pub(crate) fn inline_graph(&mut self, source: &cord_trig::TrigGraph) -> NodeId {
let mut map = Vec::with_capacity(source.nodes.len());
for op in &source.nodes {
let new_id = match op {
TrigOp::InputX => self.get_x(),
TrigOp::InputY => self.get_y(),
TrigOp::InputZ => self.get_z(),
TrigOp::Const(c) => self.graph.push(TrigOp::Const(*c)),
TrigOp::Add(a, b) => self.graph.push(TrigOp::Add(map[*a as usize], map[*b as usize])),
TrigOp::Sub(a, b) => self.graph.push(TrigOp::Sub(map[*a as usize], map[*b as usize])),
TrigOp::Mul(a, b) => self.graph.push(TrigOp::Mul(map[*a as usize], map[*b as usize])),
TrigOp::Div(a, b) => self.graph.push(TrigOp::Div(map[*a as usize], map[*b as usize])),
TrigOp::Neg(a) => self.graph.push(TrigOp::Neg(map[*a as usize])),
TrigOp::Abs(a) => self.graph.push(TrigOp::Abs(map[*a as usize])),
TrigOp::Sin(a) => self.graph.push(TrigOp::Sin(map[*a as usize])),
TrigOp::Cos(a) => self.graph.push(TrigOp::Cos(map[*a as usize])),
TrigOp::Tan(a) => self.graph.push(TrigOp::Tan(map[*a as usize])),
TrigOp::Asin(a) => self.graph.push(TrigOp::Asin(map[*a as usize])),
TrigOp::Acos(a) => self.graph.push(TrigOp::Acos(map[*a as usize])),
TrigOp::Atan(a) => self.graph.push(TrigOp::Atan(map[*a as usize])),
TrigOp::Sinh(a) => self.graph.push(TrigOp::Sinh(map[*a as usize])),
TrigOp::Cosh(a) => self.graph.push(TrigOp::Cosh(map[*a as usize])),
TrigOp::Tanh(a) => self.graph.push(TrigOp::Tanh(map[*a as usize])),
TrigOp::Asinh(a) => self.graph.push(TrigOp::Asinh(map[*a as usize])),
TrigOp::Acosh(a) => self.graph.push(TrigOp::Acosh(map[*a as usize])),
TrigOp::Atanh(a) => self.graph.push(TrigOp::Atanh(map[*a as usize])),
TrigOp::Sqrt(a) => self.graph.push(TrigOp::Sqrt(map[*a as usize])),
TrigOp::Exp(a) => self.graph.push(TrigOp::Exp(map[*a as usize])),
TrigOp::Ln(a) => self.graph.push(TrigOp::Ln(map[*a as usize])),
TrigOp::Hypot(a, b) => self.graph.push(TrigOp::Hypot(map[*a as usize], map[*b as usize])),
TrigOp::Atan2(a, b) => self.graph.push(TrigOp::Atan2(map[*a as usize], map[*b as usize])),
TrigOp::Min(a, b) => self.graph.push(TrigOp::Min(map[*a as usize], map[*b as usize])),
TrigOp::Max(a, b) => self.graph.push(TrigOp::Max(map[*a as usize], map[*b as usize])),
TrigOp::Clamp { val, lo, hi } => self.graph.push(TrigOp::Clamp {
val: map[*val as usize],
lo: map[*lo as usize],
hi: map[*hi as usize],
}),
};
map.push(new_id);
}
map[source.output as usize]
}
pub(crate) fn remap_inputs(&mut self, root: NodeId, new_x: NodeId, new_y: NodeId, new_z: NodeId) -> NodeId {
let n = root as usize + 1;
let mut reachable = vec![false; n];
reachable[root as usize] = true;
for i in (0..n).rev() {
if !reachable[i] { continue; }
Self::mark_children(&self.graph.nodes[i], &mut reachable);
}
let mut depends_on_input = vec![false; n];
for i in 0..n {
if !reachable[i] { continue; }
match &self.graph.nodes[i] {
TrigOp::InputX | TrigOp::InputY | TrigOp::InputZ => {
depends_on_input[i] = true;
}
_ => {
depends_on_input[i] = Self::any_child_depends(
&self.graph.nodes[i], &depends_on_input,
);
}
}
}
let mut map: Vec<NodeId> = (0..n as u32).collect();
for i in 0..n {
if !reachable[i] { continue; }
match &self.graph.nodes[i] {
TrigOp::InputX => { map[i] = new_x; }
TrigOp::InputY => { map[i] = new_y; }
TrigOp::InputZ => { map[i] = new_z; }
_ if !depends_on_input[i] => {}
_ => {
map[i] = self.push_remapped(&self.graph.nodes[i].clone(), &map);
}
}
}
map[root as usize]
}
fn mark_children(op: &TrigOp, reachable: &mut [bool]) {
match op {
TrigOp::InputX | TrigOp::InputY | TrigOp::InputZ | TrigOp::Const(_) => {}
TrigOp::Add(a, b) | TrigOp::Sub(a, b) | TrigOp::Mul(a, b)
| TrigOp::Div(a, b) | TrigOp::Hypot(a, b) | TrigOp::Atan2(a, b)
| TrigOp::Min(a, b) | TrigOp::Max(a, b) => {
reachable[*a as usize] = true;
reachable[*b as usize] = true;
}
TrigOp::Neg(a) | TrigOp::Abs(a) | TrigOp::Sin(a) | TrigOp::Cos(a)
| TrigOp::Tan(a) | TrigOp::Asin(a) | TrigOp::Acos(a) | TrigOp::Atan(a)
| TrigOp::Sinh(a) | TrigOp::Cosh(a) | TrigOp::Tanh(a)
| TrigOp::Asinh(a) | TrigOp::Acosh(a) | TrigOp::Atanh(a)
| TrigOp::Sqrt(a) | TrigOp::Exp(a) | TrigOp::Ln(a) => {
reachable[*a as usize] = true;
}
TrigOp::Clamp { val, lo, hi } => {
reachable[*val as usize] = true;
reachable[*lo as usize] = true;
reachable[*hi as usize] = true;
}
}
}
fn any_child_depends(op: &TrigOp, deps: &[bool]) -> bool {
match op {
TrigOp::InputX | TrigOp::InputY | TrigOp::InputZ => true,
TrigOp::Const(_) => false,
TrigOp::Add(a, b) | TrigOp::Sub(a, b) | TrigOp::Mul(a, b)
| TrigOp::Div(a, b) | TrigOp::Hypot(a, b) | TrigOp::Atan2(a, b)
| TrigOp::Min(a, b) | TrigOp::Max(a, b) => {
deps[*a as usize] || deps[*b as usize]
}
TrigOp::Neg(a) | TrigOp::Abs(a) | TrigOp::Sin(a) | TrigOp::Cos(a)
| TrigOp::Tan(a) | TrigOp::Asin(a) | TrigOp::Acos(a) | TrigOp::Atan(a)
| TrigOp::Sinh(a) | TrigOp::Cosh(a) | TrigOp::Tanh(a)
| TrigOp::Asinh(a) | TrigOp::Acosh(a) | TrigOp::Atanh(a)
| TrigOp::Sqrt(a) | TrigOp::Exp(a) | TrigOp::Ln(a) => {
deps[*a as usize]
}
TrigOp::Clamp { val, lo, hi } => {
deps[*val as usize] || deps[*lo as usize] || deps[*hi as usize]
}
}
}
fn push_remapped(&mut self, op: &TrigOp, map: &[NodeId]) -> NodeId {
match op {
TrigOp::InputX | TrigOp::InputY | TrigOp::InputZ => unreachable!(),
TrigOp::Const(c) => self.graph.push(TrigOp::Const(*c)),
TrigOp::Add(a, b) => self.graph.push(TrigOp::Add(map[*a as usize], map[*b as usize])),
TrigOp::Sub(a, b) => self.graph.push(TrigOp::Sub(map[*a as usize], map[*b as usize])),
TrigOp::Mul(a, b) => self.graph.push(TrigOp::Mul(map[*a as usize], map[*b as usize])),
TrigOp::Div(a, b) => self.graph.push(TrigOp::Div(map[*a as usize], map[*b as usize])),
TrigOp::Neg(a) => self.graph.push(TrigOp::Neg(map[*a as usize])),
TrigOp::Abs(a) => self.graph.push(TrigOp::Abs(map[*a as usize])),
TrigOp::Sin(a) => self.graph.push(TrigOp::Sin(map[*a as usize])),
TrigOp::Cos(a) => self.graph.push(TrigOp::Cos(map[*a as usize])),
TrigOp::Tan(a) => self.graph.push(TrigOp::Tan(map[*a as usize])),
TrigOp::Asin(a) => self.graph.push(TrigOp::Asin(map[*a as usize])),
TrigOp::Acos(a) => self.graph.push(TrigOp::Acos(map[*a as usize])),
TrigOp::Atan(a) => self.graph.push(TrigOp::Atan(map[*a as usize])),
TrigOp::Sinh(a) => self.graph.push(TrigOp::Sinh(map[*a as usize])),
TrigOp::Cosh(a) => self.graph.push(TrigOp::Cosh(map[*a as usize])),
TrigOp::Tanh(a) => self.graph.push(TrigOp::Tanh(map[*a as usize])),
TrigOp::Asinh(a) => self.graph.push(TrigOp::Asinh(map[*a as usize])),
TrigOp::Acosh(a) => self.graph.push(TrigOp::Acosh(map[*a as usize])),
TrigOp::Atanh(a) => self.graph.push(TrigOp::Atanh(map[*a as usize])),
TrigOp::Sqrt(a) => self.graph.push(TrigOp::Sqrt(map[*a as usize])),
TrigOp::Exp(a) => self.graph.push(TrigOp::Exp(map[*a as usize])),
TrigOp::Ln(a) => self.graph.push(TrigOp::Ln(map[*a as usize])),
TrigOp::Hypot(a, b) => self.graph.push(TrigOp::Hypot(map[*a as usize], map[*b as usize])),
TrigOp::Atan2(a, b) => self.graph.push(TrigOp::Atan2(map[*a as usize], map[*b as usize])),
TrigOp::Min(a, b) => self.graph.push(TrigOp::Min(map[*a as usize], map[*b as usize])),
TrigOp::Max(a, b) => self.graph.push(TrigOp::Max(map[*a as usize], map[*b as usize])),
TrigOp::Clamp { val, lo, hi } => self.graph.push(TrigOp::Clamp {
val: map[*val as usize],
lo: map[*lo as usize],
hi: map[*hi as usize],
}),
}
}
}