Skip to content

Commit

Permalink
new intrinsics and builtins
Browse files Browse the repository at this point in the history
  • Loading branch information
shiinamiyuki committed Nov 10, 2024
1 parent 8a24236 commit 48fba8d
Show file tree
Hide file tree
Showing 7 changed files with 1,665 additions and 1,599 deletions.
38 changes: 24 additions & 14 deletions luisa_lang/_builtin_decor.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,21 +24,23 @@
Any,
)

T = TypeVar('T')
_T = TypeVar("_T", bound=type)
_F = TypeVar("_F", bound=Callable[..., Any])


def intrinsic(name: str, ret_type: type[T], *args, **kwargs) -> T:
raise NotImplementedError(
"intrinsic functions should not be called in host-side Python code. "
"Did you mistakenly called a DSL function?"
)

def builtin(s: str) -> Callable[[_F], _F]:
def wrapper(func: _F) -> _F:
setattr(func, "__luisa_builtin__", s)
return func
return wrapper

def byref(value: T) -> T:
"""pass a value by ref"""

def _intrinsic_impl(*args, **kwargs) -> Any:
raise NotImplementedError(
"intrinsic functions should not be called in host-side Python code. "
"byref should not be called in host-side Python code. "
"Did you mistakenly called a DSL function?"
)

Expand Down Expand Up @@ -128,7 +130,7 @@ def _dsl_func_impl(f: _TT, kind: _ObjKind, attrs: Dict[str, Any]) -> _TT:
# return cast(_T, f)


def _dsl_struct_impl(cls: type[_TT], attrs: Dict[str, Any]) -> type[_TT]:
def _dsl_struct_impl(cls: type[_TT], attrs: Dict[str, Any], ir_ty_override: hir.Type | None = None) -> type[_TT]:
ctx = hir.GlobalContext.get()

register_class(cls)
Expand Down Expand Up @@ -166,13 +168,16 @@ def parse_methods(type_var_ns: Dict[TypeVar, hir.Type | Any], self_ty: hir.Type)
self_ty.instantiated.methods[name] = template
else:
self_ty.methods[name] = template
ir_ty: hir.Type
if ir_ty_override is not None:
ir_ty = ir_ty_override
else:
ir_ty = hir.StructType(
f'{cls.__name__}_{unique_hash(cls.__qualname__)}', cls.__qualname__, [])
type_parser = parse.TypeParser(
cls.__qualname__, globalns, {}, ir_ty, 'parse')

ir_ty: hir.Type = hir.StructType(
f'{cls.__name__}_{unique_hash(cls.__qualname__)}', cls.__qualname__, [])
type_parser = parse.TypeParser(
cls.__qualname__, globalns, {}, ir_ty, 'parse')

parse_fields(type_parser, ir_ty)
parse_fields(type_parser, ir_ty)
is_generic = len(cls_info.type_vars) > 0
if is_generic:
def monomorphization_func(args: List[hir.Type | Any]) -> hir.Type:
Expand Down Expand Up @@ -229,6 +234,11 @@ def volume(self) -> float:
return _dsl_decorator_impl(cls, _ObjKind.STRUCT, {})


def builtin_type(ty: hir.Type, *args, **kwargs) -> Callable[[type[_TT]], _TT]:
def decorator(cls: type[_TT]) -> _TT:
return typing.cast(_TT, _dsl_struct_impl(cls, {}, ir_ty_override=ty))
return decorator

_KernelType = TypeVar("_KernelType", bound=Callable[..., None])


Expand Down
70 changes: 45 additions & 25 deletions luisa_lang/hir.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@
import typing
from typing_extensions import override
from luisa_lang import classinfo
from luisa_lang.utils import Span
from luisa_lang.utils import Span, round_to_align
from abc import ABC, abstractmethod

PATH_PREFIX = "luisa_lang"
Expand Down Expand Up @@ -306,31 +306,26 @@ def __str__(self) -> str:
class VectorType(Type):
element: Type
count: int
_align: int
_size: int

def __init__(self, element: Type, count: int) -> None:
def __init__(self, element: Type, count: int, align:int|None=None) -> None:
super().__init__()
if align is None:
align = element.align()
self.element = element
self.count = count
self._align = align
assert (self.element.size() * self.count) % self._align == 0
self._size = round_to_align(self.element.size() * self.count, self._align)

def _special_size_align(self) -> Optional[Tuple[int, int]]:
if self.count != 3:
return None
if self.element.size() == 4:
return (16, 16)
return None

def size(self) -> int:
special = self._special_size_align()
if special is not None:
return special[0]
return self.element.size() * self.count
return self._size

def align(self) -> int:
special = self._special_size_align()
if special is not None:
return special[1]
return self.element.align()

return self._align

def __eq__(self, value: object) -> bool:
return (
isinstance(value, VectorType)
Expand Down Expand Up @@ -460,7 +455,7 @@ class StructType(Type):
_field_dict: Dict[str, Type]
# _monomorphification_cache: Dict[Tuple['GenericParameter', Type | 'Value'], Type]

def __init__(self, name: str, display_name: str, fields: List[Tuple[str, Type]]) -> None:
def __init__(self, name: str, display_name: str, fields: List[Tuple[str, Type]]) -> None:
super().__init__()
self.name = name
self._fields = fields
Expand Down Expand Up @@ -506,26 +501,38 @@ def __hash__(self) -> int:


class TypeBound:
pass
@abstractmethod
def satisfied_by(self, ty: Type) -> bool:
pass


class AnyBound(TypeBound):
pass
@override
def satisfied_by(self, ty: Type) -> bool:
return True


class SubtypeBound(TypeBound):
super_type: Type
exact_match: bool

def __init__(self, super_type: Type) -> None:
def __init__(self, super_type: Type, exact_match:bool) -> None:
self.super_type = super_type
self.exact_match = exact_match

def __repr__(self) -> str:
return f"SubtypeBound({self.super_type})"

def __eq__(self, value: object) -> bool:
return isinstance(value, SubtypeBound) and value.super_type == self.super_type


@override
def satisfied_by(self, ty: Type) -> bool:
if self.exact_match:
return is_type_compatible_to(ty, self.super_type)
else:
raise NotImplementedError()

class UnionBound(TypeBound):
bounds: List[SubtypeBound]

Expand All @@ -537,6 +544,10 @@ def __repr__(self) -> str:

def __eq__(self, value: object) -> bool:
return isinstance(value, UnionBound) and value.bounds == self.bounds

@override
def satisfied_by(self, ty: Type) -> bool:
return any(b.satisfied_by(ty) for b in self.bounds)


class GenericParameter:
Expand Down Expand Up @@ -1199,9 +1210,18 @@ def unify(a: Type | ComptimeValue, b: Type | ComptimeValue):
case SymbolicType():
if a.param.name in mapping:
return unify(mapping[a.param], b)
if isinstance(b, GenericFloatType) or isinstance(b, GenericIntType):
raise TypeInferenceError(None,
f"float/int literal cannot be used to infer generic type for `{a.param.name}` directly, wrap it with a concrete type")
if a.param.bound is None:
if isinstance(b, GenericFloatType) or isinstance(b, GenericIntType):
raise TypeInferenceError(None,
f"float/int literal cannot be used to infer generic type for `{a.param.name}` directly, wrap it with a concrete type")
else:
if not a.param.bound.satisfied_by(b):
raise TypeInferenceError(None, f"{b} does not satisfy bound {a.param.bound}")
if isinstance(a.param.bound, UnionBound):
for bound in a.param.bound.bounds:
if bound.satisfied_by(b) and bound.super_type.is_concrete():
mapping[a.param] = bound.super_type
return
mapping[a.param] = b
return
case VectorType():
Expand Down
Loading

0 comments on commit 48fba8d

Please sign in to comment.