Skip to content

Commit

Permalink
make value_caster extensible
Browse files Browse the repository at this point in the history
  • Loading branch information
makslevental committed Jul 19, 2023
1 parent 22e50f9 commit 074afc8
Show file tree
Hide file tree
Showing 6 changed files with 98 additions and 33 deletions.
14 changes: 12 additions & 2 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -26,9 +26,11 @@ This package is meant to work in concert with the upstream bindings.
Practically speaking that means you need to have *some* package installed that includes mlir python bindings.
In addition, you have to do one of two things to **configure this package** (after installing it):

1. `$ configure-mlir-python-utils -y <MLIR_PYTHON_PACKAGE_PREFIX>`, where `MLIR_PYTHON_PACKAGE_PREFIX` is (as it says) the
1. `$ configure-mlir-python-utils -y <MLIR_PYTHON_PACKAGE_PREFIX>`, where `MLIR_PYTHON_PACKAGE_PREFIX` is (as it says)
the
package prefix for your chosen upstream bindings. So for example, for `torch-mlir`, you would
execute `configure-mlir-python-utils torch_mlir`, since `torch-mlir`'s bindings are the root of the `torch-mlir` python
execute `configure-mlir-python-utils torch_mlir`, since `torch-mlir`'s bindings are the root of the `torch-mlir`
python
package. **When in doubt about this prefix**, it is everything up until `ir` (e.g., as
in `from torch_mlir import ir`).
2. `$ export MLIR_PYTHON_PACKAGE_PREFIX=<MLIR_PYTHON_PACKAGE_PREFIX>`, i.e., you can set this string as an environment
Expand All @@ -49,4 +51,12 @@ pip install setuptools -U
pip install -e .[torch-mlir-test] \
-f https://github.com/makslevental/mlir-wheels/releases/expanded_assets/latest \
-f https://llvm.github.io/torch-mlir/package-index/
```

There's an annoying bug where if you try to register to a different set of host bindings it won't work the first (e.g.,
going from `torch-mlir` to `mlir`).
Workaround is to delete the prefix token before configuring, like so:

```shell
rm /home/mlevental/dev_projects/mlir_utils/mlir_utils/_configuration/__MLIR_PYTHON_PACKAGE_PREFIX__ && configure-mlir-python-utils mlir
```
8 changes: 8 additions & 0 deletions mlir_utils/_configuration/module_alias_map.py
Original file line number Diff line number Diff line change
Expand Up @@ -87,3 +87,11 @@ def find_spec(
)
else:
return None


def maybe_remove_alias_module_loader():
for i in range(len(sys.meta_path)):
finder = sys.meta_path[i]
if isinstance(finder, AliasedModuleFinder):
del sys.meta_path[i]
return
2 changes: 1 addition & 1 deletion mlir_utils/dialects/ext/arith.py
Original file line number Diff line number Diff line change
Expand Up @@ -155,7 +155,7 @@ def __call__(cls, *args, **kwargs):


class ArithValue(Value, metaclass=ArithValueMeta):
"""Mixin class for functionality shared by Value subclasses that support
"""Class for functionality shared by Value subclasses that support
arithmetic operations.
Note, since we bind the ArithValueMeta here, it is here that the __new__ and
Expand Down
28 changes: 3 additions & 25 deletions mlir_utils/dialects/ext/tensor.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
from mlir.ir import Type, Value, RankedTensorType, DenseElementsAttr, ShapedType

from mlir_utils.dialects.ext.arith import ArithValue
from mlir_utils.dialects.util import register_value_caster

try:
from mlir_utils.dialects.tensor import *
Expand Down Expand Up @@ -64,28 +65,5 @@ def empty(

return cls(EmptyOp(shape, el_type).result)

def __class_getitem__(
cls, dim_sizes_dtype: Tuple[Union[list[int], tuple[int, ...]], Type]
) -> Type:
"""A convenience method for creating RankedTensorType.
Args:
dim_sizes_dtype: A tuple of both the shape of the type and the dtype.
Returns:
An instance of RankedTensorType.
"""
if len(dim_sizes_dtype) != 2:
raise ValueError(
f"Wrong type of argument to {cls.__name__}: {dim_sizes_dtype=}"
)
dim_sizes, dtype = dim_sizes_dtype
if not isinstance(dtype, Type):
raise ValueError(f"{dtype=} is not {Type=}")
static_sizes = []
for s in dim_sizes:
if isinstance(s, int):
static_sizes.append(s)
else:
static_sizes.append(ShapedType.get_dynamic_size())
return RankedTensorType.get(static_sizes, dtype)

register_value_caster(RankedTensorType.static_typeid, Tensor)
44 changes: 39 additions & 5 deletions mlir_utils/dialects/util.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,11 @@
import ctypes
from functools import wraps
import inspect
from collections import defaultdict
from functools import wraps
from typing import Callable

from mlir.dialects._ods_common import get_op_result_or_value, get_op_results_or_values
from mlir.ir import InsertionPoint, Value, Type
from mlir.ir import InsertionPoint, Value, Type, TypeID


def get_result_or_results(op):
Expand Down Expand Up @@ -31,20 +33,52 @@ def maybe_no_args(*args, **kwargs):
return maybe_no_args


__VALUE_CASTERS: defaultdict[
TypeID, list[Callable[[Value], Value | None]]
] = defaultdict(list)


def register_value_caster(
typeid: TypeID, caster: Callable[[Value], Value], priority: int = None
):
if not isinstance(typeid, TypeID):
raise ValueError(f"{typeid=} is not a TypeID")
if priority is None:
__VALUE_CASTERS[typeid].append(caster)
else:
__VALUE_CASTERS[typeid].insert(priority, caster)


def has_value_caster(typeid: TypeID):
if not isinstance(typeid, TypeID):
raise ValueError(f"{typeid=} is not a TypeID")
if not typeid in __VALUE_CASTERS:
return False
return True


def get_value_caster(typeid: TypeID):
if not has_value_caster(typeid):
raise ValueError(f"no registered caster for {typeid=}")
return __VALUE_CASTERS[typeid]


def maybe_cast(val: Value):
"""Maybe cast an ir.Value to one of Tensor, Scalar.
Args:
val: The ir.Value to maybe cast.
"""
from mlir_utils.dialects.ext.tensor import Tensor
from mlir_utils.dialects.ext.arith import Scalar

if not isinstance(val, Value):
return val

if Tensor.isinstance(val):
return Tensor(val)
if has_value_caster(val.type.typeid):
for caster in get_value_caster(val.type.typeid):
if casted := caster(val):
return casted
raise ValueError(f"no successful casts for {val=}")
if Scalar.isinstance(val):
return Scalar(val)
return val
Expand Down
35 changes: 35 additions & 0 deletions tests/test_value_caster.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,35 @@
import pytest
from mlir.ir import OpResult

from mlir_utils.dialects.ext.tensor import S, empty
from mlir_utils.dialects.ext.arith import constant
from mlir_utils.dialects.util import register_value_caster

# noinspection PyUnresolvedReferences
from mlir_utils.testing import mlir_ctx as ctx, filecheck, MLIRContext
from mlir_utils.types import f64_t, RankedTensorType

# needed since the fix isn't defined here nor conftest.py
pytest.mark.usefixtures("ctx")


def test_caster_registration(ctx: MLIRContext):
sizes = S, 3, S
ten = empty(sizes, f64_t)
assert repr(ten) == "Tensor(%0, tensor<?x3x?xf64>)"

def dummy_caster(val):
print(val)
return val

register_value_caster(RankedTensorType.static_typeid, dummy_caster)
ten = empty(sizes, f64_t)
assert repr(ten) == "Tensor(%1, tensor<?x3x?xf64>)"

register_value_caster(RankedTensorType.static_typeid, dummy_caster, 0)
ten = empty(sizes, f64_t)
assert repr(ten) != "Tensor(%1, tensor<?x3x?xf64>)"
assert isinstance(ten, OpResult)

one = constant(1)
assert repr(one) == "Scalar(%3, i64)"

0 comments on commit 074afc8

Please sign in to comment.