Skip to content

Commit

Permalink
refactor func to be a class hierarchy
Browse files Browse the repository at this point in the history
  • Loading branch information
makslevental committed Jul 21, 2023
1 parent 91cb08f commit a137e4d
Show file tree
Hide file tree
Showing 6 changed files with 170 additions and 89 deletions.
3 changes: 3 additions & 0 deletions mlir_utils/_configuration/configuration.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,13 +5,16 @@
from base64 import urlsafe_b64encode
from importlib.metadata import distribution, packages_distributions
from importlib.resources import files
from importlib.resources.readers import MultiplexedPath
from pathlib import Path

from .module_alias_map import get_meta_path_insertion_index, AliasedModuleFinder

__MLIR_PYTHON_PACKAGE_PREFIX__ = "__MLIR_PYTHON_PACKAGE_PREFIX__"
PACKAGE = __package__.split(".")[0]
PACKAGE_ROOT_PATH = files(PACKAGE)
if isinstance(PACKAGE_ROOT_PATH, MultiplexedPath):
PACKAGE_ROOT_PATH = PACKAGE_ROOT_PATH._paths[0]
DIST = distribution(packages_distributions()[PACKAGE][0])
MLIR_PYTHON_PACKAGE_PREFIX_TOKEN_PATH = (
Path(__file__).parent / __MLIR_PYTHON_PACKAGE_PREFIX__
Expand Down
185 changes: 102 additions & 83 deletions mlir_utils/dialects/ext/func.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,95 +18,114 @@
)


def func_base(
FuncOp,
ReturnOp,
CallOp,
sym_visibility=None,
arg_attrs=None,
res_attrs=None,
loc=None,
ip=None,
):
ip = ip or InsertionPoint.current

# if this is set to true then wrapper below won't emit a call op
# it is set below by a def emit fn that is attached to the body_builder
# wrapper; thus you can call wrapped_fn.emit() (i.e., without an operands)
# and the func will be emitted.
_emit = False

def builder_wrapper(body_builder):
@wraps(body_builder)
def wrapper(*call_args):
# TODO(max): implement constexpr ie enable passing constants that skip being
# part of the signature
sig = inspect.signature(body_builder)
implicit_return = sig.return_annotation is inspect._empty
input_types = [p.annotation for p in sig.parameters.values()]
if not (
len(input_types) == len(sig.parameters)
and all(isinstance(t, Type) for t in input_types)
):
input_types = [a.type for a in call_args]
function_type = TypeAttr.get(
FunctionType.get(
inputs=input_types,
results=[] if implicit_return else sig.return_annotation,
)
class FuncOpMeta(type):
def __call__(cls, *args, **kwargs):
cls_obj = cls.__new__(cls)
if len(args) == 1 and len(kwargs) == 0 and inspect.isfunction(args[0]):
return cls.__init__(cls_obj, args[0])
else:

def init_wrapper(f):
cls.__init__(cls_obj, f, *args, **kwargs)
return cls_obj

return lambda f: init_wrapper(f)


class FuncBase(metaclass=FuncOpMeta):
def __init__(
self,
body_builder,
func_op_ctor,
return_op_ctor,
call_op_ctor,
sym_visibility=None,
arg_attrs=None,
res_attrs=None,
loc=None,
ip=None,
):
assert inspect.isfunction(body_builder), body_builder
assert inspect.isclass(func_op_ctor), func_op_ctor
assert inspect.isclass(return_op_ctor), return_op_ctor
assert inspect.isclass(call_op_ctor), call_op_ctor

self.body_builder = body_builder
self.func_name = self.body_builder.__name__

self.func_op_ctor = func_op_ctor
self.return_op_ctor = return_op_ctor
self.call_op_ctor = call_op_ctor
self.sym_visibility = (
StringAttr.get(str(sym_visibility)) if sym_visibility is not None else None
)
self.arg_attrs = arg_attrs
self.res_attrs = res_attrs
self.loc = loc
self.ip = ip or InsertionPoint.current
self.emitted = False

def __str__(self):
return str(f"{self.__class__} {self.__dict__}")

def body_builder_wrapper(self, *call_args):
sig = inspect.signature(self.body_builder)
implicit_return = sig.return_annotation is inspect._empty
input_types = [p.annotation for p in sig.parameters.values()]
if not (
len(input_types) == len(sig.parameters)
and all(isinstance(t, Type) for t in input_types)
):
input_types = [a.type for a in call_args]
function_type = TypeAttr.get(
FunctionType.get(
inputs=input_types,
results=[] if implicit_return else sig.return_annotation,
)
# FuncOp is extended but we do really want the base
func_name = body_builder.__name__
func_op = FuncOp(
func_name,
function_type,
sym_visibility=StringAttr.get(str(sym_visibility))
if sym_visibility is not None
else None,
arg_attrs=arg_attrs,
res_attrs=res_attrs,
loc=loc,
ip=ip,
)
func_op = self.func_op_ctor(
self.func_name,
function_type,
sym_visibility=self.sym_visibility,
arg_attrs=self.arg_attrs,
res_attrs=self.res_attrs,
loc=self.loc,
ip=self.ip,
)
func_op.regions[0].blocks.append(*input_types)
with InsertionPoint(func_op.regions[0].blocks[0]):
results = get_result_or_results(
self.body_builder(*func_op.regions[0].blocks[0].arguments)
)
func_op.regions[0].blocks.append(*input_types)
with InsertionPoint(func_op.regions[0].blocks[0]):
results = get_result_or_results(
body_builder(*func_op.regions[0].blocks[0].arguments)
)
if results is not None:
if isinstance(results, (tuple, list)):
results = list(results)
else:
results = [results]
if results is not None:
if isinstance(results, (tuple, list)):
results = list(results)
else:
results = []
ReturnOp(results)
# Recompute the function type.
return_types = [v.type for v in results]
function_type = FunctionType.get(inputs=input_types, results=return_types)
func_op.attributes["function_type"] = TypeAttr.get(function_type)

if _emit:
return maybe_cast(get_result_or_results(func_op))
results = [results]
else:
call_op = CallOp(
[r.type for r in results],
FlatSymbolRefAttr.get(func_name),
call_args,
)
return maybe_cast(get_result_or_results(call_op))
results = []
self.return_op_ctor(results)

def emit():
nonlocal _emit
_emit = True
wrapper()
return results, input_types, func_op

wrapper.emit = emit
return wrapper
def emit(self):
self.results, input_types, func_op = self.body_builder_wrapper()
return_types = [v.type for v in self.results]
function_type = FunctionType.get(inputs=input_types, results=return_types)
func_op.attributes["function_type"] = TypeAttr.get(function_type)
self.emitted = True
# this is the func op itself (funcs never have a resulting ssa value)
return maybe_cast(get_result_or_results(func_op))

return builder_wrapper
def __call__(self, *call_args):
if not self.emitted:
self.emit()
call_op = CallOp(
[r.type for r in self.results],
FlatSymbolRefAttr.get(self.func_name),
call_args,
)
return maybe_cast(get_result_or_results(call_op))


func = make_maybe_no_args_decorator(
partial(func_base, FuncOp=FuncOp.__base__, ReturnOp=ReturnOp, CallOp=CallOp)
)
func = FuncBase(FuncOp.__base__, ReturnOp, CallOp.__base__)
3 changes: 1 addition & 2 deletions mlir_utils/dialects/ext/tensor.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
from typing import Union, Tuple, Sequence

import numpy as np
from mlir.dialects.tensor import EmptyOp
from mlir.dialects.tensor import EmptyOp, GenerateOp
from mlir.ir import Type, Value, RankedTensorType, DenseElementsAttr, ShapedType

from mlir_utils.dialects.ext.arith import ArithValue
Expand Down Expand Up @@ -62,7 +62,6 @@ def empty(
shape: Union[list[Union[int, Value]], tuple[Union[int, Value], ...]],
el_type: Type,
) -> "Tensor":

return cls(EmptyOp(shape, el_type).result)


Expand Down
63 changes: 61 additions & 2 deletions tests/test_func.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,6 @@
from textwrap import dedent

import pytest

from mlir_utils.dialects.ext.arith import constant
from mlir_utils.dialects.ext.func import func

Expand All @@ -20,7 +19,7 @@ def demo_fun1():
return one

assert hasattr(demo_fun1, "emit")
assert inspect.isfunction(demo_fun1.emit)
assert inspect.ismethod(demo_fun1.emit)
demo_fun1.emit()
correct = dedent(
"""\
Expand All @@ -33,3 +32,63 @@ def demo_fun1():
"""
)
filecheck(correct, ctx.module)


def test_func_base_meta(ctx: MLIRContext):
print()

@func
def foo1():
one = constant(1)
return one

# print("wrapped foo", foo1)
foo1.emit()
correct = dedent(
"""\
module {
func.func @foo1() -> i64 {
%c1_i64 = arith.constant 1 : i64
return %c1_i64 : i64
}
}
"""
)
filecheck(correct, ctx.module)

foo1()
correct = dedent(
"""\
module {
func.func @foo1() -> i64 {
%c1_i64 = arith.constant 1 : i64
return %c1_i64 : i64
}
%0 = func.call @foo1() : () -> i64
}
"""
)
filecheck(correct, ctx.module)


def test_func_base_meta2(ctx: MLIRContext):
print()

@func
def foo1():
one = constant(1)
return one

foo1()
correct = dedent(
"""\
module {
func.func @foo1() -> i64 {
%c1_i64 = arith.constant 1 : i64
return %c1_i64 : i64
}
%0 = func.call @foo1() : () -> i64
}
"""
)
filecheck(correct, ctx.module)
2 changes: 1 addition & 1 deletion tests/test_regions.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@

from mlir_utils.dialects.ext.arith import constant
from mlir_utils.dialects.ext.func import func
from mlir_utils.dialects.ext.tensor import Tensor, S, rank
from mlir_utils.dialects.ext.tensor import S, rank

# noinspection PyUnresolvedReferences
from mlir_utils.testing import mlir_ctx as ctx, filecheck, MLIRContext
Expand Down
3 changes: 2 additions & 1 deletion tests/test_value_caster.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,4 @@
import pytest
from mlir.ir import OpResult

from mlir_utils.dialects.ext.tensor import S, empty
from mlir_utils.dialects.ext.arith import constant
Expand All @@ -9,6 +8,8 @@
from mlir_utils.testing import mlir_ctx as ctx, filecheck, MLIRContext
from mlir_utils.types import f64_t, RankedTensorType

from mlir.ir import OpResult

# needed since the fix isn't defined here nor conftest.py
pytest.mark.usefixtures("ctx")

Expand Down

0 comments on commit a137e4d

Please sign in to comment.