Skip to content

Commit

Permalink
fix passing arguments by ref
Browse files Browse the repository at this point in the history
  • Loading branch information
shiinamiyuki committed Dec 17, 2024
1 parent 0e83814 commit 385ea44
Show file tree
Hide file tree
Showing 3 changed files with 15 additions and 8 deletions.
4 changes: 2 additions & 2 deletions luisa_lang/_builtin_decor.py
Original file line number Diff line number Diff line change
Expand Up @@ -80,7 +80,7 @@ def _make_func_template(f: Callable[..., Any], func_name: str, func_sig: Optiona
func_sig = classinfo.parse_func_signature(f, func_globals, [])

func_sig_converted, sig_parser = parse.convert_func_signature(
func_sig, func_name, func_globals, foreign_type_var_ns, {}, self_type)
func_sig, func_name, props, func_globals, foreign_type_var_ns, {}, self_type)
implicit_type_params = sig_parser.implicit_type_params
implicit_generic_params: Set[hir.GenericParameter] = set()
for p in implicit_type_params.values():
Expand Down Expand Up @@ -119,7 +119,7 @@ def parsing_func(args: hir.FunctionTemplateResolvingArgs) -> hir.Function:
mapped_implicit_type_params[name] = mapped_type

func_sig_instantiated, _p = parse.convert_func_signature(
func_sig, func_name, func_globals, type_var_ns, mapped_implicit_type_params, self_type, mode='instantiate')
func_sig, func_name, props, func_globals, type_var_ns, mapped_implicit_type_params, self_type, mode='instantiate')
# print(func_name, func_sig)
assert len(
func_sig_instantiated.generic_params) == 0, f"generic params should be resolved but found {func_sig_instantiated.generic_params}"
Expand Down
16 changes: 10 additions & 6 deletions luisa_lang/lang_builtins.py
Original file line number Diff line number Diff line change
Expand Up @@ -121,8 +121,11 @@ def device_assert(cond: bool, msg: str = "") -> typing.NoReturn:
raise NotImplementedError(
"device_assert should not be called in host-side Python code. ")


def sizeof(t: type[T]) -> u64:
raise NotImplementedError("sizeof should not be called in host-side Python code. ")
raise NotImplementedError(
"sizeof should not be called in host-side Python code. ")


@overload
def range(n: T) -> List[T]: ...
Expand Down Expand Up @@ -208,11 +211,11 @@ class Array(Generic[T, N]):
def __init__(self) -> None:
self = intrinsic("init.array", Array[T, N])

def __getitem__(self, index: int | u32 | u64) -> T:
def __getitem__(self, index: int | i32 | u32 | i64 | u64) -> T:
return intrinsic("array.ref", T, byref(self), index) # type: ignore

def __setitem__(self, index: int | u32 | u64, value: T) -> None:
pass
def __setitem__(self, index: int | i32 | u32 | i64 | u64, value: T | int | float) -> None:
"""value: T | int | float annotation is to make mypy happy. this function is ignored by the compiler"""

def __len__(self) -> u64:
return intrinsic("array.size", u64, self) # type: ignore
Expand All @@ -233,10 +236,11 @@ def __len__(self) -> u64:

@opaque("Buffer")
class Buffer(Generic[T]):
def __getitem__(self, index: int | u32 | u64) -> T:
def __getitem__(self, index: int | i32 | u32 | i64 | u64) -> T:
return intrinsic("buffer.ref", T, self, index) # type: ignore

def __setitem__(self, index: int | u32 | u64, value: T) -> None:
def __setitem__(self, index: int | i32 | u32 | i64 | u64, value: T | int | float) -> None:
"""value: T | int | float annotation is to make mypy happy. this function is ignored by the compiler"""
pass

def __len__(self) -> u64:
Expand Down
3 changes: 3 additions & 0 deletions luisa_lang/parse.py
Original file line number Diff line number Diff line change
Expand Up @@ -153,6 +153,7 @@ def handle_type_t():

def convert_func_signature(signature: classinfo.MethodType,
ctx_name: str,
props:hir.FuncProperties,
globalns: Dict[str, Any],
type_var_ns: Dict[typing.TypeVar, hir.Type],
implicit_type_params: Dict[str, hir.Type],
Expand All @@ -173,6 +174,8 @@ def convert_func_signature(signature: classinfo.MethodType,
assert self_type is not None
param_type = self_type
semantic = hir.ParameterSemantic.BYREF
if arg[0] in props.byref:
semantic = hir.ParameterSemantic.BYREF
if param_type is None:
raise RuntimeError(
f"Unable to parse type of parameter {arg[0]}: {arg[1]}")
Expand Down

0 comments on commit 385ea44

Please sign in to comment.