summaryrefslogtreecommitdiffstats
path: root/crates/rebel-lang/src/typing.rs
diff options
context:
space:
mode:
Diffstat (limited to 'crates/rebel-lang/src/typing.rs')
-rw-r--r--crates/rebel-lang/src/typing.rs81
1 files changed, 37 insertions, 44 deletions
diff --git a/crates/rebel-lang/src/typing.rs b/crates/rebel-lang/src/typing.rs
index 2545746..b7a7d34 100644
--- a/crates/rebel-lang/src/typing.rs
+++ b/crates/rebel-lang/src/typing.rs
@@ -4,7 +4,7 @@ use enum_kinds::EnumKind;
use rebel_parse::ast;
-use crate::{func::FuncType, scope::Module};
+use crate::{func::FuncType, scope::Context};
#[derive(Debug)]
pub struct TypeError;
@@ -77,37 +77,29 @@ impl Type {
_ => return Err(TypeError),
})
}
-}
-
-#[derive(Debug, Default)]
-pub struct Context {
- pub values: Module<Type>,
- pub methods: HashMap<TypeFamily, HashMap<&'static str, FuncType>>,
-}
-impl Context {
- pub fn ast_expr_type(&self, expr: &ast::Expr<'_>) -> Result<Type> {
+ pub fn ast_expr_type(ctx: &Context, expr: &ast::Expr<'_>) -> Result<Type> {
use ast::Expr::*;
Ok(match expr {
- Binary { left, op, right } => self.binary_op_type(left, *op, right)?,
- Unary { op, expr } => self.unary_op_type(*op, expr)?,
- Apply { expr, params } => self.apply_type(expr, params)?,
+ Binary { left, op, right } => Self::binary_op_type(ctx, left, *op, right)?,
+ Unary { op, expr } => Self::unary_op_type(ctx, *op, expr)?,
+ Apply { expr, params } => Self::apply_type(ctx, expr, params)?,
Method {
expr,
method,
params,
- } => self.method_type(expr, method, params)?,
- Index { expr, index } => self.index_type(expr, index)?,
- Field { expr, field } => self.field_type(expr, field)?,
- Paren(subexpr) => self.ast_expr_type(subexpr)?,
- Path(path) => self.path_type(path)?,
- Literal(lit) => self.literal_type(lit)?,
+ } => Self::method_type(ctx, expr, method, params)?,
+ Index { expr, index } => Self::index_type(ctx, expr, index)?,
+ Field { expr, field } => Self::field_type(ctx, expr, field)?,
+ Paren(subexpr) => Self::ast_expr_type(ctx, subexpr)?,
+ Path(path) => Self::path_type(ctx, path)?,
+ Literal(lit) => Self::literal_type(ctx, lit)?,
})
}
fn binary_op_type(
- &self,
+ ctx: &Context,
left: &ast::Expr<'_>,
op: ast::OpBinary,
right: &ast::Expr<'_>,
@@ -115,8 +107,8 @@ impl Context {
use ast::OpBinary::*;
use Type::*;
- let tl = self.ast_expr_type(left)?;
- let tr = self.ast_expr_type(right)?;
+ let tl = Self::ast_expr_type(ctx, left)?;
+ let tr = Self::ast_expr_type(ctx, right)?;
Ok(match (tl, op, tr) {
(Str, Add, Str) => Str,
@@ -147,11 +139,11 @@ impl Context {
})
}
- fn unary_op_type(&self, op: ast::OpUnary, expr: &ast::Expr<'_>) -> Result<Type> {
+ fn unary_op_type(ctx: &Context, op: ast::OpUnary, expr: &ast::Expr<'_>) -> Result<Type> {
use ast::OpUnary::*;
use Type::*;
- let typ = self.ast_expr_type(expr)?;
+ let typ = Self::ast_expr_type(ctx, expr)?;
Ok(match (op, typ) {
(Not, Bool) => Bool,
@@ -160,11 +152,11 @@ impl Context {
})
}
- fn index_type(&self, expr: &ast::Expr<'_>, index: &ast::Expr<'_>) -> Result<Type> {
+ fn index_type(ctx: &Context, expr: &ast::Expr<'_>, index: &ast::Expr<'_>) -> Result<Type> {
use Type::*;
- let expr_type = self.ast_expr_type(expr)?;
- let index_type = self.ast_expr_type(index)?;
+ let expr_type = Self::ast_expr_type(ctx, expr)?;
+ let index_type = Self::ast_expr_type(ctx, index)?;
let Array(elem_type, _) = expr_type else {
return Err(TypeError);
@@ -175,10 +167,10 @@ impl Context {
Ok(*elem_type)
}
- fn apply_type(&self, expr: &ast::Expr<'_>, params: &[ast::Expr]) -> Result<Type> {
+ fn apply_type(ctx: &Context, expr: &ast::Expr<'_>, params: &[ast::Expr]) -> Result<Type> {
use Type::*;
- let expr_type = self.ast_expr_type(expr)?;
+ let expr_type = Self::ast_expr_type(ctx, expr)?;
let Fn(func) = expr_type else {
return Err(TypeError);
@@ -189,7 +181,7 @@ impl Context {
}
for (func_param_type, call_param) in func.params.iter().zip(params) {
- let call_param_type = self.ast_expr_type(call_param)?;
+ let call_param_type = Self::ast_expr_type(ctx, call_param)?;
func_param_type
.clone()
.unify(call_param_type, Coerce::Dynamic)?;
@@ -199,18 +191,18 @@ impl Context {
}
fn method_type(
- &self,
+ ctx: &Context,
expr: &ast::Expr<'_>,
method: &ast::Ident<'_>,
params: &[ast::Expr],
) -> Result<Type> {
- let expr_type = self.ast_expr_type(expr)?;
+ let expr_type = Self::ast_expr_type(ctx, expr)?;
let type_family = TypeFamily::from(&expr_type);
- let methods = self.methods.get(&type_family).ok_or(TypeError)?;
+ let methods = ctx.methods.get(&type_family).ok_or(TypeError)?;
let method = methods.get(method.name).ok_or(TypeError)?;
- let (self_param, func_params) = method.params.split_first().ok_or(TypeError)?;
+ let (self_param, func_params) = method.typ.params.split_first().ok_or(TypeError)?;
self_param.clone().unify(expr_type, Coerce::Dynamic)?;
if func_params.len() != params.len() {
@@ -218,19 +210,19 @@ impl Context {
}
for (func_param_type, call_param) in func_params.iter().zip(params) {
- let call_param_type = self.ast_expr_type(call_param)?;
+ let call_param_type = Self::ast_expr_type(ctx, call_param)?;
func_param_type
.clone()
.unify(call_param_type, Coerce::Dynamic)?;
}
- Ok(method.ret.clone())
+ Ok(method.typ.ret.clone())
}
- fn field_type(&self, expr: &ast::Expr<'_>, field: &ast::Ident<'_>) -> Result<Type> {
+ fn field_type(ctx: &Context, expr: &ast::Expr<'_>, field: &ast::Ident<'_>) -> Result<Type> {
use Type::*;
- let expr_type = self.ast_expr_type(expr)?;
+ let expr_type = Self::ast_expr_type(ctx, expr)?;
let name = field.name;
Ok(match expr_type {
@@ -243,20 +235,21 @@ impl Context {
})
}
- fn path_type(&self, path: &ast::Path<'_>) -> Result<Type> {
+ fn path_type(ctx: &Context, path: &ast::Path<'_>) -> Result<Type> {
use Type::*;
if path.components == [ast::Ident { name: "_" }] {
return Ok(Free);
}
- self.values
+ ctx.values
.lookup(&path.components)
+ .map(|(typ, _)| typ)
.ok_or(TypeError)
.cloned()
}
- fn literal_type(&self, lit: &ast::Literal<'_>) -> Result<Type> {
+ fn literal_type(ctx: &Context, lit: &ast::Literal<'_>) -> Result<Type> {
use ast::Literal;
use Type::*;
@@ -268,19 +261,19 @@ impl Context {
Literal::Tuple(elems) => Tuple(
elems
.iter()
- .map(|elem| self.ast_expr_type(elem))
+ .map(|elem| Self::ast_expr_type(ctx, elem))
.collect::<Result<_>>()?,
),
Literal::Array(elems) => Array(
Box::new(elems.iter().try_fold(Type::Free, |acc, elem| {
- acc.unify(self.ast_expr_type(elem)?, Coerce::Common)
+ acc.unify(Self::ast_expr_type(ctx, elem)?, Coerce::Common)
})?),
ArrayLen::Fixed(elems.len()),
),
Literal::Map(entries) => Map(entries
.iter()
.map(|ast::MapEntry { key, value }| {
- Ok(((*key).to_owned(), self.ast_expr_type(value)?))
+ Ok(((*key).to_owned(), Self::ast_expr_type(ctx, value)?))
})
.collect::<Result<_>>()?),
})