diff options
Diffstat (limited to 'crates/rebel-lang/src/typing.rs')
-rw-r--r-- | crates/rebel-lang/src/typing.rs | 375 |
1 files changed, 375 insertions, 0 deletions
diff --git a/crates/rebel-lang/src/typing.rs b/crates/rebel-lang/src/typing.rs new file mode 100644 index 0000000..34492a6 --- /dev/null +++ b/crates/rebel-lang/src/typing.rs @@ -0,0 +1,375 @@ +use std::{collections::HashMap, fmt::Display}; + +use enum_kinds::EnumKind; + +use rebel_parse::ast; + +use crate::{func::FuncType, scope::Module}; + +#[derive(Debug)] +pub struct TypeError; + +pub type Result<T> = std::result::Result<T, TypeError>; + +#[derive(Debug, Clone, Copy, PartialEq, Eq)] +pub enum Coerce { + None, + Common, + Dynamic, + Assign, +} + +#[derive(Debug, Clone, PartialEq, Eq, EnumKind)] +#[enum_kind(TypeFamily, derive(Hash))] +pub enum Type { + Free, + Unit, + Bool, + Int, + Str, + Tuple(Vec<Type>), + Array(Box<Type>, ArrayLen), + Map(HashMap<String, Type>), + Fn(Box<FuncType>), +} + +impl Type { + pub fn unify(self, other: Type, coerce: Coerce) -> Result<Type> { + use Type::*; + + Ok(match (self, other) { + (Free, typ) => typ, + (typ, Free) => typ, + (Unit, Unit) => Unit, + (Bool, Bool) => Bool, + (Int, Int) => Int, + (Str, Str) => Str, + (Tuple(self_elems), Tuple(other_elems)) => { + if self_elems.len() != other_elems.len() { + return Err(TypeError); + } + Tuple( + self_elems + .into_iter() + .zip(other_elems.into_iter()) + .map(|(t1, t2)| t1.unify(t2, coerce)) + .collect::<Result<_>>()?, + ) + } + (Array(self_inner, self_len), Array(other_inner, other_len)) => Array( + Box::new(self_inner.unify(*other_inner, coerce)?), + ArrayLen::unify(self_len, other_len, coerce)?, + ), + (Map(self_entries), Map(mut other_entries)) => { + if self_entries.len() != other_entries.len() { + return Err(TypeError); + } + Map(self_entries + .into_iter() + .map(|(k, v)| { + let Some(v2) = other_entries.remove(&k) else { + return Err(TypeError); + }; + Ok((k, v.unify(v2, coerce)?)) + }) + .collect::<Result<_>>()?) + } + _ => return Err(TypeError), + }) + } +} + +impl Display for Type { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + match self { + Type::Free => write!(f, "_"), + Type::Unit => write!(f, "()"), + Type::Bool => write!(f, "bool"), + Type::Int => write!(f, "int"), + Type::Str => write!(f, "str"), + Type::Tuple(elems) => { + let mut first = true; + f.write_str("(")?; + if elems.is_empty() { + f.write_str(",")?; + } + for elem in elems { + if !first { + f.write_str(", ")?; + } + first = false; + elem.fmt(f)?; + } + f.write_str(")") + } + Type::Array(typ, len) => write!(f, "[{typ}{len}]"), + Type::Map(entries) => { + let mut first = true; + f.write_str("{")?; + for (key, typ) in entries { + if !first { + f.write_str(", ")?; + } + first = false; + write!(f, "{key}: {typ}")?; + } + f.write_str("}") + } + /* TODO */ + Type::Fn(func) => write!(f, "{func}"), + } + } +} + +#[derive(Debug, Clone, Copy, PartialEq, Eq, PartialOrd, Ord)] +pub enum ArrayLen { + Free, + Fixed(usize), + Dynamic, +} + +impl ArrayLen { + fn unify(self, other: Self, coerce: Coerce) -> Result<Self> { + use ArrayLen::*; + + Ok(match (self, other, coerce) { + (Free, len, _) => len, + (len, Free, _) => len, + (l1, l2, _) if l1 == l2 => l1, + (_, _, Coerce::Common) => Dynamic, + (Dynamic, Fixed(_), Coerce::Dynamic | Coerce::Assign) => Dynamic, + (Fixed(_), Dynamic, Coerce::Dynamic) => Dynamic, + _ => return Err(TypeError), + }) + } + + fn add(self, other: Self) -> Result<Self> { + use ArrayLen::*; + + Ok(match (self, other) { + (Free, _) => return Err(TypeError), + (_, Free) => return Err(TypeError), + (Dynamic, _) => Dynamic, + (_, Dynamic) => Dynamic, + (Fixed(l1), Fixed(l2)) => Fixed(l1 + l2), + }) + } +} + +impl Display for ArrayLen { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + match self { + ArrayLen::Free => write!(f, "; _"), + ArrayLen::Fixed(len) => write!(f, "; {len}"), + ArrayLen::Dynamic => Ok(()), + } + } +} + +#[derive(Debug, Clone, Default)] +pub struct Context { + pub values: Module<Type>, + pub methods: HashMap<TypeFamily, HashMap<&'static str, FuncType>>, +} + +impl Context { + pub fn ast_expr_type(&self, expr: &ast::Expr<'_>) -> Result<Type> { + use ast::Expr::*; + + Ok(match expr { + Binary { left, op, right } => self.binary_op_type(left, *op, right)?, + Unary { op, expr } => self.unary_op_type(*op, expr)?, + Apply { expr, params } => self.apply_type(expr, params)?, + Method { + expr, + method, + params, + } => self.method_type(expr, method, params)?, + Index { expr, index } => self.index_type(expr, index)?, + Field { expr, field } => self.field_type(expr, field)?, + Paren(subexpr) => self.ast_expr_type(subexpr)?, + Path(path) => self.path_type(path)?, + Literal(lit) => self.literal_type(lit)?, + }) + } + + fn binary_op_type( + &self, + left: &ast::Expr<'_>, + op: ast::OpBinary, + right: &ast::Expr<'_>, + ) -> Result<Type> { + use ast::OpBinary::*; + use Type::*; + + let tl = self.ast_expr_type(left)?; + let tr = self.ast_expr_type(right)?; + + Ok(match (tl, op, tr) { + (Str, Add, Str) => Str, + (Int, Add, Int) => Int, + (Array(t1, l1), Add, Array(t2, l2)) => Array( + Box::new(t1.unify(*t2, Coerce::Common)?), + ArrayLen::add(l1, l2)?, + ), + (Int, Sub, Int) => Int, + (Int, Mul, Int) => Int, + (Int, Div, Int) => Int, + (Int, Rem, Int) => Int, + (Bool, And, Bool) => Bool, + (Bool, Or, Bool) => Bool, + (l, Eq, r) => { + l.unify(r, Coerce::Dynamic)?; + Bool + } + (l, Ne, r) => { + l.unify(r, Coerce::Dynamic)?; + Bool + } + (Int, Lt, Int) => Bool, + (Int, Le, Int) => Bool, + (Int, Ge, Int) => Bool, + (Int, Gt, Int) => Bool, + _ => return Err(TypeError), + }) + } + + fn unary_op_type(&self, op: ast::OpUnary, expr: &ast::Expr<'_>) -> Result<Type> { + use ast::OpUnary::*; + use Type::*; + + let typ = self.ast_expr_type(expr)?; + + Ok(match (op, typ) { + (Not, Bool) => Bool, + (Neg, Int) => Int, + _ => return Err(TypeError), + }) + } + + fn index_type(&self, expr: &ast::Expr<'_>, index: &ast::Expr<'_>) -> Result<Type> { + use Type::*; + + let expr_type = self.ast_expr_type(expr)?; + let index_type = self.ast_expr_type(index)?; + + let Array(elem_type, _) = expr_type else { + return Err(TypeError); + }; + if index_type != Int { + return Err(TypeError); + } + Ok(*elem_type) + } + + fn apply_type(&self, expr: &ast::Expr<'_>, params: &[ast::Expr]) -> Result<Type> { + use Type::*; + + let expr_type = self.ast_expr_type(expr)?; + + let Fn(func) = expr_type else { + return Err(TypeError); + }; + + if func.params.len() != params.len() { + return Err(TypeError); + } + + for (func_param_type, call_param) in func.params.iter().zip(params) { + let call_param_type = self.ast_expr_type(call_param)?; + func_param_type + .clone() + .unify(call_param_type, Coerce::Dynamic)?; + } + + Ok(func.ret) + } + + fn method_type( + &self, + expr: &ast::Expr<'_>, + method: &ast::Ident<'_>, + params: &[ast::Expr], + ) -> Result<Type> { + let expr_type = self.ast_expr_type(expr)?; + let type_family = TypeFamily::from(&expr_type); + + let methods = self.methods.get(&type_family).ok_or(TypeError)?; + let method = methods.get(method.name).ok_or(TypeError)?; + + let (self_param, func_params) = method.params.split_first().ok_or(TypeError)?; + self_param.clone().unify(expr_type, Coerce::Dynamic)?; + + if func_params.len() != params.len() { + return Err(TypeError); + } + + for (func_param_type, call_param) in func_params.iter().zip(params) { + let call_param_type = self.ast_expr_type(call_param)?; + func_param_type + .clone() + .unify(call_param_type, Coerce::Dynamic)?; + } + + Ok(method.ret.clone()) + } + + fn field_type(&self, expr: &ast::Expr<'_>, field: &ast::Ident<'_>) -> Result<Type> { + use Type::*; + + let expr_type = self.ast_expr_type(expr)?; + let name = field.name; + + Ok(match expr_type { + Tuple(elems) => { + let index: usize = name.parse().or(Err(TypeError))?; + elems.into_iter().nth(index).ok_or(TypeError)? + } + Map(mut entries) => entries.remove(name).ok_or(TypeError)?, + _ => return Err(TypeError), + }) + } + + fn path_type(&self, path: &ast::Path<'_>) -> Result<Type> { + use Type::*; + + if path.components == [ast::Ident { name: "_" }] { + return Ok(Free); + } + + self.values + .lookup(&path.components) + .ok_or(TypeError) + .cloned() + } + + fn literal_type(&self, lit: &ast::Literal<'_>) -> Result<Type> { + use ast::Literal; + use Type::*; + + Ok(match lit { + Literal::Unit => Unit, + Literal::Bool(_) => Bool, + Literal::Int(_) => Int, + Literal::Str { .. } => Str, + Literal::Tuple(elems) => Tuple( + elems + .iter() + .map(|elem| self.ast_expr_type(elem)) + .collect::<Result<_>>()?, + ), + Literal::Array(elems) => Array( + Box::new(elems.iter().try_fold(Type::Free, |acc, elem| { + acc.unify(self.ast_expr_type(elem)?, Coerce::Common) + })?), + ArrayLen::Fixed(elems.len()), + ), + Literal::Map(entries) => Map(entries + .iter() + .map(|ast::MapEntry { key, value }| { + Ok(((*key).to_owned(), self.ast_expr_type(value)?)) + }) + .collect::<Result<_>>()?), + }) + } +} |