summaryrefslogtreecommitdiffstats
path: root/crates/rebel-lang/src/value.rs
diff options
context:
space:
mode:
Diffstat (limited to 'crates/rebel-lang/src/value.rs')
-rw-r--r--crates/rebel-lang/src/value.rs422
1 files changed, 422 insertions, 0 deletions
diff --git a/crates/rebel-lang/src/value.rs b/crates/rebel-lang/src/value.rs
new file mode 100644
index 0000000..c7e971e
--- /dev/null
+++ b/crates/rebel-lang/src/value.rs
@@ -0,0 +1,422 @@
+use std::{
+ collections::HashMap,
+ fmt::{Display, Write},
+ iter,
+};
+
+use rebel_parse::ast;
+
+use crate::{
+ func::{Func, FuncDef},
+ scope::Module,
+ typing::{self, ArrayLen, Coerce, Type, TypeFamily},
+};
+
+#[derive(Debug)]
+pub struct EvalError;
+
+pub type Result<T> = std::result::Result<T, EvalError>;
+
+#[derive(Debug, Clone, PartialEq, Eq)]
+pub enum Value {
+ Unit,
+ Boolean(bool),
+ Integer(i64),
+ Str(String),
+ Tuple(Vec<Value>),
+ Array(Vec<Value>),
+ Map(HashMap<String, Value>),
+ Fn(Box<Func>),
+}
+
+impl Value {
+ pub fn typ(&self) -> typing::Result<Type> {
+ Ok(match self {
+ Value::Unit => Type::Unit,
+ Value::Boolean(_) => Type::Bool,
+ Value::Integer(_) => Type::Int,
+ Value::Str(_) => Type::Str,
+ Value::Tuple(elems) => Type::Tuple(
+ elems
+ .iter()
+ .map(|elem| elem.typ())
+ .collect::<typing::Result<_>>()?,
+ ),
+ Value::Array(elems) => Type::Array(
+ Box::new(Self::array_elem_type(elems)?),
+ ArrayLen::Fixed(elems.len()),
+ ),
+ Value::Map(entries) => Type::Map(
+ entries
+ .iter()
+ .map(|(k, v)| Ok((k.clone(), v.typ()?)))
+ .collect::<typing::Result<_>>()?,
+ ),
+ Value::Fn(func) => Type::Fn(Box::new(func.typ.clone())),
+ })
+ }
+
+ fn array_elem_type(elems: &[Value]) -> typing::Result<Type> {
+ elems.iter().try_fold(Type::Free, |acc, elem| {
+ acc.unify(elem.typ()?, Coerce::Common)
+ })
+ }
+}
+
+impl Display for Value {
+ fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
+ match self {
+ Value::Unit => f.write_str("()"),
+ Value::Boolean(value) => value.fmt(f),
+ Value::Integer(value) => value.fmt(f),
+ Value::Str(value) => write!(f, "{value:?}"),
+ Value::Tuple(elems) => {
+ let mut first = true;
+ f.write_str("(")?;
+ if elems.is_empty() {
+ f.write_str(",")?;
+ }
+ for elem in elems {
+ if !first {
+ f.write_str(", ")?;
+ }
+ first = false;
+ elem.fmt(f)?;
+ }
+ f.write_str(")")
+ }
+ Value::Array(elems) => {
+ let mut first = true;
+ f.write_str("[")?;
+ for elem in elems {
+ if !first {
+ f.write_str(", ")?;
+ }
+ first = false;
+ elem.fmt(f)?;
+ }
+ f.write_str("]")
+ }
+ Value::Map(entries) => {
+ let mut first = true;
+ f.write_str("{")?;
+ for (key, value) in entries {
+ if !first {
+ f.write_str(", ")?;
+ }
+ first = false;
+ write!(f, "{key} = {value}")?;
+ }
+ f.write_str("}")
+ }
+ Value::Fn(func) => func.typ.fmt(f),
+ }
+ }
+}
+
+#[derive(Debug)]
+struct Stringify<'a>(&'a Value);
+
+impl<'a> Display for Stringify<'a> {
+ fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
+ match self.0 {
+ Value::Boolean(value) => value.fmt(f),
+ Value::Integer(value) => value.fmt(f),
+ Value::Str(value) => value.fmt(f),
+ _ => Err(std::fmt::Error),
+ }
+ }
+}
+
+#[derive(Debug)]
+struct ScriptStringify<'a>(&'a Value);
+
+impl<'a> ScriptStringify<'a> {
+ fn fmt_list(elems: &'a [Value], f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
+ let mut first = true;
+ for elem in elems {
+ if !first {
+ f.write_char(' ')?;
+ }
+ ScriptStringify(elem).fmt(f)?;
+ first = false;
+ }
+ Ok(())
+ }
+}
+
+impl<'a> Display for ScriptStringify<'a> {
+ fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
+ match self.0 {
+ Value::Boolean(value) => {
+ value.fmt(f)?;
+ }
+ Value::Integer(value) => {
+ value.fmt(f)?;
+ }
+ Value::Str(value) => {
+ f.write_char('\'')?;
+ f.write_str(&value.replace('\'', "'\\''"))?;
+ f.write_char('\'')?;
+ }
+ Value::Array(elems) => {
+ Self::fmt_list(elems, f)?;
+ }
+ Value::Tuple(elems) => {
+ Self::fmt_list(elems, f)?;
+ }
+ _ => return Err(std::fmt::Error),
+ };
+ Ok(())
+ }
+}
+
+#[derive(Debug, Clone, Default)]
+pub struct Context {
+ pub values: Module<Value>,
+ pub methods: HashMap<TypeFamily, HashMap<&'static str, Func>>,
+}
+
+impl Context {
+ pub fn eval(&self, expr: &ast::Expr<'_>) -> Result<Value> {
+ use ast::Expr::*;
+
+ Ok(match expr {
+ Binary { left, op, right } => self.eval_binary_op(left, *op, right)?,
+ Unary { op, expr } => self.eval_unary_op(*op, expr)?,
+ Apply { expr, params } => self.eval_apply(expr, params)?,
+ Method {
+ expr,
+ method,
+ params,
+ } => self.eval_method(expr, method, params)?,
+ Index { expr, index } => self.eval_index(expr, index)?,
+ Field { expr, field } => self.eval_field(expr, field)?,
+ Paren(subexpr) => self.eval(subexpr)?,
+ Path(path) => self.eval_path(path)?,
+ Literal(lit) => self.eval_literal(lit)?,
+ })
+ }
+
+ fn eval_binary_op(
+ &self,
+ left: &ast::Expr<'_>,
+ op: ast::OpBinary,
+ right: &ast::Expr<'_>,
+ ) -> Result<Value> {
+ use ast::OpBinary::*;
+ use Value::*;
+
+ let tl = self.eval(left)?;
+ let tr = self.eval(right)?;
+
+ Ok(match (tl, op, tr) {
+ (Str(s1), Add, Str(s2)) => Str(s1 + &s2),
+ (Integer(i1), Add, Integer(i2)) => Integer(i1.checked_add(i2).ok_or(EvalError)?),
+ (Array(elems1), Add, Array(elems2)) => Array([elems1, elems2].concat()),
+ (Integer(i1), Sub, Integer(i2)) => Integer(i1.checked_sub(i2).ok_or(EvalError)?),
+ (Integer(i1), Mul, Integer(i2)) => Integer(i1.checked_mul(i2).ok_or(EvalError)?),
+ (Integer(i1), Div, Integer(i2)) => Integer(i1.checked_div(i2).ok_or(EvalError)?),
+ (Integer(i1), Rem, Integer(i2)) => Integer(i1.checked_rem(i2).ok_or(EvalError)?),
+ (Boolean(b1), And, Boolean(b2)) => Boolean(b1 && b2),
+ (Boolean(b1), Or, Boolean(b2)) => Boolean(b1 || b2),
+ (l, Eq, r) => Boolean(l == r),
+ (l, Ne, r) => Boolean(l != r),
+ (Integer(i1), Lt, Integer(i2)) => Boolean(i1 < i2),
+ (Integer(i1), Le, Integer(i2)) => Boolean(i1 <= i2),
+ (Integer(i1), Ge, Integer(i2)) => Boolean(i1 >= i2),
+ (Integer(i1), Gt, Integer(i2)) => Boolean(i1 > i2),
+ _ => return Err(EvalError),
+ })
+ }
+
+ fn eval_unary_op(&self, op: ast::OpUnary, expr: &ast::Expr<'_>) -> Result<Value> {
+ use ast::OpUnary::*;
+ use Value::*;
+
+ let typ = self.eval(expr)?;
+
+ Ok(match (op, typ) {
+ (Not, Boolean(val)) => Boolean(!val),
+ (Neg, Integer(val)) => Integer(val.checked_neg().ok_or(EvalError)?),
+ _ => return Err(EvalError),
+ })
+ }
+
+ fn eval_index(&self, expr: &ast::Expr<'_>, index: &ast::Expr<'_>) -> Result<Value> {
+ use Value::*;
+
+ let expr_value = self.eval(expr)?;
+ let index_value = self.eval(index)?;
+
+ let Array(elems) = expr_value else {
+ return Err(EvalError);
+ };
+ let Integer(index) = index_value else {
+ return Err(EvalError);
+ };
+ let index: usize = index.try_into().or(Err(EvalError))?;
+ elems.into_iter().nth(index).ok_or(EvalError)
+ }
+
+ fn eval_apply(&self, expr: &ast::Expr<'_>, params: &[ast::Expr]) -> Result<Value> {
+ use Value::*;
+
+ let value = self.eval(expr)?;
+
+ let Fn(func) = value else {
+ return Err(EvalError);
+ };
+
+ if func.typ.params.len() != params.len() {
+ return Err(EvalError);
+ }
+
+ let param_values: Vec<_> = params
+ .iter()
+ .map(|param| self.eval(param))
+ .collect::<Result<_>>()?;
+
+ for (func_param_type, call_param_value) in func.typ.params.iter().zip(&param_values) {
+ let call_param_type = call_param_value.typ().or(Err(EvalError))?;
+ func_param_type
+ .clone()
+ .unify(call_param_type, Coerce::Assign)
+ .or(Err(EvalError))?;
+ }
+
+ match func.def {
+ FuncDef::Intrinsic(f) => f(&param_values),
+ FuncDef::Body => todo!(),
+ }
+ }
+
+ fn eval_method(
+ &self,
+ expr: &ast::Expr<'_>,
+ method: &ast::Ident<'_>,
+ params: &[ast::Expr],
+ ) -> Result<Value> {
+ let expr_value = self.eval(expr)?;
+ let expr_type = expr_value.typ().or(Err(EvalError))?;
+ let type_family = TypeFamily::from(&expr_type);
+
+ let methods = self.methods.get(&type_family).ok_or(EvalError)?;
+ let method = methods.get(method.name).ok_or(EvalError)?;
+
+ let (self_param_type, func_param_types) =
+ method.typ.params.split_first().ok_or(EvalError)?;
+ self_param_type
+ .clone()
+ .unify(expr_type, Coerce::Assign)
+ .or(Err(EvalError))?;
+
+ if func_param_types.len() != params.len() {
+ return Err(EvalError);
+ }
+
+ let param_values: Vec<_> = iter::once(Ok(expr_value))
+ .chain(params.iter().map(|param| self.eval(param)))
+ .collect::<Result<_>>()?;
+
+ for (func_param_type, call_param_value) in func_param_types.iter().zip(&param_values[1..]) {
+ let call_param_type = call_param_value.typ().or(Err(EvalError))?;
+ func_param_type
+ .clone()
+ .unify(call_param_type, Coerce::Assign)
+ .or(Err(EvalError))?;
+ }
+
+ match method.def {
+ FuncDef::Intrinsic(f) => f(&param_values),
+ FuncDef::Body => todo!(),
+ }
+ }
+
+ fn eval_field(&self, expr: &ast::Expr<'_>, field: &ast::Ident<'_>) -> Result<Value> {
+ use Value::*;
+
+ let expr_value = self.eval(expr)?;
+ let name = field.name;
+
+ Ok(match expr_value {
+ Tuple(elems) => {
+ let index: usize = name.parse().or(Err(EvalError))?;
+ elems.into_iter().nth(index).ok_or(EvalError)?
+ }
+ Map(mut entries) => entries.remove(name).ok_or(EvalError)?,
+ _ => return Err(EvalError),
+ })
+ }
+
+ fn eval_path(&self, path: &ast::Path<'_>) -> Result<Value> {
+ if path.components == [ast::Ident { name: "_" }] {
+ return Err(EvalError);
+ }
+
+ self.values
+ .lookup(&path.components)
+ .ok_or(EvalError)
+ .cloned()
+ }
+
+ fn eval_literal(&self, lit: &ast::Literal<'_>) -> Result<Value> {
+ use ast::Literal;
+ use Value::*;
+
+ Ok(match lit {
+ Literal::Unit => Unit,
+ Literal::Bool(val) => Boolean(*val),
+ Literal::Int(val) => Integer(i64::try_from(*val).or(Err(EvalError))?),
+ Literal::Str { pieces, kind } => Str(StrDisplay {
+ pieces,
+ kind: *kind,
+ ctx: self,
+ }
+ .to_string()),
+ Literal::Tuple(elems) => Tuple(
+ elems
+ .iter()
+ .map(|elem| self.eval(elem))
+ .collect::<Result<_>>()?,
+ ),
+ Literal::Array(elems) => Array(
+ elems
+ .iter()
+ .map(|elem| self.eval(elem))
+ .collect::<Result<_>>()?,
+ ),
+ Literal::Map(entries) => Map(entries
+ .iter()
+ .map(|ast::MapEntry { key, value }| Ok(((*key).to_owned(), self.eval(value)?)))
+ .collect::<Result<_>>()?),
+ })
+ }
+}
+
+#[derive(Debug)]
+struct StrDisplay<'a> {
+ pieces: &'a [ast::StrPiece<'a>],
+ kind: ast::StrKind,
+ ctx: &'a Context,
+}
+
+impl<'a> Display for StrDisplay<'a> {
+ fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
+ for piece in self.pieces {
+ match piece {
+ ast::StrPiece::Chars(chars) => f.write_str(chars)?,
+ ast::StrPiece::Escape(c) => f.write_char(*c)?,
+ ast::StrPiece::Interp(expr) => {
+ let val = self.ctx.eval(expr).or(Err(std::fmt::Error))?;
+ match self.kind {
+ ast::StrKind::Regular => Stringify(&val).fmt(f),
+ ast::StrKind::Raw => unreachable!(),
+ ast::StrKind::Script => ScriptStringify(&val).fmt(f),
+ }?;
+ }
+ };
+ }
+ Ok(())
+ }
+}