Skip to content

Commit

Permalink
refactoring intrinsics
Browse files Browse the repository at this point in the history
  • Loading branch information
shiinamiyuki committed Nov 9, 2024
1 parent e3fec07 commit 49f912a
Show file tree
Hide file tree
Showing 5 changed files with 84 additions and 16 deletions.
7 changes: 5 additions & 2 deletions luisa_lang/_builtin_decor.py
Original file line number Diff line number Diff line change
Expand Up @@ -134,6 +134,9 @@ def make_builtin():
return decorator





def builtin(s: str) -> Callable[[_F], _F]:
def wrapper(func: _F) -> _F:
setattr(func, "__luisa_builtin__", s)
Expand Down Expand Up @@ -207,7 +210,7 @@ def parsing_func(args: hir.FunctionTemplateResolvingArgs) -> hir.FunctionLike:
params = [v[0] for v in func_sig.args]
is_generic = len(func_sig_converted.generic_params) > 0
# print(
# f"func {func_name} is_generic: {is_generic} {func_sig_converted.generic_params}")
# f"func {func_name} is_generic: {is_generic} {func_sig_converted.generic_params}")
return hir.FunctionTemplate(func_name, params, parsing_func, is_generic)


Expand Down Expand Up @@ -303,7 +306,7 @@ def monomorphization_func(args: List[hir.Type | Any]) -> hir.Type:
pass
ctx.types[cls] = ir_ty
if not is_generic:
parse_methods({},ir_ty)
parse_methods({}, ir_ty)
return cls


Expand Down
6 changes: 6 additions & 0 deletions luisa_lang/codegen/cpp.py
Original file line number Diff line number Diff line change
Expand Up @@ -312,6 +312,12 @@ def impl() -> None:
f"{ty} v{vid}{{ {','.join(self.gen_expr(e) for e in expr.args)} }};")
case hir.TypeValue():
pass
case hir.Intrinsic():
intrin_name = expr.name.replace('.', '_')
args_s = ','.join(self.gen_value_or_ref(
arg) for arg in expr.args)
self.body.writeln(
f"auto v{vid} = __intrin__{intrin_name}({args_s});")
case _:
raise NotImplementedError(
f"unsupported expression: {expr}")
Expand Down
33 changes: 30 additions & 3 deletions luisa_lang/hir.py
Original file line number Diff line number Diff line change
Expand Up @@ -131,6 +131,21 @@ def is_concrete(self) -> bool:
def __len__(self) -> int:
return 1

class AnyType(Type):
def size(self) -> int:
raise RuntimeError("AnyType has no size")

def align(self) -> int:
raise RuntimeError("AnyType has no align")

def __eq__(self, value: object) -> bool:
return isinstance(value, AnyType)

def __hash__(self) -> int:
return hash(AnyType)

def __str__(self) -> str:
return "AnyType"

class UnitType(Type):
def size(self) -> int:
Expand Down Expand Up @@ -807,7 +822,6 @@ def __init__(self, value: 'Value') -> None:
class Value(TypedNode):
pass


class Unit(Value):
def __init__(self) -> None:
super().__init__(UnitType())
Expand Down Expand Up @@ -907,7 +921,10 @@ class TypeValue(Value):
def __init__(self, ty: Type, span: Optional[Span] = None) -> None:
super().__init__(TypeConstructorType(ty), span)


def inner_type(self) -> Type:
assert isinstance(self.type, TypeConstructorType)
return self.type.inner

class Alloca(Ref):
"""
A temporary variable
Expand All @@ -917,6 +934,8 @@ def __init__(self, ty: Type, span: Optional[Span] = None) -> None:
super().__init__(ty, span)




# class Init(Value):
# init_call: 'Call'

Expand All @@ -931,6 +950,14 @@ def __init__(self, args: List[Value], type: Type, span: Optional[Span] = None) -
super().__init__(type, span)
self.args = args

class Intrinsic(Value):
name: str
args: List[Value]

def __init__(self, name: str, args: List[Value], type: Type, span: Optional[Span] = None) -> None:
super().__init__(type, span)
self.name = name
self.args = args

class Call(Value):
op: FunctionLike
Expand Down Expand Up @@ -1335,7 +1362,7 @@ def get_dsl_type(cls: type) -> Optional[Type]:


def is_type_compatible_to(ty: Type, target: Type) -> bool:
if ty == target:
if ty == target or isinstance(ty, AnyType):
return True
if isinstance(target, FloatType):
return isinstance(ty, GenericFloatType)
Expand Down
16 changes: 15 additions & 1 deletion luisa_lang/lang_builtins.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,13 @@
N = TypeVar("N", int, u32, u64)


def intrinsic(name: str, ret_type: type[T], *args, **kwargs) -> T:
raise NotImplementedError(
"intrinsic functions should not be called in host-side Python code. "
"Did you mistakenly called a DSL function?"
)


@builtin("dispatch_id")
def dispatch_id() -> uint3:
return _intrinsic_impl()
Expand Down Expand Up @@ -103,7 +110,10 @@ def comptime(a):
return a


parse.comptime = comptime
parse._add_special_function("comptime", comptime)
parse._add_special_function("intrinsic", intrinsic)
parse._add_special_function("range", range)
parse._add_special_function('reveal_type', typing.reveal_type)


def static_assert(cond: Any, msg: str = ""):
Expand Down Expand Up @@ -160,6 +170,7 @@ def __setitem__(self, index: int | u32 | u64, value: T) -> None:
def __len__(self) -> u32 | u64:
return _intrinsic_impl()


def __buffer_ty():
t = hir.GenericParameter("T", "luisa_lang.lang")
return hir.ParametricType(
Expand All @@ -171,6 +182,8 @@ def __buffer_ty():
# # "Buffer", [hir.TypeParameter(_t, bound=[])], hir.OpaqueType("Buffer")
# # )
# )


class Buffer(Generic[T]):
def __getitem__(self, index: int | u32 | u64) -> T:
return _intrinsic_impl()
Expand Down Expand Up @@ -216,4 +229,5 @@ def value(self, value: T) -> None:
"dispatch_id",
"thread_id",
"block_id",
"intrinsic",
]
38 changes: 28 additions & 10 deletions luisa_lang/parse.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,8 +24,6 @@
from luisa_lang.hir import get_dsl_type, ComptimeValue
import luisa_lang.classinfo as classinfo

comptime: Any = None


def _implicit_typevar_name(v: str) -> str:
return f"T#{v}"
Expand Down Expand Up @@ -156,11 +154,14 @@ def convert_func_signature(signature: classinfo.MethodType,
return hir.FunctionSignature(type_parser.generic_params, params, return_type), type_parser


SPECIAL_FUNCTIONS: Set[Callable[..., Any]] = {
comptime,
reveal_type,
range
}
SPECIAL_FUNCTIONS_DICT: Dict[str, Callable[..., Any]] = {}
SPECIAL_FUNCTIONS: Set[Callable[..., Any]] = set()


def _add_special_function(name: str, f: Callable[..., Any]) -> None:
SPECIAL_FUNCTIONS_DICT[name] = f
SPECIAL_FUNCTIONS.add(f)


NewVarHint = Literal[False, 'dsl', 'comptime']

Expand Down Expand Up @@ -390,7 +391,8 @@ def parse_type_arg(expr: ast.expr) -> hir.Type:
case _:
type_args.append(parse_type_arg(expr.slice))
# print(f"Type args: {type_args}")
assert isinstance(value.type, hir.TypeConstructorType) and isinstance(value.type.inner, hir.ParametricType)
assert isinstance(value.type, hir.TypeConstructorType) and isinstance(
value.type.inner, hir.ParametricType)
return hir.TypeValue(
hir.BoundType(value.type.inner, type_args, value.type.inner.instantiate(type_args)))

Expand Down Expand Up @@ -476,7 +478,23 @@ def parse_call_impl(self, span: hir.Span | None, f: hir.FunctionLike | hir.Funct
raise NotImplementedError() # unreachable

def handle_special_functions(self, f: Callable[..., Any], expr: ast.Call) -> hir.Value | ComptimeValue:
if f is comptime:
if f is SPECIAL_FUNCTIONS_DICT['intrinsic']:
def do() -> hir.Intrinsic:
intrinsic_name = expr.args[0]
if not isinstance(intrinsic_name, ast.Constant) or not isinstance(intrinsic_name.value, str):
raise hir.ParsingError(
expr, "intrinsic function expects a string literal as its first argument")
args = [self.parse_expr(arg) for arg in expr.args[1:]]
ret_type = args[0]
if not isinstance(ret_type, hir.TypeValue):
raise hir.ParsingError(
expr, f"intrinsic function expects a type as its second argument but found {ret_type}")
if any([not isinstance(arg, hir.Value) for arg in args[1:]]):
raise hir.ParsingError(
expr, "intrinsic function expects values as its arguments")
return hir.Intrinsic(intrinsic_name.value, cast(List[hir.Value], args[1:]), ret_type.inner_type(), hir.Span.from_ast(expr))
return do()
elif f is SPECIAL_FUNCTIONS_DICT['comptime']:
if len(expr.args) != 1:
raise hir.ParsingError(
expr, f"when used in expressions, lc.comptime function expects exactly one argument")
Expand Down Expand Up @@ -563,7 +581,7 @@ def collect_args() -> List[hir.Value | hir.Ref]:
arg, hir.Span.from_ast(expr.args[i]))
return cast(List[hir.Value | hir.Ref], args)

if isinstance(func.type, hir.TypeConstructorType):
if isinstance(func.type, hir.TypeConstructorType):
# TypeConstructorType is unique for each type
# so if any value has this type, it must be referring to the same underlying type
# even if it comes from a very complex expression, it's still fine
Expand Down

0 comments on commit 49f912a

Please sign in to comment.