diff options
Diffstat (limited to 'crates/rebel-lang/src/value.rs')
-rw-r--r-- | crates/rebel-lang/src/value.rs | 422 |
1 files changed, 422 insertions, 0 deletions
diff --git a/crates/rebel-lang/src/value.rs b/crates/rebel-lang/src/value.rs new file mode 100644 index 0000000..c7e971e --- /dev/null +++ b/crates/rebel-lang/src/value.rs @@ -0,0 +1,422 @@ +use std::{ + collections::HashMap, + fmt::{Display, Write}, + iter, +}; + +use rebel_parse::ast; + +use crate::{ + func::{Func, FuncDef}, + scope::Module, + typing::{self, ArrayLen, Coerce, Type, TypeFamily}, +}; + +#[derive(Debug)] +pub struct EvalError; + +pub type Result<T> = std::result::Result<T, EvalError>; + +#[derive(Debug, Clone, PartialEq, Eq)] +pub enum Value { + Unit, + Boolean(bool), + Integer(i64), + Str(String), + Tuple(Vec<Value>), + Array(Vec<Value>), + Map(HashMap<String, Value>), + Fn(Box<Func>), +} + +impl Value { + pub fn typ(&self) -> typing::Result<Type> { + Ok(match self { + Value::Unit => Type::Unit, + Value::Boolean(_) => Type::Bool, + Value::Integer(_) => Type::Int, + Value::Str(_) => Type::Str, + Value::Tuple(elems) => Type::Tuple( + elems + .iter() + .map(|elem| elem.typ()) + .collect::<typing::Result<_>>()?, + ), + Value::Array(elems) => Type::Array( + Box::new(Self::array_elem_type(elems)?), + ArrayLen::Fixed(elems.len()), + ), + Value::Map(entries) => Type::Map( + entries + .iter() + .map(|(k, v)| Ok((k.clone(), v.typ()?))) + .collect::<typing::Result<_>>()?, + ), + Value::Fn(func) => Type::Fn(Box::new(func.typ.clone())), + }) + } + + fn array_elem_type(elems: &[Value]) -> typing::Result<Type> { + elems.iter().try_fold(Type::Free, |acc, elem| { + acc.unify(elem.typ()?, Coerce::Common) + }) + } +} + +impl Display for Value { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + match self { + Value::Unit => f.write_str("()"), + Value::Boolean(value) => value.fmt(f), + Value::Integer(value) => value.fmt(f), + Value::Str(value) => write!(f, "{value:?}"), + Value::Tuple(elems) => { + let mut first = true; + f.write_str("(")?; + if elems.is_empty() { + f.write_str(",")?; + } + for elem in elems { + if !first { + f.write_str(", ")?; + } + first = false; + elem.fmt(f)?; + } + f.write_str(")") + } + Value::Array(elems) => { + let mut first = true; + f.write_str("[")?; + for elem in elems { + if !first { + f.write_str(", ")?; + } + first = false; + elem.fmt(f)?; + } + f.write_str("]") + } + Value::Map(entries) => { + let mut first = true; + f.write_str("{")?; + for (key, value) in entries { + if !first { + f.write_str(", ")?; + } + first = false; + write!(f, "{key} = {value}")?; + } + f.write_str("}") + } + Value::Fn(func) => func.typ.fmt(f), + } + } +} + +#[derive(Debug)] +struct Stringify<'a>(&'a Value); + +impl<'a> Display for Stringify<'a> { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + match self.0 { + Value::Boolean(value) => value.fmt(f), + Value::Integer(value) => value.fmt(f), + Value::Str(value) => value.fmt(f), + _ => Err(std::fmt::Error), + } + } +} + +#[derive(Debug)] +struct ScriptStringify<'a>(&'a Value); + +impl<'a> ScriptStringify<'a> { + fn fmt_list(elems: &'a [Value], f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + let mut first = true; + for elem in elems { + if !first { + f.write_char(' ')?; + } + ScriptStringify(elem).fmt(f)?; + first = false; + } + Ok(()) + } +} + +impl<'a> Display for ScriptStringify<'a> { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + match self.0 { + Value::Boolean(value) => { + value.fmt(f)?; + } + Value::Integer(value) => { + value.fmt(f)?; + } + Value::Str(value) => { + f.write_char('\'')?; + f.write_str(&value.replace('\'', "'\\''"))?; + f.write_char('\'')?; + } + Value::Array(elems) => { + Self::fmt_list(elems, f)?; + } + Value::Tuple(elems) => { + Self::fmt_list(elems, f)?; + } + _ => return Err(std::fmt::Error), + }; + Ok(()) + } +} + +#[derive(Debug, Clone, Default)] +pub struct Context { + pub values: Module<Value>, + pub methods: HashMap<TypeFamily, HashMap<&'static str, Func>>, +} + +impl Context { + pub fn eval(&self, expr: &ast::Expr<'_>) -> Result<Value> { + use ast::Expr::*; + + Ok(match expr { + Binary { left, op, right } => self.eval_binary_op(left, *op, right)?, + Unary { op, expr } => self.eval_unary_op(*op, expr)?, + Apply { expr, params } => self.eval_apply(expr, params)?, + Method { + expr, + method, + params, + } => self.eval_method(expr, method, params)?, + Index { expr, index } => self.eval_index(expr, index)?, + Field { expr, field } => self.eval_field(expr, field)?, + Paren(subexpr) => self.eval(subexpr)?, + Path(path) => self.eval_path(path)?, + Literal(lit) => self.eval_literal(lit)?, + }) + } + + fn eval_binary_op( + &self, + left: &ast::Expr<'_>, + op: ast::OpBinary, + right: &ast::Expr<'_>, + ) -> Result<Value> { + use ast::OpBinary::*; + use Value::*; + + let tl = self.eval(left)?; + let tr = self.eval(right)?; + + Ok(match (tl, op, tr) { + (Str(s1), Add, Str(s2)) => Str(s1 + &s2), + (Integer(i1), Add, Integer(i2)) => Integer(i1.checked_add(i2).ok_or(EvalError)?), + (Array(elems1), Add, Array(elems2)) => Array([elems1, elems2].concat()), + (Integer(i1), Sub, Integer(i2)) => Integer(i1.checked_sub(i2).ok_or(EvalError)?), + (Integer(i1), Mul, Integer(i2)) => Integer(i1.checked_mul(i2).ok_or(EvalError)?), + (Integer(i1), Div, Integer(i2)) => Integer(i1.checked_div(i2).ok_or(EvalError)?), + (Integer(i1), Rem, Integer(i2)) => Integer(i1.checked_rem(i2).ok_or(EvalError)?), + (Boolean(b1), And, Boolean(b2)) => Boolean(b1 && b2), + (Boolean(b1), Or, Boolean(b2)) => Boolean(b1 || b2), + (l, Eq, r) => Boolean(l == r), + (l, Ne, r) => Boolean(l != r), + (Integer(i1), Lt, Integer(i2)) => Boolean(i1 < i2), + (Integer(i1), Le, Integer(i2)) => Boolean(i1 <= i2), + (Integer(i1), Ge, Integer(i2)) => Boolean(i1 >= i2), + (Integer(i1), Gt, Integer(i2)) => Boolean(i1 > i2), + _ => return Err(EvalError), + }) + } + + fn eval_unary_op(&self, op: ast::OpUnary, expr: &ast::Expr<'_>) -> Result<Value> { + use ast::OpUnary::*; + use Value::*; + + let typ = self.eval(expr)?; + + Ok(match (op, typ) { + (Not, Boolean(val)) => Boolean(!val), + (Neg, Integer(val)) => Integer(val.checked_neg().ok_or(EvalError)?), + _ => return Err(EvalError), + }) + } + + fn eval_index(&self, expr: &ast::Expr<'_>, index: &ast::Expr<'_>) -> Result<Value> { + use Value::*; + + let expr_value = self.eval(expr)?; + let index_value = self.eval(index)?; + + let Array(elems) = expr_value else { + return Err(EvalError); + }; + let Integer(index) = index_value else { + return Err(EvalError); + }; + let index: usize = index.try_into().or(Err(EvalError))?; + elems.into_iter().nth(index).ok_or(EvalError) + } + + fn eval_apply(&self, expr: &ast::Expr<'_>, params: &[ast::Expr]) -> Result<Value> { + use Value::*; + + let value = self.eval(expr)?; + + let Fn(func) = value else { + return Err(EvalError); + }; + + if func.typ.params.len() != params.len() { + return Err(EvalError); + } + + let param_values: Vec<_> = params + .iter() + .map(|param| self.eval(param)) + .collect::<Result<_>>()?; + + for (func_param_type, call_param_value) in func.typ.params.iter().zip(¶m_values) { + let call_param_type = call_param_value.typ().or(Err(EvalError))?; + func_param_type + .clone() + .unify(call_param_type, Coerce::Assign) + .or(Err(EvalError))?; + } + + match func.def { + FuncDef::Intrinsic(f) => f(¶m_values), + FuncDef::Body => todo!(), + } + } + + fn eval_method( + &self, + expr: &ast::Expr<'_>, + method: &ast::Ident<'_>, + params: &[ast::Expr], + ) -> Result<Value> { + let expr_value = self.eval(expr)?; + let expr_type = expr_value.typ().or(Err(EvalError))?; + let type_family = TypeFamily::from(&expr_type); + + let methods = self.methods.get(&type_family).ok_or(EvalError)?; + let method = methods.get(method.name).ok_or(EvalError)?; + + let (self_param_type, func_param_types) = + method.typ.params.split_first().ok_or(EvalError)?; + self_param_type + .clone() + .unify(expr_type, Coerce::Assign) + .or(Err(EvalError))?; + + if func_param_types.len() != params.len() { + return Err(EvalError); + } + + let param_values: Vec<_> = iter::once(Ok(expr_value)) + .chain(params.iter().map(|param| self.eval(param))) + .collect::<Result<_>>()?; + + for (func_param_type, call_param_value) in func_param_types.iter().zip(¶m_values[1..]) { + let call_param_type = call_param_value.typ().or(Err(EvalError))?; + func_param_type + .clone() + .unify(call_param_type, Coerce::Assign) + .or(Err(EvalError))?; + } + + match method.def { + FuncDef::Intrinsic(f) => f(¶m_values), + FuncDef::Body => todo!(), + } + } + + fn eval_field(&self, expr: &ast::Expr<'_>, field: &ast::Ident<'_>) -> Result<Value> { + use Value::*; + + let expr_value = self.eval(expr)?; + let name = field.name; + + Ok(match expr_value { + Tuple(elems) => { + let index: usize = name.parse().or(Err(EvalError))?; + elems.into_iter().nth(index).ok_or(EvalError)? + } + Map(mut entries) => entries.remove(name).ok_or(EvalError)?, + _ => return Err(EvalError), + }) + } + + fn eval_path(&self, path: &ast::Path<'_>) -> Result<Value> { + if path.components == [ast::Ident { name: "_" }] { + return Err(EvalError); + } + + self.values + .lookup(&path.components) + .ok_or(EvalError) + .cloned() + } + + fn eval_literal(&self, lit: &ast::Literal<'_>) -> Result<Value> { + use ast::Literal; + use Value::*; + + Ok(match lit { + Literal::Unit => Unit, + Literal::Bool(val) => Boolean(*val), + Literal::Int(val) => Integer(i64::try_from(*val).or(Err(EvalError))?), + Literal::Str { pieces, kind } => Str(StrDisplay { + pieces, + kind: *kind, + ctx: self, + } + .to_string()), + Literal::Tuple(elems) => Tuple( + elems + .iter() + .map(|elem| self.eval(elem)) + .collect::<Result<_>>()?, + ), + Literal::Array(elems) => Array( + elems + .iter() + .map(|elem| self.eval(elem)) + .collect::<Result<_>>()?, + ), + Literal::Map(entries) => Map(entries + .iter() + .map(|ast::MapEntry { key, value }| Ok(((*key).to_owned(), self.eval(value)?))) + .collect::<Result<_>>()?), + }) + } +} + +#[derive(Debug)] +struct StrDisplay<'a> { + pieces: &'a [ast::StrPiece<'a>], + kind: ast::StrKind, + ctx: &'a Context, +} + +impl<'a> Display for StrDisplay<'a> { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + for piece in self.pieces { + match piece { + ast::StrPiece::Chars(chars) => f.write_str(chars)?, + ast::StrPiece::Escape(c) => f.write_char(*c)?, + ast::StrPiece::Interp(expr) => { + let val = self.ctx.eval(expr).or(Err(std::fmt::Error))?; + match self.kind { + ast::StrKind::Regular => Stringify(&val).fmt(f), + ast::StrKind::Raw => unreachable!(), + ast::StrKind::Script => ScriptStringify(&val).fmt(f), + }?; + } + }; + } + Ok(()) + } +} |