Skip to content

Commit

Permalink
added for loop
Browse files Browse the repository at this point in the history
  • Loading branch information
shiinamiyuki committed Nov 5, 2024
1 parent 4b4e4f1 commit 87e7e52
Show file tree
Hide file tree
Showing 2 changed files with 65 additions and 10 deletions.
9 changes: 6 additions & 3 deletions luisa_lang/hir.py
Original file line number Diff line number Diff line change
Expand Up @@ -1034,10 +1034,10 @@ def __init__(self, value: Optional[Value], span: Optional[Span] = None) -> None:

class Range(Value):
start: Value
step: Optional[Value]
stop: Optional[Value]
step: Value
stop: Value

def __init__(self, start: Value, stop: Optional[Value] = None, step: Optional[Value] = None, span: Optional[Span] = None) -> None:
def __init__(self, start: Value, stop: Value, step: Value, span: Optional[Span] = None) -> None:
super().__init__(None, span)
self.start = start
self.stop = stop
Expand All @@ -1057,6 +1057,9 @@ def update(self, value: Any) -> None:
self.update_func(value)
else:
raise RuntimeError("unable to update comptime value")

def __str__(self) -> str:
return f"ComptimeValue({self.value})"


class BuiltinFunction:
Expand Down
66 changes: 59 additions & 7 deletions luisa_lang/parse.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
import typing
import luisa_lang
from luisa_lang.lang_builtins import comptime
import luisa_lang.math_types
from luisa_lang.utils import get_typevar_constrains_and_bounds, unwrap
import luisa_lang.hir as hir
import sys
Expand Down Expand Up @@ -49,7 +50,7 @@ class TypeParser:
implicit_type_params: Dict[str, hir.Type]

def __init__(self, ctx_name: str, globalns: Dict[str, Any], type_var_ns: Dict[typing.TypeVar, hir.Type | ComptimeValue], self_type: Optional[Type] = None) -> None:
self.globalns = copy(globalns)
self.globalns = globalns
self.self_type = self_type
self.type_var_ns = type_var_ns
self.ctx_name = ctx_name
Expand Down Expand Up @@ -161,7 +162,7 @@ def __init__(self, name: str,
self.name = name
self.func = func
self.signature = signature
self.globalns = globalns
self.globalns = copy(globalns)
obj_ast, _obj_file = retrieve_ast_and_filename(func)
print(ast.dump(obj_ast))
assert isinstance(obj_ast, ast.Module), f"{obj_ast} is not a module"
Expand Down Expand Up @@ -247,6 +248,9 @@ def convert_any_to_value(self, a: Any, span: hir.Span | None) -> hir.Value | Com
def parse_name(self, name: ast.Name, new_var_hint: NewVarHint) -> hir.Ref | hir.Value | ComptimeValue:
span = hir.Span.from_ast(name)
var = self.vars.get(name.id)
# print(__builtins__)
# print('range' in __builtins__)
# assert hasattr(__builtins__, 'range')
if var is not None:
return var
if new_var_hint == 'dsl':
Expand All @@ -258,9 +262,11 @@ def parse_name(self, name: ast.Name, new_var_hint: NewVarHint) -> hir.Ref | hir.
if name.id in self.globalns:
resolved = self.globalns[name.id]
return self.convert_any_to_value(resolved, span)
elif name.id in __builtins__: # type: ignore
resolved = __builtins__[name.id] # type: ignore
return self.convert_any_to_value(resolved, span)
elif new_var_hint == 'comptime':
self.globalns[name.id] = None

def update_fn(value: Any) -> None:
self.globalns[name.id] = value
return ComptimeValue(None, update_fn)
Expand Down Expand Up @@ -461,10 +467,12 @@ def handle_range() -> hir.Value | ComptimeValue:
args[i] = self.try_convert_comptime_value(
arg, hir.Span.from_ast(expr.args[i]))
converted_args = cast(List[hir.Value], args)
def make_int(i: int) -> hir.Value:
return hir.Constant(i, type=hir.GenericIntType())
if len(args) == 1:
return hir.Range(converted_args[0])
return hir.Range(make_int(0), converted_args[0], make_int(1))
elif len(args) == 2:
return hir.Range(converted_args[0], converted_args[1])
return hir.Range(converted_args[0], converted_args[1], make_int(1))
elif len(args) == 3:
return hir.Range(converted_args[0], converted_args[1], converted_args[2])
else:
Expand Down Expand Up @@ -684,7 +692,7 @@ def check(i: int, val_type: hir.Type) -> None:
case hir.ArrayType() as at:
assert isinstance(at.count, int)
do_unpack(at.count, lambda values, i, target: self.cur_bb().append(
hir.Index(values, hir.Constant(i, type=hir.IntType(32, True)), type=at.element, span=hir.Span.from_ast(target))))
hir.Index(values, hir.Constant(i, type=luisa_lang.typeof(luisa_lang.i32)), type=at.element, span=hir.Span.from_ast(target))))
case hir.StructType() as st:
do_unpack(len(st.fields), lambda values, i, target: self.cur_bb().append(
hir.Member(values, st.fields[i][0], type=st.fields[i][1], span=hir.Span.from_ast(target)))
Expand Down Expand Up @@ -792,7 +800,51 @@ def parse_stmt(self, stmt: ast.stmt) -> None:
hir.Loop(prepare, cond, body, update, merge, span))
self.bb_stack.append(merge)
case ast.For():
pass
iter_val = self.parse_expr(stmt.iter)
if not isinstance(iter_val, hir.Value) or not isinstance(iter_val, hir.Range):
raise hir.ParsingError(
stmt, f"for loop iterable must be a range object but found {iter_val}")
pred_bb = self.cur_bb()
self.bb_stack.pop()
loop_var = self.parse_ref(stmt.target, new_var_hint='dsl')
if not isinstance(loop_var, hir.Ref):
raise hir.ParsingError(
stmt, "for loop target must be a DSL variable")
if not loop_var.type:
loop_var.type = luisa_lang.typeof(luisa_lang.i32)
if not isinstance(loop_var.type, hir.IntType):
raise hir.ParsingError(
stmt, "for loop target must be an integer variable")
loop_range: hir.Range = iter_val

prepare = hir.BasicBlock(span)
self.bb_stack.append(prepare)
int_lt = loop_var.type.method("__lt__")
assert int_lt is not None
cmp_result = self.parse_call_impl(
span, int_lt, [loop_var, loop_range.stop])
assert isinstance(cmp_result, hir.Value)
assert cmp_result.type == hir.BoolType()
self.bb_stack.pop()
body = hir.BasicBlock(span)
self.bb_stack.append(body)
for s in stmt.body:
self.parse_stmt(s)
body = self.bb_stack.pop()
update = hir.BasicBlock(span)
self.bb_stack.append(update)
inc =loop_range.step
int_add = loop_var.type.method("__add__")
assert int_add is not None
add = self.parse_call_impl(
span, int_add, [loop_var, inc])
assert isinstance(add, hir.Value)
self.cur_bb().append(hir.Assign(loop_var, add))
self.bb_stack.pop()
merge = hir.BasicBlock(span)
pred_bb.append(
hir.Loop(prepare, cmp_result, body, update, merge, span))
self.bb_stack.append(merge)
case ast.Return():
def check_return_type(ty: hir.Type) -> None:
assert self.parsed_func
Expand Down

0 comments on commit 87e7e52

Please sign in to comment.