Skip to content

Commit

Permalink
- refactor func
Browse files Browse the repository at this point in the history
- fix trampolines casing bug
- fix configuration again
- warn about TypeID instead of crash
  • Loading branch information
makslevental committed Jul 20, 2023
1 parent 074afc8 commit 70dbd55
Show file tree
Hide file tree
Showing 10 changed files with 183 additions and 30 deletions.
13 changes: 12 additions & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -59,4 +59,15 @@ Workaround is to delete the prefix token before configuring, like so:

```shell
rm /home/mlevental/dev_projects/mlir_utils/mlir_utils/_configuration/__MLIR_PYTHON_PACKAGE_PREFIX__ && configure-mlir-python-utils mlir
```
```

## Gotchas

There's a `DefaultContext` created when this package is loaded. If you have weird things happen like

```
E error: unknown: 'arith.constant' op requires attribute 'value'
E note: unknown: see current operation: %0 = "arith.constant"() {value = 64 : i32} : () -> i32
```

which looks patently insane (because `value` is in fact there as an attribute), then you have a `Context`s problem.
7 changes: 7 additions & 0 deletions mlir_utils/_configuration/configuration.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,6 +52,13 @@ def _get_mlir_package_prefix():

def alias_upstream_bindings():
if mlir_python_package_prefix := _get_mlir_package_prefix():
# check if valid package/module
try:
_host_bindings_mlir = __import__(f"{mlir_python_package_prefix}._mlir_libs")
except (ImportError, ModuleNotFoundError) as e:
print(f"couldn't import {mlir_python_package_prefix=} due to: {e}")
raise e

sys.meta_path.insert(
get_meta_path_insertion_index(),
AliasedModuleFinder({"mlir": mlir_python_package_prefix}),
Expand Down
16 changes: 8 additions & 8 deletions mlir_utils/_configuration/generate_trampolines.py
Original file line number Diff line number Diff line change
Expand Up @@ -72,14 +72,6 @@ def generate_op_trampoline(op_class):
for a in args.args:
a.arg = inflection.underscore(a.arg).lower()

for k in args.kwonlyargs:
k.arg = inflection.underscore(k.arg).lower()

keywords = [
ast.keyword(k.arg, ast.Name(k.arg))
for k, d in zip(args.kwonlyargs, args.kw_defaults)
]

fun_name = op_class.OPERATION_NAME.split(".")[-1].replace("-", "_")
if keyword.iskeyword(fun_name):
fun_name = fun_name + "_"
Expand All @@ -88,6 +80,11 @@ def generate_op_trampoline(op_class):
if len(args.args) == 1 and args.args[0].arg == "results_":
args.defaults.append(ast.Constant(None))
body += [ast.parse("results_ = results_ or []").body[0]]

keywords = [
ast.keyword(k.arg, ast.Name(inflection.underscore(k.arg).lower()))
for k, d in zip(args.kwonlyargs, args.kw_defaults)
]
if (
hasattr(op_class, "_ODS_REGIONS")
and op_class._ODS_REGIONS[0] == 1
Expand All @@ -103,6 +100,9 @@ def generate_op_trampoline(op_class):
).body[0]
]

for k in args.kwonlyargs:
k.arg = inflection.underscore(k.arg).lower()

args = copy.deepcopy(args)
oper_finder = FindOperands()
oper_finder.visit(init_fn)
Expand Down
4 changes: 3 additions & 1 deletion mlir_utils/context.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,8 @@

import mlir.ir

from mlir_utils import DefaultContext


@dataclass
class MLIRContext:
Expand All @@ -17,7 +19,7 @@ def __str__(self):
@contextmanager
def mlir_mod_ctx(
src: Optional[str] = None,
context: mlir.ir.Context = None,
context: mlir.ir.Context = DefaultContext,
location: mlir.ir.Location = None,
allow_unregistered_dialects=False,
) -> MLIRContext:
Expand Down
55 changes: 55 additions & 0 deletions mlir_utils/dialects/ext/arith.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,12 @@
Value,
IndexType,
RankedTensorType,
IntegerAttr,
IntegerType,
DenseElementsAttr,
register_attribute_builder,
Context,
Attribute,
)

from mlir_utils.dialects.util import get_result_or_results, maybe_cast
Expand Down Expand Up @@ -244,3 +249,53 @@ def isinstance(other: Value):
or _is_index_type(other.type)
or _is_complex_type(other.type)
)


@register_attribute_builder("Arith_CmpIPredicateAttr")
def _arith_CmpIPredicateAttr(predicate: str | Attribute, context: Context):
predicates = {
"eq": 0,
"ne": 1,
"slt": 2,
"sle": 3,
"sgt": 4,
"sge": 5,
"ult": 6,
"ule": 7,
"ugt": 8,
"uge": 9,
}
if isinstance(predicate, Attribute):
return predicate
assert predicate in predicates, f"predicate {predicate} not in predicates"
return IntegerAttr.get(
IntegerType.get_signless(64, context=context), predicates[predicate]
)


@register_attribute_builder("Arith_CmpFPredicateAttr")
def _arith_CmpFPredicateAttr(predicate: str | Attribute, context: Context):
predicates = {
"false": 0,
"oeq": 1,
"ogt": 2,
"oge": 3,
"olt": 4,
"ole": 5,
"one": 6,
"ord": 7,
"ueq": 8,
"ugt": 9,
"uge": 10,
"ult": 11,
"ule": 12,
"une": 13,
"uno": 14,
"true": 15,
}
if isinstance(predicate, Attribute):
return predicate
assert predicate in predicates, f"predicate {predicate} not in predicates"
return IntegerAttr.get(
IntegerType.get_signless(64, context=context), predicates[predicate]
)
60 changes: 47 additions & 13 deletions mlir_utils/dialects/ext/func.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
import inspect
from functools import wraps
from functools import wraps, partial

from mlir.dialects.func import FuncOp, ReturnOp, CallOp
from mlir.ir import (
Expand All @@ -8,24 +8,45 @@
StringAttr,
TypeAttr,
FlatSymbolRefAttr,
Type,
)

from mlir_utils.dialects.util import (
get_result_or_results,
make_maybe_no_args_decorator,
maybe_cast,
)


@make_maybe_no_args_decorator
def func(sym_visibility=None, arg_attrs=None, res_attrs=None, loc=None, ip=None):
def func_base(
FuncOp,
ReturnOp,
CallOp,
sym_visibility=None,
arg_attrs=None,
res_attrs=None,
loc=None,
ip=None,
):
ip = ip or InsertionPoint.current

# if this is set to true then wrapper below won't emit a call op
# it is set below by a def emit fn that is attached to the body_builder
# wrapper; thus you can call wrapped_fn.emit() (i.e., without an operands)
# and the func will be emitted.
_emit = False

def builder_wrapper(body_builder):
@wraps(body_builder)
def wrapper(*call_args):
sig = inspect.signature(body_builder)
implicit_return = sig.return_annotation is inspect._empty
input_types = [a.type for a in call_args]
input_types = [p.annotation for p in sig.parameters.values()]
if not (
len(input_types) == len(sig.parameters)
and all(isinstance(t, Type) for t in input_types)
):
input_types = [a.type for a in call_args]
function_type = TypeAttr.get(
FunctionType.get(
inputs=input_types,
Expand All @@ -34,7 +55,7 @@ def wrapper(*call_args):
)
# FuncOp is extended but we do really want the base
func_name = body_builder.__name__
func_op = FuncOp.__base__(
func_op = FuncOp(
func_name,
function_type,
sym_visibility=StringAttr.get(str(sym_visibility))
Expand All @@ -45,7 +66,7 @@ def wrapper(*call_args):
loc=loc,
ip=ip,
)
func_op.regions[0].blocks.append(*[a.type for a in call_args])
func_op.regions[0].blocks.append(*input_types)
with InsertionPoint(func_op.regions[0].blocks[0]):
results = get_result_or_results(
body_builder(*func_op.regions[0].blocks[0].arguments)
Expand All @@ -63,14 +84,27 @@ def wrapper(*call_args):
function_type = FunctionType.get(inputs=input_types, results=return_types)
func_op.attributes["function_type"] = TypeAttr.get(function_type)

call_op = CallOp(
[r.type for r in results], FlatSymbolRefAttr.get(func_name), call_args
)
if results is None:
return func_op
return get_result_or_results(call_op)
if _emit:
return maybe_cast(get_result_or_results(func_op))
else:
call_op = CallOp(
[r.type for r in results],
FlatSymbolRefAttr.get(func_name),
call_args,
)
return maybe_cast(get_result_or_results(call_op))

def emit():
nonlocal _emit
_emit = True
wrapper()

# wrapper.op = op
wrapper.emit = emit
return wrapper

return builder_wrapper


func = make_maybe_no_args_decorator(
partial(func_base, FuncOp=FuncOp.__base__, ReturnOp=ReturnOp, CallOp=CallOp)
)
12 changes: 11 additions & 1 deletion mlir_utils/dialects/util.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,21 @@
import ctypes
import inspect
import warnings
from collections import defaultdict
from functools import wraps
from typing import Callable

import mlir
from mlir.dialects._ods_common import get_op_result_or_value, get_op_results_or_values
from mlir.ir import InsertionPoint, Value, Type, TypeID
from mlir.ir import InsertionPoint, Value, Type

try:
from mlir.ir import TypeID
except ImportError:
warnings.warn(
f"TypeID not supported by {mlir=}; value casting won't work correctly"
)
TypeID = object


def get_result_or_results(op):
Expand Down
35 changes: 35 additions & 0 deletions tests/test_func.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,35 @@
import inspect
from textwrap import dedent

import pytest

from mlir_utils.dialects.ext.arith import constant
from mlir_utils.dialects.ext.func import func

# noinspection PyUnresolvedReferences
from mlir_utils.testing import mlir_ctx as ctx, filecheck, MLIRContext

# needed since the fix isn't defined here nor conftest.py
pytest.mark.usefixtures("ctx")


def test_emit(ctx: MLIRContext):
@func
def demo_fun1():
one = constant(1)
return one

assert hasattr(demo_fun1, "emit")
assert inspect.isfunction(demo_fun1.emit)
demo_fun1.emit()
correct = dedent(
"""\
module {
func.func @demo_fun1() -> i64 {
%c1_i64 = arith.constant 1 : i64
return %c1_i64 : i64
}
}
"""
)
filecheck(correct, ctx.module)
10 changes: 5 additions & 5 deletions tests/test_operator_overloading.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,9 +15,9 @@

def test_tensor_arithmetic(ctx: MLIRContext):
print()
one = constant(1, index_t)
one = constant(1)
assert isinstance(one, Scalar)
two = constant(2, index_t)
two = constant(2)
assert isinstance(two, Scalar)
three = one + two
assert isinstance(three, Scalar)
Expand All @@ -34,9 +34,9 @@ def test_tensor_arithmetic(ctx: MLIRContext):
dedent(
"""\
module {
%c1 = arith.constant 1 : index
%c2 = arith.constant 2 : index
%0 = arith.addi %c1, %c2 : index
%c1_i64 = arith.constant 1 : i64
%c2_i64 = arith.constant 2 : i64
%0 = arith.addi %c1_i64, %c2_i64 : i64
%1 = tensor.empty() : tensor<10x10x10xf64>
%2 = tensor.empty() : tensor<10x10x10xf64>
%3 = arith.addf %1, %2 : tensor<10x10x10xf64>
Expand Down
1 change: 0 additions & 1 deletion tests/test_value_caster.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,6 @@ def test_caster_registration(ctx: MLIRContext):
assert repr(ten) == "Tensor(%0, tensor<?x3x?xf64>)"

def dummy_caster(val):
print(val)
return val

register_value_caster(RankedTensorType.static_typeid, dummy_caster)
Expand Down

0 comments on commit 70dbd55

Please sign in to comment.