diff options
Diffstat (limited to 'crates/rebel-lang/src/typing.rs')
-rw-r--r-- | crates/rebel-lang/src/typing.rs | 81 |
1 files changed, 37 insertions, 44 deletions
diff --git a/crates/rebel-lang/src/typing.rs b/crates/rebel-lang/src/typing.rs index 2545746..b7a7d34 100644 --- a/crates/rebel-lang/src/typing.rs +++ b/crates/rebel-lang/src/typing.rs @@ -4,7 +4,7 @@ use enum_kinds::EnumKind; use rebel_parse::ast; -use crate::{func::FuncType, scope::Module}; +use crate::{func::FuncType, scope::Context}; #[derive(Debug)] pub struct TypeError; @@ -77,37 +77,29 @@ impl Type { _ => return Err(TypeError), }) } -} - -#[derive(Debug, Default)] -pub struct Context { - pub values: Module<Type>, - pub methods: HashMap<TypeFamily, HashMap<&'static str, FuncType>>, -} -impl Context { - pub fn ast_expr_type(&self, expr: &ast::Expr<'_>) -> Result<Type> { + pub fn ast_expr_type(ctx: &Context, expr: &ast::Expr<'_>) -> Result<Type> { use ast::Expr::*; Ok(match expr { - Binary { left, op, right } => self.binary_op_type(left, *op, right)?, - Unary { op, expr } => self.unary_op_type(*op, expr)?, - Apply { expr, params } => self.apply_type(expr, params)?, + Binary { left, op, right } => Self::binary_op_type(ctx, left, *op, right)?, + Unary { op, expr } => Self::unary_op_type(ctx, *op, expr)?, + Apply { expr, params } => Self::apply_type(ctx, expr, params)?, Method { expr, method, params, - } => self.method_type(expr, method, params)?, - Index { expr, index } => self.index_type(expr, index)?, - Field { expr, field } => self.field_type(expr, field)?, - Paren(subexpr) => self.ast_expr_type(subexpr)?, - Path(path) => self.path_type(path)?, - Literal(lit) => self.literal_type(lit)?, + } => Self::method_type(ctx, expr, method, params)?, + Index { expr, index } => Self::index_type(ctx, expr, index)?, + Field { expr, field } => Self::field_type(ctx, expr, field)?, + Paren(subexpr) => Self::ast_expr_type(ctx, subexpr)?, + Path(path) => Self::path_type(ctx, path)?, + Literal(lit) => Self::literal_type(ctx, lit)?, }) } fn binary_op_type( - &self, + ctx: &Context, left: &ast::Expr<'_>, op: ast::OpBinary, right: &ast::Expr<'_>, @@ -115,8 +107,8 @@ impl Context { use ast::OpBinary::*; use Type::*; - let tl = self.ast_expr_type(left)?; - let tr = self.ast_expr_type(right)?; + let tl = Self::ast_expr_type(ctx, left)?; + let tr = Self::ast_expr_type(ctx, right)?; Ok(match (tl, op, tr) { (Str, Add, Str) => Str, @@ -147,11 +139,11 @@ impl Context { }) } - fn unary_op_type(&self, op: ast::OpUnary, expr: &ast::Expr<'_>) -> Result<Type> { + fn unary_op_type(ctx: &Context, op: ast::OpUnary, expr: &ast::Expr<'_>) -> Result<Type> { use ast::OpUnary::*; use Type::*; - let typ = self.ast_expr_type(expr)?; + let typ = Self::ast_expr_type(ctx, expr)?; Ok(match (op, typ) { (Not, Bool) => Bool, @@ -160,11 +152,11 @@ impl Context { }) } - fn index_type(&self, expr: &ast::Expr<'_>, index: &ast::Expr<'_>) -> Result<Type> { + fn index_type(ctx: &Context, expr: &ast::Expr<'_>, index: &ast::Expr<'_>) -> Result<Type> { use Type::*; - let expr_type = self.ast_expr_type(expr)?; - let index_type = self.ast_expr_type(index)?; + let expr_type = Self::ast_expr_type(ctx, expr)?; + let index_type = Self::ast_expr_type(ctx, index)?; let Array(elem_type, _) = expr_type else { return Err(TypeError); @@ -175,10 +167,10 @@ impl Context { Ok(*elem_type) } - fn apply_type(&self, expr: &ast::Expr<'_>, params: &[ast::Expr]) -> Result<Type> { + fn apply_type(ctx: &Context, expr: &ast::Expr<'_>, params: &[ast::Expr]) -> Result<Type> { use Type::*; - let expr_type = self.ast_expr_type(expr)?; + let expr_type = Self::ast_expr_type(ctx, expr)?; let Fn(func) = expr_type else { return Err(TypeError); @@ -189,7 +181,7 @@ impl Context { } for (func_param_type, call_param) in func.params.iter().zip(params) { - let call_param_type = self.ast_expr_type(call_param)?; + let call_param_type = Self::ast_expr_type(ctx, call_param)?; func_param_type .clone() .unify(call_param_type, Coerce::Dynamic)?; @@ -199,18 +191,18 @@ impl Context { } fn method_type( - &self, + ctx: &Context, expr: &ast::Expr<'_>, method: &ast::Ident<'_>, params: &[ast::Expr], ) -> Result<Type> { - let expr_type = self.ast_expr_type(expr)?; + let expr_type = Self::ast_expr_type(ctx, expr)?; let type_family = TypeFamily::from(&expr_type); - let methods = self.methods.get(&type_family).ok_or(TypeError)?; + let methods = ctx.methods.get(&type_family).ok_or(TypeError)?; let method = methods.get(method.name).ok_or(TypeError)?; - let (self_param, func_params) = method.params.split_first().ok_or(TypeError)?; + let (self_param, func_params) = method.typ.params.split_first().ok_or(TypeError)?; self_param.clone().unify(expr_type, Coerce::Dynamic)?; if func_params.len() != params.len() { @@ -218,19 +210,19 @@ impl Context { } for (func_param_type, call_param) in func_params.iter().zip(params) { - let call_param_type = self.ast_expr_type(call_param)?; + let call_param_type = Self::ast_expr_type(ctx, call_param)?; func_param_type .clone() .unify(call_param_type, Coerce::Dynamic)?; } - Ok(method.ret.clone()) + Ok(method.typ.ret.clone()) } - fn field_type(&self, expr: &ast::Expr<'_>, field: &ast::Ident<'_>) -> Result<Type> { + fn field_type(ctx: &Context, expr: &ast::Expr<'_>, field: &ast::Ident<'_>) -> Result<Type> { use Type::*; - let expr_type = self.ast_expr_type(expr)?; + let expr_type = Self::ast_expr_type(ctx, expr)?; let name = field.name; Ok(match expr_type { @@ -243,20 +235,21 @@ impl Context { }) } - fn path_type(&self, path: &ast::Path<'_>) -> Result<Type> { + fn path_type(ctx: &Context, path: &ast::Path<'_>) -> Result<Type> { use Type::*; if path.components == [ast::Ident { name: "_" }] { return Ok(Free); } - self.values + ctx.values .lookup(&path.components) + .map(|(typ, _)| typ) .ok_or(TypeError) .cloned() } - fn literal_type(&self, lit: &ast::Literal<'_>) -> Result<Type> { + fn literal_type(ctx: &Context, lit: &ast::Literal<'_>) -> Result<Type> { use ast::Literal; use Type::*; @@ -268,19 +261,19 @@ impl Context { Literal::Tuple(elems) => Tuple( elems .iter() - .map(|elem| self.ast_expr_type(elem)) + .map(|elem| Self::ast_expr_type(ctx, elem)) .collect::<Result<_>>()?, ), Literal::Array(elems) => Array( Box::new(elems.iter().try_fold(Type::Free, |acc, elem| { - acc.unify(self.ast_expr_type(elem)?, Coerce::Common) + acc.unify(Self::ast_expr_type(ctx, elem)?, Coerce::Common) })?), ArrayLen::Fixed(elems.len()), ), Literal::Map(entries) => Map(entries .iter() .map(|ast::MapEntry { key, value }| { - Ok(((*key).to_owned(), self.ast_expr_type(value)?)) + Ok(((*key).to_owned(), Self::ast_expr_type(ctx, value)?)) }) .collect::<Result<_>>()?), }) |