Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

bug[next]: Respect evaluation order in InlineCenterDerefLiftVars #1883

Merged
merged 6 commits into from
Feb 27, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
38 changes: 26 additions & 12 deletions src/gt4py/next/iterator/transforms/inline_center_deref_lift_vars.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,16 +36,25 @@ class InlineCenterDerefLiftVars(eve.NodeTranslator):
`let(var, (↑stencil)(it))(·var + ·var)`

Directly inlining `var` would increase the size of the tree and duplicate the calculation.
Instead, this pass computes the value at the current location once and replaces all previous
references to `var` by an applied lift which captures this value.
Instead this pass, first takes the iterator `(↑stencil)(it)` and transforms it into a
0-ary function that evaluates to the value at the current location.

`let(_icdlv_1, stencil(it))(·(↑(λ() → _icdlv_1) + ·(↑(λ() → _icdlv_1))`
`λ() → ·(↑stencil)(it)`

Then all previous occurences of `var` are replaced by this function.

`let(_icdlv_1, λ() → ·(↑stencil)(it))(·(↑(λ() → _icdlv_1()) + ·(↑(λ() → _icdlv_1()))`

The lift inliner can then later easily transform this into a nice expression:

`let(_icdlv_1, stencil(it))(_icdlv_1 + _icdlv_1)`
`let(_icdlv_1, λ() → stencil(it))(_icdlv_1() + _icdlv_1())`

Finally, recomputation is avoided by using the common subexpression elimination and lamba
inlining (can be configured opcount preserving). Both is up to the caller to do later.

`λ(_cs_1) → _cs_1 + _cs_1)(stencil(it))`

Note: This pass uses and preserves the `recorded_shifts` annex.
Note: This pass uses and preserves the `domain` and `recorded_shifts` annex.
"""

PRESERVED_ANNEX_ATTRS: ClassVar[tuple[str, ...]] = ("domain", "recorded_shifts")
Expand Down Expand Up @@ -78,20 +87,25 @@ def visit_FunCall(self, node: itir.FunCall, **kwargs):
assert isinstance(node.fun, itir.Lambda) # to make mypy happy
eligible_params = [False] * len(node.fun.params)
new_args = []
bound_scalars: dict[str, itir.Expr] = {}
# values are 0-ary lambda functions that evaluate to the derefed argument. We don't put
# the values themselves here as they might be inside of an if to protected from an oob
# access
evaluators: dict[str, itir.Expr] = {}

for i, (param, arg) in enumerate(zip(node.fun.params, node.args)):
if cpm.is_applied_lift(arg) and is_center_derefed_only(param):
eligible_params[i] = True
bound_arg_name = self.uids.sequential_id(prefix="_icdlv")
capture_lift = im.promote_to_const_iterator(bound_arg_name)
bound_arg_evaluator = self.uids.sequential_id(prefix="_icdlv")
capture_lift = im.promote_to_const_iterator(im.call(bound_arg_evaluator)())
trace_shifts.copy_recorded_shifts(from_=param, to=capture_lift)
new_args.append(capture_lift)
# since we deref an applied lift here we can (but don't need to) immediately
# inline
bound_scalars[bound_arg_name] = InlineLifts(
flags=InlineLifts.Flag.INLINE_TRIVIAL_DEREF_LIFT
).visit(im.deref(arg), recurse=False)
evaluators[bound_arg_evaluator] = im.lambda_()(
InlineLifts(flags=InlineLifts.Flag.INLINE_DEREF_LIFT).visit(
im.deref(arg), recurse=False
)
)
else:
new_args.append(arg)

Expand All @@ -100,6 +114,6 @@ def visit_FunCall(self, node: itir.FunCall, **kwargs):
im.call(node.fun)(*new_args), eligible_params=eligible_params
)
# TODO(tehrengruber): propagate let outwards
return im.let(*bound_scalars.items())(new_node)
return im.let(*evaluators.items())(new_node)

return node
Original file line number Diff line number Diff line change
Expand Up @@ -145,8 +145,8 @@ def test_staged_inlining():
)
expected = im.as_fieldop(
im.lambda_("a", "b")(
im.let("_icdlv_1", im.plus(im.deref("a"), im.deref("b")))(
im.plus(im.plus("_icdlv_1", 1), im.plus("_icdlv_1", 2))
im.let("_icdlv_1", im.lambda_()(im.plus(im.deref("a"), im.deref("b"))))(
im.plus(im.plus(im.call("_icdlv_1")(), 1), im.plus(im.call("_icdlv_1")(), 2))
)
),
d,
Expand Down Expand Up @@ -328,8 +328,8 @@ def test_chained_fusion():
)
expected = im.as_fieldop(
im.lambda_("inp1", "inp2")(
im.let("_icdlv_1", im.plus(im.deref("inp1"), im.deref("inp2")))(
im.plus("_icdlv_1", "_icdlv_1")
im.let("_icdlv_1", im.lambda_()(im.plus(im.deref("inp1"), im.deref("inp2"))))(
im.plus(im.call("_icdlv_1")(), im.call("_icdlv_1")())
)
),
d,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -6,20 +6,33 @@
# Please, refer to the LICENSE file in the root directory.
# SPDX-License-Identifier: BSD-3-Clause

from gt4py.next.type_system import type_specifications as ts
from gt4py.next.iterator import ir as itir
from gt4py.next.iterator.ir_utils import ir_makers as im
from gt4py.next.iterator.transforms import cse
from gt4py.next.iterator.transforms.inline_center_deref_lift_vars import InlineCenterDerefLiftVars

field_type = ts.FieldType(dims=[], dtype=ts.ScalarType(kind=ts.ScalarKind.FLOAT64))

def wrap_in_program(expr: itir.Expr) -> itir.Program:

def wrap_in_program(expr: itir.Expr, *, arg_dtypes=None) -> itir.Program:
if arg_dtypes is None:
arg_dtypes = [ts.ScalarKind.FLOAT64]
arg_types = [ts.FieldType(dims=[], dtype=ts.ScalarType(kind=dtype)) for dtype in arg_dtypes]
indices = [i for i in range(1, len(arg_dtypes) + 1)] if len(arg_dtypes) > 1 else [""]
return itir.Program(
id="f",
function_definitions=[],
params=[im.sym("d"), im.sym("inp"), im.sym("out")],
params=[
*(im.sym(f"inp{i}", type_) for i, type_ in zip(indices, arg_types)),
im.sym("out", field_type),
],
declarations=[],
body=[
itir.SetAt(
expr=im.as_fieldop(im.lambda_("it")(expr))(im.ref("inp")),
expr=im.as_fieldop(im.lambda_(*(f"it{i}" for i in indices))(expr))(
*(im.ref(f"inp{i}") for i in indices)
),
domain=im.call("cartesian_domain")(),
target=im.ref("out"),
)
Expand All @@ -34,15 +47,15 @@ def unwrap_from_program(program: itir.Program) -> itir.Expr:

def test_simple():
testee = im.let("var", im.lift("deref")("it"))(im.deref("var"))
expected = "(λ(_icdlv_1) → ·(↑(λ() → _icdlv_1))())(·it)"
expected = "(λ(_icdlv_1) → ·(↑(λ() → _icdlv_1()))())(λ() → ·it)"

actual = unwrap_from_program(InlineCenterDerefLiftVars.apply(wrap_in_program(testee)))
assert str(actual) == expected


def test_double_deref():
testee = im.let("var", im.lift("deref")("it"))(im.plus(im.deref("var"), im.deref("var")))
expected = "(λ(_icdlv_1) → ·(↑(λ() → _icdlv_1))() + ·(↑(λ() → _icdlv_1))())(·it)"
expected = "(λ(_icdlv_1) → ·(↑(λ() → _icdlv_1()))() + ·(↑(λ() → _icdlv_1()))())(λ() → ·it)"

actual = unwrap_from_program(InlineCenterDerefLiftVars.apply(wrap_in_program(testee)))
assert str(actual) == expected
Expand All @@ -62,3 +75,18 @@ def test_deref_at_multiple_pos():

actual = unwrap_from_program(InlineCenterDerefLiftVars.apply(wrap_in_program(testee)))
assert testee == actual


def test_bc():
# we also check that the common subexpression is able to extract the inlined value, such
# that it is only evaluated once
testee = im.let("var", im.lift("deref")("it2"))(
im.if_(im.deref("it1"), im.literal_from_value(0), im.plus(im.deref("var"), im.deref("var")))
)
expected = "(λ(_icdlv_1) → if ·it1 then 0 else (λ(_cs_1) → _cs_1 + _cs_1)(·(↑(λ() → _icdlv_1()))()))(λ() → ·it2)"

actual = InlineCenterDerefLiftVars.apply(
wrap_in_program(testee, arg_dtypes=[ts.ScalarKind.BOOL, ts.ScalarKind.FLOAT64])
)
simplified = unwrap_from_program(cse.CommonSubexpressionElimination.apply(actual))
assert str(simplified) == expected