diff options
author | Matthias Schiffer <mschiffer@universe-factory.net> | 2024-04-28 09:56:55 +0200 |
---|---|---|
committer | Matthias Schiffer <mschiffer@universe-factory.net> | 2024-04-28 10:00:54 +0200 |
commit | b93a4d37ddb2d955ad4e4ba47020fd2b60551f42 (patch) | |
tree | 001e8e18abdfe18f236a3121fb7f4a4265df0de4 | |
parent | abd9c712b8f3fea02726a74a2daf5c7463aa790f (diff) | |
download | rebel-b93a4d37ddb2d955ad4e4ba47020fd2b60551f42.tar rebel-b93a4d37ddb2d955ad4e4ba47020fd2b60551f42.zip |
rebel-lang: merge typing and evaluation Context structs
Avoid having to convert between different context types for typechecking
and evaluation. During typechecking, upvalues etc. are represented as
None values.
-rw-r--r-- | crates/rebel-lang/examples/repl.rs | 9 | ||||
-rw-r--r-- | crates/rebel-lang/examples/type-string.rs | 25 | ||||
-rw-r--r-- | crates/rebel-lang/src/func.rs | 2 | ||||
-rw-r--r-- | crates/rebel-lang/src/scope.rs | 36 | ||||
-rw-r--r-- | crates/rebel-lang/src/typing.rs | 81 | ||||
-rw-r--r-- | crates/rebel-lang/src/value.rs | 91 |
6 files changed, 125 insertions, 119 deletions
diff --git a/crates/rebel-lang/examples/repl.rs b/crates/rebel-lang/examples/repl.rs index 0163c6b..f7d29e3 100644 --- a/crates/rebel-lang/examples/repl.rs +++ b/crates/rebel-lang/examples/repl.rs @@ -1,7 +1,8 @@ use rebel_lang::{ func::{Func, FuncDef, FuncType}, + scope::Context, typing::{ArrayLen, Type, TypeFamily}, - value::{Context, EvalError, Result, Value}, + value::{EvalError, Result, Value}, }; use rebel_parse::{recipe, tokenize}; @@ -32,7 +33,7 @@ fn main() -> rustyline::Result<()> { params: vec![Type::Array(Box::new(Type::Free), ArrayLen::Dynamic)], ret: Type::Int, }, - def: FuncDef::Intrinsic(intrinsic_array_len), + def: Some(FuncDef::Intrinsic(intrinsic_array_len)), }, ); ctx.methods.entry(TypeFamily::Str).or_default().insert( @@ -42,7 +43,7 @@ fn main() -> rustyline::Result<()> { params: vec![Type::Str], ret: Type::Int, }, - def: FuncDef::Intrinsic(intrinsic_string_len), + def: Some(FuncDef::Intrinsic(intrinsic_string_len)), }, ); @@ -64,7 +65,7 @@ fn main() -> rustyline::Result<()> { continue; } }; - let value = match ctx.eval(&expr) { + let value = match Value::eval(&ctx, &expr) { Ok(value) => value, Err(err) => { println!("Eval error: {err:?}"); diff --git a/crates/rebel-lang/examples/type-string.rs b/crates/rebel-lang/examples/type-string.rs index 5490572..8ade044 100644 --- a/crates/rebel-lang/examples/type-string.rs +++ b/crates/rebel-lang/examples/type-string.rs @@ -3,8 +3,9 @@ use std::{fmt::Debug, process, time::Instant}; use clap::Parser; use rebel_lang::{ - func::FuncType, - typing::{ArrayLen, Context, Type, TypeFamily}, + func::{Func, FuncType}, + scope::Context, + typing::{ArrayLen, Type, TypeFamily}, }; use rebel_parse::{recipe, tokenize}; @@ -47,21 +48,27 @@ fn main() { ctx.methods.entry(TypeFamily::Array).or_default().insert( "len", - FuncType { - params: vec![Type::Array(Box::new(Type::Free), ArrayLen::Dynamic)], - ret: Type::Int, + Func { + typ: FuncType { + params: vec![Type::Array(Box::new(Type::Free), ArrayLen::Dynamic)], + ret: Type::Int, + }, + def: None, }, ); ctx.methods.entry(TypeFamily::Str).or_default().insert( "len", - FuncType { - params: vec![Type::Str], - ret: Type::Int, + Func { + typ: FuncType { + params: vec![Type::Str], + ret: Type::Int, + }, + def: None, }, ); let start = Instant::now(); - let result = ctx.ast_expr_type(&expr); + let result = Type::ast_expr_type(&ctx, &expr); let dur = Instant::now().duration_since(start); println!("Typing took {} µs", dur.as_micros()); diff --git a/crates/rebel-lang/src/func.rs b/crates/rebel-lang/src/func.rs index 129d89e..8f1aed3 100644 --- a/crates/rebel-lang/src/func.rs +++ b/crates/rebel-lang/src/func.rs @@ -8,7 +8,7 @@ use crate::{ #[derive(Debug, Clone, PartialEq, Eq)] pub struct Func { pub typ: FuncType, - pub def: FuncDef, + pub def: Option<FuncDef>, } #[derive(Debug, Clone, PartialEq, Eq)] diff --git a/crates/rebel-lang/src/scope.rs b/crates/rebel-lang/src/scope.rs index b8a66ce..bae9678 100644 --- a/crates/rebel-lang/src/scope.rs +++ b/crates/rebel-lang/src/scope.rs @@ -2,18 +2,30 @@ use std::collections::HashMap; use rebel_parse::ast; -#[derive(Debug)] -pub struct Module<T>(HashMap<String, ModuleEntry<T>>); +use crate::{ + func::Func, + typing::{Type, TypeFamily}, + value::Value, +}; + +#[derive(Debug, Default)] +pub struct Context { + pub values: Module, + pub methods: HashMap<TypeFamily, HashMap<&'static str, Func>>, +} + +#[derive(Debug, Default)] +pub struct Module(HashMap<String, ModuleEntry>); -impl<T> Module<T> { - pub fn lookup(&self, path: &[ast::Ident<'_>]) -> Option<&T> { +impl Module { + pub fn lookup(&self, path: &[ast::Ident<'_>]) -> Option<(&Type, Option<&Value>)> { let (ident, rest) = path.split_first()?; match self.0.get(ident.name)? { ModuleEntry::Module(module) => module.lookup(rest), - ModuleEntry::Def(typ) => { + ModuleEntry::Def((typ, val)) => { if rest.is_empty() { - Some(typ) + Some((&typ, val.as_ref())) } else { None } @@ -22,14 +34,8 @@ impl<T> Module<T> { } } -impl<T> Default for Module<T> { - fn default() -> Self { - Self(Default::default()) - } -} - #[derive(Debug)] -pub enum ModuleEntry<T> { - Module(Module<T>), - Def(T), +pub enum ModuleEntry { + Module(Module), + Def((Type, Option<Value>)), } diff --git a/crates/rebel-lang/src/typing.rs b/crates/rebel-lang/src/typing.rs index 2545746..b7a7d34 100644 --- a/crates/rebel-lang/src/typing.rs +++ b/crates/rebel-lang/src/typing.rs @@ -4,7 +4,7 @@ use enum_kinds::EnumKind; use rebel_parse::ast; -use crate::{func::FuncType, scope::Module}; +use crate::{func::FuncType, scope::Context}; #[derive(Debug)] pub struct TypeError; @@ -77,37 +77,29 @@ impl Type { _ => return Err(TypeError), }) } -} - -#[derive(Debug, Default)] -pub struct Context { - pub values: Module<Type>, - pub methods: HashMap<TypeFamily, HashMap<&'static str, FuncType>>, -} -impl Context { - pub fn ast_expr_type(&self, expr: &ast::Expr<'_>) -> Result<Type> { + pub fn ast_expr_type(ctx: &Context, expr: &ast::Expr<'_>) -> Result<Type> { use ast::Expr::*; Ok(match expr { - Binary { left, op, right } => self.binary_op_type(left, *op, right)?, - Unary { op, expr } => self.unary_op_type(*op, expr)?, - Apply { expr, params } => self.apply_type(expr, params)?, + Binary { left, op, right } => Self::binary_op_type(ctx, left, *op, right)?, + Unary { op, expr } => Self::unary_op_type(ctx, *op, expr)?, + Apply { expr, params } => Self::apply_type(ctx, expr, params)?, Method { expr, method, params, - } => self.method_type(expr, method, params)?, - Index { expr, index } => self.index_type(expr, index)?, - Field { expr, field } => self.field_type(expr, field)?, - Paren(subexpr) => self.ast_expr_type(subexpr)?, - Path(path) => self.path_type(path)?, - Literal(lit) => self.literal_type(lit)?, + } => Self::method_type(ctx, expr, method, params)?, + Index { expr, index } => Self::index_type(ctx, expr, index)?, + Field { expr, field } => Self::field_type(ctx, expr, field)?, + Paren(subexpr) => Self::ast_expr_type(ctx, subexpr)?, + Path(path) => Self::path_type(ctx, path)?, + Literal(lit) => Self::literal_type(ctx, lit)?, }) } fn binary_op_type( - &self, + ctx: &Context, left: &ast::Expr<'_>, op: ast::OpBinary, right: &ast::Expr<'_>, @@ -115,8 +107,8 @@ impl Context { use ast::OpBinary::*; use Type::*; - let tl = self.ast_expr_type(left)?; - let tr = self.ast_expr_type(right)?; + let tl = Self::ast_expr_type(ctx, left)?; + let tr = Self::ast_expr_type(ctx, right)?; Ok(match (tl, op, tr) { (Str, Add, Str) => Str, @@ -147,11 +139,11 @@ impl Context { }) } - fn unary_op_type(&self, op: ast::OpUnary, expr: &ast::Expr<'_>) -> Result<Type> { + fn unary_op_type(ctx: &Context, op: ast::OpUnary, expr: &ast::Expr<'_>) -> Result<Type> { use ast::OpUnary::*; use Type::*; - let typ = self.ast_expr_type(expr)?; + let typ = Self::ast_expr_type(ctx, expr)?; Ok(match (op, typ) { (Not, Bool) => Bool, @@ -160,11 +152,11 @@ impl Context { }) } - fn index_type(&self, expr: &ast::Expr<'_>, index: &ast::Expr<'_>) -> Result<Type> { + fn index_type(ctx: &Context, expr: &ast::Expr<'_>, index: &ast::Expr<'_>) -> Result<Type> { use Type::*; - let expr_type = self.ast_expr_type(expr)?; - let index_type = self.ast_expr_type(index)?; + let expr_type = Self::ast_expr_type(ctx, expr)?; + let index_type = Self::ast_expr_type(ctx, index)?; let Array(elem_type, _) = expr_type else { return Err(TypeError); @@ -175,10 +167,10 @@ impl Context { Ok(*elem_type) } - fn apply_type(&self, expr: &ast::Expr<'_>, params: &[ast::Expr]) -> Result<Type> { + fn apply_type(ctx: &Context, expr: &ast::Expr<'_>, params: &[ast::Expr]) -> Result<Type> { use Type::*; - let expr_type = self.ast_expr_type(expr)?; + let expr_type = Self::ast_expr_type(ctx, expr)?; let Fn(func) = expr_type else { return Err(TypeError); @@ -189,7 +181,7 @@ impl Context { } for (func_param_type, call_param) in func.params.iter().zip(params) { - let call_param_type = self.ast_expr_type(call_param)?; + let call_param_type = Self::ast_expr_type(ctx, call_param)?; func_param_type .clone() .unify(call_param_type, Coerce::Dynamic)?; @@ -199,18 +191,18 @@ impl Context { } fn method_type( - &self, + ctx: &Context, expr: &ast::Expr<'_>, method: &ast::Ident<'_>, params: &[ast::Expr], ) -> Result<Type> { - let expr_type = self.ast_expr_type(expr)?; + let expr_type = Self::ast_expr_type(ctx, expr)?; let type_family = TypeFamily::from(&expr_type); - let methods = self.methods.get(&type_family).ok_or(TypeError)?; + let methods = ctx.methods.get(&type_family).ok_or(TypeError)?; let method = methods.get(method.name).ok_or(TypeError)?; - let (self_param, func_params) = method.params.split_first().ok_or(TypeError)?; + let (self_param, func_params) = method.typ.params.split_first().ok_or(TypeError)?; self_param.clone().unify(expr_type, Coerce::Dynamic)?; if func_params.len() != params.len() { @@ -218,19 +210,19 @@ impl Context { } for (func_param_type, call_param) in func_params.iter().zip(params) { - let call_param_type = self.ast_expr_type(call_param)?; + let call_param_type = Self::ast_expr_type(ctx, call_param)?; func_param_type .clone() .unify(call_param_type, Coerce::Dynamic)?; } - Ok(method.ret.clone()) + Ok(method.typ.ret.clone()) } - fn field_type(&self, expr: &ast::Expr<'_>, field: &ast::Ident<'_>) -> Result<Type> { + fn field_type(ctx: &Context, expr: &ast::Expr<'_>, field: &ast::Ident<'_>) -> Result<Type> { use Type::*; - let expr_type = self.ast_expr_type(expr)?; + let expr_type = Self::ast_expr_type(ctx, expr)?; let name = field.name; Ok(match expr_type { @@ -243,20 +235,21 @@ impl Context { }) } - fn path_type(&self, path: &ast::Path<'_>) -> Result<Type> { + fn path_type(ctx: &Context, path: &ast::Path<'_>) -> Result<Type> { use Type::*; if path.components == [ast::Ident { name: "_" }] { return Ok(Free); } - self.values + ctx.values .lookup(&path.components) + .map(|(typ, _)| typ) .ok_or(TypeError) .cloned() } - fn literal_type(&self, lit: &ast::Literal<'_>) -> Result<Type> { + fn literal_type(ctx: &Context, lit: &ast::Literal<'_>) -> Result<Type> { use ast::Literal; use Type::*; @@ -268,19 +261,19 @@ impl Context { Literal::Tuple(elems) => Tuple( elems .iter() - .map(|elem| self.ast_expr_type(elem)) + .map(|elem| Self::ast_expr_type(ctx, elem)) .collect::<Result<_>>()?, ), Literal::Array(elems) => Array( Box::new(elems.iter().try_fold(Type::Free, |acc, elem| { - acc.unify(self.ast_expr_type(elem)?, Coerce::Common) + acc.unify(Self::ast_expr_type(ctx, elem)?, Coerce::Common) })?), ArrayLen::Fixed(elems.len()), ), Literal::Map(entries) => Map(entries .iter() .map(|ast::MapEntry { key, value }| { - Ok(((*key).to_owned(), self.ast_expr_type(value)?)) + Ok(((*key).to_owned(), Self::ast_expr_type(ctx, value)?)) }) .collect::<Result<_>>()?), }) diff --git a/crates/rebel-lang/src/value.rs b/crates/rebel-lang/src/value.rs index 735ae63..4f30dfd 100644 --- a/crates/rebel-lang/src/value.rs +++ b/crates/rebel-lang/src/value.rs @@ -8,7 +8,7 @@ use rebel_parse::ast; use crate::{ func::{Func, FuncDef}, - scope::Module, + scope::Context, typing::{self, ArrayLen, Coerce, Type, TypeFamily}, }; @@ -61,37 +61,29 @@ impl Value { acc.unify(elem.typ()?, Coerce::Common) }) } -} - -#[derive(Debug, Default)] -pub struct Context { - pub values: Module<Value>, - pub methods: HashMap<TypeFamily, HashMap<&'static str, Func>>, -} -impl Context { - pub fn eval(&self, expr: &ast::Expr<'_>) -> Result<Value> { + pub fn eval(ctx: &Context, expr: &ast::Expr<'_>) -> Result<Value> { use ast::Expr::*; Ok(match expr { - Binary { left, op, right } => self.eval_binary_op(left, *op, right)?, - Unary { op, expr } => self.eval_unary_op(*op, expr)?, - Apply { expr, params } => self.eval_apply(expr, params)?, + Binary { left, op, right } => Self::eval_binary_op(ctx, left, *op, right)?, + Unary { op, expr } => Self::eval_unary_op(ctx, *op, expr)?, + Apply { expr, params } => Self::eval_apply(ctx, expr, params)?, Method { expr, method, params, - } => self.eval_method(expr, method, params)?, - Index { expr, index } => self.eval_index(expr, index)?, - Field { expr, field } => self.eval_field(expr, field)?, - Paren(subexpr) => self.eval(subexpr)?, - Path(path) => self.eval_path(path)?, - Literal(lit) => self.eval_literal(lit)?, + } => Self::eval_method(ctx, expr, method, params)?, + Index { expr, index } => Self::eval_index(ctx, expr, index)?, + Field { expr, field } => Self::eval_field(ctx, expr, field)?, + Paren(subexpr) => Self::eval(ctx, subexpr)?, + Path(path) => Self::eval_path(ctx, path)?, + Literal(lit) => Self::eval_literal(ctx, lit)?, }) } fn eval_binary_op( - &self, + ctx: &Context, left: &ast::Expr<'_>, op: ast::OpBinary, right: &ast::Expr<'_>, @@ -99,8 +91,8 @@ impl Context { use ast::OpBinary::*; use Value::*; - let tl = self.eval(left)?; - let tr = self.eval(right)?; + let tl = Self::eval(ctx, left)?; + let tr = Self::eval(ctx, right)?; Ok(match (tl, op, tr) { (Str(s1), Add, Str(s2)) => Str(s1 + &s2), @@ -122,11 +114,11 @@ impl Context { }) } - fn eval_unary_op(&self, op: ast::OpUnary, expr: &ast::Expr<'_>) -> Result<Value> { + fn eval_unary_op(ctx: &Context, op: ast::OpUnary, expr: &ast::Expr<'_>) -> Result<Value> { use ast::OpUnary::*; use Value::*; - let typ = self.eval(expr)?; + let typ = Self::eval(ctx, expr)?; Ok(match (op, typ) { (Not, Boolean(val)) => Boolean(!val), @@ -135,11 +127,11 @@ impl Context { }) } - fn eval_index(&self, expr: &ast::Expr<'_>, index: &ast::Expr<'_>) -> Result<Value> { + fn eval_index(ctx: &Context, expr: &ast::Expr<'_>, index: &ast::Expr<'_>) -> Result<Value> { use Value::*; - let expr_value = self.eval(expr)?; - let index_value = self.eval(index)?; + let expr_value = Self::eval(ctx, expr)?; + let index_value = Self::eval(ctx, index)?; let Array(elems) = expr_value else { return Err(EvalError); @@ -151,10 +143,10 @@ impl Context { elems.into_iter().nth(index).ok_or(EvalError) } - fn eval_apply(&self, expr: &ast::Expr<'_>, params: &[ast::Expr]) -> Result<Value> { + fn eval_apply(ctx: &Context, expr: &ast::Expr<'_>, params: &[ast::Expr]) -> Result<Value> { use Value::*; - let value = self.eval(expr)?; + let value = Self::eval(ctx, expr)?; let Fn(func) = value else { return Err(EvalError); @@ -166,7 +158,7 @@ impl Context { let param_values: Vec<_> = params .iter() - .map(|param| self.eval(param)) + .map(|param| Self::eval(ctx, param)) .collect::<Result<_>>()?; for (func_param_type, call_param_value) in func.typ.params.iter().zip(¶m_values) { @@ -177,23 +169,25 @@ impl Context { .or(Err(EvalError))?; } - match func.def { + let def = func.def.ok_or(EvalError)?; + + match def { FuncDef::Intrinsic(f) => f(¶m_values), FuncDef::Body => todo!(), } } fn eval_method( - &self, + ctx: &Context, expr: &ast::Expr<'_>, method: &ast::Ident<'_>, params: &[ast::Expr], ) -> Result<Value> { - let expr_value = self.eval(expr)?; + let expr_value = Self::eval(ctx, expr)?; let expr_type = expr_value.typ().or(Err(EvalError))?; let type_family = TypeFamily::from(&expr_type); - let methods = self.methods.get(&type_family).ok_or(EvalError)?; + let methods = ctx.methods.get(&type_family).ok_or(EvalError)?; let method = methods.get(method.name).ok_or(EvalError)?; let (self_param_type, func_param_types) = @@ -208,7 +202,7 @@ impl Context { } let param_values: Vec<_> = iter::once(Ok(expr_value)) - .chain(params.iter().map(|param| self.eval(param))) + .chain(params.iter().map(|param| Self::eval(ctx, param))) .collect::<Result<_>>()?; for (func_param_type, call_param_value) in func_param_types.iter().zip(¶m_values[1..]) { @@ -219,16 +213,18 @@ impl Context { .or(Err(EvalError))?; } - match method.def { + let def = method.def.as_ref().ok_or(EvalError)?; + + match def { FuncDef::Intrinsic(f) => f(¶m_values), FuncDef::Body => todo!(), } } - fn eval_field(&self, expr: &ast::Expr<'_>, field: &ast::Ident<'_>) -> Result<Value> { + fn eval_field(ctx: &Context, expr: &ast::Expr<'_>, field: &ast::Ident<'_>) -> Result<Value> { use Value::*; - let expr_value = self.eval(expr)?; + let expr_value = Self::eval(ctx, expr)?; let name = field.name; Ok(match expr_value { @@ -241,18 +237,19 @@ impl Context { }) } - fn eval_path(&self, path: &ast::Path<'_>) -> Result<Value> { + fn eval_path(ctx: &Context, path: &ast::Path<'_>) -> Result<Value> { if path.components == [ast::Ident { name: "_" }] { return Err(EvalError); } - self.values + ctx.values .lookup(&path.components) + .and_then(|(_, val)| val) .ok_or(EvalError) .cloned() } - fn eval_literal(&self, lit: &ast::Literal<'_>) -> Result<Value> { + fn eval_literal(ctx: &Context, lit: &ast::Literal<'_>) -> Result<Value> { use ast::Literal; use Value::*; @@ -263,24 +260,26 @@ impl Context { Literal::Str { pieces, kind } => Str(StrDisplay { pieces, kind: *kind, - ctx: self, + ctx, } .to_string()), Literal::Tuple(elems) => Tuple( elems .iter() - .map(|elem| self.eval(elem)) + .map(|elem| Self::eval(ctx, elem)) .collect::<Result<_>>()?, ), Literal::Array(elems) => Array( elems .iter() - .map(|elem| self.eval(elem)) + .map(|elem| Self::eval(ctx, elem)) .collect::<Result<_>>()?, ), Literal::Map(entries) => Map(entries .iter() - .map(|ast::MapEntry { key, value }| Ok(((*key).to_owned(), self.eval(value)?))) + .map(|ast::MapEntry { key, value }| { + Ok(((*key).to_owned(), Self::eval(ctx, value)?)) + }) .collect::<Result<_>>()?), }) } @@ -408,7 +407,7 @@ impl<'a> Display for StrDisplay<'a> { ast::StrPiece::Chars(chars) => f.write_str(chars)?, ast::StrPiece::Escape(c) => f.write_char(*c)?, ast::StrPiece::Interp(expr) => { - let val = self.ctx.eval(expr).or(Err(std::fmt::Error))?; + let val = Value::eval(self.ctx, expr).or(Err(std::fmt::Error))?; match self.kind { ast::StrKind::Regular => Stringify(&val).fmt(f), ast::StrKind::Raw => unreachable!(), |