Cord/crates/cord-expr/src/userfunc.rs

227 lines
16 KiB
Rust

use cord_trig::ir::{NodeId, TrigOp};
use crate::token::Token;
use crate::parser::{ExprParser, UserFunc, Schematic};
impl<'a> ExprParser<'a> {
pub(crate) fn is_func_def(&self) -> bool {
if !matches!(self.tokens.get(self.pos), Some(Token::Ident(_))) { return false; }
if !matches!(self.tokens.get(self.pos + 1), Some(Token::LParen)) { return false; }
let mut i = self.pos + 2;
let mut depth = 1u32;
while i < self.tokens.len() {
match &self.tokens[i] {
Token::LParen => depth += 1,
Token::RParen => { depth -= 1; if depth == 0 { return matches!(self.tokens.get(i + 1), Some(Token::Eq)); } }
_ => {}
}
i += 1;
}
false
}
pub(crate) fn parse_param_list_with_defaults(&mut self) -> Result<(Vec<String>, Vec<Option<Vec<Token>>>), String> {
let mut params = Vec::new();
let mut defaults = Vec::new();
self.skip_separators();
if matches!(self.peek(), Some(Token::RParen)) { return Ok((params, defaults)); }
loop {
self.skip_separators();
let pname = match self.advance().cloned() { Some(Token::Ident(p)) => p, _ => return Err(self.err_at("expected parameter name".into())) };
params.push(pname);
if matches!(self.peek(), Some(Token::Colon) | Some(Token::Eq)) {
self.advance();
let start = self.pos;
let mut depth = 0u32;
while self.pos < self.tokens.len() {
match &self.tokens[self.pos] {
Token::Comma if depth == 0 => break, Token::RParen if depth == 0 => break,
Token::LParen => { depth += 1; self.pos += 1; } Token::RParen => { depth -= 1; self.pos += 1; }
Token::Semi | Token::Newline if depth == 0 => break, _ => { self.pos += 1; }
}
}
defaults.push(Some(self.tokens[start..self.pos].to_vec()));
} else { defaults.push(None); }
self.skip_separators();
if !matches!(self.peek(), Some(Token::Comma)) { break; }
self.advance();
}
self.skip_separators();
Ok((params, defaults))
}
fn resolve_defaults(&mut self, params: &[String], defaults: &[Option<Vec<Token>>], args: &[NodeId], name: &str) -> Result<Vec<NodeId>, String> {
let required = params.iter().zip(defaults.iter()).filter(|(_, d)| d.is_none()).count();
if args.len() < required || args.len() > params.len() {
return Err(format!("{name}() takes {}{} argument(s), got {}", if required < params.len() { format!("{required}..") } else { String::new() }, params.len(), args.len()));
}
let mut resolved = Vec::with_capacity(params.len());
for (i, dt) in defaults.iter().enumerate() {
if i < args.len() { resolved.push(args[i]); }
else if let Some(def_body) = dt { resolved.push(self.eval_default_expr(def_body.clone())?); }
else { return Err(format!("{name}(): missing required argument '{}'", params[i])); }
}
Ok(resolved)
}
fn eval_default_expr(&mut self, body: Vec<Token>) -> Result<NodeId, String> {
let st = std::mem::replace(&mut self.tokens, &[]); let sl = std::mem::replace(&mut self.token_lines, &[]);
let ss = std::mem::replace(&mut self.source_lines, &[]); let sp = self.pos;
let bb = body.into_boxed_slice(); let bp = Box::into_raw(bb);
self.tokens = unsafe { &*bp }; self.pos = 0;
let result = self.parse_additive();
let _ = unsafe { Box::from_raw(bp) };
self.tokens = st; self.token_lines = sl; self.source_lines = ss; self.pos = sp;
result
}
pub(crate) fn parse_func_def(&mut self) -> Result<(), String> {
let name = match self.advance().cloned() { Some(Token::Ident(n)) => n, _ => unreachable!() };
self.expect(&Token::LParen)?;
let (params, defaults) = self.parse_param_list_with_defaults()?;
self.expect(&Token::RParen)?; self.expect(&Token::Eq)?;
let start = self.pos; let mut depth = 0u32;
while self.pos < self.tokens.len() { match &self.tokens[self.pos] { Token::Semi | Token::Newline if depth == 0 => break, Token::LParen => { depth += 1; self.pos += 1; } Token::RParen if depth > 0 => { depth -= 1; self.pos += 1; } _ => { self.pos += 1; } } }
let body = self.tokens[start..self.pos].to_vec(); self.skip_separators();
self.funcs.insert(name, UserFunc { params, defaults, body }); Ok(())
}
pub(crate) fn call_user_func_inner(&mut self, params: Vec<String>, defaults: Vec<Option<Vec<Token>>>, body: Vec<Token>, args: &[NodeId], name: &str) -> Result<NodeId, String> {
let ra = self.resolve_defaults(&params, &defaults, args, name)?;
let mut saved = Vec::new();
for (p, &a) in params.iter().zip(ra.iter()) { saved.push((p.clone(), self.vars.get(p).copied())); self.vars.insert(p.clone(), a); }
let st = std::mem::replace(&mut self.tokens, &[]); let sl = std::mem::replace(&mut self.token_lines, &[]);
let ss = std::mem::replace(&mut self.source_lines, &[]); let sp = self.pos;
let bb = body.into_boxed_slice(); let bp = Box::into_raw(bb);
self.tokens = unsafe { &*bp }; self.pos = 0;
let result = self.parse_additive();
let _ = unsafe { Box::from_raw(bp) };
self.tokens = st; self.token_lines = sl; self.source_lines = ss; self.pos = sp;
for (p, old) in saved { match old { Some(v) => { self.vars.insert(p, v); } None => { self.vars.remove(&p); } } }
result
}
pub(crate) fn parse_sch_def(&mut self) -> Result<(), String> {
self.advance();
let name = match self.advance().cloned() { Some(Token::Ident(n)) => n, _ => return Err(self.err_at("expected schematic name after 'sch'".into())) };
self.expect(&Token::LParen)?;
let (params, defaults) = self.parse_param_list_with_defaults()?;
self.expect(&Token::RParen)?;
if matches!(self.peek(), Some(Token::Eq)) {
self.advance();
let start = self.pos; let mut depth = 0u32;
while self.pos < self.tokens.len() { match &self.tokens[self.pos] { Token::Semi | Token::Newline if depth == 0 => break, Token::LParen => { depth += 1; self.pos += 1; } Token::RParen if depth > 0 => { depth -= 1; self.pos += 1; } _ => { self.pos += 1; } } }
let body = self.tokens[start..self.pos].to_vec(); self.skip_separators();
self.schematics.insert(name, Schematic { params, defaults, body, value_returning: true }); return Ok(());
}
self.expect(&Token::LBrace)?; let body = self.collect_brace_body()?;
self.schematics.insert(name, Schematic { params, defaults, body, value_returning: false }); Ok(())
}
fn collect_brace_body(&mut self) -> Result<Vec<Token>, String> {
let start = self.pos; let mut depth = 1u32;
while self.pos < self.tokens.len() { match &self.tokens[self.pos] { Token::LBrace => { depth += 1; self.pos += 1; } Token::RBrace => { depth -= 1; if depth == 0 { let body = self.tokens[start..self.pos].to_vec(); self.pos += 1; self.skip_separators(); return Ok(body); } self.pos += 1; } _ => { self.pos += 1; } } }
Err("unclosed '{'".into())
}
pub(crate) fn call_schematic(&mut self, params: Vec<String>, defaults: Vec<Option<Vec<Token>>>, body: Vec<Token>, value_returning: bool, args: &[NodeId], name: &str) -> Result<NodeId, String> {
let ra = self.resolve_defaults(&params, &defaults, args, name)?;
let mut saved = Vec::new();
for (p, &a) in params.iter().zip(ra.iter()) { saved.push((p.clone(), self.vars.get(p).copied())); self.vars.insert(p.clone(), a); }
let sf = self.funcs.clone(); let ss2 = self.schematics.clone();
let st = std::mem::replace(&mut self.tokens, &[]); let sl = std::mem::replace(&mut self.token_lines, &[]);
let ss = std::mem::replace(&mut self.source_lines, &[]); let sp = self.pos;
let bb = body.into_boxed_slice(); let bp = Box::into_raw(bb);
self.tokens = unsafe { &*bp }; self.pos = 0;
let result = if value_returning { self.parse_additive() } else { self.parse_block_body() };
let _ = unsafe { Box::from_raw(bp) };
self.tokens = st; self.token_lines = sl; self.source_lines = ss; self.pos = sp;
self.funcs = sf; self.schematics = ss2;
for (p, old) in saved { match old { Some(v) => { self.vars.insert(p, v); } None => { self.vars.remove(&p); } } }
result
}
fn parse_block_body(&mut self) -> Result<NodeId, String> {
let mut last = None;
loop {
self.skip_separators(); if self.pos >= self.tokens.len() { break; }
if self.is_func_def() { self.parse_func_def()?; continue; }
if matches!(self.peek(), Some(Token::Ident(s)) if s == "sch") { self.parse_sch_def()?; continue; }
if matches!(self.peek(), Some(Token::Ident(s)) if s == "let") {
self.advance();
let name = match self.advance().cloned() { Some(Token::Ident(n)) => n, _ => return Err(self.err_at("expected variable name after 'let'".into())) };
let mut is_obj = false;
if matches!(self.peek(), Some(Token::Colon)) { self.advance(); match self.advance().cloned() { Some(Token::Ident(ty)) => { if ty == "Obj" || ty == "obj" { is_obj = true; } } _ => return Err(self.err_at("expected type name after ':'".into())) } }
self.expect(&Token::Eq)?; let val = self.parse_additive()?;
self.vars.insert(name.clone(), val);
if is_obj { self.objects.push(name.clone()); self.object_nodes.insert(name, val); }
last = Some(val); self.skip_separators();
} else { let node = self.parse_additive()?; last = Some(node); self.skip_separators(); }
}
last.ok_or_else(|| "empty block".into())
}
pub(crate) fn parse_map(&mut self) -> Result<NodeId, String> {
self.expect(&Token::LParen)?;
let iter_var = match self.advance().cloned() { Some(Token::Ident(n)) => n, _ => return Err("map: expected iteration variable name".into()) };
self.expect(&Token::Comma)?;
let sn = self.parse_additive()?; self.expect(&Token::DotDot)?; let en = self.parse_additive()?; self.expect(&Token::RParen)?;
let si = self.eval_const(sn)?.round() as i64; let ei = self.eval_const(en)?.round() as i64;
if ei <= si { return Err(format!("map: empty range {}..{}", si, ei)); }
if ei - si > 1024 { return Err("map: range too large (max 1024 iterations)".into()); }
self.expect(&Token::LBrace)?; let body = self.collect_brace_body()?;
let saved_var = self.vars.get(&iter_var).copied(); let mut nodes: Vec<NodeId> = Vec::new();
for i in si..ei {
let i_node = self.graph.push(TrigOp::Const(i as f64)); self.vars.insert(iter_var.clone(), i_node);
let st = std::mem::replace(&mut self.tokens, &[]); let sp = self.pos;
let bc = body.clone(); let bb = bc.into_boxed_slice(); let bp = Box::into_raw(bb);
self.tokens = unsafe { &*bp }; self.pos = 0;
let node = self.parse_block_body()?;
let _ = unsafe { Box::from_raw(bp) }; self.tokens = st; self.pos = sp; nodes.push(node);
}
match saved_var { Some(v) => { self.vars.insert(iter_var, v); } None => { self.vars.remove(&iter_var); } }
if nodes.is_empty() { return Err("map: produced no results".into()); }
let any_obj = nodes.iter().any(|n| self.is_obj_node(*n));
let mut result = nodes[0]; for &node in &nodes[1..] { result = self.graph.push(TrigOp::Min(result, node)); }
if any_obj { self.mark_obj(result); } Ok(result)
}
fn eval_const(&self, node: NodeId) -> Result<f64, String> {
match &self.graph.nodes[node as usize] {
TrigOp::Const(v) => Ok(*v),
TrigOp::Add(a, b) => Ok(self.eval_const(*a)? + self.eval_const(*b)?),
TrigOp::Sub(a, b) => Ok(self.eval_const(*a)? - self.eval_const(*b)?),
TrigOp::Mul(a, b) => Ok(self.eval_const(*a)? * self.eval_const(*b)?),
TrigOp::Div(a, b) => Ok(self.eval_const(*a)? / self.eval_const(*b)?),
TrigOp::Neg(a) => Ok(-self.eval_const(*a)?),
_ => { let mut g = self.graph.clone(); g.set_output(node); let val = cord_trig::eval::evaluate(&g, 0.0, 0.0, 0.0); if val.is_finite() { Ok(val) } else { Err("map: range bounds must be compile-time constants".into()) } }
}
}
}
#[cfg(test)]
mod tests {
use crate::{parse_expr, parse_expr_scene};
use cord_trig::eval::evaluate;
#[test] fn user_func_basic() { let g = parse_expr("f(a) = a^2\nf(3)").unwrap(); assert!((evaluate(&g, 0.0, 0.0, 0.0) - 9.0).abs() < 1e-10); }
#[test] fn user_func_two_params() { let g = parse_expr("f(a, b) = a + b\nf(3, 4)").unwrap(); assert!((evaluate(&g, 0.0, 0.0, 0.0) - 7.0).abs() < 1e-10); }
#[test] fn user_func_with_xyz() { let g = parse_expr("f(r) = sphere(r)\nf(3)").unwrap(); assert!((evaluate(&g, 3.0, 0.0, 0.0) - 0.0).abs() < 1e-6); }
#[test] fn user_func_composition() { let g = parse_expr("f(a) = a * 2\ng(b) = b + 1\ng(f(3))").unwrap(); assert!((evaluate(&g, 0.0, 0.0, 0.0) - 7.0).abs() < 1e-10); }
#[test] fn user_func_with_let() { let g = parse_expr("f(v) = v^2 + 1\nlet a = f(x)\na").unwrap(); assert!((evaluate(&g, 3.0, 0.0, 0.0) - 10.0).abs() < 1e-10); }
#[test] fn user_func_default_value() { let g = parse_expr("f(a, b = 10) = a + b\nf(3)").unwrap(); assert!((evaluate(&g, 0.0, 0.0, 0.0) - 13.0).abs() < 1e-10); }
#[test] fn user_func_default_override() { let g = parse_expr("f(a, b = 10) = a + b\nf(3, 5)").unwrap(); assert!((evaluate(&g, 0.0, 0.0, 0.0) - 8.0).abs() < 1e-10); }
#[test] fn sch_basic() { let g = parse_expr("sch Foo(r) { sphere(r) }\nFoo(3)").unwrap(); assert!((evaluate(&g, 3.0, 0.0, 0.0) - 0.0).abs() < 1e-6); }
#[test] fn sch_multi_statement() { let g = parse_expr("sch Bar(w, h) {\n let a = box(w, h, 1)\n let b = sphere(1)\n union(a, b)\n}\nBar(3, 2)").unwrap(); assert!(evaluate(&g, 0.0, 0.0, 0.0) < 0.0); }
#[test] fn sch_with_transforms() { let g = parse_expr("sch Arm(len) {\n translate(box(len, 0.5, 0.5), len/2, 0, 0)\n}\nArm(5)").unwrap(); assert!(evaluate(&g, 2.5, 0.0, 0.0) < 0.0); }
#[test] fn sch_multiline_params() { let g = parse_expr("sch Brace(\n w,\n h,\n t\n) {\n box(w, h, t)\n}\nBrace(3, 2, 1)").unwrap(); assert!(evaluate(&g, 0.0, 0.0, 0.0) < 0.0); }
#[test] fn sch_default_params() { let g = parse_expr("sch Cube(s: 2) { box(s, s, s) }\nCube()").unwrap(); assert!(evaluate(&g, 0.0, 0.0, 0.0) < 0.0); }
#[test] fn sch_default_params_override() { let g = parse_expr("sch Cube(s: 2) { box(s, s, s) }\nCube(5)").unwrap(); assert!((evaluate(&g, 5.0, 0.0, 0.0) - 0.0).abs() < 1e-6); }
#[test] fn sch_mixed_defaults() { let g = parse_expr("sch Pillar(r, h: 10) {\n cylinder(r, h)\n}\nPillar(2)").unwrap(); assert!(evaluate(&g, 0.0, 0.0, 0.0) < 0.0); }
#[test] fn sch_value_returning() { let g = parse_expr("sch double(v) = v * 2\ndouble(5)").unwrap(); assert!((evaluate(&g, 0.0, 0.0, 0.0) - 10.0).abs() < 1e-10); }
#[test] fn sch_nested_definition() { let g = parse_expr("sch Outer(r) {\n sch Inner(s) { sphere(s) }\n translate(Inner(r), r, 0, 0)\n}\nOuter(3)").unwrap(); assert!(evaluate(&g, 3.0, 0.0, 0.0) < 0.0); }
#[test] fn sch_outer_scope_visible() { let g = parse_expr("let k = 5\nsch S(r) { sphere(r + k) }\nS(1)").unwrap(); assert!((evaluate(&g, 6.0, 0.0, 0.0) - 0.0).abs() < 1e-6); }
#[test] fn map_basic() { let g = parse_expr("map(i, 0..5) { translate(sphere(1), i * 3, 0, 0) }").unwrap(); assert!(evaluate(&g, 0.0, 0.0, 0.0) < 0.0); assert!(evaluate(&g, 6.0, 0.0, 0.0) < 0.0); assert!(evaluate(&g, 1.5, 0.0, 0.0) > 0.0); }
#[test] fn map_with_sch() { let g = parse_expr("sch Peg(r) { sphere(r) }\nmap(i, 0..3) { translate(Peg(1), i * 4, 0, 0) }").unwrap(); assert!(evaluate(&g, 0.0, 0.0, 0.0) < 0.0); assert!(evaluate(&g, 4.0, 0.0, 0.0) < 0.0); assert!(evaluate(&g, 8.0, 0.0, 0.0) < 0.0); assert!(evaluate(&g, 2.0, 0.0, 0.0) > 0.0); }
#[test] fn map_rotation_ring() { let g = parse_expr("map(i, 0..4) { rotate_z(translate(sphere(0.5), 5, 0, 0), i * pi/2) }").unwrap(); assert!(evaluate(&g, 5.0, 0.0, 0.0) < 0.0); assert!(evaluate(&g, 0.0, 5.0, 0.0) < 0.0); assert!(evaluate(&g, -5.0, 0.0, 0.0) < 0.0); assert!(evaluate(&g, 0.0, -5.0, 0.0) < 0.0); }
#[test] fn let_with_map() { let scene = parse_expr_scene("let row: Obj = map(i, 0..3) { translate(sphere(1), i * 3, 0, 0) }\ncast()").unwrap(); assert!(scene.cast_all); let g = &scene.graph; assert!(evaluate(g, 0.0, 0.0, 0.0) < 0.0); assert!(evaluate(g, 3.0, 0.0, 0.0) < 0.0); }
}