From 179b115507cf08d7e566001e7db9353c7b141e5e Mon Sep 17 00:00:00 2001 From: Mark Koch Date: Tue, 28 May 2024 09:34:06 +0100 Subject: [PATCH 1/8] feat: Add a nat type and make int/float core types --- guppylang/checker/core.py | 19 ++++++++ guppylang/prelude/_internal.py | 41 +++------------- guppylang/prelude/builtins.py | 19 ++++++-- guppylang/tys/builtin.py | 50 ++++++++++++++----- guppylang/tys/printing.py | 5 ++ guppylang/tys/ty.py | 72 +++++++++++++++++++++++++++- tests/integration/test_arithmetic.py | 9 ++++ 7 files changed, 163 insertions(+), 52 deletions(-) diff --git a/guppylang/checker/core.py b/guppylang/checker/core.py index cba8930f..c092381f 100644 --- a/guppylang/checker/core.py +++ b/guppylang/checker/core.py @@ -14,8 +14,11 @@ from guppylang.tys.builtin import ( bool_type_def, callable_type_def, + float_type_def, + int_type_def, linst_type_def, list_type_def, + nat_type_def, none_type_def, tuple_type_def, ) @@ -24,6 +27,7 @@ ExistentialTypeVar, FunctionType, NoneType, + NumericType, OpaqueType, StructType, SumType, @@ -67,6 +71,9 @@ def default() -> "Globals": tuple_type_def, none_type_def, bool_type_def, + nat_type_def, + int_type_def, + float_type_def, list_type_def, linst_type_def, ] @@ -85,6 +92,18 @@ def get_instance_func(self, ty: Type | TypeDef, name: str) -> CallableDef | None pass case BoundTypeVar() | ExistentialTypeVar() | SumType(): return None + case NumericType(kind): + match kind: + case NumericType.Kind.Bool: + type_defn = bool_type_def + case NumericType.Kind.Nat: + type_defn = nat_type_def + case NumericType.Kind.Int: + type_defn = int_type_def + case NumericType.Kind.Float: + type_defn = float_type_def + case kind: + return assert_never(kind) case FunctionType(): type_defn = callable_type_def case OpaqueType() as ty: diff --git a/guppylang/prelude/_internal.py b/guppylang/prelude/_internal.py index 9db3f70e..58450b2e 100644 --- a/guppylang/prelude/_internal.py +++ b/guppylang/prelude/_internal.py @@ -12,36 +12,13 @@ CustomFunctionDef, DefaultCallChecker, ) -from guppylang.definition.ty import TypeDef from guppylang.definition.value import CallableDef from guppylang.error import GuppyError, GuppyTypeError from guppylang.hugr_builder.hugr import UNDEFINED, OutPortV from guppylang.nodes import GlobalCall from guppylang.tys.builtin import bool_type, list_type from guppylang.tys.subst import Subst -from guppylang.tys.ty import FunctionType, OpaqueType, Type, unify - -INT_WIDTH = 6 # 2^6 = 64 bit - - -hugr_int_type = tys.Type( - tys.Opaque( - extension="arithmetic.int.types", - id="int", - args=[tys.TypeArg(tys.BoundedNatArg(n=INT_WIDTH))], - bound=tys.TypeBound.Eq, - ) -) - - -hugr_float_type = tys.Type( - tys.Opaque( - extension="arithmetic.float.types", - id="float64", - args=[], - bound=tys.TypeBound.Copyable, - ) -) +from guppylang.tys.ty import FunctionType, NumericType, Type, unify class ConstInt(BaseModel): @@ -77,9 +54,9 @@ def int_value(i: int) -> ops.Value: return ops.Value( ops.ExtensionValue( extensions=["arithmetic.int.types"], - typ=hugr_int_type, + typ=NumericType(NumericType.Kind.Nat).to_hugr(), value=ops.CustomConst( - c="ConstInt", v=ConstInt(log_width=INT_WIDTH, value=i) + c="ConstInt", v=ConstInt(log_width=NumericType.INT_WIDTH, value=i) ), ) ) @@ -90,7 +67,7 @@ def float_value(f: float) -> ops.Value: return ops.Value( ops.ExtensionValue( extensions=["arithmetic.float.types"], - typ=hugr_float_type, + typ=NumericType(NumericType.Kind.Float).to_hugr(), value=ops.CustomConst(c="ConstF64", v=ConstF64(value=f)), ) ) @@ -124,7 +101,7 @@ def int_op( ops.CustomOp( extension=ext, op_name=op_name, - args=num_params * [tys.TypeArg(tys.BoundedNatArg(n=INT_WIDTH))], + args=num_params * [tys.TypeArg(tys.BoundedNatArg(n=NumericType.INT_WIDTH))], parent=UNDEFINED, ) ) @@ -145,16 +122,12 @@ def synthesize(self, args: list[ast.expr]) -> tuple[ast.expr, Type]: for i in range(len(args)): args[i], ty = ExprSynthesizer(self.ctx).synthesize(args[i]) - if isinstance(ty, OpaqueType) and ty.defn == self.ctx.globals["int"]: + if isinstance(ty, NumericType) and ty.kind == NumericType.Kind.Int: call = with_loc( self.node, GlobalCall(def_id=Int.__float__.id, args=[args[i]], type_args=[]), ) - float_defn = self.ctx.globals["float"] - assert isinstance(float_defn, TypeDef) - args[i] = with_type( - float_defn.check_instantiate([], self.ctx.globals), call - ) + args[i] = with_type(NumericType(NumericType.Kind.Float), call) return super().synthesize(args) diff --git a/guppylang/prelude/builtins.py b/guppylang/prelude/builtins.py index 1d017fe6..ce485e97 100644 --- a/guppylang/prelude/builtins.py +++ b/guppylang/prelude/builtins.py @@ -21,12 +21,17 @@ ReversingChecker, UnsupportedChecker, float_op, - hugr_float_type, - hugr_int_type, int_op, logic_op, ) -from guppylang.tys.builtin import bool_type_def, linst_type_def, list_type_def +from guppylang.tys.builtin import ( + bool_type_def, + float_type_def, + int_type_def, + linst_type_def, + list_type_def, + nat_type_def, +) builtins = GuppyModule("builtins", import_builtins=False) @@ -34,6 +39,10 @@ L = guppy.type_var(builtins, "L", linear=True) +# Define the nat type so scripts can import it +nat = nat_type_def + + @guppy.extend_type(builtins, bool_type_def) class Bool: @guppy.hugr_op(builtins, logic_op("And", [tys.TypeArg(tys.BoundedNatArg(n=2))])) @@ -52,7 +61,7 @@ def __new__(x): ... def __or__(self: bool, other: bool) -> bool: ... -@guppy.type(builtins, hugr_int_type, name="int") +@guppy.extend_type(builtins, int_type_def) class Int: @guppy.hugr_op(builtins, int_op("iabs")) # TODO: Maybe wrong? (signed vs unsigned!) def __abs__(self: int) -> int: ... @@ -191,7 +200,7 @@ def __trunc__(self: int) -> int: ... def __xor__(self: int, other: int) -> int: ... -@guppy.type(builtins, hugr_float_type, name="float", bound=tys.TypeBound.Copyable) +@guppy.extend_type(builtins, float_type_def) class Float: @guppy.hugr_op(builtins, float_op("fabs"), CoercingChecker()) def __abs__(self: float) -> float: ... diff --git a/guppylang/tys/builtin.py b/guppylang/tys/builtin.py index 2c984496..fe37996d 100644 --- a/guppylang/tys/builtin.py +++ b/guppylang/tys/builtin.py @@ -10,7 +10,14 @@ from guppylang.error import GuppyError from guppylang.tys.arg import Argument, TypeArg from guppylang.tys.param import TypeParam -from guppylang.tys.ty import FunctionType, NoneType, OpaqueType, TupleType, Type +from guppylang.tys.ty import ( + FunctionType, + NoneType, + NumericType, + OpaqueType, + TupleType, + Type, +) if TYPE_CHECKING: from guppylang.checker.core import Globals @@ -79,6 +86,23 @@ def check_instantiate( return NoneType() +@dataclass(frozen=True) +class _NumericTypeDef(TypeDef): + """Type definition associated with the builtin `None` type. + + Any impls on None can be registered with this definition. + """ + + ty: NumericType + + def check_instantiate( + self, args: Sequence[Argument], globals: "Globals", loc: AstNode | None = None + ) -> NumericType: + if args: + raise GuppyError(f"Type `{self.name}` is not parameterized", loc) + return self.ty + + @dataclass(frozen=True) class _ListTypeDef(OpaqueTypeDef): """Type definition associated with the builtin `list` type. @@ -115,13 +139,17 @@ def _list_to_hugr(args: Sequence[Argument]) -> tys.Type: callable_type_def = _CallableTypeDef(DefId.fresh(), None) tuple_type_def = _TupleTypeDef(DefId.fresh(), None) none_type_def = _NoneTypeDef(DefId.fresh(), None) -bool_type_def = OpaqueTypeDef( - id=DefId.fresh(), - name="bool", - defined_at=None, - params=[], - always_linear=False, - to_hugr=lambda _: tys.Type(tys.SumType(tys.UnitSum(size=2))), +bool_type_def = _NumericTypeDef( + DefId.fresh(), "bool", None, NumericType(NumericType.Kind.Bool) +) +nat_type_def = _NumericTypeDef( + DefId.fresh(), "nat", None, NumericType(NumericType.Kind.Nat) +) +int_type_def = _NumericTypeDef( + DefId.fresh(), "int", None, NumericType(NumericType.Kind.Int) +) +float_type_def = _NumericTypeDef( + DefId.fresh(), "float", None, NumericType(NumericType.Kind.Float) ) linst_type_def = OpaqueTypeDef( id=DefId.fresh(), @@ -141,8 +169,8 @@ def _list_to_hugr(args: Sequence[Argument]) -> tys.Type: ) -def bool_type() -> OpaqueType: - return OpaqueType([], bool_type_def) +def bool_type() -> NumericType: + return NumericType(NumericType.Kind.Bool) def list_type(element_ty: Type) -> OpaqueType: @@ -154,7 +182,7 @@ def linst_type(element_ty: Type) -> OpaqueType: def is_bool_type(ty: Type) -> bool: - return isinstance(ty, OpaqueType) and ty.defn == bool_type_def + return isinstance(ty, NumericType) and ty.kind == NumericType.Kind.Bool def is_list_type(ty: Type) -> bool: diff --git a/guppylang/tys/printing.py b/guppylang/tys/printing.py index 8ac82be6..2d0efb44 100644 --- a/guppylang/tys/printing.py +++ b/guppylang/tys/printing.py @@ -6,6 +6,7 @@ from guppylang.tys.ty import ( FunctionType, NoneType, + NumericType, OpaqueType, StructType, SumType, @@ -106,6 +107,10 @@ def _visit_SumType(self, ty: SumType, inside_row: bool) -> str: def _visit_NoneType(self, ty: NoneType, inside_row: bool) -> str: return "None" + @_visit.register + def _visit_NumericType(self, ty: NumericType, inside_row: bool) -> str: + return ty.kind.value + @_visit.register def _visit_TypeParam(self, param: TypeParam, inside_row: bool) -> str: # TODO: Print linearity? diff --git a/guppylang/tys/ty.py b/guppylang/tys/ty.py index 5b176120..2a58f75d 100644 --- a/guppylang/tys/ty.py +++ b/guppylang/tys/ty.py @@ -1,8 +1,9 @@ from abc import ABC, abstractmethod from collections.abc import Sequence from dataclasses import dataclass, field +from enum import Enum from functools import cached_property -from typing import TYPE_CHECKING, TypeAlias, cast +from typing import TYPE_CHECKING, ClassVar, TypeAlias, cast from hugr.serialization import tys from hugr.serialization.tys import TypeBound @@ -234,6 +235,69 @@ def transform(self, transformer: Transformer) -> "Type": return transformer.transform(self) or self +@dataclass(frozen=True) +class NumericType(TypeBase): + """Numeric types like `int` and `float`.""" + + kind: "Kind" + + class Kind(Enum): + """The different kinds of numeric types.""" + + Bool = "bool" + Nat = "nat" + Int = "int" + Float = "float" + + INT_WIDTH: ClassVar[int] = 6 + + @property + def linear(self) -> bool: + """Whether this type should be treated linearly.""" + return False + + def to_hugr(self) -> tys.Type: + """Computes the Hugr representation of the type.""" + match self.kind: + case NumericType.Kind.Bool: + return SumType([NoneType(), NoneType()]).to_hugr() + case NumericType.Kind.Nat | NumericType.Kind.Int: + return tys.Type( + tys.Opaque( + extension="arithmetic.int.types", + id="int", + args=[tys.TypeArg(tys.BoundedNatArg(n=NumericType.INT_WIDTH))], + bound=tys.TypeBound.Eq, + ) + ) + case NumericType.Kind.Float: + return tys.Type( + tys.Opaque( + extension="arithmetic.float.types", + id="float64", + args=[], + bound=tys.TypeBound.Copyable, + ) + ) + + @property + def hugr_bound(self) -> tys.TypeBound: + """The Hugr bound of this type, i.e. `Any`, `Copyable`, or `Equatable`.""" + match self.kind: + case NumericType.Kind.Float: + return tys.TypeBound.Copyable + case _: + return tys.TypeBound.Eq + + def visit(self, visitor: Visitor) -> None: + """Accepts a visitor on this type.""" + visitor.visit(self) + + def transform(self, transformer: Transformer) -> "Type": + """Accepts a transformer on this type.""" + return transformer.transform(self) or self + + @dataclass(frozen=True, init=False) class FunctionType(ParametrizedTypeBase): """Type of (potentially generic) functions.""" @@ -493,7 +557,9 @@ def transform(self, transformer: Transformer) -> "Type": #: This might become obsolete in case the @sealed decorator is added: #: * https://peps.python.org/pep-0622/#sealed-classes-as-algebraic-data-types #: * https://github.com/johnthagen/sealed-typing-pep -Type: TypeAlias = BoundTypeVar | ExistentialTypeVar | NoneType | ParametrizedType +Type: TypeAlias = ( + BoundTypeVar | ExistentialTypeVar | NumericType | NoneType | ParametrizedType +) #: An immutable row of Guppy types. TypeRow: TypeAlias = Sequence[Type] @@ -545,6 +611,8 @@ def unify(s: Type, t: Type, subst: "Subst | None") -> "Subst | None": return _unify_var(t, s, subst) case BoundTypeVar(idx=s_idx), BoundTypeVar(idx=t_idx) if s_idx == t_idx: return subst + case NumericType(kind=s_kind), NumericType(kind=t_kind) if s_kind == t_kind: + return subst case NoneType(), NoneType(): return subst case FunctionType() as s, FunctionType() as t if s.params == t.params: diff --git a/tests/integration/test_arithmetic.py b/tests/integration/test_arithmetic.py index a31db662..1abb4c57 100644 --- a/tests/integration/test_arithmetic.py +++ b/tests/integration/test_arithmetic.py @@ -1,3 +1,4 @@ +from guppylang.prelude.builtins import nat from tests.util import compile_guppy @@ -82,3 +83,11 @@ def foo(x: bool, y: int) -> bool: return z validate(foo) + + +def test_nat(validate): + @compile_guppy + def add(x: nat) -> nat: + return x + + validate(add) From 0ea6ac2bfa917292d8705e374e1fa165194c2dc8 Mon Sep 17 00:00:00 2001 From: Mark Koch Date: Thu, 30 May 2024 10:33:42 +0100 Subject: [PATCH 2/8] Revert addition of nat type --- guppylang/checker/core.py | 4 ---- guppylang/prelude/_internal.py | 2 +- guppylang/prelude/builtins.py | 5 ----- guppylang/tys/builtin.py | 3 --- guppylang/tys/ty.py | 3 +-- tests/integration/test_arithmetic.py | 9 --------- 6 files changed, 2 insertions(+), 24 deletions(-) diff --git a/guppylang/checker/core.py b/guppylang/checker/core.py index c092381f..4ebb4a33 100644 --- a/guppylang/checker/core.py +++ b/guppylang/checker/core.py @@ -18,7 +18,6 @@ int_type_def, linst_type_def, list_type_def, - nat_type_def, none_type_def, tuple_type_def, ) @@ -71,7 +70,6 @@ def default() -> "Globals": tuple_type_def, none_type_def, bool_type_def, - nat_type_def, int_type_def, float_type_def, list_type_def, @@ -96,8 +94,6 @@ def get_instance_func(self, ty: Type | TypeDef, name: str) -> CallableDef | None match kind: case NumericType.Kind.Bool: type_defn = bool_type_def - case NumericType.Kind.Nat: - type_defn = nat_type_def case NumericType.Kind.Int: type_defn = int_type_def case NumericType.Kind.Float: diff --git a/guppylang/prelude/_internal.py b/guppylang/prelude/_internal.py index 58450b2e..1f9c47e4 100644 --- a/guppylang/prelude/_internal.py +++ b/guppylang/prelude/_internal.py @@ -54,7 +54,7 @@ def int_value(i: int) -> ops.Value: return ops.Value( ops.ExtensionValue( extensions=["arithmetic.int.types"], - typ=NumericType(NumericType.Kind.Nat).to_hugr(), + typ=NumericType(NumericType.Kind.Int).to_hugr(), value=ops.CustomConst( c="ConstInt", v=ConstInt(log_width=NumericType.INT_WIDTH, value=i) ), diff --git a/guppylang/prelude/builtins.py b/guppylang/prelude/builtins.py index ce485e97..4ce8ca8a 100644 --- a/guppylang/prelude/builtins.py +++ b/guppylang/prelude/builtins.py @@ -30,7 +30,6 @@ int_type_def, linst_type_def, list_type_def, - nat_type_def, ) builtins = GuppyModule("builtins", import_builtins=False) @@ -39,10 +38,6 @@ L = guppy.type_var(builtins, "L", linear=True) -# Define the nat type so scripts can import it -nat = nat_type_def - - @guppy.extend_type(builtins, bool_type_def) class Bool: @guppy.hugr_op(builtins, logic_op("And", [tys.TypeArg(tys.BoundedNatArg(n=2))])) diff --git a/guppylang/tys/builtin.py b/guppylang/tys/builtin.py index fe37996d..1668314a 100644 --- a/guppylang/tys/builtin.py +++ b/guppylang/tys/builtin.py @@ -142,9 +142,6 @@ def _list_to_hugr(args: Sequence[Argument]) -> tys.Type: bool_type_def = _NumericTypeDef( DefId.fresh(), "bool", None, NumericType(NumericType.Kind.Bool) ) -nat_type_def = _NumericTypeDef( - DefId.fresh(), "nat", None, NumericType(NumericType.Kind.Nat) -) int_type_def = _NumericTypeDef( DefId.fresh(), "int", None, NumericType(NumericType.Kind.Int) ) diff --git a/guppylang/tys/ty.py b/guppylang/tys/ty.py index 2a58f75d..1c70038d 100644 --- a/guppylang/tys/ty.py +++ b/guppylang/tys/ty.py @@ -245,7 +245,6 @@ class Kind(Enum): """The different kinds of numeric types.""" Bool = "bool" - Nat = "nat" Int = "int" Float = "float" @@ -261,7 +260,7 @@ def to_hugr(self) -> tys.Type: match self.kind: case NumericType.Kind.Bool: return SumType([NoneType(), NoneType()]).to_hugr() - case NumericType.Kind.Nat | NumericType.Kind.Int: + case NumericType.Kind.Int: return tys.Type( tys.Opaque( extension="arithmetic.int.types", diff --git a/tests/integration/test_arithmetic.py b/tests/integration/test_arithmetic.py index 1abb4c57..a31db662 100644 --- a/tests/integration/test_arithmetic.py +++ b/tests/integration/test_arithmetic.py @@ -1,4 +1,3 @@ -from guppylang.prelude.builtins import nat from tests.util import compile_guppy @@ -83,11 +82,3 @@ def foo(x: bool, y: int) -> bool: return z validate(foo) - - -def test_nat(validate): - @compile_guppy - def add(x: nat) -> nat: - return x - - validate(add) From 506de84658f7c63fea264bac68d9be9ad42588eb Mon Sep 17 00:00:00 2001 From: Mark Koch Date: Thu, 30 May 2024 10:35:50 +0100 Subject: [PATCH 3/8] Fix _NumericTypeDef docstring --- guppylang/tys/builtin.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/guppylang/tys/builtin.py b/guppylang/tys/builtin.py index 1668314a..4bbd179e 100644 --- a/guppylang/tys/builtin.py +++ b/guppylang/tys/builtin.py @@ -88,9 +88,9 @@ def check_instantiate( @dataclass(frozen=True) class _NumericTypeDef(TypeDef): - """Type definition associated with the builtin `None` type. + """Type definition associated with the builtin numeric types. - Any impls on None can be registered with this definition. + Any impls on numerics can be registered with these definitions. """ ty: NumericType From 1a61ca82cd187aeb6cb544153b4f02ec152b4ade Mon Sep 17 00:00:00 2001 From: Mark Koch Date: Thu, 30 May 2024 10:43:31 +0100 Subject: [PATCH 4/8] Support float coercion for all numeric types --- guppylang/prelude/_internal.py | 12 ++++-------- 1 file changed, 4 insertions(+), 8 deletions(-) diff --git a/guppylang/prelude/_internal.py b/guppylang/prelude/_internal.py index 1f9c47e4..37c9b06e 100644 --- a/guppylang/prelude/_internal.py +++ b/guppylang/prelude/_internal.py @@ -118,16 +118,12 @@ class CoercingChecker(DefaultCallChecker): """Function call type checker that automatically coerces arguments to float.""" def synthesize(self, args: list[ast.expr]) -> tuple[ast.expr, Type]: - from .builtins import Int - for i in range(len(args)): args[i], ty = ExprSynthesizer(self.ctx).synthesize(args[i]) - if isinstance(ty, NumericType) and ty.kind == NumericType.Kind.Int: - call = with_loc( - self.node, - GlobalCall(def_id=Int.__float__.id, args=[args[i]], type_args=[]), - ) - args[i] = with_type(NumericType(NumericType.Kind.Float), call) + if isinstance(ty, NumericType) and ty.kind != NumericType.Kind.Float: + to_float = self.ctx.globals.get_instance_func(ty, "__float__") + assert to_float is not None + args[i], _ = to_float.synthesize_call([args[i]], self.node, self.ctx) return super().synthesize(args) From 078aa25d91b4a2eabb5477ecbe5f0d51ec8fec18 Mon Sep 17 00:00:00 2001 From: Mark Koch Date: Thu, 30 May 2024 13:46:23 +0100 Subject: [PATCH 5/8] Turn bool into a numeric type --- guppylang/prelude/_internal.py | 57 +++++++-- guppylang/prelude/builtins.py | 130 +++++++++++++++++++- guppylang/tys/builtin.py | 4 + tests/error/type_errors/invert_not_int.err | 6 +- tests/error/type_errors/invert_not_int.py | 2 +- tests/error/type_errors/unary_not_arith.err | 6 +- tests/error/type_errors/unary_not_arith.py | 2 +- tests/integration/test_arithmetic.py | 12 +- 8 files changed, 197 insertions(+), 22 deletions(-) diff --git a/guppylang/prelude/_internal.py b/guppylang/prelude/_internal.py index 37c9b06e..601d608b 100644 --- a/guppylang/prelude/_internal.py +++ b/guppylang/prelude/_internal.py @@ -3,9 +3,14 @@ from hugr.serialization import ops, tys from pydantic import BaseModel -from guppylang.ast_util import AstNode, get_type, with_loc, with_type +from guppylang.ast_util import AstNode, get_type, with_loc from guppylang.checker.core import Context -from guppylang.checker.expr_checker import ExprSynthesizer, check_num_args +from guppylang.checker.expr_checker import ( + ExprSynthesizer, + check_call, + check_num_args, + synthesize_call, +) from guppylang.definition.custom import ( CustomCallChecker, CustomCallCompiler, @@ -16,7 +21,7 @@ from guppylang.error import GuppyError, GuppyTypeError from guppylang.hugr_builder.hugr import UNDEFINED, OutPortV from guppylang.nodes import GlobalCall -from guppylang.tys.builtin import bool_type, list_type +from guppylang.tys.builtin import bool_type, int_type, list_type from guppylang.tys.subst import Subst from guppylang.tys.ty import FunctionType, NumericType, Type, unify @@ -94,16 +99,16 @@ def logic_op(op_name: str, args: list[tys.TypeArg] | None = None) -> ops.OpType: def int_op( - op_name: str, ext: str = "arithmetic.int", num_params: int = 1 + op_name: str, + ext: str = "arithmetic.int", + args: list[tys.TypeArg] | None = None, + num_params: int = 1, ) -> ops.OpType: """Utility method to create Hugr integer arithmetic ops.""" + if args is None: + args = num_params * [tys.TypeArg(tys.BoundedNatArg(n=NumericType.INT_WIDTH))] return ops.OpType( - ops.CustomOp( - extension=ext, - op_name=op_name, - args=num_params * [tys.TypeArg(tys.BoundedNatArg(n=NumericType.INT_WIDTH))], - parent=UNDEFINED, - ) + ops.CustomOp(extension=ext, op_name=op_name, args=args, parent=UNDEFINED) ) @@ -242,6 +247,38 @@ def check(self, args: list[ast.expr], ty: Type) -> tuple[ast.expr, Subst]: return args, subst +class BoolArithChecker(DefaultCallChecker): + """Function call checker for arithmetic operations on bools. + + Converts all bools into ints and calls the corresponding int arithmetic method with + the same name. + """ + + def _prepare_args(self, args: list[ast.expr]) -> list[ast.expr]: + # Cast all inputs to int + to_int = self.ctx.globals.get_instance_func(bool_type(), "__int__") + assert to_int is not None + return [to_int.synthesize_call([arg], arg, self.ctx)[0] for arg in args] + + def _get_func(self) -> CallableDef: + # Get the int function with the same + func = self.ctx.globals.get_instance_func(int_type(), self.func.name) + assert func is not None + return func + + def synthesize(self, args: list[ast.expr]) -> tuple[ast.expr, Type]: + args, _, inst = synthesize_call(self.func.ty, args, self.node, self.ctx) + assert not inst # `self.func.ty` is not generic + args = self._prepare_args(args) + return self._get_func().synthesize_call(args, self.node, self.ctx) + + def check(self, args: list[ast.expr], ty: Type) -> tuple[ast.expr, Subst]: + args, _, inst = check_call(self.func.ty, args, ty, self.node, self.ctx) + assert not inst # `self.func.ty` is not generic + args = self._prepare_args(args) + return self._get_func().check_call(args, ty, self.node, self.ctx) + + class IntTruedivCompiler(CustomCallCompiler): """Compiler for the `int.__truediv__` method.""" diff --git a/guppylang/prelude/builtins.py b/guppylang/prelude/builtins.py index 4ce8ca8a..2ab5e241 100644 --- a/guppylang/prelude/builtins.py +++ b/guppylang/prelude/builtins.py @@ -9,6 +9,7 @@ from guppylang.hugr_builder.hugr import DummyOp from guppylang.module import GuppyModule from guppylang.prelude._internal import ( + BoolArithChecker, CallableChecker, CoercingChecker, DunderChecker, @@ -40,21 +41,146 @@ @guppy.extend_type(builtins, bool_type_def) class Bool: + @guppy.custom(builtins, NoopCompiler()) + def __abs__(self: int) -> int: ... + + @guppy.custom(builtins, checker=BoolArithChecker()) + def __add__(self: bool, other: bool) -> int: ... + @guppy.hugr_op(builtins, logic_op("And", [tys.TypeArg(tys.BoundedNatArg(n=2))])) def __and__(self: bool, other: bool) -> bool: ... @guppy.custom(builtins, NoopCompiler()) def __bool__(self: bool) -> bool: ... - @guppy.hugr_op(builtins, int_op("ifrombool")) + @guppy.custom(builtins, checker=BoolArithChecker()) + def __ceil__(self: bool) -> int: ... + + @guppy.custom(builtins, checker=BoolArithChecker()) + def __divmod__(self: bool, other: bool) -> tuple[int, int]: ... + + @guppy.custom(builtins, checker=BoolArithChecker()) + def __eq__(self: bool, other: bool) -> bool: ... + + @guppy.custom(builtins, checker=BoolArithChecker()) + def __float__(self: bool) -> float: ... + + @guppy.custom(builtins, checker=BoolArithChecker()) + def __floor__(self: bool) -> int: ... + + @guppy.custom(builtins, checker=BoolArithChecker()) + def __floordiv__(self: bool, other: bool) -> int: ... + + @guppy.custom(builtins, checker=BoolArithChecker()) + def __ge__(self: bool, other: bool) -> bool: ... + + @guppy.custom(builtins, checker=BoolArithChecker()) + def __gt__(self: bool, other: bool) -> bool: ... + + @guppy.hugr_op(builtins, DummyOp("ifrombool")) # TODO: Widen to INT_WIDTH def __int__(self: bool) -> int: ... + @guppy.custom(builtins, checker=BoolArithChecker()) + def __invert__(self: bool) -> int: ... + + @guppy.custom(builtins, checker=BoolArithChecker()) + def __le__(self: bool, other: bool) -> bool: ... + + @guppy.custom(builtins, checker=BoolArithChecker()) + def __lshift__(self: bool, other: bool) -> int: ... + + @guppy.custom(builtins, checker=BoolArithChecker()) + def __lt__(self: bool, other: bool) -> bool: ... + + @guppy.custom(builtins, checker=BoolArithChecker()) + def __mod__(self: bool, other: bool) -> bool: ... + + @guppy.custom(builtins, checker=BoolArithChecker()) + def __mul__(self: bool, other: bool) -> bool: ... + + @guppy.custom(builtins, checker=BoolArithChecker()) + def __ne__(self: bool, other: bool) -> bool: ... + + @guppy.custom(builtins, checker=BoolArithChecker()) + def __neg__(self: bool) -> int: ... + @guppy.custom(builtins, checker=DunderChecker("__bool__"), higher_order_value=False) def __new__(x): ... @guppy.hugr_op(builtins, logic_op("Or", [tys.TypeArg(tys.BoundedNatArg(n=2))])) def __or__(self: bool, other: bool) -> bool: ... + @guppy.custom(builtins, checker=BoolArithChecker()) + def __pos__(self: bool) -> int: ... + + @guppy.custom(builtins, checker=BoolArithChecker()) + def __pow__(self: bool, other: bool) -> int: ... + + @guppy.custom(builtins, checker=BoolArithChecker()) + def __radd__(self: bool, other: bool) -> int: ... + + @guppy.hugr_op( + builtins, + logic_op("And", [tys.TypeArg(tys.BoundedNatArg(n=2))]), + ReversingChecker(), + ) + def __rand__(self: bool, other: bool) -> int: ... + + @guppy.custom(builtins, checker=BoolArithChecker()) + def __rdivmod__(self: bool, other: bool) -> tuple[int, int]: ... + + @guppy.custom(builtins, checker=BoolArithChecker()) + def __rfloordiv__(self: bool, other: bool) -> int: ... + + @guppy.custom(builtins, checker=BoolArithChecker()) + def __rlshift__(self: bool, other: bool) -> int: ... + + @guppy.custom(builtins, checker=BoolArithChecker()) + def __rmod__(self: bool, other: bool) -> int: ... + + @guppy.custom(builtins, checker=BoolArithChecker()) + def __rmul__(self: bool, other: bool) -> int: ... + + @guppy.hugr_op( + builtins, + logic_op("Or", [tys.TypeArg(tys.BoundedNatArg(n=2))]), + ReversingChecker(), + ) + def __ror__(self: bool, other: bool) -> int: ... + + @guppy.custom(builtins, checker=BoolArithChecker()) + def __round__(self: bool) -> int: ... + + @guppy.custom(builtins, checker=BoolArithChecker()) + def __rpow__(self: bool, other: bool) -> int: ... + + @guppy.custom(builtins, checker=BoolArithChecker()) + def __rrshift__(self: bool, other: bool) -> int: ... + + @guppy.custom(builtins, checker=BoolArithChecker()) + def __rshift__(self: bool, other: bool) -> int: ... + + @guppy.custom(builtins, checker=BoolArithChecker()) + def __rsub__(self: bool, other: bool) -> int: ... + + @guppy.custom(builtins, checker=BoolArithChecker()) + def __rtruediv__(self: bool, other: bool) -> float: ... + + @guppy.hugr_op(builtins, DummyOp("Xor"), ReversingChecker()) # TODO + def __rxor__(self: bool, other: bool) -> bool: ... + + @guppy.custom(builtins, checker=BoolArithChecker()) + def __sub__(self: bool, other: bool) -> int: ... + + @guppy.custom(builtins, checker=BoolArithChecker()) + def __truediv__(self: bool, other: bool) -> float: ... + + @guppy.custom(builtins, checker=BoolArithChecker()) + def __trunc__(self: bool) -> int: ... + + @guppy.hugr_op(builtins, DummyOp("Xor")) # TODO + def __xor__(self: bool, other: bool) -> bool: ... + @guppy.extend_type(builtins, int_type_def) class Int: @@ -98,7 +224,7 @@ def __gt__(self: int, other: int) -> bool: ... def __int__(self: int) -> int: ... @guppy.hugr_op(builtins, int_op("inot")) - def __invert__(self: int) -> bool: ... + def __invert__(self: int) -> int: ... @guppy.hugr_op(builtins, int_op("ile_s")) def __le__(self: int, other: int) -> bool: ... diff --git a/guppylang/tys/builtin.py b/guppylang/tys/builtin.py index 4bbd179e..7cfc0c9d 100644 --- a/guppylang/tys/builtin.py +++ b/guppylang/tys/builtin.py @@ -170,6 +170,10 @@ def bool_type() -> NumericType: return NumericType(NumericType.Kind.Bool) +def int_type() -> NumericType: + return NumericType(NumericType.Kind.Int) + + def list_type(element_ty: Type) -> OpaqueType: return OpaqueType([TypeArg(element_ty)], list_type_def) diff --git a/tests/error/type_errors/invert_not_int.err b/tests/error/type_errors/invert_not_int.err index 16524a1f..a26cae58 100644 --- a/tests/error/type_errors/invert_not_int.err +++ b/tests/error/type_errors/invert_not_int.err @@ -2,6 +2,6 @@ Guppy compilation failed. Error in file $FILE:6 4: @compile_guppy 5: def foo() -> int: -6: return ~True - ^^^^ -GuppyTypeError: Unary operator `~` not defined for argument of type `bool` +6: return ~() + ^^ +GuppyTypeError: Unary operator `~` not defined for argument of type `()` diff --git a/tests/error/type_errors/invert_not_int.py b/tests/error/type_errors/invert_not_int.py index 39399e81..958456d9 100644 --- a/tests/error/type_errors/invert_not_int.py +++ b/tests/error/type_errors/invert_not_int.py @@ -3,4 +3,4 @@ @compile_guppy def foo() -> int: - return ~True + return ~() diff --git a/tests/error/type_errors/unary_not_arith.err b/tests/error/type_errors/unary_not_arith.err index ae82206b..09ae6847 100644 --- a/tests/error/type_errors/unary_not_arith.err +++ b/tests/error/type_errors/unary_not_arith.err @@ -2,6 +2,6 @@ Guppy compilation failed. Error in file $FILE:6 4: @compile_guppy 5: def foo() -> int: -6: return -True - ^^^^ -GuppyTypeError: Unary operator `-` not defined for argument of type `bool` +6: return -() + ^^ +GuppyTypeError: Unary operator `-` not defined for argument of type `()` diff --git a/tests/error/type_errors/unary_not_arith.py b/tests/error/type_errors/unary_not_arith.py index bf9ce2cb..0862b902 100644 --- a/tests/error/type_errors/unary_not_arith.py +++ b/tests/error/type_errors/unary_not_arith.py @@ -3,4 +3,4 @@ @compile_guppy def foo() -> int: - return -True + return -() diff --git a/tests/integration/test_arithmetic.py b/tests/integration/test_arithmetic.py index a31db662..1ac41e14 100644 --- a/tests/integration/test_arithmetic.py +++ b/tests/integration/test_arithmetic.py @@ -26,10 +26,18 @@ def add(x: int) -> int: validate(add) +def test_bool(validate): + @compile_guppy + def add(x: bool, y: bool) -> int: + return x + y + + validate(add) + + def test_float_coercion(validate): @compile_guppy - def coerce(x: int, y: float) -> float: - return x * y + def coerce(x: int, y: float, z: bool) -> float: + return x * y + z validate(coerce) From 999f121e01e6e3220d1d395bb0948363f4f0c4a6 Mon Sep 17 00:00:00 2001 From: Mark Koch <48097969+mark-koch@users.noreply.github.com> Date: Tue, 18 Jun 2024 17:14:50 +0100 Subject: [PATCH 6/8] Fix comment Co-authored-by: Craig Roy --- guppylang/prelude/_internal.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/guppylang/prelude/_internal.py b/guppylang/prelude/_internal.py index 601d608b..8325faa6 100644 --- a/guppylang/prelude/_internal.py +++ b/guppylang/prelude/_internal.py @@ -261,7 +261,7 @@ def _prepare_args(self, args: list[ast.expr]) -> list[ast.expr]: return [to_int.synthesize_call([arg], arg, self.ctx)[0] for arg in args] def _get_func(self) -> CallableDef: - # Get the int function with the same + # Get the int function with the same name func = self.ctx.globals.get_instance_func(int_type(), self.func.name) assert func is not None return func From cf8a5299e7b317de1f30f2c84c1b665bfba8c904 Mon Sep 17 00:00:00 2001 From: Mark Koch Date: Mon, 24 Jun 2024 14:37:07 +0100 Subject: [PATCH 7/8] Undo turning bool into a numeric type --- guppylang/checker/core.py | 2 - guppylang/prelude/_internal.py | 41 +------ guppylang/prelude/builtins.py | 128 +------------------- guppylang/tys/builtin.py | 15 ++- guppylang/tys/ty.py | 3 - tests/error/type_errors/invert_not_int.err | 6 +- tests/error/type_errors/invert_not_int.py | 2 +- tests/error/type_errors/unary_not_arith.err | 6 +- tests/error/type_errors/unary_not_arith.py | 2 +- tests/integration/test_arithmetic.py | 12 +- 10 files changed, 23 insertions(+), 194 deletions(-) diff --git a/guppylang/checker/core.py b/guppylang/checker/core.py index 4ebb4a33..72fef4db 100644 --- a/guppylang/checker/core.py +++ b/guppylang/checker/core.py @@ -92,8 +92,6 @@ def get_instance_func(self, ty: Type | TypeDef, name: str) -> CallableDef | None return None case NumericType(kind): match kind: - case NumericType.Kind.Bool: - type_defn = bool_type_def case NumericType.Kind.Int: type_defn = int_type_def case NumericType.Kind.Float: diff --git a/guppylang/prelude/_internal.py b/guppylang/prelude/_internal.py index fc62385c..6d02f78f 100644 --- a/guppylang/prelude/_internal.py +++ b/guppylang/prelude/_internal.py @@ -5,12 +5,7 @@ from guppylang.ast_util import AstNode, get_type, with_loc from guppylang.checker.core import Context -from guppylang.checker.expr_checker import ( - ExprSynthesizer, - check_call, - check_num_args, - synthesize_call, -) +from guppylang.checker.expr_checker import ExprSynthesizer, check_num_args from guppylang.definition.custom import ( CustomCallChecker, CustomCallCompiler, @@ -21,7 +16,7 @@ from guppylang.error import GuppyError, GuppyTypeError from guppylang.hugr_builder.hugr import UNDEFINED, OutPortV from guppylang.nodes import GlobalCall -from guppylang.tys.builtin import bool_type, int_type, list_type +from guppylang.tys.builtin import bool_type, list_type from guppylang.tys.subst import Subst from guppylang.tys.ty import FunctionType, NumericType, Type, unify @@ -246,38 +241,6 @@ def check(self, args: list[ast.expr], ty: Type) -> tuple[ast.expr, Subst]: return args, subst -class BoolArithChecker(DefaultCallChecker): - """Function call checker for arithmetic operations on bools. - - Converts all bools into ints and calls the corresponding int arithmetic method with - the same name. - """ - - def _prepare_args(self, args: list[ast.expr]) -> list[ast.expr]: - # Cast all inputs to int - to_int = self.ctx.globals.get_instance_func(bool_type(), "__int__") - assert to_int is not None - return [to_int.synthesize_call([arg], arg, self.ctx)[0] for arg in args] - - def _get_func(self) -> CallableDef: - # Get the int function with the same name - func = self.ctx.globals.get_instance_func(int_type(), self.func.name) - assert func is not None - return func - - def synthesize(self, args: list[ast.expr]) -> tuple[ast.expr, Type]: - args, _, inst = synthesize_call(self.func.ty, args, self.node, self.ctx) - assert not inst # `self.func.ty` is not generic - args = self._prepare_args(args) - return self._get_func().synthesize_call(args, self.node, self.ctx) - - def check(self, args: list[ast.expr], ty: Type) -> tuple[ast.expr, Subst]: - args, _, inst = check_call(self.func.ty, args, ty, self.node, self.ctx) - assert not inst # `self.func.ty` is not generic - args = self._prepare_args(args) - return self._get_func().check_call(args, ty, self.node, self.ctx) - - class IntTruedivCompiler(CustomCallCompiler): """Compiler for the `int.__truediv__` method.""" diff --git a/guppylang/prelude/builtins.py b/guppylang/prelude/builtins.py index 65a5abb0..03af4d75 100644 --- a/guppylang/prelude/builtins.py +++ b/guppylang/prelude/builtins.py @@ -12,7 +12,6 @@ from guppylang.hugr_builder.hugr import DummyOp from guppylang.module import GuppyModule from guppylang.prelude._internal import ( - BoolArithChecker, CallableChecker, CoercingChecker, DunderChecker, @@ -53,146 +52,21 @@ def py(*_args: Any) -> Any: @guppy.extend_type(builtins, bool_type_def) class Bool: - @guppy.custom(builtins, NoopCompiler()) - def __abs__(self: int) -> int: ... - - @guppy.custom(builtins, checker=BoolArithChecker()) - def __add__(self: bool, other: bool) -> int: ... - @guppy.hugr_op(builtins, logic_op("And", [tys.TypeArg(tys.BoundedNatArg(n=2))])) def __and__(self: bool, other: bool) -> bool: ... @guppy.custom(builtins, NoopCompiler()) def __bool__(self: bool) -> bool: ... - @guppy.custom(builtins, checker=BoolArithChecker()) - def __ceil__(self: bool) -> int: ... - - @guppy.custom(builtins, checker=BoolArithChecker()) - def __divmod__(self: bool, other: bool) -> tuple[int, int]: ... - - @guppy.custom(builtins, checker=BoolArithChecker()) - def __eq__(self: bool, other: bool) -> bool: ... - - @guppy.custom(builtins, checker=BoolArithChecker()) - def __float__(self: bool) -> float: ... - - @guppy.custom(builtins, checker=BoolArithChecker()) - def __floor__(self: bool) -> int: ... - - @guppy.custom(builtins, checker=BoolArithChecker()) - def __floordiv__(self: bool, other: bool) -> int: ... - - @guppy.custom(builtins, checker=BoolArithChecker()) - def __ge__(self: bool, other: bool) -> bool: ... - - @guppy.custom(builtins, checker=BoolArithChecker()) - def __gt__(self: bool, other: bool) -> bool: ... - - @guppy.hugr_op(builtins, DummyOp("ifrombool")) # TODO: Widen to INT_WIDTH + @guppy.hugr_op(builtins, int_op("ifrombool")) def __int__(self: bool) -> int: ... - @guppy.custom(builtins, checker=BoolArithChecker()) - def __invert__(self: bool) -> int: ... - - @guppy.custom(builtins, checker=BoolArithChecker()) - def __le__(self: bool, other: bool) -> bool: ... - - @guppy.custom(builtins, checker=BoolArithChecker()) - def __lshift__(self: bool, other: bool) -> int: ... - - @guppy.custom(builtins, checker=BoolArithChecker()) - def __lt__(self: bool, other: bool) -> bool: ... - - @guppy.custom(builtins, checker=BoolArithChecker()) - def __mod__(self: bool, other: bool) -> bool: ... - - @guppy.custom(builtins, checker=BoolArithChecker()) - def __mul__(self: bool, other: bool) -> bool: ... - - @guppy.custom(builtins, checker=BoolArithChecker()) - def __ne__(self: bool, other: bool) -> bool: ... - - @guppy.custom(builtins, checker=BoolArithChecker()) - def __neg__(self: bool) -> int: ... - @guppy.custom(builtins, checker=DunderChecker("__bool__"), higher_order_value=False) def __new__(x): ... @guppy.hugr_op(builtins, logic_op("Or", [tys.TypeArg(tys.BoundedNatArg(n=2))])) def __or__(self: bool, other: bool) -> bool: ... - @guppy.custom(builtins, checker=BoolArithChecker()) - def __pos__(self: bool) -> int: ... - - @guppy.custom(builtins, checker=BoolArithChecker()) - def __pow__(self: bool, other: bool) -> int: ... - - @guppy.custom(builtins, checker=BoolArithChecker()) - def __radd__(self: bool, other: bool) -> int: ... - - @guppy.hugr_op( - builtins, - logic_op("And", [tys.TypeArg(tys.BoundedNatArg(n=2))]), - ReversingChecker(), - ) - def __rand__(self: bool, other: bool) -> int: ... - - @guppy.custom(builtins, checker=BoolArithChecker()) - def __rdivmod__(self: bool, other: bool) -> tuple[int, int]: ... - - @guppy.custom(builtins, checker=BoolArithChecker()) - def __rfloordiv__(self: bool, other: bool) -> int: ... - - @guppy.custom(builtins, checker=BoolArithChecker()) - def __rlshift__(self: bool, other: bool) -> int: ... - - @guppy.custom(builtins, checker=BoolArithChecker()) - def __rmod__(self: bool, other: bool) -> int: ... - - @guppy.custom(builtins, checker=BoolArithChecker()) - def __rmul__(self: bool, other: bool) -> int: ... - - @guppy.hugr_op( - builtins, - logic_op("Or", [tys.TypeArg(tys.BoundedNatArg(n=2))]), - ReversingChecker(), - ) - def __ror__(self: bool, other: bool) -> int: ... - - @guppy.custom(builtins, checker=BoolArithChecker()) - def __round__(self: bool) -> int: ... - - @guppy.custom(builtins, checker=BoolArithChecker()) - def __rpow__(self: bool, other: bool) -> int: ... - - @guppy.custom(builtins, checker=BoolArithChecker()) - def __rrshift__(self: bool, other: bool) -> int: ... - - @guppy.custom(builtins, checker=BoolArithChecker()) - def __rshift__(self: bool, other: bool) -> int: ... - - @guppy.custom(builtins, checker=BoolArithChecker()) - def __rsub__(self: bool, other: bool) -> int: ... - - @guppy.custom(builtins, checker=BoolArithChecker()) - def __rtruediv__(self: bool, other: bool) -> float: ... - - @guppy.hugr_op(builtins, DummyOp("Xor"), ReversingChecker()) # TODO - def __rxor__(self: bool, other: bool) -> bool: ... - - @guppy.custom(builtins, checker=BoolArithChecker()) - def __sub__(self: bool, other: bool) -> int: ... - - @guppy.custom(builtins, checker=BoolArithChecker()) - def __truediv__(self: bool, other: bool) -> float: ... - - @guppy.custom(builtins, checker=BoolArithChecker()) - def __trunc__(self: bool) -> int: ... - - @guppy.hugr_op(builtins, DummyOp("Xor")) # TODO - def __xor__(self: bool, other: bool) -> bool: ... - @guppy.extend_type(builtins, int_type_def) class Int: diff --git a/guppylang/tys/builtin.py b/guppylang/tys/builtin.py index 7cfc0c9d..e8441343 100644 --- a/guppylang/tys/builtin.py +++ b/guppylang/tys/builtin.py @@ -139,8 +139,13 @@ def _list_to_hugr(args: Sequence[Argument]) -> tys.Type: callable_type_def = _CallableTypeDef(DefId.fresh(), None) tuple_type_def = _TupleTypeDef(DefId.fresh(), None) none_type_def = _NoneTypeDef(DefId.fresh(), None) -bool_type_def = _NumericTypeDef( - DefId.fresh(), "bool", None, NumericType(NumericType.Kind.Bool) +bool_type_def = OpaqueTypeDef( + id=DefId.fresh(), + name="bool", + defined_at=None, + params=[], + always_linear=False, + to_hugr=lambda _: tys.Type(tys.SumType(tys.UnitSum(size=2))), ) int_type_def = _NumericTypeDef( DefId.fresh(), "int", None, NumericType(NumericType.Kind.Int) @@ -166,8 +171,8 @@ def _list_to_hugr(args: Sequence[Argument]) -> tys.Type: ) -def bool_type() -> NumericType: - return NumericType(NumericType.Kind.Bool) +def bool_type() -> OpaqueType: + return OpaqueType([], bool_type_def) def int_type() -> NumericType: @@ -183,7 +188,7 @@ def linst_type(element_ty: Type) -> OpaqueType: def is_bool_type(ty: Type) -> bool: - return isinstance(ty, NumericType) and ty.kind == NumericType.Kind.Bool + return isinstance(ty, OpaqueType) and ty.defn == bool_type_def def is_list_type(ty: Type) -> bool: diff --git a/guppylang/tys/ty.py b/guppylang/tys/ty.py index 1c70038d..75b3700f 100644 --- a/guppylang/tys/ty.py +++ b/guppylang/tys/ty.py @@ -244,7 +244,6 @@ class NumericType(TypeBase): class Kind(Enum): """The different kinds of numeric types.""" - Bool = "bool" Int = "int" Float = "float" @@ -258,8 +257,6 @@ def linear(self) -> bool: def to_hugr(self) -> tys.Type: """Computes the Hugr representation of the type.""" match self.kind: - case NumericType.Kind.Bool: - return SumType([NoneType(), NoneType()]).to_hugr() case NumericType.Kind.Int: return tys.Type( tys.Opaque( diff --git a/tests/error/type_errors/invert_not_int.err b/tests/error/type_errors/invert_not_int.err index a26cae58..16524a1f 100644 --- a/tests/error/type_errors/invert_not_int.err +++ b/tests/error/type_errors/invert_not_int.err @@ -2,6 +2,6 @@ Guppy compilation failed. Error in file $FILE:6 4: @compile_guppy 5: def foo() -> int: -6: return ~() - ^^ -GuppyTypeError: Unary operator `~` not defined for argument of type `()` +6: return ~True + ^^^^ +GuppyTypeError: Unary operator `~` not defined for argument of type `bool` diff --git a/tests/error/type_errors/invert_not_int.py b/tests/error/type_errors/invert_not_int.py index 958456d9..39399e81 100644 --- a/tests/error/type_errors/invert_not_int.py +++ b/tests/error/type_errors/invert_not_int.py @@ -3,4 +3,4 @@ @compile_guppy def foo() -> int: - return ~() + return ~True diff --git a/tests/error/type_errors/unary_not_arith.err b/tests/error/type_errors/unary_not_arith.err index 09ae6847..ae82206b 100644 --- a/tests/error/type_errors/unary_not_arith.err +++ b/tests/error/type_errors/unary_not_arith.err @@ -2,6 +2,6 @@ Guppy compilation failed. Error in file $FILE:6 4: @compile_guppy 5: def foo() -> int: -6: return -() - ^^ -GuppyTypeError: Unary operator `-` not defined for argument of type `()` +6: return -True + ^^^^ +GuppyTypeError: Unary operator `-` not defined for argument of type `bool` diff --git a/tests/error/type_errors/unary_not_arith.py b/tests/error/type_errors/unary_not_arith.py index 0862b902..bf9ce2cb 100644 --- a/tests/error/type_errors/unary_not_arith.py +++ b/tests/error/type_errors/unary_not_arith.py @@ -3,4 +3,4 @@ @compile_guppy def foo() -> int: - return -() + return -True diff --git a/tests/integration/test_arithmetic.py b/tests/integration/test_arithmetic.py index 1ac41e14..a31db662 100644 --- a/tests/integration/test_arithmetic.py +++ b/tests/integration/test_arithmetic.py @@ -26,18 +26,10 @@ def add(x: int) -> int: validate(add) -def test_bool(validate): - @compile_guppy - def add(x: bool, y: bool) -> int: - return x + y - - validate(add) - - def test_float_coercion(validate): @compile_guppy - def coerce(x: int, y: float, z: bool) -> float: - return x * y + z + def coerce(x: int, y: float) -> float: + return x * y validate(coerce) From 1c31b0682e1a66f95ec402677064128c9c15e467 Mon Sep 17 00:00:00 2001 From: Mark Koch Date: Mon, 24 Jun 2024 14:39:58 +0100 Subject: [PATCH 8/8] feat: Turn bool into a numeric type This reverts commit cf8a5299e7b317de1f30f2c84c1b665bfba8c904. --- guppylang/checker/core.py | 2 + guppylang/prelude/_internal.py | 41 ++++++- guppylang/prelude/builtins.py | 128 +++++++++++++++++++- guppylang/tys/builtin.py | 15 +-- guppylang/tys/ty.py | 3 + tests/error/type_errors/invert_not_int.err | 6 +- tests/error/type_errors/invert_not_int.py | 2 +- tests/error/type_errors/unary_not_arith.err | 6 +- tests/error/type_errors/unary_not_arith.py | 2 +- tests/integration/test_arithmetic.py | 12 +- 10 files changed, 194 insertions(+), 23 deletions(-) diff --git a/guppylang/checker/core.py b/guppylang/checker/core.py index 72fef4db..4ebb4a33 100644 --- a/guppylang/checker/core.py +++ b/guppylang/checker/core.py @@ -92,6 +92,8 @@ def get_instance_func(self, ty: Type | TypeDef, name: str) -> CallableDef | None return None case NumericType(kind): match kind: + case NumericType.Kind.Bool: + type_defn = bool_type_def case NumericType.Kind.Int: type_defn = int_type_def case NumericType.Kind.Float: diff --git a/guppylang/prelude/_internal.py b/guppylang/prelude/_internal.py index 6d02f78f..fc62385c 100644 --- a/guppylang/prelude/_internal.py +++ b/guppylang/prelude/_internal.py @@ -5,7 +5,12 @@ from guppylang.ast_util import AstNode, get_type, with_loc from guppylang.checker.core import Context -from guppylang.checker.expr_checker import ExprSynthesizer, check_num_args +from guppylang.checker.expr_checker import ( + ExprSynthesizer, + check_call, + check_num_args, + synthesize_call, +) from guppylang.definition.custom import ( CustomCallChecker, CustomCallCompiler, @@ -16,7 +21,7 @@ from guppylang.error import GuppyError, GuppyTypeError from guppylang.hugr_builder.hugr import UNDEFINED, OutPortV from guppylang.nodes import GlobalCall -from guppylang.tys.builtin import bool_type, list_type +from guppylang.tys.builtin import bool_type, int_type, list_type from guppylang.tys.subst import Subst from guppylang.tys.ty import FunctionType, NumericType, Type, unify @@ -241,6 +246,38 @@ def check(self, args: list[ast.expr], ty: Type) -> tuple[ast.expr, Subst]: return args, subst +class BoolArithChecker(DefaultCallChecker): + """Function call checker for arithmetic operations on bools. + + Converts all bools into ints and calls the corresponding int arithmetic method with + the same name. + """ + + def _prepare_args(self, args: list[ast.expr]) -> list[ast.expr]: + # Cast all inputs to int + to_int = self.ctx.globals.get_instance_func(bool_type(), "__int__") + assert to_int is not None + return [to_int.synthesize_call([arg], arg, self.ctx)[0] for arg in args] + + def _get_func(self) -> CallableDef: + # Get the int function with the same name + func = self.ctx.globals.get_instance_func(int_type(), self.func.name) + assert func is not None + return func + + def synthesize(self, args: list[ast.expr]) -> tuple[ast.expr, Type]: + args, _, inst = synthesize_call(self.func.ty, args, self.node, self.ctx) + assert not inst # `self.func.ty` is not generic + args = self._prepare_args(args) + return self._get_func().synthesize_call(args, self.node, self.ctx) + + def check(self, args: list[ast.expr], ty: Type) -> tuple[ast.expr, Subst]: + args, _, inst = check_call(self.func.ty, args, ty, self.node, self.ctx) + assert not inst # `self.func.ty` is not generic + args = self._prepare_args(args) + return self._get_func().check_call(args, ty, self.node, self.ctx) + + class IntTruedivCompiler(CustomCallCompiler): """Compiler for the `int.__truediv__` method.""" diff --git a/guppylang/prelude/builtins.py b/guppylang/prelude/builtins.py index 03af4d75..65a5abb0 100644 --- a/guppylang/prelude/builtins.py +++ b/guppylang/prelude/builtins.py @@ -12,6 +12,7 @@ from guppylang.hugr_builder.hugr import DummyOp from guppylang.module import GuppyModule from guppylang.prelude._internal import ( + BoolArithChecker, CallableChecker, CoercingChecker, DunderChecker, @@ -52,21 +53,146 @@ def py(*_args: Any) -> Any: @guppy.extend_type(builtins, bool_type_def) class Bool: + @guppy.custom(builtins, NoopCompiler()) + def __abs__(self: int) -> int: ... + + @guppy.custom(builtins, checker=BoolArithChecker()) + def __add__(self: bool, other: bool) -> int: ... + @guppy.hugr_op(builtins, logic_op("And", [tys.TypeArg(tys.BoundedNatArg(n=2))])) def __and__(self: bool, other: bool) -> bool: ... @guppy.custom(builtins, NoopCompiler()) def __bool__(self: bool) -> bool: ... - @guppy.hugr_op(builtins, int_op("ifrombool")) + @guppy.custom(builtins, checker=BoolArithChecker()) + def __ceil__(self: bool) -> int: ... + + @guppy.custom(builtins, checker=BoolArithChecker()) + def __divmod__(self: bool, other: bool) -> tuple[int, int]: ... + + @guppy.custom(builtins, checker=BoolArithChecker()) + def __eq__(self: bool, other: bool) -> bool: ... + + @guppy.custom(builtins, checker=BoolArithChecker()) + def __float__(self: bool) -> float: ... + + @guppy.custom(builtins, checker=BoolArithChecker()) + def __floor__(self: bool) -> int: ... + + @guppy.custom(builtins, checker=BoolArithChecker()) + def __floordiv__(self: bool, other: bool) -> int: ... + + @guppy.custom(builtins, checker=BoolArithChecker()) + def __ge__(self: bool, other: bool) -> bool: ... + + @guppy.custom(builtins, checker=BoolArithChecker()) + def __gt__(self: bool, other: bool) -> bool: ... + + @guppy.hugr_op(builtins, DummyOp("ifrombool")) # TODO: Widen to INT_WIDTH def __int__(self: bool) -> int: ... + @guppy.custom(builtins, checker=BoolArithChecker()) + def __invert__(self: bool) -> int: ... + + @guppy.custom(builtins, checker=BoolArithChecker()) + def __le__(self: bool, other: bool) -> bool: ... + + @guppy.custom(builtins, checker=BoolArithChecker()) + def __lshift__(self: bool, other: bool) -> int: ... + + @guppy.custom(builtins, checker=BoolArithChecker()) + def __lt__(self: bool, other: bool) -> bool: ... + + @guppy.custom(builtins, checker=BoolArithChecker()) + def __mod__(self: bool, other: bool) -> bool: ... + + @guppy.custom(builtins, checker=BoolArithChecker()) + def __mul__(self: bool, other: bool) -> bool: ... + + @guppy.custom(builtins, checker=BoolArithChecker()) + def __ne__(self: bool, other: bool) -> bool: ... + + @guppy.custom(builtins, checker=BoolArithChecker()) + def __neg__(self: bool) -> int: ... + @guppy.custom(builtins, checker=DunderChecker("__bool__"), higher_order_value=False) def __new__(x): ... @guppy.hugr_op(builtins, logic_op("Or", [tys.TypeArg(tys.BoundedNatArg(n=2))])) def __or__(self: bool, other: bool) -> bool: ... + @guppy.custom(builtins, checker=BoolArithChecker()) + def __pos__(self: bool) -> int: ... + + @guppy.custom(builtins, checker=BoolArithChecker()) + def __pow__(self: bool, other: bool) -> int: ... + + @guppy.custom(builtins, checker=BoolArithChecker()) + def __radd__(self: bool, other: bool) -> int: ... + + @guppy.hugr_op( + builtins, + logic_op("And", [tys.TypeArg(tys.BoundedNatArg(n=2))]), + ReversingChecker(), + ) + def __rand__(self: bool, other: bool) -> int: ... + + @guppy.custom(builtins, checker=BoolArithChecker()) + def __rdivmod__(self: bool, other: bool) -> tuple[int, int]: ... + + @guppy.custom(builtins, checker=BoolArithChecker()) + def __rfloordiv__(self: bool, other: bool) -> int: ... + + @guppy.custom(builtins, checker=BoolArithChecker()) + def __rlshift__(self: bool, other: bool) -> int: ... + + @guppy.custom(builtins, checker=BoolArithChecker()) + def __rmod__(self: bool, other: bool) -> int: ... + + @guppy.custom(builtins, checker=BoolArithChecker()) + def __rmul__(self: bool, other: bool) -> int: ... + + @guppy.hugr_op( + builtins, + logic_op("Or", [tys.TypeArg(tys.BoundedNatArg(n=2))]), + ReversingChecker(), + ) + def __ror__(self: bool, other: bool) -> int: ... + + @guppy.custom(builtins, checker=BoolArithChecker()) + def __round__(self: bool) -> int: ... + + @guppy.custom(builtins, checker=BoolArithChecker()) + def __rpow__(self: bool, other: bool) -> int: ... + + @guppy.custom(builtins, checker=BoolArithChecker()) + def __rrshift__(self: bool, other: bool) -> int: ... + + @guppy.custom(builtins, checker=BoolArithChecker()) + def __rshift__(self: bool, other: bool) -> int: ... + + @guppy.custom(builtins, checker=BoolArithChecker()) + def __rsub__(self: bool, other: bool) -> int: ... + + @guppy.custom(builtins, checker=BoolArithChecker()) + def __rtruediv__(self: bool, other: bool) -> float: ... + + @guppy.hugr_op(builtins, DummyOp("Xor"), ReversingChecker()) # TODO + def __rxor__(self: bool, other: bool) -> bool: ... + + @guppy.custom(builtins, checker=BoolArithChecker()) + def __sub__(self: bool, other: bool) -> int: ... + + @guppy.custom(builtins, checker=BoolArithChecker()) + def __truediv__(self: bool, other: bool) -> float: ... + + @guppy.custom(builtins, checker=BoolArithChecker()) + def __trunc__(self: bool) -> int: ... + + @guppy.hugr_op(builtins, DummyOp("Xor")) # TODO + def __xor__(self: bool, other: bool) -> bool: ... + @guppy.extend_type(builtins, int_type_def) class Int: diff --git a/guppylang/tys/builtin.py b/guppylang/tys/builtin.py index e8441343..7cfc0c9d 100644 --- a/guppylang/tys/builtin.py +++ b/guppylang/tys/builtin.py @@ -139,13 +139,8 @@ def _list_to_hugr(args: Sequence[Argument]) -> tys.Type: callable_type_def = _CallableTypeDef(DefId.fresh(), None) tuple_type_def = _TupleTypeDef(DefId.fresh(), None) none_type_def = _NoneTypeDef(DefId.fresh(), None) -bool_type_def = OpaqueTypeDef( - id=DefId.fresh(), - name="bool", - defined_at=None, - params=[], - always_linear=False, - to_hugr=lambda _: tys.Type(tys.SumType(tys.UnitSum(size=2))), +bool_type_def = _NumericTypeDef( + DefId.fresh(), "bool", None, NumericType(NumericType.Kind.Bool) ) int_type_def = _NumericTypeDef( DefId.fresh(), "int", None, NumericType(NumericType.Kind.Int) @@ -171,8 +166,8 @@ def _list_to_hugr(args: Sequence[Argument]) -> tys.Type: ) -def bool_type() -> OpaqueType: - return OpaqueType([], bool_type_def) +def bool_type() -> NumericType: + return NumericType(NumericType.Kind.Bool) def int_type() -> NumericType: @@ -188,7 +183,7 @@ def linst_type(element_ty: Type) -> OpaqueType: def is_bool_type(ty: Type) -> bool: - return isinstance(ty, OpaqueType) and ty.defn == bool_type_def + return isinstance(ty, NumericType) and ty.kind == NumericType.Kind.Bool def is_list_type(ty: Type) -> bool: diff --git a/guppylang/tys/ty.py b/guppylang/tys/ty.py index 75b3700f..1c70038d 100644 --- a/guppylang/tys/ty.py +++ b/guppylang/tys/ty.py @@ -244,6 +244,7 @@ class NumericType(TypeBase): class Kind(Enum): """The different kinds of numeric types.""" + Bool = "bool" Int = "int" Float = "float" @@ -257,6 +258,8 @@ def linear(self) -> bool: def to_hugr(self) -> tys.Type: """Computes the Hugr representation of the type.""" match self.kind: + case NumericType.Kind.Bool: + return SumType([NoneType(), NoneType()]).to_hugr() case NumericType.Kind.Int: return tys.Type( tys.Opaque( diff --git a/tests/error/type_errors/invert_not_int.err b/tests/error/type_errors/invert_not_int.err index 16524a1f..a26cae58 100644 --- a/tests/error/type_errors/invert_not_int.err +++ b/tests/error/type_errors/invert_not_int.err @@ -2,6 +2,6 @@ Guppy compilation failed. Error in file $FILE:6 4: @compile_guppy 5: def foo() -> int: -6: return ~True - ^^^^ -GuppyTypeError: Unary operator `~` not defined for argument of type `bool` +6: return ~() + ^^ +GuppyTypeError: Unary operator `~` not defined for argument of type `()` diff --git a/tests/error/type_errors/invert_not_int.py b/tests/error/type_errors/invert_not_int.py index 39399e81..958456d9 100644 --- a/tests/error/type_errors/invert_not_int.py +++ b/tests/error/type_errors/invert_not_int.py @@ -3,4 +3,4 @@ @compile_guppy def foo() -> int: - return ~True + return ~() diff --git a/tests/error/type_errors/unary_not_arith.err b/tests/error/type_errors/unary_not_arith.err index ae82206b..09ae6847 100644 --- a/tests/error/type_errors/unary_not_arith.err +++ b/tests/error/type_errors/unary_not_arith.err @@ -2,6 +2,6 @@ Guppy compilation failed. Error in file $FILE:6 4: @compile_guppy 5: def foo() -> int: -6: return -True - ^^^^ -GuppyTypeError: Unary operator `-` not defined for argument of type `bool` +6: return -() + ^^ +GuppyTypeError: Unary operator `-` not defined for argument of type `()` diff --git a/tests/error/type_errors/unary_not_arith.py b/tests/error/type_errors/unary_not_arith.py index bf9ce2cb..0862b902 100644 --- a/tests/error/type_errors/unary_not_arith.py +++ b/tests/error/type_errors/unary_not_arith.py @@ -3,4 +3,4 @@ @compile_guppy def foo() -> int: - return -True + return -() diff --git a/tests/integration/test_arithmetic.py b/tests/integration/test_arithmetic.py index a31db662..1ac41e14 100644 --- a/tests/integration/test_arithmetic.py +++ b/tests/integration/test_arithmetic.py @@ -26,10 +26,18 @@ def add(x: int) -> int: validate(add) +def test_bool(validate): + @compile_guppy + def add(x: bool, y: bool) -> int: + return x + y + + validate(add) + + def test_float_coercion(validate): @compile_guppy - def coerce(x: int, y: float) -> float: - return x * y + def coerce(x: int, y: float, z: bool) -> float: + return x * y + z validate(coerce)