summaryrefslogtreecommitdiffstats
path: root/crates/rebel-lang/src/typing.rs
diff options
context:
space:
mode:
Diffstat (limited to 'crates/rebel-lang/src/typing.rs')
-rw-r--r--crates/rebel-lang/src/typing.rs375
1 files changed, 375 insertions, 0 deletions
diff --git a/crates/rebel-lang/src/typing.rs b/crates/rebel-lang/src/typing.rs
new file mode 100644
index 0000000..34492a6
--- /dev/null
+++ b/crates/rebel-lang/src/typing.rs
@@ -0,0 +1,375 @@
+use std::{collections::HashMap, fmt::Display};
+
+use enum_kinds::EnumKind;
+
+use rebel_parse::ast;
+
+use crate::{func::FuncType, scope::Module};
+
+#[derive(Debug)]
+pub struct TypeError;
+
+pub type Result<T> = std::result::Result<T, TypeError>;
+
+#[derive(Debug, Clone, Copy, PartialEq, Eq)]
+pub enum Coerce {
+ None,
+ Common,
+ Dynamic,
+ Assign,
+}
+
+#[derive(Debug, Clone, PartialEq, Eq, EnumKind)]
+#[enum_kind(TypeFamily, derive(Hash))]
+pub enum Type {
+ Free,
+ Unit,
+ Bool,
+ Int,
+ Str,
+ Tuple(Vec<Type>),
+ Array(Box<Type>, ArrayLen),
+ Map(HashMap<String, Type>),
+ Fn(Box<FuncType>),
+}
+
+impl Type {
+ pub fn unify(self, other: Type, coerce: Coerce) -> Result<Type> {
+ use Type::*;
+
+ Ok(match (self, other) {
+ (Free, typ) => typ,
+ (typ, Free) => typ,
+ (Unit, Unit) => Unit,
+ (Bool, Bool) => Bool,
+ (Int, Int) => Int,
+ (Str, Str) => Str,
+ (Tuple(self_elems), Tuple(other_elems)) => {
+ if self_elems.len() != other_elems.len() {
+ return Err(TypeError);
+ }
+ Tuple(
+ self_elems
+ .into_iter()
+ .zip(other_elems.into_iter())
+ .map(|(t1, t2)| t1.unify(t2, coerce))
+ .collect::<Result<_>>()?,
+ )
+ }
+ (Array(self_inner, self_len), Array(other_inner, other_len)) => Array(
+ Box::new(self_inner.unify(*other_inner, coerce)?),
+ ArrayLen::unify(self_len, other_len, coerce)?,
+ ),
+ (Map(self_entries), Map(mut other_entries)) => {
+ if self_entries.len() != other_entries.len() {
+ return Err(TypeError);
+ }
+ Map(self_entries
+ .into_iter()
+ .map(|(k, v)| {
+ let Some(v2) = other_entries.remove(&k) else {
+ return Err(TypeError);
+ };
+ Ok((k, v.unify(v2, coerce)?))
+ })
+ .collect::<Result<_>>()?)
+ }
+ _ => return Err(TypeError),
+ })
+ }
+}
+
+impl Display for Type {
+ fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
+ match self {
+ Type::Free => write!(f, "_"),
+ Type::Unit => write!(f, "()"),
+ Type::Bool => write!(f, "bool"),
+ Type::Int => write!(f, "int"),
+ Type::Str => write!(f, "str"),
+ Type::Tuple(elems) => {
+ let mut first = true;
+ f.write_str("(")?;
+ if elems.is_empty() {
+ f.write_str(",")?;
+ }
+ for elem in elems {
+ if !first {
+ f.write_str(", ")?;
+ }
+ first = false;
+ elem.fmt(f)?;
+ }
+ f.write_str(")")
+ }
+ Type::Array(typ, len) => write!(f, "[{typ}{len}]"),
+ Type::Map(entries) => {
+ let mut first = true;
+ f.write_str("{")?;
+ for (key, typ) in entries {
+ if !first {
+ f.write_str(", ")?;
+ }
+ first = false;
+ write!(f, "{key}: {typ}")?;
+ }
+ f.write_str("}")
+ }
+ /* TODO */
+ Type::Fn(func) => write!(f, "{func}"),
+ }
+ }
+}
+
+#[derive(Debug, Clone, Copy, PartialEq, Eq, PartialOrd, Ord)]
+pub enum ArrayLen {
+ Free,
+ Fixed(usize),
+ Dynamic,
+}
+
+impl ArrayLen {
+ fn unify(self, other: Self, coerce: Coerce) -> Result<Self> {
+ use ArrayLen::*;
+
+ Ok(match (self, other, coerce) {
+ (Free, len, _) => len,
+ (len, Free, _) => len,
+ (l1, l2, _) if l1 == l2 => l1,
+ (_, _, Coerce::Common) => Dynamic,
+ (Dynamic, Fixed(_), Coerce::Dynamic | Coerce::Assign) => Dynamic,
+ (Fixed(_), Dynamic, Coerce::Dynamic) => Dynamic,
+ _ => return Err(TypeError),
+ })
+ }
+
+ fn add(self, other: Self) -> Result<Self> {
+ use ArrayLen::*;
+
+ Ok(match (self, other) {
+ (Free, _) => return Err(TypeError),
+ (_, Free) => return Err(TypeError),
+ (Dynamic, _) => Dynamic,
+ (_, Dynamic) => Dynamic,
+ (Fixed(l1), Fixed(l2)) => Fixed(l1 + l2),
+ })
+ }
+}
+
+impl Display for ArrayLen {
+ fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
+ match self {
+ ArrayLen::Free => write!(f, "; _"),
+ ArrayLen::Fixed(len) => write!(f, "; {len}"),
+ ArrayLen::Dynamic => Ok(()),
+ }
+ }
+}
+
+#[derive(Debug, Clone, Default)]
+pub struct Context {
+ pub values: Module<Type>,
+ pub methods: HashMap<TypeFamily, HashMap<&'static str, FuncType>>,
+}
+
+impl Context {
+ pub fn ast_expr_type(&self, expr: &ast::Expr<'_>) -> Result<Type> {
+ use ast::Expr::*;
+
+ Ok(match expr {
+ Binary { left, op, right } => self.binary_op_type(left, *op, right)?,
+ Unary { op, expr } => self.unary_op_type(*op, expr)?,
+ Apply { expr, params } => self.apply_type(expr, params)?,
+ Method {
+ expr,
+ method,
+ params,
+ } => self.method_type(expr, method, params)?,
+ Index { expr, index } => self.index_type(expr, index)?,
+ Field { expr, field } => self.field_type(expr, field)?,
+ Paren(subexpr) => self.ast_expr_type(subexpr)?,
+ Path(path) => self.path_type(path)?,
+ Literal(lit) => self.literal_type(lit)?,
+ })
+ }
+
+ fn binary_op_type(
+ &self,
+ left: &ast::Expr<'_>,
+ op: ast::OpBinary,
+ right: &ast::Expr<'_>,
+ ) -> Result<Type> {
+ use ast::OpBinary::*;
+ use Type::*;
+
+ let tl = self.ast_expr_type(left)?;
+ let tr = self.ast_expr_type(right)?;
+
+ Ok(match (tl, op, tr) {
+ (Str, Add, Str) => Str,
+ (Int, Add, Int) => Int,
+ (Array(t1, l1), Add, Array(t2, l2)) => Array(
+ Box::new(t1.unify(*t2, Coerce::Common)?),
+ ArrayLen::add(l1, l2)?,
+ ),
+ (Int, Sub, Int) => Int,
+ (Int, Mul, Int) => Int,
+ (Int, Div, Int) => Int,
+ (Int, Rem, Int) => Int,
+ (Bool, And, Bool) => Bool,
+ (Bool, Or, Bool) => Bool,
+ (l, Eq, r) => {
+ l.unify(r, Coerce::Dynamic)?;
+ Bool
+ }
+ (l, Ne, r) => {
+ l.unify(r, Coerce::Dynamic)?;
+ Bool
+ }
+ (Int, Lt, Int) => Bool,
+ (Int, Le, Int) => Bool,
+ (Int, Ge, Int) => Bool,
+ (Int, Gt, Int) => Bool,
+ _ => return Err(TypeError),
+ })
+ }
+
+ fn unary_op_type(&self, op: ast::OpUnary, expr: &ast::Expr<'_>) -> Result<Type> {
+ use ast::OpUnary::*;
+ use Type::*;
+
+ let typ = self.ast_expr_type(expr)?;
+
+ Ok(match (op, typ) {
+ (Not, Bool) => Bool,
+ (Neg, Int) => Int,
+ _ => return Err(TypeError),
+ })
+ }
+
+ fn index_type(&self, expr: &ast::Expr<'_>, index: &ast::Expr<'_>) -> Result<Type> {
+ use Type::*;
+
+ let expr_type = self.ast_expr_type(expr)?;
+ let index_type = self.ast_expr_type(index)?;
+
+ let Array(elem_type, _) = expr_type else {
+ return Err(TypeError);
+ };
+ if index_type != Int {
+ return Err(TypeError);
+ }
+ Ok(*elem_type)
+ }
+
+ fn apply_type(&self, expr: &ast::Expr<'_>, params: &[ast::Expr]) -> Result<Type> {
+ use Type::*;
+
+ let expr_type = self.ast_expr_type(expr)?;
+
+ let Fn(func) = expr_type else {
+ return Err(TypeError);
+ };
+
+ if func.params.len() != params.len() {
+ return Err(TypeError);
+ }
+
+ for (func_param_type, call_param) in func.params.iter().zip(params) {
+ let call_param_type = self.ast_expr_type(call_param)?;
+ func_param_type
+ .clone()
+ .unify(call_param_type, Coerce::Dynamic)?;
+ }
+
+ Ok(func.ret)
+ }
+
+ fn method_type(
+ &self,
+ expr: &ast::Expr<'_>,
+ method: &ast::Ident<'_>,
+ params: &[ast::Expr],
+ ) -> Result<Type> {
+ let expr_type = self.ast_expr_type(expr)?;
+ let type_family = TypeFamily::from(&expr_type);
+
+ let methods = self.methods.get(&type_family).ok_or(TypeError)?;
+ let method = methods.get(method.name).ok_or(TypeError)?;
+
+ let (self_param, func_params) = method.params.split_first().ok_or(TypeError)?;
+ self_param.clone().unify(expr_type, Coerce::Dynamic)?;
+
+ if func_params.len() != params.len() {
+ return Err(TypeError);
+ }
+
+ for (func_param_type, call_param) in func_params.iter().zip(params) {
+ let call_param_type = self.ast_expr_type(call_param)?;
+ func_param_type
+ .clone()
+ .unify(call_param_type, Coerce::Dynamic)?;
+ }
+
+ Ok(method.ret.clone())
+ }
+
+ fn field_type(&self, expr: &ast::Expr<'_>, field: &ast::Ident<'_>) -> Result<Type> {
+ use Type::*;
+
+ let expr_type = self.ast_expr_type(expr)?;
+ let name = field.name;
+
+ Ok(match expr_type {
+ Tuple(elems) => {
+ let index: usize = name.parse().or(Err(TypeError))?;
+ elems.into_iter().nth(index).ok_or(TypeError)?
+ }
+ Map(mut entries) => entries.remove(name).ok_or(TypeError)?,
+ _ => return Err(TypeError),
+ })
+ }
+
+ fn path_type(&self, path: &ast::Path<'_>) -> Result<Type> {
+ use Type::*;
+
+ if path.components == [ast::Ident { name: "_" }] {
+ return Ok(Free);
+ }
+
+ self.values
+ .lookup(&path.components)
+ .ok_or(TypeError)
+ .cloned()
+ }
+
+ fn literal_type(&self, lit: &ast::Literal<'_>) -> Result<Type> {
+ use ast::Literal;
+ use Type::*;
+
+ Ok(match lit {
+ Literal::Unit => Unit,
+ Literal::Bool(_) => Bool,
+ Literal::Int(_) => Int,
+ Literal::Str { .. } => Str,
+ Literal::Tuple(elems) => Tuple(
+ elems
+ .iter()
+ .map(|elem| self.ast_expr_type(elem))
+ .collect::<Result<_>>()?,
+ ),
+ Literal::Array(elems) => Array(
+ Box::new(elems.iter().try_fold(Type::Free, |acc, elem| {
+ acc.unify(self.ast_expr_type(elem)?, Coerce::Common)
+ })?),
+ ArrayLen::Fixed(elems.len()),
+ ),
+ Literal::Map(entries) => Map(entries
+ .iter()
+ .map(|ast::MapEntry { key, value }| {
+ Ok(((*key).to_owned(), self.ast_expr_type(value)?))
+ })
+ .collect::<Result<_>>()?),
+ })
+ }
+}