From 587d10723c38fb3914be8629dcb3debcb654eff0 Mon Sep 17 00:00:00 2001 From: Till Ehrengruber Date: Thu, 27 Feb 2025 10:22:16 +0100 Subject: [PATCH] bug[next]: Respect evaluation order in `InlineCenterDerefLiftVars` (#1883) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Changes the `InlineCenterDerefLiftVars` pass to respect evaluation order by lazily evaluating the inlined values. Consider the following case that is common in boundary conditions: ``` let(var, (↑deref)(it2))(if ·on_bc then 0 else ·var) ``` Then var should only be dereferenced in case `·on_bc` evalutes to False. Previously we just evaluated all values unconditionally: ``` let(_icdlv_1, ·it)(if ·on_bc then 0 else _icdlv_1) ``` Now we instead create a 0-ary lambda function for `_icdlv_1` and evaluate it when the value is needed. ``` let(_icdlv_1, λ() → ·it)(if ·on_bc then 0 else _icdlv_1()) ``` Note that as a result we do evaluate the function multiple times. To avoid redundant recompuations usage of the common subexpression elimination is required. --- .../inline_center_deref_lift_vars.py | 38 +++++++++++++------ .../transforms_tests/test_fuse_as_fieldop.py | 8 ++-- .../test_inline_center_deref_lift_vars.py | 38 ++++++++++++++++--- 3 files changed, 63 insertions(+), 21 deletions(-) diff --git a/src/gt4py/next/iterator/transforms/inline_center_deref_lift_vars.py b/src/gt4py/next/iterator/transforms/inline_center_deref_lift_vars.py index 7bd26d0f19..c0a8c9f1b7 100644 --- a/src/gt4py/next/iterator/transforms/inline_center_deref_lift_vars.py +++ b/src/gt4py/next/iterator/transforms/inline_center_deref_lift_vars.py @@ -36,16 +36,25 @@ class InlineCenterDerefLiftVars(eve.NodeTranslator): `let(var, (↑stencil)(it))(·var + ·var)` Directly inlining `var` would increase the size of the tree and duplicate the calculation. - Instead, this pass computes the value at the current location once and replaces all previous - references to `var` by an applied lift which captures this value. + Instead this pass, first takes the iterator `(↑stencil)(it)` and transforms it into a + 0-ary function that evaluates to the value at the current location. - `let(_icdlv_1, stencil(it))(·(↑(λ() → _icdlv_1) + ·(↑(λ() → _icdlv_1))` + `λ() → ·(↑stencil)(it)` + + Then all previous occurences of `var` are replaced by this function. + + `let(_icdlv_1, λ() → ·(↑stencil)(it))(·(↑(λ() → _icdlv_1()) + ·(↑(λ() → _icdlv_1()))` The lift inliner can then later easily transform this into a nice expression: - `let(_icdlv_1, stencil(it))(_icdlv_1 + _icdlv_1)` + `let(_icdlv_1, λ() → stencil(it))(_icdlv_1() + _icdlv_1())` + + Finally, recomputation is avoided by using the common subexpression elimination and lamba + inlining (can be configured opcount preserving). Both is up to the caller to do later. + + `λ(_cs_1) → _cs_1 + _cs_1)(stencil(it))` - Note: This pass uses and preserves the `recorded_shifts` annex. + Note: This pass uses and preserves the `domain` and `recorded_shifts` annex. """ PRESERVED_ANNEX_ATTRS: ClassVar[tuple[str, ...]] = ("domain", "recorded_shifts") @@ -78,20 +87,25 @@ def visit_FunCall(self, node: itir.FunCall, **kwargs): assert isinstance(node.fun, itir.Lambda) # to make mypy happy eligible_params = [False] * len(node.fun.params) new_args = [] - bound_scalars: dict[str, itir.Expr] = {} + # values are 0-ary lambda functions that evaluate to the derefed argument. We don't put + # the values themselves here as they might be inside of an if to protected from an oob + # access + evaluators: dict[str, itir.Expr] = {} for i, (param, arg) in enumerate(zip(node.fun.params, node.args)): if cpm.is_applied_lift(arg) and is_center_derefed_only(param): eligible_params[i] = True - bound_arg_name = self.uids.sequential_id(prefix="_icdlv") - capture_lift = im.promote_to_const_iterator(bound_arg_name) + bound_arg_evaluator = self.uids.sequential_id(prefix="_icdlv") + capture_lift = im.promote_to_const_iterator(im.call(bound_arg_evaluator)()) trace_shifts.copy_recorded_shifts(from_=param, to=capture_lift) new_args.append(capture_lift) # since we deref an applied lift here we can (but don't need to) immediately # inline - bound_scalars[bound_arg_name] = InlineLifts( - flags=InlineLifts.Flag.INLINE_TRIVIAL_DEREF_LIFT - ).visit(im.deref(arg), recurse=False) + evaluators[bound_arg_evaluator] = im.lambda_()( + InlineLifts(flags=InlineLifts.Flag.INLINE_DEREF_LIFT).visit( + im.deref(arg), recurse=False + ) + ) else: new_args.append(arg) @@ -100,6 +114,6 @@ def visit_FunCall(self, node: itir.FunCall, **kwargs): im.call(node.fun)(*new_args), eligible_params=eligible_params ) # TODO(tehrengruber): propagate let outwards - return im.let(*bound_scalars.items())(new_node) + return im.let(*evaluators.items())(new_node) return node diff --git a/tests/next_tests/unit_tests/iterator_tests/transforms_tests/test_fuse_as_fieldop.py b/tests/next_tests/unit_tests/iterator_tests/transforms_tests/test_fuse_as_fieldop.py index fd884e239f..14aebd032c 100644 --- a/tests/next_tests/unit_tests/iterator_tests/transforms_tests/test_fuse_as_fieldop.py +++ b/tests/next_tests/unit_tests/iterator_tests/transforms_tests/test_fuse_as_fieldop.py @@ -145,8 +145,8 @@ def test_staged_inlining(): ) expected = im.as_fieldop( im.lambda_("a", "b")( - im.let("_icdlv_1", im.plus(im.deref("a"), im.deref("b")))( - im.plus(im.plus("_icdlv_1", 1), im.plus("_icdlv_1", 2)) + im.let("_icdlv_1", im.lambda_()(im.plus(im.deref("a"), im.deref("b"))))( + im.plus(im.plus(im.call("_icdlv_1")(), 1), im.plus(im.call("_icdlv_1")(), 2)) ) ), d, @@ -328,8 +328,8 @@ def test_chained_fusion(): ) expected = im.as_fieldop( im.lambda_("inp1", "inp2")( - im.let("_icdlv_1", im.plus(im.deref("inp1"), im.deref("inp2")))( - im.plus("_icdlv_1", "_icdlv_1") + im.let("_icdlv_1", im.lambda_()(im.plus(im.deref("inp1"), im.deref("inp2"))))( + im.plus(im.call("_icdlv_1")(), im.call("_icdlv_1")()) ) ), d, diff --git a/tests/next_tests/unit_tests/iterator_tests/transforms_tests/test_inline_center_deref_lift_vars.py b/tests/next_tests/unit_tests/iterator_tests/transforms_tests/test_inline_center_deref_lift_vars.py index 6cc2f7cd28..2caa887803 100644 --- a/tests/next_tests/unit_tests/iterator_tests/transforms_tests/test_inline_center_deref_lift_vars.py +++ b/tests/next_tests/unit_tests/iterator_tests/transforms_tests/test_inline_center_deref_lift_vars.py @@ -6,20 +6,33 @@ # Please, refer to the LICENSE file in the root directory. # SPDX-License-Identifier: BSD-3-Clause +from gt4py.next.type_system import type_specifications as ts from gt4py.next.iterator import ir as itir from gt4py.next.iterator.ir_utils import ir_makers as im +from gt4py.next.iterator.transforms import cse from gt4py.next.iterator.transforms.inline_center_deref_lift_vars import InlineCenterDerefLiftVars +field_type = ts.FieldType(dims=[], dtype=ts.ScalarType(kind=ts.ScalarKind.FLOAT64)) -def wrap_in_program(expr: itir.Expr) -> itir.Program: + +def wrap_in_program(expr: itir.Expr, *, arg_dtypes=None) -> itir.Program: + if arg_dtypes is None: + arg_dtypes = [ts.ScalarKind.FLOAT64] + arg_types = [ts.FieldType(dims=[], dtype=ts.ScalarType(kind=dtype)) for dtype in arg_dtypes] + indices = [i for i in range(1, len(arg_dtypes) + 1)] if len(arg_dtypes) > 1 else [""] return itir.Program( id="f", function_definitions=[], - params=[im.sym("d"), im.sym("inp"), im.sym("out")], + params=[ + *(im.sym(f"inp{i}", type_) for i, type_ in zip(indices, arg_types)), + im.sym("out", field_type), + ], declarations=[], body=[ itir.SetAt( - expr=im.as_fieldop(im.lambda_("it")(expr))(im.ref("inp")), + expr=im.as_fieldop(im.lambda_(*(f"it{i}" for i in indices))(expr))( + *(im.ref(f"inp{i}") for i in indices) + ), domain=im.call("cartesian_domain")(), target=im.ref("out"), ) @@ -34,7 +47,7 @@ def unwrap_from_program(program: itir.Program) -> itir.Expr: def test_simple(): testee = im.let("var", im.lift("deref")("it"))(im.deref("var")) - expected = "(λ(_icdlv_1) → ·(↑(λ() → _icdlv_1))())(·it)" + expected = "(λ(_icdlv_1) → ·(↑(λ() → _icdlv_1()))())(λ() → ·it)" actual = unwrap_from_program(InlineCenterDerefLiftVars.apply(wrap_in_program(testee))) assert str(actual) == expected @@ -42,7 +55,7 @@ def test_simple(): def test_double_deref(): testee = im.let("var", im.lift("deref")("it"))(im.plus(im.deref("var"), im.deref("var"))) - expected = "(λ(_icdlv_1) → ·(↑(λ() → _icdlv_1))() + ·(↑(λ() → _icdlv_1))())(·it)" + expected = "(λ(_icdlv_1) → ·(↑(λ() → _icdlv_1()))() + ·(↑(λ() → _icdlv_1()))())(λ() → ·it)" actual = unwrap_from_program(InlineCenterDerefLiftVars.apply(wrap_in_program(testee))) assert str(actual) == expected @@ -62,3 +75,18 @@ def test_deref_at_multiple_pos(): actual = unwrap_from_program(InlineCenterDerefLiftVars.apply(wrap_in_program(testee))) assert testee == actual + + +def test_bc(): + # we also check that the common subexpression is able to extract the inlined value, such + # that it is only evaluated once + testee = im.let("var", im.lift("deref")("it2"))( + im.if_(im.deref("it1"), im.literal_from_value(0), im.plus(im.deref("var"), im.deref("var"))) + ) + expected = "(λ(_icdlv_1) → if ·it1 then 0 else (λ(_cs_1) → _cs_1 + _cs_1)(·(↑(λ() → _icdlv_1()))()))(λ() → ·it2)" + + actual = InlineCenterDerefLiftVars.apply( + wrap_in_program(testee, arg_dtypes=[ts.ScalarKind.BOOL, ts.ScalarKind.FLOAT64]) + ) + simplified = unwrap_from_program(cse.CommonSubexpressionElimination.apply(actual)) + assert str(simplified) == expected