Skip to content

Commit

Permalink
improve location tracking (and test it)
Browse files Browse the repository at this point in the history
  • Loading branch information
makslevental committed Jul 26, 2023
1 parent 49c5f5f commit 1cd37d6
Show file tree
Hide file tree
Showing 6 changed files with 160 additions and 33 deletions.
39 changes: 26 additions & 13 deletions mlir_utils/dialects/ext/arith.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,21 +13,22 @@
_is_index_type,
)
from mlir.ir import (
Attribute,
Context,
DenseElementsAttr,
IndexType,
IntegerAttr,
IntegerType,
Location,
OpView,
Operation,
RankedTensorType,
Type,
Value,
IndexType,
RankedTensorType,
IntegerAttr,
IntegerType,
DenseElementsAttr,
register_attribute_builder,
Context,
Attribute,
)

from mlir_utils.util import get_result_or_results, maybe_cast
from mlir_utils.util import get_result_or_results, maybe_cast, get_user_code_loc

try:
from mlir_utils.dialects.arith import *
Expand All @@ -41,6 +42,8 @@ def constant(
value: Union[int, float, bool, np.ndarray],
type: Optional[Type] = None,
index: Optional[bool] = None,
*,
loc: Location = None,
) -> arith_dialect.ConstantOp:
"""Instantiate arith.constant with value `value`.
Expand All @@ -56,6 +59,8 @@ def constant(
Returns:
ir.OpView instance that corresponds to instantiated arith.constant op.
"""
if loc is None:
loc = get_user_code_loc()
if index is not None and index:
type = IndexType.get()
if type is None:
Expand All @@ -73,8 +78,9 @@ def constant(
value,
type=type,
)

return maybe_cast(get_result_or_results(arith_dialect.ConstantOp(type, value)))
return maybe_cast(
get_result_or_results(arith_dialect.ConstantOp(type, value, loc=loc))
)


class ArithValueMeta(type(Value)):
Expand Down Expand Up @@ -217,7 +223,12 @@ def _arith_CmpFPredicateAttr(predicate: str | Attribute, context: Context):


def _binary_op(
lhs: "ArithValue", rhs: "ArithValue", op: str, predicate: str = None
lhs: "ArithValue",
rhs: "ArithValue",
op: str,
predicate: str = None,
*,
loc: Location = None,
) -> "ArithValue":
"""Generic for handling infix binary operator dispatch.
Expand All @@ -230,6 +241,8 @@ def _binary_op(
Returns:
Result of binary operation. This will be a handle to an arith(add|sub|mul) op.
"""
if loc is None:
loc = get_user_code_loc()
if not isinstance(rhs, lhs.__class__):
rhs = lhs.__class__(rhs, dtype=lhs.type)

Expand Down Expand Up @@ -258,9 +271,9 @@ def _binary_op(
predicate = "s" + predicate
else:
predicate = "u" + predicate
return lhs.__class__(op(predicate, lhs, rhs), dtype=lhs.dtype)
return lhs.__class__(op(predicate, lhs, rhs, loc=loc), dtype=lhs.dtype)
else:
return lhs.__class__(op(lhs, rhs), dtype=lhs.dtype)
return lhs.__class__(op(lhs, rhs, loc=loc), dtype=lhs.dtype)


class ArithValue(Value, metaclass=ArithValueMeta):
Expand Down
6 changes: 5 additions & 1 deletion mlir_utils/dialects/ext/func.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
TypeAttr,
FlatSymbolRefAttr,
Type,
Location,
)

from mlir_utils.util import (
Expand Down Expand Up @@ -103,13 +104,16 @@ def emit(self):
# this is the func op itself (funcs never have a resulting ssa value)
return maybe_cast(get_result_or_results(func_op))

def __call__(self, *call_args):
def __call__(self, *call_args, loc: Location = None):
if loc is None:
loc = get_user_code_loc()
if not self.emitted:
self.emit()
call_op = self.call_op_ctor(
[r.type for r in self.results],
FlatSymbolRefAttr.get(self.func_name),
call_args,
loc=loc,
)
return maybe_cast(get_result_or_results(call_op))

Expand Down
14 changes: 10 additions & 4 deletions mlir_utils/dialects/ext/scf.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
import libcst as cst
import libcst.matchers as m
from bytecode import ConcreteBytecode, ConcreteInstr
from mlir.dialects import scf
from mlir.dialects.scf import IfOp, ForOp
from mlir.ir import InsertionPoint, Value, OpResultList, OpResult

from mlir_utils.ast.canonicalize import (
Expand All @@ -24,6 +24,7 @@
maybe_cast,
_update_caller_vars,
get_result_or_results,
get_user_code_loc,
)

logger = logging.getLogger(__name__)
Expand All @@ -49,7 +50,9 @@ def _for(
stop = constant(stop, index=True)
if isinstance(step, int):
step = constant(step, index=True)
return scf.ForOp(start, stop, step, iter_args, loc=loc, ip=ip)
if loc is None:
loc = get_user_code_loc()
return ForOp(start, stop, step, iter_args, loc=loc, ip=ip)


for_ = region_op(_for, terminator=yield__)
Expand Down Expand Up @@ -91,7 +94,9 @@ def _if(cond, results_=None, *, has_else=False, loc=None, ip=None):
results_ = []
if results_:
has_else = True
return scf.IfOp(cond, results_, hasElse=has_else, loc=loc, ip=ip)
if loc is None:
loc = get_user_code_loc()
return IfOp(cond, results_, hasElse=has_else, loc=loc, ip=ip)


if_ = region_op(_if, terminator=yield__)
Expand All @@ -100,7 +105,7 @@ def _if(cond, results_=None, *, has_else=False, loc=None, ip=None):


class IfStack:
__current_if_op: list[scf.IfOp] = []
__current_if_op: list[IfOp] = []
__if_ip: list[InsertionPoint] = []

@staticmethod
Expand Down Expand Up @@ -423,6 +428,7 @@ def patch_bytecode(self, code: ConcreteBytecode, f):
f.__globals__[end_if.__name__] = end_if
f.__globals__[stack_if.__name__] = stack_if
f.__globals__[stack_yield.__name__] = stack_yield
f.__globals__[yield_.__name__] = yield_
f.__globals__["_placeholder_opaque_t"] = _placeholder_opaque_t
return code

Expand Down
21 changes: 7 additions & 14 deletions mlir_utils/util.py
Original file line number Diff line number Diff line change
Expand Up @@ -134,7 +134,8 @@ def builder_wrapper(body_builder):
f"for {body_builder=} either missing a type annotation or type annotation isn't a mlir type: {sig}"
)

op.regions[0].blocks.append(*types)
arg_locs = [get_user_code_loc()] * len(sig.parameters)
op.regions[0].blocks.append(*types, arg_locs=arg_locs)
with InsertionPoint(op.regions[0].blocks[0]):
results = body_builder(
*[maybe_cast(a) for a in op.regions[0].blocks[0].arguments]
Expand Down Expand Up @@ -209,17 +210,9 @@ def get_user_code_loc():
mlir_utis_root_path = Path(mlir_utils.__path__[0])

prev_frame = inspect.currentframe().f_back
stack = traceback.StackSummary.extract(traceback.walk_stack(prev_frame))

user_frame = next(
(
fr
for fr in stack
if not Path(fr.filename).is_relative_to(mlir_utis_root_path)
),
None,
while Path(prev_frame.f_code.co_filename).is_relative_to(mlir_utis_root_path):
prev_frame = prev_frame.f_back
frame_info = inspect.getframeinfo(prev_frame)
return Location.file(
frame_info.filename, frame_info.lineno, frame_info.positions.col_offset
)
if user_frame is None:
warnings.warn("couldn't find user code frame")
return
return Location.file(user_frame.filename, user_frame.lineno, user_frame.colno or 0)
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
[project]
name = "mlir-python-utils"
version = "0.0.2"
version = "0.0.3"
description = "The missing pieces (as far as boilerplate reduction goes) of the upstream MLIR python bindings."
requires-python = ">=3.11"
license = { file = "LICENSE" }
Expand Down
111 changes: 111 additions & 0 deletions tests/test_location_tracking.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,111 @@
from pathlib import Path
from textwrap import dedent
from os import sep
import pytest

from mlir_utils.ast.canonicalize import canonicalize
from mlir_utils.dialects.ext.arith import constant
from mlir_utils.dialects.ext.scf import (
canonicalizer,
stack_if,
)
from mlir_utils.dialects.ext.tensor import S
from mlir_utils.dialects.tensor import generate, yield_ as tensor_yield, rank

# noinspection PyUnresolvedReferences
from mlir_utils.testing import mlir_ctx as ctx, filecheck, MLIRContext
from mlir_utils.types import f64_t, index_t, tensor_t

# needed since the fix isn't defined here nor conftest.py
pytest.mark.usefixtures("ctx")


THIS_DIR = str(Path(__file__).parent.absolute())


def get_asm(operation):
return operation.get_asm(enable_debug_info=True, pretty_debug_info=True).replace(
THIS_DIR, "THIS_DIR"
)


def test_if_replace_yield_5(ctx: MLIRContext):
@canonicalize(using=canonicalizer)
def iffoo():
one = constant(1.0)
two = constant(2.0)
if res := stack_if(one < two, (f64_t, f64_t, f64_t)):
three = constant(3.0)
yield three, three, three
else:
four = constant(4.0)
yield four, four, four
return

iffoo()
ctx.module.operation.verify()
correct = dedent(
f"""\
module {{
%cst = arith.constant 1.000000e+00 : f64 THIS_DIR{sep}test_location_tracking.py:35:10
%cst_0 = arith.constant 2.000000e+00 : f64 THIS_DIR{sep}test_location_tracking.py:36:10
%0 = arith.cmpf olt, %cst, %cst_0 : f64 THIS_DIR{sep}test_location_tracking.py:37:23
%1:3 = scf.if %0 -> (f64, f64, f64) {{
%cst_1 = arith.constant 3.000000e+00 : f64 THIS_DIR{sep}test_location_tracking.py:38:16
scf.yield %cst_1, %cst_1, %cst_1 : f64, f64, f64 THIS_DIR{sep}test_location_tracking.py:39:8
}} else {{
%cst_1 = arith.constant 4.000000e+00 : f64 THIS_DIR{sep}test_location_tracking.py:41:24
scf.yield %cst_1, %cst_1, %cst_1 : f64, f64, f64 THIS_DIR{sep}test_location_tracking.py:42:8
}} THIS_DIR{sep}test_location_tracking.py:37:14
}} [unknown]
#loc = [unknown]
#loc1 = THIS_DIR{sep}test_location_tracking.py:35:10
#loc2 = THIS_DIR{sep}test_location_tracking.py:36:10
#loc3 = THIS_DIR{sep}test_location_tracking.py:37:23
#loc4 = THIS_DIR{sep}test_location_tracking.py:37:14
#loc5 = THIS_DIR{sep}test_location_tracking.py:38:16
#loc6 = THIS_DIR{sep}test_location_tracking.py:39:8
#loc7 = THIS_DIR{sep}test_location_tracking.py:41:24
#loc8 = THIS_DIR{sep}test_location_tracking.py:42:8
"""
)
asm = get_asm(ctx.module.operation)
filecheck(correct, asm)


def test_block_args(ctx: MLIRContext):
one = constant(1, index_t)
two = constant(2, index_t)

@generate(tensor_t(S, 3, S, f64_t), dynamic_extents=[one, two])
def demo_fun1(i: index_t, j: index_t, k: index_t):
one = constant(1.0)
tensor_yield(one)

r = rank(demo_fun1)

ctx.module.operation.verify()

correct = dedent(
f"""\
#loc3 = THIS_DIR{sep}test_location_tracking.py:80:5
module {{
%c1 = arith.constant 1 : index THIS_DIR{sep}test_location_tracking.py:77:10
%c2 = arith.constant 2 : index THIS_DIR{sep}test_location_tracking.py:78:10
%generated = tensor.generate %c1, %c2 {{
^bb0(%arg0: index THIS_DIR{sep}test_location_tracking.py:80:5, %arg1: index THIS_DIR{sep}test_location_tracking.py:80:5, %arg2: index THIS_DIR{sep}test_location_tracking.py:80:5):
%cst = arith.constant 1.000000e+00 : f64 THIS_DIR{sep}test_location_tracking.py:82:14
tensor.yield %cst : f64 THIS_DIR{sep}test_location_tracking.py:83:8
}} : tensor<?x3x?xf64> THIS_DIR{sep}test_location_tracking.py:80:5
%rank = tensor.rank %generated : tensor<?x3x?xf64> THIS_DIR{sep}test_location_tracking.py:85:8
}} [unknown]
#loc = [unknown]
#loc1 = THIS_DIR{sep}test_location_tracking.py:77:10
#loc2 = THIS_DIR{sep}test_location_tracking.py:78:10
#loc4 = THIS_DIR{sep}test_location_tracking.py:82:14
#loc5 = THIS_DIR{sep}test_location_tracking.py:83:8
#loc6 = THIS_DIR{sep}test_location_tracking.py:85:8
"""
)
asm = get_asm(ctx.module.operation)
filecheck(correct, asm)

0 comments on commit 1cd37d6

Please sign in to comment.