Skip to content

Commit

Permalink
added parsing for generic aliases
Browse files Browse the repository at this point in the history
  • Loading branch information
shiinamiyuki committed Dec 26, 2024
1 parent fea3ff3 commit e6a21a5
Show file tree
Hide file tree
Showing 6 changed files with 95 additions and 34 deletions.
6 changes: 6 additions & 0 deletions luisa_lang/__init__.py
Original file line number Diff line number Diff line change
@@ -1,2 +1,8 @@
import sys

# check if is python 3.12 or higher
if sys.version_info < (3, 12):
raise Exception("luisa_lang requires Python 3.12 or higher")

from luisa_lang.lang import *
from luisa_lang.lang_builtins import *
1 change: 0 additions & 1 deletion luisa_lang/_builtin_decor.py
Original file line number Diff line number Diff line change
Expand Up @@ -142,7 +142,6 @@ def parsing_func(args: hir.FunctionTemplateResolvingArgs) -> hir.Function:


def _dsl_func_impl(f: _TT, kind: _ObjKind, attrs: Dict[str, Any]) -> _TT:
import sourceinspect
assert inspect.isfunction(f), f"{f} is not a function"
# print(hir.GlobalContext.get)
ctx = hir.GlobalContext.get()
Expand Down
78 changes: 63 additions & 15 deletions luisa_lang/classinfo.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
import inspect
from types import NoneType
from types import GenericAlias, NoneType
import types
import typing
from typing import (
Expand All @@ -10,21 +10,23 @@
Optional,
Set,
Tuple,
TypeAliasType,
TypeVar,
Generic,
Dict,
Type,
Union,
cast,
)
import functools
from dataclasses import dataclass


class GenericInstance:
origin: type
origin: 'VarType'
args: List["VarType"]

def __init__(self, origin: type, args: List["VarType"]):
def __init__(self, origin: 'VarType', args: List["VarType"]):
self.origin = origin
self.args = args

Expand All @@ -41,6 +43,9 @@ def __init__(self, types: List["VarType"]):
def __repr__(self):
return f"Union[{', '.join(map(repr, self.types))}]"

def substitute(self, env: Dict[TypeVar, 'VarType']) -> "UnionType":
return UnionType([subst_type(ty, env) for ty in self.types])


class AnyType:
def __repr__(self):
Expand All @@ -56,7 +61,8 @@ def __repr__(self):

def __eq__(self, other):
return isinstance(other, SelfType)



class LiteralType:
value: Any

Expand All @@ -70,7 +76,23 @@ def __eq__(self, other):
return isinstance(other, LiteralType) and self.value == other.value


VarType = Union[TypeVar, Type, GenericInstance, UnionType, SelfType, AnyType, LiteralType]
class AnnotatedType:
origin: 'VarType'
annotations: List[Any]

def __init__(self, origin: 'VarType', annotations: List[Any]):
self.origin = origin
self.annotations = annotations

def __repr__(self):
return f"Annotated[{self.origin}, {self.annotations}]"

def substitute(self, env: Dict[TypeVar, 'VarType']) -> "AnnotatedType":
return AnnotatedType(subst_type(self.origin, env), self.annotations)


type VarType = Union[TypeVar, Type, GenericInstance,
UnionType, SelfType, AnyType, LiteralType, AnnotatedType]


def subst_type(ty: VarType, env: Dict[TypeVar, VarType]) -> VarType:
Expand All @@ -79,6 +101,8 @@ def subst_type(ty: VarType, env: Dict[TypeVar, VarType]) -> VarType:
return env.get(ty, ty)
case GenericInstance(origin=origin, args=args):
return GenericInstance(origin, [subst_type(arg, env) for arg in args])
case MethodType() | UnionType() | AnnotatedType():
return ty.substitute(env)
case _:
return ty

Expand Down Expand Up @@ -140,7 +164,8 @@ def __repr__(self):
def instantiate(self, type_args: List[VarType]) -> "ClassType":
if len(type_args) != len(self.type_vars):
raise RuntimeError(
f"Expected {len(self.type_vars)} type arguments but got {len(type_args)}"
f"Expected {len(self.type_vars)}" +
f"type arguments but got {len(type_args)}"
)
env = dict(zip(self.type_vars, type_args))
return ClassType(
Expand Down Expand Up @@ -172,7 +197,8 @@ def _get_base_classinfo(cls: type, globalns) -> List[tuple[str, ClassType]]:
for base in cls.__orig_bases__:
if hasattr(base, "__origin__"):
base_params = []
base_orig = base.__origin__
base_orig: Any = base.__origin__

if not _is_class_registered(base_orig) and base_orig not in _BUILTIN_ANNOTATION_BASES:
raise RuntimeError(
f"Base class {base_orig} of {cls} is not registered."
Expand All @@ -185,7 +211,8 @@ def _get_base_classinfo(cls: type, globalns) -> List[tuple[str, ClassType]]:
if base_orig in _BUILTIN_ANNOTATION_BASES:
pass
else:
base_info = class_typeinfo(base_orig)
assert isinstance(base_orig, type)
base_info = class_typeinfo(cast(type, base_orig))
info.append(
(base.__name__, base_info.instantiate(base_params)))
else:
Expand All @@ -210,19 +237,40 @@ def parse_type_hint(hint: Any) -> VarType:
return UnionType([parse_type_hint(arg) for arg in hint.__args__])
if hint is typing.Any:
return AnyType()
if isinstance(hint, TypeAliasType):
return parse_type_hint(hint.__value__)

origin = typing.get_origin(hint)
if origin:
if isinstance(origin, type):
# assert isinstance(origin, type), f"origin must be a type but got {origin}"
args = list(typing.get_args(hint))
return GenericInstance(origin, [parse_type_hint(arg) for arg in args])
if origin is typing.Annotated:
annotate_args = typing.get_args(hint)
return AnnotatedType(parse_type_hint(annotate_args[0]), list(annotate_args[1:]))
elif origin is Union:
return UnionType([parse_type_hint(arg) for arg in typing.get_args(hint)])
elif origin is Literal:
return LiteralType(typing.get_args(hint)[0])
elif isinstance(origin, TypeAliasType):
def do() -> VarType:
assert isinstance(hint, GenericAlias)
args = list(typing.get_args(hint))
assert len(args) == len(origin.__parameters__), f"Expected {
len(origin.__parameters__)} type arguments but got {len(args)}"
true_origin = origin.__value__
parametric_args = origin.__parameters__
parsed_args = [parse_type_hint(arg) for arg in args]
env = dict(zip(parametric_args, parsed_args))
parsed_origin = parse_type_hint(true_origin)
return subst_type(parsed_origin, env)
return do()
elif isinstance(origin, type):
# assert isinstance(origin, type), f"origin must be a type but got {origin}"
args = list(typing.get_args(hint))
return GenericInstance(origin, [parse_type_hint(arg) for arg in args])

else:
raise RuntimeError(f"Unsupported origin type: {origin}")

raise RuntimeError(f"Unsupported origin type: {
origin}, {type(origin), type(hint)}")

if isinstance(hint, type):
return hint
if hint == typing.Self:
Expand All @@ -242,7 +290,7 @@ def extract_type_vars_from_hint(hint: typing.Any) -> List[TypeVar]:


def get_type_vars(func: typing.Callable) -> List[TypeVar]:
type_hints = typing.get_type_hints(func)
type_hints = typing.get_type_hints(func, include_extras=True)
type_vars = []
for hint in type_hints.values():
type_vars.extend(extract_type_vars_from_hint(hint))
Expand Down
38 changes: 22 additions & 16 deletions luisa_lang/codegen/cpp.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,10 +35,10 @@ def gen_impl(self, ty: hir.Type) -> str:
match ty:
case hir.IntType(bits=bits, signed=signed):
int_names = {
'8':'byte',
'16':'short',
'32':'int',
'64':'long',
'8': 'byte',
'16': 'short',
'32': 'int',
'64': 'long',
}
if signed:
return f"lc_{int_names[str(bits)]}"
Expand Down Expand Up @@ -85,11 +85,13 @@ def do():
return self.gen(ty.instantiated)
case hir.FunctionType():
name = f'func_{unique_hash(ty.func_like.name)}_t'
self.impl.writeln(f'struct {name} {{}}; // function type of {ty.func_like.name}')
self.impl.writeln(
f'struct {name} {{}}; // function type of {ty.func_like.name}')
return name
case hir.TypeConstructorType():
name = f'type_{unique_hash(self.gen(ty.inner))}_t'
self.impl.writeln(f'struct {name} {{}}; // type constructor of {ty.inner}')
self.impl.writeln(
f'struct {name} {{}}; // type constructor of {ty.inner}')
return name
case hir.OpaqueType():
def do():
Expand All @@ -98,7 +100,8 @@ def do():
elem_ty = self.gen(ty.extra_args[0])
return f'__builtin__Buffer<{elem_ty}>'
case _:
raise NotImplementedError(f"unsupported opaque type: {ty.name}")
raise NotImplementedError(
f"unsupported opaque type: {ty.name}")
return do()
case hir.GenericIntType():
return 'int'
Expand Down Expand Up @@ -225,7 +228,8 @@ def gen_function(self, func: hir.Function | Callable[..., Any]) -> str:
if callable(func):
dsl_func = get_dsl_func(func)
assert dsl_func is not None
assert not dsl_func.is_generic, f"Generic functions should be resolved before codegen: {func}"
assert not dsl_func.is_generic, f"Generic functions should be resolved before codegen: {
func}"
func_tmp = dsl_func.resolve([])
assert isinstance(
func_tmp, hir.Function), f"Expected function, got {func_tmp}"
Expand Down Expand Up @@ -268,8 +272,9 @@ def __init__(self, base: CppCodeGen, func: hir.Function) -> None:
params = ",".join(self.gen_var(
p) for p in func.params)
assert func.return_type

self.signature = f'auto {self.name}({params}) -> {base.type_cache.gen(func.return_type)}'

self.signature = f'auto {
self.name}({params}) -> {base.type_cache.gen(func.return_type)}'
if func.export:
self.signature = f'extern "C" {self.signature}'
if func.inline_hint == True:
Expand Down Expand Up @@ -304,14 +309,15 @@ def gen_ref(self, ref: hir.Ref) -> str:
def do():
intrin_name = intrin.name
gened_args = [self.gen_value_or_ref(
arg) for arg in intrin.args]
arg) for arg in intrin.args]
match intrin_name:
case 'buffer.ref' | 'array.ref':
return f"{gened_args[0]}[{gened_args[1]}]"
case 'buffer.size' | 'array.size':
return f"{gened_args[0]}.size"
case _:
raise RuntimeError(f"unsupported intrinsic reference: {intrin_name}")
raise RuntimeError(
f"unsupported intrinsic reference: {intrin_name}")
return do()
case _:
raise NotImplementedError(f"unsupported reference: {ref}")
Expand All @@ -338,7 +344,7 @@ def gen_node_checked(self, node: hir.Node) -> str:
if isinstance(node, hir.TypedNode) and isinstance(node.type, (hir.TypeConstructorType, hir.FunctionType)):
assert node.type
return f'{self.base.type_cache.gen(node.type)}{{}}'

return self.node_map[node]

def gen_expr(self, expr: hir.Value) -> str:
Expand Down Expand Up @@ -440,7 +446,7 @@ def do():
'__sub__': '-',
'__mul__': '*',
'__truediv__': '/',
'__floordiv__': '/', # TODO: fix floordiv
'__floordiv__': '/', # TODO: fix floordiv
'__mod__': '%',
'__pow__': '**',
'__and__': '&',
Expand All @@ -460,7 +466,7 @@ def do():
'__isub__': '-=',
'__imul__': '*=',
'__itruediv__': '/=',
'__ifloordiv__': '/=', # TODO: fix floordiv
'__ifloordiv__': '/=', # TODO: fix floordiv
'__imod__': '%=',
'__ipow__': '**=',
'__iand__': '&=',
Expand Down Expand Up @@ -489,7 +495,7 @@ def do():
args_s = ','.join(gened_args)
self.body.writeln(
f"auto v{vid} = __intrin__{intrin_name}({args_s});")

do()
case _:
raise NotImplementedError(
Expand Down
1 change: 1 addition & 0 deletions luisa_lang/lang_builtins.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@
N = TypeVar("N")



@func
def dispatch_id() -> uint3:
return intrinsic("dispatch_id", uint3)
Expand Down
5 changes: 3 additions & 2 deletions luisa_lang/parse.py
Original file line number Diff line number Diff line change
Expand Up @@ -153,7 +153,7 @@ def handle_type_t():

def convert_func_signature(signature: classinfo.MethodType,
ctx_name: str,
props:hir.FuncProperties,
props: hir.FuncProperties,
globalns: Dict[str, Any],
type_var_ns: Dict[typing.TypeVar, hir.Type],
implicit_type_params: Dict[str, hir.Type],
Expand Down Expand Up @@ -194,7 +194,8 @@ def convert_func_signature(signature: classinfo.MethodType,
params.append(
Var(arg[0], implicit_type_params[arg[0]], span=None, semantic=semantic))
return_type = type_parser.parse_type_ext(signature.return_type)
assert return_type is not None, f"failed to parse return type {signature.return_type}"
assert return_type is not None, f"failed to parse return type {
signature.return_type}"
if isinstance(return_type, hir.AnyBound):
return_type = None
elif isinstance(return_type, hir.TypeBound):
Expand Down

0 comments on commit e6a21a5

Please sign in to comment.