Skip to content

Commit

Permalink
good
Browse files Browse the repository at this point in the history
  • Loading branch information
shiinamiyuki committed Feb 4, 2025
1 parent c87945b commit dfbdf3b
Show file tree
Hide file tree
Showing 4 changed files with 184 additions and 20 deletions.
20 changes: 18 additions & 2 deletions luisa_lang/codegen/cpp.py
Original file line number Diff line number Diff line change
Expand Up @@ -415,8 +415,10 @@ def impl() -> None:
ty = self.base.type_cache.gen(expr.type)
self.body.writeln(
f"{ty} v{vid}{{ {','.join(self.gen_expr(e) for e in expr.args)} }};")
case hir.Intrinsic() as intrin:
case hir.Intrinsic() as intrin:
def do():
assert intrin.type
intrin_ty_s = self.base.type_cache.gen(intrin.type)
intrin_name = intrin.name
comps = intrin_name.split('.')
gened_args = [self.gen_value_or_ref(
Expand All @@ -426,6 +428,12 @@ def do():
ty = self.base.type_cache.gen(expr.type)
self.body.writeln(
f"{ty} v{vid}{{ {','.join(gened_args)} }};")
elif comps[0] == 'cast':
self.body.writeln(
f"auto v{vid} = static_cast<{intrin_ty_s}>({gened_args[0]});")
elif comps[0] == 'bitcast':
self.body.writeln(
f"auto v{vid} = lc_bit_cast<{intrin_ty_s}>({gened_args[0]});")
elif comps[0] == 'cmp':
cmp_dict = {
'__eq__': '==',
Expand Down Expand Up @@ -592,11 +600,19 @@ def gen_node(self, node: hir.Node) -> Optional[hir.BasicBlock]:
ty = self.base.type_cache.gen(alloca.type.remove_ref())
self.body.writeln(f"{ty} v{vid}{{}}; // alloca")
self.node_map[alloca] = f"v{vid}"
case hir.AggregateInit() | hir.Intrinsic() | hir.Call() | hir.Constant() | hir.Load() | hir.Index() | hir.Member() | hir.TypeValue() | hir.FunctionValue():
case hir.Print() as print_stmt:
raise NotImplementedError("print statement")
case hir.Assert() as assert_stmt:
raise NotImplementedError("assert statement")
case hir.AggregateInit() | hir.Intrinsic() | hir.Call() | hir.Constant() | hir.Load() | hir.Index() | hir.Member() | hir.TypeValue() | hir.FunctionValue() | hir.VarValue():
if isinstance(node, hir.TypedNode) and node.is_ref():
pass
else:
self.gen_expr(node)
case hir.VarRef():
pass
case _:
raise NotImplementedError(f"unsupported node: {node}")
return None

def gen_bb(self, bb: hir.BasicBlock):
Expand Down
87 changes: 75 additions & 12 deletions luisa_lang/hir.py
Original file line number Diff line number Diff line change
Expand Up @@ -142,7 +142,7 @@ def method(self, name: str) -> Optional[Union["Function", FunctionTemplate]]:

def is_concrete(self) -> bool:
return True

def is_addressable(self) -> bool:
return True

Expand All @@ -163,8 +163,10 @@ class RefType(Type):

def __init__(self, element: Type) -> None:
super().__init__()
assert element.is_addressable(), f"RefType element {element} is not addressable"
assert not isinstance(element, (OpaqueType, RefType, FunctionType,TypeConstructorType))
assert element.is_addressable(), f"RefType element {
element} is not addressable"
assert not isinstance(
element, (OpaqueType, RefType, FunctionType, TypeConstructorType))
self.element = element
self.methods = element.methods

Expand All @@ -189,18 +191,19 @@ def member(self, field: Any) -> Optional['Type']:
ty = self.element.member(field)
if ty is None:
return None
if isinstance(ty,FunctionType):
if isinstance(ty, FunctionType):
return ty
return RefType(ty)

@override
def method(self, name: str) -> Optional[Union["Function", FunctionTemplate]]:
return self.element.method(name)

@override
def is_addressable(self) -> bool:
return False


class LiteralType(Type):
value: Any

Expand All @@ -221,7 +224,7 @@ def is_concrete(self) -> bool:
@override
def is_addressable(self) -> bool:
return False

def __eq__(self, value: object) -> bool:
return isinstance(value, LiteralType) and value.value == self.value

Expand Down Expand Up @@ -349,6 +352,7 @@ def is_concrete(self) -> bool:
def is_addressable(self) -> bool:
return False


class GenericIntType(ScalarType):
@override
def __eq__(self, value: object) -> bool:
Expand Down Expand Up @@ -382,6 +386,7 @@ def is_concrete(self) -> bool:
def is_addressable(self) -> bool:
return False


class FloatType(ScalarType):
bits: int

Expand Down Expand Up @@ -695,6 +700,7 @@ def __repr__(self) -> str:
def __str__(self) -> str:
return f"~{self.name}@{self.ctx_name}"


class OpaqueType(Type):
name: str
extra_args: List[Any]
Expand Down Expand Up @@ -722,7 +728,7 @@ def __str__(self) -> str:
@override
def is_concrete(self) -> bool:
return False

@override
def is_addressable(self) -> bool:
return False
Expand Down Expand Up @@ -800,18 +806,19 @@ def __eq__(self, value: object) -> bool:

def __hash__(self) -> int:
return hash((ParametricType, tuple(self.params), self.body))

def __str__(self) -> str:
return f"{self.body}[{', '.join(str(p) for p in self.params)}]"

@override
def is_concrete(self) -> bool:
return self.body.is_concrete()

@override
def is_addressable(self) -> bool:
return self.body.is_addressable()


class BoundType(Type):
"""
An instance of a parametric type, e.g. Foo[int]
Expand Down Expand Up @@ -841,7 +848,7 @@ def __eq__(self, value: object) -> bool:

def __hash__(self):
return hash((BoundType, self.generic, tuple(self.args)))

def __str__(self) -> str:
return f"{self.generic}[{', '.join(str(a) for a in self.args)}]"

Expand All @@ -862,11 +869,12 @@ def method(self, name) -> Optional[Union["Function", FunctionTemplate]]:
@override
def is_addressable(self) -> bool:
return self.generic.is_addressable()

@override
def is_concrete(self) -> bool:
return self.generic.is_concrete()


class TypeConstructorType(Type):
inner: Type

Expand Down Expand Up @@ -910,6 +918,7 @@ def size(self) -> int:
def align(self) -> int:
raise RuntimeError("FunctionType has no align")


class Node:
"""
Base class for all nodes in the HIR. A node could be a value, a reference, or a statement.
Expand Down Expand Up @@ -999,13 +1008,15 @@ def __init__(
self.name = name
self.semantic = semantic


class VarValue(Value):
var: Var

def __init__(self, var: Var, span: Optional[Span]) -> None:
super().__init__(var.type, span)
self.var = var


class VarRef(Value):
var: Var

Expand Down Expand Up @@ -1155,6 +1166,8 @@ def __str__(self) -> str:
return f"Template matching error:\n\t{self.message}"
return f"Template matching error at {self.span}:\n\t{self.message}"

class ComptimeCallStack:
pass

class SpannedError(Exception):
span: Span | None
Expand Down Expand Up @@ -1200,7 +1213,8 @@ class Assign(Node):
value: Value

def __init__(self, ref: Value, value: Value, span: Optional[Span] = None) -> None:
assert not isinstance(value.type, (FunctionType, TypeConstructorType, RefType))
assert not isinstance(
value.type, (FunctionType, TypeConstructorType, RefType))
if not isinstance(ref.type, RefType):
raise ParsingError(
ref, f"cannot assign to a non-reference variable")
Expand All @@ -1209,6 +1223,24 @@ def __init__(self, ref: Value, value: Value, span: Optional[Span] = None) -> Non
self.value = value


class Assert(Node):
cond: Value
msg: List[Union[Value, str]]

def __init__(self, cond: Value, msg: List[Union[Value, str]], span: Optional[Span] = None) -> None:
super().__init__(span)
self.cond = cond
self.msg = msg


class Print(Node):
args: List[Union[Value, str]]

def __init__(self, args: List[Union[Value, str]], span: Optional[Span] = None) -> None:
super().__init__(span)
self.args = args


class Terminator(Node):
pass

Expand Down Expand Up @@ -1559,6 +1591,7 @@ def __init__(self, func: Function, args: List[Value], body: BasicBlock, span: Op
self.mapping[param] = arg
for v in func.locals:
if v in self.mapping:
# skip function parameters
continue
assert v.type
assert v.type.is_addressable()
Expand Down Expand Up @@ -1631,6 +1664,33 @@ def do():
self.mapping[intrin] = body.append(
Intrinsic(intrin.name, args, intrin.type, node.span))
do()
case If():
cond = self.mapping.get(node.cond)
assert isinstance(cond, Value)
then_body = BasicBlock()
else_body = BasicBlock()
merge = BasicBlock()
body.append(If(cond, then_body, else_body, merge))
self.do_inline(node.then_body, then_body)
if node.else_body:
self.do_inline(node.else_body, else_body)
body.append(merge)
case Loop():
prepare = BasicBlock()
if node.cond:
cond = self.mapping.get(node.cond)
else:
cond = None
assert cond is None or isinstance(cond, Value)
body_ = BasicBlock()
update = BasicBlock()
merge = BasicBlock()
body.append(Loop(prepare, cond, body_, update, merge))
self.do_inline(node.prepare, prepare)
self.do_inline(node.body, body_)
if node.update:
self.do_inline(node.update, update)
body.append(merge)
case Return():
if self.ret is not None:
raise InlineError(node, "multiple return statement")
Expand All @@ -1646,6 +1706,9 @@ def do():
@staticmethod
def inline(func: Function, args: List[Value], body: BasicBlock, span: Optional[Span] = None) -> Value:
inliner = FunctionInliner(func, args, body, span)
assert func.return_type
if func.return_type == UnitType():
return Unit()
assert inliner.ret
return inliner.ret

Expand Down
33 changes: 32 additions & 1 deletion luisa_lang/lang_builtins.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@
Any,
Annotated
)
from luisa_lang._builtin_decor import func, intrinsic, opaque, builtin_generic_type, byref
from luisa_lang._builtin_decor import func, intrinsic, opaque, builtin_generic_type, byref, struct
from luisa_lang import parse

T = TypeVar("T")
Expand Down Expand Up @@ -317,6 +317,37 @@ def __sub__(self, offset: i32 | i64 | u32 | u64) -> 'Pointer[T]':
return intrinsic("pointer.sub", Pointer[T], self, offset)


@struct
class RtxRay:
o: float3
d: float3
tmin: float
tmax: float

def __init__(self, o: float3, d: float3, tmin: float, tmax: float) -> None:
self.o = o
self.d = d
self.tmin = tmin
self.tmax = tmax


@struct
class RtxHit:
inst_id: u32
prim_id: u32
bary: float2

def __init__(self, inst_id: u32, prim_id: u32, bary: float2) -> None:
self.inst_id = inst_id
self.prim_id = prim_id
self.bary = bary


@func
def ray_query_pipeline(ray: RtxRay, on_surface_hit, on_procedural_hit) -> RtxHit:
return intrinsic("ray_query_pipeline", RtxHit, ray, on_surface_hit, on_procedural_hit)


__all__: List[str] = [
# 'Pointer',
'Buffer',
Expand Down
Loading

0 comments on commit dfbdf3b

Please sign in to comment.