101 lines
3.1 KiB
Rust
101 lines
3.1 KiB
Rust
use crate::ast::{Literal, Node};
|
|
use crate::constants::DEFAULT_FUNCTIONS;
|
|
use crate::context::{EvalContext, FunctionProvider, ValueProvider};
|
|
use crate::value::{Number, Value};
|
|
use thiserror::Error;
|
|
|
|
#[derive(Debug, Error)]
|
|
pub enum EvalError {
|
|
#[error("Missing value: {0}")]
|
|
MissingValue(String),
|
|
|
|
#[error("Missing function: {0}")]
|
|
MissingFunction(String),
|
|
#[error("Wrong type for function call")]
|
|
TypeError,
|
|
}
|
|
|
|
impl Node {
|
|
pub fn eval<V: ValueProvider, F: FunctionProvider>(&self, context: &EvalContext<V, F>) -> Result<Value, EvalError> {
|
|
match self {
|
|
Node::Lit(lit) => match lit {
|
|
Literal::Float(num) => Ok(Value::from_f64(*num)),
|
|
Literal::Complex(num) => Ok(Value::Number(Number::Complex(*num))),
|
|
},
|
|
|
|
Node::BinOp { lhs, op, rhs } => match (lhs.eval(context)?, rhs.eval(context)?) {
|
|
(Value::Number(lhs), Value::Number(rhs)) => Ok(Value::Number(lhs.binary_op(*op, rhs))),
|
|
},
|
|
Node::UnaryOp { expr, op } => match expr.eval(context)? {
|
|
Value::Number(num) => Ok(Value::Number(num.unary_op(*op))),
|
|
},
|
|
Node::Var(name) => context.get_value(name).ok_or_else(|| EvalError::MissingValue(name.clone())),
|
|
Node::FnCall { name, expr } => {
|
|
let values = expr.iter().map(|expr| expr.eval(context)).collect::<Result<Vec<Value>, EvalError>>()?;
|
|
if let Some(function) = DEFAULT_FUNCTIONS.get(&name.as_str()) {
|
|
function(&values).ok_or(EvalError::TypeError)
|
|
} else if let Some(val) = context.run_function(name, &values) {
|
|
Ok(val)
|
|
} else {
|
|
context.get_value(name).ok_or_else(|| EvalError::MissingFunction(name.to_string()))
|
|
}
|
|
}
|
|
}
|
|
}
|
|
}
|
|
|
|
#[cfg(test)]
|
|
mod tests {
|
|
use crate::ast::{BinaryOp, Literal, Node, UnaryOp};
|
|
use crate::context::{EvalContext, ValueMap};
|
|
use crate::value::Value;
|
|
|
|
macro_rules! eval_tests {
|
|
($($name:ident: $expected:expr_2021 => $expr:expr_2021),* $(,)?) => {
|
|
$(
|
|
#[test]
|
|
fn $name() {
|
|
let result = $expr.eval(&EvalContext::default()).unwrap();
|
|
assert_eq!(result, $expected);
|
|
}
|
|
)*
|
|
};
|
|
}
|
|
|
|
eval_tests! {
|
|
test_addition: Value::from_f64(7.0) => Node::BinOp {
|
|
lhs: Box::new(Node::Lit(Literal::Float(3.0))),
|
|
op: BinaryOp::Add,
|
|
rhs: Box::new(Node::Lit(Literal::Float(4.0))),
|
|
},
|
|
test_subtraction: Value::from_f64(1.0) => Node::BinOp {
|
|
lhs: Box::new(Node::Lit(Literal::Float(5.0))),
|
|
op: BinaryOp::Sub,
|
|
rhs: Box::new(Node::Lit(Literal::Float(4.0))),
|
|
},
|
|
test_multiplication: Value::from_f64(12.0) => Node::BinOp {
|
|
lhs: Box::new(Node::Lit(Literal::Float(3.0))),
|
|
op: BinaryOp::Mul,
|
|
rhs: Box::new(Node::Lit(Literal::Float(4.0))),
|
|
},
|
|
test_division: Value::from_f64(2.5) => Node::BinOp {
|
|
lhs: Box::new(Node::Lit(Literal::Float(5.0))),
|
|
op: BinaryOp::Div,
|
|
rhs: Box::new(Node::Lit(Literal::Float(2.0))),
|
|
},
|
|
test_negation: Value::from_f64(-3.0) => Node::UnaryOp {
|
|
expr: Box::new(Node::Lit(Literal::Float(3.0))),
|
|
op: UnaryOp::Neg,
|
|
},
|
|
test_sqrt: Value::from_f64(2.0) => Node::UnaryOp {
|
|
expr: Box::new(Node::Lit(Literal::Float(4.0))),
|
|
op: UnaryOp::Sqrt,
|
|
},
|
|
test_power: Value::from_f64(8.0) => Node::BinOp {
|
|
lhs: Box::new(Node::Lit(Literal::Float(2.0))),
|
|
op: BinaryOp::Pow,
|
|
rhs: Box::new(Node::Lit(Literal::Float(3.0))),
|
|
},
|
|
}
|
|
}
|