From e9c2b9ed804a92a0a88ab9a48e91d594a4d2cd68 Mon Sep 17 00:00:00 2001 From: jess Date: Tue, 26 May 2026 18:24:23 -0700 Subject: [PATCH] Impl, Trait --- compile/src/lib.rs | 36 +++++ core/src/doc.rs | 2 + core/src/interp.rs | 317 ++++++++++++++++++++++++++++++++++++++++- viewport/src/syntax.rs | 2 +- 4 files changed, 353 insertions(+), 4 deletions(-) diff --git a/compile/src/lib.rs b/compile/src/lib.rs index a4e7bd2..3ffc36a 100644 --- a/compile/src/lib.rs +++ b/compile/src/lib.rs @@ -183,6 +183,24 @@ fn emit_stmt(out: &mut String, stmt: &Stmt, depth: usize, deps: &mut Vec { + indent(out, depth, &format!("pub trait {} {{", ident(name))); + for m in methods { + indent(out, depth + 1, &format!("fn {}(&self) -> V;", ident(m))); + } + indent(out, depth, "}"); + } + Stmt::ImplBlock { type_name, trait_name, methods } => { + let header = match trait_name { + Some(t) => format!("impl {} for {} {{", ident(t), ident(type_name)), + None => format!("impl {} {{", ident(type_name)), + }; + indent(out, depth, &header); + for m in methods { + emit_stmt(out, m, depth + 1, deps, hook)?; + } + indent(out, depth, "}"); + } } Ok(()) } @@ -287,6 +305,23 @@ fn emit_expr(expr: &Expr, hook: &dyn DecomposeHook) -> Result { let b = emit_expr(base, hook)?; format!("v_field(&{}, {:?})", b, field) } + Expr::MethodCall(recv, method, args) => { + let r = emit_expr(recv, hook)?; + let arg_list: Vec = args.iter() + .map(|a| emit_expr(a, hook)) + .collect::>()?; + if arg_list.is_empty() { + format!("v_method_call(&{}, {:?}, &[])", r, method) + } else { + format!("v_method_call(&{}, {:?}, &[{}])", r, method, arg_list.join(", ")) + } + } + Expr::StaticCall(type_name, method, args) => { + let arg_list: Vec = args.iter() + .map(|a| emit_expr(a, hook)) + .collect::>()?; + format!("{}__{}({})", ident(type_name), ident(method), arg_list.join(", ")) + } }) } @@ -437,6 +472,7 @@ fn v_cell_table(_table: &str) -> V { V::Array(Vec::new()) } fn v_cell_set(_table: &str, _col: u32, _row: u32, _val: &V) {} fn v_solve_newton(_target_var: &str, _source_fn: &str) -> V { V::Void } fn v_solve_call(_var: &str, _source_fn: &str) -> V { V::Void } +fn v_method_call(_recv: &V, _method: &str, _args: &[V]) -> V { V::Void } // --- generated code below --- diff --git a/core/src/doc.rs b/core/src/doc.rs index e6c6c67..c214c4b 100644 --- a/core/src/doc.rs +++ b/core/src/doc.rs @@ -62,6 +62,8 @@ fn is_cordial(line: &str) -> bool { if line.starts_with("while ") || line.starts_with("while(") { return true; } if line.starts_with("fn ") { return true; } + if line.starts_with("impl ") { return true; } + if line.starts_with("trait ") { return true; } if line.starts_with("if ") || line.starts_with("if(") { return true; } if line.starts_with("else ") || line == "else" || line.starts_with("else{") { return true; } if line.starts_with("for ") { return true; } diff --git a/core/src/interp.rs b/core/src/interp.rs index fbcf9e4..a0c421a 100644 --- a/core/src/interp.rs +++ b/core/src/interp.rs @@ -380,6 +380,8 @@ enum Token { Return, Is, Use, + Impl, + Trait, ColonColon, Newline, Eof, @@ -689,6 +691,8 @@ fn tokenize(input: &str, spice: bool) -> Result, String> { "not" => tokens.push(Token::Bang), "is" => tokens.push(Token::Is), "use" => tokens.push(Token::Use), + "impl" => tokens.push(Token::Impl), + "trait" => tokens.push(Token::Trait), _ => tokens.push(Token::Ident(word)), } } @@ -742,6 +746,15 @@ pub enum Stmt { result_var: String, }, ExprStmt(Expr), + TraitDef { + name: String, + methods: Vec, + }, + ImplBlock { + type_name: String, + trait_name: Option, + methods: Vec, + }, } #[derive(Debug, Clone, PartialEq)] @@ -775,6 +788,8 @@ pub enum Expr { }, Struct(Vec<(String, Expr)>), Field(Box, String), + MethodCall(Box, String, Vec), + StaticCall(String, String, Vec), } // --- Parser --- @@ -846,6 +861,8 @@ impl Parser { Token::Return => self.parse_return(), Token::Fn => self.parse_fn_def(), Token::Use => self.parse_use(), + Token::Trait => self.parse_trait_def(), + Token::Impl => self.parse_impl_block(), Token::At => { let saved = self.pos; let cref = self.parse_cell_ref()?; @@ -1130,6 +1147,83 @@ impl Parser { Ok(Stmt::Use(module, item)) } + fn parse_trait_def(&mut self) -> Result { + self.expect(&Token::Trait)?; + let name = match self.advance() { + Token::Ident(n) => n, + other => return Err(format!("expected trait name, got {:?}", other)), + }; + self.skip_newlines(); + self.expect(&Token::LBrace)?; + self.skip_newlines(); + let mut methods = Vec::new(); + while self.peek() != &Token::RBrace && self.peek() != &Token::Eof { + if self.peek() == &Token::Fn { + self.advance(); + match self.advance() { + Token::Ident(m) => methods.push(m), + other => return Err(format!("expected method name in trait, got {:?}", other)), + } + // skip signature: consume until newline or { + while !matches!(self.peek(), Token::Newline | Token::LBrace | Token::Eof) { + self.advance(); + } + // skip body if present + if self.peek() == &Token::LBrace { + let mut depth = 1; + self.advance(); + while depth > 0 && self.peek() != &Token::Eof { + match self.advance() { + Token::LBrace => depth += 1, + Token::RBrace => depth -= 1, + _ => {} + } + } + } + } + self.skip_newlines(); + } + self.expect(&Token::RBrace)?; + self.skip_newlines(); + Ok(Stmt::TraitDef { name, methods }) + } + + fn parse_impl_block(&mut self) -> Result { + self.expect(&Token::Impl)?; + let first_name = match self.advance() { + Token::Ident(n) => n, + other => return Err(format!("expected type name after 'impl', got {:?}", other)), + }; + // impl Trait for Type OR impl Type + let (type_name, trait_name) = if self.peek() == &Token::For { + self.advance(); + let tn = match self.advance() { + Token::Ident(n) => n, + other => return Err(format!("expected type name after 'for', got {:?}", other)), + }; + (tn, Some(first_name)) + } else { + (first_name, None) + }; + self.skip_newlines(); + self.expect(&Token::LBrace)?; + self.skip_newlines(); + let mut methods = Vec::new(); + while self.peek() != &Token::RBrace && self.peek() != &Token::Eof { + if self.peek() == &Token::Fn { + methods.push(self.parse_fn_def()?); + } else { + self.skip_newlines(); + if self.peek() != &Token::Fn && self.peek() != &Token::RBrace { + self.advance(); + } + } + } + self.expect(&Token::RBrace)?; + self.skip_newlines(); + Ok(Stmt::ImplBlock { type_name, trait_name, methods }) + } + fn parse_cell_ref(&mut self) -> Result { self.expect(&Token::At)?; let first = match self.advance() { @@ -1399,6 +1493,16 @@ impl Parser { fn parse_call(&mut self) -> Result { let mut expr = self.parse_atom()?; if let Expr::Ident(ref name) = expr { + if self.peek() == &Token::ColonColon { + let type_name = name.clone(); + self.advance(); + if let Token::Ident(method) = self.advance() { + let args = self.parse_arg_list()?; + return Ok(Expr::StaticCall(type_name, method, args)); + } else { + return Err("expected method name after '::'".into()); + } + } if self.peek() == &Token::LParen { self.advance(); let mut args = Vec::new(); @@ -1413,7 +1517,7 @@ impl Parser { expr = Expr::Call(name.clone(), args); } } - while matches!(self.peek(), Token::LBracket | Token::Dot) { + loop { match self.peek() { Token::LBracket => { self.advance(); @@ -1427,14 +1531,33 @@ impl Parser { Token::Ident(name) => name, other => return Err(format!("expected field name after '.', got {:?}", other)), }; - expr = Expr::Field(Box::new(expr), field); + if self.peek() == &Token::LParen { + let args = self.parse_arg_list()?; + expr = Expr::MethodCall(Box::new(expr), field, args); + } else { + expr = Expr::Field(Box::new(expr), field); + } } - _ => unreachable!(), + _ => break, } } Ok(expr) } + fn parse_arg_list(&mut self) -> Result, String> { + self.expect(&Token::LParen)?; + let mut args = Vec::new(); + if self.peek() != &Token::RParen { + args.push(self.parse_expr()?); + while self.peek() == &Token::Comma { + self.advance(); + args.push(self.parse_expr()?); + } + } + self.expect(&Token::RParen)?; + Ok(args) + } + fn parse_atom(&mut self) -> Result { match self.peek().clone() { Token::Number(n) => { self.advance(); Ok(Expr::Num(n)) } @@ -1534,6 +1657,8 @@ pub struct Interpreter { current_table: Option, current_block: Option, table_writes: Vec, + methods: HashMap<(String, String), FnDef>, + traits: HashMap>, } #[derive(Debug, Clone)] @@ -1561,6 +1686,8 @@ impl Interpreter { current_table: None, current_block: None, table_writes: Vec::new(), + methods: HashMap::new(), + traits: HashMap::new(), } } @@ -1933,6 +2060,35 @@ impl Interpreter { Stmt::ExprStmt(expr) => { self.eval_expr(expr, depth) } + Stmt::TraitDef { name, methods } => { + self.traits.insert(name.clone(), methods.clone()); + Ok(Value::Void) + } + Stmt::ImplBlock { type_name, trait_name, methods } => { + for m in methods { + if let Stmt::FnDef { name, params, return_type, body } = m { + let fndef = FnDef { + params: params.clone(), + return_type: return_type.clone(), + body: body.clone(), + }; + self.methods.insert((type_name.clone(), name.clone()), fndef); + } + } + if let Some(trait_name) = trait_name { + if let Some(required) = self.traits.get(trait_name) { + for req in required { + if !self.methods.contains_key(&(type_name.clone(), req.clone())) { + return Err(format!( + "impl {} for {}: missing required method '{}'", + trait_name, type_name, req + )); + } + } + } + } + Ok(Value::Void) + } } } @@ -2046,6 +2202,39 @@ impl Interpreter { other => Err(format!("cannot read field '{}' on {}", name, type_name(&other))), } } + Expr::MethodCall(receiver, method, args) => { + let recv_val = self.eval_expr(receiver, depth)?; + let type_tag = match &recv_val { + Value::Struct(s) => { + s.borrow().get("__type").and_then(|v| match v { + Value::Str(s) => Some(s.clone()), + _ => None, + }) + } + _ => None, + }; + let type_tag = type_tag.ok_or_else(|| + format!("cannot call .{}() — receiver has no __type", method) + )?; + let fndef = self.methods.get(&(type_tag.clone(), method.clone())) + .cloned() + .ok_or_else(|| format!("no method '{}' on type '{}'", method, type_tag))?; + let mut eval_args = vec![recv_val]; + for a in args { + eval_args.push(self.eval_expr(a, depth)?); + } + self.call_fndef(&fndef, &eval_args, depth) + } + Expr::StaticCall(type_name_str, method, args) => { + let fndef = self.methods.get(&(type_name_str.clone(), method.clone())) + .cloned() + .ok_or_else(|| format!("no static method '{}::{}' ", type_name_str, method))?; + let mut eval_args = Vec::new(); + for a in args { + eval_args.push(self.eval_expr(a, depth)?); + } + self.call_fndef(&fndef, &eval_args, depth) + } Expr::Range(start, end) => { let sv = self.eval_expr(start, depth)?; let ev = self.eval_expr(end, depth)?; @@ -2958,6 +3147,46 @@ impl Interpreter { apply_fn_return_type(&fdef.return_type, result, name) } + fn call_fndef(&mut self, fndef: &FnDef, arg_vals: &[Value], depth: u32) -> Result { + if depth >= MAX_CALL_DEPTH { + return Err("maximum call depth exceeded".into()); + } + let saved_vars = self.vars.clone(); + let saved_types = self.var_types.clone(); + for ((pname, pty), val) in fndef.params.iter().zip(arg_vals.iter()) { + let bound = match pty { + Some(t) => apply_type_annotation(val, Some(t)) + .map_err(|e| format!("parameter '{}': {}", pname, e))?, + None => val.clone(), + }; + if let Some(t) = pty { + self.var_types.insert(pname.clone(), t.clone()); + } else { + self.var_types.remove(pname); + } + self.vars.insert(pname.clone(), bound); + } + let mut result = Value::Void; + for stmt in &fndef.body { + match self.exec_stmt(stmt, depth + 1) { + Ok(v) => result = v, + Err(e) if e.starts_with('\x00') => { + self.vars = saved_vars; + self.var_types = saved_types; + return Ok(self.return_slot.take().unwrap_or(Value::Void)); + } + Err(e) => { + self.vars = saved_vars; + self.var_types = saved_types; + return Err(e); + } + } + } + self.vars = saved_vars; + self.var_types = saved_types; + Ok(result) + } + fn build_solved_fn_def( &self, source_fn: &str, @@ -3494,6 +3723,17 @@ fn collect_formula_refs(expr: &Expr, current_table: &str, out: &mut Vec { + collect_formula_refs(recv, current_table, out); + for a in args { + collect_formula_refs(a, current_table, out); + } + } + Expr::StaticCall(_, _, args) => { + for a in args { + collect_formula_refs(a, current_table, out); + } + } Expr::Num(_) | Expr::Str(_) | Expr::Bool(_) | Expr::SolveMacro { .. } => {} } } @@ -5658,4 +5898,75 @@ fn find(arr, target) { let want = 1.0 / (4.0 * pi * pi * 2600.0 * 2600.0 * 1e-9); assert!((n - want).abs() / want < 1e-6, "got {}, want {}", n, want); } + + #[test] + fn impl_static_constructor() { + let mut i = Interpreter::new(); + i.exec_line("impl Point {\n fn new(x, y) {\n return {__type: \"Point\", x: x, y: y}\n }\n}").unwrap(); + let v = i.eval_expr_str("Point::new(3, 4)").unwrap(); + match v { + Value::Struct(s) => { + assert!(matches!(s.borrow().get("x"), Some(Value::Number(n)) if *n == 3.0)); + assert!(matches!(s.borrow().get("y"), Some(Value::Number(n)) if *n == 4.0)); + } + _ => panic!("expected struct"), + } + } + + #[test] + fn impl_method_call() { + let mut i = Interpreter::new(); + i.exec_line("impl Vec2 {\n fn new(x, y) {\n return {__type: \"Vec2\", x: x, y: y}\n }\n fn length(self) {\n return sqrt(self.x^2 + self.y^2)\n }\n}").unwrap(); + i.exec_line("let v = Vec2::new(3, 4)").unwrap(); + let v = i.eval_expr_str("v.length()").unwrap(); + assert!(matches!(v, Value::Number(n) if n == 5.0)); + } + + #[test] + fn impl_method_with_args() { + let mut i = Interpreter::new(); + i.exec_line("impl Vec2 {\n fn new(x, y) {\n return {__type: \"Vec2\", x: x, y: y}\n }\n fn add(self, other) {\n return Vec2::new(self.x + other.x, self.y + other.y)\n }\n}").unwrap(); + i.exec_line("let a = Vec2::new(1, 2)").unwrap(); + i.exec_line("let b = Vec2::new(3, 4)").unwrap(); + let v = i.eval_expr_str("a.add(b)").unwrap(); + match v { + Value::Struct(s) => { + assert!(matches!(s.borrow().get("x"), Some(Value::Number(n)) if *n == 4.0)); + assert!(matches!(s.borrow().get("y"), Some(Value::Number(n)) if *n == 6.0)); + } + _ => panic!("expected struct"), + } + } + + #[test] + fn trait_def_and_impl() { + let mut i = Interpreter::new(); + i.exec_line("trait Measurable {\n fn area(self)\n}").unwrap(); + i.exec_line("impl Measurable for Circle {\n fn area(self) {\n return pi * self.r^2\n }\n}").unwrap(); + i.exec_line("impl Circle {\n fn new(r) {\n return {__type: \"Circle\", r: r}\n }\n}").unwrap(); + i.exec_line("let c = Circle::new(5)").unwrap(); + let v = i.eval_expr_str("c.area()").unwrap(); + let expected = std::f64::consts::PI * 25.0; + match v { + Value::Number(n) => assert!((n - expected).abs() < 1e-10), + _ => panic!("expected number"), + } + } + + #[test] + fn trait_missing_method_errors() { + let mut i = Interpreter::new(); + i.exec_line("trait Drawable {\n fn draw(self)\n fn bounds(self)\n}").unwrap(); + let result = i.exec_line("impl Drawable for Box {\n fn draw(self) {\n return 0\n }\n}"); + assert!(result.is_err()); + assert!(result.unwrap_err().contains("missing required method 'bounds'")); + } + + #[test] + fn method_on_untyped_struct_errors() { + let mut i = Interpreter::new(); + i.exec_line("let p = {x: 1, y: 2}").unwrap(); + let result = i.eval_expr_str("p.length()"); + assert!(result.is_err()); + } } diff --git a/viewport/src/syntax.rs b/viewport/src/syntax.rs index 0e60498..bb9356f 100644 --- a/viewport/src/syntax.rs +++ b/viewport/src/syntax.rs @@ -626,7 +626,7 @@ fn is_operator_byte(b: u8) -> bool { fn is_cordial_keyword(w: &str) -> bool { matches!(w, "let" | "fn" | "if" | "else" | "while" | "for" | "in" | "return" | "use" | "is" | "true" | "false" | "and" | "or" | "not" - | "solve" | "where" | "from") + | "solve" | "where" | "from" | "impl" | "trait") } fn is_cordial_builtin(w: &str) -> bool {