From 2bb589d655d1581b71ea6563b006f2daef8ea0ff Mon Sep 17 00:00:00 2001 From: Matthias Schiffer Date: Thu, 25 Apr 2024 00:21:01 +0200 Subject: rebel-lang: new crate Handle a lot of typechecking and evaluation of expressions. --- crates/rebel-lang/Cargo.toml | 17 ++ crates/rebel-lang/examples/eval-string.rs | 101 +++++++ crates/rebel-lang/examples/type-string.rs | 77 ++++++ crates/rebel-lang/src/func.rs | 31 +++ crates/rebel-lang/src/lib.rs | 4 + crates/rebel-lang/src/scope.rs | 35 +++ crates/rebel-lang/src/typing.rs | 375 ++++++++++++++++++++++++++ crates/rebel-lang/src/value.rs | 422 ++++++++++++++++++++++++++++++ 8 files changed, 1062 insertions(+) create mode 100644 crates/rebel-lang/Cargo.toml create mode 100644 crates/rebel-lang/examples/eval-string.rs create mode 100644 crates/rebel-lang/examples/type-string.rs create mode 100644 crates/rebel-lang/src/func.rs create mode 100644 crates/rebel-lang/src/lib.rs create mode 100644 crates/rebel-lang/src/scope.rs create mode 100644 crates/rebel-lang/src/typing.rs create mode 100644 crates/rebel-lang/src/value.rs (limited to 'crates') diff --git a/crates/rebel-lang/Cargo.toml b/crates/rebel-lang/Cargo.toml new file mode 100644 index 0000000..44183be --- /dev/null +++ b/crates/rebel-lang/Cargo.toml @@ -0,0 +1,17 @@ +[package] +name = "rebel-lang" +version = "0.1.0" +authors = ["Matthias Schiffer "] +license = "MIT" +edition = "2021" + +# See more keys and their definitions at https://doc.rust-lang.org/cargo/reference/manifest.html + +[dependencies] +rebel-common = { path = "../rebel-common" } +rebel-parse = { path = "../rebel-parse" } + +enum-kinds = "0.5.1" + +[dev-dependencies] +clap = { version = "4.0.0", features = ["derive"] } diff --git a/crates/rebel-lang/examples/eval-string.rs b/crates/rebel-lang/examples/eval-string.rs new file mode 100644 index 0000000..a9c0b01 --- /dev/null +++ b/crates/rebel-lang/examples/eval-string.rs @@ -0,0 +1,101 @@ +use std::{fmt::Debug, process, time::Instant}; + +use clap::Parser; + +use rebel_lang::{ + func::{Func, FuncDef, FuncType}, + typing::{ArrayLen, Type, TypeFamily}, + value::{Context, EvalError, Result, Value}, +}; +use rebel_parse::{recipe, tokenize}; + +#[derive(Clone, Debug, Parser)] +struct Opts { + input: String, +} + +fn intrinsic_array_len(params: &[Value]) -> Result { + assert!(params.len() == 1); + let Value::Array(array) = ¶ms[0] else { + panic!(); + }; + Ok(Value::Integer(array.len().try_into().or(Err(EvalError))?)) +} +fn intrinsic_string_len(params: &[Value]) -> Result { + assert!(params.len() == 1); + let Value::Str(string) = ¶ms[0] else { + panic!(); + }; + Ok(Value::Integer( + string.chars().count().try_into().or(Err(EvalError))?, + )) +} + +fn main() { + let opts: Opts = Opts::parse(); + let input = opts.input.trim(); + + let start = Instant::now(); + let result = tokenize::token_stream(input); + let dur = Instant::now().duration_since(start); + println!("Tokenization took {} µs", dur.as_micros()); + + let tokens = match result { + Ok(value) => value, + Err(err) => { + println!("{err}"); + process::exit(1); + } + }; + + let start = Instant::now(); + let result = recipe::expr(&tokens); + let dur = Instant::now().duration_since(start); + println!("Parsing took {} µs", dur.as_micros()); + + let expr = match result { + Ok(value) => value, + Err(err) => { + println!("{err}"); + process::exit(1); + } + }; + + let mut ctx = Context::default(); + + ctx.methods.entry(TypeFamily::Array).or_default().insert( + "len", + Func { + typ: FuncType { + params: vec![Type::Array(Box::new(Type::Free), ArrayLen::Dynamic)], + ret: Type::Int, + }, + def: FuncDef::Intrinsic(intrinsic_array_len), + }, + ); + ctx.methods.entry(TypeFamily::Str).or_default().insert( + "len", + Func { + typ: FuncType { + params: vec![Type::Str], + ret: Type::Int, + }, + def: FuncDef::Intrinsic(intrinsic_string_len), + }, + ); + + let start = Instant::now(); + let result = ctx.eval(&expr); + let dur = Instant::now().duration_since(start); + println!("Eval took {} µs", dur.as_micros()); + + let value = match result { + Ok(value) => value, + Err(err) => { + println!("{err:?}"); + process::exit(1); + } + }; + + println!("{value}"); +} diff --git a/crates/rebel-lang/examples/type-string.rs b/crates/rebel-lang/examples/type-string.rs new file mode 100644 index 0000000..5490572 --- /dev/null +++ b/crates/rebel-lang/examples/type-string.rs @@ -0,0 +1,77 @@ +use std::{fmt::Debug, process, time::Instant}; + +use clap::Parser; + +use rebel_lang::{ + func::FuncType, + typing::{ArrayLen, Context, Type, TypeFamily}, +}; +use rebel_parse::{recipe, tokenize}; + +#[derive(Clone, Debug, Parser)] +struct Opts { + input: String, +} + +fn main() { + let opts: Opts = Opts::parse(); + let input = opts.input.trim(); + + let start = Instant::now(); + let result = tokenize::token_stream(input); + let dur = Instant::now().duration_since(start); + println!("Tokenization took {} µs", dur.as_micros()); + + let tokens = match result { + Ok(value) => value, + Err(err) => { + println!("{err}"); + process::exit(1); + } + }; + + let start = Instant::now(); + let result = recipe::expr(&tokens); + let dur = Instant::now().duration_since(start); + println!("Parsing took {} µs", dur.as_micros()); + + let expr = match result { + Ok(value) => value, + Err(err) => { + println!("{err}"); + process::exit(1); + } + }; + + let mut ctx = Context::default(); + + ctx.methods.entry(TypeFamily::Array).or_default().insert( + "len", + FuncType { + params: vec![Type::Array(Box::new(Type::Free), ArrayLen::Dynamic)], + ret: Type::Int, + }, + ); + ctx.methods.entry(TypeFamily::Str).or_default().insert( + "len", + FuncType { + params: vec![Type::Str], + ret: Type::Int, + }, + ); + + let start = Instant::now(); + let result = ctx.ast_expr_type(&expr); + let dur = Instant::now().duration_since(start); + println!("Typing took {} µs", dur.as_micros()); + + let typ = match result { + Ok(value) => value, + Err(err) => { + println!("{err:?}"); + process::exit(1); + } + }; + + println!("{typ}"); +} diff --git a/crates/rebel-lang/src/func.rs b/crates/rebel-lang/src/func.rs new file mode 100644 index 0000000..129d89e --- /dev/null +++ b/crates/rebel-lang/src/func.rs @@ -0,0 +1,31 @@ +use std::fmt::Display; + +use crate::{ + typing::Type, + value::{Result, Value}, +}; + +#[derive(Debug, Clone, PartialEq, Eq)] +pub struct Func { + pub typ: FuncType, + pub def: FuncDef, +} + +#[derive(Debug, Clone, PartialEq, Eq)] +pub struct FuncType { + pub params: Vec, + pub ret: Type, +} + +impl Display for FuncType { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + // TODO + write!(f, "fn( ... ) -> {}", self.ret) + } +} + +#[derive(Debug, Clone, PartialEq, Eq)] +pub enum FuncDef { + Intrinsic(fn(&[Value]) -> Result), + Body, +} diff --git a/crates/rebel-lang/src/lib.rs b/crates/rebel-lang/src/lib.rs new file mode 100644 index 0000000..6b75d99 --- /dev/null +++ b/crates/rebel-lang/src/lib.rs @@ -0,0 +1,4 @@ +pub mod func; +pub mod scope; +pub mod typing; +pub mod value; diff --git a/crates/rebel-lang/src/scope.rs b/crates/rebel-lang/src/scope.rs new file mode 100644 index 0000000..6dcc7ee --- /dev/null +++ b/crates/rebel-lang/src/scope.rs @@ -0,0 +1,35 @@ +use std::collections::HashMap; + +use rebel_parse::ast; + +#[derive(Debug, Clone)] +pub struct Module(HashMap>); + +impl Module { + pub fn lookup(&self, path: &[ast::Ident<'_>]) -> Option<&T> { + let (ident, rest) = path.split_first()?; + + match self.0.get(ident.name)? { + ModuleEntry::Module(module) => module.lookup(rest), + ModuleEntry::Def(typ) => { + if rest.is_empty() { + Some(typ) + } else { + None + } + } + } + } +} + +impl Default for Module { + fn default() -> Self { + Self(Default::default()) + } +} + +#[derive(Debug, Clone)] +pub enum ModuleEntry { + Module(Module), + Def(T), +} diff --git a/crates/rebel-lang/src/typing.rs b/crates/rebel-lang/src/typing.rs new file mode 100644 index 0000000..34492a6 --- /dev/null +++ b/crates/rebel-lang/src/typing.rs @@ -0,0 +1,375 @@ +use std::{collections::HashMap, fmt::Display}; + +use enum_kinds::EnumKind; + +use rebel_parse::ast; + +use crate::{func::FuncType, scope::Module}; + +#[derive(Debug)] +pub struct TypeError; + +pub type Result = std::result::Result; + +#[derive(Debug, Clone, Copy, PartialEq, Eq)] +pub enum Coerce { + None, + Common, + Dynamic, + Assign, +} + +#[derive(Debug, Clone, PartialEq, Eq, EnumKind)] +#[enum_kind(TypeFamily, derive(Hash))] +pub enum Type { + Free, + Unit, + Bool, + Int, + Str, + Tuple(Vec), + Array(Box, ArrayLen), + Map(HashMap), + Fn(Box), +} + +impl Type { + pub fn unify(self, other: Type, coerce: Coerce) -> Result { + use Type::*; + + Ok(match (self, other) { + (Free, typ) => typ, + (typ, Free) => typ, + (Unit, Unit) => Unit, + (Bool, Bool) => Bool, + (Int, Int) => Int, + (Str, Str) => Str, + (Tuple(self_elems), Tuple(other_elems)) => { + if self_elems.len() != other_elems.len() { + return Err(TypeError); + } + Tuple( + self_elems + .into_iter() + .zip(other_elems.into_iter()) + .map(|(t1, t2)| t1.unify(t2, coerce)) + .collect::>()?, + ) + } + (Array(self_inner, self_len), Array(other_inner, other_len)) => Array( + Box::new(self_inner.unify(*other_inner, coerce)?), + ArrayLen::unify(self_len, other_len, coerce)?, + ), + (Map(self_entries), Map(mut other_entries)) => { + if self_entries.len() != other_entries.len() { + return Err(TypeError); + } + Map(self_entries + .into_iter() + .map(|(k, v)| { + let Some(v2) = other_entries.remove(&k) else { + return Err(TypeError); + }; + Ok((k, v.unify(v2, coerce)?)) + }) + .collect::>()?) + } + _ => return Err(TypeError), + }) + } +} + +impl Display for Type { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + match self { + Type::Free => write!(f, "_"), + Type::Unit => write!(f, "()"), + Type::Bool => write!(f, "bool"), + Type::Int => write!(f, "int"), + Type::Str => write!(f, "str"), + Type::Tuple(elems) => { + let mut first = true; + f.write_str("(")?; + if elems.is_empty() { + f.write_str(",")?; + } + for elem in elems { + if !first { + f.write_str(", ")?; + } + first = false; + elem.fmt(f)?; + } + f.write_str(")") + } + Type::Array(typ, len) => write!(f, "[{typ}{len}]"), + Type::Map(entries) => { + let mut first = true; + f.write_str("{")?; + for (key, typ) in entries { + if !first { + f.write_str(", ")?; + } + first = false; + write!(f, "{key}: {typ}")?; + } + f.write_str("}") + } + /* TODO */ + Type::Fn(func) => write!(f, "{func}"), + } + } +} + +#[derive(Debug, Clone, Copy, PartialEq, Eq, PartialOrd, Ord)] +pub enum ArrayLen { + Free, + Fixed(usize), + Dynamic, +} + +impl ArrayLen { + fn unify(self, other: Self, coerce: Coerce) -> Result { + use ArrayLen::*; + + Ok(match (self, other, coerce) { + (Free, len, _) => len, + (len, Free, _) => len, + (l1, l2, _) if l1 == l2 => l1, + (_, _, Coerce::Common) => Dynamic, + (Dynamic, Fixed(_), Coerce::Dynamic | Coerce::Assign) => Dynamic, + (Fixed(_), Dynamic, Coerce::Dynamic) => Dynamic, + _ => return Err(TypeError), + }) + } + + fn add(self, other: Self) -> Result { + use ArrayLen::*; + + Ok(match (self, other) { + (Free, _) => return Err(TypeError), + (_, Free) => return Err(TypeError), + (Dynamic, _) => Dynamic, + (_, Dynamic) => Dynamic, + (Fixed(l1), Fixed(l2)) => Fixed(l1 + l2), + }) + } +} + +impl Display for ArrayLen { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + match self { + ArrayLen::Free => write!(f, "; _"), + ArrayLen::Fixed(len) => write!(f, "; {len}"), + ArrayLen::Dynamic => Ok(()), + } + } +} + +#[derive(Debug, Clone, Default)] +pub struct Context { + pub values: Module, + pub methods: HashMap>, +} + +impl Context { + pub fn ast_expr_type(&self, expr: &ast::Expr<'_>) -> Result { + use ast::Expr::*; + + Ok(match expr { + Binary { left, op, right } => self.binary_op_type(left, *op, right)?, + Unary { op, expr } => self.unary_op_type(*op, expr)?, + Apply { expr, params } => self.apply_type(expr, params)?, + Method { + expr, + method, + params, + } => self.method_type(expr, method, params)?, + Index { expr, index } => self.index_type(expr, index)?, + Field { expr, field } => self.field_type(expr, field)?, + Paren(subexpr) => self.ast_expr_type(subexpr)?, + Path(path) => self.path_type(path)?, + Literal(lit) => self.literal_type(lit)?, + }) + } + + fn binary_op_type( + &self, + left: &ast::Expr<'_>, + op: ast::OpBinary, + right: &ast::Expr<'_>, + ) -> Result { + use ast::OpBinary::*; + use Type::*; + + let tl = self.ast_expr_type(left)?; + let tr = self.ast_expr_type(right)?; + + Ok(match (tl, op, tr) { + (Str, Add, Str) => Str, + (Int, Add, Int) => Int, + (Array(t1, l1), Add, Array(t2, l2)) => Array( + Box::new(t1.unify(*t2, Coerce::Common)?), + ArrayLen::add(l1, l2)?, + ), + (Int, Sub, Int) => Int, + (Int, Mul, Int) => Int, + (Int, Div, Int) => Int, + (Int, Rem, Int) => Int, + (Bool, And, Bool) => Bool, + (Bool, Or, Bool) => Bool, + (l, Eq, r) => { + l.unify(r, Coerce::Dynamic)?; + Bool + } + (l, Ne, r) => { + l.unify(r, Coerce::Dynamic)?; + Bool + } + (Int, Lt, Int) => Bool, + (Int, Le, Int) => Bool, + (Int, Ge, Int) => Bool, + (Int, Gt, Int) => Bool, + _ => return Err(TypeError), + }) + } + + fn unary_op_type(&self, op: ast::OpUnary, expr: &ast::Expr<'_>) -> Result { + use ast::OpUnary::*; + use Type::*; + + let typ = self.ast_expr_type(expr)?; + + Ok(match (op, typ) { + (Not, Bool) => Bool, + (Neg, Int) => Int, + _ => return Err(TypeError), + }) + } + + fn index_type(&self, expr: &ast::Expr<'_>, index: &ast::Expr<'_>) -> Result { + use Type::*; + + let expr_type = self.ast_expr_type(expr)?; + let index_type = self.ast_expr_type(index)?; + + let Array(elem_type, _) = expr_type else { + return Err(TypeError); + }; + if index_type != Int { + return Err(TypeError); + } + Ok(*elem_type) + } + + fn apply_type(&self, expr: &ast::Expr<'_>, params: &[ast::Expr]) -> Result { + use Type::*; + + let expr_type = self.ast_expr_type(expr)?; + + let Fn(func) = expr_type else { + return Err(TypeError); + }; + + if func.params.len() != params.len() { + return Err(TypeError); + } + + for (func_param_type, call_param) in func.params.iter().zip(params) { + let call_param_type = self.ast_expr_type(call_param)?; + func_param_type + .clone() + .unify(call_param_type, Coerce::Dynamic)?; + } + + Ok(func.ret) + } + + fn method_type( + &self, + expr: &ast::Expr<'_>, + method: &ast::Ident<'_>, + params: &[ast::Expr], + ) -> Result { + let expr_type = self.ast_expr_type(expr)?; + let type_family = TypeFamily::from(&expr_type); + + let methods = self.methods.get(&type_family).ok_or(TypeError)?; + let method = methods.get(method.name).ok_or(TypeError)?; + + let (self_param, func_params) = method.params.split_first().ok_or(TypeError)?; + self_param.clone().unify(expr_type, Coerce::Dynamic)?; + + if func_params.len() != params.len() { + return Err(TypeError); + } + + for (func_param_type, call_param) in func_params.iter().zip(params) { + let call_param_type = self.ast_expr_type(call_param)?; + func_param_type + .clone() + .unify(call_param_type, Coerce::Dynamic)?; + } + + Ok(method.ret.clone()) + } + + fn field_type(&self, expr: &ast::Expr<'_>, field: &ast::Ident<'_>) -> Result { + use Type::*; + + let expr_type = self.ast_expr_type(expr)?; + let name = field.name; + + Ok(match expr_type { + Tuple(elems) => { + let index: usize = name.parse().or(Err(TypeError))?; + elems.into_iter().nth(index).ok_or(TypeError)? + } + Map(mut entries) => entries.remove(name).ok_or(TypeError)?, + _ => return Err(TypeError), + }) + } + + fn path_type(&self, path: &ast::Path<'_>) -> Result { + use Type::*; + + if path.components == [ast::Ident { name: "_" }] { + return Ok(Free); + } + + self.values + .lookup(&path.components) + .ok_or(TypeError) + .cloned() + } + + fn literal_type(&self, lit: &ast::Literal<'_>) -> Result { + use ast::Literal; + use Type::*; + + Ok(match lit { + Literal::Unit => Unit, + Literal::Bool(_) => Bool, + Literal::Int(_) => Int, + Literal::Str { .. } => Str, + Literal::Tuple(elems) => Tuple( + elems + .iter() + .map(|elem| self.ast_expr_type(elem)) + .collect::>()?, + ), + Literal::Array(elems) => Array( + Box::new(elems.iter().try_fold(Type::Free, |acc, elem| { + acc.unify(self.ast_expr_type(elem)?, Coerce::Common) + })?), + ArrayLen::Fixed(elems.len()), + ), + Literal::Map(entries) => Map(entries + .iter() + .map(|ast::MapEntry { key, value }| { + Ok(((*key).to_owned(), self.ast_expr_type(value)?)) + }) + .collect::>()?), + }) + } +} diff --git a/crates/rebel-lang/src/value.rs b/crates/rebel-lang/src/value.rs new file mode 100644 index 0000000..c7e971e --- /dev/null +++ b/crates/rebel-lang/src/value.rs @@ -0,0 +1,422 @@ +use std::{ + collections::HashMap, + fmt::{Display, Write}, + iter, +}; + +use rebel_parse::ast; + +use crate::{ + func::{Func, FuncDef}, + scope::Module, + typing::{self, ArrayLen, Coerce, Type, TypeFamily}, +}; + +#[derive(Debug)] +pub struct EvalError; + +pub type Result = std::result::Result; + +#[derive(Debug, Clone, PartialEq, Eq)] +pub enum Value { + Unit, + Boolean(bool), + Integer(i64), + Str(String), + Tuple(Vec), + Array(Vec), + Map(HashMap), + Fn(Box), +} + +impl Value { + pub fn typ(&self) -> typing::Result { + Ok(match self { + Value::Unit => Type::Unit, + Value::Boolean(_) => Type::Bool, + Value::Integer(_) => Type::Int, + Value::Str(_) => Type::Str, + Value::Tuple(elems) => Type::Tuple( + elems + .iter() + .map(|elem| elem.typ()) + .collect::>()?, + ), + Value::Array(elems) => Type::Array( + Box::new(Self::array_elem_type(elems)?), + ArrayLen::Fixed(elems.len()), + ), + Value::Map(entries) => Type::Map( + entries + .iter() + .map(|(k, v)| Ok((k.clone(), v.typ()?))) + .collect::>()?, + ), + Value::Fn(func) => Type::Fn(Box::new(func.typ.clone())), + }) + } + + fn array_elem_type(elems: &[Value]) -> typing::Result { + elems.iter().try_fold(Type::Free, |acc, elem| { + acc.unify(elem.typ()?, Coerce::Common) + }) + } +} + +impl Display for Value { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + match self { + Value::Unit => f.write_str("()"), + Value::Boolean(value) => value.fmt(f), + Value::Integer(value) => value.fmt(f), + Value::Str(value) => write!(f, "{value:?}"), + Value::Tuple(elems) => { + let mut first = true; + f.write_str("(")?; + if elems.is_empty() { + f.write_str(",")?; + } + for elem in elems { + if !first { + f.write_str(", ")?; + } + first = false; + elem.fmt(f)?; + } + f.write_str(")") + } + Value::Array(elems) => { + let mut first = true; + f.write_str("[")?; + for elem in elems { + if !first { + f.write_str(", ")?; + } + first = false; + elem.fmt(f)?; + } + f.write_str("]") + } + Value::Map(entries) => { + let mut first = true; + f.write_str("{")?; + for (key, value) in entries { + if !first { + f.write_str(", ")?; + } + first = false; + write!(f, "{key} = {value}")?; + } + f.write_str("}") + } + Value::Fn(func) => func.typ.fmt(f), + } + } +} + +#[derive(Debug)] +struct Stringify<'a>(&'a Value); + +impl<'a> Display for Stringify<'a> { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + match self.0 { + Value::Boolean(value) => value.fmt(f), + Value::Integer(value) => value.fmt(f), + Value::Str(value) => value.fmt(f), + _ => Err(std::fmt::Error), + } + } +} + +#[derive(Debug)] +struct ScriptStringify<'a>(&'a Value); + +impl<'a> ScriptStringify<'a> { + fn fmt_list(elems: &'a [Value], f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + let mut first = true; + for elem in elems { + if !first { + f.write_char(' ')?; + } + ScriptStringify(elem).fmt(f)?; + first = false; + } + Ok(()) + } +} + +impl<'a> Display for ScriptStringify<'a> { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + match self.0 { + Value::Boolean(value) => { + value.fmt(f)?; + } + Value::Integer(value) => { + value.fmt(f)?; + } + Value::Str(value) => { + f.write_char('\'')?; + f.write_str(&value.replace('\'', "'\\''"))?; + f.write_char('\'')?; + } + Value::Array(elems) => { + Self::fmt_list(elems, f)?; + } + Value::Tuple(elems) => { + Self::fmt_list(elems, f)?; + } + _ => return Err(std::fmt::Error), + }; + Ok(()) + } +} + +#[derive(Debug, Clone, Default)] +pub struct Context { + pub values: Module, + pub methods: HashMap>, +} + +impl Context { + pub fn eval(&self, expr: &ast::Expr<'_>) -> Result { + use ast::Expr::*; + + Ok(match expr { + Binary { left, op, right } => self.eval_binary_op(left, *op, right)?, + Unary { op, expr } => self.eval_unary_op(*op, expr)?, + Apply { expr, params } => self.eval_apply(expr, params)?, + Method { + expr, + method, + params, + } => self.eval_method(expr, method, params)?, + Index { expr, index } => self.eval_index(expr, index)?, + Field { expr, field } => self.eval_field(expr, field)?, + Paren(subexpr) => self.eval(subexpr)?, + Path(path) => self.eval_path(path)?, + Literal(lit) => self.eval_literal(lit)?, + }) + } + + fn eval_binary_op( + &self, + left: &ast::Expr<'_>, + op: ast::OpBinary, + right: &ast::Expr<'_>, + ) -> Result { + use ast::OpBinary::*; + use Value::*; + + let tl = self.eval(left)?; + let tr = self.eval(right)?; + + Ok(match (tl, op, tr) { + (Str(s1), Add, Str(s2)) => Str(s1 + &s2), + (Integer(i1), Add, Integer(i2)) => Integer(i1.checked_add(i2).ok_or(EvalError)?), + (Array(elems1), Add, Array(elems2)) => Array([elems1, elems2].concat()), + (Integer(i1), Sub, Integer(i2)) => Integer(i1.checked_sub(i2).ok_or(EvalError)?), + (Integer(i1), Mul, Integer(i2)) => Integer(i1.checked_mul(i2).ok_or(EvalError)?), + (Integer(i1), Div, Integer(i2)) => Integer(i1.checked_div(i2).ok_or(EvalError)?), + (Integer(i1), Rem, Integer(i2)) => Integer(i1.checked_rem(i2).ok_or(EvalError)?), + (Boolean(b1), And, Boolean(b2)) => Boolean(b1 && b2), + (Boolean(b1), Or, Boolean(b2)) => Boolean(b1 || b2), + (l, Eq, r) => Boolean(l == r), + (l, Ne, r) => Boolean(l != r), + (Integer(i1), Lt, Integer(i2)) => Boolean(i1 < i2), + (Integer(i1), Le, Integer(i2)) => Boolean(i1 <= i2), + (Integer(i1), Ge, Integer(i2)) => Boolean(i1 >= i2), + (Integer(i1), Gt, Integer(i2)) => Boolean(i1 > i2), + _ => return Err(EvalError), + }) + } + + fn eval_unary_op(&self, op: ast::OpUnary, expr: &ast::Expr<'_>) -> Result { + use ast::OpUnary::*; + use Value::*; + + let typ = self.eval(expr)?; + + Ok(match (op, typ) { + (Not, Boolean(val)) => Boolean(!val), + (Neg, Integer(val)) => Integer(val.checked_neg().ok_or(EvalError)?), + _ => return Err(EvalError), + }) + } + + fn eval_index(&self, expr: &ast::Expr<'_>, index: &ast::Expr<'_>) -> Result { + use Value::*; + + let expr_value = self.eval(expr)?; + let index_value = self.eval(index)?; + + let Array(elems) = expr_value else { + return Err(EvalError); + }; + let Integer(index) = index_value else { + return Err(EvalError); + }; + let index: usize = index.try_into().or(Err(EvalError))?; + elems.into_iter().nth(index).ok_or(EvalError) + } + + fn eval_apply(&self, expr: &ast::Expr<'_>, params: &[ast::Expr]) -> Result { + use Value::*; + + let value = self.eval(expr)?; + + let Fn(func) = value else { + return Err(EvalError); + }; + + if func.typ.params.len() != params.len() { + return Err(EvalError); + } + + let param_values: Vec<_> = params + .iter() + .map(|param| self.eval(param)) + .collect::>()?; + + for (func_param_type, call_param_value) in func.typ.params.iter().zip(¶m_values) { + let call_param_type = call_param_value.typ().or(Err(EvalError))?; + func_param_type + .clone() + .unify(call_param_type, Coerce::Assign) + .or(Err(EvalError))?; + } + + match func.def { + FuncDef::Intrinsic(f) => f(¶m_values), + FuncDef::Body => todo!(), + } + } + + fn eval_method( + &self, + expr: &ast::Expr<'_>, + method: &ast::Ident<'_>, + params: &[ast::Expr], + ) -> Result { + let expr_value = self.eval(expr)?; + let expr_type = expr_value.typ().or(Err(EvalError))?; + let type_family = TypeFamily::from(&expr_type); + + let methods = self.methods.get(&type_family).ok_or(EvalError)?; + let method = methods.get(method.name).ok_or(EvalError)?; + + let (self_param_type, func_param_types) = + method.typ.params.split_first().ok_or(EvalError)?; + self_param_type + .clone() + .unify(expr_type, Coerce::Assign) + .or(Err(EvalError))?; + + if func_param_types.len() != params.len() { + return Err(EvalError); + } + + let param_values: Vec<_> = iter::once(Ok(expr_value)) + .chain(params.iter().map(|param| self.eval(param))) + .collect::>()?; + + for (func_param_type, call_param_value) in func_param_types.iter().zip(¶m_values[1..]) { + let call_param_type = call_param_value.typ().or(Err(EvalError))?; + func_param_type + .clone() + .unify(call_param_type, Coerce::Assign) + .or(Err(EvalError))?; + } + + match method.def { + FuncDef::Intrinsic(f) => f(¶m_values), + FuncDef::Body => todo!(), + } + } + + fn eval_field(&self, expr: &ast::Expr<'_>, field: &ast::Ident<'_>) -> Result { + use Value::*; + + let expr_value = self.eval(expr)?; + let name = field.name; + + Ok(match expr_value { + Tuple(elems) => { + let index: usize = name.parse().or(Err(EvalError))?; + elems.into_iter().nth(index).ok_or(EvalError)? + } + Map(mut entries) => entries.remove(name).ok_or(EvalError)?, + _ => return Err(EvalError), + }) + } + + fn eval_path(&self, path: &ast::Path<'_>) -> Result { + if path.components == [ast::Ident { name: "_" }] { + return Err(EvalError); + } + + self.values + .lookup(&path.components) + .ok_or(EvalError) + .cloned() + } + + fn eval_literal(&self, lit: &ast::Literal<'_>) -> Result { + use ast::Literal; + use Value::*; + + Ok(match lit { + Literal::Unit => Unit, + Literal::Bool(val) => Boolean(*val), + Literal::Int(val) => Integer(i64::try_from(*val).or(Err(EvalError))?), + Literal::Str { pieces, kind } => Str(StrDisplay { + pieces, + kind: *kind, + ctx: self, + } + .to_string()), + Literal::Tuple(elems) => Tuple( + elems + .iter() + .map(|elem| self.eval(elem)) + .collect::>()?, + ), + Literal::Array(elems) => Array( + elems + .iter() + .map(|elem| self.eval(elem)) + .collect::>()?, + ), + Literal::Map(entries) => Map(entries + .iter() + .map(|ast::MapEntry { key, value }| Ok(((*key).to_owned(), self.eval(value)?))) + .collect::>()?), + }) + } +} + +#[derive(Debug)] +struct StrDisplay<'a> { + pieces: &'a [ast::StrPiece<'a>], + kind: ast::StrKind, + ctx: &'a Context, +} + +impl<'a> Display for StrDisplay<'a> { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + for piece in self.pieces { + match piece { + ast::StrPiece::Chars(chars) => f.write_str(chars)?, + ast::StrPiece::Escape(c) => f.write_char(*c)?, + ast::StrPiece::Interp(expr) => { + let val = self.ctx.eval(expr).or(Err(std::fmt::Error))?; + match self.kind { + ast::StrKind::Regular => Stringify(&val).fmt(f), + ast::StrKind::Raw => unreachable!(), + ast::StrKind::Script => ScriptStringify(&val).fmt(f), + }?; + } + }; + } + Ok(()) + } +} -- cgit v1.2.3