From 9025dbdbeaa125d9a1255a0aacfda2aea19d1f91 Mon Sep 17 00:00:00 2001 From: Alaina <68250402+alaidriel@users.noreply.github.com> Date: Thu, 21 Mar 2024 11:08:03 -0500 Subject: [PATCH 1/5] feat: generics --- .../snapshots/kyac__ast__tests__calls.snap | 61 +- .../snapshots/kyac__ast__tests__empty.snap | 24 +- .../snapshots/kyac__ast__tests__expr.snap | 1 + .../kyac__ast__tests__hello_world.snap | 1 + .../snapshots/kyac__ast__tests__mixed.snap | 147 ++--- ...kyac__pass__typecheck__tests__classes.snap | 22 +- crates/kyac/src/ast/mod.rs | 85 +-- crates/kyac/src/ast/node.rs | 36 +- crates/kyac/src/ast/span.rs | 18 +- crates/kyac/src/ast/ty.rs | 32 ++ .../kyac/src/backend/kyir/arch/armv8a/mod.rs | 4 +- crates/kyac/src/backend/kyir/translate/mod.rs | 28 +- crates/kyac/src/backend/llvm/mod.rs | 4 +- crates/kyac/src/backend/mod.rs | 1 + crates/kyac/src/lib.rs | 16 +- crates/kyac/src/parse.rs | 93 +++- crates/kyac/src/pass/symbol.rs | 30 +- crates/kyac/src/pass/typecheck.rs | 522 +++++++++++------- crates/kyanite/tests/mod.rs | 1 + examples/kyir/parametric-polymorphism.kya | 31 ++ 20 files changed, 698 insertions(+), 459 deletions(-) create mode 100644 crates/kyac/src/ast/ty.rs create mode 100644 examples/kyir/parametric-polymorphism.kya diff --git a/crates/kyac/snapshots/kyac__ast__tests__calls.snap b/crates/kyac/snapshots/kyac__ast__tests__calls.snap index 2fb8a34..abb2fd4 100644 --- a/crates/kyac/snapshots/kyac__ast__tests__calls.snap +++ b/crates/kyac/snapshots/kyac__ast__tests__calls.snap @@ -30,32 +30,39 @@ Ast { length: 3, }, }, - ty: Token { + ty: Type { + base: Token { + kind: Identifier, + lexeme: Some( + "int", + ), + span: Span { + line: 1, + column: 14, + length: 3, + }, + }, + params: [], + }, + }, + ], + ty: Some( + Type { + base: Token { kind: Identifier, lexeme: Some( "int", ), span: Span { line: 1, - column: 14, + column: 20, length: 3, }, }, - }, - ], - ty: Some( - Token { - kind: Identifier, - lexeme: Some( - "int", - ), - span: Span { - line: 1, - column: 20, - length: 3, - }, + params: [], }, ), + tp: [], body: [ Return( Return { @@ -105,18 +112,22 @@ Ast { }, params: [], ty: Some( - Token { - kind: Identifier, - lexeme: Some( - "float", - ), - span: Span { - line: 5, - column: 12, - length: 5, + Type { + base: Token { + kind: Identifier, + lexeme: Some( + "float", + ), + span: Span { + line: 5, + column: 12, + length: 5, + }, }, + params: [], }, ), + tp: [], body: [ Return( Return { @@ -167,6 +178,7 @@ Ast { }, params: [], ty: None, + tp: [], body: [ Expr( Call( @@ -249,6 +261,7 @@ Ast { }, params: [], ty: None, + tp: [], body: [ Expr( Call( diff --git a/crates/kyac/snapshots/kyac__ast__tests__empty.snap b/crates/kyac/snapshots/kyac__ast__tests__empty.snap index b0e3978..df3dffc 100644 --- a/crates/kyac/snapshots/kyac__ast__tests__empty.snap +++ b/crates/kyac/snapshots/kyac__ast__tests__empty.snap @@ -1,5 +1,5 @@ --- -source: kyanite-core/src/ast/mod.rs +source: crates/kyac/src/ast/mod.rs expression: ast --- Ast { @@ -19,18 +19,22 @@ Ast { }, params: [], ty: Some( - Token { - kind: Identifier, - lexeme: Some( - "void", - ), - span: Span { - line: 1, - column: 12, - length: 4, + Type { + base: Token { + kind: Identifier, + lexeme: Some( + "void", + ), + span: Span { + line: 1, + column: 12, + length: 4, + }, }, + params: [], }, ), + tp: [], body: [], external: false, id: 0, diff --git a/crates/kyac/snapshots/kyac__ast__tests__expr.snap b/crates/kyac/snapshots/kyac__ast__tests__expr.snap index cead230..576cbd8 100644 --- a/crates/kyac/snapshots/kyac__ast__tests__expr.snap +++ b/crates/kyac/snapshots/kyac__ast__tests__expr.snap @@ -19,6 +19,7 @@ Ast { }, params: [], ty: None, + tp: [], body: [ Expr( Call( diff --git a/crates/kyac/snapshots/kyac__ast__tests__hello_world.snap b/crates/kyac/snapshots/kyac__ast__tests__hello_world.snap index 0b2b4b3..bdfa5d1 100644 --- a/crates/kyac/snapshots/kyac__ast__tests__hello_world.snap +++ b/crates/kyac/snapshots/kyac__ast__tests__hello_world.snap @@ -19,6 +19,7 @@ Ast { }, params: [], ty: None, + tp: [], body: [ Expr( Call( diff --git a/crates/kyac/snapshots/kyac__ast__tests__mixed.snap b/crates/kyac/snapshots/kyac__ast__tests__mixed.snap index 10601c9..b96e58b 100644 --- a/crates/kyac/snapshots/kyac__ast__tests__mixed.snap +++ b/crates/kyac/snapshots/kyac__ast__tests__mixed.snap @@ -17,16 +17,19 @@ Ast { length: 3, }, }, - ty: Token { - kind: Identifier, - lexeme: Some( - "float", - ), - span: Span { - line: 1, - column: 12, - length: 5, + ty: Type { + base: Token { + kind: Identifier, + lexeme: Some( + "float", + ), + span: Span { + line: 1, + column: 12, + length: 5, + }, }, + params: [], }, expr: Float( Literal { @@ -72,16 +75,19 @@ Ast { length: 3, }, }, - ty: Token { - kind: Identifier, - lexeme: Some( - "float", - ), - span: Span { - line: 3, - column: 19, - length: 5, + ty: Type { + base: Token { + kind: Identifier, + lexeme: Some( + "float", + ), + span: Span { + line: 3, + column: 19, + length: 5, + }, }, + params: [], }, }, Param { @@ -96,32 +102,39 @@ Ast { length: 3, }, }, - ty: Token { + ty: Type { + base: Token { + kind: Identifier, + lexeme: Some( + "int", + ), + span: Span { + line: 3, + column: 31, + length: 3, + }, + }, + params: [], + }, + }, + ], + ty: Some( + Type { + base: Token { kind: Identifier, lexeme: Some( - "int", + "float", ), span: Span { line: 3, - column: 31, - length: 3, + column: 37, + length: 5, }, }, - }, - ], - ty: Some( - Token { - kind: Identifier, - lexeme: Some( - "float", - ), - span: Span { - line: 3, - column: 37, - length: 5, - }, + params: [], }, ), + tp: [], body: [ Var( VarDecl { @@ -136,16 +149,19 @@ Ast { length: 3, }, }, - ty: Token { - kind: Identifier, - lexeme: Some( - "float", - ), - span: Span { - line: 4, - column: 14, - length: 5, + ty: Type { + base: Token { + kind: Identifier, + lexeme: Some( + "float", + ), + span: Span { + line: 4, + column: 14, + length: 5, + }, }, + params: [], }, expr: Float( Literal { @@ -269,6 +285,7 @@ Ast { }, params: [], ty: None, + tp: [], body: [ Var( VarDecl { @@ -283,16 +300,19 @@ Ast { length: 1, }, }, - ty: Token { - kind: Identifier, - lexeme: Some( - "float", - ), - span: Span { - line: 9, - column: 12, - length: 5, + ty: Type { + base: Token { + kind: Identifier, + lexeme: Some( + "float", + ), + span: Span { + line: 9, + column: 12, + length: 5, + }, }, + params: [], }, expr: Float( Literal { @@ -325,16 +345,19 @@ Ast { length: 1, }, }, - ty: Token { - kind: Identifier, - lexeme: Some( - "int", - ), - span: Span { - line: 10, - column: 12, - length: 3, + ty: Type { + base: Token { + kind: Identifier, + lexeme: Some( + "int", + ), + span: Span { + line: 10, + column: 12, + length: 3, + }, }, + params: [], }, expr: Int( Literal { diff --git a/crates/kyac/snapshots/kyac__pass__typecheck__tests__classes.snap b/crates/kyac/snapshots/kyac__pass__typecheck__tests__classes.snap index b46bb74..d927e38 100644 --- a/crates/kyac/snapshots/kyac__pass__typecheck__tests__classes.snap +++ b/crates/kyac/snapshots/kyac__pass__typecheck__tests__classes.snap @@ -50,7 +50,7 @@ Err( }, PreciseError { filename: "test-cases/typecheck/classes.kya", - heading: "no field `faavorite` on type `Person`", + heading: "undefined reference to `faavorite` (while reading `Person`)", source: " let favorite: int = p.faavorite;", span: Span { line: 31, @@ -61,7 +61,7 @@ Err( }, PreciseError { filename: "test-cases/typecheck/classes.kya", - heading: "no field `barr` on type `Foo`", + heading: "undefined reference to `barr` (while reading `Foo`)", source: " let bar: Bar = p.foo.barr;", span: Span { line: 33, @@ -72,7 +72,7 @@ Err( }, PreciseError { filename: "test-cases/typecheck/classes.kya", - heading: "no field `baaz` on type `Bar`", + heading: "undefined reference to `baaz` (while reading `Bar`)", source: " let baz: Baz = p.foo.bar.baaz;", span: Span { line: 34, @@ -83,25 +83,25 @@ Err( }, PreciseError { filename: "test-cases/typecheck/classes.kya", - heading: "Bar is not a subclass of Baz", + heading: "`Baz` is not defined", source: " let baz: Baz = p.foo.bar;", span: Span { line: 37, - column: 20, - length: 9, + column: 14, + length: 3, }, - text: "expression of type Bar", + text: "", }, PreciseError { filename: "test-cases/typecheck/classes.kya", - heading: "expected initializer to be of type Baz", + heading: "`Baz` is not defined", source: " let baz: Baz = p.foo.bar.baz;", span: Span { line: 38, - column: 20, - length: 13, + column: 14, + length: 3, }, - text: "expression of type bool", + text: "", }, PreciseError { filename: "test-cases/typecheck/classes.kya", diff --git a/crates/kyac/src/ast/mod.rs b/crates/kyac/src/ast/mod.rs index a53497d..ae780eb 100644 --- a/crates/kyac/src/ast/mod.rs +++ b/crates/kyac/src/ast/mod.rs @@ -2,13 +2,10 @@ pub mod node; pub mod span; #[cfg(test)] mod strip; +pub mod ty; -use crate::{ - parse::Parser, - token::{Lexer, Token}, - PipelineError, Source, -}; -use std::{fmt, rc::Rc}; +use crate::{parse::Parser, token::Lexer, PipelineError, Source}; +use std::rc::Rc; #[derive(Debug)] pub struct Ast { @@ -97,82 +94,6 @@ impl Expr { } } -#[derive(Debug, Clone, PartialEq, Eq)] -pub enum Type { - Str, - Int, - Float, - Bool, - Void, - UserDefined(String), -} - -impl From<&Token> for Type { - fn from(value: &Token) -> Self { - match value.lexeme.expect("token should have lexeme") { - "str" => Self::Str, - "int" => Self::Int, - "float" => Self::Float, - "bool" => Self::Bool, - "void" => Self::Void, - name => Self::UserDefined(name.to_string()), - } - } -} - -impl From> for Type { - fn from(token: Option<&Token>) -> Self { - match token { - Some(token) => Self::from(token), - None => Self::Void, - } - } -} - -impl PartialEq for Option { - fn eq(&self, other: &Type) -> bool { - match self { - Some(token) => Type::from(token) == *other, - None => *other == Type::Void, - } - } -} - -impl fmt::Display for Type { - fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { - write!( - f, - "{}", - match self { - Self::Str => "str", - Self::Int => "int", - Self::Float => "float", - Self::Bool => "bool", - Self::Void => "void", - Self::UserDefined(name) => name, - } - ) - } -} - -impl Expr { - pub fn ty(&self) -> Type { - match self { - Expr::Str(..) => Type::Str, - Expr::Int(..) => Type::Int, - Expr::Float(..) => Type::Float, - Expr::Bool(..) => Type::Bool, - Expr::Range(r) => r.start.ty(), - Expr::Binary(binary) => binary.left.ty(), - Expr::Unary(unary) => unary.expr.ty(), - Expr::Call(call) => call.left.ty(), - Expr::Ident(_) => unimplemented!(), - Expr::Init(..) => unimplemented!(), - Expr::Access(..) => unimplemented!(), - } - } -} - macro_rules! assert_ast { ($($path:expr => $name:ident),*) => { #[cfg(test)] diff --git a/crates/kyac/src/ast/node.rs b/crates/kyac/src/ast/node.rs index 4fbe052..acedbeb 100644 --- a/crates/kyac/src/ast/node.rs +++ b/crates/kyac/src/ast/node.rs @@ -1,5 +1,8 @@ use crate::{ - ast::{Decl, Expr, Stmt}, + ast::{ + ty::{Type, TypeParameter}, + Decl, Expr, Stmt, + }, token::Token, }; use std::{ @@ -11,7 +14,8 @@ use std::{ pub struct FuncDecl { pub name: Token, pub params: Vec, - pub ty: Option, + pub ty: Option, + pub tp: Vec, pub body: Vec, pub external: bool, pub id: usize, @@ -21,7 +25,8 @@ impl FuncDecl { pub fn new( name: Token, params: Vec, - ty: Option, + ty: Option, + tp: Vec, body: Vec, external: bool, ) -> Self { @@ -31,6 +36,7 @@ impl FuncDecl { name, params, ty, + tp, body, external, id, @@ -40,11 +46,12 @@ impl FuncDecl { pub fn wrapped( name: Token, params: Vec, - ty: Option, + ty: Option, + tp: Vec, body: Vec, external: bool, ) -> Decl { - Decl::Function(Rc::new(Self::new(name, params, ty, body, external))) + Decl::Function(Rc::new(Self::new(name, params, ty, tp, body, external))) } } @@ -54,6 +61,7 @@ pub struct ClassDecl { pub fields: Vec, pub methods: Vec>, pub parent: Option, + pub tp: Option>, } impl ClassDecl { @@ -62,12 +70,14 @@ impl ClassDecl { fields: Vec, methods: Vec>, parent: Option, + tp: Option>, ) -> Decl { Decl::Class(Rc::new(Self { name, fields, methods, parent, + tp, })) } } @@ -75,12 +85,12 @@ impl ClassDecl { #[derive(Debug)] pub struct ConstantDecl { pub name: Token, - pub ty: Token, + pub ty: Type, pub expr: Expr, } impl ConstantDecl { - pub fn wrapped(name: Token, ty: Token, expr: Expr) -> Decl { + pub fn wrapped(name: Token, ty: Type, expr: Expr) -> Decl { Decl::Constant(Rc::new(Self { name, ty, expr })) } } @@ -88,12 +98,12 @@ impl ConstantDecl { #[derive(Debug, PartialEq)] pub struct VarDecl { pub name: Token, - pub ty: Token, + pub ty: Type, pub expr: Expr, } impl VarDecl { - pub fn wrapped(name: Token, ty: Token, expr: Expr) -> Stmt { + pub fn wrapped(name: Token, ty: Type, expr: Expr) -> Stmt { Stmt::Var(Rc::new(Self { name, ty, expr })) } } @@ -339,11 +349,11 @@ impl Initializer { #[derive(Debug, Clone, PartialEq, Eq)] pub struct Param { pub name: Token, - pub ty: Token, + pub ty: Type, } impl Param { - pub fn new(name: Token, ty: Token) -> Self { + pub fn new(name: Token, ty: Type) -> Self { Self { name, ty } } } @@ -351,11 +361,11 @@ impl Param { #[derive(Debug, Clone, PartialEq, Eq)] pub struct Field { pub name: Token, - pub ty: Token, + pub ty: Type, } impl Field { - pub fn new(name: Token, ty: Token) -> Self { + pub fn new(name: Token, ty: Type) -> Self { Self { name, ty } } } diff --git a/crates/kyac/src/ast/span.rs b/crates/kyac/src/ast/span.rs index 9d17ebd..d391e61 100644 --- a/crates/kyac/src/ast/span.rs +++ b/crates/kyac/src/ast/span.rs @@ -1,5 +1,5 @@ use crate::{ - ast::{Expr, Stmt}, + ast::{ty::Type, Expr, Stmt}, token::Span, }; @@ -13,6 +13,22 @@ pub trait Combined { fn line(&self) -> usize; } +impl Combined for Type { + fn start(&self) -> usize { + self.base.span.column + } + + fn end(&self) -> usize { + self.params + .last() + .map_or(self.base.span.column + self.base.span.length, Combined::end) + } + + fn line(&self) -> usize { + self.base.span.line + } +} + impl Combined for Stmt { fn start(&self) -> usize { match self { diff --git a/crates/kyac/src/ast/ty.rs b/crates/kyac/src/ast/ty.rs new file mode 100644 index 0000000..318be3f --- /dev/null +++ b/crates/kyac/src/ast/ty.rs @@ -0,0 +1,32 @@ +use crate::token::Token; +use std::fmt; + +#[derive(Debug, Clone, PartialEq, Eq)] +pub struct Type { + pub base: Token, + pub params: Vec, +} + +impl Type { + pub fn new(base: Token, params: Vec) -> Self { + Self { base, params } + } +} + +impl fmt::Display for Type { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + write!(f, "{}", self.base.lexeme.unwrap_or("no lexeme found")) + } +} + +#[derive(Debug, Clone, PartialEq, Eq)] +pub struct TypeParameter { + pub name: Token, + pub bound: Option, +} + +impl TypeParameter { + pub fn new(name: Token, bound: Option) -> Self { + Self { name, bound } + } +} diff --git a/crates/kyac/src/backend/kyir/arch/armv8a/mod.rs b/crates/kyac/src/backend/kyir/arch/armv8a/mod.rs index dff021a..2e24e5f 100644 --- a/crates/kyac/src/backend/kyir/arch/armv8a/mod.rs +++ b/crates/kyac/src/backend/kyir/arch/armv8a/mod.rs @@ -1,7 +1,7 @@ pub mod isa; use crate::{ - ast::{node::FuncDecl, Type}, + ast::node::FuncDecl, backend::kyir::{ arch::{Location, RegisterMap}, ir::{BinOp, Binary, Const, Expr, Mem, Temp}, @@ -32,7 +32,7 @@ impl Frame for Armv8a { param.name.to_string(), Variable::new( offset, - matches!(Type::from(¶m.ty), Type::UserDefined(_)), + !matches!(param.ty.base.lexeme, Some("int" | "float" | "bool")), ), ); offset -= i64::try_from(Self::word_size()).unwrap(); diff --git a/crates/kyac/src/backend/kyir/translate/mod.rs b/crates/kyac/src/backend/kyir/translate/mod.rs index bec28d8..e1218fe 100644 --- a/crates/kyac/src/backend/kyir/translate/mod.rs +++ b/crates/kyac/src/backend/kyir/translate/mod.rs @@ -4,7 +4,7 @@ pub(super) use canon::canonicalize; #[allow(clippy::wildcard_imports)] use crate::{ - ast::{self, node::FuncDecl, Decl as AstDecl, Expr as AstExpr, Stmt as AstStmt, Type}, + ast::{self, node::FuncDecl, ty::Type, Decl as AstDecl, Expr as AstExpr, Stmt as AstStmt}, backend::kyir::{ arch::{ArchInstr, Frame}, ir::*, @@ -182,18 +182,18 @@ impl Translate for ast::node::Call { }, _ => panic!("Expected either `AstExpr::Ident` or `AstExpr::Access` on left side of call expression"), }; - let ty = match *self.left { - AstExpr::Ident(_) => translator.symbols[&name].ty(), + let ptr = match *self.left { + AstExpr::Ident(_) => translator.symbols[&name].is_ptr(), AstExpr::Access(ref access) => { let meta = translator.meta.access.get(&access.id).unwrap(); - meta.ty.clone() + meta.symbols.last().unwrap().is_ptr() } _ => unimplemented!(), }; let temp = Temp::next(); let id = translator.function.unwrap(); let frame = translator.functions.get_mut(&id).unwrap(); - let saved = frame.allocate(&temp, matches!(ty, Type::UserDefined(_))); + let saved = frame.allocate(&temp, ptr); let mut stmts = vec![]; let address = if cls.is_some_and(|cls| Symbol::has_subclass(cls, translator.symbols)) { // This call could be overridden by a subclass in which case we need to use dynamic dispatch. @@ -272,7 +272,10 @@ impl Translate for ast::node::Access { let name = Temp::next(); let decl = ast::node::VarDecl::wrapped( Token::new(Kind::Identifier, Some(name.clone().leak()), Span::default()), - Token::new(Kind::Literal, init.name.lexeme, Span::default()), + Type::new( + Token::new(Kind::Literal, init.name.lexeme, Span::default()), + vec![], + ), head.clone(), ); let stmt = decl.translate(translator); @@ -510,7 +513,10 @@ impl Translate for ast::node::For { let cur = ast::node::Ident::wrapped(self.index.clone()); let start = ast::node::VarDecl::wrapped( cur.clone().ident().name.clone(), - Token::new(Kind::Literal, Some("int"), Span::default()), + Type::new( + Token::new(Kind::Literal, Some("int"), Span::default()), + vec![], + ), range.start.clone(), ); let w = ast::node::While { @@ -586,7 +592,10 @@ impl Translate for ast::node::VarDecl { let id = translator.function.unwrap(); let frame = translator.functions.get_mut(&id).unwrap(); // No matter what, variables are always F::word_size() (either pointer to first element or the value itself) - let target = frame.allocate(&name, matches!(Type::from(&self.ty), Type::UserDefined(_))); + let target = frame.allocate( + &name, + !matches!(self.ty.base.lexeme, Some("int" | "float" | "bool")), + ); Stmt::checked_move(target, expr) } } @@ -601,7 +610,7 @@ impl Translate for ast::node::FuncDecl { .into_iter() .chain(self.body.iter().map(|stmt| stmt.translate(translator))) .collect(); - if self.ty == Type::Void { + if self.ty.is_none() || self.ty.as_ref().unwrap().base.lexeme == Some("void") { // If the function returns void, explicitly zero out the return register. This can // cause unwanted behavior in the garbage collector because if the last call is an // allocation: that pointer will be copied to the parent frame and be considered @@ -624,6 +633,7 @@ impl Translate> for ast::node::ClassDecl { method.name.clone(), method.params.clone(), method.ty.clone(), + method.tp.clone(), method.body.clone(), false, ); diff --git a/crates/kyac/src/backend/llvm/mod.rs b/crates/kyac/src/backend/llvm/mod.rs index 6c5fddc..213aacf 100644 --- a/crates/kyac/src/backend/llvm/mod.rs +++ b/crates/kyac/src/backend/llvm/mod.rs @@ -3,7 +3,8 @@ mod builtins; use crate::{ ast::{ node::{self, ClassDecl, Ident}, - Decl, Expr, Stmt, Type, + ty::Type, + Decl, Expr, Stmt, }, backend::llvm::builtins::Builtins, pass::{ResolvedMetaInfo, Symbol, SymbolTable}, @@ -298,6 +299,7 @@ impl<'a, 'ctx> Ir<'a, 'ctx> { func.name.clone(), func.params.clone(), func.ty.clone(), + func.tp.clone(), func.body.clone(), func.external, ); diff --git a/crates/kyac/src/backend/mod.rs b/crates/kyac/src/backend/mod.rs index 56fb20f..8659e0f 100644 --- a/crates/kyac/src/backend/mod.rs +++ b/crates/kyac/src/backend/mod.rs @@ -1,2 +1,3 @@ pub mod kyir; +#[cfg(feature = "llvm")] pub mod llvm; diff --git a/crates/kyac/src/lib.rs b/crates/kyac/src/lib.rs index 97ef09b..df29c6b 100644 --- a/crates/kyac/src/lib.rs +++ b/crates/kyac/src/lib.rs @@ -17,25 +17,28 @@ pub mod isa { pub use crate::backend::kyir::arch::armv8a::isa::A64; } -use crate::{ - arch::Armv8a, - backend::{kyir, llvm}, - isa::A64, - pass::SymbolTable, -}; +#[cfg(feature = "llvm")] +use crate::backend::llvm; +use crate::{arch::Armv8a, backend::kyir, isa::A64, pass::SymbolTable}; use std::{fs::File, io::Read, path::Path}; pub const VERSION: &str = env!("CARGO_PKG_VERSION"); pub fn compile(source: &Source, backend: &Backend) -> Result { + #[cfg(feature = "llvm")] let mut ast = ast::Ast::try_from(source)?; + #[cfg(not(feature = "llvm"))] + let ast = ast::Ast::try_from(source)?; let symbols = SymbolTable::from(&ast.nodes); let meta = pass::resolve_types(source, &symbols, &ast.nodes) .map_err(|e| PipelineError::TypeError(e.len()))?; match backend { + #[cfg(feature = "llvm")] Backend::Llvm => Ok(Output::Llvm( llvm::Ir::build(&mut ast.nodes, symbols, meta).map_err(PipelineError::IrError)?, )), + #[cfg(not(feature = "llvm"))] + Backend::Llvm => panic!("LLVM backend not enabled"), Backend::Kyir => Ok(Output::Asm(kyir::asm::( &ast.nodes, &symbols, &meta, ))), @@ -98,6 +101,7 @@ pub enum PipelineError { ParseError(usize), #[error("(while type checking) {0} error(s) encountered")] TypeError(usize), + #[cfg(feature = "llvm")] #[error("(while building ir) {0}")] IrError(llvm::IrError), #[error("failed to compile (see output)")] diff --git a/crates/kyac/src/parse.rs b/crates/kyac/src/parse.rs index 05ad046..ede86d7 100644 --- a/crates/kyac/src/parse.rs +++ b/crates/kyac/src/parse.rs @@ -1,6 +1,10 @@ #[allow(clippy::wildcard_imports)] use crate::{ - ast::{node::*, Decl, Expr, Stmt}, + ast::{ + node::*, + ty::{Type, TypeParameter}, + Decl, Expr, Stmt, + }, error::PreciseError, token::{Kind, Span, Token}, Source, @@ -71,14 +75,15 @@ impl<'a> Parser<'a> { fn class(&mut self) -> Result { self.consume(Kind::Class)?; let name = self.consume(Kind::Identifier)?; - let parent = (self.peek()?.kind == Kind::Colon).then_some(|| { - self.consume(Kind::Colon)?; - self.consume(Kind::Identifier) - }); - let parent = match parent { - Some(mut parent) => Some(parent()?), - None => None, - }; + let tp = (self.peek()?.kind == Kind::Less) + .then(|| self.type_parameters()) + .transpose()?; + let parent = (self.peek()?.kind == Kind::Colon) + .then(|| { + self.consume(Kind::Colon)?; + self.consume(Kind::Identifier) + }) + .transpose()?; self.consume(Kind::LeftBrace)?; let fields = self.fields()?; let mut methods = vec![]; @@ -89,7 +94,27 @@ impl<'a> Parser<'a> { }); } self.consume(Kind::RightBrace)?; - Ok(ClassDecl::wrapped(name, fields, methods, parent)) + Ok(ClassDecl::wrapped(name, fields, methods, parent, tp)) + } + + fn type_parameters(&mut self) -> Result, ParseError> { + let mut tp = vec![]; + self.consume(Kind::Less)?; + while self.peek()?.kind != Kind::Greater { + let name = self.consume(Kind::Identifier)?; + let bound = (!matches!(self.peek()?.kind, Kind::Comma | Kind::Greater)) + .then(|| { + self.consume(Kind::Colon)?; + self.consume(Kind::Identifier) + }) + .transpose()?; + if self.peek()?.kind != Kind::Greater { + self.consume(Kind::Comma)?; + } + tp.push(TypeParameter::new(name, bound)); + } + self.consume(Kind::Greater)?; + Ok(tp) } fn function(&mut self, method: &Option, external: bool) -> Result { @@ -98,18 +123,29 @@ impl<'a> Parser<'a> { } self.consume(Kind::Fun)?; let name = self.consume(Kind::Identifier)?; + let tp = (self.peek()?.kind == Kind::Less) + .then(|| self.type_parameters()) + .transpose()? + .unwrap_or(vec![]); self.consume(Kind::LeftParen)?; let params = self.params(method)?; self.consume(Kind::RightParen)?; - let mut ty: Option = None; + let mut ty: Option = None; if self.peek()?.kind == Kind::Colon { self.consume(Kind::Colon)?; - ty = Some(self.consume(Kind::Identifier)?); + ty = Some(self.ty()?); } if external { - Ok(FuncDecl::wrapped(name, params, ty, vec![], external)) + Ok(FuncDecl::wrapped(name, params, ty, tp, vec![], external)) } else { - Ok(FuncDecl::wrapped(name, params, ty, self.block()?, external)) + Ok(FuncDecl::wrapped( + name, + params, + ty, + tp, + self.block()?, + external, + )) } } @@ -120,10 +156,10 @@ impl<'a> Parser<'a> { let name = self.consume(Kind::Identifier)?; let ty = if method.as_ref().is_some_and(|_| index == 0) { index += 1; - method.clone().unwrap() + Type::new(method.clone().unwrap(), vec![]) } else { self.consume(Kind::Colon)?; - self.consume(Kind::Identifier)? + self.ty()? }; params.push(Param::new(name, ty)); if self.peek()?.kind != Kind::RightParen { @@ -138,7 +174,7 @@ impl<'a> Parser<'a> { while !matches!(self.peek()?.kind, Kind::RightBrace | Kind::Fun) { let name = self.consume(Kind::Identifier)?; self.consume(Kind::Colon)?; - let ty = self.consume(Kind::Identifier)?; + let ty = self.ty()?; fields.push(Field::new(name, ty)); if !matches!(self.peek()?.kind, Kind::RightBrace | Kind::Fun) { self.consume(Kind::Comma)?; @@ -147,6 +183,25 @@ impl<'a> Parser<'a> { Ok(fields) } + fn ty(&mut self) -> Result { + let base = self.consume(Kind::Identifier)?; + (self.peek()?.kind == Kind::Less) + .then(|| { + self.consume(Kind::Less)?; + let mut params = vec![]; + while self.peek()?.kind != Kind::Greater { + params.push(self.ty()?); + if self.peek()?.kind != Kind::Greater { + self.consume(Kind::Comma)?; + } + } + self.consume(Kind::Greater)?; + Ok(params) + }) + .transpose() + .map(|params| Type::new(base, params.unwrap_or_default())) + } + fn block(&mut self) -> Result, ParseError> { self.consume(Kind::LeftBrace)?; let mut stmts: Vec = vec![]; @@ -168,7 +223,7 @@ impl<'a> Parser<'a> { self.consume(Kind::Const)?; let name = self.consume(Kind::Identifier)?; self.consume(Kind::Colon)?; - let ty = self.consume(Kind::Identifier)?; + let ty = self.ty()?; self.consume(Kind::Equal)?; let value = self.expression()?; self.consume(Kind::Semicolon)?; @@ -179,7 +234,7 @@ impl<'a> Parser<'a> { self.consume(Kind::Let)?; let name = self.consume(Kind::Identifier)?; self.consume(Kind::Colon)?; - let ty = self.consume(Kind::Identifier)?; + let ty = self.ty()?; self.consume(Kind::Equal)?; let expr = self.expression()?; self.consume(Kind::Semicolon)?; diff --git a/crates/kyac/src/pass/symbol.rs b/crates/kyac/src/pass/symbol.rs index 74cdfa7..1b894f6 100644 --- a/crates/kyac/src/pass/symbol.rs +++ b/crates/kyac/src/pass/symbol.rs @@ -1,5 +1,5 @@ use crate::{ - ast::{node, Decl, Type}, + ast::{node, Decl}, builtins, }; use std::{ @@ -13,16 +13,25 @@ pub enum Symbol { Function(Rc), Constant(Rc), Variable(Rc), + Str, + Int, + Float, + Bool, + Void, } impl Symbol { pub fn class(&self) -> &node::ClassDecl { match self { Symbol::Class(cls) => cls, - _ => panic!("called `Symbol::class()` on a non-class symbol"), + _ => panic!("called `Symbol::class()` on a non-class symbol: {self:?}"), } } + pub fn is_ptr(&self) -> bool { + matches!(self, Symbol::Class(_) | Symbol::Str) + } + pub fn function(&self) -> &node::FuncDecl { match self { Symbol::Function(fun) => fun, @@ -97,23 +106,14 @@ impl Symbol { let fields = self .fields(symbols) .iter() - .map(|f| match Type::from(&f.ty) { - Type::Int | Type::Float | Type::Bool => 'i', - Type::Str | Type::UserDefined(_) => 'p', - Type::Void => panic!("class cannot contain void field"), + .map(|f| match f.ty.base.lexeme.unwrap() { + "int" | "float" | "bool" => 'i', + "void" => panic!("class cannot contain void field"), + _ => 'p', }) .collect(); (fields, methods) } - - pub fn ty(&self) -> Type { - match self { - Self::Class(cls) => Type::from(&cls.name), - Self::Function(fun) => fun.ty.as_ref().into(), - Self::Constant(c) => Type::from(&c.ty), - Self::Variable(v) => Type::from(&v.ty), - } - } } crate::newtype!(SymbolTable:HashMap); diff --git a/crates/kyac/src/pass/typecheck.rs b/crates/kyac/src/pass/typecheck.rs index 1d16441..7640603 100644 --- a/crates/kyac/src/pass/typecheck.rs +++ b/crates/kyac/src/pass/typecheck.rs @@ -1,9 +1,5 @@ use crate::{ - ast::{ - node::{self, Ident}, - span::Combined, - Decl, Expr, Stmt, Type, - }, + ast::{node, span::Combined, ty::Type, Decl, Expr, Stmt}, error::PreciseError, pass::{Symbol, SymbolTable}, token::{Kind, Span, Token}, @@ -11,35 +7,6 @@ use crate::{ }; use std::{collections::HashMap, rc::Rc}; -macro_rules! symbol { - ($self:ident, $name:expr, $ty:ident, $s:literal) => { - match $self.symbol(&$name.to_string()) { - Some(Symbol::$ty(v)) => v.clone(), - Some(_) => { - $self.error( - $name.span, - format!("`{}` is not a {}", $name, $s), - "".into(), - ); - return Err(TypeError::NotType($name.clone(), $s)); - } - None => { - $self.error($name.span, format!("`{}` is not defined", $name), "".into()); - return Err(TypeError::Undefined); - } - } - }; -} - -macro_rules! cast { - ($id:expr, $res:expr, $pattern:pat) => { - match $id { - $pattern => $res, - _ => unimplemented!(), - } - }; -} - #[derive(thiserror::Error, Debug)] pub enum TypeError { #[error("undefined variable")] @@ -49,9 +16,9 @@ pub enum TypeError { #[error("cannot {0} {1}")] UnaryMismatch(&'static str, Type), #[error("expected {0}, got {1}")] - Mismatch(Type, Type), - #[error("{0} is not a property of {1}")] - NotProperty(Token, Type), + Mismatch(String, String), + #[error("{0:?} is not a property of {1}")] + NotProperty(Expr, Type), } #[derive(Debug)] @@ -80,12 +47,146 @@ struct TypeResolverContext<'a> { class: Option, } +#[derive(Debug)] +struct ResolvedType { + base: Symbol, + #[allow(dead_code)] + params: Vec, + meta: Type, +} + +impl ResolvedType { + fn new(base: Symbol, params: Vec, meta: Type) -> Self { + Self { base, params, meta } + } + + fn field(&self, symbols: &SymbolTable, field: &Expr) -> Option<(usize, node::Field)> { + match (&self.base, field) { + (Symbol::Class(_), Expr::Ident(ident)) => self + .base + .fields(symbols) + .iter() + .enumerate() + .find(|(_, f)| f.name.to_string() == ident.name.to_string()) + .map(|(i, f)| (i, f.clone())), + _ => None, + } + } + + fn method(&self, symbols: &SymbolTable, method: &Expr) -> Option> { + match (&self.base, method) { + #[allow(clippy::cmp_owned)] + (Symbol::Class(_), Expr::Call(call)) => self + .base + .methods(symbols) + .iter() + .map(|(label, method)| (label.rsplit_once('.').unwrap().1, method)) + .find(|(name, _)| *name == call.left.ident().name.to_string()) + .map(|(_, func)| func) + .cloned(), + _ => None, + } + } + + fn is_numeric(&self) -> bool { + matches!(self.base, Symbol::Int | Symbol::Float) + } + + fn is_bool(&self) -> bool { + matches!(self.base, Symbol::Bool) + } + + fn fake_meta(lexeme: &'static str) -> Type { + Type::new( + Token::new(Kind::Identifier, Some(lexeme), Span::default()), + vec![], + ) + } + + fn str() -> Self { + Self::new(Symbol::Str, vec![], Self::fake_meta("str")) + } + + fn float() -> Self { + Self::new(Symbol::Float, vec![], Self::fake_meta("float")) + } + + fn int() -> Self { + Self::new(Symbol::Int, vec![], Self::fake_meta("int")) + } + + fn bool() -> Self { + Self::new(Symbol::Bool, vec![], Self::fake_meta("bool")) + } + + fn void() -> Self { + Self::new(Symbol::Void, vec![], Self::fake_meta("void")) + } +} + +impl PartialEq for ResolvedType { + fn eq(&self, other: &Self) -> bool { + match &self.base { + Symbol::Int | Symbol::Float | Symbol::Str | Symbol::Bool | Symbol::Void => { + self.meta.base.to_string() == other.meta.base.to_string() + } + Symbol::Class(cls) => { + let other = other.meta.to_string(); + let cls = cls.name.to_string(); + cls == other + } + Symbol::Constant(c) => { + let other = other.meta.to_string(); + let c = c.name.to_string(); + c == other + } + Symbol::Function(f) => { + let other = other.meta.to_string(); + let f = f.name.to_string(); + f == other + } + Symbol::Variable(v) => { + let other = other.meta.to_string(); + let v = v.name.to_string(); + v == other + } + } + } +} + trait ResolveType { fn resolve( &self, cx: &mut TypeResolverContext, meta: &mut ResolvedMetaInfo, - ) -> Result; + ) -> Result; +} + +impl ResolveType for Type { + #[allow(clippy::only_used_in_recursion)] + fn resolve( + &self, + cx: &mut TypeResolverContext, + meta: &mut ResolvedMetaInfo, + ) -> Result { + Ok(ResolvedType::new( + if let Some(symbol) = cx.symbol(&self.base.to_string()) { + symbol.clone() + } else { + cx.error( + self.base.span, + format!("`{}` is not defined", self.base.lexeme.unwrap()), + String::new(), + ); + return Err(TypeError::Undefined); + }, + self.params + .iter() + .map(|p| p.resolve(cx, meta).unwrap()) + .collect(), + self.clone(), + )) + } } impl ResolveType for Decl { @@ -93,7 +194,7 @@ impl ResolveType for Decl { &self, cx: &mut TypeResolverContext, meta: &mut ResolvedMetaInfo, - ) -> Result { + ) -> Result { match self { Decl::Function(fun) => fun.resolve(cx, meta), Decl::Class(cls) => cls.resolve(cx, meta), @@ -107,7 +208,7 @@ impl ResolveType for Stmt { &self, cx: &mut TypeResolverContext, meta: &mut ResolvedMetaInfo, - ) -> Result { + ) -> Result { match self { Stmt::Var(v) => v.resolve(cx, meta), Stmt::Assign(a) => a.resolve(cx, meta), @@ -125,7 +226,7 @@ impl ResolveType for Expr { &self, cx: &mut TypeResolverContext, meta: &mut ResolvedMetaInfo, - ) -> Result { + ) -> Result { match self { Expr::Int(i) => i.resolve(cx, meta), Expr::Float(f) => f.resolve(cx, meta), @@ -147,13 +248,13 @@ impl ResolveType for node::ClassDecl { &self, cx: &mut TypeResolverContext, meta: &mut ResolvedMetaInfo, - ) -> Result { + ) -> Result { cx.class = Some(self.name.clone()); for method in &self.methods { let _ = Decl::Function(Rc::clone(method)).resolve(cx, meta); } cx.class = None; - Ok(Type::Void) + Ok(ResolvedType::void()) } } @@ -162,12 +263,12 @@ impl ResolveType for Rc { &self, cx: &mut TypeResolverContext, meta: &mut ResolvedMetaInfo, - ) -> Result { + ) -> Result { if self.name == "main" { if let Some(ty) = &self.ty { - if ty != "void" { + if !matches!(ty.resolve(cx, meta)?.base, Symbol::Void) { cx.error( - ty.span, + ty.span(), "main function must return void".into(), "try changing or removing this type".into(), ); @@ -192,7 +293,7 @@ impl ResolveType for Rc { } cx.end_scope(); cx.function = None; - Ok(Type::Void) + Ok(ResolvedType::void()) } } @@ -201,14 +302,14 @@ impl ResolveType for node::ConstantDecl { &self, cx: &mut TypeResolverContext, meta: &mut ResolvedMetaInfo, - ) -> Result { + ) -> Result { let got = self.expr.resolve(cx, meta)?; - let expected = Type::from(&self.ty); + let expected = self.ty.resolve(cx, meta)?; if got != expected { cx.error( self.expr.span(), - format!("expected initializer to be of type {expected}"), - format!("expression of type {got}"), + format!("expected initializer to be of type {}", expected.meta), + format!("expression of type {}", got.meta), ); } Ok(expected) @@ -220,22 +321,25 @@ impl ResolveType for Rc { &self, cx: &mut TypeResolverContext, meta: &mut ResolvedMetaInfo, - ) -> Result { + ) -> Result { let got = self.expr.resolve(cx, meta)?; - let expected = Type::from(&self.ty); - if let Type::UserDefined(ref cls) = got { - if got != expected && cx.cast(&expected, cls).is_none() { + let expected = self.ty.resolve(cx, meta)?; + if !matches!( + got.base, + Symbol::Bool | Symbol::Int | Symbol::Float | Symbol::Str | Symbol::Void + ) { + if got != expected && cx.cast(&expected, &got).is_none() { cx.error( self.expr.span(), - format!("{got} is not a subclass of {expected}"), - format!("expression of type {got}"), + format!("{} is not a subclass of {}", got.meta, expected.meta), + format!("expression of type {}", got.meta), ); } } else if got != expected { cx.error( self.expr.span(), - format!("expected initializer to be of type {expected}"), - format!("expression of type {got}"), + format!("expected initializer to be of type {}", expected.meta), + format!("expression of type {}", got.meta), ); } cx.scope_mut() @@ -249,14 +353,17 @@ impl ResolveType for node::For { &self, cx: &mut TypeResolverContext, meta: &mut ResolvedMetaInfo, - ) -> Result { + ) -> Result { self.iter.resolve(cx, meta)?; cx.begin_scope(); cx.scope_mut().insert( self.index.to_string(), Symbol::Variable(Rc::new(node::VarDecl { name: self.index.clone(), - ty: Token::new(Kind::Identifier, Some("int"), self.index.span), + ty: Type::new( + Token::new(Kind::Identifier, Some("int"), Span::default()), + vec![], + ), expr: self.iter.clone(), })), ); @@ -264,7 +371,7 @@ impl ResolveType for node::For { let _ = node.resolve(cx, meta); } cx.end_scope(); - Ok(Type::Void) + Ok(ResolvedType::void()) } } @@ -273,13 +380,13 @@ impl ResolveType for node::While { &self, cx: &mut TypeResolverContext, meta: &mut ResolvedMetaInfo, - ) -> Result { + ) -> Result { let got = self.condition.resolve(cx, meta)?; - if got != Type::Bool { + if !got.is_bool() { cx.error( self.condition.span(), - format!("expected condition of type {}", Type::Bool), - format!("expression of type {got}"), + "expected condition of type bool".into(), + format!("expression of type {}", got.meta), ); } cx.begin_scope(); @@ -287,7 +394,7 @@ impl ResolveType for node::While { let _ = stmt.resolve(cx, meta); } cx.end_scope(); - Ok(Type::Void) + Ok(ResolvedType::void()) } } @@ -296,13 +403,13 @@ impl ResolveType for node::If { &self, cx: &mut TypeResolverContext, meta: &mut ResolvedMetaInfo, - ) -> Result { + ) -> Result { let got = self.condition.resolve(cx, meta)?; - if got != Type::Bool { + if !got.is_bool() { cx.error( self.condition.span(), - format!("expected condition of type {}", Type::Bool), - format!("expression of type {got}"), + "expected condition of type bool".into(), + format!("expression of type {}", got.meta), ); } cx.begin_scope(); @@ -316,7 +423,7 @@ impl ResolveType for node::If { } cx.end_scope(); - Ok(Type::Void) + Ok(ResolvedType::void()) } } @@ -325,30 +432,30 @@ impl ResolveType for node::Unary { &self, cx: &mut TypeResolverContext, meta: &mut ResolvedMetaInfo, - ) -> Result { + ) -> Result { let got = self.expr.resolve(cx, meta)?; match self.op.kind { Kind::Minus => { - if !matches!(got, Type::Int | Type::Float) { + if !got.is_numeric() { cx.error( self.expr.span(), - format!("cannot negate {got}"), - format!("expression of type {got}"), + format!("cannot negate {}", got.meta), + format!("expression of type {}", got.meta), ); - return Err(TypeError::UnaryMismatch("negate", got)); + return Err(TypeError::UnaryMismatch("negate", got.meta)); } Ok(got) } Kind::Bang => { - if got != Type::Bool { + if !got.is_bool() { cx.error( self.expr.span(), - format!("cannot invert {got}"), - format!("expression of type {got}"), + format!("cannot invert {}", got.meta), + format!("expression of type {}", got.meta), ); - return Err(TypeError::UnaryMismatch("invert", got)); + return Err(TypeError::UnaryMismatch("invert", got.meta)); } - Ok(Type::Bool) + Ok(ResolvedType::bool()) } _ => unimplemented!(), } @@ -360,12 +467,12 @@ impl ResolveType for node::Call { &self, cx: &mut TypeResolverContext, meta: &mut ResolvedMetaInfo, - ) -> Result { + ) -> Result { let function = match &*self.left { Expr::Ident(ident) => { let name = ident.name.to_string(); match cx.symbol(&name) { - Some(Symbol::Function(f)) => f.as_ref(), + Some(Symbol::Function(f)) => f, Some(_) => { cx.error( ident.name.span, @@ -412,14 +519,17 @@ impl ResolveType for node::Call { for (i, arg) in self.args.iter().enumerate() { let got = arg.resolve(cx, meta)?; if i < params.len() { - let expected = Type::from(¶ms[i].ty); - if let Type::UserDefined(ref cls) = got { - let casted = cx.cast(&expected, cls); + let expected = params[i].ty.resolve(cx, meta)?; + if !matches!( + got.base, + Symbol::Bool | Symbol::Int | Symbol::Float | Symbol::Str + ) { + let casted = cx.cast(&expected, &got); if got != expected && casted.is_none() { cx.error( arg.span(), - format!("{got} is not a subclass of {expected}"), - format!("expression of type {got}"), + format!("{} is not a subclass of {}", got.meta, expected.meta), + format!("expression of type {}", got.meta), ); } if let Some(cls) = casted { @@ -430,17 +540,20 @@ impl ResolveType for node::Call { } else if got != expected { cx.error( arg.span(), - format!("expected argument of type {expected}, but found {got}"), - format!("expression of type {got}"), + format!( + "expected argument of type {}, but found {}", + expected.meta, got.meta + ), + format!("expression of type {}", got.meta), ); } } } - Ok(if let Some(ty) = ty { - Type::from(&ty) + if let Some(ty) = ty { + ty.resolve(cx, meta) } else { - Type::Void - }) + Ok(ResolvedType::void()) + } } } @@ -449,8 +562,15 @@ impl ResolveType for node::Init { &self, cx: &mut TypeResolverContext, meta: &mut ResolvedMetaInfo, - ) -> Result { - symbol!(cx, self.name, Class, "class"); // ensure class is defined + ) -> Result { + if cx.symbol(&self.name.to_string()).is_none() { + cx.error( + self.name.span, + format!("`{}` is not defined", self.name), + String::new(), + ); + return Err(TypeError::Undefined); + } let fields = cx .symbol(&self.name.to_string()) .unwrap() @@ -458,7 +578,7 @@ impl ResolveType for node::Init { for initializer in &self.initializers { let got = initializer.expr.resolve(cx, meta)?; let expected = if let Some(field) = fields.iter().find(|f| f.name == initializer.name) { - Type::from(&field.ty) + field.ty.resolve(cx, meta)? } else { cx.error( initializer.name.span, @@ -470,12 +590,20 @@ impl ResolveType for node::Init { if got != expected { cx.error( initializer.expr.span(), - format!("expected initializer to be of type {expected}"), - format!("expression of type {got}"), + format!("expected initializer to be of type {}", expected.meta), + format!("expression of type {}", got.meta), ); } } - Ok((&self.name).into()) + let symbol = { + let symbol = cx.symbol(&self.name.to_string()).cloned(); + symbol.unwrap() + }; + Ok(ResolvedType::new( + symbol, + vec![], + Type::new(self.name.clone(), vec![]), + )) } } @@ -484,18 +612,21 @@ impl ResolveType for node::Range { &self, cx: &mut TypeResolverContext, meta: &mut ResolvedMetaInfo, - ) -> Result { + ) -> Result { let start = self.start.resolve(cx, meta)?; let end = self.end.resolve(cx, meta)?; - if start == end && start == Type::Int { - Ok(Type::Int) + if start == end && matches!(start.base, Symbol::Int) { + Ok(ResolvedType::int()) } else { cx.error( self.brackets.0.span, "expected range to be of type [int, int]".into(), - format!("expression of [{start}, {end}]"), + format!("expression of [{}, {}]", start.meta, end.meta), ); - Err(TypeError::Mismatch(Type::Int, start)) + Err(TypeError::Mismatch( + String::from("int"), + start.meta.to_string(), + )) } } } @@ -505,95 +636,59 @@ impl ResolveType for node::Assign { &self, cx: &mut TypeResolverContext, meta: &mut ResolvedMetaInfo, - ) -> Result { + ) -> Result { let expected = self.target.resolve(cx, meta)?; let got = self.expr.resolve(cx, meta)?; if got != expected { cx.error( self.expr.span(), - format!("expected expression of type {expected}"), - format!("expression of type {got}"), + format!("expected expression of type {}", expected.meta), + format!("expression of type {}", got.meta), ); } - Ok(Type::Void) + Ok(ResolvedType::void()) } } -// TODO: make this prettier impl ResolveType for node::Access { fn resolve( &self, cx: &mut TypeResolverContext, meta: &mut ResolvedMetaInfo, - ) -> Result { - fn err( - cx: &mut TypeResolverContext, - kind: &str, - ident: &Ident, - ty: Type, - ) -> Result { - cx.error( - ident.name.span, - format!("no {kind} `{}` on type `{}`", ident.name, ty), - String::new(), - ); - Err(TypeError::NotProperty(ident.name.clone(), ty)) - } - let mut ty = self.chain[0].resolve(cx, meta)?; + ) -> Result { let mut symbols = vec![]; let mut indices = vec![]; - let Some(mut symbol) = cx.symbol(&ty.to_string()).cloned() else { - return Err(TypeError::Undefined); - }; - symbols.push(symbol.clone()); - for (i, pair) in self.chain.windows(2).enumerate() { - let (left, right) = (&pair[0], &pair[1]); - if i != 0 { - let cls = cast!(symbol, r, Symbol::Class(ref r)); - let fields = cx.symbol(&cls.name.to_string()).unwrap().fields(cx.symbols); - if let Expr::Ident(ident) = left { - let field = fields.iter().find(|f| f.name == ident.name); - if let Some(field) = field { - symbol = cx.symbol(&field.ty.to_string()).cloned().unwrap(); - symbols.push(symbol.clone()); - } else { - return err(cx, "field", ident, ty); - } - } else { - todo!("support accesses after method calls") - } - } - let cls = cast!(symbol, r, Symbol::Class(ref r)); - if let Expr::Ident(ident) = right { - let fields = cx.symbol(&cls.name.to_string()).unwrap().fields(cx.symbols); - let index = fields.iter().position(|f| f.name == ident.name); - if let Some(index) = index { - indices.push(index); - ty = Type::from(&fields[index].ty); - } else { - return err(cx, "field", ident, ty); - } + let mut ty = self.chain[0].resolve(cx, meta)?; + for (n, window) in self.chain.windows(2).enumerate() { + let left = &window[0]; + let right = &window[1]; + let left = if n == 0 { left.resolve(cx, meta)? } else { ty }; + if let Some((index, field)) = left.field(cx.symbols, right) { + symbols.push(left.base.clone()); + indices.push(index); + ty = field.ty.resolve(cx, meta)?; + } else if let Some(method) = left.method(cx.symbols, right) { + symbols.push(left.base.clone()); + symbols.push(Symbol::Function(Rc::clone(&method))); + ty = match &method.ty { + Some(ty) => ty.resolve(cx, meta)?, + None => ResolvedType::void(), + }; } else { - let symbol = cx.symbol(&ty.to_string()).unwrap(); - let call = cast!(right, c, Expr::Call(c)); - let ident = call.left.ident(); - let methods = symbol.methods(cx.symbols); - let method = methods - .iter() - // make sure we find the most "specific" implementation - // (i.e. Y.method() before X.method()) - .rev() - .find(|(_, m)| m.name.to_string() == ident.name.to_string()); - if let Some((_, method)) = method { - symbols.push(Symbol::Function(Rc::clone(method))); - ty = Type::from(method.ty.as_ref()); - } else { - return err(cx, "method", ident, ty); - } + cx.error( + right.span(), + format!( + "undefined reference to `{}` (while reading `{}`)", + right.ident().name, + left.meta + ), + String::new(), + ); + return Err(TypeError::NotProperty(right.clone(), left.meta)); } } meta.access - .insert(self.id, Access::new(symbols, indices, ty.clone())); + .insert(self.id, Access::new(symbols, indices, ty.meta.clone())); Ok(ty) } } @@ -603,10 +698,12 @@ impl ResolveType for node::Binary { &self, cx: &mut TypeResolverContext, meta: &mut ResolvedMetaInfo, - ) -> Result { + ) -> Result { let lhs = self.left.resolve(cx, meta)?; let rhs = self.right.resolve(cx, meta)?; if lhs != rhs { + let lhs = lhs.meta; + let rhs = rhs.meta; let heading = match self.op.kind { Kind::Plus => format!("cannot add {lhs} to {rhs}"), Kind::Minus => format!("cannot subtract {rhs} from {lhs}"), @@ -615,7 +712,7 @@ impl ResolveType for node::Binary { _ => format!("cannot compare {lhs} and {rhs}"), }; cx.error(self.op.span, heading, String::new()); - return Err(TypeError::Mismatch(lhs, rhs)); + return Err(TypeError::Mismatch(lhs.to_string(), rhs.to_string())); } if matches!( self.op.kind, @@ -623,7 +720,7 @@ impl ResolveType for node::Binary { ) { Ok(lhs) } else { - Ok(Type::Bool) + Ok(ResolvedType::bool()) } } } @@ -633,7 +730,7 @@ impl ResolveType for node::Return { &self, cx: &mut TypeResolverContext, meta: &mut ResolvedMetaInfo, - ) -> Result { + ) -> Result { let got = self.expr.resolve(cx, meta)?; match &cx.function { Some(function) => { @@ -641,22 +738,31 @@ impl ResolveType for node::Return { .class .as_ref() .map_or(function.to_string(), ToString::to_string); - let symbol = cx.symbol(&symb).unwrap(); + let symbol = cx.symbol(&symb).unwrap().clone(); let expected = match symbol { Symbol::Class(cls) => { let method = cls.methods.iter().find(|m| &m.name == function).unwrap(); - Type::from(method.ty.as_ref()) + method + .ty + .as_ref() + .map_or(ResolvedType::void(), |t| t.resolve(cx, meta).unwrap()) + } + Symbol::Function(f) => { + f.ty.as_ref() + .map_or(ResolvedType::void(), |t| t.resolve(cx, meta).unwrap()) } - Symbol::Function(f) => Type::from(f.ty.as_ref()), _ => unimplemented!(), }; if got != expected { cx.error( self.expr.span(), - format!("expected return type to be {expected}"), - format!("expression is of type {got}"), + format!("expected return type to be {}", expected.meta), + format!("expression is of type {}", got.meta), ); - return Err(TypeError::Mismatch(expected, got)); + return Err(TypeError::Mismatch( + expected.meta.base.to_string(), + got.meta.base.to_string(), + )); } } None => unimplemented!("disallowed by parser"), @@ -669,24 +775,24 @@ impl ResolveType for node::Ident { fn resolve( &self, cx: &mut TypeResolverContext, - _: &mut ResolvedMetaInfo, - ) -> Result { - Ok(match cx.symbol(&self.name.to_string()) { + meta: &mut ResolvedMetaInfo, + ) -> Result { + match cx.symbol(&self.name.to_string()).cloned() { Some(Symbol::Function(f)) => { let param = f.params.iter().find(|p| p.name == self.name).unwrap(); - Type::from(¶m.ty) + param.ty.resolve(cx, meta) } - Some(Symbol::Variable(v)) => Type::from(&v.ty), - Some(Symbol::Constant(c)) => Type::from(&c.ty), + Some(Symbol::Variable(v)) => v.ty.resolve(cx, meta), + Some(Symbol::Constant(c)) => c.ty.resolve(cx, meta), _ => { cx.error( self.name.span, format!("`{}` is not defined", &self.name), String::new(), ); - return Err(TypeError::Undefined); + Err(TypeError::Undefined) } - }) + } } } @@ -695,8 +801,8 @@ impl ResolveType for node::Literal { &self, _: &mut TypeResolverContext, _: &mut ResolvedMetaInfo, - ) -> Result { - Ok(Type::Bool) + ) -> Result { + Ok(ResolvedType::bool()) } } @@ -705,8 +811,8 @@ impl ResolveType for node::Literal { &self, _: &mut TypeResolverContext, _: &mut ResolvedMetaInfo, - ) -> Result { - Ok(Type::Int) + ) -> Result { + Ok(ResolvedType::int()) } } @@ -715,8 +821,8 @@ impl ResolveType for node::Literal { &self, _: &mut TypeResolverContext, _: &mut ResolvedMetaInfo, - ) -> Result { - Ok(Type::Float) + ) -> Result { + Ok(ResolvedType::float()) } } @@ -725,8 +831,8 @@ impl ResolveType for node::Literal<&'static str> { &self, _: &mut TypeResolverContext, _: &mut ResolvedMetaInfo, - ) -> Result { - Ok(Type::Str) + ) -> Result { + Ok(ResolvedType::str()) } } @@ -793,7 +899,14 @@ impl<'a> TypeResolverContext<'a> { return Some(definition); } } - self.symbols.get(name) + match &name[..] { + "int" => Some(&Symbol::Int), + "float" => Some(&Symbol::Float), + "str" => Some(&Symbol::Str), + "bool" => Some(&Symbol::Bool), + "void" => Some(&Symbol::Void), + _ => self.symbols.get(name), + } } fn error(&mut self, at: Span, heading: String, text: String) { @@ -802,13 +915,14 @@ impl<'a> TypeResolverContext<'a> { self.errors.push(error); } - fn cast(&self, expected: &Type, cls: &String) -> Option { - let cls = self.symbol(cls).unwrap().class(); + fn cast(&self, expected: &ResolvedType, got: &ResolvedType) -> Option { + let cls = self.symbol(&got.meta.to_string())?; + let cls = cls.class(); Symbol::superclasses(cls, self.symbols) .iter() .filter(|c| c.name != cls.name) .map(|c| c.name.to_string()) - .find(|cls| cls == &expected.to_string()) + .find(|cls| cls == &expected.meta.to_string()) } } diff --git a/crates/kyanite/tests/mod.rs b/crates/kyanite/tests/mod.rs index f5ecb31..54dc912 100644 --- a/crates/kyanite/tests/mod.rs +++ b/crates/kyanite/tests/mod.rs @@ -1,4 +1,5 @@ mod kyir; +#[cfg(feature = "llvm")] mod llvm; pub fn path(name: &str) -> Result> { diff --git a/examples/kyir/parametric-polymorphism.kya b/examples/kyir/parametric-polymorphism.kya new file mode 100644 index 0000000..98bee0e --- /dev/null +++ b/examples/kyir/parametric-polymorphism.kya @@ -0,0 +1,31 @@ +class Print { + fun print(self) { + println_str("implement me"); + } +} + +class Bar: Print { + x: int + + fun print() { + println_int(self.x); + } +} + +class Foo { + val: T + + fun print(self) { + self.val.print(); + } +} + +fun freeFunction(obj: T) { + obj.print(); +} + +fun main() { + let bar: Bar = Bar:init(x: 42); + let foo: Foo = Foo:init(val: bar); + foo.print(); +} \ No newline at end of file From 2e7ce26b514e3982d57d6813367f6c68eae119f7 Mon Sep 17 00:00:00 2001 From: Alaina <68250402+alaidriel@users.noreply.github.com> Date: Mon, 25 Mar 2024 10:20:45 -0500 Subject: [PATCH 2/5] feat(typecheck): explicitly require `self` params --- crates/kyac/src/pass/typecheck.rs | 8 ++++++++ 1 file changed, 8 insertions(+) diff --git a/crates/kyac/src/pass/typecheck.rs b/crates/kyac/src/pass/typecheck.rs index 7640603..e412ecb 100644 --- a/crates/kyac/src/pass/typecheck.rs +++ b/crates/kyac/src/pass/typecheck.rs @@ -284,6 +284,14 @@ impl ResolveType for Rc { "try removing some parameters".into(), ); } + let self_param = self.params.iter().position(|p| p.name == "self"); + if cx.class.is_some() && (self_param.is_none() || self_param.unwrap() != 0) { + cx.error( + self.name.span, + "first parameter must be `self`".into(), + "try adding `self` as the first parameter".into(), + ); + } for param in &self.params { cx.scope_mut() .insert(param.name.to_string(), Symbol::Function(Rc::clone(self))); From c41a78b0a602c8fa1685b2fa9d2c158330072aad Mon Sep 17 00:00:00 2001 From: Alaina <68250402+alaidriel@users.noreply.github.com> Date: Thu, 28 Mar 2024 10:46:00 -0500 Subject: [PATCH 3/5] feat: generics MVP (working!) --- crates/kyac/src/ast/mod.rs | 12 +- crates/kyac/src/backend/kyir/translate/mod.rs | 6 +- crates/kyac/src/pass/symbol.rs | 1 + crates/kyac/src/pass/typecheck.rs | 182 ++++++++++++++---- examples/kyir/parametric-polymorphism.kya | 3 +- 5 files changed, 158 insertions(+), 46 deletions(-) diff --git a/crates/kyac/src/ast/mod.rs b/crates/kyac/src/ast/mod.rs index ae780eb..885b771 100644 --- a/crates/kyac/src/ast/mod.rs +++ b/crates/kyac/src/ast/mod.rs @@ -5,7 +5,7 @@ mod strip; pub mod ty; use crate::{parse::Parser, token::Lexer, PipelineError, Source}; -use std::rc::Rc; +use std::{fmt, rc::Rc}; #[derive(Debug)] pub struct Ast { @@ -94,6 +94,16 @@ impl Expr { } } +impl fmt::Display for Expr { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + match self { + Expr::Call(call) => write!(f, "{}", call.left), + Expr::Ident(ident) => write!(f, "{}", ident.name), + _ => unimplemented!(), + } + } +} + macro_rules! assert_ast { ($($path:expr => $name:ident),*) => { #[cfg(test)] diff --git a/crates/kyac/src/backend/kyir/translate/mod.rs b/crates/kyac/src/backend/kyir/translate/mod.rs index e1218fe..1148646 100644 --- a/crates/kyac/src/backend/kyir/translate/mod.rs +++ b/crates/kyac/src/backend/kyir/translate/mod.rs @@ -170,11 +170,7 @@ impl Translate for ast::node::Call { Symbol::Function(fun) => fun.name.to_string(), Symbol::Class(c) => { cls = Some(c); - if let Some(cls) = translator.meta.call.get(&self.id) { - cls.to_string() - } else { - c.name.to_string() - } + c.name.to_string() } _ => unimplemented!(), }).fold(String::new(), |acc, item| format!("{acc}{item}.")); diff --git a/crates/kyac/src/pass/symbol.rs b/crates/kyac/src/pass/symbol.rs index 1b894f6..98a6631 100644 --- a/crates/kyac/src/pass/symbol.rs +++ b/crates/kyac/src/pass/symbol.rs @@ -13,6 +13,7 @@ pub enum Symbol { Function(Rc), Constant(Rc), Variable(Rc), + Opaque(String), Str, Int, Float, diff --git a/crates/kyac/src/pass/typecheck.rs b/crates/kyac/src/pass/typecheck.rs index e412ecb..19ed6c6 100644 --- a/crates/kyac/src/pass/typecheck.rs +++ b/crates/kyac/src/pass/typecheck.rs @@ -1,5 +1,10 @@ use crate::{ - ast::{node, span::Combined, ty::Type, Decl, Expr, Stmt}, + ast::{ + node, + span::Combined, + ty::{Type, TypeParameter}, + Decl, Expr, Stmt, + }, error::PreciseError, pass::{Symbol, SymbolTable}, token::{Kind, Span, Token}, @@ -42,7 +47,7 @@ struct TypeResolverContext<'a> { source: &'a Source, symbols: &'a SymbolTable, errors: Vec>, - scopes: Vec, + scopes: Vec, function: Option, class: Option, } @@ -60,15 +65,19 @@ impl ResolvedType { Self { base, params, meta } } - fn field(&self, symbols: &SymbolTable, field: &Expr) -> Option<(usize, node::Field)> { + fn field( + &self, + symbols: &SymbolTable, + field: &Expr, + ) -> Option<(usize, &Rc, node::Field)> { match (&self.base, field) { - (Symbol::Class(_), Expr::Ident(ident)) => self + (Symbol::Class(cls), Expr::Ident(ident)) => self .base .fields(symbols) .iter() .enumerate() .find(|(_, f)| f.name.to_string() == ident.name.to_string()) - .map(|(i, f)| (i, f.clone())), + .map(|(i, f)| (i, cls, f.clone())), _ => None, } } @@ -96,7 +105,7 @@ impl ResolvedType { matches!(self.base, Symbol::Bool) } - fn fake_meta(lexeme: &'static str) -> Type { + fn meta(lexeme: &'static str) -> Type { Type::new( Token::new(Kind::Identifier, Some(lexeme), Span::default()), vec![], @@ -104,23 +113,23 @@ impl ResolvedType { } fn str() -> Self { - Self::new(Symbol::Str, vec![], Self::fake_meta("str")) + Self::new(Symbol::Str, vec![], Self::meta("str")) } fn float() -> Self { - Self::new(Symbol::Float, vec![], Self::fake_meta("float")) + Self::new(Symbol::Float, vec![], Self::meta("float")) } fn int() -> Self { - Self::new(Symbol::Int, vec![], Self::fake_meta("int")) + Self::new(Symbol::Int, vec![], Self::meta("int")) } fn bool() -> Self { - Self::new(Symbol::Bool, vec![], Self::fake_meta("bool")) + Self::new(Symbol::Bool, vec![], Self::meta("bool")) } fn void() -> Self { - Self::new(Symbol::Void, vec![], Self::fake_meta("void")) + Self::new(Symbol::Void, vec![], Self::meta("void")) } } @@ -130,6 +139,7 @@ impl PartialEq for ResolvedType { Symbol::Int | Symbol::Float | Symbol::Str | Symbol::Bool | Symbol::Void => { self.meta.base.to_string() == other.meta.base.to_string() } + Symbol::Opaque(s) => s == other.meta.base.lexeme.unwrap(), Symbol::Class(cls) => { let other = other.meta.to_string(); let cls = cls.name.to_string(); @@ -172,6 +182,8 @@ impl ResolveType for Type { Ok(ResolvedType::new( if let Some(symbol) = cx.symbol(&self.base.to_string()) { symbol.clone() + } else if let Some(ty) = cx.ty(&self.base.to_string()) { + ty.base.clone() } else { cx.error( self.base.span, @@ -197,7 +209,12 @@ impl ResolveType for Decl { ) -> Result { match self { Decl::Function(fun) => fun.resolve(cx, meta), - Decl::Class(cls) => cls.resolve(cx, meta), + Decl::Class(cls) => { + cx.class = Some(cls.name.clone()); + let resolved = cls.resolve(cx, meta); + cx.class = None; + resolved + } Decl::Constant(c) => c.resolve(cx, meta), } } @@ -249,11 +266,12 @@ impl ResolveType for node::ClassDecl { cx: &mut TypeResolverContext, meta: &mut ResolvedMetaInfo, ) -> Result { - cx.class = Some(self.name.clone()); + cx.begin_scope(); + cx.set_type_parameters(meta, self.tp.as_ref()); for method in &self.methods { let _ = Decl::Function(Rc::clone(method)).resolve(cx, meta); } - cx.class = None; + cx.end_scope(); Ok(ResolvedType::void()) } } @@ -276,6 +294,7 @@ impl ResolveType for Rc { } } cx.begin_scope(); + cx.set_type_parameters(meta, Some(&self.tp)); cx.function = Some(self.name.clone()); if self.params.len() > 8 { cx.error( @@ -294,6 +313,7 @@ impl ResolveType for Rc { } for param in &self.params { cx.scope_mut() + .symbols .insert(param.name.to_string(), Symbol::Function(Rc::clone(self))); } for node in &self.body { @@ -351,6 +371,7 @@ impl ResolveType for Rc { ); } cx.scope_mut() + .symbols .insert(self.name.to_string(), Symbol::Variable(Rc::clone(self))); Ok(expected) } @@ -364,7 +385,7 @@ impl ResolveType for node::For { ) -> Result { self.iter.resolve(cx, meta)?; cx.begin_scope(); - cx.scope_mut().insert( + cx.scope_mut().symbols.insert( self.index.to_string(), Symbol::Variable(Rc::new(node::VarDecl { name: self.index.clone(), @@ -528,10 +549,7 @@ impl ResolveType for node::Call { let got = arg.resolve(cx, meta)?; if i < params.len() { let expected = params[i].ty.resolve(cx, meta)?; - if !matches!( - got.base, - Symbol::Bool | Symbol::Int | Symbol::Float | Symbol::Str - ) { + if matches!(got.base, Symbol::Class(_)) { let casted = cx.cast(&expected, &got); if got != expected && casted.is_none() { cx.error( @@ -540,11 +558,6 @@ impl ResolveType for node::Call { format!("expression of type {}", got.meta), ); } - if let Some(cls) = casted { - // This call actually refers to a parent (inherited) method call, so we need - // to switch out the branch label later - meta.call.insert(self.id, cls); - } } else if got != expected { cx.error( arg.span(), @@ -571,6 +584,7 @@ impl ResolveType for node::Init { cx: &mut TypeResolverContext, meta: &mut ResolvedMetaInfo, ) -> Result { + cx.begin_scope(); if cx.symbol(&self.name.to_string()).is_none() { cx.error( self.name.span, @@ -579,10 +593,23 @@ impl ResolveType for node::Init { ); return Err(TypeError::Undefined); } - let fields = cx + cx.set_type_parameters( + meta, + cx.symbol(&self.name.to_string()) + .unwrap() + .class() + .tp + .clone() + .as_ref(), + ); + let typ = cx .symbol(&self.name.to_string()) .unwrap() - .fields(cx.symbols); + .class() + .tp + .clone(); + let symbol = cx.symbol(&self.name.to_string()).unwrap(); + let fields = symbol.fields(cx.symbols); for initializer in &self.initializers { let got = initializer.expr.resolve(cx, meta)?; let expected = if let Some(field) = fields.iter().find(|f| f.name == initializer.name) { @@ -595,7 +622,33 @@ impl ResolveType for node::Init { ); continue; }; - if got != expected { + let valid = if let Some(ref typ) = typ { + if let Some(ty) = typ + .iter() + .find(|t| t.name == expected.meta.base.lexeme.unwrap()) + { + let uncastable = |bound: &Token| { + let raw_type = Type::new(bound.clone(), vec![]); + let expected = raw_type.resolve(cx, meta).unwrap(); + cx.cast(&expected, &got).is_none() + }; + if ty.bound.as_ref().is_some_and(uncastable) { + cx.error( + initializer.expr.span(), + format!("expected initializer to be of type {}", expected.meta), + format!("expression of type {}", got.meta), + ); + false + } else { + true + } + } else { + true + } + } else { + got == expected + }; + if !valid { cx.error( initializer.expr.span(), format!("expected initializer to be of type {}", expected.meta), @@ -607,6 +660,7 @@ impl ResolveType for node::Init { let symbol = cx.symbol(&self.name.to_string()).cloned(); symbol.unwrap() }; + cx.end_scope(); Ok(ResolvedType::new( symbol, vec![], @@ -670,16 +724,20 @@ impl ResolveType for node::Access { for (n, window) in self.chain.windows(2).enumerate() { let left = &window[0]; let right = &window[1]; - let left = if n == 0 { left.resolve(cx, meta)? } else { ty }; - if let Some((index, field)) = left.field(cx.symbols, right) { + let left = if n == 0 { + left.resolve(cx, meta).unwrap() + } else { + ty + }; + if let Some((index, _, field)) = left.field(cx.symbols, right) { symbols.push(left.base.clone()); indices.push(index); - ty = field.ty.resolve(cx, meta)?; + ty = field.ty.resolve(cx, meta).unwrap(); } else if let Some(method) = left.method(cx.symbols, right) { symbols.push(left.base.clone()); symbols.push(Symbol::Function(Rc::clone(&method))); ty = match &method.ty { - Some(ty) => ty.resolve(cx, meta)?, + Some(ty) => ty.resolve(cx, meta).unwrap(), None => ResolvedType::void(), }; } else { @@ -687,8 +745,7 @@ impl ResolveType for node::Access { right.span(), format!( "undefined reference to `{}` (while reading `{}`)", - right.ident().name, - left.meta + right, left.meta ), String::new(), ); @@ -847,14 +904,12 @@ impl ResolveType for node::Literal<&'static str> { #[derive(Debug)] pub struct ResolvedMetaInfo { pub access: HashMap, - pub call: HashMap, } impl ResolvedMetaInfo { pub fn new() -> Self { Self { access: HashMap::new(), - call: HashMap::new(), } } } @@ -889,21 +944,54 @@ impl<'a> TypeResolverContext<'a> { } } - fn scope_mut(&mut self) -> &mut SymbolTable { + fn scope_mut(&mut self) -> &mut Scope { self.scopes.last_mut().unwrap() } fn begin_scope(&mut self) { - self.scopes.push(SymbolTable::default()); + self.scopes.push(Scope::default()); } fn end_scope(&mut self) { self.scopes.pop(); } + fn set_type_parameters( + &mut self, + meta: &mut ResolvedMetaInfo, + tp: Option<&Vec>, + ) { + if let Some(tp) = tp { + for typ in tp { + let ty = match typ.bound { + Some(ref bound) => { + let raw_type = Type::new(bound.clone(), vec![]); + raw_type.resolve(self, meta).map_or_else( + |_| { + self.error( + bound.span, + format!("`{}` is not defined", bound.lexeme.unwrap()), + String::new(), + ); + None + }, + Some, + ) + } + None => Some(ResolvedType::new( + Symbol::Opaque(typ.name.to_string()), + vec![], + Type::new(typ.name.clone(), vec![]), + )), + }; + self.scope_mut().types.insert(typ.name.to_string(), ty); + } + } + } + fn symbol(&self, name: &String) -> Option<&Symbol> { for scope in self.scopes.iter().rev() { - if let Some(definition) = scope.get(name) { + if let Some(definition) = scope.symbols.get(name) { return Some(definition); } } @@ -913,7 +1001,17 @@ impl<'a> TypeResolverContext<'a> { "str" => Some(&Symbol::Str), "bool" => Some(&Symbol::Bool), "void" => Some(&Symbol::Void), - _ => self.symbols.get(name), + _ => self + .symbols + .get(name) + .or_else(|| self.ty(name).map(|t| &t.base)), + } + } + + fn ty(&self, name: &String) -> Option<&ResolvedType> { + match self.scopes.iter().rev().find_map(|s| s.types.get(name)) { + Some(Some(ty)) => Some(ty), + _ => None, } } @@ -934,6 +1032,12 @@ impl<'a> TypeResolverContext<'a> { } } +#[derive(Debug, Default)] +struct Scope { + symbols: SymbolTable, + types: HashMap>, +} + macro_rules! assert_typecheck { ($($path:expr => $name:ident),*) => { #[cfg(test)] diff --git a/examples/kyir/parametric-polymorphism.kya b/examples/kyir/parametric-polymorphism.kya index 98bee0e..70a4ead 100644 --- a/examples/kyir/parametric-polymorphism.kya +++ b/examples/kyir/parametric-polymorphism.kya @@ -7,7 +7,7 @@ class Print { class Bar: Print { x: int - fun print() { + fun print(self) { println_int(self.x); } } @@ -28,4 +28,5 @@ fun main() { let bar: Bar = Bar:init(x: 42); let foo: Foo = Foo:init(val: bar); foo.print(); + bar.print(); } \ No newline at end of file From 4cf0a643dc1c38f1c0764fd910a37171b0bd012b Mon Sep 17 00:00:00 2001 From: Alaina <68250402+alaidriel@users.noreply.github.com> Date: Thu, 28 Mar 2024 11:01:31 -0500 Subject: [PATCH 4/5] fix: method inheritance/overrides --- crates/kyac/src/backend/kyir/translate/mod.rs | 15 ++++++++++++--- crates/kyanite/tests/kyir/mod.rs | 2 +- examples/kyir/method-inheritance.kya | 4 ---- 3 files changed, 13 insertions(+), 8 deletions(-) diff --git a/crates/kyac/src/backend/kyir/translate/mod.rs b/crates/kyac/src/backend/kyir/translate/mod.rs index 1148646..234459a 100644 --- a/crates/kyac/src/backend/kyir/translate/mod.rs +++ b/crates/kyac/src/backend/kyir/translate/mod.rs @@ -191,7 +191,10 @@ impl Translate for ast::node::Call { let frame = translator.functions.get_mut(&id).unwrap(); let saved = frame.allocate(&temp, ptr); let mut stmts = vec![]; - let address = if cls.is_some_and(|cls| Symbol::has_subclass(cls, translator.symbols)) { + let address = if cls.is_some_and(|cls| { + Symbol::has_subclass(cls, translator.symbols) + || !Symbol::superclasses(cls, translator.symbols).is_empty() + }) { // This call could be overridden by a subclass in which case we need to use dynamic dispatch. // we just checked is_some_and, so this is safe. let cls = cls.unwrap(); @@ -199,8 +202,14 @@ impl Translate for ast::node::Call { let arr = Temp::next(); let address = Temp::next(); let (_, n) = name.rsplit_once('.').unwrap(); - let index = - F::word_size() * (cls.methods.iter().position(|m| m.name == n).unwrap() + 1); + let symbol = Symbol::Class(cls.clone()); + let index = F::word_size() + * (symbol + .methods(translator.symbols) + .iter() + .position(|(_, m)| m.name == n) + .unwrap() + + 1); stmts.append(&mut vec![ Move::wrapped( Temp::wrapped(arr.clone()), diff --git a/crates/kyanite/tests/kyir/mod.rs b/crates/kyanite/tests/kyir/mod.rs index 87ad3e2..b467f42 100644 --- a/crates/kyanite/tests/kyir/mod.rs +++ b/crates/kyanite/tests/kyir/mod.rs @@ -322,7 +322,7 @@ fn method_override() -> Result<(), Box> { let res = run("kyir/method-override.kya")?; assert_eq!( res.output, - "inside `X.show()`\n2\ninside `Y.show()`\n6\ninside `Z.show()`\n1\n5\n" + "inside `Y.show()`\n6\ninside `Y.show()`\n6\ninside `Z.show()`\n1\n5\n" ); Ok(()) } diff --git a/examples/kyir/method-inheritance.kya b/examples/kyir/method-inheritance.kya index 1494f74..ee87e74 100644 --- a/examples/kyir/method-inheritance.kya +++ b/examples/kyir/method-inheritance.kya @@ -20,10 +20,6 @@ class Y: X { } } -class Z { - y: int -} - fun main() { let y: Y = Y:init(y: 6, x: 2); % x is an implicit parameter to init % this is safe (and therefore valid) because `Y` is a subclass of `X` From 32b744ea14540ed1686dfd4bad7aa36190985ca6 Mon Sep 17 00:00:00 2001 From: Alaina <68250402+alaidriel@users.noreply.github.com> Date: Thu, 28 Mar 2024 11:01:51 -0500 Subject: [PATCH 5/5] chore: set `RUST_BACKTRACE` --- nix/package.nix | 1 + 1 file changed, 1 insertion(+) diff --git a/nix/package.nix b/nix/package.nix index c644f16..aa33cae 100644 --- a/nix/package.nix +++ b/nix/package.nix @@ -41,6 +41,7 @@ in --prefix PATH : ${lib.makeBinPath [llvmPackages_15.libllvm]} ''; RUSTFLAGS = "-C link-arg=-lc++abi"; # https://github.com/NixOS/nixpkgs/issues/166205 + RUST_BACKTRACE = "1"; LLVM_SYS_150_PREFIX = llvmPackages_15.libllvm.dev; meta = with lib; { description = "A toy compiled programming language to learn more about PLs";