From d16bbd5d69aea6972d4e8d2e62828fe06265602d Mon Sep 17 00:00:00 2001 From: tehrengruber Date: Sat, 15 Feb 2025 14:25:19 +0100 Subject: [PATCH 1/9] ITIR type inference: store param type in Lambda --- src/gt4py/next/iterator/type_system/inference.py | 4 ++++ .../unit_tests/iterator_tests/test_type_inference.py | 1 + 2 files changed, 5 insertions(+) diff --git a/src/gt4py/next/iterator/type_system/inference.py b/src/gt4py/next/iterator/type_system/inference.py index fe450625db..cc7a7123b9 100644 --- a/src/gt4py/next/iterator/type_system/inference.py +++ b/src/gt4py/next/iterator/type_system/inference.py @@ -37,6 +37,10 @@ def _set_node_type(node: itir.Node, type_: ts.TypeSpec) -> None: assert type_info.is_compatible_type( node.type, type_ ), "Node already has a type which differs." + if isinstance(node, itir.Lambda): + assert isinstance(type_, ts.FunctionType) + for param, param_type in zip(node.params, type_.pos_only_args): + _set_node_type(param, param_type) node.type = type_ diff --git a/tests/next_tests/unit_tests/iterator_tests/test_type_inference.py b/tests/next_tests/unit_tests/iterator_tests/test_type_inference.py index a39fe3c6d8..577c7bce1c 100644 --- a/tests/next_tests/unit_tests/iterator_tests/test_type_inference.py +++ b/tests/next_tests/unit_tests/iterator_tests/test_type_inference.py @@ -245,6 +245,7 @@ def test_aliased_function(): assert result.args[0].type == ts.FunctionType( pos_only_args=[int_type], pos_or_kw_args={}, kw_only_args={}, returns=int_type ) + assert result.args[0].params[0].type == int_type assert result.type == int_type From 813f3285bbf9b3d2876544e082baaf3394569bcf Mon Sep 17 00:00:00 2001 From: Till Ehrengruber Date: Tue, 18 Feb 2025 15:32:25 +0100 Subject: [PATCH 2/9] Flatten as_fieldop tuple arguments --- src/gt4py/next/iterator/ir_utils/misc.py | 22 ++- .../iterator/transforms/collapse_tuple.py | 151 +++++++++++++----- .../iterator/transforms/fuse_as_fieldop.py | 25 +-- .../transforms_tests/test_collapse_tuple.py | 29 +++- 4 files changed, 165 insertions(+), 62 deletions(-) diff --git a/src/gt4py/next/iterator/ir_utils/misc.py b/src/gt4py/next/iterator/ir_utils/misc.py index 03652cdf16..e04ccd7dd3 100644 --- a/src/gt4py/next/iterator/ir_utils/misc.py +++ b/src/gt4py/next/iterator/ir_utils/misc.py @@ -12,7 +12,7 @@ from gt4py import eve from gt4py.eve import utils as eve_utils from gt4py.next.iterator import ir as itir -from gt4py.next.iterator.ir_utils import ir_makers as im +from gt4py.next.iterator.ir_utils import common_pattern_matcher as cpm, ir_makers as im @dataclasses.dataclass(frozen=True) @@ -71,3 +71,23 @@ def is_equal(a: itir.Expr, b: itir.Expr): return a == b or ( CannonicalizeBoundSymbolNames.apply(a) == CannonicalizeBoundSymbolNames.apply(b) ) + + +def canonicalize_as_fieldop(expr: itir.FunCall) -> itir.FunCall: + """ + Canonicalize applied `as_fieldop`s. + + In case the stencil argument is a `deref` wrap it into a lambda such that we have a unified + format to work with (e.g. each parameter has a name without the need to special case). + """ + assert cpm.is_applied_as_fieldop(expr) + + stencil = expr.fun.args[0] # type: ignore[attr-defined] + domain = expr.fun.args[1] if len(expr.fun.args) > 1 else None # type: ignore[attr-defined] + if cpm.is_ref_to(stencil, "deref"): + stencil = im.lambda_("arg")(im.deref("arg")) + new_expr = im.as_fieldop(stencil, domain)(*expr.args) + + return new_expr + + return expr diff --git a/src/gt4py/next/iterator/transforms/collapse_tuple.py b/src/gt4py/next/iterator/transforms/collapse_tuple.py index 462f87b600..d22a92faf8 100644 --- a/src/gt4py/next/iterator/transforms/collapse_tuple.py +++ b/src/gt4py/next/iterator/transforms/collapse_tuple.py @@ -17,39 +17,48 @@ from gt4py import eve from gt4py.eve import utils as eve_utils from gt4py.next import common -from gt4py.next.iterator import ir +from gt4py.next.iterator import ir as itir from gt4py.next.iterator.ir_utils import ( common_pattern_matcher as cpm, ir_makers as im, misc as ir_misc, ) -from gt4py.next.iterator.transforms import fixed_point_transformation +from gt4py.next.iterator.transforms import fixed_point_transformation, inline_lifts from gt4py.next.iterator.transforms.inline_lambdas import InlineLambdas, inline_lambda -from gt4py.next.iterator.type_system import inference as itir_type_inference +from gt4py.next.iterator.type_system import ( + inference as itir_type_inference, + type_specifications as it_ts, +) from gt4py.next.type_system import type_info, type_specifications as ts -def _with_altered_arg(node: ir.FunCall, arg_idx: int, new_arg: ir.Expr | str): - """Given a itir.FunCall return a new call with one of its argument replaced.""" - return ir.FunCall( +def _with_altered_arg(node: itir.FunCall, arg_idx: int, new_arg: itir.Expr | str): + """Given a ititir.FunCall return a new call with one of its argument replaced.""" + return itir.FunCall( fun=node.fun, args=[arg if i != arg_idx else im.ensure_expr(new_arg) for i, arg in enumerate(node.args)], ) -def _is_trivial_make_tuple_call(node: ir.Expr): +def _with_altered_iterator_element_type(type_: it_ts.IteratorType, new_el_type: ts.DataType): + return it_ts.IteratorType( + position_dims=type_.position_dims, defined_dims=type_.defined_dims, element_type=new_el_type + ) + + +def _is_trivial_make_tuple_call(node: itir.Expr): """Return if node is a `make_tuple` call with all elements `SymRef`s, `Literal`s or tuples thereof.""" if not cpm.is_call_to(node, "make_tuple"): return False if not all( - isinstance(arg, (ir.SymRef, ir.Literal)) or _is_trivial_make_tuple_call(arg) + isinstance(arg, (itir.SymRef, itir.Literal)) or _is_trivial_make_tuple_call(arg) for arg in node.args ): return False return True -def _is_trivial_or_tuple_thereof_expr(node: ir.Node) -> bool: +def _is_trivial_or_tuple_thereof_expr(node: itir.Node) -> bool: """ Return `true` if the expr is a trivial expression (`SymRef` or `Literal`) or tuple thereof. @@ -65,7 +74,7 @@ def _is_trivial_or_tuple_thereof_expr(node: ir.Node) -> bool: ... ) True """ - if isinstance(node, (ir.SymRef, ir.Literal)): + if isinstance(node, (itir.SymRef, itir.Literal)): return True if cpm.is_call_to(node, "make_tuple"): return all(_is_trivial_or_tuple_thereof_expr(arg) for arg in node.args) @@ -138,6 +147,8 @@ class Transformation(enum.Flag): PROPAGATE_NESTED_LET = enum.auto() #: `let(a, 1)(a)` -> `1` or `let(a, b)(f(a))` -> `f(a)` INLINE_TRIVIAL_LET = enum.auto() + #: `as_fieldop(λ(t) → ·t[0]+·t[1])({a, b})` -> as_fieldop(λ(a, b) → ·a+·b)(a, b) + FLATTEN_AS_FIELDOP_ARGS = enum.auto() @classmethod def all(self) -> CollapseTuple.Transformation: @@ -152,7 +163,7 @@ def all(self) -> CollapseTuple.Transformation: @classmethod def apply( cls, - node: ir.Node, + node: itir.Node, *, ignore_tuple_size: bool = False, remove_letified_make_tuple_elements: bool = True, @@ -163,7 +174,7 @@ def apply( # allow sym references without a symbol declaration, mostly for testing allow_undeclared_symbols: bool = False, uids: Optional[eve_utils.UIDGenerator] = None, - ) -> ir.Node: + ) -> itir.Node: """ Simplifies `make_tuple`, `tuple_get` calls. @@ -181,7 +192,7 @@ def apply( offset_provider_type = offset_provider_type or {} uids = uids or eve_utils.UIDGenerator() - if isinstance(node, ir.Program): + if isinstance(node, itir.Program): within_stencil = False assert within_stencil in [ True, @@ -220,18 +231,18 @@ def visit(self, node, **kwargs): return super().visit(node, **kwargs) def transform_collapse_make_tuple_tuple_get( - self, node: ir.FunCall, **kwargs - ) -> Optional[ir.Node]: + self, node: itir.FunCall, **kwargs + ) -> Optional[itir.Node]: if cpm.is_call_to(node, "make_tuple") and all( cpm.is_call_to(arg, "tuple_get") for arg in node.args ): # `make_tuple(tuple_get(0, t), tuple_get(1, t), ..., tuple_get(N-1,t))` -> `t` - assert isinstance(node.args[0], ir.FunCall) + assert isinstance(node.args[0], itir.FunCall) first_expr = node.args[0].args[1] for i, v in enumerate(node.args): - assert isinstance(v, ir.FunCall) - assert isinstance(v.args[0], ir.Literal) + assert isinstance(v, itir.FunCall) + assert isinstance(v.args[0], itir.Literal) if not (int(v.args[0].value) == i and ir_misc.is_equal(v.args[1], first_expr)): # tuple argument differs, just continue with the rest of the tree return None @@ -248,11 +259,11 @@ def transform_collapse_make_tuple_tuple_get( return None def transform_collapse_tuple_get_make_tuple( - self, node: ir.FunCall, **kwargs - ) -> Optional[ir.Node]: + self, node: itir.FunCall, **kwargs + ) -> Optional[itir.Node]: if ( cpm.is_call_to(node, "tuple_get") - and isinstance(node.args[0], ir.Literal) + and isinstance(node.args[0], itir.Literal) and cpm.is_call_to(node.args[1], "make_tuple") ): # `tuple_get(i, make_tuple(e_0, e_1, ..., e_i, ..., e_N))` -> `e_i` @@ -265,8 +276,8 @@ def transform_collapse_tuple_get_make_tuple( return node.args[1].args[idx] return None - def transform_propagate_tuple_get(self, node: ir.FunCall, **kwargs) -> Optional[ir.Node]: - if cpm.is_call_to(node, "tuple_get") and isinstance(node.args[0], ir.Literal): + def transform_propagate_tuple_get(self, node: itir.FunCall, **kwargs) -> Optional[itir.Node]: + if cpm.is_call_to(node, "tuple_get") and isinstance(node.args[0], itir.Literal): # TODO(tehrengruber): extend to general symbols as long as the tail call in the let # does not capture # `tuple_get(i, let(...)(make_tuple()))` -> `let(...)(tuple_get(i, make_tuple()))` @@ -289,12 +300,14 @@ def transform_propagate_tuple_get(self, node: ir.FunCall, **kwargs) -> Optional[ ) return None - def transform_letify_make_tuple_elements(self, node: ir.Node, **kwargs) -> Optional[ir.Node]: + def transform_letify_make_tuple_elements( + self, node: itir.Node, **kwargs + ) -> Optional[itir.Node]: if cpm.is_call_to(node, "make_tuple"): # `make_tuple(expr1, expr1)` # -> `let((_tuple_el_1, expr1), (_tuple_el_2, expr2))(make_tuple(_tuple_el_1, _tuple_el_2))` - bound_vars: dict[ir.Sym, ir.Expr] = {} - new_args: list[ir.Expr] = [] + bound_vars: dict[itir.Sym, itir.Expr] = {} + new_args: list[itir.Expr] = [] for arg in node.args: if cpm.is_call_to(node, "make_tuple") and not _is_trivial_make_tuple_call(node): el_name = self.uids.sequential_id(prefix="__ct_el") @@ -309,7 +322,7 @@ def transform_letify_make_tuple_elements(self, node: ir.Node, **kwargs) -> Optio ) return None - def transform_inline_trivial_make_tuple(self, node: ir.Node, **kwargs) -> Optional[ir.Node]: + def transform_inline_trivial_make_tuple(self, node: itir.Node, **kwargs) -> Optional[itir.Node]: if cpm.is_let(node): # `let(tup, make_tuple(trivial_expr1, trivial_expr2))(foo(tup))` # -> `foo(make_tuple(trivial_expr1, trivial_expr2))` @@ -318,13 +331,15 @@ def transform_inline_trivial_make_tuple(self, node: ir.Node, **kwargs) -> Option return self.visit(inline_lambda(node, eligible_params=eligible_params), **kwargs) return None - def transform_propagate_to_if_on_tuples(self, node: ir.FunCall, **kwargs) -> Optional[ir.Node]: + def transform_propagate_to_if_on_tuples( + self, node: itir.FunCall, **kwargs + ) -> Optional[itir.Node]: if kwargs["within_stencil"]: # TODO(tehrengruber): This significantly increases the size of the tree. Skip transformation # in local-view for now. Revisit. return None - if isinstance(node, ir.FunCall) and not cpm.is_call_to(node, "if_"): + if isinstance(node, itir.FunCall) and not cpm.is_call_to(node, "if_"): # TODO(tehrengruber): Only inline if type of branch value is a tuple. # Examples: # `(if cond then {1, 2} else {3, 4})[0]` -> `if cond then {1, 2}[0] else {3, 4}[0]` @@ -343,8 +358,8 @@ def transform_propagate_to_if_on_tuples(self, node: ir.FunCall, **kwargs) -> Opt return None def transform_propagate_to_if_on_tuples_cps( - self, node: ir.FunCall, **kwargs - ) -> Optional[ir.Node]: + self, node: itir.FunCall, **kwargs + ) -> Optional[itir.Node]: # The basic idea of this transformation is to remove tuples across if-stmts by rewriting # the expression in continuation passing style, e.g. something like a tuple reordering # ``` @@ -366,7 +381,7 @@ def transform_propagate_to_if_on_tuples_cps( # `if True then {2, 1} else {4, 3}`. The examples in the comments below all refer to this # tuple reordering example here. - if not isinstance(node, ir.FunCall) or cpm.is_call_to(node, "if_"): + if not isinstance(node, itir.FunCall) or cpm.is_call_to(node, "if_"): return None # The first argument that is eligible also transforms all remaining args (They will be @@ -438,7 +453,7 @@ def transform_propagate_to_if_on_tuples_cps( return None - def transform_propagate_nested_let(self, node: ir.FunCall, **kwargs) -> Optional[ir.Node]: + def transform_propagate_nested_let(self, node: itir.FunCall, **kwargs) -> Optional[itir.Node]: if cpm.is_let(node): # `let((a, let(b, 1)(a_val)))(a)`-> `let(b, 1)(let(a, a_val)(a))` outer_vars = {} @@ -464,14 +479,76 @@ def transform_propagate_nested_let(self, node: ir.FunCall, **kwargs) -> Optional ) return None - def transform_inline_trivial_let(self, node: ir.FunCall, **kwargs) -> Optional[ir.Node]: + def transform_inline_trivial_let(self, node: itir.FunCall, **kwargs) -> Optional[itir.Node]: if cpm.is_let(node): - if isinstance(node.fun.expr, ir.SymRef): # type: ignore[attr-defined] # ensured by is_let + if isinstance(node.fun.expr, itir.SymRef): # type: ignore[attr-defined] # ensured by is_let # `let(a, 1)(a)` -> `1` for arg_sym, arg in zip(node.fun.params, node.args): # type: ignore[attr-defined] # ensured by is_let - if isinstance(node.fun.expr, ir.SymRef) and node.fun.expr.id == arg_sym.id: # type: ignore[attr-defined] # ensured by is_let + if isinstance(node.fun.expr, itir.SymRef) and node.fun.expr.id == arg_sym.id: # type: ignore[attr-defined] # ensured by is_let return arg - if any(trivial_args := [isinstance(arg, (ir.SymRef, ir.Literal)) for arg in node.args]): + if any( + trivial_args := [isinstance(arg, (itir.SymRef, itir.Literal)) for arg in node.args] + ): return inline_lambda(node, eligible_params=trivial_args) return None + + # TODO(tehrengruber): This is a transformation that should be executed before visiting the children. Then + # revisiting the body would not be needed. + def transform_flatten_as_fieldop_args( + self, node: itir.FunCall, **kwargs + ) -> Optional[itir.Node]: + if not cpm.is_applied_as_fieldop(node): + return None + + for arg in node.args: + itir_type_inference.reinfer(arg) + + if not any(isinstance(arg.type, ts.TupleType) for arg in node.args): + return None + + node = ir_misc.canonicalize_as_fieldop(node) + stencil = node.fun.args[0] # type: ignore[attr-defined] # ensured by cpm.is_applied_as_fieldop + new_body = stencil.expr + domain = node.fun.args[1] if len(node.fun.args) > 1 else None # type: ignore[attr-defined] # ensured by cpm.is_applied_as_fieldop + orig_args_map: dict[itir.Sym, itir.Expr] = {} + new_params: list[itir.Sym] = [] + new_args: list[itir.Expr] = [] + for param, arg in zip(stencil.params, node.args, strict=True): + if isinstance(arg.type, ts.TupleType): + ref_to_orig_arg = im.ref(f"__ct_flat_orig_arg_{len(orig_args_map)}", arg.type) + orig_args_map[im.sym(ref_to_orig_arg.id, arg.type)] = arg + new_params_inner, new_args_inner = [], [] + for i, type_ in enumerate(param.type.element_type.types): + new_params_inner.append( + im.sym( + f"__ct_flat_el{i}_{param.id}", + _with_altered_iterator_element_type(param.type, type_), + ) + ) + new_args_inner.append(im.tuple_get(i, ref_to_orig_arg)) + + param_substitute = im.lift( + im.lambda_(*new_params_inner)( + im.make_tuple(*[im.deref(im.ref(p.id, p.type)) for p in new_params_inner]) + ) + )(*[im.ref(p.id, p.type) for p in new_params_inner]) + + new_body = im.let(param.id, param_substitute)(new_body) + # note: the lift is trivial so inlining it is not an issue with respect to tree size + new_body = inline_lambda(new_body, force_inline_lift_args=True) + new_params.extend(new_params_inner) + new_args.extend(new_args_inner) + else: + new_params.append(param) + new_args.append(arg) + + # remove lifts again + new_body = inline_lifts.InlineLifts( + flags=inline_lifts.InlineLifts.Flag.INLINE_TRIVIAL_DEREF_LIFT + ).visit(new_body) + new_body = self.visit(new_body, **kwargs) + + return im.let(*orig_args_map.items())( + im.as_fieldop(im.lambda_(*new_params)(new_body), domain)(*new_args) + ) diff --git a/src/gt4py/next/iterator/transforms/fuse_as_fieldop.py b/src/gt4py/next/iterator/transforms/fuse_as_fieldop.py index 81633dfb87..ff1a1b36e8 100644 --- a/src/gt4py/next/iterator/transforms/fuse_as_fieldop.py +++ b/src/gt4py/next/iterator/transforms/fuse_as_fieldop.py @@ -21,6 +21,7 @@ common_pattern_matcher as cpm, domain_utils, ir_makers as im, + misc as ir_misc, ) from gt4py.next.iterator.transforms import ( fixed_point_transformation, @@ -46,26 +47,6 @@ def _merge_arguments( return new_args -def _canonicalize_as_fieldop(expr: itir.FunCall) -> itir.FunCall: - """ - Canonicalize applied `as_fieldop`s. - - In case the stencil argument is a `deref` wrap it into a lambda such that we have a unified - format to work with (e.g. each parameter has a name without the need to special case). - """ - assert cpm.is_applied_as_fieldop(expr) - - stencil = expr.fun.args[0] # type: ignore[attr-defined] - domain = expr.fun.args[1] if len(expr.fun.args) > 1 else None # type: ignore[attr-defined] - if cpm.is_ref_to(stencil, "deref"): - stencil = im.lambda_("arg")(im.deref("arg")) - new_expr = im.as_fieldop(stencil, domain)(*expr.args) - - return new_expr - - return expr - - def _is_tuple_expr_of_literals(expr: itir.Expr): if cpm.is_call_to(expr, "make_tuple"): return all(_is_tuple_expr_of_literals(arg) for arg in expr.args) @@ -78,7 +59,7 @@ def _inline_as_fieldop_arg( arg: itir.Expr, *, uids: eve_utils.UIDGenerator ) -> tuple[itir.Expr, dict[str, itir.Expr]]: assert cpm.is_applied_as_fieldop(arg) - arg = _canonicalize_as_fieldop(arg) + arg = ir_misc.canonicalize_as_fieldop(arg) stencil, *_ = arg.fun.args # type: ignore[attr-defined] # ensured by `is_applied_as_fieldop` inner_args: list[itir.Expr] = arg.args @@ -411,7 +392,7 @@ def transform_fuse_make_tuple(self, node: itir.Node, **kwargs): def transform_fuse_as_fieldop(self, node: itir.Node, **kwargs): if cpm.is_applied_as_fieldop(node): - node = _canonicalize_as_fieldop(node) + node = ir_misc.canonicalize_as_fieldop(node) stencil = node.fun.args[0] # type: ignore[attr-defined] # ensure cpm.is_applied_as_fieldop assert isinstance(stencil, itir.Lambda) or cpm.is_call_to(stencil, "scan") args: list[itir.Expr] = node.args diff --git a/tests/next_tests/unit_tests/iterator_tests/transforms_tests/test_collapse_tuple.py b/tests/next_tests/unit_tests/iterator_tests/transforms_tests/test_collapse_tuple.py index 916ae4e578..7813224f4d 100644 --- a/tests/next_tests/unit_tests/iterator_tests/transforms_tests/test_collapse_tuple.py +++ b/tests/next_tests/unit_tests/iterator_tests/transforms_tests/test_collapse_tuple.py @@ -5,11 +5,15 @@ # # Please, refer to the LICENSE file in the root directory. # SPDX-License-Identifier: BSD-3-Clause - +from gt4py.next import common from gt4py.next.iterator.ir_utils import ir_makers as im from gt4py.next.iterator.transforms.collapse_tuple import CollapseTuple from gt4py.next.type_system import type_specifications as ts -from next_tests.unit_tests.iterator_tests.test_type_inference import int_type +from gt4py.next.iterator.type_system import type_specifications as it_ts + +bool_type = ts.ScalarType(kind=ts.ScalarKind.BOOL) +int_type = ts.ScalarType(kind=ts.ScalarKind.INT32) +Vertex = common.Dimension(value="Vertex", kind=common.DimensionKind.HORIZONTAL) def test_simple_make_tuple_tuple_get(): @@ -311,3 +315,24 @@ def test_if_make_tuple_reorder_cps_external(): within_stencil=False, ) assert actual == expected + + +def test_flatten_as_fieldop_args(): + it_type = it_ts.IteratorType( + position_dims=[Vertex], + defined_dims=[Vertex], + element_type=ts.TupleType(types=[bool_type, int_type]), + ) + testee = im.as_fieldop(im.lambda_(im.sym("it", it_type))(im.tuple_get(1, im.deref("it"))))( + im.make_tuple(1, 2) + ) + expected = im.as_fieldop( + im.lambda_("__ct_flat_el0_it", "__ct_flat_el1_it")(im.deref("__ct_flat_el1_it")) + )(1, 2) + actual = CollapseTuple.apply( + testee, + enabled_transformations=~CollapseTuple.Transformation.PROPAGATE_TO_IF_ON_TUPLES, + allow_undeclared_symbols=True, + within_stencil=False, + ) + assert actual == expected From 374546133e96dd18ee01d4ee3e3ba54cdc841493 Mon Sep 17 00:00:00 2001 From: tehrengruber Date: Wed, 19 Feb 2025 02:10:34 +0100 Subject: [PATCH 3/9] Add support for scan and nested tuples --- src/gt4py/next/iterator/ir_utils/misc.py | 44 +++++++++++++ .../iterator/transforms/collapse_tuple.py | 26 ++++++-- .../iterator/transforms/fuse_as_fieldop.py | 46 +------------- .../transforms_tests/test_collapse_tuple.py | 62 ++++++++++++++++++- 4 files changed, 125 insertions(+), 53 deletions(-) diff --git a/src/gt4py/next/iterator/ir_utils/misc.py b/src/gt4py/next/iterator/ir_utils/misc.py index e04ccd7dd3..bcffd5fe51 100644 --- a/src/gt4py/next/iterator/ir_utils/misc.py +++ b/src/gt4py/next/iterator/ir_utils/misc.py @@ -91,3 +91,47 @@ def canonicalize_as_fieldop(expr: itir.FunCall) -> itir.FunCall: return new_expr return expr + + +def unwrap_scan(stencil: itir.Lambda | itir.FunCall): + """ + If given a scan, extract stencil part of its scan pass and a back-transformation into a scan. + + If a regular stencil is given the stencil is left as-is and the back-transformation is the + identity function. This function allows treating a scan stencil like a regular stencil during + a transformation avoiding the complexity introduced by the different IR format. + + >>> scan = im.call("scan")( + ... im.lambda_("state", "arg")(im.plus("state", im.deref("arg"))), True, 0.0 + ... ) + >>> stencil, back_trafo = _unwrap_scan(scan) + >>> str(stencil) + 'λ(arg) → state + ·arg' + >>> str(back_trafo(stencil)) + 'scan(λ(state, arg) → (λ(arg) → state + ·arg)(arg), True, 0.0)' + + In case a regular stencil is given it is returned as-is: + + >>> deref_stencil = im.lambda_("it")(im.deref("it")) + >>> stencil, back_trafo = _unwrap_scan(deref_stencil) + >>> assert stencil == deref_stencil + """ + if cpm.is_call_to(stencil, "scan"): + scan_pass, direction, init = stencil.args + assert isinstance(scan_pass, itir.Lambda) + # remove scan pass state to be used by caller + state_param = scan_pass.params[0] + stencil_like = im.lambda_(*scan_pass.params[1:])(scan_pass.expr) + + def restore_scan(transformed_stencil_like: itir.Lambda): + new_scan_pass = im.lambda_(state_param, *transformed_stencil_like.params)( + im.call(transformed_stencil_like)( + *(param.id for param in transformed_stencil_like.params) + ) + ) + return im.call("scan")(new_scan_pass, direction, init) + + return stencil_like, restore_scan + + assert isinstance(stencil, itir.Lambda) + return stencil, lambda s: s diff --git a/src/gt4py/next/iterator/transforms/collapse_tuple.py b/src/gt4py/next/iterator/transforms/collapse_tuple.py index d22a92faf8..f0afd5bafc 100644 --- a/src/gt4py/next/iterator/transforms/collapse_tuple.py +++ b/src/gt4py/next/iterator/transforms/collapse_tuple.py @@ -12,6 +12,7 @@ import enum import functools import operator +import re from typing import Optional from gt4py import eve @@ -91,6 +92,19 @@ def _is_trivial_or_tuple_thereof_expr(node: itir.Node) -> bool: return False +def _flattened_as_fieldop_param_el_name(param: str, idx: int) -> str: + prefix = "__ct_flat_el_" + + # keep the original param name, but skip prefix from previous flattenings + if param.startswith(prefix): + parent_idx, suffix = re.split(r"_(?!\d)", param[len(prefix) :], maxsplit=1) + prefix = f"{prefix}{parent_idx}_" + else: + suffix = param + + return f"{prefix}{idx}_{suffix}" + + # TODO(tehrengruber): Conceptually the structure of this pass makes sense: Visit depth first, # transform each node until no transformations apply anymore, whenever a node is to be transformed # go through all available transformation and apply them. However the final result here still @@ -508,7 +522,10 @@ def transform_flatten_as_fieldop_args( return None node = ir_misc.canonicalize_as_fieldop(node) - stencil = node.fun.args[0] # type: ignore[attr-defined] # ensured by cpm.is_applied_as_fieldop + stencil, restore_scan = ir_misc.unwrap_scan( + node.fun.args[0] # type: ignore[attr-defined] # ensured by cpm.is_applied_as_fieldop + ) + new_body = stencil.expr domain = node.fun.args[1] if len(node.fun.args) > 1 else None # type: ignore[attr-defined] # ensured by cpm.is_applied_as_fieldop orig_args_map: dict[itir.Sym, itir.Expr] = {} @@ -522,7 +539,7 @@ def transform_flatten_as_fieldop_args( for i, type_ in enumerate(param.type.element_type.types): new_params_inner.append( im.sym( - f"__ct_flat_el{i}_{param.id}", + _flattened_as_fieldop_param_el_name(param.id, i), _with_altered_iterator_element_type(param.type, type_), ) ) @@ -548,7 +565,6 @@ def transform_flatten_as_fieldop_args( flags=inline_lifts.InlineLifts.Flag.INLINE_TRIVIAL_DEREF_LIFT ).visit(new_body) new_body = self.visit(new_body, **kwargs) + new_stencil = restore_scan(im.lambda_(*new_params)(new_body)) - return im.let(*orig_args_map.items())( - im.as_fieldop(im.lambda_(*new_params)(new_body), domain)(*new_args) - ) + return im.let(*orig_args_map.items())(im.as_fieldop(new_stencil, domain)(*new_args)) diff --git a/src/gt4py/next/iterator/transforms/fuse_as_fieldop.py b/src/gt4py/next/iterator/transforms/fuse_as_fieldop.py index ff1a1b36e8..b746369152 100644 --- a/src/gt4py/next/iterator/transforms/fuse_as_fieldop.py +++ b/src/gt4py/next/iterator/transforms/fuse_as_fieldop.py @@ -95,50 +95,6 @@ def _inline_as_fieldop_arg( ), extracted_args -def _unwrap_scan(stencil: itir.Lambda | itir.FunCall): - """ - If given a scan, extract stencil part of its scan pass and a back-transformation into a scan. - - If a regular stencil is given the stencil is left as-is and the back-transformation is the - identity function. This function allows treating a scan stencil like a regular stencil during - a transformation avoiding the complexity introduced by the different IR format. - - >>> scan = im.call("scan")( - ... im.lambda_("state", "arg")(im.plus("state", im.deref("arg"))), True, 0.0 - ... ) - >>> stencil, back_trafo = _unwrap_scan(scan) - >>> str(stencil) - 'λ(arg) → state + ·arg' - >>> str(back_trafo(stencil)) - 'scan(λ(state, arg) → (λ(arg) → state + ·arg)(arg), True, 0.0)' - - In case a regular stencil is given it is returned as-is: - - >>> deref_stencil = im.lambda_("it")(im.deref("it")) - >>> stencil, back_trafo = _unwrap_scan(deref_stencil) - >>> assert stencil == deref_stencil - """ - if cpm.is_call_to(stencil, "scan"): - scan_pass, direction, init = stencil.args - assert isinstance(scan_pass, itir.Lambda) - # remove scan pass state to be used by caller - state_param = scan_pass.params[0] - stencil_like = im.lambda_(*scan_pass.params[1:])(scan_pass.expr) - - def restore_scan(transformed_stencil_like: itir.Lambda): - new_scan_pass = im.lambda_(state_param, *transformed_stencil_like.params)( - im.call(transformed_stencil_like)( - *(param.id for param in transformed_stencil_like.params) - ) - ) - return im.call("scan")(new_scan_pass, direction, init) - - return stencil_like, restore_scan - - assert isinstance(stencil, itir.Lambda) - return stencil, lambda s: s - - def fuse_as_fieldop( expr: itir.Expr, eligible_args: list[bool], *, uids: eve_utils.UIDGenerator ) -> itir.Expr: @@ -146,7 +102,7 @@ def fuse_as_fieldop( stencil: itir.Lambda = expr.fun.args[0] # type: ignore[attr-defined] # ensured by is_applied_as_fieldop assert isinstance(expr.fun.args[0], itir.Lambda) or cpm.is_call_to(stencil, "scan") # type: ignore[attr-defined] # ensured by is_applied_as_fieldop - stencil, restore_scan = _unwrap_scan(stencil) + stencil, restore_scan = ir_misc.unwrap_scan(stencil) domain = expr.fun.args[1] if len(expr.fun.args) > 1 else None # type: ignore[attr-defined] # ensured by is_applied_as_fieldop diff --git a/tests/next_tests/unit_tests/iterator_tests/transforms_tests/test_collapse_tuple.py b/tests/next_tests/unit_tests/iterator_tests/transforms_tests/test_collapse_tuple.py index 7813224f4d..08f1926277 100644 --- a/tests/next_tests/unit_tests/iterator_tests/transforms_tests/test_collapse_tuple.py +++ b/tests/next_tests/unit_tests/iterator_tests/transforms_tests/test_collapse_tuple.py @@ -11,7 +11,6 @@ from gt4py.next.type_system import type_specifications as ts from gt4py.next.iterator.type_system import type_specifications as it_ts -bool_type = ts.ScalarType(kind=ts.ScalarKind.BOOL) int_type = ts.ScalarType(kind=ts.ScalarKind.INT32) Vertex = common.Dimension(value="Vertex", kind=common.DimensionKind.HORIZONTAL) @@ -321,13 +320,70 @@ def test_flatten_as_fieldop_args(): it_type = it_ts.IteratorType( position_dims=[Vertex], defined_dims=[Vertex], - element_type=ts.TupleType(types=[bool_type, int_type]), + element_type=ts.TupleType(types=[int_type, int_type]), ) testee = im.as_fieldop(im.lambda_(im.sym("it", it_type))(im.tuple_get(1, im.deref("it"))))( im.make_tuple(1, 2) ) expected = im.as_fieldop( - im.lambda_("__ct_flat_el0_it", "__ct_flat_el1_it")(im.deref("__ct_flat_el1_it")) + im.lambda_("__ct_flat_el_0_it", "__ct_flat_el_1_it")(im.deref("__ct_flat_el_1_it")) + )(1, 2) + actual = CollapseTuple.apply( + testee, + enabled_transformations=~CollapseTuple.Transformation.PROPAGATE_TO_IF_ON_TUPLES, + allow_undeclared_symbols=True, + within_stencil=False, + ) + assert actual == expected + + +def test_flatten_as_fieldop_args_nested(): + it_type = it_ts.IteratorType( + position_dims=[Vertex], + defined_dims=[Vertex], + element_type=ts.TupleType( + types=[ + int_type, + ts.TupleType(types=[int_type, int_type]), + ] + ), + ) + testee = im.as_fieldop( + im.lambda_(im.sym("it", it_type))(im.tuple_get(1, im.tuple_get(1, im.deref("it")))) + )(im.make_tuple(1, im.make_tuple(2, 3))) + expected = im.as_fieldop( + im.lambda_("__ct_flat_el_0_it", "__ct_flat_el_1_0_it", "__ct_flat_el_1_1_it")( + im.deref("__ct_flat_el_1_1_it") + ) + )(1, 2, 3) + actual = CollapseTuple.apply( + testee, + enabled_transformations=~CollapseTuple.Transformation.PROPAGATE_TO_IF_ON_TUPLES, + allow_undeclared_symbols=True, + within_stencil=False, + ) + assert actual == expected + + +def test_flatten_as_fieldop_args_scan(): + it_type = it_ts.IteratorType( + position_dims=[Vertex], + defined_dims=[Vertex], + element_type=ts.TupleType(types=[int_type, int_type]), + ) + testee = im.as_fieldop( + im.scan( + im.lambda_("state", im.sym("it", it_type))(im.tuple_get(1, im.deref("it"))), True, 0 + ) + )(im.make_tuple(1, 2)) + expected = im.as_fieldop( + im.scan( + im.lambda_("state", "__ct_flat_el_0_it", "__ct_flat_el_1_it")( + im.deref("__ct_flat_el_1_it") + ), + True, + 0, + ) )(1, 2) actual = CollapseTuple.apply( testee, From fc20d7c5e2ff4d71abfd9c70f07434b8ab9a6b0f Mon Sep 17 00:00:00 2001 From: tehrengruber Date: Wed, 19 Feb 2025 12:11:33 +0100 Subject: [PATCH 4/9] Fix doctest --- src/gt4py/next/iterator/ir_utils/misc.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/gt4py/next/iterator/ir_utils/misc.py b/src/gt4py/next/iterator/ir_utils/misc.py index bcffd5fe51..00a1ab5609 100644 --- a/src/gt4py/next/iterator/ir_utils/misc.py +++ b/src/gt4py/next/iterator/ir_utils/misc.py @@ -104,7 +104,7 @@ def unwrap_scan(stencil: itir.Lambda | itir.FunCall): >>> scan = im.call("scan")( ... im.lambda_("state", "arg")(im.plus("state", im.deref("arg"))), True, 0.0 ... ) - >>> stencil, back_trafo = _unwrap_scan(scan) + >>> stencil, back_trafo = unwrap_scan(scan) >>> str(stencil) 'λ(arg) → state + ·arg' >>> str(back_trafo(stencil)) @@ -113,7 +113,7 @@ def unwrap_scan(stencil: itir.Lambda | itir.FunCall): In case a regular stencil is given it is returned as-is: >>> deref_stencil = im.lambda_("it")(im.deref("it")) - >>> stencil, back_trafo = _unwrap_scan(deref_stencil) + >>> stencil, back_trafo = unwrap_scan(deref_stencil) >>> assert stencil == deref_stencil """ if cpm.is_call_to(stencil, "scan"): From 335e932086b19481ef904a65417ac1bd611fcc59 Mon Sep 17 00:00:00 2001 From: tehrengruber Date: Thu, 20 Feb 2025 13:13:15 +0100 Subject: [PATCH 5/9] Fix broken scan (e.g. test_tuple_scalar_scan) --- src/gt4py/next/iterator/ir_utils/misc.py | 18 +++++++++++++++--- .../next/iterator/transforms/collapse_tuple.py | 15 ++++++++++----- .../transforms_tests/test_collapse_tuple.py | 9 ++++++--- 3 files changed, 31 insertions(+), 11 deletions(-) diff --git a/src/gt4py/next/iterator/ir_utils/misc.py b/src/gt4py/next/iterator/ir_utils/misc.py index 00a1ab5609..03a3dfb0e3 100644 --- a/src/gt4py/next/iterator/ir_utils/misc.py +++ b/src/gt4py/next/iterator/ir_utils/misc.py @@ -93,6 +93,16 @@ def canonicalize_as_fieldop(expr: itir.FunCall) -> itir.FunCall: return expr +def _remove_let_alias(let_expr: itir.FunCall): + assert cpm.is_let(let_expr) + is_aliased_let = True + for param, arg in zip(let_expr.fun.params, let_expr.args, strict=True): # type: ignore[attr-defined] # ensured by cpm.is_let + is_aliased_let &= cpm.is_ref_to(arg, param.id) + if is_aliased_let: + return let_expr.fun.expr # type: ignore[attr-defined] # ensured by cpm.is_let + return let_expr + + def unwrap_scan(stencil: itir.Lambda | itir.FunCall): """ If given a scan, extract stencil part of its scan pass and a back-transformation into a scan. @@ -108,7 +118,7 @@ def unwrap_scan(stencil: itir.Lambda | itir.FunCall): >>> str(stencil) 'λ(arg) → state + ·arg' >>> str(back_trafo(stencil)) - 'scan(λ(state, arg) → (λ(arg) → state + ·arg)(arg), True, 0.0)' + 'scan(λ(state, arg) → state + ·arg, True, 0.0)' In case a regular stencil is given it is returned as-is: @@ -125,8 +135,10 @@ def unwrap_scan(stencil: itir.Lambda | itir.FunCall): def restore_scan(transformed_stencil_like: itir.Lambda): new_scan_pass = im.lambda_(state_param, *transformed_stencil_like.params)( - im.call(transformed_stencil_like)( - *(param.id for param in transformed_stencil_like.params) + _remove_let_alias( + im.call(transformed_stencil_like)( + *(param.id for param in transformed_stencil_like.params) + ) ) ) return im.call("scan")(new_scan_pass, direction, init) diff --git a/src/gt4py/next/iterator/transforms/collapse_tuple.py b/src/gt4py/next/iterator/transforms/collapse_tuple.py index f0afd5bafc..03364451b4 100644 --- a/src/gt4py/next/iterator/transforms/collapse_tuple.py +++ b/src/gt4py/next/iterator/transforms/collapse_tuple.py @@ -18,7 +18,7 @@ from gt4py import eve from gt4py.eve import utils as eve_utils from gt4py.next import common -from gt4py.next.iterator import ir as itir +from gt4py.next.iterator import ir, ir as itir from gt4py.next.iterator.ir_utils import ( common_pattern_matcher as cpm, ir_makers as im, @@ -51,10 +51,7 @@ def _is_trivial_make_tuple_call(node: itir.Expr): """Return if node is a `make_tuple` call with all elements `SymRef`s, `Literal`s or tuples thereof.""" if not cpm.is_call_to(node, "make_tuple"): return False - if not all( - isinstance(arg, (itir.SymRef, itir.Literal)) or _is_trivial_make_tuple_call(arg) - for arg in node.args - ): + if not all(_is_trivial_or_tuple_thereof_expr(arg) for arg in node.args): return False return True @@ -163,6 +160,8 @@ class Transformation(enum.Flag): INLINE_TRIVIAL_LET = enum.auto() #: `as_fieldop(λ(t) → ·t[0]+·t[1])({a, b})` -> as_fieldop(λ(a, b) → ·a+·b)(a, b) FLATTEN_AS_FIELDOP_ARGS = enum.auto() + #: `let(a, b[1])(a)` -> `b[1]` + INLINE_TRIVIAL_TUPLE_LET_VAR = enum.auto() @classmethod def all(self) -> CollapseTuple.Transformation: @@ -507,6 +506,12 @@ def transform_inline_trivial_let(self, node: itir.FunCall, **kwargs) -> Optional return None + def transform_inline_trivial_tuple_let_var(self, node: ir.Node, **kwargs) -> Optional[ir.Node]: + if cpm.is_let(node): + if any(trivial_args := [_is_trivial_or_tuple_thereof_expr(arg) for arg in node.args]): + return inline_lambda(node, eligible_params=trivial_args) + return None + # TODO(tehrengruber): This is a transformation that should be executed before visiting the children. Then # revisiting the body would not be needed. def transform_flatten_as_fieldop_args( diff --git a/tests/next_tests/unit_tests/iterator_tests/transforms_tests/test_collapse_tuple.py b/tests/next_tests/unit_tests/iterator_tests/transforms_tests/test_collapse_tuple.py index 08f1926277..e8d04096b4 100644 --- a/tests/next_tests/unit_tests/iterator_tests/transforms_tests/test_collapse_tuple.py +++ b/tests/next_tests/unit_tests/iterator_tests/transforms_tests/test_collapse_tuple.py @@ -128,8 +128,11 @@ def test_propagate_tuple_get(): def test_letify_make_tuple_elements(): - # anything that is not trivial, i.e. a SymRef, works here - el1, el2 = im.let("foo", "foo")("foo"), im.let("bar", "bar")("bar") + fun_type = ts.FunctionType( + pos_only_args=[], pos_or_kw_args={}, kw_only_args={}, returns=int_type + ) + # anything that is not trivial, works here + el1, el2 = im.call(im.ref("foo", fun_type))(), im.call(im.ref("bar", fun_type))() testee = im.make_tuple(el1, el2) expected = im.let(("__ct_el_1", el1), ("__ct_el_2", el2))( im.make_tuple("__ct_el_1", "__ct_el_2") @@ -391,4 +394,4 @@ def test_flatten_as_fieldop_args_scan(): allow_undeclared_symbols=True, within_stencil=False, ) - assert actual == expected + assert actual == expected \ No newline at end of file From d399c65848f2916c0c760ed217ec80d62d5a887e Mon Sep 17 00:00:00 2001 From: tehrengruber Date: Thu, 20 Feb 2025 13:34:51 +0100 Subject: [PATCH 6/9] Fix format --- .../iterator_tests/transforms_tests/test_collapse_tuple.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/next_tests/unit_tests/iterator_tests/transforms_tests/test_collapse_tuple.py b/tests/next_tests/unit_tests/iterator_tests/transforms_tests/test_collapse_tuple.py index e8d04096b4..636e66940c 100644 --- a/tests/next_tests/unit_tests/iterator_tests/transforms_tests/test_collapse_tuple.py +++ b/tests/next_tests/unit_tests/iterator_tests/transforms_tests/test_collapse_tuple.py @@ -394,4 +394,4 @@ def test_flatten_as_fieldop_args_scan(): allow_undeclared_symbols=True, within_stencil=False, ) - assert actual == expected \ No newline at end of file + assert actual == expected From 5ad77013267c520d44ae59dcaeefba7dadb3ee20 Mon Sep 17 00:00:00 2001 From: tehrengruber Date: Thu, 20 Feb 2025 15:29:50 +0100 Subject: [PATCH 7/9] Fix failing tests --- src/gt4py/next/iterator/ir_utils/ir_makers.py | 6 ++- .../iterator/transforms/collapse_tuple.py | 39 +++++++++++++------ 2 files changed, 32 insertions(+), 13 deletions(-) diff --git a/src/gt4py/next/iterator/ir_utils/ir_makers.py b/src/gt4py/next/iterator/ir_utils/ir_makers.py index 42b82ffdd0..bdb9bd8249 100644 --- a/src/gt4py/next/iterator/ir_utils/ir_makers.py +++ b/src/gt4py/next/iterator/ir_utils/ir_makers.py @@ -445,7 +445,7 @@ def as_fieldop(expr: itir.Expr | str, domain: Optional[itir.Expr] = None) -> Cal >>> str(as_fieldop(lambda_("it1", "it2")(plus(deref("it1"), deref("it2"))))("field1", "field2")) '(⇑(λ(it1, it2) → ·it1 + ·it2))(field1, field2)' """ - from gt4py.next.iterator.ir_utils import domain_utils + from gt4py.next.iterator.ir_utils import common_pattern_matcher as cpm, domain_utils result = call( call("as_fieldop")( @@ -462,7 +462,9 @@ def as_fieldop(expr: itir.Expr | str, domain: Optional[itir.Expr] = None) -> Cal def _populate_domain_annex_wrapper(*args, **kwargs): node = result(*args, **kwargs) - if domain: + # note: if the domain is not a direct construction, e.g. because it is only a reference + # to a domain defined in a let, don't populate the annex + if domain and cpm.is_call_to(domain, ("cartesian_domain", "unstructured_domain")): node.annex.domain = domain_utils.SymbolicDomain.from_expr(domain) return node diff --git a/src/gt4py/next/iterator/transforms/collapse_tuple.py b/src/gt4py/next/iterator/transforms/collapse_tuple.py index 03364451b4..6b04790644 100644 --- a/src/gt4py/next/iterator/transforms/collapse_tuple.py +++ b/src/gt4py/next/iterator/transforms/collapse_tuple.py @@ -13,7 +13,7 @@ import functools import operator import re -from typing import Optional +from typing import Literal, Optional from gt4py import eve from gt4py.eve import utils as eve_utils @@ -41,12 +41,24 @@ def _with_altered_arg(node: itir.FunCall, arg_idx: int, new_arg: itir.Expr | str ) -def _with_altered_iterator_element_type(type_: it_ts.IteratorType, new_el_type: ts.DataType): +def _with_altered_iterator_element_type( + type_: it_ts.IteratorType, new_el_type: ts.DataType +) -> it_ts.IteratorType: return it_ts.IteratorType( position_dims=type_.position_dims, defined_dims=type_.defined_dims, element_type=new_el_type ) +def _with_altered_iterator_position_dims( + type_: it_ts.IteratorType, new_position_dims: list[common.Dimension] | Literal["unknown"] +) -> it_ts.IteratorType: + return it_ts.IteratorType( + position_dims=new_position_dims, + defined_dims=type_.defined_dims, + element_type=type_.element_type, + ) + + def _is_trivial_make_tuple_call(node: itir.Expr): """Return if node is a `make_tuple` call with all elements `SymRef`s, `Literal`s or tuples thereof.""" if not cpm.is_call_to(node, "make_tuple"): @@ -540,19 +552,24 @@ def transform_flatten_as_fieldop_args( if isinstance(arg.type, ts.TupleType): ref_to_orig_arg = im.ref(f"__ct_flat_orig_arg_{len(orig_args_map)}", arg.type) orig_args_map[im.sym(ref_to_orig_arg.id, arg.type)] = arg - new_params_inner, new_args_inner = [], [] + new_params_inner, lift_params = [], [] for i, type_ in enumerate(param.type.element_type.types): - new_params_inner.append( + new_param = im.sym( + _flattened_as_fieldop_param_el_name(param.id, i), + _with_altered_iterator_element_type(param.type, type_), + ) + lift_params.append( im.sym( - _flattened_as_fieldop_param_el_name(param.id, i), - _with_altered_iterator_element_type(param.type, type_), + new_param.id, + _with_altered_iterator_position_dims(new_param.type, "unknown"), ) ) - new_args_inner.append(im.tuple_get(i, ref_to_orig_arg)) + new_params_inner.append(new_param) + new_args.append(im.tuple_get(i, ref_to_orig_arg)) param_substitute = im.lift( - im.lambda_(*new_params_inner)( - im.make_tuple(*[im.deref(im.ref(p.id, p.type)) for p in new_params_inner]) + im.lambda_(*lift_params)( + im.make_tuple(*[im.deref(im.ref(p.id, p.type)) for p in lift_params]) ) )(*[im.ref(p.id, p.type) for p in new_params_inner]) @@ -560,14 +577,14 @@ def transform_flatten_as_fieldop_args( # note: the lift is trivial so inlining it is not an issue with respect to tree size new_body = inline_lambda(new_body, force_inline_lift_args=True) new_params.extend(new_params_inner) - new_args.extend(new_args_inner) else: new_params.append(param) new_args.append(arg) # remove lifts again new_body = inline_lifts.InlineLifts( - flags=inline_lifts.InlineLifts.Flag.INLINE_TRIVIAL_DEREF_LIFT + flags=inline_lifts.InlineLifts.Flag.INLINE_DEREF_LIFT + | inline_lifts.InlineLifts.Flag.PROPAGATE_SHIFT ).visit(new_body) new_body = self.visit(new_body, **kwargs) new_stencil = restore_scan(im.lambda_(*new_params)(new_body)) From d3957bd805f96c1d16e8891b9976a08ecd872366 Mon Sep 17 00:00:00 2001 From: tehrengruber Date: Thu, 20 Feb 2025 15:30:53 +0100 Subject: [PATCH 8/9] Fix format --- src/gt4py/next/iterator/transforms/collapse_tuple.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/gt4py/next/iterator/transforms/collapse_tuple.py b/src/gt4py/next/iterator/transforms/collapse_tuple.py index 6b04790644..97c8f1ca02 100644 --- a/src/gt4py/next/iterator/transforms/collapse_tuple.py +++ b/src/gt4py/next/iterator/transforms/collapse_tuple.py @@ -561,7 +561,7 @@ def transform_flatten_as_fieldop_args( lift_params.append( im.sym( new_param.id, - _with_altered_iterator_position_dims(new_param.type, "unknown"), + _with_altered_iterator_position_dims(new_param.type, "unknown"), # type: ignore[arg-type] # always in IteratorType ) ) new_params_inner.append(new_param) From b52a07c4288fe9b1dd6ecca23b050e8f7eeceec1 Mon Sep 17 00:00:00 2001 From: tehrengruber Date: Thu, 20 Feb 2025 16:14:01 +0100 Subject: [PATCH 9/9] Cleanup --- .../next/iterator/transforms/collapse_tuple.py | 18 ++++++++++++------ 1 file changed, 12 insertions(+), 6 deletions(-) diff --git a/src/gt4py/next/iterator/transforms/collapse_tuple.py b/src/gt4py/next/iterator/transforms/collapse_tuple.py index 97c8f1ca02..923f6d1302 100644 --- a/src/gt4py/next/iterator/transforms/collapse_tuple.py +++ b/src/gt4py/next/iterator/transforms/collapse_tuple.py @@ -170,7 +170,7 @@ class Transformation(enum.Flag): PROPAGATE_NESTED_LET = enum.auto() #: `let(a, 1)(a)` -> `1` or `let(a, b)(f(a))` -> `f(a)` INLINE_TRIVIAL_LET = enum.auto() - #: `as_fieldop(λ(t) → ·t[0]+·t[1])({a, b})` -> as_fieldop(λ(a, b) → ·a+·b)(a, b) + #: `as_fieldop(λ(t) → ·t[0]+·t[1])({a, b})` -> `as_fieldop(λ(a, b) → ·a+·b)(a, b)` FLATTEN_AS_FIELDOP_ARGS = enum.auto() #: `let(a, b[1])(a)` -> `b[1]` INLINE_TRIVIAL_TUPLE_LET_VAR = enum.auto() @@ -529,6 +529,7 @@ def transform_inline_trivial_tuple_let_var(self, node: ir.Node, **kwargs) -> Opt def transform_flatten_as_fieldop_args( self, node: itir.FunCall, **kwargs ) -> Optional[itir.Node]: + # `as_fieldop(λ(t) → ·t[0]+·t[1])({a, b})` -> `as_fieldop(λ(a, b) → ·a+·b)(a, b)` if not cpm.is_applied_as_fieldop(node): return None @@ -545,13 +546,15 @@ def transform_flatten_as_fieldop_args( new_body = stencil.expr domain = node.fun.args[1] if len(node.fun.args) > 1 else None # type: ignore[attr-defined] # ensured by cpm.is_applied_as_fieldop - orig_args_map: dict[itir.Sym, itir.Expr] = {} + remapped_args: dict[ + itir.Sym, itir.Expr + ] = {} # contains the arguments that are remapped, e.g. `{a, b}` new_params: list[itir.Sym] = [] new_args: list[itir.Expr] = [] for param, arg in zip(stencil.params, node.args, strict=True): if isinstance(arg.type, ts.TupleType): - ref_to_orig_arg = im.ref(f"__ct_flat_orig_arg_{len(orig_args_map)}", arg.type) - orig_args_map[im.sym(ref_to_orig_arg.id, arg.type)] = arg + ref_to_remapped_arg = im.ref(f"__ct_flat_remapped_{len(remapped_args)}", arg.type) + remapped_args[im.sym(ref_to_remapped_arg.id, arg.type)] = arg new_params_inner, lift_params = [], [] for i, type_ in enumerate(param.type.element_type.types): new_param = im.sym( @@ -565,8 +568,10 @@ def transform_flatten_as_fieldop_args( ) ) new_params_inner.append(new_param) - new_args.append(im.tuple_get(i, ref_to_orig_arg)) + new_args.append(im.tuple_get(i, ref_to_remapped_arg)) + # an iterator that substitutes the original (tuple) iterator, e.g. `t`. Built + # from the new parameters which are the elements of `t`. param_substitute = im.lift( im.lambda_(*lift_params)( im.make_tuple(*[im.deref(im.ref(p.id, p.type)) for p in lift_params]) @@ -576,6 +581,7 @@ def transform_flatten_as_fieldop_args( new_body = im.let(param.id, param_substitute)(new_body) # note: the lift is trivial so inlining it is not an issue with respect to tree size new_body = inline_lambda(new_body, force_inline_lift_args=True) + new_params.extend(new_params_inner) else: new_params.append(param) @@ -589,4 +595,4 @@ def transform_flatten_as_fieldop_args( new_body = self.visit(new_body, **kwargs) new_stencil = restore_scan(im.lambda_(*new_params)(new_body)) - return im.let(*orig_args_map.items())(im.as_fieldop(new_stencil, domain)(*new_args)) + return im.let(*remapped_args.items())(im.as_fieldop(new_stencil, domain)(*new_args))