diff --git a/src/gt4py/next/iterator/ir_utils/ir_makers.py b/src/gt4py/next/iterator/ir_utils/ir_makers.py index 42b82ffdd0..bdb9bd8249 100644 --- a/src/gt4py/next/iterator/ir_utils/ir_makers.py +++ b/src/gt4py/next/iterator/ir_utils/ir_makers.py @@ -445,7 +445,7 @@ def as_fieldop(expr: itir.Expr | str, domain: Optional[itir.Expr] = None) -> Cal >>> str(as_fieldop(lambda_("it1", "it2")(plus(deref("it1"), deref("it2"))))("field1", "field2")) '(⇑(λ(it1, it2) → ·it1 + ·it2))(field1, field2)' """ - from gt4py.next.iterator.ir_utils import domain_utils + from gt4py.next.iterator.ir_utils import common_pattern_matcher as cpm, domain_utils result = call( call("as_fieldop")( @@ -462,7 +462,9 @@ def as_fieldop(expr: itir.Expr | str, domain: Optional[itir.Expr] = None) -> Cal def _populate_domain_annex_wrapper(*args, **kwargs): node = result(*args, **kwargs) - if domain: + # note: if the domain is not a direct construction, e.g. because it is only a reference + # to a domain defined in a let, don't populate the annex + if domain and cpm.is_call_to(domain, ("cartesian_domain", "unstructured_domain")): node.annex.domain = domain_utils.SymbolicDomain.from_expr(domain) return node diff --git a/src/gt4py/next/iterator/ir_utils/misc.py b/src/gt4py/next/iterator/ir_utils/misc.py index 03652cdf16..03a3dfb0e3 100644 --- a/src/gt4py/next/iterator/ir_utils/misc.py +++ b/src/gt4py/next/iterator/ir_utils/misc.py @@ -12,7 +12,7 @@ from gt4py import eve from gt4py.eve import utils as eve_utils from gt4py.next.iterator import ir as itir -from gt4py.next.iterator.ir_utils import ir_makers as im +from gt4py.next.iterator.ir_utils import common_pattern_matcher as cpm, ir_makers as im @dataclasses.dataclass(frozen=True) @@ -71,3 +71,79 @@ def is_equal(a: itir.Expr, b: itir.Expr): return a == b or ( CannonicalizeBoundSymbolNames.apply(a) == CannonicalizeBoundSymbolNames.apply(b) ) + + +def canonicalize_as_fieldop(expr: itir.FunCall) -> itir.FunCall: + """ + Canonicalize applied `as_fieldop`s. + + In case the stencil argument is a `deref` wrap it into a lambda such that we have a unified + format to work with (e.g. each parameter has a name without the need to special case). + """ + assert cpm.is_applied_as_fieldop(expr) + + stencil = expr.fun.args[0] # type: ignore[attr-defined] + domain = expr.fun.args[1] if len(expr.fun.args) > 1 else None # type: ignore[attr-defined] + if cpm.is_ref_to(stencil, "deref"): + stencil = im.lambda_("arg")(im.deref("arg")) + new_expr = im.as_fieldop(stencil, domain)(*expr.args) + + return new_expr + + return expr + + +def _remove_let_alias(let_expr: itir.FunCall): + assert cpm.is_let(let_expr) + is_aliased_let = True + for param, arg in zip(let_expr.fun.params, let_expr.args, strict=True): # type: ignore[attr-defined] # ensured by cpm.is_let + is_aliased_let &= cpm.is_ref_to(arg, param.id) + if is_aliased_let: + return let_expr.fun.expr # type: ignore[attr-defined] # ensured by cpm.is_let + return let_expr + + +def unwrap_scan(stencil: itir.Lambda | itir.FunCall): + """ + If given a scan, extract stencil part of its scan pass and a back-transformation into a scan. + + If a regular stencil is given the stencil is left as-is and the back-transformation is the + identity function. This function allows treating a scan stencil like a regular stencil during + a transformation avoiding the complexity introduced by the different IR format. + + >>> scan = im.call("scan")( + ... im.lambda_("state", "arg")(im.plus("state", im.deref("arg"))), True, 0.0 + ... ) + >>> stencil, back_trafo = unwrap_scan(scan) + >>> str(stencil) + 'λ(arg) → state + ·arg' + >>> str(back_trafo(stencil)) + 'scan(λ(state, arg) → state + ·arg, True, 0.0)' + + In case a regular stencil is given it is returned as-is: + + >>> deref_stencil = im.lambda_("it")(im.deref("it")) + >>> stencil, back_trafo = unwrap_scan(deref_stencil) + >>> assert stencil == deref_stencil + """ + if cpm.is_call_to(stencil, "scan"): + scan_pass, direction, init = stencil.args + assert isinstance(scan_pass, itir.Lambda) + # remove scan pass state to be used by caller + state_param = scan_pass.params[0] + stencil_like = im.lambda_(*scan_pass.params[1:])(scan_pass.expr) + + def restore_scan(transformed_stencil_like: itir.Lambda): + new_scan_pass = im.lambda_(state_param, *transformed_stencil_like.params)( + _remove_let_alias( + im.call(transformed_stencil_like)( + *(param.id for param in transformed_stencil_like.params) + ) + ) + ) + return im.call("scan")(new_scan_pass, direction, init) + + return stencil_like, restore_scan + + assert isinstance(stencil, itir.Lambda) + return stencil, lambda s: s diff --git a/src/gt4py/next/iterator/transforms/collapse_tuple.py b/src/gt4py/next/iterator/transforms/collapse_tuple.py index 462f87b600..923f6d1302 100644 --- a/src/gt4py/next/iterator/transforms/collapse_tuple.py +++ b/src/gt4py/next/iterator/transforms/collapse_tuple.py @@ -12,44 +12,63 @@ import enum import functools import operator -from typing import Optional +import re +from typing import Literal, Optional from gt4py import eve from gt4py.eve import utils as eve_utils from gt4py.next import common -from gt4py.next.iterator import ir +from gt4py.next.iterator import ir, ir as itir from gt4py.next.iterator.ir_utils import ( common_pattern_matcher as cpm, ir_makers as im, misc as ir_misc, ) -from gt4py.next.iterator.transforms import fixed_point_transformation +from gt4py.next.iterator.transforms import fixed_point_transformation, inline_lifts from gt4py.next.iterator.transforms.inline_lambdas import InlineLambdas, inline_lambda -from gt4py.next.iterator.type_system import inference as itir_type_inference +from gt4py.next.iterator.type_system import ( + inference as itir_type_inference, + type_specifications as it_ts, +) from gt4py.next.type_system import type_info, type_specifications as ts -def _with_altered_arg(node: ir.FunCall, arg_idx: int, new_arg: ir.Expr | str): - """Given a itir.FunCall return a new call with one of its argument replaced.""" - return ir.FunCall( +def _with_altered_arg(node: itir.FunCall, arg_idx: int, new_arg: itir.Expr | str): + """Given a ititir.FunCall return a new call with one of its argument replaced.""" + return itir.FunCall( fun=node.fun, args=[arg if i != arg_idx else im.ensure_expr(new_arg) for i, arg in enumerate(node.args)], ) -def _is_trivial_make_tuple_call(node: ir.Expr): +def _with_altered_iterator_element_type( + type_: it_ts.IteratorType, new_el_type: ts.DataType +) -> it_ts.IteratorType: + return it_ts.IteratorType( + position_dims=type_.position_dims, defined_dims=type_.defined_dims, element_type=new_el_type + ) + + +def _with_altered_iterator_position_dims( + type_: it_ts.IteratorType, new_position_dims: list[common.Dimension] | Literal["unknown"] +) -> it_ts.IteratorType: + return it_ts.IteratorType( + position_dims=new_position_dims, + defined_dims=type_.defined_dims, + element_type=type_.element_type, + ) + + +def _is_trivial_make_tuple_call(node: itir.Expr): """Return if node is a `make_tuple` call with all elements `SymRef`s, `Literal`s or tuples thereof.""" if not cpm.is_call_to(node, "make_tuple"): return False - if not all( - isinstance(arg, (ir.SymRef, ir.Literal)) or _is_trivial_make_tuple_call(arg) - for arg in node.args - ): + if not all(_is_trivial_or_tuple_thereof_expr(arg) for arg in node.args): return False return True -def _is_trivial_or_tuple_thereof_expr(node: ir.Node) -> bool: +def _is_trivial_or_tuple_thereof_expr(node: itir.Node) -> bool: """ Return `true` if the expr is a trivial expression (`SymRef` or `Literal`) or tuple thereof. @@ -65,7 +84,7 @@ def _is_trivial_or_tuple_thereof_expr(node: ir.Node) -> bool: ... ) True """ - if isinstance(node, (ir.SymRef, ir.Literal)): + if isinstance(node, (itir.SymRef, itir.Literal)): return True if cpm.is_call_to(node, "make_tuple"): return all(_is_trivial_or_tuple_thereof_expr(arg) for arg in node.args) @@ -82,6 +101,19 @@ def _is_trivial_or_tuple_thereof_expr(node: ir.Node) -> bool: return False +def _flattened_as_fieldop_param_el_name(param: str, idx: int) -> str: + prefix = "__ct_flat_el_" + + # keep the original param name, but skip prefix from previous flattenings + if param.startswith(prefix): + parent_idx, suffix = re.split(r"_(?!\d)", param[len(prefix) :], maxsplit=1) + prefix = f"{prefix}{parent_idx}_" + else: + suffix = param + + return f"{prefix}{idx}_{suffix}" + + # TODO(tehrengruber): Conceptually the structure of this pass makes sense: Visit depth first, # transform each node until no transformations apply anymore, whenever a node is to be transformed # go through all available transformation and apply them. However the final result here still @@ -138,6 +170,10 @@ class Transformation(enum.Flag): PROPAGATE_NESTED_LET = enum.auto() #: `let(a, 1)(a)` -> `1` or `let(a, b)(f(a))` -> `f(a)` INLINE_TRIVIAL_LET = enum.auto() + #: `as_fieldop(λ(t) → ·t[0]+·t[1])({a, b})` -> `as_fieldop(λ(a, b) → ·a+·b)(a, b)` + FLATTEN_AS_FIELDOP_ARGS = enum.auto() + #: `let(a, b[1])(a)` -> `b[1]` + INLINE_TRIVIAL_TUPLE_LET_VAR = enum.auto() @classmethod def all(self) -> CollapseTuple.Transformation: @@ -152,7 +188,7 @@ def all(self) -> CollapseTuple.Transformation: @classmethod def apply( cls, - node: ir.Node, + node: itir.Node, *, ignore_tuple_size: bool = False, remove_letified_make_tuple_elements: bool = True, @@ -163,7 +199,7 @@ def apply( # allow sym references without a symbol declaration, mostly for testing allow_undeclared_symbols: bool = False, uids: Optional[eve_utils.UIDGenerator] = None, - ) -> ir.Node: + ) -> itir.Node: """ Simplifies `make_tuple`, `tuple_get` calls. @@ -181,7 +217,7 @@ def apply( offset_provider_type = offset_provider_type or {} uids = uids or eve_utils.UIDGenerator() - if isinstance(node, ir.Program): + if isinstance(node, itir.Program): within_stencil = False assert within_stencil in [ True, @@ -220,18 +256,18 @@ def visit(self, node, **kwargs): return super().visit(node, **kwargs) def transform_collapse_make_tuple_tuple_get( - self, node: ir.FunCall, **kwargs - ) -> Optional[ir.Node]: + self, node: itir.FunCall, **kwargs + ) -> Optional[itir.Node]: if cpm.is_call_to(node, "make_tuple") and all( cpm.is_call_to(arg, "tuple_get") for arg in node.args ): # `make_tuple(tuple_get(0, t), tuple_get(1, t), ..., tuple_get(N-1,t))` -> `t` - assert isinstance(node.args[0], ir.FunCall) + assert isinstance(node.args[0], itir.FunCall) first_expr = node.args[0].args[1] for i, v in enumerate(node.args): - assert isinstance(v, ir.FunCall) - assert isinstance(v.args[0], ir.Literal) + assert isinstance(v, itir.FunCall) + assert isinstance(v.args[0], itir.Literal) if not (int(v.args[0].value) == i and ir_misc.is_equal(v.args[1], first_expr)): # tuple argument differs, just continue with the rest of the tree return None @@ -248,11 +284,11 @@ def transform_collapse_make_tuple_tuple_get( return None def transform_collapse_tuple_get_make_tuple( - self, node: ir.FunCall, **kwargs - ) -> Optional[ir.Node]: + self, node: itir.FunCall, **kwargs + ) -> Optional[itir.Node]: if ( cpm.is_call_to(node, "tuple_get") - and isinstance(node.args[0], ir.Literal) + and isinstance(node.args[0], itir.Literal) and cpm.is_call_to(node.args[1], "make_tuple") ): # `tuple_get(i, make_tuple(e_0, e_1, ..., e_i, ..., e_N))` -> `e_i` @@ -265,8 +301,8 @@ def transform_collapse_tuple_get_make_tuple( return node.args[1].args[idx] return None - def transform_propagate_tuple_get(self, node: ir.FunCall, **kwargs) -> Optional[ir.Node]: - if cpm.is_call_to(node, "tuple_get") and isinstance(node.args[0], ir.Literal): + def transform_propagate_tuple_get(self, node: itir.FunCall, **kwargs) -> Optional[itir.Node]: + if cpm.is_call_to(node, "tuple_get") and isinstance(node.args[0], itir.Literal): # TODO(tehrengruber): extend to general symbols as long as the tail call in the let # does not capture # `tuple_get(i, let(...)(make_tuple()))` -> `let(...)(tuple_get(i, make_tuple()))` @@ -289,12 +325,14 @@ def transform_propagate_tuple_get(self, node: ir.FunCall, **kwargs) -> Optional[ ) return None - def transform_letify_make_tuple_elements(self, node: ir.Node, **kwargs) -> Optional[ir.Node]: + def transform_letify_make_tuple_elements( + self, node: itir.Node, **kwargs + ) -> Optional[itir.Node]: if cpm.is_call_to(node, "make_tuple"): # `make_tuple(expr1, expr1)` # -> `let((_tuple_el_1, expr1), (_tuple_el_2, expr2))(make_tuple(_tuple_el_1, _tuple_el_2))` - bound_vars: dict[ir.Sym, ir.Expr] = {} - new_args: list[ir.Expr] = [] + bound_vars: dict[itir.Sym, itir.Expr] = {} + new_args: list[itir.Expr] = [] for arg in node.args: if cpm.is_call_to(node, "make_tuple") and not _is_trivial_make_tuple_call(node): el_name = self.uids.sequential_id(prefix="__ct_el") @@ -309,7 +347,7 @@ def transform_letify_make_tuple_elements(self, node: ir.Node, **kwargs) -> Optio ) return None - def transform_inline_trivial_make_tuple(self, node: ir.Node, **kwargs) -> Optional[ir.Node]: + def transform_inline_trivial_make_tuple(self, node: itir.Node, **kwargs) -> Optional[itir.Node]: if cpm.is_let(node): # `let(tup, make_tuple(trivial_expr1, trivial_expr2))(foo(tup))` # -> `foo(make_tuple(trivial_expr1, trivial_expr2))` @@ -318,13 +356,15 @@ def transform_inline_trivial_make_tuple(self, node: ir.Node, **kwargs) -> Option return self.visit(inline_lambda(node, eligible_params=eligible_params), **kwargs) return None - def transform_propagate_to_if_on_tuples(self, node: ir.FunCall, **kwargs) -> Optional[ir.Node]: + def transform_propagate_to_if_on_tuples( + self, node: itir.FunCall, **kwargs + ) -> Optional[itir.Node]: if kwargs["within_stencil"]: # TODO(tehrengruber): This significantly increases the size of the tree. Skip transformation # in local-view for now. Revisit. return None - if isinstance(node, ir.FunCall) and not cpm.is_call_to(node, "if_"): + if isinstance(node, itir.FunCall) and not cpm.is_call_to(node, "if_"): # TODO(tehrengruber): Only inline if type of branch value is a tuple. # Examples: # `(if cond then {1, 2} else {3, 4})[0]` -> `if cond then {1, 2}[0] else {3, 4}[0]` @@ -343,8 +383,8 @@ def transform_propagate_to_if_on_tuples(self, node: ir.FunCall, **kwargs) -> Opt return None def transform_propagate_to_if_on_tuples_cps( - self, node: ir.FunCall, **kwargs - ) -> Optional[ir.Node]: + self, node: itir.FunCall, **kwargs + ) -> Optional[itir.Node]: # The basic idea of this transformation is to remove tuples across if-stmts by rewriting # the expression in continuation passing style, e.g. something like a tuple reordering # ``` @@ -366,7 +406,7 @@ def transform_propagate_to_if_on_tuples_cps( # `if True then {2, 1} else {4, 3}`. The examples in the comments below all refer to this # tuple reordering example here. - if not isinstance(node, ir.FunCall) or cpm.is_call_to(node, "if_"): + if not isinstance(node, itir.FunCall) or cpm.is_call_to(node, "if_"): return None # The first argument that is eligible also transforms all remaining args (They will be @@ -438,7 +478,7 @@ def transform_propagate_to_if_on_tuples_cps( return None - def transform_propagate_nested_let(self, node: ir.FunCall, **kwargs) -> Optional[ir.Node]: + def transform_propagate_nested_let(self, node: itir.FunCall, **kwargs) -> Optional[itir.Node]: if cpm.is_let(node): # `let((a, let(b, 1)(a_val)))(a)`-> `let(b, 1)(let(a, a_val)(a))` outer_vars = {} @@ -464,14 +504,95 @@ def transform_propagate_nested_let(self, node: ir.FunCall, **kwargs) -> Optional ) return None - def transform_inline_trivial_let(self, node: ir.FunCall, **kwargs) -> Optional[ir.Node]: + def transform_inline_trivial_let(self, node: itir.FunCall, **kwargs) -> Optional[itir.Node]: if cpm.is_let(node): - if isinstance(node.fun.expr, ir.SymRef): # type: ignore[attr-defined] # ensured by is_let + if isinstance(node.fun.expr, itir.SymRef): # type: ignore[attr-defined] # ensured by is_let # `let(a, 1)(a)` -> `1` for arg_sym, arg in zip(node.fun.params, node.args): # type: ignore[attr-defined] # ensured by is_let - if isinstance(node.fun.expr, ir.SymRef) and node.fun.expr.id == arg_sym.id: # type: ignore[attr-defined] # ensured by is_let + if isinstance(node.fun.expr, itir.SymRef) and node.fun.expr.id == arg_sym.id: # type: ignore[attr-defined] # ensured by is_let return arg - if any(trivial_args := [isinstance(arg, (ir.SymRef, ir.Literal)) for arg in node.args]): + if any( + trivial_args := [isinstance(arg, (itir.SymRef, itir.Literal)) for arg in node.args] + ): return inline_lambda(node, eligible_params=trivial_args) return None + + def transform_inline_trivial_tuple_let_var(self, node: ir.Node, **kwargs) -> Optional[ir.Node]: + if cpm.is_let(node): + if any(trivial_args := [_is_trivial_or_tuple_thereof_expr(arg) for arg in node.args]): + return inline_lambda(node, eligible_params=trivial_args) + return None + + # TODO(tehrengruber): This is a transformation that should be executed before visiting the children. Then + # revisiting the body would not be needed. + def transform_flatten_as_fieldop_args( + self, node: itir.FunCall, **kwargs + ) -> Optional[itir.Node]: + # `as_fieldop(λ(t) → ·t[0]+·t[1])({a, b})` -> `as_fieldop(λ(a, b) → ·a+·b)(a, b)` + if not cpm.is_applied_as_fieldop(node): + return None + + for arg in node.args: + itir_type_inference.reinfer(arg) + + if not any(isinstance(arg.type, ts.TupleType) for arg in node.args): + return None + + node = ir_misc.canonicalize_as_fieldop(node) + stencil, restore_scan = ir_misc.unwrap_scan( + node.fun.args[0] # type: ignore[attr-defined] # ensured by cpm.is_applied_as_fieldop + ) + + new_body = stencil.expr + domain = node.fun.args[1] if len(node.fun.args) > 1 else None # type: ignore[attr-defined] # ensured by cpm.is_applied_as_fieldop + remapped_args: dict[ + itir.Sym, itir.Expr + ] = {} # contains the arguments that are remapped, e.g. `{a, b}` + new_params: list[itir.Sym] = [] + new_args: list[itir.Expr] = [] + for param, arg in zip(stencil.params, node.args, strict=True): + if isinstance(arg.type, ts.TupleType): + ref_to_remapped_arg = im.ref(f"__ct_flat_remapped_{len(remapped_args)}", arg.type) + remapped_args[im.sym(ref_to_remapped_arg.id, arg.type)] = arg + new_params_inner, lift_params = [], [] + for i, type_ in enumerate(param.type.element_type.types): + new_param = im.sym( + _flattened_as_fieldop_param_el_name(param.id, i), + _with_altered_iterator_element_type(param.type, type_), + ) + lift_params.append( + im.sym( + new_param.id, + _with_altered_iterator_position_dims(new_param.type, "unknown"), # type: ignore[arg-type] # always in IteratorType + ) + ) + new_params_inner.append(new_param) + new_args.append(im.tuple_get(i, ref_to_remapped_arg)) + + # an iterator that substitutes the original (tuple) iterator, e.g. `t`. Built + # from the new parameters which are the elements of `t`. + param_substitute = im.lift( + im.lambda_(*lift_params)( + im.make_tuple(*[im.deref(im.ref(p.id, p.type)) for p in lift_params]) + ) + )(*[im.ref(p.id, p.type) for p in new_params_inner]) + + new_body = im.let(param.id, param_substitute)(new_body) + # note: the lift is trivial so inlining it is not an issue with respect to tree size + new_body = inline_lambda(new_body, force_inline_lift_args=True) + + new_params.extend(new_params_inner) + else: + new_params.append(param) + new_args.append(arg) + + # remove lifts again + new_body = inline_lifts.InlineLifts( + flags=inline_lifts.InlineLifts.Flag.INLINE_DEREF_LIFT + | inline_lifts.InlineLifts.Flag.PROPAGATE_SHIFT + ).visit(new_body) + new_body = self.visit(new_body, **kwargs) + new_stencil = restore_scan(im.lambda_(*new_params)(new_body)) + + return im.let(*remapped_args.items())(im.as_fieldop(new_stencil, domain)(*new_args)) diff --git a/src/gt4py/next/iterator/transforms/fuse_as_fieldop.py b/src/gt4py/next/iterator/transforms/fuse_as_fieldop.py index 81633dfb87..b746369152 100644 --- a/src/gt4py/next/iterator/transforms/fuse_as_fieldop.py +++ b/src/gt4py/next/iterator/transforms/fuse_as_fieldop.py @@ -21,6 +21,7 @@ common_pattern_matcher as cpm, domain_utils, ir_makers as im, + misc as ir_misc, ) from gt4py.next.iterator.transforms import ( fixed_point_transformation, @@ -46,26 +47,6 @@ def _merge_arguments( return new_args -def _canonicalize_as_fieldop(expr: itir.FunCall) -> itir.FunCall: - """ - Canonicalize applied `as_fieldop`s. - - In case the stencil argument is a `deref` wrap it into a lambda such that we have a unified - format to work with (e.g. each parameter has a name without the need to special case). - """ - assert cpm.is_applied_as_fieldop(expr) - - stencil = expr.fun.args[0] # type: ignore[attr-defined] - domain = expr.fun.args[1] if len(expr.fun.args) > 1 else None # type: ignore[attr-defined] - if cpm.is_ref_to(stencil, "deref"): - stencil = im.lambda_("arg")(im.deref("arg")) - new_expr = im.as_fieldop(stencil, domain)(*expr.args) - - return new_expr - - return expr - - def _is_tuple_expr_of_literals(expr: itir.Expr): if cpm.is_call_to(expr, "make_tuple"): return all(_is_tuple_expr_of_literals(arg) for arg in expr.args) @@ -78,7 +59,7 @@ def _inline_as_fieldop_arg( arg: itir.Expr, *, uids: eve_utils.UIDGenerator ) -> tuple[itir.Expr, dict[str, itir.Expr]]: assert cpm.is_applied_as_fieldop(arg) - arg = _canonicalize_as_fieldop(arg) + arg = ir_misc.canonicalize_as_fieldop(arg) stencil, *_ = arg.fun.args # type: ignore[attr-defined] # ensured by `is_applied_as_fieldop` inner_args: list[itir.Expr] = arg.args @@ -114,50 +95,6 @@ def _inline_as_fieldop_arg( ), extracted_args -def _unwrap_scan(stencil: itir.Lambda | itir.FunCall): - """ - If given a scan, extract stencil part of its scan pass and a back-transformation into a scan. - - If a regular stencil is given the stencil is left as-is and the back-transformation is the - identity function. This function allows treating a scan stencil like a regular stencil during - a transformation avoiding the complexity introduced by the different IR format. - - >>> scan = im.call("scan")( - ... im.lambda_("state", "arg")(im.plus("state", im.deref("arg"))), True, 0.0 - ... ) - >>> stencil, back_trafo = _unwrap_scan(scan) - >>> str(stencil) - 'λ(arg) → state + ·arg' - >>> str(back_trafo(stencil)) - 'scan(λ(state, arg) → (λ(arg) → state + ·arg)(arg), True, 0.0)' - - In case a regular stencil is given it is returned as-is: - - >>> deref_stencil = im.lambda_("it")(im.deref("it")) - >>> stencil, back_trafo = _unwrap_scan(deref_stencil) - >>> assert stencil == deref_stencil - """ - if cpm.is_call_to(stencil, "scan"): - scan_pass, direction, init = stencil.args - assert isinstance(scan_pass, itir.Lambda) - # remove scan pass state to be used by caller - state_param = scan_pass.params[0] - stencil_like = im.lambda_(*scan_pass.params[1:])(scan_pass.expr) - - def restore_scan(transformed_stencil_like: itir.Lambda): - new_scan_pass = im.lambda_(state_param, *transformed_stencil_like.params)( - im.call(transformed_stencil_like)( - *(param.id for param in transformed_stencil_like.params) - ) - ) - return im.call("scan")(new_scan_pass, direction, init) - - return stencil_like, restore_scan - - assert isinstance(stencil, itir.Lambda) - return stencil, lambda s: s - - def fuse_as_fieldop( expr: itir.Expr, eligible_args: list[bool], *, uids: eve_utils.UIDGenerator ) -> itir.Expr: @@ -165,7 +102,7 @@ def fuse_as_fieldop( stencil: itir.Lambda = expr.fun.args[0] # type: ignore[attr-defined] # ensured by is_applied_as_fieldop assert isinstance(expr.fun.args[0], itir.Lambda) or cpm.is_call_to(stencil, "scan") # type: ignore[attr-defined] # ensured by is_applied_as_fieldop - stencil, restore_scan = _unwrap_scan(stencil) + stencil, restore_scan = ir_misc.unwrap_scan(stencil) domain = expr.fun.args[1] if len(expr.fun.args) > 1 else None # type: ignore[attr-defined] # ensured by is_applied_as_fieldop @@ -411,7 +348,7 @@ def transform_fuse_make_tuple(self, node: itir.Node, **kwargs): def transform_fuse_as_fieldop(self, node: itir.Node, **kwargs): if cpm.is_applied_as_fieldop(node): - node = _canonicalize_as_fieldop(node) + node = ir_misc.canonicalize_as_fieldop(node) stencil = node.fun.args[0] # type: ignore[attr-defined] # ensure cpm.is_applied_as_fieldop assert isinstance(stencil, itir.Lambda) or cpm.is_call_to(stencil, "scan") args: list[itir.Expr] = node.args diff --git a/src/gt4py/next/iterator/type_system/inference.py b/src/gt4py/next/iterator/type_system/inference.py index fe450625db..cc7a7123b9 100644 --- a/src/gt4py/next/iterator/type_system/inference.py +++ b/src/gt4py/next/iterator/type_system/inference.py @@ -37,6 +37,10 @@ def _set_node_type(node: itir.Node, type_: ts.TypeSpec) -> None: assert type_info.is_compatible_type( node.type, type_ ), "Node already has a type which differs." + if isinstance(node, itir.Lambda): + assert isinstance(type_, ts.FunctionType) + for param, param_type in zip(node.params, type_.pos_only_args): + _set_node_type(param, param_type) node.type = type_ diff --git a/tests/next_tests/unit_tests/iterator_tests/test_type_inference.py b/tests/next_tests/unit_tests/iterator_tests/test_type_inference.py index a39fe3c6d8..577c7bce1c 100644 --- a/tests/next_tests/unit_tests/iterator_tests/test_type_inference.py +++ b/tests/next_tests/unit_tests/iterator_tests/test_type_inference.py @@ -245,6 +245,7 @@ def test_aliased_function(): assert result.args[0].type == ts.FunctionType( pos_only_args=[int_type], pos_or_kw_args={}, kw_only_args={}, returns=int_type ) + assert result.args[0].params[0].type == int_type assert result.type == int_type diff --git a/tests/next_tests/unit_tests/iterator_tests/transforms_tests/test_collapse_tuple.py b/tests/next_tests/unit_tests/iterator_tests/transforms_tests/test_collapse_tuple.py index 916ae4e578..636e66940c 100644 --- a/tests/next_tests/unit_tests/iterator_tests/transforms_tests/test_collapse_tuple.py +++ b/tests/next_tests/unit_tests/iterator_tests/transforms_tests/test_collapse_tuple.py @@ -5,11 +5,14 @@ # # Please, refer to the LICENSE file in the root directory. # SPDX-License-Identifier: BSD-3-Clause - +from gt4py.next import common from gt4py.next.iterator.ir_utils import ir_makers as im from gt4py.next.iterator.transforms.collapse_tuple import CollapseTuple from gt4py.next.type_system import type_specifications as ts -from next_tests.unit_tests.iterator_tests.test_type_inference import int_type +from gt4py.next.iterator.type_system import type_specifications as it_ts + +int_type = ts.ScalarType(kind=ts.ScalarKind.INT32) +Vertex = common.Dimension(value="Vertex", kind=common.DimensionKind.HORIZONTAL) def test_simple_make_tuple_tuple_get(): @@ -125,8 +128,11 @@ def test_propagate_tuple_get(): def test_letify_make_tuple_elements(): - # anything that is not trivial, i.e. a SymRef, works here - el1, el2 = im.let("foo", "foo")("foo"), im.let("bar", "bar")("bar") + fun_type = ts.FunctionType( + pos_only_args=[], pos_or_kw_args={}, kw_only_args={}, returns=int_type + ) + # anything that is not trivial, works here + el1, el2 = im.call(im.ref("foo", fun_type))(), im.call(im.ref("bar", fun_type))() testee = im.make_tuple(el1, el2) expected = im.let(("__ct_el_1", el1), ("__ct_el_2", el2))( im.make_tuple("__ct_el_1", "__ct_el_2") @@ -311,3 +317,81 @@ def test_if_make_tuple_reorder_cps_external(): within_stencil=False, ) assert actual == expected + + +def test_flatten_as_fieldop_args(): + it_type = it_ts.IteratorType( + position_dims=[Vertex], + defined_dims=[Vertex], + element_type=ts.TupleType(types=[int_type, int_type]), + ) + testee = im.as_fieldop(im.lambda_(im.sym("it", it_type))(im.tuple_get(1, im.deref("it"))))( + im.make_tuple(1, 2) + ) + expected = im.as_fieldop( + im.lambda_("__ct_flat_el_0_it", "__ct_flat_el_1_it")(im.deref("__ct_flat_el_1_it")) + )(1, 2) + actual = CollapseTuple.apply( + testee, + enabled_transformations=~CollapseTuple.Transformation.PROPAGATE_TO_IF_ON_TUPLES, + allow_undeclared_symbols=True, + within_stencil=False, + ) + assert actual == expected + + +def test_flatten_as_fieldop_args_nested(): + it_type = it_ts.IteratorType( + position_dims=[Vertex], + defined_dims=[Vertex], + element_type=ts.TupleType( + types=[ + int_type, + ts.TupleType(types=[int_type, int_type]), + ] + ), + ) + testee = im.as_fieldop( + im.lambda_(im.sym("it", it_type))(im.tuple_get(1, im.tuple_get(1, im.deref("it")))) + )(im.make_tuple(1, im.make_tuple(2, 3))) + expected = im.as_fieldop( + im.lambda_("__ct_flat_el_0_it", "__ct_flat_el_1_0_it", "__ct_flat_el_1_1_it")( + im.deref("__ct_flat_el_1_1_it") + ) + )(1, 2, 3) + actual = CollapseTuple.apply( + testee, + enabled_transformations=~CollapseTuple.Transformation.PROPAGATE_TO_IF_ON_TUPLES, + allow_undeclared_symbols=True, + within_stencil=False, + ) + assert actual == expected + + +def test_flatten_as_fieldop_args_scan(): + it_type = it_ts.IteratorType( + position_dims=[Vertex], + defined_dims=[Vertex], + element_type=ts.TupleType(types=[int_type, int_type]), + ) + testee = im.as_fieldop( + im.scan( + im.lambda_("state", im.sym("it", it_type))(im.tuple_get(1, im.deref("it"))), True, 0 + ) + )(im.make_tuple(1, 2)) + expected = im.as_fieldop( + im.scan( + im.lambda_("state", "__ct_flat_el_0_it", "__ct_flat_el_1_it")( + im.deref("__ct_flat_el_1_it") + ), + True, + 0, + ) + )(1, 2) + actual = CollapseTuple.apply( + testee, + enabled_transformations=~CollapseTuple.Transformation.PROPAGATE_TO_IF_ON_TUPLES, + allow_undeclared_symbols=True, + within_stencil=False, + ) + assert actual == expected