Skip to content

Commit

Permalink
bug[next]: Respect evaluation order in InlineCenterDerefLiftVars (#…
Browse files Browse the repository at this point in the history
…1883)

Changes the `InlineCenterDerefLiftVars` pass to respect evaluation order
by lazily evaluating the inlined values.

Consider the following case that is common in boundary conditions:
```
let(var, (↑deref)(it2))(if ·on_bc then 0 else ·var)
```
Then var should only be dereferenced in case `·on_bc` evalutes to False.
Previously we just evaluated all values unconditionally:
```
let(_icdlv_1, ·it)(if ·on_bc then 0 else _icdlv_1)
```
Now we instead create a 0-ary lambda function for `_icdlv_1` and
evaluate it when the value is needed.

```
let(_icdlv_1, λ() → ·it)(if ·on_bc then 0 else _icdlv_1())
```
Note that as a result we do evaluate the function multiple times. To
avoid redundant recompuations usage of the common subexpression
elimination is required.
  • Loading branch information
tehrengruber authored Feb 27, 2025
1 parent 847d8ab commit 587d107
Show file tree
Hide file tree
Showing 3 changed files with 63 additions and 21 deletions.
38 changes: 26 additions & 12 deletions src/gt4py/next/iterator/transforms/inline_center_deref_lift_vars.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,16 +36,25 @@ class InlineCenterDerefLiftVars(eve.NodeTranslator):
`let(var, (↑stencil)(it))(·var + ·var)`
Directly inlining `var` would increase the size of the tree and duplicate the calculation.
Instead, this pass computes the value at the current location once and replaces all previous
references to `var` by an applied lift which captures this value.
Instead this pass, first takes the iterator `(↑stencil)(it)` and transforms it into a
0-ary function that evaluates to the value at the current location.
`let(_icdlv_1, stencil(it))(·(↑(λ() → _icdlv_1) + ·(↑(λ() → _icdlv_1))`
`λ() → ·(↑stencil)(it)`
Then all previous occurences of `var` are replaced by this function.
`let(_icdlv_1, λ() → ·(↑stencil)(it))(·(↑(λ() → _icdlv_1()) + ·(↑(λ() → _icdlv_1()))`
The lift inliner can then later easily transform this into a nice expression:
`let(_icdlv_1, stencil(it))(_icdlv_1 + _icdlv_1)`
`let(_icdlv_1, λ() → stencil(it))(_icdlv_1() + _icdlv_1())`
Finally, recomputation is avoided by using the common subexpression elimination and lamba
inlining (can be configured opcount preserving). Both is up to the caller to do later.
`λ(_cs_1) → _cs_1 + _cs_1)(stencil(it))`
Note: This pass uses and preserves the `recorded_shifts` annex.
Note: This pass uses and preserves the `domain` and `recorded_shifts` annex.
"""

PRESERVED_ANNEX_ATTRS: ClassVar[tuple[str, ...]] = ("domain", "recorded_shifts")
Expand Down Expand Up @@ -78,20 +87,25 @@ def visit_FunCall(self, node: itir.FunCall, **kwargs):
assert isinstance(node.fun, itir.Lambda) # to make mypy happy
eligible_params = [False] * len(node.fun.params)
new_args = []
bound_scalars: dict[str, itir.Expr] = {}
# values are 0-ary lambda functions that evaluate to the derefed argument. We don't put
# the values themselves here as they might be inside of an if to protected from an oob
# access
evaluators: dict[str, itir.Expr] = {}

for i, (param, arg) in enumerate(zip(node.fun.params, node.args)):
if cpm.is_applied_lift(arg) and is_center_derefed_only(param):
eligible_params[i] = True
bound_arg_name = self.uids.sequential_id(prefix="_icdlv")
capture_lift = im.promote_to_const_iterator(bound_arg_name)
bound_arg_evaluator = self.uids.sequential_id(prefix="_icdlv")
capture_lift = im.promote_to_const_iterator(im.call(bound_arg_evaluator)())
trace_shifts.copy_recorded_shifts(from_=param, to=capture_lift)
new_args.append(capture_lift)
# since we deref an applied lift here we can (but don't need to) immediately
# inline
bound_scalars[bound_arg_name] = InlineLifts(
flags=InlineLifts.Flag.INLINE_TRIVIAL_DEREF_LIFT
).visit(im.deref(arg), recurse=False)
evaluators[bound_arg_evaluator] = im.lambda_()(
InlineLifts(flags=InlineLifts.Flag.INLINE_DEREF_LIFT).visit(
im.deref(arg), recurse=False
)
)
else:
new_args.append(arg)

Expand All @@ -100,6 +114,6 @@ def visit_FunCall(self, node: itir.FunCall, **kwargs):
im.call(node.fun)(*new_args), eligible_params=eligible_params
)
# TODO(tehrengruber): propagate let outwards
return im.let(*bound_scalars.items())(new_node)
return im.let(*evaluators.items())(new_node)

return node
Original file line number Diff line number Diff line change
Expand Up @@ -145,8 +145,8 @@ def test_staged_inlining():
)
expected = im.as_fieldop(
im.lambda_("a", "b")(
im.let("_icdlv_1", im.plus(im.deref("a"), im.deref("b")))(
im.plus(im.plus("_icdlv_1", 1), im.plus("_icdlv_1", 2))
im.let("_icdlv_1", im.lambda_()(im.plus(im.deref("a"), im.deref("b"))))(
im.plus(im.plus(im.call("_icdlv_1")(), 1), im.plus(im.call("_icdlv_1")(), 2))
)
),
d,
Expand Down Expand Up @@ -328,8 +328,8 @@ def test_chained_fusion():
)
expected = im.as_fieldop(
im.lambda_("inp1", "inp2")(
im.let("_icdlv_1", im.plus(im.deref("inp1"), im.deref("inp2")))(
im.plus("_icdlv_1", "_icdlv_1")
im.let("_icdlv_1", im.lambda_()(im.plus(im.deref("inp1"), im.deref("inp2"))))(
im.plus(im.call("_icdlv_1")(), im.call("_icdlv_1")())
)
),
d,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -6,20 +6,33 @@
# Please, refer to the LICENSE file in the root directory.
# SPDX-License-Identifier: BSD-3-Clause

from gt4py.next.type_system import type_specifications as ts
from gt4py.next.iterator import ir as itir
from gt4py.next.iterator.ir_utils import ir_makers as im
from gt4py.next.iterator.transforms import cse
from gt4py.next.iterator.transforms.inline_center_deref_lift_vars import InlineCenterDerefLiftVars

field_type = ts.FieldType(dims=[], dtype=ts.ScalarType(kind=ts.ScalarKind.FLOAT64))

def wrap_in_program(expr: itir.Expr) -> itir.Program:

def wrap_in_program(expr: itir.Expr, *, arg_dtypes=None) -> itir.Program:
if arg_dtypes is None:
arg_dtypes = [ts.ScalarKind.FLOAT64]
arg_types = [ts.FieldType(dims=[], dtype=ts.ScalarType(kind=dtype)) for dtype in arg_dtypes]
indices = [i for i in range(1, len(arg_dtypes) + 1)] if len(arg_dtypes) > 1 else [""]
return itir.Program(
id="f",
function_definitions=[],
params=[im.sym("d"), im.sym("inp"), im.sym("out")],
params=[
*(im.sym(f"inp{i}", type_) for i, type_ in zip(indices, arg_types)),
im.sym("out", field_type),
],
declarations=[],
body=[
itir.SetAt(
expr=im.as_fieldop(im.lambda_("it")(expr))(im.ref("inp")),
expr=im.as_fieldop(im.lambda_(*(f"it{i}" for i in indices))(expr))(
*(im.ref(f"inp{i}") for i in indices)
),
domain=im.call("cartesian_domain")(),
target=im.ref("out"),
)
Expand All @@ -34,15 +47,15 @@ def unwrap_from_program(program: itir.Program) -> itir.Expr:

def test_simple():
testee = im.let("var", im.lift("deref")("it"))(im.deref("var"))
expected = "(λ(_icdlv_1) → ·(↑(λ() → _icdlv_1))())(·it)"
expected = "(λ(_icdlv_1) → ·(↑(λ() → _icdlv_1()))())(λ() → ·it)"

actual = unwrap_from_program(InlineCenterDerefLiftVars.apply(wrap_in_program(testee)))
assert str(actual) == expected


def test_double_deref():
testee = im.let("var", im.lift("deref")("it"))(im.plus(im.deref("var"), im.deref("var")))
expected = "(λ(_icdlv_1) → ·(↑(λ() → _icdlv_1))() + ·(↑(λ() → _icdlv_1))())(·it)"
expected = "(λ(_icdlv_1) → ·(↑(λ() → _icdlv_1()))() + ·(↑(λ() → _icdlv_1()))())(λ() → ·it)"

actual = unwrap_from_program(InlineCenterDerefLiftVars.apply(wrap_in_program(testee)))
assert str(actual) == expected
Expand All @@ -62,3 +75,18 @@ def test_deref_at_multiple_pos():

actual = unwrap_from_program(InlineCenterDerefLiftVars.apply(wrap_in_program(testee)))
assert testee == actual


def test_bc():
# we also check that the common subexpression is able to extract the inlined value, such
# that it is only evaluated once
testee = im.let("var", im.lift("deref")("it2"))(
im.if_(im.deref("it1"), im.literal_from_value(0), im.plus(im.deref("var"), im.deref("var")))
)
expected = "(λ(_icdlv_1) → if ·it1 then 0 else (λ(_cs_1) → _cs_1 + _cs_1)(·(↑(λ() → _icdlv_1()))()))(λ() → ·it2)"

actual = InlineCenterDerefLiftVars.apply(
wrap_in_program(testee, arg_dtypes=[ts.ScalarKind.BOOL, ts.ScalarKind.FLOAT64])
)
simplified = unwrap_from_program(cse.CommonSubexpressionElimination.apply(actual))
assert str(simplified) == expected

0 comments on commit 587d107

Please sign in to comment.