Skip to content

Commit

Permalink
fix(typecheck): improve handling of generic-related type errors
Browse files Browse the repository at this point in the history
  • Loading branch information
cecelot committed Apr 5, 2024
1 parent c735f58 commit 4297ddb
Show file tree
Hide file tree
Showing 6 changed files with 106 additions and 52 deletions.
10 changes: 5 additions & 5 deletions crates/kyac/snapshots/kyac__pass__typecheck__tests__classes.snap
Original file line number Diff line number Diff line change
Expand Up @@ -72,18 +72,18 @@ Err(
},
PreciseError {
filename: "test-cases/typecheck/classes.kya",
heading: "undefined reference to `baaz` (while reading `Bar`)",
heading: "type `Baz` does not exist",
source: " let baz: Baz = p.foo.bar.baaz;",
span: Span {
line: 34,
column: 30,
length: 4,
column: 14,
length: 3,
},
text: "",
},
PreciseError {
filename: "test-cases/typecheck/classes.kya",
heading: "`Baz` is not defined",
heading: "type `Baz` does not exist",
source: " let baz: Baz = p.foo.bar;",
span: Span {
line: 37,
Expand All @@ -94,7 +94,7 @@ Err(
},
PreciseError {
filename: "test-cases/typecheck/classes.kya",
heading: "`Baz` is not defined",
heading: "type `Baz` does not exist",
source: " let baz: Baz = p.foo.bar.baz;",
span: Span {
line: 38,
Expand Down
6 changes: 5 additions & 1 deletion crates/kyac/src/backend/kyir/translate/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -267,7 +267,11 @@ impl Translate<Expr> for ast::node::Unary {
impl Translate<Expr> for ast::node::Access {
// heh, this is basically the spiritual equivalent of LLVM's getelementptr
fn translate<I: ArchInstr, F: Frame<I>>(&self, translator: &mut Translator<I, F>) -> Expr {
let meta = translator.meta.access.get(&self.id).unwrap();
let meta = translator
.meta
.access
.get(&self.id)
.unwrap_or_else(|| panic!("expected metadata for access node {}", self.id));
let head = self.chain.first().unwrap();
let mut initial = vec![];
let ident = match head {
Expand Down
12 changes: 6 additions & 6 deletions crates/kyac/src/pass/symbol.rs
Original file line number Diff line number Diff line change
Expand Up @@ -22,10 +22,10 @@ pub enum Symbol {
}

impl Symbol {
pub fn class(&self) -> &node::ClassDecl {
pub fn class(&self) -> Option<&node::ClassDecl> {
match self {
Symbol::Class(cls) => cls,
_ => panic!("called `Symbol::class()` on a non-class symbol: {self:?}"),
Symbol::Class(cls) => Some(cls),
_ => None,
}
}

Expand Down Expand Up @@ -55,14 +55,14 @@ impl Symbol {
) -> Vec<&'a node::ClassDecl> {
let mut classes = vec![cls];
while let Some(parent) = cls.parent.as_ref() {
cls = symbols.get(&parent.to_string()).unwrap().class();
cls = symbols.get(&parent.to_string()).unwrap().class().unwrap();
classes.push(cls);
}
classes
}

pub fn fields(&self, symbols: &SymbolTable) -> Vec<node::Field> {
let cls = self.class();
let cls = self.class().unwrap();
let superclasses = Self::superclasses(cls, symbols);
superclasses
.iter()
Expand All @@ -72,7 +72,7 @@ impl Symbol {
}

pub fn methods(&self, symbols: &SymbolTable) -> Vec<(String, Rc<node::FuncDecl>)> {
let cls = self.class();
let cls = self.class().unwrap();
let superclasses = Self::superclasses(cls, symbols);
let mut used = HashSet::new();
let mut methods = vec![];
Expand Down
118 changes: 80 additions & 38 deletions crates/kyac/src/pass/typecheck.rs
Original file line number Diff line number Diff line change
Expand Up @@ -50,12 +50,12 @@ struct TypeResolverContext<'a> {
scopes: Vec<Scope>,
function: Option<Token>,
class: Option<Token>,
instantiation: Option<HashMap<String, ResolvedType>>,
}

#[derive(Debug)]
#[derive(Debug, Clone)]
struct ResolvedType {
base: Symbol,
#[allow(dead_code)]
params: Vec<ResolvedType>,
meta: Type,
}
Expand Down Expand Up @@ -185,11 +185,6 @@ impl ResolveType for Type {
} else if let Some(ty) = cx.ty(&self.base.to_string()) {
ty.base.clone()
} else {
cx.error(
self.base.span,
format!("`{}` is not defined", self.base.lexeme.unwrap()),
String::new(),
);
return Err(TypeError::Undefined);
},
self.params
Expand Down Expand Up @@ -268,6 +263,16 @@ impl ResolveType for node::ClassDecl {
) -> Result<ResolvedType, TypeError> {
cx.begin_scope();
cx.set_type_parameters(meta, self.tp.as_ref());
for field in &self.fields {
if let Err(e) = field.ty.resolve(cx, meta) {
cx.error(
field.ty.base.span,
format!("`{}` is not defined", field.ty.base.lexeme.unwrap()),
String::new(),
);
return Err(e);
}
}
for method in &self.methods {
let _ = Decl::Function(Rc::clone(method)).resolve(cx, meta);
}
Expand Down Expand Up @@ -312,6 +317,14 @@ impl ResolveType for Rc<node::FuncDecl> {
);
}
for param in &self.params {
if let Err(e) = param.ty.resolve(cx, meta) {
cx.error(
param.ty.base.span,
format!("`{}` is not defined", param.ty.base.lexeme.unwrap()),
String::new(),
);
return Err(e);
}
cx.scope_mut()
.symbols
.insert(param.name.to_string(), Symbol::Function(Rc::clone(self)));
Expand Down Expand Up @@ -350,8 +363,41 @@ impl ResolveType for Rc<node::VarDecl> {
cx: &mut TypeResolverContext,
meta: &mut ResolvedMetaInfo,
) -> Result<ResolvedType, TypeError> {
let expected = match self.ty.resolve(cx, meta) {
Ok(ty) => ty,
Err(e) => {
cx.error(
self.ty.base.span,
format!("type `{}` does not exist", self.ty.base.lexeme.unwrap()),
String::new(),
);
return Err(e);
}
};
let empty = vec![];
if let Some(cls) = expected.base.class() {
let tp = cls.tp.as_ref().unwrap_or(&empty);
let mut instantiation = HashMap::new();
for (got, expected) in expected.params.iter().zip(tp.iter()) {
if let Some(ref bound) = expected.bound {
let ty = Type::new(bound.clone(), vec![]).resolve(cx, meta)?;
if cx.cast(&ty, got).is_none() {
cx.error(
self.ty.base.span,
format!("{} does not satisfy bound {}", got.meta, ty.meta),
String::from("in instantiation of type here"),
);
return Err(TypeError::Mismatch(
ty.meta.base.to_string(),
got.meta.base.to_string(),
));
}
}
instantiation.insert(expected.name.to_string(), got.clone());
}
cx.instantiation = Some(instantiation);
}
let got = self.expr.resolve(cx, meta)?;
let expected = self.ty.resolve(cx, meta)?;
if !matches!(
got.base,
Symbol::Bool | Symbol::Int | Symbol::Float | Symbol::Str | Symbol::Void
Expand Down Expand Up @@ -593,27 +639,29 @@ impl ResolveType for node::Init {
);
return Err(TypeError::Undefined);
}
cx.set_type_parameters(
meta,
cx.symbol(&self.name.to_string())
.unwrap()
.class()
.tp
.clone()
.as_ref(),
);
let typ = cx
.symbol(&self.name.to_string())
.unwrap()
.class()
.unwrap()
.tp
.clone();
let symbol = cx.symbol(&self.name.to_string()).unwrap();
let fields = symbol.fields(cx.symbols);
for initializer in &self.initializers {
let got = initializer.expr.resolve(cx, meta)?;
let expected = if let Some(field) = fields.iter().find(|f| f.name == initializer.name) {
field.ty.resolve(cx, meta)?
match field.ty.resolve(cx, meta) {
Ok(ty) => ty,
Err(e) => {
match cx.instantiation.as_ref().map(|instantiation| {
&instantiation[&field.ty.base.lexeme.unwrap().to_string()]
}) {
Some(ty) => ty.clone(),
None => return Err(e),
}
}
}
} else {
cx.error(
initializer.name.span,
Expand All @@ -623,25 +671,15 @@ impl ResolveType for node::Init {
continue;
};
let valid = if let Some(ref typ) = typ {
if let Some(ty) = typ
.iter()
.find(|t| t.name == expected.meta.base.lexeme.unwrap())
{
let uncastable = |bound: &Token| {
let field = fields.iter().find(|f| f.name == initializer.name).unwrap();
if let Some(ty) = typ.iter().find(|t| t.name == field.ty.base.lexeme.unwrap()) {
let castable = |bound: &Token| {
let raw_type = Type::new(bound.clone(), vec![]);
let expected = raw_type.resolve(cx, meta).unwrap();
cx.cast(&expected, &got).is_none()
cx.cast(&expected, &got).is_some()
};
if ty.bound.as_ref().is_some_and(uncastable) {
cx.error(
initializer.expr.span(),
format!("expected initializer to be of type {}", expected.meta),
format!("expression of type {}", got.meta),
);
false
} else {
true
}
let castable = ty.bound.as_ref().is_some_and(castable);
castable
} else {
true
}
Expand Down Expand Up @@ -729,10 +767,13 @@ impl ResolveType for node::Access {
} else {
ty
};
if let Some((index, _, field)) = left.field(cx.symbols, right) {
if let Some((index, cls, field)) = left.field(cx.symbols, right) {
symbols.push(left.base.clone());
indices.push(index);
ty = field.ty.resolve(cx, meta).unwrap();
cx.begin_scope();
cx.set_type_parameters(meta, cls.tp.as_ref());
ty = field.ty.resolve(cx, meta)?;
cx.end_scope();
} else if let Some(method) = left.method(cx.symbols, right) {
symbols.push(left.base.clone());
symbols.push(Symbol::Function(Rc::clone(&method)));
Expand Down Expand Up @@ -853,7 +894,7 @@ impl ResolveType for node::Ident {
cx.error(
self.name.span,
format!("`{}` is not defined", &self.name),
String::new(),
format!("the type of `{}` may not be valid", &self.name),
);
Err(TypeError::Undefined)
}
Expand Down Expand Up @@ -941,6 +982,7 @@ impl<'a> TypeResolverContext<'a> {
class: None,
function: None,
scopes: vec![],
instantiation: None,
}
}

Expand Down Expand Up @@ -1023,7 +1065,7 @@ impl<'a> TypeResolverContext<'a> {

fn cast(&self, expected: &ResolvedType, got: &ResolvedType) -> Option<String> {
let cls = self.symbol(&got.meta.to_string())?;
let cls = cls.class();
let cls = matches!(cls, Symbol::Class(_)).then(|| cls.class().unwrap())?;
Symbol::superclasses(cls, self.symbols)
.iter()
.filter(|c| c.name != cls.name)
Expand Down
7 changes: 7 additions & 0 deletions crates/kyanite/tests/kyir/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -336,3 +336,10 @@ fn dynamic_dispatch() -> Result<(), Box<dyn std::error::Error>> {
);
Ok(())
}

#[test]
fn basic_generics() -> Result<(), Box<dyn std::error::Error>> {
let res = run("kyir/basic-generics.kya")?;
assert_eq!(res.output, "42\n42\n42\n");
Ok(())
}
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,8 @@ fun freeFunction<T: Print>(obj: T) {

fun main() {
let bar: Bar = Bar:init(x: 42);
let foo: Foo<Bar> = Foo:init(val: bar);
foo.print();
let foo: Foo<int> = Foo:init(val: bar);
foo.val.print();
bar.print();
freeFunction(foo);
}

0 comments on commit 4297ddb

Please sign in to comment.