From 97385cd16ad260d856ca063fd5960bb544075b81 Mon Sep 17 00:00:00 2001 From: Alex Hansen Date: Thu, 14 Nov 2024 06:19:15 -0800 Subject: [PATCH] Add class constraints for built-in classes (#2007) This PR exposes some relevant built-in classes via class constraints. # Changes - Add `TypeParameter`, `ClassConstraint`, and `ConstraintParameter` AST nodes to represent the notion of a generic type that might have class constraints, and constraints that may have parameters (e.g. `Exp[Int]`). - update visitors accordingly - `Ty` can now contain arbitrary constraints - Add notion of primitive classes to the AST - Update constraint checking algorithm to check type parameters for primitive constraints at callable decl time - Update constraint generation algorithm to add additional user-specified constraints to arguments that are passed in to callables (`fn constrained_ty`) - Add completions for primitive classes to the language service - Adds a sample describing class constraints tiny changes: - Add docstrings discretionally - Remove some unnecessary derives - Tried to clarify and unify jargon around classes, constraints, parameters, and arguments in general - Some expect_tests were updated in minor ways as their underlying types changed (`Ty` and `TyParam`, for example) - Improve an error message for `MissingTy` (it mentioned something about global types not being inferred when 99% of the time it does not apply to a global type, it applies to a callable parameter with a missing type) --- compiler/qsc_ast/src/ast.rs | 146 +++- compiler/qsc_ast/src/mut_visit.rs | 30 +- compiler/qsc_ast/src/visit.rs | 31 +- compiler/qsc_codegen/src/qsharp.rs | 6 +- compiler/qsc_doc_gen/src/display.rs | 60 +- compiler/qsc_eval/src/lib.rs | 4 +- compiler/qsc_fir/src/fir.rs | 8 +- compiler/qsc_fir/src/ty.rs | 91 +- compiler/qsc_formatter/src/formatter.rs | 4 +- compiler/qsc_frontend/src/lower.rs | 271 ++++-- compiler/qsc_frontend/src/lower/tests.rs | 6 +- compiler/qsc_frontend/src/resolve.rs | 60 +- compiler/qsc_frontend/src/resolve/tests.rs | 22 +- compiler/qsc_frontend/src/typeck.rs | 91 +- compiler/qsc_frontend/src/typeck/check.rs | 40 +- compiler/qsc_frontend/src/typeck/convert.rs | 292 +++++-- compiler/qsc_frontend/src/typeck/infer.rs | 352 +++++++- compiler/qsc_frontend/src/typeck/rules.rs | 56 +- compiler/qsc_frontend/src/typeck/tests.rs | 21 +- .../src/typeck/tests/bounded_polymorphism.rs | 824 ++++++++++++++++++ compiler/qsc_hir/src/hir.rs | 8 +- compiler/qsc_hir/src/ty.rs | 151 +++- compiler/qsc_lowerer/src/lib.rs | 71 +- .../qsc_parse/src/completion/word_kinds.rs | 15 +- compiler/qsc_parse/src/item/tests.rs | 32 +- compiler/qsc_parse/src/ty.rs | 56 +- compiler/qsc_parse/src/ty/tests.rs | 6 +- compiler/qsc_passes/src/loop_unification.rs | 4 +- language_service/src/completion.rs | 43 + language_service/src/completion/tests.rs | 2 + .../src/completion/tests/class_completions.rs | 182 ++++ language_service/src/name_locator.rs | 19 +- language_service/src/protocol.rs | 1 + language_service/src/references.rs | 14 +- npm/qsharp/test/basics.js | 2 +- playground/src/main.tsx | 3 + samples/language/ClassConstraints.qs | 58 ++ samples_test/src/tests/language.rs | 16 + vscode/src/completion.ts | 3 + wasm/src/language_service.rs | 3 +- wasm/src/tests.rs | 2 +- 41 files changed, 2725 insertions(+), 381 deletions(-) create mode 100644 compiler/qsc_frontend/src/typeck/tests/bounded_polymorphism.rs create mode 100644 language_service/src/completion/tests/class_completions.rs create mode 100644 samples/language/ClassConstraints.qs diff --git a/compiler/qsc_ast/src/ast.rs b/compiler/qsc_ast/src/ast.rs index 021e107891..21b3fff218 100644 --- a/compiler/qsc_ast/src/ast.rs +++ b/compiler/qsc_ast/src/ast.rs @@ -486,7 +486,7 @@ pub struct CallableDecl { /// The name of the callable. pub name: Box, /// The generic parameters to the callable. - pub generics: Box<[Box]>, + pub generics: Box<[TypeParameter]>, /// The input to the callable. pub input: Box, /// The return type of the callable. @@ -510,9 +510,13 @@ impl Display for CallableDecl { if !self.generics.is_empty() { write!(indent, "\ngenerics:")?; indent = set_indentation(indent, 2); + let mut buf = Vec::with_capacity(self.generics.len()); for param in &self.generics { - write!(indent, "\n{param}")?; + buf.push(format!("{param}")); } + + let buf = buf.join(",\n"); + write!(indent, "\n{buf}")?; indent = set_indentation(indent, 1); } write!(indent, "\ninput: {}", self.input)?; @@ -674,7 +678,7 @@ pub enum TyKind { /// A named type. Path(PathKind), /// A type parameter. - Param(Box), + Param(TypeParameter), /// A tuple type. Tuple(Box<[Ty]>), /// An invalid type. @@ -1968,3 +1972,139 @@ impl ImportOrExportItem { } } } + +/// A [`TypeParameter`] is a generic type variable with optional bounds (constraints). +#[derive(Default, Debug, PartialEq, Eq, Clone, Hash)] +pub struct TypeParameter { + /// Class constraints specified for this type parameter -- any type variable passed in + /// as an argument to these parameters must satisfy these constraints. + pub constraints: ClassConstraints, + /// The name of the type parameter. + pub ty: Ident, + /// The span of the full type parameter, including its name and its constraints. + pub span: Span, +} + +impl WithSpan for TypeParameter { + fn with_span(self, span: Span) -> Self { + Self { span, ..self } + } +} + +impl TypeParameter { + /// Instantiates a new `TypeParameter` with the given type name, constraints, and span. + #[must_use] + pub fn new(ty: Ident, bounds: ClassConstraints, span: Span) -> Self { + Self { + ty, + constraints: bounds, + span, + } + } +} + +impl std::fmt::Display for TypeParameter { + fn fmt(&self, f: &mut Formatter<'_>) -> fmt::Result { + // 'A: Eq + Ord + Clone + write!( + f, + "{}{}", + self.ty.name, + if self.constraints.0.is_empty() { + Default::default() + } else { + format!(": {}", self.constraints) + } + ) + } +} + +/// A list of class constraints, used when constraining a type parameter. +#[derive(Default, Debug, PartialEq, Eq, Clone, Hash)] +pub struct ClassConstraints(pub Box<[ClassConstraint]>); + +/// An individual class constraint, used when constraining a type parameter. +/// To understand this concept, think of parameters in a function signature -- the potential arguments that can +/// be passed to them are constrained by what type is specified. Type-level parameters are no different, and +/// the type variables that are passed to a type parameter must satisfy the constraints specified in the type parameter. +#[derive(PartialEq, Eq, Clone, Hash, Debug)] +pub struct ClassConstraint { + /// The name of the constraint. + pub name: Ident, + /// Parameters for a constraint. For example, `Iterator` has a parameter `T` in `Iterator` -- this + /// is the type of the item that is coming out of the iterator. + pub parameters: Box<[ConstraintParameter]>, +} + +impl std::fmt::Display for ClassConstraint { + fn fmt(&self, f: &mut Formatter<'_>) -> fmt::Result { + // Iterator + write!( + f, + "{}{}", + self.name.name, + if self.parameters.is_empty() { + String::new() + } else { + format!( + "[{}]", + self.parameters + .iter() + .map(|x| x.ty.to_string()) + .collect::>() + .join(", ") + ) + } + ) + } +} + +/// An individual constraint parameter is a type that is passed to a constraint, such as `T` in `Iterator`. +/// #[derive(Default, `PartialEq`, Eq, Clone, Hash, Debug)] +#[derive(Default, PartialEq, Eq, Clone, Hash, Debug)] +pub struct ConstraintParameter { + /// The type variable being passed as a constraint parameter. + pub ty: Ty, +} + +impl WithSpan for ConstraintParameter { + fn with_span(self, span: Span) -> Self { + Self { + ty: self.ty.with_span(span), + } + } +} + +impl ClassConstraint { + /// Getter for the `span` field of the `name` field (the name of the class constraint). + #[must_use] + pub fn span(&self) -> Span { + self.name.span + } +} + +impl ClassConstraints { + /// The conjoined span of all of the bounds + #[must_use] + pub fn span(&self) -> Span { + Span { + lo: self.0.first().map(|i| i.span().lo).unwrap_or_default(), + hi: self.0.last().map(|i| i.span().hi).unwrap_or_default(), + } + } +} + +impl std::fmt::Display for ClassConstraints { + fn fmt(&self, f: &mut Formatter<'_>) -> fmt::Result { + // A + B + C + D + write!( + f, + "{}", + self.0 + .iter() + .map(|x| format!("{}", x.name.name,)) + .collect::>() + .join(" + "), + ) + } +} diff --git a/compiler/qsc_ast/src/mut_visit.rs b/compiler/qsc_ast/src/mut_visit.rs index dafac164c0..92da11515e 100644 --- a/compiler/qsc_ast/src/mut_visit.rs +++ b/compiler/qsc_ast/src/mut_visit.rs @@ -5,7 +5,7 @@ use crate::ast::{ Attr, Block, CallableBody, CallableDecl, Expr, ExprKind, FieldAccess, FieldAssign, FieldDef, FunctorExpr, FunctorExprKind, Ident, Item, ItemKind, Namespace, Package, Pat, PatKind, Path, PathKind, QubitInit, QubitInitKind, SpecBody, SpecDecl, Stmt, StmtKind, StringComponent, - StructDecl, TopLevelNode, Ty, TyDef, TyDefKind, TyKind, + StructDecl, TopLevelNode, Ty, TyDef, TyDefKind, TyKind, TypeParameter, }; use qsc_data_structures::span::Span; @@ -164,7 +164,17 @@ pub fn walk_ty_def(vis: &mut impl MutVisitor, def: &mut TyDef) { pub fn walk_callable_decl(vis: &mut impl MutVisitor, decl: &mut CallableDecl) { vis.visit_span(&mut decl.span); vis.visit_ident(&mut decl.name); - decl.generics.iter_mut().for_each(|p| vis.visit_ident(p)); + decl.generics.iter_mut().for_each(|p| { + vis.visit_ident(&mut p.ty); + p.constraints.0.iter_mut().for_each(|b| { + vis.visit_ident(&mut b.name); + b.parameters + .iter_mut() + .for_each(|crate::ast::ConstraintParameter { ty, .. }| { + vis.visit_ty(ty); + }); + }); + }); vis.visit_pat(&mut decl.input); vis.visit_ty(&mut decl.output); decl.functors @@ -226,7 +236,21 @@ pub fn walk_ty(vis: &mut impl MutVisitor, ty: &mut Ty) { } TyKind::Hole | TyKind::Err => {} TyKind::Paren(ty) => vis.visit_ty(ty), - TyKind::Param(name) => vis.visit_ident(name), + TyKind::Param(TypeParameter { + ty, + constraints: bounds, + .. + }) => { + for bound in &mut bounds.0 { + vis.visit_ident(&mut bound.name); + bound.parameters.iter_mut().for_each( + |crate::ast::ConstraintParameter { ref mut ty, .. }| { + vis.visit_ty(ty); + }, + ); + } + vis.visit_ident(ty); + } TyKind::Path(path) => vis.visit_path_kind(path), TyKind::Tuple(tys) => tys.iter_mut().for_each(|t| vis.visit_ty(t)), } diff --git a/compiler/qsc_ast/src/visit.rs b/compiler/qsc_ast/src/visit.rs index 4e87499ad7..e680d1e972 100644 --- a/compiler/qsc_ast/src/visit.rs +++ b/compiler/qsc_ast/src/visit.rs @@ -5,7 +5,7 @@ use crate::ast::{ Attr, Block, CallableBody, CallableDecl, Expr, ExprKind, FieldAccess, FieldAssign, FieldDef, FunctorExpr, FunctorExprKind, Ident, Item, ItemKind, Namespace, Package, Pat, PatKind, Path, PathKind, QubitInit, QubitInitKind, SpecBody, SpecDecl, Stmt, StmtKind, StringComponent, - StructDecl, TopLevelNode, Ty, TyDef, TyDefKind, TyKind, + StructDecl, TopLevelNode, Ty, TyDef, TyDefKind, TyKind, TypeParameter, }; pub trait Visitor<'a>: Sized { @@ -149,7 +149,17 @@ pub fn walk_ty_def<'a>(vis: &mut impl Visitor<'a>, def: &'a TyDef) { pub fn walk_callable_decl<'a>(vis: &mut impl Visitor<'a>, decl: &'a CallableDecl) { vis.visit_ident(&decl.name); - decl.generics.iter().for_each(|p| vis.visit_ident(p)); + decl.generics.iter().for_each(|p| { + vis.visit_ident(&p.ty); + p.constraints.0.iter().for_each(|b| { + vis.visit_ident(&b.name); + b.parameters + .iter() + .for_each(|crate::ast::ConstraintParameter { ty, .. }| { + vis.visit_ty(ty); + }); + }); + }); vis.visit_pat(&decl.input); vis.visit_ty(&decl.output); decl.functors.iter().for_each(|f| vis.visit_functor_expr(f)); @@ -201,7 +211,22 @@ pub fn walk_ty<'a>(vis: &mut impl Visitor<'a>, ty: &'a Ty) { TyKind::Hole | TyKind::Err => {} TyKind::Paren(ty) => vis.visit_ty(ty), TyKind::Path(path) => vis.visit_path_kind(path), - TyKind::Param(name) => vis.visit_ident(name), + TyKind::Param(TypeParameter { + ty, + constraints: bounds, + .. + }) => { + for bound in &bounds.0 { + vis.visit_ident(&bound.name); + + bound.parameters.iter().for_each( + |crate::ast::ConstraintParameter { ty, .. }| { + vis.visit_ty(ty); + }, + ); + } + vis.visit_ident(ty); + } TyKind::Tuple(tys) => tys.iter().for_each(|t| vis.visit_ty(t)), } } diff --git a/compiler/qsc_codegen/src/qsharp.rs b/compiler/qsc_codegen/src/qsharp.rs index 70bf39512d..9b4495df8a 100644 --- a/compiler/qsc_codegen/src/qsharp.rs +++ b/compiler/qsc_codegen/src/qsharp.rs @@ -231,10 +231,10 @@ impl Visitor<'_> for QSharpGen { self.write("<"); if let Some((last, most)) = decl.generics.split_last() { for i in most { - self.visit_ident(i); + self.visit_ident(&i.ty); self.write(", "); } - self.visit_ident(last); + self.visit_ident(&last.ty); } self.write(">"); @@ -349,7 +349,7 @@ impl Visitor<'_> for QSharpGen { self.write(")"); } TyKind::Path(path) => self.visit_path_kind(path), - TyKind::Param(name) => self.visit_ident(name), + TyKind::Param(name) => self.visit_ident(&name.ty), TyKind::Tuple(tys) => { if tys.is_empty() { self.write("()"); diff --git a/compiler/qsc_doc_gen/src/display.rs b/compiler/qsc_doc_gen/src/display.rs index f579b40d3d..c63ebe6ce6 100644 --- a/compiler/qsc_doc_gen/src/display.rs +++ b/compiler/qsc_doc_gen/src/display.rs @@ -1,11 +1,11 @@ // Copyright (c) Microsoft Corporation. // Licensed under the MIT License. -use qsc_ast::ast::{self, Idents}; +use qsc_ast::ast::{self, Idents, TypeParameter as AstTypeParameter}; use qsc_frontend::resolve; use qsc_hir::{ hir::{self, PackageId}, - ty::{self, GenericParam}, + ty::{self, TypeParameter as HirTypeParameter}, }; use regex_lite::Regex; use std::{ @@ -218,7 +218,45 @@ impl<'a> Display for AstCallableDecl<'a> { .decl .generics .iter() - .map(|p| p.name.clone()) + .map( + |AstTypeParameter { + ty, constraints, .. + }| { + format!( + "{}{}", + ty.name, + if constraints.0.is_empty() { + Default::default() + } else { + format!( + ": {}", + constraints + .0 + .iter() + .map(|bound| { + let constraint_parameters = bound + .parameters + .iter() + .map(|x| format!("{}", AstTy { ty: &x.ty })) + .collect::>() + .join(", "); + format!( + "{}{}", + bound.name.name, + if constraint_parameters.is_empty() { + Default::default() + } else { + format!("[{constraint_parameters}]") + } + ) + }) + .collect::>() + .join(" + ") + ) + } + ) + }, + ) .collect::>() .join(", "); write!(f, "<{type_params}>")?; @@ -517,7 +555,7 @@ impl<'a> Display for AstTy<'a> { ast::TyKind::Hole => write!(f, "_"), ast::TyKind::Paren(ty) => write!(f, "{}", AstTy { ty }), ast::TyKind::Path(path) => write!(f, "{}", AstPathKind { path }), - ast::TyKind::Param(id) => write!(f, "{}", id.name), + ast::TyKind::Param(AstTypeParameter { ty, .. }) => write!(f, "{}", ty.name), ast::TyKind::Tuple(tys) => fmt_tuple(f, tys, |ty| AstTy { ty }), ast::TyKind::Err => write!(f, "?"), } @@ -615,12 +653,20 @@ where write!(formatter, "}}") } -fn display_type_params(generics: &[GenericParam]) -> String { +fn display_type_params(generics: &[HirTypeParameter]) -> String { let type_params = generics .iter() .filter_map(|generic| match generic { - GenericParam::Ty(name) => Some(name.name.clone()), - GenericParam::Functor(_) => None, + HirTypeParameter::Ty { name, bounds } => Some(format!( + "{}{}", + name, + if bounds.is_empty() { + Default::default() + } else { + format!(": {bounds}") + } + )), + HirTypeParameter::Functor(_) => None, }) .collect::>() .join(", "); diff --git a/compiler/qsc_eval/src/lib.rs b/compiler/qsc_eval/src/lib.rs index d40b91a570..e938c85c1f 100644 --- a/compiler/qsc_eval/src/lib.rs +++ b/compiler/qsc_eval/src/lib.rs @@ -1103,7 +1103,9 @@ impl State { (record, Field::Path(path)) => { follow_field_path(record, &path.indices).expect("field path should be valid") } - _ => panic!("invalid field access"), + (ref value, ref field) => { + panic!("invalid field access. value: {value:?}, field: {field:?}") + } }; self.set_val_register(val); } diff --git a/compiler/qsc_fir/src/fir.rs b/compiler/qsc_fir/src/fir.rs index 0da9c1481b..c877a972e0 100644 --- a/compiler/qsc_fir/src/fir.rs +++ b/compiler/qsc_fir/src/fir.rs @@ -8,7 +8,7 @@ #![warn(missing_docs)] -use crate::ty::{Arrow, FunctorSet, FunctorSetValue, GenericArg, GenericParam, Scheme, Ty, Udt}; +use crate::ty::{Arrow, FunctorSet, FunctorSetValue, GenericArg, Scheme, Ty, TypeParameter, Udt}; use indenter::{indented, Indented}; use num_bigint::BigInt; use qsc_data_structures::{ @@ -295,7 +295,7 @@ impl Display for ItemId { /// A resolution. This connects a usage of a name with the declaration of that name by uniquely /// identifying the node that declared it. -#[derive(Clone, Copy, Debug, Eq, Hash, PartialEq)] +#[derive(Clone, Copy, Debug, Eq, Hash, PartialEq, PartialOrd, Ord)] pub enum Res { /// An invalid resolution. Err, @@ -747,7 +747,7 @@ pub struct CallableDecl { /// The name of the callable. pub name: Ident, /// The generic parameters to the callable. - pub generics: Vec, + pub generics: Vec, /// The input to the callable. pub input: PatId, /// The return type of the callable. @@ -1589,7 +1589,7 @@ pub enum Visibility { } /// A callable kind. -#[derive(Clone, Copy, Debug, Eq, Hash, PartialEq)] +#[derive(Clone, Copy, Debug, Eq, Hash, PartialEq, PartialOrd, Ord)] pub enum CallableKind { /// A function. Function, diff --git a/compiler/qsc_fir/src/ty.rs b/compiler/qsc_fir/src/ty.rs index 831a3e6bed..0b107c297c 100644 --- a/compiler/qsc_fir/src/ty.rs +++ b/compiler/qsc_fir/src/ty.rs @@ -85,20 +85,20 @@ impl Display for Ty { /// A type scheme. pub struct Scheme { - params: Vec, + params: Vec, ty: Box, } impl Scheme { /// Creates a new type scheme. #[must_use] - pub fn new(params: Vec, ty: Box) -> Self { + pub fn new(params: Vec, ty: Box) -> Self { Self { params, ty } } /// The generic parameters to the type. #[must_use] - pub fn params(&self) -> &[GenericParam] { + pub fn params(&self) -> &[TypeParameter] { &self.params } @@ -178,20 +178,95 @@ fn instantiate_arrow_ty<'a>( }) } -impl Display for GenericParam { +impl Display for TypeParameter { fn fmt(&self, f: &mut Formatter) -> fmt::Result { match self { - GenericParam::Ty => write!(f, "type"), - GenericParam::Functor(min) => write!(f, "functor ({min})"), + TypeParameter::Ty { name, bounds } => { + write!(f, "type ({name}){bounds}") + } + TypeParameter::Functor(min) => write!(f, "functor ({min})"), + } + } +} +#[derive(Clone, Debug, Default, Eq, PartialEq)] +pub struct ClassConstraints(pub Box<[ClassConstraint]>); + +impl std::fmt::Display for ClassConstraints { + fn fmt(&self, f: &mut Formatter<'_>) -> fmt::Result { + if self.0.is_empty() { + Ok(()) + } else { + let bounds = self + .0 + .iter() + .map(std::string::ToString::to_string) + .collect::>() + .join(", "); + write!(f, "{bounds}") + } + } +} + +#[derive(Clone, Debug, Eq, PartialEq)] +pub enum ClassConstraint { + /// Whether or not 'T can be compared via Eq to values of the same domain. + Eq, + /// Whether or not 'T can be added to values of the same domain via the + operator. + Add, + Exp { + // `base` is inferred to be the self type + power: Ty, + }, + /// If 'T is iterable, then it can be iterated over and the items inside are yielded (of type `item`). + Iterable { item: Ty }, + /// Whether or not 'T can be divided by values of the same domain via the / operator. + Div, + /// Whether or not 'T can be subtracted from values of the same domain via the - operator. + Sub, + /// Whether or not 'T can be multiplied by values of the same domain via the * operator. + Mul, + /// Whether or not 'T can be taken modulo values of the same domain via the % operator. + Mod, + /// Whether or not 'T can be compared via Ord to values of the same domain. + Ord, + /// Whether or not 'T can be signed. + Signed, + /// Whether or not 'T is an integral type (can be used in bit shifting operators). + Integral, + /// Whether or not 'T can be displayed as a string (converted to a string). + Show, + /// A class that is not built-in to the compiler. + NonNativeClass(Rc), +} + +impl std::fmt::Display for ClassConstraint { + fn fmt(&self, f: &mut Formatter<'_>) -> fmt::Result { + match self { + ClassConstraint::Eq => write!(f, "Eq"), + ClassConstraint::NonNativeClass(name) => write!(f, "{name}"), + ClassConstraint::Exp { power } => write!(f, "Exp<{power}>"), + ClassConstraint::Iterable { item } => write!(f, "Iterable<{item}>"), + ClassConstraint::Add => write!(f, "Add"), + ClassConstraint::Integral => write!(f, "Integral"), + ClassConstraint::Show => write!(f, "Show"), + ClassConstraint::Div => write!(f, "Div"), + ClassConstraint::Sub => write!(f, "Sub"), + ClassConstraint::Mul => write!(f, "Mul"), + ClassConstraint::Mod => write!(f, "Mod"), + ClassConstraint::Ord => write!(f, "Ord"), + ClassConstraint::Signed => write!(f, "Signed"), } } } /// The kind of a generic parameter. #[derive(Clone, Debug, PartialEq)] -pub enum GenericParam { +pub enum TypeParameter { /// A type parameter. - Ty, + Ty { + name: Rc, + bounds: ClassConstraints, + }, /// A functor parameter with a lower bound. Functor(FunctorSetValue), } diff --git a/compiler/qsc_formatter/src/formatter.rs b/compiler/qsc_formatter/src/formatter.rs index 40846fcc6d..c3e4022a40 100644 --- a/compiler/qsc_formatter/src/formatter.rs +++ b/compiler/qsc_formatter/src/formatter.rs @@ -562,7 +562,7 @@ impl<'a> Formatter<'a> { left_kind: &ConcreteTokenKind, right_kind: &ConcreteTokenKind, ) { - use qsc_frontend::keyword::Keyword; + use qsc_frontend::{keyword::Keyword, lex::cooked::ClosedBinOp}; use ConcreteTokenKind::*; use TokenKind::*; @@ -599,7 +599,7 @@ impl<'a> Formatter<'a> { { self.type_param_state = TypeParameterListState::InTypeParamList; } - Syntax(AposIdent | Comma | Gt) + Syntax(AposIdent | Comma | Gt | ClosedBinOp(ClosedBinOp::Plus) | Ident | Colon) if matches!( self.type_param_state, TypeParameterListState::InTypeParamList diff --git a/compiler/qsc_frontend/src/lower.rs b/compiler/qsc_frontend/src/lower.rs index 788f2ec179..bb3bf0554a 100644 --- a/compiler/qsc_frontend/src/lower.rs +++ b/compiler/qsc_frontend/src/lower.rs @@ -7,7 +7,10 @@ mod tests; use crate::{ closure::{self, Lambda, PartialApp}, resolve::{self, Names}, - typeck::{self, convert}, + typeck::{ + self, + convert::{self, synthesize_functor_params}, + }, }; use miette::Diagnostic; use qsc_ast::ast::{self, FieldAccess, Ident, Idents, PathKind}; @@ -16,11 +19,13 @@ use qsc_hir::{ assigner::Assigner, hir::{self, LocalItemId, Visibility}, mut_visit::MutVisitor, - ty::{Arrow, FunctorSetValue, GenericArg, Ty}, + ty::{Arrow, FunctorSetValue, GenericArg, ParamId, Ty, TypeParameter}, }; use std::{clone::Clone, rc::Rc, str::FromStr, vec}; use thiserror::Error; +use self::convert::TyConversionError; + #[derive(Clone, Debug, Diagnostic, Error)] pub(super) enum Error { #[error("unknown attribute {0}")] @@ -46,6 +51,59 @@ pub(super) enum Error { #[error("invalid pattern for specialization declaration")] #[diagnostic(code("Qsc.LowerAst.InvalidSpecPat"))] InvalidSpecPat(#[label] Span), + #[error("missing type in item signature")] + #[diagnostic(help("a type must be provided for this item"))] + #[diagnostic(code("Qsc.LowerAst.MissingTy"))] + MissingTy { + #[label] + span: Span, + }, + #[error("unrecognized class constraint {name}")] + #[help("supported classes are Eq, Add, Exp, Integral, Num, and Show")] + #[diagnostic(code("Qsc.LowerAst.UnrecognizedClass"))] + UnrecognizedClass { + #[label] + span: Span, + name: String, + }, + #[error("class constraint is recursive via {name}")] + #[help("if a type refers to itself via its constraints, it is self-referential and cannot ever be resolved")] + #[diagnostic(code("Qsc.LowerAst.RecursiveClassConstraint"))] + RecursiveClassConstraint { + #[label] + span: Span, + name: String, + }, + #[error("expected {expected} parameters for constraint, found {found}")] + #[diagnostic(code("Qsc.TypeCk.IncorrectNumberOfConstraintParameters"))] + IncorrectNumberOfConstraintParameters { + expected: usize, + found: usize, + #[label] + span: Span, + }, +} + +impl From for Error { + fn from(err: TyConversionError) -> Self { + use TyConversionError::*; + match err { + MissingTy { span } => Error::MissingTy { span }, + UnrecognizedClass { span, name } => Error::UnrecognizedClass { span, name }, + RecursiveClassConstraint { span, name } => { + Error::RecursiveClassConstraint { span, name } + } + IncorrectNumberOfConstraintParameters { + expected, + found, + span, + } => Error::IncorrectNumberOfConstraintParameters { + expected, + found, + span, + }, + } + } } pub(super) struct Lowerer { @@ -199,73 +257,77 @@ impl With<'_> { _otherwise => None, }; - let (id, kind) = match &*item.kind { - ast::ItemKind::Err | ast::ItemKind::Open(..) => return None, + let (id, kind) = + match &*item.kind { + ast::ItemKind::Err | ast::ItemKind::Open(..) => return None, - ast::ItemKind::ImportOrExport(item) => { - if item.is_import() { + ast::ItemKind::ImportOrExport(item) => { + if item.is_import() { + return None; + } + for item in &item.items { + let Some(item_name) = item.name() else { + continue; + }; + let Some((id, alias)) = resolve_id(item_name.id) else { + continue; + }; + let is_reexport = id.package.is_some() || alias.is_some(); + // if the package is Some, then this is a re-export and we + // need to preserve the reference to the original `ItemId` + if is_reexport { + let mut name = self.lower_ident(item_name); + name.id = self.assigner.next_node(); + let kind = hir::ItemKind::Export(name, id); + self.lowerer.items.push(hir::Item { + id: self.assigner.next_item(), + span: item.span, + parent: self.lowerer.parent, + doc: "".into(), + // attrs on exports not supported + attrs: Vec::new(), + visibility: Visibility::Public, + kind, + }); + } + } return None; } - for item in &item.items { - let Some(item_name) = item.name() else { - continue; - }; - let Some((id, alias)) = resolve_id(item_name.id) else { - continue; - }; - let is_reexport = id.package.is_some() || alias.is_some(); - // if the package is Some, then this is a re-export and we - // need to preserve the reference to the original `ItemId` - if is_reexport { - let mut name = self.lower_ident(item_name); - name.id = self.assigner.next_node(); - let kind = hir::ItemKind::Export(name, id); - self.lowerer.items.push(hir::Item { - id: self.assigner.next_item(), - span: item.span, - parent: self.lowerer.parent, - doc: "".into(), - // attrs on exports not supported - attrs: Vec::new(), - visibility: Visibility::Public, - kind, - }); - } + ast::ItemKind::Callable(callable) => { + let (id, _) = resolve_id(callable.name.id)?; + let grandparent = self.lowerer.parent; + self.lowerer.parent = Some(id.item); + let (callable, errs) = self.lower_callable_decl(callable, &attrs); + self.lowerer.errors.extend(errs.into_iter().map(|err| { + Into::::into(Into::::into(err)) + })); + self.lowerer.parent = grandparent; + (id, hir::ItemKind::Callable(callable)) } - return None; - } - ast::ItemKind::Callable(callable) => { - let (id, _) = resolve_id(callable.name.id)?; - let grandparent = self.lowerer.parent; - self.lowerer.parent = Some(id.item); - let callable = self.lower_callable_decl(callable, &attrs); - self.lowerer.parent = grandparent; - (id, hir::ItemKind::Callable(callable)) - } - ast::ItemKind::Ty(name, _) => { - let (id, _) = resolve_id(name.id)?; - let udt = self - .tys - .udts - .get(&id) - .expect("type item should have lowered UDT"); - - (id, hir::ItemKind::Ty(self.lower_ident(name), udt.clone())) - } - ast::ItemKind::Struct(decl) => { - let (id, _) = resolve_id(decl.name.id)?; - let strct = self - .tys - .udts - .get(&id) - .expect("type item should have lowered struct"); - - ( - id, - hir::ItemKind::Ty(self.lower_ident(&decl.name), strct.clone()), - ) - } - }; + ast::ItemKind::Ty(name, _) => { + let (id, _) = resolve_id(name.id)?; + let udt = self + .tys + .udts + .get(&id) + .expect("type item should have lowered UDT"); + + (id, hir::ItemKind::Ty(self.lower_ident(name), udt.clone())) + } + ast::ItemKind::Struct(decl) => { + let (id, _) = resolve_id(decl.name.id)?; + let strct = self + .tys + .udts + .get(&id) + .expect("type item should have lowered struct"); + + ( + id, + hir::ItemKind::Ty(self.lower_ident(&decl.name), strct.clone()), + ) + } + }; let export_info = exported_ids.iter().find(|(hir_id, _)| hir_id == &id); let visibility = match export_info { @@ -387,17 +449,55 @@ impl With<'_> { } } + /// Generates generic parameters for the functors, if there were generics on the original callable. + /// Basically just creates new generic params for the purpose of being used in functor callable + /// decls. + pub(crate) fn synthesize_callable_generics( + &mut self, + generics: &[ast::TypeParameter], + input: &mut hir::Pat, + ) -> (Vec, Vec) { + let (mut params, errs) = convert::type_parameters_for_ast_callable(self.names, generics); + let mut functor_params = + Self::synthesize_functor_params_in_pat(&mut params.len().into(), input); + params.append(&mut functor_params); + (params, errs) + } + + fn synthesize_functor_params_in_pat( + next_param: &mut ParamId, + pat: &mut hir::Pat, + ) -> Vec { + match &mut pat.kind { + hir::PatKind::Discard | hir::PatKind::Err | hir::PatKind::Bind(_) => { + synthesize_functor_params(next_param, &mut pat.ty) + } + hir::PatKind::Tuple(items) => { + let mut params = Vec::new(); + for item in &mut *items { + params.append(&mut Self::synthesize_functor_params_in_pat( + next_param, item, + )); + } + if !params.is_empty() { + pat.ty = Ty::Tuple(items.iter().map(|i| i.ty.clone()).collect()); + } + params + } + } + } + pub(super) fn lower_callable_decl( &mut self, decl: &ast::CallableDecl, attrs: &[qsc_hir::hir::Attr], - ) -> hir::CallableDecl { + ) -> (hir::CallableDecl, Vec) { let id = self.lower_id(decl.id); let kind = self.lower_callable_kind(decl.kind, attrs, decl.name.span); let name = self.lower_ident(&decl.name); let mut input = self.lower_pat(&decl.input); - let output = convert::ty_from_ast(self.names, &decl.output).0; - let generics = convert::synthesize_callable_generics(&decl.generics, &mut input); + let output = convert::ty_from_ast(self.names, &decl.output, &mut Default::default()).0; + let (generics, errs) = self.synthesize_callable_generics(&decl.generics, &mut input); let functors = convert::ast_callable_functors(decl); let (body, adj, ctl, ctl_adj) = match decl.body.as_ref() { @@ -425,21 +525,24 @@ impl With<'_> { } }; - hir::CallableDecl { - id, - span: decl.span, - kind, - name, - generics, - input, - output, - functors, - body, - adj, - ctl, - ctl_adj, - attrs: attrs.to_vec(), - } + ( + hir::CallableDecl { + id, + span: decl.span, + kind, + name, + generics, + input, + output, + functors, + body, + adj, + ctl, + ctl_adj, + attrs: attrs.to_vec(), + }, + errs, + ) } fn check_invalid_attrs_on_function(&mut self, attrs: &[hir::Attr], span: Span) { @@ -910,7 +1013,7 @@ impl With<'_> { // Exported items are just pass-throughs to the items they reference, and should be // treated as Res to that original item. Some(&resolve::Res::ExportedItem(item_id, _)) => hir::Res::Item(item_id), - Some(resolve::Res::PrimTy(_) | resolve::Res::UnitTy | resolve::Res::Param(_)) + Some(resolve::Res::PrimTy(_) | resolve::Res::UnitTy | resolve::Res::Param { .. }) | None => hir::Res::Err, } } diff --git a/compiler/qsc_frontend/src/lower/tests.rs b/compiler/qsc_frontend/src/lower/tests.rs index 0e8bbd2658..408fb42c7e 100644 --- a/compiler/qsc_frontend/src/lower/tests.rs +++ b/compiler/qsc_frontend/src/lower/tests.rs @@ -2329,7 +2329,7 @@ fn nested_params() { Callable 0 [17-55] (function): name: Ident 1 [26-29] "Foo" generics: - 0: type [30-32] "'T" + 0: type 'T 1: functor (empty set) input: Pat 2 [34-45] [Type (Param<"'T": 0> => Unit is Param<1>)]: Bind: Ident 3 [34-35] "f" output: Unit @@ -2420,8 +2420,8 @@ fn duplicate_commas_in_generics() { Callable 0 [21-57] (function): name: Ident 1 [30-33] "Foo" generics: - 0: type [34-36] "'T" - 1: type [37-37] "" + 0: type 'T + 1: type input: Pat 2 [40-46] [Type Param<"'T": 0>]: Bind: Ident 3 [40-41] "x" output: Unit functors: empty set diff --git a/compiler/qsc_frontend/src/resolve.rs b/compiler/qsc_frontend/src/resolve.rs index dded3c8515..2843be41e6 100644 --- a/compiler/qsc_frontend/src/resolve.rs +++ b/compiler/qsc_frontend/src/resolve.rs @@ -7,8 +7,8 @@ mod tests; use miette::Diagnostic; use qsc_ast::{ ast::{ - self, CallableBody, CallableDecl, Ident, Idents, NodeId, PathKind, SpecBody, SpecGen, - TopLevelNode, + self, CallableBody, CallableDecl, ClassConstraints, Ident, Idents, NodeId, PathKind, + SpecBody, SpecGen, TopLevelNode, TypeParameter, }, visit::{self as ast_visit, walk_attr, Visitor as AstVisitor}, }; @@ -63,7 +63,10 @@ pub enum Res { /// A local variable. Local(NodeId), /// A type/functor parameter in the generics section of the parent callable decl. - Param(ParamId), + Param { + id: ParamId, + bounds: ClassConstraints, + }, /// A primitive type. PrimTy(Prim), /// The unit type. @@ -212,7 +215,7 @@ pub struct Scope { /// it is missed in the list. vars: FxHashMap, (u32, NodeId)>, /// Type parameters. - ty_vars: FxHashMap, ParamId>, + ty_vars: FxHashMap, (ParamId, ClassConstraints)>, } #[derive(Debug, Clone)] @@ -1164,13 +1167,26 @@ impl Resolver { } } + /// For a given callable declaration, bind the names of the type parameters + /// into the current scope. Tracks the constraints defined on the type parameters + /// as well, for later use in type checking. fn bind_type_parameters(&mut self, decl: &CallableDecl) { - decl.generics.iter().enumerate().for_each(|(ix, ident)| { - self.current_scope_mut() - .ty_vars - .insert(Rc::clone(&ident.name), ix.into()); - self.names.insert(ident.id, Res::Param(ix.into())); - }); + decl.generics + .iter() + .enumerate() + .for_each(|(ix, type_parameter)| { + self.current_scope_mut().ty_vars.insert( + Rc::clone(&type_parameter.ty.name), + (ix.into(), type_parameter.constraints.clone()), + ); + self.names.insert( + type_parameter.ty.id, + Res::Param { + id: ix.into(), + bounds: type_parameter.constraints.clone(), + }, + ); + }); } fn push_scope(&mut self, span: Span, kind: ScopeKind) { @@ -1396,8 +1412,8 @@ impl AstVisitor<'_> for With<'_> { self.resolver.errors.push(e); } } - ast::TyKind::Param(ident) => { - self.resolver.resolve_ident(NameKind::Ty, ident); + ast::TyKind::Param(TypeParameter { ty, .. }) => { + self.resolver.resolve_ident(NameKind::Ty, ty); } _ => ast_visit::walk_ty(self, ty), } @@ -2334,8 +2350,11 @@ fn resolve_scope_locals( } } NameKind::Ty => { - if let Some(&id) = scope.ty_vars.get(name) { - return Some(Res::Param(id)); + if let Some((id, bounds)) = scope.ty_vars.get(name) { + return Some(Res::Param { + id: *id, + bounds: bounds.clone(), + }); } } } @@ -2373,10 +2392,15 @@ fn get_scope_locals(scope: &Scope, offset: u32, vars: bool) -> Vec { } })); - names.extend(scope.ty_vars.iter().map(|id| Local { - name: id.0.clone(), - kind: LocalKind::TyParam(*id.1), - })); + names.extend( + scope + .ty_vars + .iter() + .map(|(name, (id, _constraints))| Local { + name: name.clone(), + kind: LocalKind::TyParam(*id), + }), + ); } // items diff --git a/compiler/qsc_frontend/src/resolve/tests.rs b/compiler/qsc_frontend/src/resolve/tests.rs index 1a2b53dce1..e010f0e9a0 100644 --- a/compiler/qsc_frontend/src/resolve/tests.rs +++ b/compiler/qsc_frontend/src/resolve/tests.rs @@ -81,7 +81,7 @@ impl<'a> Renamer<'a> { Res::Local(node) => format!("local{node}"), Res::PrimTy(prim) => format!("{prim:?}"), Res::UnitTy => "Unit".to_string(), - Res::Param(id) => format!("param{id}"), + Res::Param { id, .. } => format!("param{id}"), Res::ExportedItem(item, _) => match item.package { None => format!("exported_item{}", item.item), Some(package) => format!("reexport_from_{package}:{}", item.item), @@ -5269,3 +5269,23 @@ fn export_of_item_with_same_name_as_namespace_resolves_to_item_even_when_before_ "#]], ); } + +#[test] +fn ty_param_name_is_in_scope() { + check( + indoc! {r#" + namespace Foo { + operation Foo<'T: Eq>(a: 'T) : Unit { + let x: 'T = a; + } + } + "#}, + &expect![[r#" + namespace namespace3 { + operation item1(local10: param0) : Unit { + let local19: param0 = local10; + } + } + "#]], + ); +} diff --git a/compiler/qsc_frontend/src/typeck.rs b/compiler/qsc_frontend/src/typeck.rs index 2ae44c2f40..c17cf08592 100644 --- a/compiler/qsc_frontend/src/typeck.rs +++ b/compiler/qsc_frontend/src/typeck.rs @@ -1,6 +1,12 @@ // Copyright (c) Microsoft Corporation. // Licensed under the MIT License. +//! Type checks a Q# AST and produces a typed HIR. +//! `check`ing references `rules` within contexts to produce context-aware constraints. The inferrer is used +//! within `rules` to assist in the production of constraints from rules. +//! For example, a rule might say that if a statement is an expression, it must +//! return `Unit`. The inferrer would then be used to get the inferred type out of +//! the expression, giving us a type id, which we can then constrain to `Unit`. mod check; pub(super) mod convert; mod infer; @@ -8,6 +14,7 @@ mod rules; #[cfg(test)] mod tests; +use convert::TyConversionError; use miette::Diagnostic; use qsc_ast::ast::NodeId; use qsc_data_structures::{index_map::IndexMap, span::Span}; @@ -21,6 +28,8 @@ use thiserror::Error; pub(super) use check::{Checker, GlobalTable}; +/// This [`Table`] builds up mappings from items to typed HIR UDTs _and_ nodes to +/// their term HIR type and generic arguments, if any exist. #[derive(Debug, Default, Clone)] pub struct Table { pub udts: FxHashMap, @@ -93,10 +102,29 @@ enum ErrorKind { #[diagnostic(help("only arrays and ranges are iterable"))] #[diagnostic(code("Qsc.TypeCk.MissingClassIterable"))] MissingClassIterable(String, #[label] Span), - #[error("type {0} is not a number")] + #[error("Type {0} cannot be used in subtraction")] #[diagnostic(help("only BigInt, Double, and Int are numbers"))] - #[diagnostic(code("Qsc.TypeCk.MissingClassNum"))] - MissingClassNum(String, #[label] Span), + #[diagnostic(code("Qsc.TypeCk.MissingClassSub"))] + MissingClassSub(String, #[label] Span), + #[error("Type {0} cannot be used in multiplication")] + #[diagnostic(help("only BigInt, Double, and Int are numbers"))] + #[diagnostic(code("Qsc.TypeCk.MissingClassMul"))] + MissingClassMul(String, #[label] Span), + #[error("Type {0} cannot be used in division")] + #[diagnostic(help("only BigInt, Double, and Int are numbers"))] + #[diagnostic(code("Qsc.TypeCk.MissingClassDiv"))] + MissingClassDiv(String, #[label] Span), + #[error("Type {0} cannot be used with comparison operators (less than/greater than)")] + #[diagnostic(code("Qsc.TypeCk.MissingClassOrd"))] + MissingClassOrd(String, #[label] Span), + #[error("Type {0} cannot be used with the modulo operator")] + #[diagnostic(help("only BigInt and Int are numbers"))] + #[diagnostic(code("Qsc.TypeCk.MissingClassMod"))] + MissingClassMod(String, #[label] Span), + #[error("Type {0} cannot have a sign applied to it")] + #[diagnostic(help("only BigInt, Double, and Int are numbers"))] + #[diagnostic(code("Qsc.TypeCk.MissingClassSigned"))] + MissingClassSigned(String, #[label] Span), #[error("type {0} cannot be converted into a string")] #[diagnostic(code("Qsc.TypeCk.MissingClassShow"))] MissingClassShow(String, #[label] Span), @@ -107,10 +135,6 @@ enum ErrorKind { #[error("expected superset of {0}, found {1}")] #[diagnostic(code("Qsc.TypeCk.MissingFunctor"))] MissingFunctor(FunctorSet, FunctorSet, #[label] Span), - #[error("missing type in item signature")] - #[diagnostic(help("types cannot be inferred for global declarations"))] - #[diagnostic(code("Qsc.TypeCk.MissingItemTy"))] - MissingItemTy(#[label] Span), #[error("found hole with type {0}")] #[diagnostic(help("replace this hole with an expression of the expected type"))] #[diagnostic(code("Qsc.TypeCk.TyHole"))] @@ -119,4 +143,57 @@ enum ErrorKind { #[diagnostic(help("provide a type annotation"))] #[diagnostic(code("Qsc.TypeCk.AmbiguousTy"))] AmbiguousTy(#[label] Span), + #[error("missing type in item signature")] + #[diagnostic(help("a type must be provided for this item"))] + #[diagnostic(code("Qsc.TypeCk.MissingTy"))] + MissingTy { + #[label] + span: Span, + }, + #[error("unrecognized class constraint {name}")] + #[help("supported classes are Eq, Add, Exp, Integral, Num, and Show")] + #[diagnostic(code("Qsc.TypeCk.UnrecognizedClass"))] + UnrecognizedClass { + #[label] + span: Span, + name: String, + }, + #[error("class constraint is recursive via {name}")] + #[help("if a type refers to itself via its constraints, it is self-referential and cannot ever be resolved")] + #[diagnostic(code("Qsc.TypeCk.RecursiveClassConstraint"))] + RecursiveClassConstraint { + #[label] + span: Span, + name: String, + }, + #[error("expected {expected} parameters for constraint, found {found}")] + #[diagnostic(code("Qsc.TypeCk.IncorrectNumberOfConstraintParameters"))] + IncorrectNumberOfConstraintParameters { + expected: usize, + found: usize, + #[label] + span: Span, + }, +} + +impl From for Error { + fn from(err: TyConversionError) -> Self { + use TyConversionError::*; + match err { + MissingTy { span } => Error(ErrorKind::MissingTy { span }), + UnrecognizedClass { span, name } => Error(ErrorKind::UnrecognizedClass { span, name }), + RecursiveClassConstraint { span, name } => { + Error(ErrorKind::RecursiveClassConstraint { span, name }) + } + IncorrectNumberOfConstraintParameters { + expected, + found, + span, + } => Error(ErrorKind::IncorrectNumberOfConstraintParameters { + expected, + found, + span, + }), + } + } } diff --git a/compiler/qsc_frontend/src/typeck/check.rs b/compiler/qsc_frontend/src/typeck/check.rs index c53982eed6..40a65d6d8a 100644 --- a/compiler/qsc_frontend/src/typeck/check.rs +++ b/compiler/qsc_frontend/src/typeck/check.rs @@ -8,7 +8,7 @@ use super::{ }; use crate::{ resolve::{Names, Res}, - typeck::convert::{self, MissingTyError}, + typeck::convert::{self}, }; use qsc_ast::{ ast::{self, NodeId, TopLevelNode}, @@ -94,6 +94,9 @@ impl GlobalTable { } } +/// This struct is the entry point of the type checker. Constructed with [`Checker::new`], it +/// exposes a method [`Checker::check_package`] that will type check a given [`ast::Package`] and +/// populate its own fields with the results. pub(crate) struct Checker { globals: FxHashMap, table: Table, @@ -157,7 +160,7 @@ impl Checker { fn check_callable_decl(&mut self, names: &Names, decl: &ast::CallableDecl) { self.check_callable_signature(names, decl); - let output = convert::ty_from_ast(names, &decl.output).0; + let output = convert::ty_from_ast(names, &decl.output, &mut Default::default()).0; match &*decl.body { ast::CallableBody::Block(block) => self.check_spec( names, @@ -192,7 +195,7 @@ impl Checker { fn check_callable_signature(&mut self, names: &Names, decl: &ast::CallableDecl) { if convert::ast_callable_functors(decl) != FunctorSetValue::Empty { - let output = convert::ty_from_ast(names, &decl.output).0; + let output = convert::ty_from_ast(names, &decl.output, &mut Default::default()).0; match &output { Ty::Tuple(items) if items.is_empty() => {} _ => self.errors.push(Error(ErrorKind::TyMismatch( @@ -204,6 +207,9 @@ impl Checker { } } + /// Used to check all callable bodies + /// Note that a regular function block callable body is still checked by + /// this function fn check_spec(&mut self, names: &Names, spec: SpecImpl) { self.errors.append(&mut rules::spec( names, @@ -224,6 +230,8 @@ impl Checker { } } +/// Populates `Checker` with definitions and errors, while referring to the `Names` table to get +/// definitions. struct ItemCollector<'a> { checker: &'a mut Checker, names: &'a Names, @@ -243,11 +251,9 @@ impl Visitor<'_> for ItemCollector<'_> { panic!("callable should have item ID"); }; - let (scheme, errors) = convert::ast_callable_scheme(self.names, decl); - for MissingTyError(span) in errors { - self.checker - .errors - .push(Error(ErrorKind::MissingItemTy(span))); + let (scheme, errors) = convert::scheme_for_ast_callable(self.names, decl); + for err in errors { + self.checker.errors.push(err.into()); } self.checker.globals.insert(item, scheme); @@ -261,12 +267,9 @@ impl Visitor<'_> for ItemCollector<'_> { let (cons, cons_errors) = convert::ast_ty_def_cons(self.names, &name.name, item, def); let (udt_def, def_errors) = convert::ast_ty_def(self.names, def); - self.checker.errors.extend( - cons_errors - .into_iter() - .chain(def_errors) - .map(|MissingTyError(span)| Error(ErrorKind::MissingItemTy(span))), - ); + self.checker + .errors + .extend(cons_errors.into_iter().chain(def_errors).map(Into::into)); self.checker.table.udts.insert( item, @@ -289,12 +292,9 @@ impl Visitor<'_> for ItemCollector<'_> { let (cons, cons_errors) = convert::ast_ty_def_cons(self.names, &decl.name.name, item, &def); let (udt_def, def_errors) = convert::ast_ty_def(self.names, &def); - self.checker.errors.extend( - cons_errors - .into_iter() - .chain(def_errors) - .map(|MissingTyError(span)| Error(ErrorKind::MissingItemTy(span))), - ); + self.checker + .errors + .extend(cons_errors.into_iter().chain(def_errors).map(Into::into)); self.checker.table.udts.insert( item, diff --git a/compiler/qsc_frontend/src/typeck/convert.rs b/compiler/qsc_frontend/src/typeck/convert.rs index dc97f6493e..bc36d7c8b3 100644 --- a/compiler/qsc_frontend/src/typeck/convert.rs +++ b/compiler/qsc_frontend/src/typeck/convert.rs @@ -1,33 +1,62 @@ // Copyright (c) Microsoft Corporation. // Licensed under the MIT License. +//! Ascribe types to the AST and output HIR items. Put another way, converts the AST to the HIR. use std::rc::Rc; use crate::resolve::{self, Names}; + use qsc_ast::ast::{ - self, CallableBody, CallableDecl, CallableKind, FunctorExpr, FunctorExprKind, Ident, Pat, - PatKind, Path, PathKind, SetOp, Spec, StructDecl, TyDef, TyDefKind, TyKind, + self, CallableBody, CallableDecl, CallableKind, FunctorExpr, FunctorExprKind, Pat, PatKind, + Path, PathKind, SetOp, Spec, StructDecl, TyDef, TyDefKind, TyKind, + TypeParameter as AstTypeParameter, }; use qsc_data_structures::span::Span; use qsc_hir::{ - hir, + hir::{self}, ty::{ - Arrow, FunctorSet, FunctorSetValue, GenericParam, ParamId, Scheme, Ty, TypeParamName, + Arrow, FunctorSet, FunctorSetValue, ParamId, Scheme, Ty, TypeParameter as HirTypeParameter, UdtDef, UdtDefKind, UdtField, }, }; +use rustc_hash::FxHashSet; -pub(crate) struct MissingTyError(pub(super) Span); +#[derive(Debug, Clone, PartialEq, Eq, Hash)] +pub(crate) enum TyConversionError { + MissingTy { + span: Span, + }, + UnrecognizedClass { + span: Span, + name: String, + }, + RecursiveClassConstraint { + span: Span, + name: String, + }, + IncorrectNumberOfConstraintParameters { + expected: usize, + found: usize, + span: Span, + }, +} -pub(crate) fn ty_from_ast(names: &Names, ty: &ast::Ty) -> (Ty, Vec) { +/// Given an `ast::Ty` and a list of resolved `Names`, convert the `ast::Ty` to an `hir::Ty`. +pub(crate) fn ty_from_ast( + names: &Names, + ty: &ast::Ty, + stack: &mut FxHashSet, +) -> (Ty, Vec) { match &*ty.kind { TyKind::Array(item) => { - let (item, errors) = ty_from_ast(names, item); + let (item, errors) = ty_from_ast(names, item, stack); (Ty::Array(Box::new(item)), errors) } TyKind::Arrow(kind, input, output, functors) => { - let (input, mut errors) = ty_from_ast(names, input); - let (output, output_errors) = ty_from_ast(names, output); + // shadow the stack as a new empty one, since we are in a new arrow type + let mut stack = Default::default(); + let (input, mut errors) = ty_from_ast(names, input, &mut stack); + let (output, output_errors) = ty_from_ast(names, output, &mut stack); errors.extend(output_errors); let functors = functors .as_ref() @@ -40,22 +69,34 @@ pub(crate) fn ty_from_ast(names: &Names, ty: &ast::Ty) -> (Ty, Vec (Ty::Err, vec![MissingTyError(ty.span)]), - TyKind::Paren(inner) => ty_from_ast(names, inner), - TyKind::Path(PathKind::Ok(path)) => (ty_from_path(names, path), Vec::new()), - TyKind::Param(name) => match names.get(name.id) { - Some(resolve::Res::Param(id)) => (Ty::Param(name.name.clone(), *id), Vec::new()), - Some(_) => unreachable!( - "A parameter should never resolve to a non-parameter type, as there \ - is syntactic differentiation" + TyKind::Hole => ( + Ty::Err, + vec![TyConversionError::MissingTy { span: ty.span }], + ), + TyKind::Paren(inner) => ty_from_ast(names, inner, stack), + TyKind::Param(AstTypeParameter { ty, .. }) => match names.get(ty.id) { + Some(resolve::Res::Param { id, bounds }) => { + let (bounds, errors) = class_constraints_from_ast(names, bounds, stack); + ( + Ty::Param { + name: ty.name.clone(), + id: *id, + bounds, + }, + errors, + ) + } + Some(_) | None => ( + Ty::Err, + vec![TyConversionError::MissingTy { span: ty.span }], ), - None => (Ty::Err, Vec::new()), }, + TyKind::Path(PathKind::Ok(path)) => (ty_from_path(names, path), Vec::new()), TyKind::Tuple(items) => { let mut tys = Vec::new(); let mut errors = Vec::new(); for item in items { - let (item_ty, item_errors) = ty_from_ast(names, item); + let (item_ty, item_errors) = ty_from_ast(names, item, stack); tys.push(item_ty); errors.extend(item_errors); } @@ -77,7 +118,7 @@ pub(super) fn ty_from_path(names: &Names, path: &Path) -> Ty { // A path can also never resolve to an export, because in typeck/check, // we resolve exports to their original definition. Some( - resolve::Res::Local(_) | resolve::Res::Param(_) | resolve::Res::ExportedItem(_, _), + resolve::Res::Local(_) | resolve::Res::Param { .. } | resolve::Res::ExportedItem(_, _), ) => { unreachable!( "A path should never resolve \ @@ -113,7 +154,7 @@ pub(super) fn ast_ty_def_cons( ty_name: &Rc, id: hir::ItemId, def: &TyDef, -) -> (Scheme, Vec) { +) -> (Scheme, Vec) { let (input, errors) = ast_ty_def_base(names, def); let ty = Arrow { kind: hir::CallableKind::Function, @@ -125,9 +166,9 @@ pub(super) fn ast_ty_def_cons( (scheme, errors) } -fn ast_ty_def_base(names: &Names, def: &TyDef) -> (Ty, Vec) { +fn ast_ty_def_base(names: &Names, def: &TyDef) -> (Ty, Vec) { match &*def.kind { - TyDefKind::Field(_, ty) => ty_from_ast(names, ty), + TyDefKind::Field(_, ty) => ty_from_ast(names, ty, &mut Default::default()), TyDefKind::Paren(inner) => ast_ty_def_base(names, inner), TyDefKind::Tuple(items) => { let mut tys = Vec::new(); @@ -144,7 +185,9 @@ fn ast_ty_def_base(names: &Names, def: &TyDef) -> (Ty, Vec) { } } -pub(super) fn ast_ty_def(names: &Names, def: &TyDef) -> (UdtDef, Vec) { +/// Given a type definition from the AST ([`TyDef`]), convert it to a HIR type definition ([`UdtDef`]). +/// Relies on `names` having been correctly populated to resolve any pending names referred to in the definition. +pub(super) fn ast_ty_def(names: &Names, def: &TyDef) -> (UdtDef, Vec) { if let TyDefKind::Paren(inner) = &*def.kind { return ast_ty_def(names, inner); } @@ -154,7 +197,7 @@ pub(super) fn ast_ty_def(names: &Names, def: &TyDef) -> (UdtDef, Vec { - let (ty, item_errors) = ty_from_ast(names, ty); + let (ty, item_errors) = ty_from_ast(names, ty, &mut Default::default()); errors.extend(item_errors); let (name_span, name) = match name { Some(name) => (Some(name.span), Some(name.name.clone())), @@ -189,18 +232,49 @@ pub(super) fn ast_ty_def(names: &Names, def: &TyDef) -> (UdtDef, Vec (Vec, Vec) { + let mut errors = Vec::new(); + let mut generics_buf = Vec::with_capacity(generics.len()); + for param in generics { + let (bounds, new_errors) = + class_constraints_from_ast(names, ¶m.constraints, &mut Default::default()); + errors.extend(new_errors); + generics_buf.push(HirTypeParameter::Ty { + name: param.ty.name.clone(), + bounds, + }); + } + (generics_buf, errors) +} + +/// Given an AST callable, convert it to a HIR callable scheme (type scheme). +pub(super) fn scheme_for_ast_callable( names: &Names, callable: &CallableDecl, -) -> (Scheme, Vec) { +) -> (Scheme, Vec) { + let (mut type_parameters, errors) = type_parameters_for_ast_callable(names, &callable.generics); + let mut errors = errors + .into_iter() + .map(TyConversionError::from) + .collect::>(); let kind = callable_kind_from_ast(callable.kind); - let (mut input, mut errors) = ast_pat_ty(names, &callable.input); - let (output, output_errors) = ty_from_ast(names, &callable.output); + + let (mut input, new_errors) = ast_pat_ty(names, &callable.input); + errors.extend(&mut new_errors.into_iter()); + + let (output, output_errors) = ty_from_ast(names, &callable.output, &mut Default::default()); + errors.extend(output_errors); - let mut params = ast_callable_generics(&callable.generics); - let mut functor_params = synthesize_functor_params(&mut params.len().into(), &mut input); - params.append(&mut functor_params); + let mut functor_params = + synthesize_functor_params(&mut type_parameters.len().into(), &mut input); + + type_parameters.append(&mut functor_params); let ty = Arrow { kind, @@ -209,25 +283,20 @@ pub(super) fn ast_callable_scheme( functors: FunctorSet::Value(ast_callable_functors(callable)), }; - (Scheme::new(params, Box::new(ty)), errors) + (Scheme::new(type_parameters, Box::new(ty)), errors) } -pub(crate) fn synthesize_callable_generics( - generics: &[Box], - input: &mut hir::Pat, -) -> Vec { - let mut params = ast_callable_generics(generics); - let mut functor_params = synthesize_functor_params_in_pat(&mut params.len().into(), input); - params.append(&mut functor_params); - params -} - -fn synthesize_functor_params(next_param: &mut ParamId, ty: &mut Ty) -> Vec { +/// Given a [`Ty`], find all arrow types and create type parameters, if necessary, for them. +/// Recurses into container types to find all arrow types contained within the type. +pub(crate) fn synthesize_functor_params( + next_param: &mut ParamId, + ty: &mut Ty, +) -> Vec { match ty { Ty::Array(item) => synthesize_functor_params(next_param, item), Ty::Arrow(arrow) => match arrow.functors { FunctorSet::Value(functors) if arrow.kind == hir::CallableKind::Operation => { - let param = GenericParam::Functor(functors); + let param = HirTypeParameter::Functor(functors); arrow.functors = FunctorSet::Param(*next_param, functors); *next_param = next_param.successor(); vec![param] @@ -238,49 +307,19 @@ fn synthesize_functor_params(next_param: &mut ParamId, ty: &mut Ty) -> Vec Vec::new(), - } -} - -fn synthesize_functor_params_in_pat( - next_param: &mut ParamId, - pat: &mut hir::Pat, -) -> Vec { - match &mut pat.kind { - hir::PatKind::Discard | hir::PatKind::Err | hir::PatKind::Bind(_) => { - synthesize_functor_params(next_param, &mut pat.ty) - } - hir::PatKind::Tuple(items) => { - let mut params = Vec::new(); - for item in &mut *items { - params.append(&mut synthesize_functor_params_in_pat(next_param, item)); - } - if !params.is_empty() { - pat.ty = Ty::Tuple(items.iter().map(|i| i.ty.clone()).collect()); - } - params - } + Ty::Infer(_) | Ty::Param { .. } | Ty::Prim(_) | Ty::Udt(_, _) | Ty::Err => Vec::new(), } } -fn ast_callable_generics(generics: &[Box]) -> Vec { - generics - .iter() - .map(|param| { - GenericParam::Ty(TypeParamName { - span: param.span, - name: param.name.clone(), - }) - }) - .collect() -} - -pub(crate) fn ast_pat_ty(names: &Names, pat: &Pat) -> (Ty, Vec) { +pub(crate) fn ast_pat_ty(names: &Names, pat: &Pat) -> (Ty, Vec) { match &*pat.kind { - PatKind::Bind(_, None) | PatKind::Discard(None) | PatKind::Elided => { - (Ty::Err, vec![MissingTyError(pat.span)]) + PatKind::Bind(_, None) | PatKind::Discard(None) | PatKind::Elided => ( + Ty::Err, + vec![TyConversionError::MissingTy { span: pat.span }], + ), + PatKind::Bind(_, Some(ty)) | PatKind::Discard(Some(ty)) => { + ty_from_ast(names, ty, &mut Default::default()) } - PatKind::Bind(_, Some(ty)) | PatKind::Discard(Some(ty)) => ty_from_ast(names, ty), PatKind::Paren(inner) => ast_pat_ty(names, inner), PatKind::Tuple(items) => { let mut tys = Vec::new(); @@ -339,3 +378,90 @@ pub(crate) fn eval_functor_expr(expr: &FunctorExpr) -> FunctorSetValue { FunctorExprKind::Paren(inner) => eval_functor_expr(inner), } } + +/// Convert an AST type bound to an HIR type bound. +pub(crate) fn class_constraints_from_ast( + names: &Names, + bounds: &qsc_ast::ast::ClassConstraints, + // used to check for recursive types + stack: &mut FxHashSet, +) -> (qsc_hir::ty::ClassConstraints, Vec) { + let mut bounds_buf = Vec::new(); + let mut errors = FxHashSet::default(); + + for ast_bound in &bounds.0 { + if stack.contains(ast_bound) { + errors.insert(TyConversionError::RecursiveClassConstraint { + span: ast_bound.span(), + name: ast_bound.name.name.to_string(), + }); + continue; + } + stack.insert(ast_bound.clone()); + if check_param_length(ast_bound, &mut errors) { + continue; + }; + let bound_result = match &*ast_bound.name.name { + "Eq" => Ok(qsc_hir::ty::ClassConstraint::Eq), + "Add" => Ok(qsc_hir::ty::ClassConstraint::Add), + "Iterable" => { + let (item, item_errors) = ty_from_ast(names, &ast_bound.parameters[0].ty, stack); + errors.extend(item_errors.into_iter()); + Ok(qsc_hir::ty::ClassConstraint::Iterable { item }) + } + "Exp" => { + let (power, power_errors) = ty_from_ast(names, &ast_bound.parameters[0].ty, stack); + errors.extend(power_errors.into_iter()); + Ok(qsc_hir::ty::ClassConstraint::Exp { power }) + } + "Integral" => Ok(qsc_hir::ty::ClassConstraint::Integral), + "Mul" => Ok(qsc_hir::ty::ClassConstraint::Mul), + "Sub" => Ok(qsc_hir::ty::ClassConstraint::Sub), + "Mod" => Ok(qsc_hir::ty::ClassConstraint::Mod), + "Div" => Ok(qsc_hir::ty::ClassConstraint::Div), + "Signed" => Ok(qsc_hir::ty::ClassConstraint::Signed), + "Show" => Ok(qsc_hir::ty::ClassConstraint::Show), + otherwise => Err(TyConversionError::UnrecognizedClass { + span: ast_bound.span(), + name: otherwise.to_string(), + }), + }; + + match bound_result { + Ok(hir_bound) => { + bounds_buf.push(hir_bound); + } + Err(e) => { + errors.insert(e); + } + } + } + + ( + qsc_hir::ty::ClassConstraints(bounds_buf.into_boxed_slice()), + errors.into_iter().collect(), + ) +} + +/// returns `true` if the param length is incorrect +fn check_param_length( + bound: &ast::ClassConstraint, + errors: &mut FxHashSet, +) -> bool { + let num_given_parameters = bound.parameters.len(); + let num_parameters = match &*bound.name.name { + "Eq" | "Add" | "Integral" | "Num" | "Show" => 0, + "Iterable" | "Exp" => 1, + _ => return false, + }; + if num_parameters == num_given_parameters { + false + } else { + errors.insert(TyConversionError::IncorrectNumberOfConstraintParameters { + expected: num_parameters, + found: num_given_parameters, + span: bound.span(), + }); + true + } +} diff --git a/compiler/qsc_frontend/src/typeck/infer.rs b/compiler/qsc_frontend/src/typeck/infer.rs index 29cb0c1b27..41b3ab3900 100644 --- a/compiler/qsc_frontend/src/typeck/infer.rs +++ b/compiler/qsc_frontend/src/typeck/infer.rs @@ -6,14 +6,15 @@ use qsc_data_structures::{index_map::IndexMap, span::Span}; use qsc_hir::{ hir::{ItemId, PrimField, Res}, ty::{ - Arrow, FunctorSet, FunctorSetValue, GenericArg, GenericParam, InferFunctorId, InferTyId, - Prim, Scheme, Ty, Udt, + Arrow, ClassConstraint, FunctorSet, FunctorSetValue, GenericArg, InferFunctorId, InferTyId, + Prim, Scheme, Ty, TypeParameter, Udt, }, }; use rustc_hash::{FxHashMap, FxHashSet}; use std::{ - collections::{hash_map::Entry, VecDeque}, + collections::{hash_map::Entry, BTreeSet, VecDeque}, fmt::Debug, + rc::Rc, }; const MAX_TY_RECURSION_DEPTH: i8 = 100; @@ -63,12 +64,21 @@ pub(super) enum Class { container: Ty, item: Ty, }, - Num(Ty), + Mul(Ty), + Sub(Ty), + Div(Ty), + Ord(Ty), + Mod(Ty), + Signed(Ty), Show(Ty), Unwrap { wrapper: Ty, base: Ty, }, + // A user-defined class + // When we actually support this, and don't just use it to generate an error, + // it should have an ID here instead of a name + NonPrimitive(Rc), } impl Class { @@ -78,7 +88,12 @@ impl Class { | Self::Adj(ty) | Self::Eq(ty) | Self::Integral(ty) - | Self::Num(ty) + | Self::Mul(ty) + | Self::Sub(ty) + | Self::Div(ty) + | Self::Mod(ty) + | Self::Ord(ty) + | Self::Signed(ty) | Self::Show(ty) | Self::Struct(ty) => { vec![ty] @@ -92,6 +107,7 @@ impl Class { } => vec![container, index], Self::Iterable { container, .. } => vec![container], Self::Unwrap { wrapper, .. } => vec![wrapper], + Self::NonPrimitive(_) => Vec::new(), } } @@ -146,12 +162,19 @@ impl Class { container: f(container), item: f(item), }, - Self::Num(ty) => Self::Num(f(ty)), + Self::Sub(ty) => Self::Sub(f(ty)), + Self::Mul(ty) => Self::Mul(f(ty)), + Self::Div(ty) => Self::Div(f(ty)), + Self::Ord(ty) => Self::Ord(f(ty)), + Self::Mod(ty) => Self::Mod(f(ty)), + Self::Signed(ty) => Self::Signed(f(ty)), + Self::Show(ty) => Self::Show(f(ty)), Self::Unwrap { wrapper, base } => Self::Unwrap { wrapper: f(wrapper), base: f(base), }, + Self::NonPrimitive(name) => Self::NonPrimitive(name), } } @@ -191,17 +214,67 @@ impl Class { vec![Error(ErrorKind::MissingClassInteger(ty.display(), span))], ), Class::Iterable { container, item } => check_iterable(container, item, span), - Class::Num(ty) if check_num(&ty) => (Vec::new(), Vec::new()), - Class::Num(ty) => ( + Class::Sub(ty) if check_sub(&ty) => (Vec::new(), Vec::new()), + Class::Sub(ty) => ( + Vec::new(), + vec![Error(ErrorKind::MissingClassSub(ty.display(), span))], + ), + Class::Mul(ty) if check_mul(&ty) => (Vec::new(), Vec::new()), + Class::Mul(ty) => ( + Vec::new(), + vec![Error(ErrorKind::MissingClassMul(ty.display(), span))], + ), + Class::Div(ty) if check_div(&ty) => (Vec::new(), Vec::new()), + Class::Div(ty) => ( + Vec::new(), + vec![Error(ErrorKind::MissingClassDiv(ty.display(), span))], + ), + Class::Ord(ty) if check_ord(&ty) => (Vec::new(), Vec::new()), + Class::Ord(ty) => ( + Vec::new(), + vec![Error(ErrorKind::MissingClassOrd(ty.display(), span))], + ), + Class::Signed(ty) if check_signed(&ty) => (Vec::new(), Vec::new()), + Class::Signed(ty) => ( + Vec::new(), + vec![Error(ErrorKind::MissingClassSigned(ty.display(), span))], + ), + Class::Mod(ty) if check_mod(&ty) => (Vec::new(), Vec::new()), + Class::Mod(ty) => ( Vec::new(), - vec![Error(ErrorKind::MissingClassNum(ty.display(), span))], + vec![Error(ErrorKind::MissingClassMod(ty.display(), span))], ), Class::Show(ty) => check_show(ty, span), Class::Unwrap { wrapper, base } => check_unwrap(udts, &wrapper, base, span), + Class::NonPrimitive(_) => (vec![], vec![]), } } } +fn check_mod(ty: &Ty) -> bool { + check_num_constraint(&ClassConstraint::Mod, ty) +} + +fn check_signed(ty: &Ty) -> bool { + check_num_constraint(&ClassConstraint::Signed, ty) +} + +fn check_ord(ty: &Ty) -> bool { + check_num_constraint(&ClassConstraint::Ord, ty) +} + +fn check_div(ty: &Ty) -> bool { + check_num_constraint(&ClassConstraint::Div, ty) +} + +fn check_mul(ty: &Ty) -> bool { + check_num_constraint(&ClassConstraint::Mul, ty) +} + +fn check_sub(ty: &Ty) -> bool { + check_num_constraint(&ClassConstraint::Sub, ty) +} + /// Meta-level descriptions about the source of a type. /// The compiler uses the notion of "unresolved types" to /// represent both divergent types (return expressions, similar to @@ -211,6 +284,7 @@ impl Class { /// so we need to track where types came from. This `TySource` /// struct allows us to know if a type originates from a divergent /// source, and if it doesn't, we generate an ambiguous type error. +#[derive(Debug)] pub(super) enum TySource { Divergent, NotDivergent { span: Span }, @@ -226,7 +300,9 @@ impl TySource { } } -/// An argument type and tags describing the call syntax. +/// An argument type and tags describing the call syntax. This represents the type of something +/// that appears in a _call_ to a _callable_, and an argument can be a hole, a given argument, or, +/// in the most standard case, a tuple. Foo(1, 2, 3) is [`ArgTy::Tuple`], not [`ArgTy::Given`]. #[derive(Clone, Debug)] pub(super) enum ArgTy { /// A missing argument, indicating partial application. @@ -238,6 +314,7 @@ pub(super) enum ArgTy { } impl ArgTy { + /// Applies a function `f` to each type in the argument type. fn map(self, f: &mut impl FnMut(Ty) -> Ty) -> Self { match self { Self::Hole(ty) => Self::Hole(f(ty)), @@ -246,8 +323,14 @@ impl ArgTy { } } + /// Applies the argument type to a parameter type, generating constraints and errors. fn apply(&self, param: &Ty, span: Span) -> App { match (self, param) { + // If `arg` is a hole, then it doesn't matter what the param is, + // because the hole can be anything. + // However, we do know that the type of Arg must be Eq to the type of Param, so we + // add that to the constraints. + // Preserve the hole. (Self::Hole(arg), _) => App { holes: vec![param.clone()], constraints: vec![Constraint::Eq { @@ -257,6 +340,10 @@ impl ArgTy { }], errors: Vec::new(), }, + // If `arg` is a hole, then it doesn't matter what the param is, + // because the hole can be anything. + // However, we do know that the type of Arg must be Eq to the type of Param, so we + // add that to the constraints. (Self::Given(arg), _) => App { holes: Vec::new(), constraints: vec![Constraint::Eq { @@ -266,6 +353,8 @@ impl ArgTy { }], errors: Vec::new(), }, + // if both the arg and the param are tuples, then we must check + // the types of each element in the tuple and generate iterative applications. (Self::Tuple(args), Ty::Tuple(params)) => { let mut errors = Vec::new(); if args.len() != params.len() { @@ -295,6 +384,7 @@ impl ArgTy { errors, } } + (Self::Tuple(_), Ty::Infer(_)) => App { holes: Vec::new(), constraints: vec![Constraint::Eq { @@ -335,8 +425,10 @@ struct App { } #[derive(Debug)] -enum Constraint { +pub(super) enum Constraint { + // Constraint that says a type must satisfy a class Class(Class, Span), + // Constraint that says two types must be the same Eq { expected: Ty, actual: Ty, @@ -349,6 +441,7 @@ enum Constraint { }, } +#[derive(Debug)] pub(super) struct Inferrer { solver: Solver, constraints: VecDeque, @@ -383,6 +476,20 @@ impl Inferrer { self.constraints.push_back(Constraint::Class(class, span)); } + /// Returns a unique type variable with specified constraints. + fn constrained_ty( + &mut self, + meta: TySource, + with_constraints: impl Fn(Ty) -> Box<[Constraint]>, + ) -> Ty { + let fresh = self.next_ty; + self.next_ty = fresh.successor(); + self.ty_metadata.insert(fresh, meta); + let constraints = with_constraints(Ty::Infer(fresh)); + self.constraints.extend(constraints); + Ty::Infer(fresh) + } + /// Returns a unique unconstrained type variable. pub(super) fn fresh_ty(&mut self, meta: TySource) -> Ty { let fresh = self.next_ty; @@ -404,8 +511,16 @@ impl Inferrer { .params() .iter() .map(|param| match param { - GenericParam::Ty(_) => GenericArg::Ty(self.fresh_ty(TySource::not_divergent(span))), - GenericParam::Functor(expected) => { + TypeParameter::Ty { bounds, .. } => { + GenericArg::Ty(self.constrained_ty(TySource::not_divergent(span), |ty| { + bounds + .0 + .iter() + .map(|x| into_constraint(ty.clone(), x, span)) + .collect() + })) + } + TypeParameter::Functor(expected) => { let actual = self.fresh_functor(); self.constraints.push_back(Constraint::Superset { expected: *expected, @@ -469,6 +584,10 @@ impl Inferrer { pub(super) fn substitute_functor(&mut self, functors: &mut FunctorSet) { substitute_functor(&self.solver.solution, functors); } + + pub(super) fn report_error(&mut self, error: impl Into) { + self.solver.errors.push(error.into()); + } } #[derive(Debug)] @@ -489,6 +608,9 @@ impl Solver { } } + /// Given a constraint, attempts to narrow the constraint by either + /// generating more specific constraints, or, if it cannot be narrowed further, + /// returns an empty vector. fn constrain( &mut self, udts: &FxHashMap, @@ -512,15 +634,20 @@ impl Solver { } } + /// Attempts to narrow a class constraint, returning more specific constraints if any. fn class( &mut self, udts: &FxHashMap, class: Class, span: Span, ) -> Vec { + // true if a dependency of this class constraint is currently unknown, meaning we + // have to come back to it later. + // false if we know everything we need to know and this is solved let unknown_dependency = class.dependencies().into_iter().any(|ty| { if ty == &Ty::Err { true + // if this needs to be inferred further, `unknown_ty` returns `Some(ty_id)` } else if let Some(infer) = unknown_ty(&self.solution.tys, ty) { self.pending_tys .entry(infer) @@ -621,7 +748,36 @@ impl Solver { (&Ty::Infer(infer), ty) | (ty, &Ty::Infer(infer)) if !contains_infer_ty(infer, ty) => { self.bind_ty(infer, ty.clone(), span) } - (Ty::Param(_, name1), Ty::Param(_, name2)) if name1 == name2 => Vec::new(), + ( + Ty::Param { + name: name1, + id: id1, + bounds: bounds1, + }, + Ty::Param { + name: _name2, + id: id2, + bounds: bounds2, + }, + ) if id1 == id2 => { + // concat the two sets of bounds + let bounds: BTreeSet = bounds1 + .0 + .iter() + .chain(bounds2.0.iter()) + .map(Clone::clone) + .collect(); + + let merged_ty = Ty::Param { + name: name1.clone(), + id: *id1, + bounds: qsc_hir::ty::ClassConstraints(bounds.clone().into_iter().collect()), + }; + bounds + .into_iter() + .map(|x| into_constraint(merged_ty.clone(), &x, span)) + .collect() + } (Ty::Prim(prim1), Ty::Prim(prim2)) if prim1 == prim2 => Vec::new(), (Ty::Tuple(items1), Ty::Tuple(items2)) => { if items1.len() != items2.len() { @@ -704,6 +860,8 @@ impl Solver { } } +/// Replaces inferred tys with the underlying type that they refer to, if it has been solved +/// already. fn substitute_ty(solution: &Solution, ty: &mut Ty) -> bool { fn substitute_ty_recursive(solution: &Solution, ty: &mut Ty, limit: i8) -> bool { if limit == 0 { @@ -714,7 +872,7 @@ fn substitute_ty(solution: &Solution, ty: &mut Ty) -> bool { return false; } match ty { - Ty::Err | Ty::Param(_, _) | Ty::Prim(_) | Ty::Udt(_, _) => true, + Ty::Err | Ty::Param { .. } | Ty::Prim(_) | Ty::Udt(_, _) => true, Ty::Array(item) => substitute_ty_recursive(solution, item, limit - 1), Ty::Arrow(arrow) => { let a = substitute_ty_recursive(solution, &mut arrow.input, limit - 1); @@ -757,19 +915,25 @@ fn substitute_functor(solution: &Solution, functors: &mut FunctorSet) { } } -fn unknown_ty(tys: &IndexMap, ty: &Ty) -> Option { - match ty { - &Ty::Infer(infer) => match tys.get(infer) { +// `Some(ty)` if `given_type` has not been solved for yet, `None` if it is fully known/non-inferred +fn unknown_ty(solved_types: &IndexMap, given_type: &Ty) -> Option { + match given_type { + // if the given type is an inference type, check if we have solved for it + &Ty::Infer(infer) => match solved_types.get(infer) { + // if we have not solved for it, then indeed this is an unknown type None => Some(infer), - Some(ty) => unknown_ty(tys, ty), + // if we have solved for it, then we check if that solved type is itself + // solved. It could have been solved to another inference type + Some(solved_type) => unknown_ty(solved_types, solved_type), }, + // the given type is not an inference type so it is not unknown _ => None, } } fn contains_infer_ty(id: InferTyId, ty: &Ty) -> bool { match ty { - Ty::Err | Ty::Param(_, _) | Ty::Prim(_) | Ty::Udt(_, _) => false, + Ty::Err | Ty::Param { .. } | Ty::Prim(_) | Ty::Udt(_, _) => false, Ty::Array(item) => contains_infer_ty(id, item), Ty::Arrow(arrow) => { contains_infer_ty(id, &arrow.input) || contains_infer_ty(id, &arrow.output) @@ -780,10 +944,14 @@ fn contains_infer_ty(id: InferTyId, ty: &Ty) -> bool { } fn check_add(ty: &Ty) -> bool { - matches!( - ty, - Ty::Prim(Prim::BigInt | Prim::Double | Prim::Int | Prim::String) | Ty::Array(_) - ) + match ty { + Ty::Prim(Prim::BigInt | Prim::Double | Prim::Int | Prim::String) | Ty::Array(_) => true, + Ty::Param { ref bounds, .. } => bounds + .0 + .iter() + .any(|bound| matches!(bound, ClassConstraint::Add)), + _ => false, + } } fn check_adj(ty: Ty, span: Span) -> (Vec, Vec) { @@ -811,6 +979,9 @@ fn check_call(callee: Ty, input: &ArgTy, output: Ty, span: Span) -> (Vec 1 { Ty::Arrow(Box::new(Arrow { @@ -870,6 +1041,7 @@ fn check_ctl(op: Ty, with_ctls: Ty, span: Span) -> (Vec, Vec) ) } +/// Checks that the class `Eq` is implemented for the given type. fn check_eq(ty: Ty, span: Span) -> (Vec, Vec) { match ty { Ty::Prim( @@ -891,6 +1063,21 @@ fn check_eq(ty: Ty, span: Span) -> (Vec, Vec) { .collect(), Vec::new(), ), + Ty::Param { ref bounds, .. } => { + // check if the bounds contain Eq + + match bounds + .0 + .iter() + .find(|bound| matches!(bound, ClassConstraint::Eq)) + { + Some(_) => (Vec::new(), Vec::new()), + None => ( + Vec::new(), + vec![Error(ErrorKind::MissingClassEq(ty.display(), span))], + ), + } + } _ => ( Vec::new(), vec![Error(ErrorKind::MissingClassEq(ty.display(), span))], @@ -898,12 +1085,12 @@ fn check_eq(ty: Ty, span: Span) -> (Vec, Vec) { } } -fn check_exp(base: Ty, power: Ty, span: Span) -> (Vec, Vec) { +fn check_exp(base: Ty, given_power: Ty, span: Span) -> (Vec, Vec) { match base { Ty::Prim(Prim::BigInt) => ( vec![Constraint::Eq { expected: Ty::Prim(Prim::Int), - actual: power, + actual: given_power, span, }], Vec::new(), @@ -911,11 +1098,35 @@ fn check_exp(base: Ty, power: Ty, span: Span) -> (Vec, Vec) { Ty::Prim(Prim::Double | Prim::Int) => ( vec![Constraint::Eq { expected: base, - actual: power, + actual: given_power, span, }], Vec::new(), ), + Ty::Param { ref bounds, .. } => { + // check if the bounds contain Exp + + match bounds + .0 + .iter() + .find(|bound| matches!(bound, ClassConstraint::Exp { .. })) + { + Some(ClassConstraint::Exp { + power: power_from_param, + }) => ( + vec![Constraint::Eq { + actual: given_power, + expected: power_from_param.clone(), + span, + }], + Vec::new(), + ), + _ => ( + Vec::new(), + vec![Error(ErrorKind::MissingClassExp(base.display(), span))], + ), + } + } _ => ( Vec::new(), vec![Error(ErrorKind::MissingClassExp(base.display(), span))], @@ -923,6 +1134,9 @@ fn check_exp(base: Ty, power: Ty, span: Span) -> (Vec, Vec) { } } +// i'm using the wildcard below to enforce that Ty::Param is always matched in the err branch, as +// it shouldn't be constrained by HasField as long as we don't support structural typing +#[allow(clippy::wildcard_in_or_patterns)] fn check_has_field( udts: &FxHashMap, record: &Ty, @@ -966,7 +1180,9 @@ fn check_has_field( ), } } - _ => ( + // `HasField` cannot be used to constrain an arbitrary type parameter, it is used + // internally only, so it will never resolve to a ty param. + (_, Ty::Param { .. }) | _ => ( Vec::new(), vec![Error(ErrorKind::MissingClassHasField( record.display(), @@ -1083,7 +1299,14 @@ fn check_has_index( } fn check_integral(ty: &Ty) -> bool { - matches!(ty, Ty::Prim(Prim::BigInt | Prim::Int)) + match ty { + Ty::Prim(Prim::BigInt | Prim::Int) => true, + Ty::Param { ref bounds, .. } => bounds + .0 + .iter() + .any(|bound| matches!(bound, ClassConstraint::Integral)), + _ => false, + } } fn check_iterable(container: Ty, item: Ty, span: Span) -> (Vec, Vec) { @@ -1104,6 +1327,13 @@ fn check_iterable(container: Ty, item: Ty, span: Span) -> (Vec, Vec< }], Vec::new(), ), + Ty::Param { .. } => ( + Vec::default(), + vec![Error(ErrorKind::UnrecognizedClass { + span, + name: "Iterable".into(), + })], + ), _ => ( Vec::new(), vec![Error(ErrorKind::MissingClassIterable( @@ -1114,8 +1344,17 @@ fn check_iterable(container: Ty, item: Ty, span: Span) -> (Vec, Vec< } } -fn check_num(ty: &Ty) -> bool { - matches!(ty, Ty::Prim(Prim::BigInt | Prim::Double | Prim::Int)) +/// Some constraints are just true if the type is numeric, this used to be the class Num, but now +/// we support different operators as separate classes. +fn check_num_constraint(constraint: &ClassConstraint, ty: &Ty) -> bool { + match ty { + Ty::Prim(Prim::BigInt | Prim::Double | Prim::Int) => true, + Ty::Param { ref bounds, .. } => { + // check if the bounds contain Num + bounds.0.iter().any(|bound| *bound == *constraint) + } + _ => false, + } } fn check_show(ty: Ty, span: Span) -> (Vec, Vec) { @@ -1132,6 +1371,20 @@ fn check_show(ty: Ty, span: Span) -> (Vec, Vec) { .collect(), Vec::new(), ), + Ty::Param { ref bounds, .. } => { + // check if the bounds contain Show + match bounds + .0 + .iter() + .find(|bound| matches!(bound, ClassConstraint::Show)) + { + Some(_) => (Vec::new(), Vec::new()), + None => ( + Vec::new(), + vec![Error(ErrorKind::MissingClassShow(ty.display(), span))], + ), + } + } _ => ( Vec::new(), vec![Error(ErrorKind::MissingClassShow(ty.display(), span))], @@ -1169,3 +1422,40 @@ fn check_unwrap( ))], ) } + +/// Given an HIR class constraint, produce an actual type system constraint. +fn into_constraint(ty: Ty, bound: &ClassConstraint, span: Span) -> Constraint { + match bound { + ClassConstraint::Eq => Constraint::Class(Class::Eq(ty), span), + ClassConstraint::Exp { power } => Constraint::Class( + Class::Exp { + // `ty` here is basically `Self` -- so Exp[Double] is a type that can be raised to + // the power of a double. + // Exponentiation is a _closed_ operation, meaning the domain and codomain are the + // same. + base: ty.clone(), + power: power.clone(), + }, + span, + ), + ClassConstraint::Add => Constraint::Class(Class::Add(ty), span), + ClassConstraint::Iterable { item } => Constraint::Class( + Class::Iterable { + item: item.clone(), + container: ty.clone(), + }, + span, + ), + ClassConstraint::NonNativeClass(name) => { + Constraint::Class(Class::NonPrimitive(name.clone()), span) + } + ClassConstraint::Show => Constraint::Class(Class::Show(ty), span), + ClassConstraint::Integral => Constraint::Class(Class::Integral(ty), span), + ClassConstraint::Ord => Constraint::Class(Class::Ord(ty), span), + ClassConstraint::Mul => Constraint::Class(Class::Mul(ty), span), + ClassConstraint::Div => Constraint::Class(Class::Div(ty), span), + ClassConstraint::Sub => Constraint::Class(Class::Sub(ty), span), + ClassConstraint::Signed => Constraint::Class(Class::Signed(ty), span), + ClassConstraint::Mod => Constraint::Class(Class::Mod(ty), span), + } +} diff --git a/compiler/qsc_frontend/src/typeck/rules.rs b/compiler/qsc_frontend/src/typeck/rules.rs index 582177d744..2afeaaf8e0 100644 --- a/compiler/qsc_frontend/src/typeck/rules.rs +++ b/compiler/qsc_frontend/src/typeck/rules.rs @@ -1,6 +1,9 @@ // Copyright (c) Microsoft Corporation. // Licensed under the MIT License. +//! Defines type system rules for Q#. The checker calls these rules on the AST. +//! These rules use the inferrer to know what types to apply constraints to. + use super::{ convert, infer::{ArgTy, Class, Inferrer, TySource}, @@ -10,7 +13,7 @@ use crate::resolve::{self, Names, Res}; use qsc_ast::ast::{ self, BinOp, Block, Expr, ExprKind, FieldAccess, Functor, Ident, Idents, Lit, NodeId, Pat, PatKind, Path, PathKind, QubitInit, QubitInitKind, Spec, Stmt, StmtKind, StringComponent, - TernOp, TyKind, UnOp, + TernOp, TyKind, TypeParameter, UnOp, }; use qsc_data_structures::span::Span; use qsc_hir::{ @@ -36,12 +39,16 @@ impl Partial { } } +/// Contexts are currently only generated for exprs, stmts, and specs, +/// They provide a context within which types are solved for. +#[derive(Debug)] struct Context<'a> { names: &'a Names, globals: &'a FxHashMap, table: &'a mut Table, return_ty: Option, typed_holes: Vec<(NodeId, Span)>, + /// New nodes that will be introduced into the parent `Context` after this context terminates new: Vec, inferrer: &'a mut Inferrer, } @@ -119,15 +126,31 @@ impl<'a> Context<'a> { // we resolve exports to their original definition. Some( resolve::Res::Local(_) - | resolve::Res::Param(_) + | resolve::Res::Param { .. } | resolve::Res::ExportedItem(_, _), ) => unreachable!( "A path should never resolve \ to a local or a parameter, as there is syntactic differentiation." ), }, - TyKind::Param(name) => match self.names.get(name.id) { - Some(Res::Param(id)) => Ty::Param(name.name.clone(), *id), + TyKind::Param(TypeParameter { + ty, constraints: _, .. + }) => match self.names.get(ty.id) { + Some(Res::Param { id, bounds }) => { + let (bounds, errs) = convert::class_constraints_from_ast( + self.names, + bounds, + &mut Default::default(), + ); + for err in errs { + self.inferrer.report_error(err); + } + Ty::Param { + name: ty.name.clone(), + id: *id, + bounds, + } + } None => Ty::Err, Some(_) => unreachable!( "A parameter should never resolve to a non-parameter type, as there \ @@ -615,7 +638,7 @@ impl<'a> Context<'a> { self.table.generics.insert(expr.id, args); converge(Ty::Arrow(Box::new(ty))) } - Some(Res::PrimTy(_) | Res::UnitTy | Res::Param(_)) => { + Some(Res::PrimTy(_) | Res::UnitTy | Res::Param { .. }) => { panic!("expression should not resolve to type reference") } }, @@ -702,7 +725,7 @@ impl<'a> Context<'a> { converge(with_ctls) } UnOp::Neg | UnOp::Pos => { - self.inferrer.class(span, Class::Num(operand.ty.clone())); + self.inferrer.class(span, Class::Signed(operand.ty.clone())); operand } UnOp::NotB => { @@ -757,7 +780,7 @@ impl<'a> Context<'a> { } BinOp::Gt | BinOp::Gte | BinOp::Lt | BinOp::Lte => { self.inferrer.eq(rhs_span, lhs.ty.clone(), rhs.ty); - self.inferrer.class(lhs_span, Class::Num(lhs.ty)); + self.inferrer.class(lhs_span, Class::Ord(lhs.ty)); converge(Ty::Prim(Prim::Bool)) } BinOp::AndB | BinOp::OrB | BinOp::XorB => { @@ -766,9 +789,24 @@ impl<'a> Context<'a> { .class(lhs_span, Class::Integral(lhs.ty.clone())); lhs } - BinOp::Div | BinOp::Mod | BinOp::Mul | BinOp::Sub => { + BinOp::Div => { + self.inferrer.eq(rhs_span, lhs.ty.clone(), rhs.ty); + self.inferrer.class(lhs_span, Class::Div(lhs.ty.clone())); + lhs + } + BinOp::Mul => { + self.inferrer.eq(rhs_span, lhs.ty.clone(), rhs.ty); + self.inferrer.class(lhs_span, Class::Mul(lhs.ty.clone())); + lhs + } + BinOp::Sub => { + self.inferrer.eq(rhs_span, lhs.ty.clone(), rhs.ty); + self.inferrer.class(lhs_span, Class::Sub(lhs.ty.clone())); + lhs + } + BinOp::Mod => { self.inferrer.eq(rhs_span, lhs.ty.clone(), rhs.ty); - self.inferrer.class(lhs_span, Class::Num(lhs.ty.clone())); + self.inferrer.class(lhs_span, Class::Mod(lhs.ty.clone())); lhs } BinOp::Exp => { diff --git a/compiler/qsc_frontend/src/typeck/tests.rs b/compiler/qsc_frontend/src/typeck/tests.rs index 68e1fbdbc0..1933bdc07b 100644 --- a/compiler/qsc_frontend/src/typeck/tests.rs +++ b/compiler/qsc_frontend/src/typeck/tests.rs @@ -1,6 +1,8 @@ // Copyright (c) Microsoft Corporation. // Licensed under the MIT License. +mod bounded_polymorphism; + use crate::{ compile::{self, Offsetter}, resolve::{self, Resolver}, @@ -123,7 +125,6 @@ fn compile( let mut errors = globals.add_local_package(&mut assigner, &package); let mut resolver = Resolver::new(globals, Vec::new()); resolver.bind_and_resolve_imports_and_exports(&package); - resolver.with(&mut assigner).visit_package(&package); let (names, _, mut resolve_errors, _namespaces) = resolver.into_result(); errors.append(&mut resolve_errors); @@ -138,7 +139,6 @@ fn compile( .chain(ty_errors.into_iter().map(Into::into)) .map(compile::Error) .collect(); - (package, tys, errors) } @@ -1441,11 +1441,11 @@ fn unop_neg_bool() { check( "", "-false", - &expect![[r#" + &expect![[r##" #1 0-6 "-false" : Bool #2 1-6 "false" : Bool - Error(Type(Error(MissingClassNum("Bool", Span { lo: 1, hi: 6 })))) - "#]], + Error(Type(Error(MissingClassSigned("Bool", Span { lo: 1, hi: 6 })))) + "##]], ); } @@ -1454,11 +1454,11 @@ fn unop_pos_bool() { check( "", "+false", - &expect![[r#" + &expect![[r##" #1 0-6 "+false" : Bool #2 1-6 "false" : Bool - Error(Type(Error(MissingClassNum("Bool", Span { lo: 1, hi: 6 })))) - "#]], + Error(Type(Error(MissingClassSigned("Bool", Span { lo: 1, hi: 6 })))) + "##]], ); } @@ -4203,12 +4203,13 @@ fn undeclared_generic_param() { check( r#"namespace c{operation y(g: 'U): Unit {} }"#, "", - &expect![[r#" + &expect![[r##" #6 23-30 "(g: 'U)" : ? #7 24-29 "g: 'U" : ? #14 37-39 "{}" : Unit Error(Resolve(NotFound("'U", Span { lo: 27, hi: 29 }))) - "#]], + Error(Type(Error(MissingTy { span: Span { lo: 27, hi: 29 } }))) + "##]], ); } diff --git a/compiler/qsc_frontend/src/typeck/tests/bounded_polymorphism.rs b/compiler/qsc_frontend/src/typeck/tests/bounded_polymorphism.rs new file mode 100644 index 0000000000..616188f781 --- /dev/null +++ b/compiler/qsc_frontend/src/typeck/tests/bounded_polymorphism.rs @@ -0,0 +1,824 @@ +// Copyright (c) Microsoft Corporation. +// Licensed under the MIT License. + +use expect_test::expect; + +use super::check; + +#[test] +fn eq() { + check( + r#" + namespace A { + function Foo<'T: Eq>(a: 'T, b: 'T) : Bool { + a == b + } + } + "#, + "", + &expect![[r##" + #8 55-69 "(a: 'T, b: 'T)" : (Param<"'T": 0>, Param<"'T": 0>) + #9 56-61 "a: 'T" : Param<"'T": 0> + #13 63-68 "b: 'T" : Param<"'T": 0> + #20 77-115 "{\n a == b\n }" : Bool + #22 95-101 "a == b" : Bool + #23 95-96 "a" : Param<"'T": 0> + #26 100-101 "b" : Param<"'T": 0> + "##]], + ); +} + +#[test] +fn exp() { + check( + r#" + namespace A { + function Foo<'T: Exp[Int]>(a: 'T, b: Int) : 'T { + a ^ b + } + } + "#, + "", + &expect![[r##" + #11 61-76 "(a: 'T, b: Int)" : (Param<"'T": 0>, Int) + #12 62-67 "a: 'T" : Param<"'T": 0> + #16 69-75 "b: Int" : Int + #23 82-119 "{\n a ^ b\n }" : Param<"'T": 0> + #25 100-105 "a ^ b" : Param<"'T": 0> + #26 100-101 "a" : Param<"'T": 0> + #29 104-105 "b" : Int + "##]], + ); +} + +#[test] +fn exp_fail() { + check( + r#" + namespace A { + function Foo<'T: Exp[Int]>(a: 'T, b: Bool) : 'T { + a ^ b + } + } + "#, + "", + &expect![[r##" + #11 61-77 "(a: 'T, b: Bool)" : (Param<"'T": 0>, Bool) + #12 62-67 "a: 'T" : Param<"'T": 0> + #16 69-76 "b: Bool" : Bool + #23 83-120 "{\n a ^ b\n }" : Param<"'T": 0> + #25 101-106 "a ^ b" : Param<"'T": 0> + #26 101-102 "a" : Param<"'T": 0> + #29 105-106 "b" : Bool + Error(Type(Error(TyMismatch("Int", "Bool", Span { lo: 101, hi: 106 })))) + "##]], + ); +} +#[test] +fn extra_arg_to_exp() { + check( + r#" + namespace A { + function Foo<'E, 'T: Exp['E, Int]>(a: 'E, b: Int) : 'T { + a ^ b + } + function Foo2<'E, 'T: Exp>(a: 'E, b: Int) : 'T { + a ^ b + } + } + "#, + "", + &expect![[r##" + #14 69-84 "(a: 'E, b: Int)" : (Param<"'E": 0>, Int) + #15 70-75 "a: 'E" : Param<"'E": 0> + #19 77-83 "b: Int" : Int + #26 90-127 "{\n a ^ b\n }" : Param<"'E": 0> + #28 108-113 "a ^ b" : Param<"'E": 0> + #29 108-109 "a" : Param<"'E": 0> + #32 112-113 "b" : Int + #41 166-181 "(a: 'E, b: Int)" : (Param<"'E": 0>, Int) + #42 167-172 "a: 'E" : Param<"'E": 0> + #46 174-180 "b: Int" : Int + #53 187-224 "{\n a ^ b\n }" : Param<"'E": 0> + #55 205-210 "a ^ b" : Param<"'E": 0> + #56 205-206 "a" : Param<"'E": 0> + #59 209-210 "b" : Int + Error(Type(Error(IncorrectNumberOfConstraintParameters { expected: 1, found: 2, span: Span { lo: 56, hi: 59 } }))) + Error(Type(Error(IncorrectNumberOfConstraintParameters { expected: 1, found: 2, span: Span { lo: 56, hi: 59 } }))) + Error(Type(Error(IncorrectNumberOfConstraintParameters { expected: 1, found: 0, span: Span { lo: 162, hi: 165 } }))) + Error(Type(Error(IncorrectNumberOfConstraintParameters { expected: 1, found: 0, span: Span { lo: 162, hi: 165 } }))) + Error(Type(Error(MissingClassExp("'E", Span { lo: 108, hi: 113 })))) + Error(Type(Error(TyMismatch("'T", "'E", Span { lo: 108, hi: 113 })))) + Error(Type(Error(MissingClassExp("'E", Span { lo: 205, hi: 210 })))) + Error(Type(Error(TyMismatch("'T", "'E", Span { lo: 205, hi: 210 })))) + "##]], + ); +} + +#[test] +fn example_should_fail() { + check( + r#" + namespace A { + function Foo<'T: Eq, 'O: Eq>(a: 'T, b: 'O) : Bool { + // should fail because we can't compare two different types + a == b + } + } + "#, + "", + &expect![[r##" + #10 63-77 "(a: 'T, b: 'O)" : (Param<"'T": 0>, Param<"'O": 1>) + #11 64-69 "a: 'T" : Param<"'T": 0> + #15 71-76 "b: 'O" : Param<"'O": 1> + #22 85-195 "{\n // should fail because we can't compare two different types\n a == b\n }" : Bool + #24 175-181 "a == b" : Bool + #25 175-176 "a" : Param<"'T": 0> + #28 180-181 "b" : Param<"'O": 1> + Error(Type(Error(TyMismatch("'T", "'O", Span { lo: 180, hi: 181 })))) + "##]], + ); +} + +// This test ensures that we show a pretty error for polymorphism bounds that are not supported +// yet. +#[test] +fn iter() { + check( + r#" + namespace A { + function Foo<'T: Iterable[Bool]>(a: 'T) : Bool { + for item in a { + return item; + } + } + + function Main() : Unit { + let x = Foo([true]); + } + } + "#, + "", + &expect![[r##" + #11 67-74 "(a: 'T)" : Param<"'T": 0> + #12 68-73 "a: 'T" : Param<"'T": 0> + #19 82-180 "{\n for item in a {\n return item;\n }\n }" : Bool + #21 100-166 "for item in a {\n return item;\n }" : Bool + #22 104-108 "item" : Bool + #24 112-113 "a" : Param<"'T": 0> + #27 114-166 "{\n return item;\n }" : Unit + #29 136-147 "return item" : Unit + #30 143-147 "item" : Bool + #36 207-209 "()" : Unit + #40 217-269 "{\n let x = Foo([true]);\n }" : Unit + #42 239-240 "x" : Bool + #44 243-254 "Foo([true])" : Bool + #45 243-246 "Foo" : (Bool[] -> Bool) + #48 246-254 "([true])" : Bool[] + #49 247-253 "[true]" : Bool[] + #50 248-252 "true" : Bool + Error(Type(Error(UnrecognizedClass { span: Span { lo: 112, hi: 113 }, name: "Iterable" }))) + "##]], + ); +} + +#[test] +fn signed() { + check( + r#" + namespace A { + function Foo<'T: Signed>(a: 'T) : 'T { + -a + } + + function Main() : Unit { + let x: Int = Foo(1); + let y: Double = Foo(1.0); + let z: BigInt = Foo(10L); + } + } + "#, + "", + &expect![[r##" + #8 59-66 "(a: 'T)" : Param<"'T": 0> + #9 60-65 "a: 'T" : Param<"'T": 0> + #15 72-106 "{\n -a\n }" : Param<"'T": 0> + #17 90-92 "-a" : Param<"'T": 0> + #18 91-92 "a" : Param<"'T": 0> + #24 133-135 "()" : Unit + #28 143-279 "{\n let x: Int = Foo(1);\n let y: Double = Foo(1.0);\n let z: BigInt = Foo(10L);\n }" : Unit + #30 165-171 "x: Int" : Int + #35 174-180 "Foo(1)" : Int + #36 174-177 "Foo" : (Int -> Int) + #39 177-180 "(1)" : Int + #40 178-179 "1" : Int + #42 202-211 "y: Double" : Double + #47 214-222 "Foo(1.0)" : Double + #48 214-217 "Foo" : (Double -> Double) + #51 217-222 "(1.0)" : Double + #52 218-221 "1.0" : Double + #54 244-253 "z: BigInt" : BigInt + #59 256-264 "Foo(10L)" : BigInt + #60 256-259 "Foo" : (BigInt -> BigInt) + #63 259-264 "(10L)" : BigInt + #64 260-263 "10L" : BigInt + "##]], + ); +} + +#[test] +fn signed_fail() { + check( + r#" + namespace A { + function Foo<'T: Eq>(a: 'T) : 'T { + -a + } + + function Main() : Unit { + let x: Int = Foo(1); + let y: Double = Foo(1.0); + let z: BigInt = Foo(10L); + } + } + "#, + "", + &expect![[r##" + #8 55-62 "(a: 'T)" : Param<"'T": 0> + #9 56-61 "a: 'T" : Param<"'T": 0> + #15 68-102 "{\n -a\n }" : Param<"'T": 0> + #17 86-88 "-a" : Param<"'T": 0> + #18 87-88 "a" : Param<"'T": 0> + #24 129-131 "()" : Unit + #28 139-275 "{\n let x: Int = Foo(1);\n let y: Double = Foo(1.0);\n let z: BigInt = Foo(10L);\n }" : Unit + #30 161-167 "x: Int" : Int + #35 170-176 "Foo(1)" : Int + #36 170-173 "Foo" : (Int -> Int) + #39 173-176 "(1)" : Int + #40 174-175 "1" : Int + #42 198-207 "y: Double" : Double + #47 210-218 "Foo(1.0)" : Double + #48 210-213 "Foo" : (Double -> Double) + #51 213-218 "(1.0)" : Double + #52 214-217 "1.0" : Double + #54 240-249 "z: BigInt" : BigInt + #59 252-260 "Foo(10L)" : BigInt + #60 252-255 "Foo" : (BigInt -> BigInt) + #63 255-260 "(10L)" : BigInt + #64 256-259 "10L" : BigInt + Error(Type(Error(MissingClassSigned("'T", Span { lo: 87, hi: 88 })))) + "##]], + ); +} + +#[test] +fn transitive_class_check() { + check( + r#" + namespace A { + function Foo<'T: Mul>(a: 'T) : 'T { + a * a + } + + function Bar<'F: Mul>(a: 'F) : 'F { + Foo(a) + } + + function Main() : Unit { + let x: Int = Bar(1); + let y: Double = Bar(1.0); + let z: BigInt = Bar(10L); + } + } + "#, + "", + &expect![[r##" + #8 56-63 "(a: 'T)" : Param<"'T": 0> + #9 57-62 "a: 'T" : Param<"'T": 0> + #15 69-106 "{\n a * a\n }" : Param<"'T": 0> + #17 87-92 "a * a" : Param<"'T": 0> + #18 87-88 "a" : Param<"'T": 0> + #21 91-92 "a" : Param<"'T": 0> + #29 141-148 "(a: 'F)" : Param<"'F": 0> + #30 142-147 "a: 'F" : Param<"'F": 0> + #36 154-192 "{\n Foo(a)\n }" : Param<"'F": 0> + #38 172-178 "Foo(a)" : Param<"'F": 0> + #39 172-175 "Foo" : (Param<"'F": 0> -> Param<"'F": 0>) + #42 175-178 "(a)" : Param<"'F": 0> + #43 176-177 "a" : Param<"'F": 0> + #49 219-221 "()" : Unit + #53 229-365 "{\n let x: Int = Bar(1);\n let y: Double = Bar(1.0);\n let z: BigInt = Bar(10L);\n }" : Unit + #55 251-257 "x: Int" : Int + #60 260-266 "Bar(1)" : Int + #61 260-263 "Bar" : (Int -> Int) + #64 263-266 "(1)" : Int + #65 264-265 "1" : Int + #67 288-297 "y: Double" : Double + #72 300-308 "Bar(1.0)" : Double + #73 300-303 "Bar" : (Double -> Double) + #76 303-308 "(1.0)" : Double + #77 304-307 "1.0" : Double + #79 330-339 "z: BigInt" : BigInt + #84 342-350 "Bar(10L)" : BigInt + #85 342-345 "Bar" : (BigInt -> BigInt) + #88 345-350 "(10L)" : BigInt + #89 346-349 "10L" : BigInt + "##]], + ); +} + +#[test] +fn transitive_class_check_fail() { + check( + r#" + namespace A { + function Foo<'T: Integral>(a: 'T) : 'T { + a + } + + function Bar<'F>(a: 'F) : 'F { + // below should be an error as 'F has no + // Integral bound + Foo(a) + } + + function Main() : Unit { + let x: Int = Foo(1); + // below should be an error as it is a double and not an integral type + let y: Double = Foo(1.0); + let z: BigInt = Foo(10L); + } + } + "#, + "", + &expect![[r##" + #8 61-68 "(a: 'T)" : Param<"'T": 0> + #9 62-67 "a: 'T" : Param<"'T": 0> + #15 74-107 "{\n a\n }" : Param<"'T": 0> + #17 92-93 "a" : Param<"'T": 0> + #24 137-144 "(a: 'F)" : Param<"'F": 0> + #25 138-143 "a: 'F" : Param<"'F": 0> + #31 150-279 "{\n // below should be an error as 'F has no\n // Integral bound\n Foo(a)\n }" : Param<"'F": 0> + #33 259-265 "Foo(a)" : Param<"'F": 0> + #34 259-262 "Foo" : (Param<"'F": 0> -> Param<"'F": 0>) + #37 262-265 "(a)" : Param<"'F": 0> + #38 263-264 "a" : Param<"'F": 0> + #44 306-308 "()" : Unit + #48 316-539 "{\n let x: Int = Foo(1);\n // below should be an error as it is a double and not an integral type\n let y: Double = Foo(1.0);\n let z: BigInt = Foo(10L);\n }" : Unit + #50 338-344 "x: Int" : Int + #55 347-353 "Foo(1)" : Int + #56 347-350 "Foo" : (Int -> Int) + #59 350-353 "(1)" : Int + #60 351-352 "1" : Int + #62 462-471 "y: Double" : Double + #67 474-482 "Foo(1.0)" : Double + #68 474-477 "Foo" : (Double -> Double) + #71 477-482 "(1.0)" : Double + #72 478-481 "1.0" : Double + #74 504-513 "z: BigInt" : BigInt + #79 516-524 "Foo(10L)" : BigInt + #80 516-519 "Foo" : (BigInt -> BigInt) + #83 519-524 "(10L)" : BigInt + #84 520-523 "10L" : BigInt + Error(Type(Error(MissingClassInteger("'F", Span { lo: 259, hi: 265 })))) + Error(Type(Error(MissingClassInteger("Double", Span { lo: 474, hi: 482 })))) + "##]], + ); +} + +#[test] +fn transitive_class_check_superset() { + check( + r#" + namespace A { + function Foo<'T: Sub>(a: 'T) : 'T { + a - a + } + + function Bar<'F: Sub + Eq>(a: 'F) : 'F { + Foo(a) + } + + function Main() : Unit { + let x: Int = Bar(1); + let y: Double = Bar(1.0); + let z: BigInt = Bar(10L); + } + } + "#, + "", + &expect![[r##" + #8 56-63 "(a: 'T)" : Param<"'T": 0> + #9 57-62 "a: 'T" : Param<"'T": 0> + #15 69-106 "{\n a - a\n }" : Param<"'T": 0> + #17 87-92 "a - a" : Param<"'T": 0> + #18 87-88 "a" : Param<"'T": 0> + #21 91-92 "a" : Param<"'T": 0> + #30 146-153 "(a: 'F)" : Param<"'F": 0> + #31 147-152 "a: 'F" : Param<"'F": 0> + #37 159-197 "{\n Foo(a)\n }" : Param<"'F": 0> + #39 177-183 "Foo(a)" : Param<"'F": 0> + #40 177-180 "Foo" : (Param<"'F": 0> -> Param<"'F": 0>) + #43 180-183 "(a)" : Param<"'F": 0> + #44 181-182 "a" : Param<"'F": 0> + #50 224-226 "()" : Unit + #54 234-370 "{\n let x: Int = Bar(1);\n let y: Double = Bar(1.0);\n let z: BigInt = Bar(10L);\n }" : Unit + #56 256-262 "x: Int" : Int + #61 265-271 "Bar(1)" : Int + #62 265-268 "Bar" : (Int -> Int) + #65 268-271 "(1)" : Int + #66 269-270 "1" : Int + #68 293-302 "y: Double" : Double + #73 305-313 "Bar(1.0)" : Double + #74 305-308 "Bar" : (Double -> Double) + #77 308-313 "(1.0)" : Double + #78 309-312 "1.0" : Double + #80 335-344 "z: BigInt" : BigInt + #85 347-355 "Bar(10L)" : BigInt + #86 347-350 "Bar" : (BigInt -> BigInt) + #89 350-355 "(10L)" : BigInt + #90 351-354 "10L" : BigInt + "##]], + ); +} +#[test] +fn show() { + check( + r#" + namespace A { + function Foo<'T: Show>(a: 'T) : String { + let x = $"Value: {a}"; + x + } + + function Main() : Unit { + let x: String = Foo(1); + let y: String = Foo(1.0); + let z: String = Foo(true); + } + } + "#, + "", + &expect![[r##" + #8 57-64 "(a: 'T)" : Param<"'T": 0> + #9 58-63 "a: 'T" : Param<"'T": 0> + #16 74-146 "{\n let x = $\"Value: {a}\";\n x\n }" : String + #18 96-97 "x" : String + #20 100-113 "$\"Value: {a}\"" : String + #21 110-111 "a" : Param<"'T": 0> + #25 131-132 "x" : String + #31 173-175 "()" : Unit + #35 183-323 "{\n let x: String = Foo(1);\n let y: String = Foo(1.0);\n let z: String = Foo(true);\n }" : Unit + #37 205-214 "x: String" : String + #42 217-223 "Foo(1)" : String + #43 217-220 "Foo" : (Int -> String) + #46 220-223 "(1)" : Int + #47 221-222 "1" : Int + #49 245-254 "y: String" : String + #54 257-265 "Foo(1.0)" : String + #55 257-260 "Foo" : (Double -> String) + #58 260-265 "(1.0)" : Double + #59 261-264 "1.0" : Double + #61 287-296 "z: String" : String + #66 299-308 "Foo(true)" : String + #67 299-302 "Foo" : (Bool -> String) + #70 302-308 "(true)" : Bool + #71 303-307 "true" : Bool + "##]], + ); +} + +#[test] +fn show_fail() { + check( + r#" + namespace A { + function Foo<'T>(a: 'T) : String { + $"Value: {a}" + } + + function Main() : Unit { + let x = Foo(1); + let y = Foo(1.0); + let z = Foo(true); + } + } + "#, + "", + &expect![[r##" + #7 51-58 "(a: 'T)" : Param<"'T": 0> + #8 52-57 "a: 'T" : Param<"'T": 0> + #15 68-112 "{\n $\"Value: {a}\"\n }" : String + #17 85-98 "$\"Value: {a}\"" : String + #18 95-96 "a" : Param<"'T": 0> + #24 139-141 "()" : Unit + #28 149-265 "{\n let x = Foo(1);\n let y = Foo(1.0);\n let z = Foo(true);\n }" : Unit + #30 171-172 "x" : String + #32 175-181 "Foo(1)" : String + #33 175-178 "Foo" : (Int -> String) + #36 178-181 "(1)" : Int + #37 179-180 "1" : Int + #39 203-204 "y" : String + #41 207-215 "Foo(1.0)" : String + #42 207-210 "Foo" : (Double -> String) + #45 210-215 "(1.0)" : Double + #46 211-214 "1.0" : Double + #48 237-238 "z" : String + #50 241-250 "Foo(true)" : String + #51 241-244 "Foo" : (Bool -> String) + #54 244-250 "(true)" : Bool + #55 245-249 "true" : Bool + Error(Type(Error(MissingClassShow("'T", Span { lo: 95, hi: 96 })))) + "##]], + ); +} + +#[test] +fn integral() { + check( + r#" + namespace A { + function Foo<'T: Integral>(a: 'T) : 'T { + a ^^^ a + } + + function Main() : Unit { + let x: Int = Foo(1); + let y: BigInt = Foo(10L); + } + } + "#, + "", + &expect![[r##" + #8 61-68 "(a: 'T)" : Param<"'T": 0> + #9 62-67 "a: 'T" : Param<"'T": 0> + #15 74-113 "{\n a ^^^ a\n }" : Param<"'T": 0> + #17 92-99 "a ^^^ a" : Param<"'T": 0> + #18 92-93 "a" : Param<"'T": 0> + #21 98-99 "a" : Param<"'T": 0> + #27 140-142 "()" : Unit + #31 150-244 "{\n let x: Int = Foo(1);\n let y: BigInt = Foo(10L);\n }" : Unit + #33 172-178 "x: Int" : Int + #38 181-187 "Foo(1)" : Int + #39 181-184 "Foo" : (Int -> Int) + #42 184-187 "(1)" : Int + #43 185-186 "1" : Int + #45 209-218 "y: BigInt" : BigInt + #50 221-229 "Foo(10L)" : BigInt + #51 221-224 "Foo" : (BigInt -> BigInt) + #54 224-229 "(10L)" : BigInt + #55 225-228 "10L" : BigInt + "##]], + ); +} +#[test] +fn integral_fail() { + check( + r#" + namespace A { + function Foo<'T: Integral>(a: 'T) : 'T { + a ^^^ a + } + + function Main() : Unit { + let x = Foo(1.0); + let y = Foo(true); + } + } + "#, + "", + &expect![[r##" + #8 61-68 "(a: 'T)" : Param<"'T": 0> + #9 62-67 "a: 'T" : Param<"'T": 0> + #15 74-113 "{\n a ^^^ a\n }" : Param<"'T": 0> + #17 92-99 "a ^^^ a" : Param<"'T": 0> + #18 92-93 "a" : Param<"'T": 0> + #21 98-99 "a" : Param<"'T": 0> + #27 140-142 "()" : Unit + #31 150-234 "{\n let x = Foo(1.0);\n let y = Foo(true);\n }" : Unit + #33 172-173 "x" : Double + #35 176-184 "Foo(1.0)" : Double + #36 176-179 "Foo" : (Double -> Double) + #39 179-184 "(1.0)" : Double + #40 180-183 "1.0" : Double + #42 206-207 "y" : Bool + #44 210-219 "Foo(true)" : Bool + #45 210-213 "Foo" : (Bool -> Bool) + #48 213-219 "(true)" : Bool + #49 214-218 "true" : Bool + Error(Type(Error(MissingClassInteger("Double", Span { lo: 176, hi: 184 })))) + Error(Type(Error(MissingClassInteger("Bool", Span { lo: 210, hi: 219 })))) + "##]], + ); +} + +#[test] +fn constraint_arguments_for_class_with_no_args() { + check( + r#" + namespace A { + function Foo<'T: Eq[Int]>() : Bool { + true + } + } + "#, + "", + &expect![[r##" + #11 60-62 "()" : Unit + #15 70-106 "{\n true\n }" : Bool + #17 88-92 "true" : Bool + Error(Type(Error(IncorrectNumberOfConstraintParameters { expected: 0, found: 1, span: Span { lo: 52, hi: 54 } }))) + "##]], + ); +} + +#[test] +fn show_and_eq() { + check( + r#" + namespace A { + function Foo<'T: Eq + Show>(a: 'T, b: 'T) : String { + if a == b { + $"Value: {a}" + } else { + $"Value: {b}" + } + } + + function Main() : Unit { + let x = Foo(1, 1); + let y = Foo(1, 2); + } + } + "#, + "", + &expect![[r##" + #9 62-76 "(a: 'T, b: 'T)" : (Param<"'T": 0>, Param<"'T": 0>) + #10 63-68 "a: 'T" : Param<"'T": 0> + #14 70-75 "b: 'T" : Param<"'T": 0> + #21 86-240 "{\n if a == b {\n $\"Value: {a}\"\n } else {\n $\"Value: {b}\"\n }\n }" : String + #23 104-226 "if a == b {\n $\"Value: {a}\"\n } else {\n $\"Value: {b}\"\n }" : String + #24 107-113 "a == b" : Bool + #25 107-108 "a" : Param<"'T": 0> + #28 112-113 "b" : Param<"'T": 0> + #31 114-167 "{\n $\"Value: {a}\"\n }" : String + #33 136-149 "$\"Value: {a}\"" : String + #34 146-147 "a" : Param<"'T": 0> + #37 168-226 "else {\n $\"Value: {b}\"\n }" : String + #38 173-226 "{\n $\"Value: {b}\"\n }" : String + #40 195-208 "$\"Value: {b}\"" : String + #41 205-206 "b" : Param<"'T": 0> + #47 267-269 "()" : Unit + #51 277-362 "{\n let x = Foo(1, 1);\n let y = Foo(1, 2);\n }" : Unit + #53 299-300 "x" : String + #55 303-312 "Foo(1, 1)" : String + #56 303-306 "Foo" : ((Int, Int) -> String) + #59 306-312 "(1, 1)" : (Int, Int) + #60 307-308 "1" : Int + #61 310-311 "1" : Int + #63 334-335 "y" : String + #65 338-347 "Foo(1, 2)" : String + #66 338-341 "Foo" : ((Int, Int) -> String) + #69 341-347 "(1, 2)" : (Int, Int) + #70 342-343 "1" : Int + #71 345-346 "2" : Int + "##]], + ); +} + +#[test] +fn show_and_eq_should_fail() { + check( + r#" + namespace A { + function Foo<'T: Eq + Show>(a: 'T, b: 'T) : String { + if a == b { + $"Value: {a}" + } else { + $"Value: {b}" + } + } + + function Main() : Unit { + let x = Foo(1, true); + let y = Foo(1, "2"); + } + } + "#, + "", + &expect![[r##" + #9 62-76 "(a: 'T, b: 'T)" : (Param<"'T": 0>, Param<"'T": 0>) + #10 63-68 "a: 'T" : Param<"'T": 0> + #14 70-75 "b: 'T" : Param<"'T": 0> + #21 86-240 "{\n if a == b {\n $\"Value: {a}\"\n } else {\n $\"Value: {b}\"\n }\n }" : String + #23 104-226 "if a == b {\n $\"Value: {a}\"\n } else {\n $\"Value: {b}\"\n }" : String + #24 107-113 "a == b" : Bool + #25 107-108 "a" : Param<"'T": 0> + #28 112-113 "b" : Param<"'T": 0> + #31 114-167 "{\n $\"Value: {a}\"\n }" : String + #33 136-149 "$\"Value: {a}\"" : String + #34 146-147 "a" : Param<"'T": 0> + #37 168-226 "else {\n $\"Value: {b}\"\n }" : String + #38 173-226 "{\n $\"Value: {b}\"\n }" : String + #40 195-208 "$\"Value: {b}\"" : String + #41 205-206 "b" : Param<"'T": 0> + #47 267-269 "()" : Unit + #51 277-367 "{\n let x = Foo(1, true);\n let y = Foo(1, \"2\");\n }" : Unit + #53 299-300 "x" : String + #55 303-315 "Foo(1, true)" : String + #56 303-306 "Foo" : ((Int, Int) -> String) + #59 306-315 "(1, true)" : (Int, Bool) + #60 307-308 "1" : Int + #61 310-314 "true" : Bool + #63 337-338 "y" : String + #65 341-352 "Foo(1, \"2\")" : String + #66 341-344 "Foo" : ((Int, Int) -> String) + #69 344-352 "(1, \"2\")" : (Int, String) + #70 345-346 "1" : Int + #71 348-351 "\"2\"" : String + Error(Type(Error(TyMismatch("Int", "Bool", Span { lo: 303, hi: 315 })))) + Error(Type(Error(TyMismatch("Int", "String", Span { lo: 341, hi: 352 })))) + "##]], + ); +} + +#[test] +fn unknown_class() { + check( + r#" + namespace A { + function Foo<'T: Unknown>(a: 'T) : 'T { + a + } + + function Main() : Unit { + let x = Foo(1); + } + }"#, + "", + &expect![[r##" + #8 60-67 "(a: 'T)" : Param<"'T": 0> + #9 61-66 "a: 'T" : Param<"'T": 0> + #15 73-106 "{\n a\n }" : Param<"'T": 0> + #17 91-92 "a" : Param<"'T": 0> + #23 133-135 "()" : Unit + #27 143-190 "{\n let x = Foo(1);\n }" : Unit + #29 165-166 "x" : Int + #31 169-175 "Foo(1)" : Int + #32 169-172 "Foo" : (Int -> Int) + #35 172-175 "(1)" : Int + #36 173-174 "1" : Int + Error(Type(Error(UnrecognizedClass { span: Span { lo: 52, hi: 59 }, name: "Unknown" }))) + Error(Type(Error(UnrecognizedClass { span: Span { lo: 52, hi: 59 }, name: "Unknown" }))) + Error(Type(Error(UnrecognizedClass { span: Span { lo: 52, hi: 59 }, name: "Unknown" }))) + Error(Type(Error(UnrecognizedClass { span: Span { lo: 52, hi: 59 }, name: "Unknown" }))) + "##]], + ); +} + +#[test] +fn class_constraint_in_lambda() { + check( + r#" + namespace A { + function Foo<'T: Eq>(a: 'T -> Bool, b: 'T) : Bool { + a(b); + b == b + } + } + "#, + "", + &expect![[r##" + #8 55-77 "(a: 'T -> Bool, b: 'T)" : ((Param<"'T": 0> -> Bool), Param<"'T": 0>) + #9 56-69 "a: 'T -> Bool" : (Param<"'T": 0> -> Bool) + #17 71-76 "b: 'T" : Param<"'T": 0> + #24 85-145 "{\n a(b);\n b == b\n }" : Bool + #26 103-107 "a(b)" : Bool + #27 103-104 "a" : (Param<"'T": 0> -> Bool) + #30 104-107 "(b)" : Param<"'T": 0> + #31 105-106 "b" : Param<"'T": 0> + #35 125-131 "b == b" : Bool + #36 125-126 "b" : Param<"'T": 0> + #39 130-131 "b" : Param<"'T": 0> + "##]], + ); +} + +#[test] +fn test_harness_use_case() { + check( + r#" + namespace A { + function Test<'T: Eq>(test_cases: (() => 'T)[], answers: 'T[]) : Unit { + } + } + "#, + "", + &expect![[r##" + #8 56-97 "(test_cases: (() => 'T)[], answers: 'T[])" : ((Unit => Param<"'T": 0>)[], Param<"'T": 0>[]) + #9 57-81 "test_cases: (() => 'T)[]" : (Unit => Param<"'T": 0>)[] + #17 83-96 "answers: 'T[]" : Param<"'T": 0>[] + #25 105-120 "{\n }" : Unit + "##]], + ); +} diff --git a/compiler/qsc_hir/src/hir.rs b/compiler/qsc_hir/src/hir.rs index fd2799fa93..cb42b9be3b 100644 --- a/compiler/qsc_hir/src/hir.rs +++ b/compiler/qsc_hir/src/hir.rs @@ -4,7 +4,7 @@ //! The high-level intermediate representation for Q#. HIR is lowered from the AST. #![warn(missing_docs)] -use crate::ty::{Arrow, FunctorSet, FunctorSetValue, GenericArg, GenericParam, Scheme, Ty, Udt}; +use crate::ty::{Arrow, FunctorSet, FunctorSetValue, GenericArg, Scheme, Ty, TypeParameter, Udt}; use indenter::{indented, Indented}; use num_bigint::BigInt; use qsc_data_structures::{index_map::IndexMap, span::Span}; @@ -214,7 +214,7 @@ impl ItemStatus { /// A resolution. This connects a usage of a name with the declaration of that name by uniquely /// identifying the node that declared it. -#[derive(Clone, Copy, Debug, Eq, Hash, PartialEq)] +#[derive(Clone, Copy, Debug, Eq, Hash, PartialEq, PartialOrd, Ord)] pub enum Res { /// An invalid resolution. Err, @@ -376,7 +376,7 @@ pub struct CallableDecl { /// The name of the callable. pub name: Ident, /// The generic parameters to the callable. - pub generics: Vec, + pub generics: Vec, /// The input to the callable. pub input: Pat, /// The return type of the callable. @@ -1439,7 +1439,7 @@ pub enum Visibility { } /// A callable kind. -#[derive(Clone, Copy, Debug, Eq, Hash, PartialEq)] +#[derive(Clone, Copy, Debug, Eq, Hash, PartialEq, PartialOrd, Ord)] pub enum CallableKind { /// A function. Function, diff --git a/compiler/qsc_hir/src/ty.rs b/compiler/qsc_hir/src/ty.rs index 0ea18ed9a9..2f1d5b2e04 100644 --- a/compiler/qsc_hir/src/ty.rs +++ b/compiler/qsc_hir/src/ty.rs @@ -24,7 +24,7 @@ fn set_indentation<'a, 'b>( } /// A type. -#[derive(Clone, Debug, Default, Eq, PartialEq)] +#[derive(Clone, Debug, Default, Eq, PartialEq, PartialOrd, Ord)] pub enum Ty { /// An array type. Array(Box), @@ -33,7 +33,11 @@ pub enum Ty { /// A placeholder type variable used during type inference. Infer(InferTyId), /// A type parameter. - Param(Rc, ParamId), + Param { + name: Rc, + id: ParamId, + bounds: ClassConstraints, + }, /// A primitive type. Prim(Prim), /// A tuple type. @@ -45,6 +49,92 @@ pub enum Ty { Err, } +/// Container type for a collection of class constraints, so we can define methods on it. +#[derive(Clone, Debug, Default, Eq, PartialEq, PartialOrd, Ord)] +pub struct ClassConstraints(pub Box<[ClassConstraint]>); + +impl ClassConstraints { + #[must_use] + pub fn contains_iterable_bound(&self) -> bool { + self.0 + .iter() + .any(|bound| matches!(bound, ClassConstraint::Iterable { .. })) + } + + #[must_use] + pub fn is_empty(&self) -> bool { + self.0.is_empty() + } +} + +impl std::fmt::Display for ClassConstraints { + fn fmt(&self, f: &mut Formatter<'_>) -> fmt::Result { + if self.0.is_empty() { + Ok(()) + } else { + let bounds = self + .0 + .iter() + .map(|bound| format!("{bound}")) + .collect::>() + .join(" + "); + write!(f, "{bounds}") + } + } +} + +#[derive(Clone, Debug, Eq, PartialEq, PartialOrd, Ord)] +pub enum ClassConstraint { + /// Whether or not 'T can be compared via Eq to values of the same domain. + Eq, + /// Whether or not 'T can be added to values of the same domain via the + operator. + Add, + Exp { + // `base` is inferred to be the self type + power: Ty, + }, + /// If 'T is iterable, then it can be iterated over and the items inside are yielded (of type `item`). + Iterable { item: Ty }, + /// Whether or not 'T can be divided by values of the same domain via the / operator. + Div, + /// Whether or not 'T can be subtracted from values of the same domain via the - operator. + Sub, + /// Whether or not 'T can be multiplied by values of the same domain via the * operator. + Mul, + /// Whether or not 'T can be taken modulo values of the same domain via the % operator. + Mod, + /// Whether or not 'T can be compared via Ord to values of the same domain. + Ord, + /// Whether or not 'T can be signed. + Signed, + /// Whether or not 'T is an integral type (can be used in bit shifting operators). + Integral, + /// Whether or not 'T can be displayed as a string (converted to a string). + Show, + /// A class that is not built-in to the compiler. + NonNativeClass(Rc), +} + +impl std::fmt::Display for ClassConstraint { + fn fmt(&self, f: &mut Formatter<'_>) -> fmt::Result { + match self { + ClassConstraint::Eq => write!(f, "Eq"), + ClassConstraint::NonNativeClass(name) => write!(f, "{name}"), + ClassConstraint::Add => write!(f, "Add"), + ClassConstraint::Exp { power } => write!(f, "Exp[{power}]"), + ClassConstraint::Iterable { item } => write!(f, "Iterable<{item}>"), + ClassConstraint::Integral => write!(f, "Integral"), + ClassConstraint::Show => write!(f, "Show"), + ClassConstraint::Div => write!(f, "Div"), + ClassConstraint::Sub => write!(f, "Sub"), + ClassConstraint::Mul => write!(f, "Mul"), + ClassConstraint::Mod => write!(f, "Mod"), + ClassConstraint::Ord => write!(f, "Ord"), + ClassConstraint::Signed => write!(f, "Signed"), + } + } +} + impl Ty { /// The unit type. pub const UNIT: Self = Self::Tuple(Vec::new()); @@ -52,7 +142,7 @@ impl Ty { #[must_use] pub fn with_package(&self, package: PackageId) -> Self { match self { - Ty::Infer(_) | Ty::Param(_, _) | Ty::Prim(_) | Ty::Err => self.clone(), + Ty::Infer(_) | Ty::Param { .. } | Ty::Prim(_) | Ty::Err => self.clone(), Ty::Array(item) => Ty::Array(Box::new(item.with_package(package))), Ty::Arrow(arrow) => Ty::Arrow(Box::new(arrow.with_package(package))), Ty::Tuple(items) => Ty::Tuple( @@ -93,7 +183,7 @@ impl Ty { ) } Ty::Infer(_) | Ty::Err => "?".to_string(), - Ty::Param(name, _) | Ty::Udt(name, _) => name.to_string(), + Ty::Param { name, .. } | Ty::Udt(name, _) => name.to_string(), Ty::Prim(prim) => format!("{prim:?}"), Ty::Tuple(items) => { if items.is_empty() { @@ -116,8 +206,8 @@ impl Display for Ty { Ty::Array(item) => write!(f, "{item}[]"), Ty::Arrow(arrow) => Display::fmt(arrow, f), Ty::Infer(infer) => Display::fmt(infer, f), - Ty::Param(name, param_id) => { - write!(f, "Param<\"{name}\": {param_id}>") + Ty::Param { name, id, .. } => { + write!(f, "Param<\"{name}\": {id}>") } Ty::Prim(prim) => Debug::fmt(prim, f), Ty::Tuple(items) => { @@ -150,14 +240,14 @@ impl Display for Ty { #[derive(Debug)] /// A type scheme. pub struct Scheme { - params: Vec, + params: Vec, ty: Box, } impl Scheme { /// Creates a new type scheme. #[must_use] - pub fn new(params: Vec, ty: Box) -> Self { + pub fn new(params: Vec, ty: Box) -> Self { Self { params, ty } } @@ -176,7 +266,7 @@ impl Scheme { /// The generic parameters to the type. #[must_use] - pub fn params(&self) -> &[GenericParam] { + pub fn params(&self) -> &[TypeParameter] { &self.params } @@ -208,6 +298,8 @@ pub enum InstantiationError { Arity, /// A generic argument does not match the kind of its corresponding generic parameter. Kind(ParamId), + /// An in invalid type bound was provided. + Bound(ParamId), } fn instantiate_ty<'a>( @@ -218,9 +310,9 @@ fn instantiate_ty<'a>( Ty::Err | Ty::Infer(_) | Ty::Prim(_) | Ty::Udt(_, _) => Ok(ty.clone()), Ty::Array(item) => Ok(Ty::Array(Box::new(instantiate_ty(arg, item)?))), Ty::Arrow(arrow) => Ok(Ty::Arrow(Box::new(instantiate_arrow_ty(arg, arrow)?))), - Ty::Param(_, param) => match arg(param) { + Ty::Param { id, .. } => match arg(id) { Some(GenericArg::Ty(ty_arg)) => Ok(ty_arg.clone()), - Some(_) => Err(InstantiationError::Kind(*param)), + Some(_) => Err(InstantiationError::Kind(*id)), None => Ok(ty.clone()), }, Ty::Tuple(items) => Ok(Ty::Tuple( @@ -256,21 +348,34 @@ fn instantiate_arrow_ty<'a>( }) } -impl Display for GenericParam { +impl Display for TypeParameter { fn fmt(&self, f: &mut Formatter) -> fmt::Result { match self { - GenericParam::Ty(name) => write!(f, "type {name}"), - GenericParam::Functor(min) => write!(f, "functor ({min})"), + TypeParameter::Ty { name, bounds, .. } => write!( + f, + "type {name}{}", + if bounds.0.is_empty() { + String::new() + } else { + format!(" bounds: {bounds}",) + } + ), + TypeParameter::Functor(min) => write!(f, "functor ({min})"), } } } /// The kind of a generic parameter. #[derive(Clone, Debug, PartialEq)] -pub enum GenericParam { +pub enum TypeParameter { /// A type parameter. - Ty(TypeParamName), - /// A functor parameter with a lower bound. + Ty { + name: Rc, + bounds: ClassConstraints, + }, + /// A functor parameter with a minimal set (lower bound) of functors. + /// if `'T is Adj` then `functor ('T)` is the minimal set of functors. + /// This can currently only occur on lambda expressions. Functor(FunctorSetValue), } @@ -290,7 +395,7 @@ impl Display for TypeParamName { } /// A generic parameter ID. -#[derive(Clone, Copy, Default, Debug, Eq, Hash, PartialEq)] +#[derive(Clone, Copy, Default, Debug, Eq, Hash, PartialEq, PartialOrd, Ord)] pub struct ParamId(u32); impl ParamId { @@ -342,7 +447,7 @@ impl Display for GenericArg { } /// An arrow type: `->` for a function or `=>` for an operation. -#[derive(Clone, Debug, Eq, PartialEq)] +#[derive(Clone, Debug, Eq, PartialEq, PartialOrd, Ord)] pub struct Arrow { /// Whether the callable is a function or an operation. pub kind: CallableKind, @@ -382,7 +487,7 @@ impl Display for Arrow { } /// A primitive type. -#[derive(Clone, Copy, Debug, Eq, Hash, PartialEq)] +#[derive(Clone, Copy, Debug, Eq, Hash, PartialEq, PartialOrd, Ord)] pub enum Prim { /// The big integer type. BigInt, @@ -411,7 +516,7 @@ pub enum Prim { } /// A set of functors. -#[derive(Clone, Copy, Debug, Eq, PartialEq)] +#[derive(Clone, Copy, Debug, Eq, PartialEq, PartialOrd, Ord)] pub enum FunctorSet { /// An evaluated set. Value(FunctorSetValue), @@ -447,7 +552,7 @@ impl Display for FunctorSet { } /// The value of a functor set. -#[derive(Clone, Copy, Debug, Default, Eq, PartialEq)] +#[derive(Clone, Copy, Debug, Default, Eq, PartialEq, PartialOrd, Ord)] pub enum FunctorSetValue { /// The empty set. #[default] @@ -746,7 +851,7 @@ impl Display for UdtField { } /// A placeholder type variable used during type inference. -#[derive(Clone, Copy, Debug, Default, Eq, Hash, PartialEq)] +#[derive(Clone, Copy, Debug, Default, Eq, Hash, PartialEq, PartialOrd, Ord)] pub struct InferTyId(usize); impl InferTyId { diff --git a/compiler/qsc_lowerer/src/lib.rs b/compiler/qsc_lowerer/src/lib.rs index 7495daf527..3fb0084127 100644 --- a/compiler/qsc_lowerer/src/lib.rs +++ b/compiler/qsc_lowerer/src/lib.rs @@ -257,7 +257,7 @@ impl Lowerer { let kind = lower_callable_kind(decl.kind); let name = self.lower_ident(&decl.name); let input = self.lower_pat(&decl.input); - let generics = lower_generics(&decl.generics); + let generics = self.lower_type_parameters(&decl.generics); let output = self.lower_ty(&decl.output); let functors = lower_functors(decl.functors); let implementation = if decl.body.body == SpecBody::Gen(SpecGen::Intrinsic) { @@ -859,7 +859,7 @@ impl Lowerer { qsc_hir::ty::Ty::Infer(id) => { qsc_fir::ty::Ty::Infer(qsc_fir::ty::InferTyId::from(usize::from(*id))) } - qsc_hir::ty::Ty::Param(_, id) => { + qsc_hir::ty::Ty::Param { id, .. } => { qsc_fir::ty::Ty::Param(qsc_fir::ty::ParamId::from(usize::from(*id))) } qsc_hir::ty::Ty::Prim(prim) => qsc_fir::ty::Ty::Prim(lower_ty_prim(*prim)), @@ -878,10 +878,62 @@ impl Lowerer { name_span: field.name_span, } } -} -fn lower_generics(generics: &[qsc_hir::ty::GenericParam]) -> Vec { - generics.iter().map(lower_generic_param).collect() + fn lower_type_parameters( + &mut self, + generics: &[qsc_hir::ty::TypeParameter], + ) -> Vec { + generics + .iter() + .map(|x| self.lower_generic_param(x)) + .collect() + } + + fn lower_generic_param( + &mut self, + g: &qsc_hir::ty::TypeParameter, + ) -> qsc_fir::ty::TypeParameter { + match g { + qsc_hir::ty::TypeParameter::Ty { name, bounds } => qsc_fir::ty::TypeParameter::Ty { + name: name.clone(), + bounds: self.lower_class_constraints(bounds), + }, + qsc_hir::ty::TypeParameter::Functor(value) => { + qsc_fir::ty::TypeParameter::Functor(lower_functor_set_value(*value)) + } + } + } + + fn lower_class_constraints( + &mut self, + bounds: &qsc_hir::ty::ClassConstraints, + ) -> qsc_fir::ty::ClassConstraints { + qsc_fir::ty::ClassConstraints(bounds.0.iter().map(|x| self.lower_ty_bound(x)).collect()) + } + + fn lower_ty_bound(&mut self, b: &qsc_hir::ty::ClassConstraint) -> qsc_fir::ty::ClassConstraint { + use qsc_fir::ty::ClassConstraint as FirClass; + use qsc_hir::ty::ClassConstraint as HirClass; + match b { + HirClass::Eq => FirClass::Eq, + HirClass::Exp { power } => FirClass::Exp { + power: self.lower_ty(power), + }, + HirClass::Add => FirClass::Add, + HirClass::NonNativeClass(name) => FirClass::NonNativeClass(name.clone()), + HirClass::Iterable { item } => FirClass::Iterable { + item: self.lower_ty(item), + }, + HirClass::Integral => FirClass::Integral, + HirClass::Show => FirClass::Show, + HirClass::Mul => FirClass::Mul, + HirClass::Div => FirClass::Div, + HirClass::Sub => FirClass::Sub, + HirClass::Mod => FirClass::Mod, + HirClass::Signed => FirClass::Signed, + HirClass::Ord => FirClass::Ord, + } + } } fn lower_attrs(attrs: &[hir::Attr]) -> Vec { @@ -900,15 +952,6 @@ fn lower_functors(functors: qsc_hir::ty::FunctorSetValue) -> qsc_fir::ty::Functo lower_functor_set_value(functors) } -fn lower_generic_param(g: &qsc_hir::ty::GenericParam) -> qsc_fir::ty::GenericParam { - match g { - qsc_hir::ty::GenericParam::Ty(_) => qsc_fir::ty::GenericParam::Ty, - qsc_hir::ty::GenericParam::Functor(value) => { - qsc_fir::ty::GenericParam::Functor(lower_functor_set_value(*value)) - } - } -} - fn lower_field(field: &hir::Field) -> fir::Field { match field { hir::Field::Err => fir::Field::Err, diff --git a/compiler/qsc_parse/src/completion/word_kinds.rs b/compiler/qsc_parse/src/completion/word_kinds.rs index 64266a20b8..1d04a59220 100644 --- a/compiler/qsc_parse/src/completion/word_kinds.rs +++ b/compiler/qsc_parse/src/completion/word_kinds.rs @@ -48,9 +48,11 @@ bitflags! { const PathSegment = 1 << 5; /// A type parameter, without the leading `'`. const TyParam = 1 << 6; + /// A primitive class. + const PrimitiveClass = 1 << 7; /// A field name. Can follow a `.` or `::` in a field access expression, /// or can be in a field assignment. - const Field = 1 << 7; + const Field = 1 << 8; // // End names. @@ -61,11 +63,11 @@ bitflags! { // /// An attribute, without the leading `@`. - const Attr = 1 << 8; + const Attr = 1 << 9; /// The word `Qubit`. - const Qubit = 1 << 9; + const Qubit = 1 << 10; /// The word `size`. - const Size = 1 << 10; + const Size = 1 << 11; // // End hardcoded identifiers. @@ -132,7 +134,7 @@ bitflags! { } } -const KEYWORDS_START: u8 = 11; +const KEYWORDS_START: u8 = 12; const fn keyword_bit(k: Keyword) -> u128 { 1 << (k as u8 + KEYWORDS_START) } @@ -155,6 +157,7 @@ impl WordKinds { WordKinds::PathSegment => Some(NameKind::PathSegment), WordKinds::TyParam => Some(NameKind::TyParam), WordKinds::Field => Some(NameKind::Field), + WordKinds::PrimitiveClass => Some(NameKind::PrimitiveClass), _ => None, }) } @@ -202,6 +205,8 @@ pub enum NameKind { TyParam, /// A field name that follows a `.` or `::` in a field access expression. Field, + /// A primitive class, like Eq, Exp, or Add. + PrimitiveClass, } /// A path (see: [`Predictions`]) diff --git a/compiler/qsc_parse/src/item/tests.rs b/compiler/qsc_parse/src/item/tests.rs index 642f93f8f0..e94d7168c8 100644 --- a/compiler/qsc_parse/src/item/tests.rs +++ b/compiler/qsc_parse/src/item/tests.rs @@ -709,7 +709,7 @@ fn function_one_ty_param() { Callable _id_ [0-45] (Function): name: Ident _id_ [9-12] "Foo" generics: - Ident _id_ [13-15] "'T" + 'T input: Pat _id_ [16-18]: Unit output: Type _id_ [21-25]: Path: Path _id_ [21-25] (Ident _id_ [21-25] "Unit") body: Specializations: @@ -727,8 +727,8 @@ fn function_two_ty_params() { Callable _id_ [0-49] (Function): name: Ident _id_ [9-12] "Foo" generics: - Ident _id_ [13-15] "'T" - Ident _id_ [17-19] "'U" + 'T, + 'U input: Pat _id_ [20-22]: Unit output: Type _id_ [25-29]: Path: Path _id_ [25-29] (Ident _id_ [25-29] "Unit") body: Specializations: @@ -746,8 +746,8 @@ fn function_duplicate_comma_in_ty_param() { Callable _id_ [0-47] (Function): name: Ident _id_ [9-12] "Foo" generics: - Ident _id_ [13-15] "'T" - Ident _id_ [16-16] "" + 'T, + input: Pat _id_ [18-20]: Unit output: Type _id_ [23-27]: Path: Path _id_ [23-27] (Ident _id_ [23-27] "Unit") body: Specializations: @@ -2267,6 +2267,24 @@ fn missing_semi_between_items() { ); } +#[test] +fn allow_class_bound_on_type_param() { + check( + parse, + "operation Foo<'T: Eq + Ord, 'E: Eq>() : Unit {}", + &expect![[r#" + Item _id_ [0-47]: + Callable _id_ [0-47] (Operation): + name: Ident _id_ [10-13] "Foo" + generics: + 'T: Eq + Ord, + 'E: Eq + input: Pat _id_ [35-37]: Unit + output: Type _id_ [40-44]: Path: Path _id_ [40-44] (Ident _id_ [40-44] "Unit") + body: Block: Block _id_ [45-47]: "#]], + ); +} + #[test] fn callable_decl_no_return_type_or_body_recovery() { check( @@ -2277,7 +2295,7 @@ fn callable_decl_no_return_type_or_body_recovery() { Callable _id_ [0-22] (Operation): name: Ident _id_ [10-13] "Foo" generics: - Ident _id_ [14-16] "'T" + 'T input: Pat _id_ [17-19]: Unit output: Type _id_ [22-22]: Err body: Block: Block _id_ [22-22]: @@ -2307,7 +2325,7 @@ fn callable_decl_broken_return_type_no_body_recovery() { Callable _id_ [0-28] (Operation): name: Ident _id_ [10-13] "Foo" generics: - Ident _id_ [14-16] "'T" + 'T input: Pat _id_ [17-19]: Unit output: Type _id_ [22-28]: Arrow (Operation): param: Type _id_ [22-24]: Unit diff --git a/compiler/qsc_parse/src/ty.rs b/compiler/qsc_parse/src/ty.rs index cdec62f051..0eae1cceea 100644 --- a/compiler/qsc_parse/src/ty.rs +++ b/compiler/qsc_parse/src/ty.rs @@ -14,11 +14,12 @@ use crate::{ completion::WordKinds, item::throw_away_doc, lex::{ClosedBinOp, Delim, TokenKind}, - prim::{parse_or_else, recovering_path}, + prim::{ident, parse_or_else, recovering_path}, ErrorKind, }; use qsc_ast::ast::{ - CallableKind, Functor, FunctorExpr, FunctorExprKind, Ident, NodeId, SetOp, Ty, TyKind, + CallableKind, ClassConstraint, ClassConstraints, ConstraintParameter, Functor, FunctorExpr, + FunctorExprKind, NodeId, SetOp, Ty, TyKind, TypeParameter, }; pub(super) fn ty(s: &mut ParserContext) -> Result { @@ -72,9 +73,53 @@ pub(super) fn array_or_arrow(s: &mut ParserContext<'_>, mut lhs: Ty, lo: u32) -> } } -pub(super) fn param(s: &mut ParserContext) -> Result> { +pub(super) fn param(s: &mut ParserContext) -> Result { throw_away_doc(s); - apos_ident(s) + let lo = s.peek().span.lo; + let generic = apos_ident(s)?; + let bounds = if token(s, TokenKind::Colon).is_ok() { + Some(class_constraints(s)?) + } else { + None + }; + + Ok(TypeParameter::new( + *generic, + bounds.unwrap_or_else(|| ClassConstraints(Box::new([]))), + s.span(lo), + )) +} + +/// Parses the bounds of a type parameter, which are a list of class names separated by `+`. +/// This occurs after a `:` in a generic type: +/// `T: Eq + Iterator[Bool] + Class3` +/// ^^^^^^^^^^^^^^^^^^^^^^^^^^^^ bounds +fn class_constraints(s: &mut ParserContext) -> Result { + let mut bounds: Vec = Vec::new(); + loop { + s.expect(WordKinds::PrimitiveClass); + let bound_name = ident(s)?; + // if there's a less-than sign, or "open square bracket", try to parse type parameters for + // the class + // e.g. `Iterator[Bool]` + let mut ty_parameters = Vec::new(); + if token(s, TokenKind::Open(Delim::Bracket)).is_ok() { + let (tys, _final_sep) = seq(s, ty)?; + ty_parameters = tys; + token(s, TokenKind::Close(Delim::Bracket))?; + } + bounds.push(ClassConstraint { + name: *bound_name, + parameters: ty_parameters + .into_iter() + .map(|ty| ConstraintParameter { ty }) + .collect(), + }); + if token(s, TokenKind::ClosedBinOp(ClosedBinOp::Plus)).is_err() { + break; + } + } + Ok(ClassConstraints(bounds.into_boxed_slice())) } fn array(s: &mut ParserContext) -> Result<()> { @@ -97,6 +142,9 @@ fn arrow(s: &mut ParserContext) -> Result { } } +/// the base type of a type, which can be a hole, a type parameter, a path, or a parenthesized type +/// (or a tuple) +/// This parses the part before the arrow or array in a type, if an arrow or array is present. fn base(s: &mut ParserContext) -> Result { throw_away_doc(s); let lo = s.peek().span.lo; diff --git a/compiler/qsc_parse/src/ty/tests.rs b/compiler/qsc_parse/src/ty/tests.rs index 8d38def154..c1cb008126 100644 --- a/compiler/qsc_parse/src/ty/tests.rs +++ b/compiler/qsc_parse/src/ty/tests.rs @@ -97,11 +97,7 @@ fn ty_unit() { #[test] fn ty_param() { - check( - ty, - "'T", - &expect![[r#"Type _id_ [0-2]: Type Param: Ident _id_ [0-2] "'T""#]], - ); + check(ty, "'T", &expect!["Type _id_ [0-2]: Type Param: 'T"]); } #[test] diff --git a/compiler/qsc_passes/src/loop_unification.rs b/compiler/qsc_passes/src/loop_unification.rs index 28427fc188..e39c1fa636 100644 --- a/compiler/qsc_passes/src/loop_unification.rs +++ b/compiler/qsc_passes/src/loop_unification.rs @@ -336,9 +336,9 @@ impl MutVisitor for LoopUni<'_> { Ty::Prim(Prim::Range) => { *expr = self.visit_for_range(iter, iterable, block, expr.span); } - _ => { + a => { // This scenario should have been caught by type-checking earlier - panic!("The type of the iterable must be either array or range.") + panic!("The type of the iterable must be either array or range, but it is an {a:?}") } } } diff --git a/language_service/src/completion.rs b/language_service/src/completion.rs index 7e3e41e7a0..8b94525ab1 100644 --- a/language_service/src/completion.rs +++ b/language_service/src/completion.rs @@ -224,6 +224,49 @@ fn collect_names( groups.push(fields.fields()); } + NameKind::PrimitiveClass => { + // we know the types of the primitive classes, so we can just return them + // hard coded here. + // If we ever support user-defined primitive classes, we'll need to change this. + + // this is here to force us to update completions if a new primitive class + // constraint is supported + use qsc::hir::ty::ClassConstraint::*; + match Add { + Add + | Eq + | Exp { .. } + | Iterable { .. } + | NonNativeClass(_) + | Integral + | Mod + | Sub + | Mul + | Div + | Signed + | Ord + | Show => (), + } + + groups.push(vec![ + Completion::new("Add".to_string(), CompletionItemKind::Class), + Completion::new("Eq".to_string(), CompletionItemKind::Class), + Completion::with_detail( + "Exp".to_string(), + CompletionItemKind::Class, + Some("Exp['Power]".into()), + ), + Completion::new("Num".to_string(), CompletionItemKind::Class), + Completion::new("Integral".to_string(), CompletionItemKind::Class), + Completion::new("Show".to_string(), CompletionItemKind::Class), + Completion::new("Signed".to_string(), CompletionItemKind::Class), + Completion::new("Ord".to_string(), CompletionItemKind::Class), + Completion::new("Mod".to_string(), CompletionItemKind::Class), + Completion::new("Sub".to_string(), CompletionItemKind::Class), + Completion::new("Mul".to_string(), CompletionItemKind::Class), + Completion::new("Div".to_string(), CompletionItemKind::Class), + ]); + } }; } groups diff --git a/language_service/src/completion/tests.rs b/language_service/src/completion/tests.rs index cd1a00856c..582b0caa83 100644 --- a/language_service/src/completion/tests.rs +++ b/language_service/src/completion/tests.rs @@ -13,6 +13,8 @@ use crate::{ use expect_test::{expect, Expect}; use indoc::indoc; +mod class_completions; + fn check(source_with_cursor: &str, completions_to_check: &[&str], expect: &Expect) { let (compilation, cursor_position, _) = compile_with_markers(source_with_cursor, true); let actual_completions = diff --git a/language_service/src/completion/tests/class_completions.rs b/language_service/src/completion/tests/class_completions.rs new file mode 100644 index 0000000000..a868765106 --- /dev/null +++ b/language_service/src/completion/tests/class_completions.rs @@ -0,0 +1,182 @@ +use super::check; +use expect_test::expect; + +// the `Iterable` class should not be in completions until we support it +#[test] +fn iterable_not_included_in_completions() { + check( + r"namespace Test { + operation Test<'T: ↘ + }", + &["Iterable"], + &expect![[r#" + [ + None, + ] + "#]], + ); +} + +#[test] +fn all_prim_classes_in_completions() { + check( + r"namespace Test { + operation Test<'T: ↘ + }", + &["Eq", "Add", "Exp", "Integral", "Num", "Show"], + &expect![[r#" + [ + Some( + CompletionItem { + label: "Eq", + kind: Class, + sort_text: Some( + "0100Eq", + ), + detail: None, + additional_text_edits: None, + }, + ), + Some( + CompletionItem { + label: "Add", + kind: Class, + sort_text: Some( + "0100Add", + ), + detail: None, + additional_text_edits: None, + }, + ), + Some( + CompletionItem { + label: "Exp", + kind: Class, + sort_text: Some( + "0100Exp", + ), + detail: Some( + "Exp['Power]", + ), + additional_text_edits: None, + }, + ), + Some( + CompletionItem { + label: "Integral", + kind: Class, + sort_text: Some( + "0100Integral", + ), + detail: None, + additional_text_edits: None, + }, + ), + Some( + CompletionItem { + label: "Num", + kind: Class, + sort_text: Some( + "0100Num", + ), + detail: None, + additional_text_edits: None, + }, + ), + Some( + CompletionItem { + label: "Show", + kind: Class, + sort_text: Some( + "0100Show", + ), + detail: None, + additional_text_edits: None, + }, + ), + ] + "#]], + ); +} + +#[test] +fn classes_appear_after_plus_too() { + check( + r"namespace Test { + operation Test<'T: Add + ↘ + }", + &["Eq", "Add", "Exp", "Integral", "Num", "Show"], + &expect![[r#" + [ + Some( + CompletionItem { + label: "Eq", + kind: Class, + sort_text: Some( + "0100Eq", + ), + detail: None, + additional_text_edits: None, + }, + ), + Some( + CompletionItem { + label: "Add", + kind: Class, + sort_text: Some( + "0100Add", + ), + detail: None, + additional_text_edits: None, + }, + ), + Some( + CompletionItem { + label: "Exp", + kind: Class, + sort_text: Some( + "0100Exp", + ), + detail: Some( + "Exp['Power]", + ), + additional_text_edits: None, + }, + ), + Some( + CompletionItem { + label: "Integral", + kind: Class, + sort_text: Some( + "0100Integral", + ), + detail: None, + additional_text_edits: None, + }, + ), + Some( + CompletionItem { + label: "Num", + kind: Class, + sort_text: Some( + "0100Num", + ), + detail: None, + additional_text_edits: None, + }, + ), + Some( + CompletionItem { + label: "Show", + kind: Class, + sort_text: Some( + "0100Show", + ), + detail: None, + additional_text_edits: None, + }, + ), + ] + "#]], + ); +} diff --git a/language_service/src/name_locator.rs b/language_service/src/name_locator.rs index 782a897d56..22d3a8f2cd 100644 --- a/language_service/src/name_locator.rs +++ b/language_service/src/name_locator.rs @@ -190,10 +190,10 @@ impl<'inner, 'package, T: Handler<'package>> Visitor<'package> for Locator<'inne // walk callable decl decl.generics.iter().for_each(|p| { if p.span.touches(self.offset) { - if let Some(resolve::Res::Param(param_id)) = - self.compilation.get_res(p.id) + if let Some(resolve::Res::Param { id, .. }) = + self.compilation.get_res(p.ty.id) { - self.inner.at_type_param_def(&self.context, p, *param_id); + self.inner.at_type_param_def(&self.context, &p.ty, *id); } } }); @@ -308,11 +308,16 @@ impl<'inner, 'package, T: Handler<'package>> Visitor<'package> for Locator<'inne fn visit_ty(&mut self, ty: &'package ast::Ty) { if ty.span.touches(self.offset) { if let ast::TyKind::Param(param) = &*ty.kind { - if let Some(resolve::Res::Param(param_id)) = self.compilation.get_res(param.id) { + if let Some(resolve::Res::Param { id, .. }) = self.compilation.get_res(param.ty.id) + { if let Some(curr) = self.context.current_callable { - if let Some(def_name) = curr.generics.get(usize::from(*param_id)) { - self.inner - .at_type_param_ref(&self.context, param, *param_id, def_name); + if let Some(def_name) = curr.generics.get(usize::from(*id)) { + self.inner.at_type_param_ref( + &self.context, + ¶m.ty, + *id, + &def_name.ty, + ); } } } diff --git a/language_service/src/protocol.rs b/language_service/src/protocol.rs index 014d222027..4a75b701ca 100644 --- a/language_service/src/protocol.rs +++ b/language_service/src/protocol.rs @@ -69,6 +69,7 @@ pub enum CompletionItemKind { Variable, TypeParameter, Field, + Class, } #[derive(Debug, Default)] diff --git a/language_service/src/references.rs b/language_service/src/references.rs index 30d0445fc3..59026da5b3 100644 --- a/language_service/src/references.rs +++ b/language_service/src/references.rs @@ -511,10 +511,10 @@ impl<'a> Visitor<'_> for FindTyParamLocations<'a> { fn visit_callable_decl(&mut self, decl: &ast::CallableDecl) { if self.include_declaration { decl.generics.iter().for_each(|p| { - let res = self.compilation.get_res(p.id); - if let Some(resolve::Res::Param(param_id)) = res { - if *param_id == self.param_id { - self.locations.push(p.span); + let res = self.compilation.get_res(p.ty.id); + if let Some(resolve::Res::Param { id, .. }) = res { + if *id == self.param_id { + self.locations.push(p.ty.span); } } }); @@ -524,9 +524,9 @@ impl<'a> Visitor<'_> for FindTyParamLocations<'a> { fn visit_ty(&mut self, ty: &ast::Ty) { if let ast::TyKind::Param(param) = &*ty.kind { - let res = self.compilation.get_res(param.id); - if let Some(resolve::Res::Param(param_id)) = res { - if *param_id == self.param_id { + let res = self.compilation.get_res(param.ty.id); + if let Some(resolve::Res::Param { id, .. }) = res { + if *id == self.param_id { self.locations.push(param.span); } } diff --git a/npm/qsharp/test/basics.js b/npm/qsharp/test/basics.js index 0e7c78546d..44488d9b4e 100644 --- a/npm/qsharp/test/basics.js +++ b/npm/qsharp/test/basics.js @@ -131,7 +131,7 @@ test("error with newlines", async () => { ); assert.equal( diags[0].message, - "type error: missing type in item signature\n\nhelp: types cannot be inferred for global declarations", + "type error: missing type in item signature\n\nhelp: a type must be provided for this item", ); }); diff --git a/playground/src/main.tsx b/playground/src/main.tsx index 0c0ad32a47..3d945f18ff 100644 --- a/playground/src/main.tsx +++ b/playground/src/main.tsx @@ -287,6 +287,9 @@ function registerMonacoLanguageServiceProviders( case "field": kind = monaco.languages.CompletionItemKind.Field; break; + case "class": + kind = monaco.languages.CompletionItemKind.Class; + break; } return { label: i.label, diff --git a/samples/language/ClassConstraints.qs b/samples/language/ClassConstraints.qs new file mode 100644 index 0000000000..aa4303aa8a --- /dev/null +++ b/samples/language/ClassConstraints.qs @@ -0,0 +1,58 @@ +// # Sample +// Class Constraints +// +// # Description +// Q# supports constraining generic types via _class constraints_. The formal term for this concept is bounded polymorphism, +// or parametric polymorphism. +// The currently supported classes are `Exp`, for exponentiation; `Eq`, for comparison via the `==` operator; `Add`, for addition via the `+` operator; +// `Num`, if a type is numeric; `Integral`, if a type is a form of integer; and `Show`, if a type can be rendered as a string. + +// A generic type, or type parameter, is specified on a callable declaration to signify that a function can take multiple types of data as input. +// For a generic type parameter to be useful, we need to be able to know enough about it to operate on it. This is where class constraints come in. By specifying +// class constraints for a type parameter, we are limiting what types can be passed as arguments to a subset with known properties. + +// Classes that Q# currently supports are: +// - `Eq`: denotes that a type can be compared to other values of the same type via the `==` operator. +// - `Add`: denotes that a type can be added to other values of the same type via the `+` operator. +// - `Mod`: denotes that a type can be used with the modulo (`%`) operator. +// - `Sub`: denotes that a type can subtracted from other values of the same type via the `-` operator. +// - `Div`: denotes that a type can subtracted from other values of the same type via the `/` operator. +// - `Signed`: denotes that a type can be negated or made positive with `+` or `-` unary operator prefixes. +// - `Ord`: denotes that a type can be used with comparison operators (`>`, `<`, `>=`, `<=`). +// - `Num`: denotes that a type can be used in `>`, `>=`, `<`, `<=`, `/`, `%`, `*`, and `-` operator expressions. +// - `Show`: denotes that a type can be converted to a string via format strings (`$"number: {num}"`). +// - `Exp['T]`: denotes that a type can be raised to a power of type `'T`. The return type of exponentiation is the type of the base. +// - `Integral`: denotes that a type is an integer-ish type, i.e., can be used in following expressions using the following operators: `&&&`, `|||`, `^^^`, `<<<`, and `>>>`. + +// For example, we may want to write a function that checks if a list is full of entirely the same item. `f([3, 3, 3])` would be `true` and `f([3, 4])` would be false. +function AllEqual<'T : Eq>(items : 'T[]) : Bool { + let allEqual = true; + for i in 1..Length(items) - 1 { + if items[i] != items[i - 1] { + return false; + } + } + return true; +} + +function Main() : Unit { + let is_equal = AllEqual([1, 1, 1]); + Message($"{is_equal}"); + + let is_equal = AllEqual([1, 2, 3]); + Message($"{is_equal}"); + + // Because we wrote this function generically, we are able to pass in different types, as + // long as they can be compared via the class `Eq`. + let is_equal = AllEqual([true, true, false]); + Message($"{is_equal}"); + + let is_equal = AllEqual(["a", "b"]); + Message($"{is_equal}"); + + let is_equal = AllEqual([[], [1]]); + Message($"{is_equal}"); + + let is_equal = AllEqual([[1], [1]]); + Message($"{is_equal}"); +} diff --git a/samples_test/src/tests/language.rs b/samples_test/src/tests/language.rs index 6a4b80470a..4b51b10169 100644 --- a/samples_test/src/tests/language.rs +++ b/samples_test/src/tests/language.rs @@ -344,3 +344,19 @@ pub const WHILELOOPS_EXPECT: Expect = expect!["()"]; pub const WHILELOOPS_EXPECT_DEBUG: Expect = expect!["()"]; pub const WITHINAPPLY_EXPECT: Expect = expect!["()"]; pub const WITHINAPPLY_EXPECT_DEBUG: Expect = expect!["()"]; +pub const CLASSCONSTRAINTS_EXPECT: Expect = expect![[r#" + true + false + false + false + false + true + ()"#]]; +pub const CLASSCONSTRAINTS_EXPECT_DEBUG: Expect = expect![[r#" + true + false + false + false + false + true + ()"#]]; diff --git a/vscode/src/completion.ts b/vscode/src/completion.ts index 8753ee685b..dd6600ae54 100644 --- a/vscode/src/completion.ts +++ b/vscode/src/completion.ts @@ -76,6 +76,9 @@ class QSharpCompletionItemProvider implements vscode.CompletionItemProvider { case "field": kind = vscode.CompletionItemKind.Field; break; + case "class": + kind = vscode.CompletionItemKind.Class; + break; } const item = new CompletionItem(c.label, kind); item.sortText = c.sortText; diff --git a/wasm/src/language_service.rs b/wasm/src/language_service.rs index f189b50762..5e60a3f9fb 100644 --- a/wasm/src/language_service.rs +++ b/wasm/src/language_service.rs @@ -164,6 +164,7 @@ impl LanguageService { qsls::protocol::CompletionItemKind::Variable => "variable", qsls::protocol::CompletionItemKind::TypeParameter => "typeParameter", qsls::protocol::CompletionItemKind::Field => "field", + qsls::protocol::CompletionItemKind::Class => "class", }) .to_string(), sortText: i.sort_text, @@ -408,7 +409,7 @@ serializable_type! { }, r#"export interface ICompletionItem { label: string; - kind: "function" | "interface" | "keyword" | "module" | "property" | "variable" | "typeParameter" | "field" ; + kind: "function" | "interface" | "keyword" | "module" | "property" | "variable" | "typeParameter" | "field" | "class"; sortText?: string; detail?: string; additionalTextEdits?: ITextEdit[]; diff --git a/wasm/src/tests.rs b/wasm/src/tests.rs index db4124b381..3b6d79f075 100644 --- a/wasm/src/tests.rs +++ b/wasm/src/tests.rs @@ -37,7 +37,7 @@ fn test_missing_type() { let _ = run_internal( SourceMap::new([("test.qs".into(), code.into())], Some(expr.into())), |msg| { - expect![[r#"{"result":{"code":"Qsc.TypeCk.MissingItemTy","message":"type error: missing type in item signature\n\nhelp: types cannot be inferred for global declarations","range":{"end":{"character":33,"line":0},"start":{"character":32,"line":0}},"severity":"error"},"success":false,"type":"Result"}"#]].assert_eq(msg); + expect![[r#"{"result":{"code":"Qsc.TypeCk.MissingTy","message":"type error: missing type in item signature\n\nhelp: a type must be provided for this item","range":{"end":{"character":33,"line":0},"start":{"character":32,"line":0}},"severity":"error"},"success":false,"type":"Result"}"#]].assert_eq(msg); count.set(count.get() + 1); }, 1,