Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: Turn bool into a numeric type #263

Draft
wants to merge 9 commits into
base: main
Choose a base branch
from
15 changes: 15 additions & 0 deletions guppylang/checker/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,8 @@
from guppylang.tys.builtin import (
bool_type_def,
callable_type_def,
float_type_def,
int_type_def,
linst_type_def,
list_type_def,
none_type_def,
Expand All @@ -24,6 +26,7 @@
ExistentialTypeVar,
FunctionType,
NoneType,
NumericType,
OpaqueType,
StructType,
SumType,
Expand Down Expand Up @@ -67,6 +70,8 @@ def default() -> "Globals":
tuple_type_def,
none_type_def,
bool_type_def,
int_type_def,
float_type_def,
list_type_def,
linst_type_def,
]
Expand All @@ -85,6 +90,16 @@ def get_instance_func(self, ty: Type | TypeDef, name: str) -> CallableDef | None
pass
case BoundTypeVar() | ExistentialTypeVar() | SumType():
return None
case NumericType(kind):
match kind:
case NumericType.Kind.Bool:
type_defn = bool_type_def
case NumericType.Kind.Int:
type_defn = int_type_def
case NumericType.Kind.Float:
type_defn = float_type_def
case kind:
return assert_never(kind)
case FunctionType():
type_defn = callable_type_def
case OpaqueType() as ty:
Expand Down
104 changes: 55 additions & 49 deletions guppylang/prelude/_internal.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,45 +3,27 @@
from hugr.serialization import ops, tys
from pydantic import BaseModel

from guppylang.ast_util import AstNode, get_type, with_loc, with_type
from guppylang.ast_util import AstNode, get_type, with_loc
from guppylang.checker.core import Context
from guppylang.checker.expr_checker import ExprSynthesizer, check_num_args
from guppylang.checker.expr_checker import (
ExprSynthesizer,
check_call,
check_num_args,
synthesize_call,
)
from guppylang.definition.custom import (
CustomCallChecker,
CustomCallCompiler,
CustomFunctionDef,
DefaultCallChecker,
)
from guppylang.definition.ty import TypeDef
from guppylang.definition.value import CallableDef
from guppylang.error import GuppyError, GuppyTypeError
from guppylang.hugr_builder.hugr import UNDEFINED, OutPortV
from guppylang.nodes import GlobalCall
from guppylang.tys.builtin import bool_type, list_type
from guppylang.tys.builtin import bool_type, int_type, list_type
from guppylang.tys.subst import Subst
from guppylang.tys.ty import FunctionType, OpaqueType, Type, unify

INT_WIDTH = 6 # 2^6 = 64 bit


hugr_int_type = tys.Type(
tys.Opaque(
extension="arithmetic.int.types",
id="int",
args=[tys.TypeArg(tys.BoundedNatArg(n=INT_WIDTH))],
bound=tys.TypeBound.Eq,
)
)


hugr_float_type = tys.Type(
tys.Opaque(
extension="arithmetic.float.types",
id="float64",
args=[],
bound=tys.TypeBound.Copyable,
)
)
from guppylang.tys.ty import FunctionType, NumericType, Type, unify


class ConstInt(BaseModel):
Expand Down Expand Up @@ -76,9 +58,9 @@ def int_value(i: int) -> ops.Value:
return ops.Value(
ops.ExtensionValue(
extensions=["arithmetic.int.types"],
typ=hugr_int_type,
typ=NumericType(NumericType.Kind.Int).to_hugr(),
value=ops.CustomConst(
c="ConstInt", v=ConstInt(log_width=INT_WIDTH, value=i)
c="ConstInt", v=ConstInt(log_width=NumericType.INT_WIDTH, value=i)
),
)
)
Expand All @@ -89,7 +71,7 @@ def float_value(f: float) -> ops.Value:
return ops.Value(
ops.ExtensionValue(
extensions=["arithmetic.float.types"],
typ=hugr_float_type,
typ=NumericType(NumericType.Kind.Float).to_hugr(),
value=ops.CustomConst(c="ConstF64", v=ConstF64(value=f)),
)
)
Expand All @@ -116,16 +98,16 @@ def logic_op(op_name: str, args: list[tys.TypeArg] | None = None) -> ops.OpType:


def int_op(
op_name: str, ext: str = "arithmetic.int", num_params: int = 1
op_name: str,
ext: str = "arithmetic.int",
args: list[tys.TypeArg] | None = None,
num_params: int = 1,
) -> ops.OpType:
"""Utility method to create Hugr integer arithmetic ops."""
if args is None:
args = num_params * [tys.TypeArg(tys.BoundedNatArg(n=NumericType.INT_WIDTH))]
return ops.OpType(
ops.CustomOp(
extension=ext,
op_name=op_name,
args=num_params * [tys.TypeArg(tys.BoundedNatArg(n=INT_WIDTH))],
parent=UNDEFINED,
)
ops.CustomOp(extension=ext, op_name=op_name, args=args, parent=UNDEFINED)
)


Expand All @@ -140,20 +122,12 @@ class CoercingChecker(DefaultCallChecker):
"""Function call type checker that automatically coerces arguments to float."""

def synthesize(self, args: list[ast.expr]) -> tuple[ast.expr, Type]:
from .builtins import Int

for i in range(len(args)):
args[i], ty = ExprSynthesizer(self.ctx).synthesize(args[i])
if isinstance(ty, OpaqueType) and ty.defn == self.ctx.globals["int"]:
call = with_loc(
self.node,
GlobalCall(def_id=Int.__float__.id, args=[args[i]], type_args=[]),
)
float_defn = self.ctx.globals["float"]
assert isinstance(float_defn, TypeDef)
args[i] = with_type(
float_defn.check_instantiate([], self.ctx.globals), call
)
if isinstance(ty, NumericType) and ty.kind != NumericType.Kind.Float:
to_float = self.ctx.globals.get_instance_func(ty, "__float__")
assert to_float is not None
args[i], _ = to_float.synthesize_call([args[i]], self.node, self.ctx)
return super().synthesize(args)


Expand Down Expand Up @@ -272,6 +246,38 @@ def check(self, args: list[ast.expr], ty: Type) -> tuple[ast.expr, Subst]:
return args, subst


class BoolArithChecker(DefaultCallChecker):
"""Function call checker for arithmetic operations on bools.

Converts all bools into ints and calls the corresponding int arithmetic method with
the same name.
"""

def _prepare_args(self, args: list[ast.expr]) -> list[ast.expr]:
# Cast all inputs to int
to_int = self.ctx.globals.get_instance_func(bool_type(), "__int__")
assert to_int is not None
return [to_int.synthesize_call([arg], arg, self.ctx)[0] for arg in args]

def _get_func(self) -> CallableDef:
# Get the int function with the same name
func = self.ctx.globals.get_instance_func(int_type(), self.func.name)
assert func is not None
return func

def synthesize(self, args: list[ast.expr]) -> tuple[ast.expr, Type]:
args, _, inst = synthesize_call(self.func.ty, args, self.node, self.ctx)
assert not inst # `self.func.ty` is not generic
args = self._prepare_args(args)
return self._get_func().synthesize_call(args, self.node, self.ctx)

def check(self, args: list[ast.expr], ty: Type) -> tuple[ast.expr, Subst]:
args, _, inst = check_call(self.func.ty, args, ty, self.node, self.ctx)
assert not inst # `self.func.ty` is not generic
args = self._prepare_args(args)
return self._get_func().check_call(args, ty, self.node, self.ctx)


class IntTruedivCompiler(CustomCallCompiler):
"""Compiler for the `int.__truediv__` method."""

Expand Down
Loading