From e4f9be12c3664cabf4d491ee47426e515095a172 Mon Sep 17 00:00:00 2001 From: Sara Faghih-Naini Date: Wed, 15 Jan 2025 20:53:49 +0100 Subject: [PATCH 01/39] constant fold min/max for domain expressions --- .../iterator/transforms/constant_folding.py | 25 +++++- .../transforms_tests/test_constant_folding.py | 86 +++++++++++++++++++ 2 files changed, 110 insertions(+), 1 deletion(-) diff --git a/src/gt4py/next/iterator/transforms/constant_folding.py b/src/gt4py/next/iterator/transforms/constant_folding.py index 2084ab2518..fdee6b41ac 100644 --- a/src/gt4py/next/iterator/transforms/constant_folding.py +++ b/src/gt4py/next/iterator/transforms/constant_folding.py @@ -26,7 +26,30 @@ def visit_FunCall(self, node: ir.FunCall): and new_node.args[0] == new_node.args[1] ): # `minimum(a, a)` -> `a` return new_node.args[0] - + if isinstance(new_node.fun, ir.SymRef) and new_node.fun.id in ["minimum", "maximum"]: + if new_node.args[0] == new_node.args[1]: + return new_node.args[0] + if isinstance(new_node.args[0], ir.FunCall) and isinstance(new_node.args[1], ir.SymRef): + fun_call, sym_ref = new_node.args + elif isinstance(new_node.args[0], ir.SymRef) and isinstance( + new_node.args[1], ir.FunCall + ): + sym_ref, fun_call = new_node.args + else: + return new_node + if fun_call.fun.id in ["plus", "minus"]: + if fun_call.args[0] == sym_ref: + if new_node.fun.id == "minimum": + if fun_call.fun.id == "plus": + return sym_ref if fun_call.args[1].value >= "0" else fun_call + elif fun_call.fun.id == "minus": + return fun_call if fun_call.args[1].value > "0" else sym_ref + elif new_node.fun.id == "maximum": + if fun_call.fun.id == "plus": + return fun_call if fun_call.args[1].value > "0" else sym_ref + elif fun_call.fun.id == "minus": + return sym_ref if fun_call.args[1].value >= "0" else fun_call + return new_node.args[0] if ( isinstance(new_node.fun, ir.SymRef) and new_node.fun.id == "if_" diff --git a/tests/next_tests/unit_tests/iterator_tests/transforms_tests/test_constant_folding.py b/tests/next_tests/unit_tests/iterator_tests/transforms_tests/test_constant_folding.py index 0bf8dcb65d..cc80606eb1 100644 --- a/tests/next_tests/unit_tests/iterator_tests/transforms_tests/test_constant_folding.py +++ b/tests/next_tests/unit_tests/iterator_tests/transforms_tests/test_constant_folding.py @@ -48,6 +48,92 @@ def test_constant_folding_minimum(): assert actual == expected +def test_constant_folding_maximum_literal_plus(): + testee = im.call("maximum")( + im.plus(im.ref("__out_size_1"), im.literal_from_value(1)), im.ref("__out_size_1") + ) + expected = im.plus(im.ref("__out_size_1"), im.literal_from_value(1)) + actual = ConstantFolding.apply(testee) + assert actual == expected + + testee = im.call("maximum")( + im.ref("__out_size_1"), im.plus(im.ref("__out_size_1"), im.literal_from_value(1)) + ) + expected = im.plus(im.ref("__out_size_1"), im.literal_from_value(1)) + actual = ConstantFolding.apply(testee) + assert actual == expected + + testee = im.call("maximum")( + im.ref("__out_size_1"), im.plus(im.ref("__out_size_1"), im.literal_from_value(-1)) + ) + expected = im.ref("__out_size_1") + actual = ConstantFolding.apply(testee) + assert actual == expected + + testee = im.call("maximum")( + im.minus(im.ref("__out_size_1"), im.literal_from_value(1)), im.ref("__out_size_1") + ) + expected = im.ref("__out_size_1") + actual = ConstantFolding.apply(testee) + assert actual == expected + + testee = im.call("maximum")( + im.ref("__out_size_1"), im.minus(im.ref("__out_size_1"), im.literal_from_value(1)) + ) + expected = im.ref("__out_size_1") + actual = ConstantFolding.apply(testee) + assert actual == expected + + testee = im.call("maximum")( + im.ref("__out_size_1"), im.minus(im.ref("__out_size_1"), im.literal_from_value(-1)) + ) + expected = im.minus(im.ref("__out_size_1"), im.literal_from_value(-1)) + actual = ConstantFolding.apply(testee) + assert actual == expected + + testee = im.call("minimum")( + im.plus(im.ref("__out_size_1"), im.literal_from_value(1)), im.ref("__out_size_1") + ) + expected = im.ref("__out_size_1") + actual = ConstantFolding.apply(testee) + assert actual == expected + + testee = im.call("minimum")( + im.ref("__out_size_1"), im.plus(im.ref("__out_size_1"), im.literal_from_value(1)) + ) + expected = im.ref("__out_size_1") + actual = ConstantFolding.apply(testee) + assert actual == expected + + testee = im.call("minimum")( + im.ref("__out_size_1"), im.plus(im.ref("__out_size_1"), im.literal_from_value(-1)) + ) + expected = im.plus(im.ref("__out_size_1"), im.literal_from_value(-1)) + actual = ConstantFolding.apply(testee) + assert actual == expected + + testee = im.call("minimum")( + im.minus(im.ref("__out_size_1"), im.literal_from_value(1)), im.ref("__out_size_1") + ) + expected = im.minus(im.ref("__out_size_1"), im.literal_from_value(1)) + actual = ConstantFolding.apply(testee) + assert actual == expected + + testee = im.call("minimum")( + im.ref("__out_size_1"), im.minus(im.ref("__out_size_1"), im.literal_from_value(1)) + ) + expected = im.minus(im.ref("__out_size_1"), im.literal_from_value(1)) + actual = ConstantFolding.apply(testee) + assert actual == expected + + testee = im.call("minimum")( + im.ref("__out_size_1"), im.minus(im.ref("__out_size_1"), im.literal_from_value(-1)) + ) + expected = im.ref("__out_size_1") + actual = ConstantFolding.apply(testee) + assert actual == expected + + def test_constant_folding_literal(): testee = im.plus(im.literal_from_value(1), im.literal_from_value(2)) expected = im.literal_from_value(3) From 21de7d787e88719aee0dbca883d8a17244f56de1 Mon Sep 17 00:00:00 2001 From: Sara Faghih-Naini Date: Tue, 21 Jan 2025 12:11:22 +0100 Subject: [PATCH 02/39] Extend and refactor constant folding --- src/gt4py/next/iterator/ir.py | 1 + .../iterator/transforms/constant_folding.py | 405 +++++++++++++++--- .../transforms_tests/test_constant_folding.py | 226 +++++++++- 3 files changed, 573 insertions(+), 59 deletions(-) diff --git a/src/gt4py/next/iterator/ir.py b/src/gt4py/next/iterator/ir.py index e875709631..b680c8d679 100644 --- a/src/gt4py/next/iterator/ir.py +++ b/src/gt4py/next/iterator/ir.py @@ -116,6 +116,7 @@ class FunctionDefinition(Node, SymbolTableTrait): "floor", "ceil", "trunc", + "neg", } UNARY_MATH_FP_PREDICATE_BUILTINS = {"isfinite", "isinf", "isnan"} BINARY_MATH_NUMBER_BUILTINS = { diff --git a/src/gt4py/next/iterator/transforms/constant_folding.py b/src/gt4py/next/iterator/transforms/constant_folding.py index fdee6b41ac..7c64f8a08c 100644 --- a/src/gt4py/next/iterator/transforms/constant_folding.py +++ b/src/gt4py/next/iterator/transforms/constant_folding.py @@ -8,73 +8,370 @@ from gt4py.eve import NodeTranslator, PreserveLocationVisitor from gt4py.next.iterator import embedded, ir -from gt4py.next.iterator.ir_utils import ir_makers as im +import functools +import operator +from gt4py.next.iterator.ir_utils import common_pattern_matcher as cpm, ir_makers as im +from gt4py.next.type_system import type_specifications as ts, type_translation +import dataclasses +import enum +from typing import Optional + + +class ConvertMinusToUnary(PreserveLocationVisitor, NodeTranslator): + def visit(self, node: ir.Node): + node = self.generic_visit(node) + + # im.minus(1, im.ref("a")) -> im.plus(im.call("neg)(im.ref("a")), 1) + if cpm.is_call_to(node, "minus"): + if isinstance(node.args[1], (ir.SymRef, ir.FunCall)): + node = im.plus(im.call("neg")(node.args[1]), node.args[0]) + return node + +class ConvertUnaryToMinus(PreserveLocationVisitor, NodeTranslator): + + def visit(self, node: ir.Node): + + if isinstance(node, ir.FunCall) and cpm.is_call_to(node.args[0], "neg"): + manipulated_first_arg = False + if node.args[0].args[0].type: + zero = im.literal(str(0), node.args[0].args[0].type) + else: + zero = im.literal_from_value(0.0) # TODO: fix datatype + if cpm.is_call_to(node, "minus"): + node = im.plus(node.args[1], node.args[0].args[0]) + manipulated_first_arg = True + elif cpm.is_call_to(node, "plus"): + node = im.minus(node.args[1], node.args[0].args[0]) + manipulated_first_arg = True + elif cpm.is_call_to(node, "multiplies"): + node = im.multiplies_(im.minus(zero,node.args[0].args[0]), node.args[1]) + manipulated_first_arg = True + elif cpm.is_call_to(node, "divides"): + node = im.divides_(im.minus(zero,node.args[0].args[0]), node.args[1]) + manipulated_first_arg = True + elif cpm.is_call_to(node, "minimum"): + node = im.call("minimum")(im.minus(zero,node.args[0].args[0]), node.args[1]) + manipulated_first_arg = True + elif cpm.is_call_to(node, "maximum"): + node = im.call("maximum")(im.minus(zero,node.args[0].args[0]), node.args[1]) + manipulated_first_arg = True + if manipulated_first_arg: + node = self.visit(node) + elif isinstance(node, ir.FunCall) and len(node.args) > 2 and cpm.is_call_to(node.args[1], "neg"): + if node.args[0].args[0].type: + zero = im.literal(str(0), node.args[0].args[0].type) + else: + zero = im.literal_from_value(0.0) # TODO: fix datatype + if cpm.is_call_to(node, "minus"): + node = im.plus(node.args[0], node.args[1].args[0]) + elif cpm.is_call_to(node, "plus"): + node = im.minus(node.args[0], node.args[1].args[0]) + elif cpm.is_call_to(node, "multiplies"): + node = im.multiplies_(node.args[0], im.minus(zero,node.args[1].args[0])) + elif cpm.is_call_to(node, "divides"): + node = im.divides_(node.args[0], im.minus(zero, node.args[1].args[0])) + elif cpm.is_call_to(node, "minimum"): + node = im.call("minimum")(im.node.args[1], im.minus(zero,node.args[1].args[0])) + elif cpm.is_call_to(node, "maximum"): + node = im.call("maximum")(im.node.args[1], im.minus(zero,node.args[1].args[0])) + + return self.generic_visit(node) + + + +@dataclasses.dataclass(frozen=True) class ConstantFolding(PreserveLocationVisitor, NodeTranslator): + class Flag(enum.Flag): + # literal + symref -> symref + literal + CANONICALIZE_SYMREF_LITERAL = enum.auto() + + # literal + funcall -> funcall + literal + CANONICALIZE_FUNCALL_LITERAL = enum.auto() + + # `__out_size_1 + 1 + 1` -> `__out_size_1 + 2` + FOLD_FUNCALL_LITERAL = enum.auto() + + # `maximum(1, __out_size_1)` -> `maximum(__out_size_1, 1)` and `maximum(__out_size_1, maximum(__out_size_1, 1))` -> `maximum(maximum(__out_size_1, 1), __out_size_1)` + CANONICALIZE_MIN_MAX_FUNCALL_SYMREF_LITERAL = enum.auto() + + # `maximum(maximum(__out_size_1, 1), __out_size_1)` -> `maximum(__out_size_1, 1)` + FOLD_MIN_MAX_FUNCALL_SYMREF_LITERAL = enum.auto() + + # `minus(__out_size_1, literal) -> plus(__out_size_1,-literal)` + CANONICALIZE_MINUS_SYMREF_LITERAL = enum.auto() + + # `maximum(plus(__out_size_1, 1), __out_size_1)` -> `plus(__out_size_1, 1)` + # and `maximum(plus(__out_size_1, 1), plus(__out_size_1, -1))` -> `plus(__out_size_1, 1)` + FOLD_MIN_MAX_PLUS_MINUS = enum.auto() + + # `__out_size_1 + 0` -> `__out_size_1` + FOLD_SYMREF_PLUS_MINUS_ZERO = enum.auto() + + # `sym + 1 + (sym + 2)` -> `sym + sym + 2 + 1` + CANONICALIZE_PLUS_SYMREF_LITERAL = enum.auto() + + # `1 + 1` -> `2` + FOLD_ARITHMETIC_BUILTINS = enum.auto() + + # `neg(1)` -> `-1` + CANONICALIZE_NEG_LITERAL = enum.auto() + + # `minimum(a, a)` -> `a` + FOLD_MIN_MAX_LITERALS = enum.auto() + + # `if_(True, true_branch, false_branch)` -> `true_branch` + FOLD_IF = enum.auto() + + @classmethod + def all(self): # TODO: -> ConstantFolding.Flag + return functools.reduce(operator.or_, self.__members__.values()) + + flags: Flag = Flag.all() + + @classmethod - def apply(cls, node: ir.Node) -> ir.Node: - return cls().visit(node) + def apply(cls, node: ir.Node, flags: Optional[Flag] = None) -> ir.Node: + flags = flags or cls.flags + + node = ConvertMinusToUnary().visit(node) + node = cls().visit(node, flags=flags) #TODO: remove flags? + node = ConvertUnaryToMinus().visit(node) + return node + - def visit_FunCall(self, node: ir.FunCall): + def visit_FunCall(self, node: ir.FunCall, **kwargs): # visit depth-first such that nested constant expressions (e.g. `(1+2)+3`) are properly folded - new_node = self.generic_visit(node) + node = self.generic_visit(node, **kwargs) + return self.fp_transform(node, **kwargs) - if ( - isinstance(new_node.fun, ir.SymRef) - and new_node.fun.id in ["minimum", "maximum"] - and new_node.args[0] == new_node.args[1] - ): # `minimum(a, a)` -> `a` - return new_node.args[0] - if isinstance(new_node.fun, ir.SymRef) and new_node.fun.id in ["minimum", "maximum"]: - if new_node.args[0] == new_node.args[1]: - return new_node.args[0] - if isinstance(new_node.args[0], ir.FunCall) and isinstance(new_node.args[1], ir.SymRef): - fun_call, sym_ref = new_node.args - elif isinstance(new_node.args[0], ir.SymRef) and isinstance( - new_node.args[1], ir.FunCall + def fp_transform(self, node: ir.Node, **kwargs) -> ir.Node: + while True: + new_node = self.transform(node, **kwargs) + if new_node is None: + break + assert new_node != node + node = new_node + return node + + def transform(self, node: ir.Node, **kwargs) -> Optional[ir.Node]: + if not isinstance(node, ir.FunCall): + return None + + for transformation in self.Flag: + if self.flags & transformation: + assert isinstance(transformation.name, str) + method = getattr(self, f"transform_{transformation.name.lower()}") + result = method(node) + if result is not None: + assert ( + result is not node + ) # transformation should have returned None, since nothing changed + return result + return None + + def transform_canonicalize_symref_literal(self, node: ir.FunCall, **kwargs) -> Optional[ir.Node]: + # literal + symref -> symref + literal + if cpm.is_call_to(node, ("plus", "multiplies")): + if cpm.is_call_to(node, ("plus", "times")): + if isinstance(node.args[1], ir.SymRef) and isinstance(node.args[0], ir.Literal): + return im.call(node.fun.id)(node.args[1], node.args[0]) + return None + + def transform_canonicalize_funcall_literal(self, node: ir.FunCall, **kwargs) -> Optional[ir.Node]: + # literal + funcall -> funcall + literal + if cpm.is_call_to(node, ("plus", "multiplies")): + if cpm.is_call_to(node, ("plus", "multiplies")): + if isinstance(node.args[1], ir.FunCall) and isinstance(node.args[0], ir.Literal): + return im.call(node.fun.id)(node.args[1], node.args[0]) + return None + + def transform_fold_funcall_literal(self, node: ir.FunCall, **kwargs) -> Optional[ir.Node]: + # `__out_size_1 + 1 + 1` -> `__out_size_1 + 2` + if cpm.is_call_to(node, ("plus", "minus")): + if isinstance(node.args[0], ir.FunCall) and isinstance( + node.args[1], ir.Literal ): - sym_ref, fun_call = new_node.args - else: - return new_node - if fun_call.fun.id in ["plus", "minus"]: - if fun_call.args[0] == sym_ref: - if new_node.fun.id == "minimum": - if fun_call.fun.id == "plus": - return sym_ref if fun_call.args[1].value >= "0" else fun_call - elif fun_call.fun.id == "minus": - return fun_call if fun_call.args[1].value > "0" else sym_ref - elif new_node.fun.id == "maximum": - if fun_call.fun.id == "plus": - return fun_call if fun_call.args[1].value > "0" else sym_ref - elif fun_call.fun.id == "minus": - return sym_ref if fun_call.args[1].value >= "0" else fun_call - return new_node.args[0] - if ( - isinstance(new_node.fun, ir.SymRef) - and new_node.fun.id == "if_" - and isinstance(new_node.args[0], ir.Literal) - ): # `if_(True, true_branch, false_branch)` -> `true_branch` - if new_node.args[0].value == "True": - new_node = new_node.args[1] - else: - new_node = new_node.args[2] + fun_call, literal = node.args + if cpm.is_call_to(fun_call, ("plus", "minus")): + if isinstance(fun_call.args[0], (ir.SymRef, ir.FunCall)) and isinstance( + fun_call.args[1], ir.Literal + ): + return self.visit(im.plus( + fun_call.args[0], + self.visit(im.call(node.fun.id)(fun_call.args[1], literal)))) + return None + def transform_canonicalize_min_max_funcall_symref_literal(self, node: ir.FunCall, **kwargs) -> Optional[ir.Node]: + # `maximum(1, __out_size_1)` -> `maximum(__out_size_1, 1)` and `maximum(__out_size_1, maximum(__out_size_1, 1))` -> `maximum(maximum(__out_size_1, 1), __out_size_1)` + if cpm.is_call_to(node, ("minimum", "maximum")): + if ((isinstance(node.args[0], ir.Literal) and isinstance(node.args[1], (ir.SymRef, ir.FunCall))) or + (isinstance(node.args[0], ir.SymRef) and isinstance(node.args[1], ir.FunCall))): + return im.call(node.fun.id)(node.args[1], node.args[0]) + return None + + + def transform_fold_min_max_funcall_symref_literal(self, node: ir.FunCall, **kwargs) -> Optional[ir.Node]: + # `maximum(maximum(__out_size_1, 1), __out_size_1)` -> `maximum(__out_size_1, 1)` + if cpm.is_call_to(node, ("minimum", "maximum")): + if isinstance(node.args[0], ir.FunCall): + fun_call, arg1, = node.args + if cpm.is_call_to(fun_call, ("maximum", "minimum")): + if arg1 == fun_call.args[0]: + return self.visit(im.call(fun_call.fun.id)(fun_call.args[1], arg1)) + if arg1 == fun_call.args[1]: + return self.visit(im.call(fun_call.fun.id)(fun_call.args[0], arg1)) + return None + + def transform_canonicalize_minus_symref_literal(self, node: ir.FunCall, **kwargs) -> Optional[ir.Node]: + # `minus(__out_size_1, literal) -> plus(__out_size_1,-literal)` + if cpm.is_call_to(node, "minus") and isinstance(node.args[0], (ir.SymRef, ir.FunCall)) and isinstance(node.args[1], ir.Literal): + return self.visit(im.plus(node.args[0], im.minus(0, node.args[1]))) + return None + + def transform_fold_min_max_plus_minus(self, node: ir.FunCall, **kwargs) -> Optional[ir.Node]: + if cpm.is_call_to(node, ("minimum", "maximum")): + arg0, arg1 = node.args + # `maximum(plus(__out_size_1, 1), __out_size_1)` -> `plus(__out_size_1, 1)` + if cpm.is_call_to(arg0, ("plus", "minus")) and isinstance(arg1, (ir.SymRef, ir.FunCall)): + if arg0.args[0] == arg1: + return self.visit(im.call(arg0.fun.id)(arg0.args[0], im.call(node.fun.id)(0, arg0.args[1]))) + # `maximum(plus(__out_size_1, 1), plus(__out_size_1, -1))` -> `plus(__out_size_1, 1)` + if cpm.is_call_to(arg0, ("plus", "minus")) and cpm.is_call_to(arg1, ("plus", "minus")): + if arg0.args[0] == arg1.args[0]: + return self.visit(im.call(arg0.fun.id)(arg0.args[0], im.call(node.fun.id)(arg0.args[1], arg1.args[1]))) + return None + + def transform_fold_symref_plus_minus_zero(self, node: ir.FunCall, **kwargs) -> Optional[ir.Node]: + # `__out_size_1 + 0` -> `__out_size_1` + if cpm.is_call_to(node, ("plus", "minus")) and isinstance(node.args[0], (ir.SymRef, ir.FunCall)) and isinstance(node.args[1], ir.Literal) and node.args[1].value.isdigit() and int(node.args[1].value) == 0: + return node.args[0] + return None + + def transform_canonicalize_plus_symref_literal(self, node: ir.FunCall, **kwargs) -> Optional[ir.Node]: + # `sym1 + 1 + (sym2 + 2)` -> `sym1 + sym2 + 2 + 1` + if cpm.is_call_to(node, "plus"): + if cpm.is_call_to(node.args[0], "plus") and cpm.is_call_to(node.args[1], "plus") and isinstance(node.args[0].args[1], ir.Literal) and isinstance(node.args[1].args[1], ir.Literal): + return self.visit(im.plus(im.plus(node.args[0].args[0], node.args[1].args[0]),im.plus(node.args[0].args[1], node.args[1].args[1]))) + return None + + def transform_fold_arithmetic_builtins(self, node: ir.FunCall, **kwargs) -> Optional[ir.Node]: + # `1 + 1` -> `2` if ( - isinstance(new_node, ir.FunCall) - and isinstance(new_node.fun, ir.SymRef) - and len(new_node.args) > 0 - and all(isinstance(arg, ir.Literal) for arg in new_node.args) - ): # `1 + 1` -> `2` + isinstance(node, ir.FunCall) + and isinstance(node.fun, ir.SymRef) + and len(node.args) > 0 + and all(isinstance(arg, ir.Literal) for arg in node.args) + ): try: - if new_node.fun.id in ir.ARITHMETIC_BUILTINS: - fun = getattr(embedded, str(new_node.fun.id)) + if node.fun.id in ir.ARITHMETIC_BUILTINS and not cpm.is_call_to(node, "neg"): + fun = getattr(embedded, str(node.fun.id)) arg_values = [ - getattr(embedded, str(arg.type))(arg.value) # type: ignore[attr-defined] # arg type already established in if condition - for arg in new_node.args + getattr(embedded, str(arg.type))(arg.value) + # type: ignore[attr-defined] # arg type already established in if condition + for arg in node.args ] - new_node = im.literal_from_value(fun(*arg_values)) + return im.literal_from_value(fun(*arg_values)) except ValueError: pass # happens for inf and neginf + return None + + def transform_canonicalize_neg_literal(self, node: ir.FunCall, **kwargs) -> Optional[ir.Node]: + # `neg(1)` -> `-1` + if cpm.is_call_to(node, ("neg")): + if isinstance(node.args[0], ir.Literal): + return self.visit(im.minus(0, int(node.args[0].value))) + return None + + def transform_fold_min_max_literals(self, node: ir.FunCall, **kwargs) -> Optional[ir.Node]: + # `minimum(a, a)` -> `a` + if cpm.is_call_to(node, ("minimum", "maximum")): + if node.args[0] == node.args[1]: + return node.args[0] + return None + + def transform_fold_if(self, node: ir.FunCall, **kwargs) -> Optional[ir.Node]: + # `if_(True, true_branch, false_branch)` -> `true_branch` + if cpm.is_call_to(node, "if_") and isinstance(node.args[0], ir.Literal): + if node.args[0].value == "True": + return node.args[1] + else: + return node.args[2] + return None + + # # `maximum(maximum(__out_size_1, 1), maximum(1, __out_size_1))` -> `maximum(__out_size_1, 1)` + # if cpm.is_call_to(new_node, ("minimum", "maximum")): + # if all(cpm.is_call_to(arg, "maximum") for arg in new_node.args) or all( + # cpm.is_call_to(arg, "minimum") for arg in new_node.args + # ): + # if ( + # new_node.args[0].args[0] == new_node.args[1].args[1] + # and new_node.args[0].args[1] == new_node.args[1].args[0] + # ): + # new_node = new_node.args[0] + # # `maximum(maximum(__out_size_1, 1), __out_size_1)` -> `maximum(__out_size_1, 1)` + # if cpm.is_call_to(new_node, ("minimum", "maximum")): + # match = False + # if isinstance(new_node.args[0], ir.FunCall) and isinstance( + # new_node.args[1], (ir.Literal, ir.SymRef) + # ): + # fun_call, sym_lit = new_node.args + # match = True + # elif isinstance(new_node.args[0], (ir.Literal, ir.SymRef)) and isinstance( + # new_node.args[1], ir.FunCall + # ): + # match = True + # sym_lit, fun_call = new_node.args + # if match and cpm.is_call_to(fun_call, ("maximum", "minimum")): + # if isinstance(fun_call.args[0], ir.SymRef) and isinstance( + # fun_call.args[1], ir.Literal + # ): + # if sym_lit == fun_call.args[0]: + # new_node = im.call(fun_call.fun.id)(sym_lit, fun_call.args[1]) + # if sym_lit == fun_call.args[1]: + # new_node = im.call(fun_call.fun.id)(fun_call.args[0], sym_lit) + # if isinstance(fun_call.args[0], ir.Literal) and isinstance( + # fun_call.args[1], ir.SymRef + # ): + # if sym_lit == fun_call.args[0]: + # new_node = im.call(fun_call.fun.id)(fun_call.args[1], sym_lit) + # if sym_lit == fun_call.args[1]: + # new_node = im.call(fun_call.fun.id)(sym_lit, fun_call.args[0]) + # # `maximum(plus(__out_size_1, 1), minus(__out_size_1, 1))` -> `plus(__out_size_1, 1)` + # if cpm.is_call_to(new_node, ("minimum", "maximum")): + # if all(cpm.is_call_to(arg, ("plus", "minus")) for arg in new_node.args): + # if new_node.args[0].args[0] == new_node.args[1].args[0]: + # new_node = im.plus( + # new_node.args[0].args[0], + # self.visit( + # im.call(new_node.fun.id)( + # im.call(new_node.args[0].fun.id)(0, new_node.args[0].args[1]), + # im.call(new_node.args[1].fun.id)(0, new_node.args[1].args[1]), + # ) + # ), + # ) + # # `maximum(plus(__out_size_1, 1), __out_size_1)` -> `plus(__out_size_1, 1)` + # match = False + # if isinstance(new_node.args[0], ir.FunCall) and isinstance(new_node.args[1], ir.SymRef): + # fun_call, sym_ref = new_node.args + # match = True + # elif isinstance(new_node.args[0], ir.SymRef) and isinstance( + # new_node.args[1], ir.FunCall + # ): + # match = True + # sym_ref, fun_call = new_node.args + # if match and fun_call.fun.id in ["plus", "minus"]: + # if fun_call.args[0] == sym_ref: + # if new_node.fun.id == "minimum": + # if fun_call.fun.id == "plus": + # new_node = sym_ref if int(fun_call.args[1].value) >= 0 else fun_call + # elif fun_call.fun.id == "minus": + # new_node = fun_call if int(fun_call.args[1].value) > 0 else sym_ref + # elif new_node.fun.id == "maximum": + # if fun_call.fun.id == "plus": + # new_node = fun_call if int(fun_call.args[1].value) > 0 else sym_ref + # elif fun_call.fun.id == "minus": + # new_node = sym_ref if int(fun_call.args[1].value) >= 0 else fun_call - return new_node + # return new_node diff --git a/tests/next_tests/unit_tests/iterator_tests/transforms_tests/test_constant_folding.py b/tests/next_tests/unit_tests/iterator_tests/transforms_tests/test_constant_folding.py index cc80606eb1..3d82851f19 100644 --- a/tests/next_tests/unit_tests/iterator_tests/transforms_tests/test_constant_folding.py +++ b/tests/next_tests/unit_tests/iterator_tests/transforms_tests/test_constant_folding.py @@ -10,6 +10,14 @@ from gt4py.next.iterator.transforms.constant_folding import ConstantFolding + +def test_constant_folding_plus(): + expected = im.literal_from_value(2) + testee = im.plus( + im.literal_from_value(1), im.literal_from_value(1)) + actual = ConstantFolding.apply(testee) + assert actual == expected + def test_constant_folding_boolean(): testee = im.not_(im.literal_from_value(True)) expected = im.literal_from_value(False) @@ -31,7 +39,7 @@ def test_constant_folding_math_op(): def test_constant_folding_if(): - expected = im.call("plus")("a", 2) + expected = im.plus("a", 2) testee = im.if_( im.literal_from_value(True), im.plus(im.ref("a"), im.literal_from_value(2)), @@ -48,7 +56,132 @@ def test_constant_folding_minimum(): assert actual == expected -def test_constant_folding_maximum_literal_plus(): +def test_constant_folding_literal_plus0(): + testee = im.plus(im.ref("__out_size_1"), im.literal_from_value(1)) + expected = im.plus(im.ref("__out_size_1"), im.literal_from_value(1)) + actual = ConstantFolding.apply(testee) + assert actual == expected + + testee = im.plus(im.literal_from_value(1), im.ref("__out_size_1")) + expected = im.plus(im.ref("__out_size_1"), im.literal_from_value(1)) + actual = ConstantFolding.apply(testee) + assert actual == expected + +def test_constant_folding_literal_minus0(): + testee = im.minus(im.ref("__out_size_1"), im.literal_from_value(1)) + expected = im.plus(im.ref("__out_size_1"), im.literal_from_value(-1)) + actual = ConstantFolding.apply(testee) + assert actual == expected + + testee = im.minus(im.literal_from_value(1), im.ref("__out_size_1")) + expected = im.minus(im.literal_from_value(1), im.ref("__out_size_1")) + actual = ConstantFolding.apply(testee) + assert actual == expected + +def test_constant_folding_funcall_literal(): + testee = im.plus(im.plus(im.ref("__out_size_1"), im.literal_from_value(1)), im.literal_from_value(1)) + expected = im.plus(im.ref("__out_size_1"), im.literal_from_value(2)) + actual = ConstantFolding.apply(testee) + assert actual == expected + + testee = im.plus(im.literal_from_value(1), im.plus(im.ref("__out_size_1"), im.literal_from_value(1))) + expected = im.plus(im.ref("__out_size_1"), im.literal_from_value(2)) + actual = ConstantFolding.apply(testee) + assert actual == expected +def test_constant_folding_maximum_literal_minus(): + testee = im.plus(im.ref("__out_size_1"), im.literal_from_value(1)) + expected = im.plus(im.ref("__out_size_1"), im.literal_from_value(1)) + actual = ConstantFolding.apply(testee) + assert actual == expected + + testee = im.plus(im.literal_from_value(1), im.ref("__out_size_1")) + expected = im.plus(im.ref("__out_size_1"), im.literal_from_value(1)) + actual = ConstantFolding.apply(testee) + assert actual == expected + +def test_constant_folding_maximum_literal_plus1(): + testee = im.call("maximum")( + im.call("maximum")(im.ref("__out_size_1"), im.literal_from_value(1)), + im.literal_from_value(1), + ) + expected = im.call("maximum")(im.ref("__out_size_1"), im.literal_from_value(1)) + actual = ConstantFolding.apply(testee) + assert actual == expected + +def test_constant_folding_maximum_literal_plus2(): + testee = im.call("maximum")( + im.call("maximum")(im.literal_from_value(1), im.ref("__out_size_1")), + im.literal_from_value(1), + ) + expected = im.call("maximum")(im.ref("__out_size_1"), im.literal_from_value(1)) + actual = ConstantFolding.apply(testee) + assert actual == expected + +def test_constant_folding_maximum_literal_plus3(): + testee = im.call("maximum")( + im.call("maximum")(im.literal_from_value(1), im.ref("__out_size_1")), + im.call("maximum")(im.literal_from_value(1), im.ref("__out_size_1")), + ) + expected = im.call("maximum")(im.ref("__out_size_1"),im.literal_from_value(1)) + actual = ConstantFolding.apply(testee) + assert actual == expected + +def test_constant_folding_maximum_literal_plus4(): + testee = im.call("maximum")( + im.call("maximum")(im.literal_from_value(1), im.ref("__out_size_1")), + im.call("maximum")(im.ref("__out_size_1"), im.literal_from_value(1)), + ) + expected = im.call("maximum")(im.ref("__out_size_1"),im.literal_from_value(1)) + actual = ConstantFolding.apply(testee) + assert actual == expected + +def test_constant_folding_maximum_literal_plus5(): + testee = im.call("maximum")( + im.ref("__out_size_1"), im.call("maximum")(im.literal_from_value(1), im.ref("__out_size_1")) + ) + expected = im.call("maximum")(im.ref("__out_size_1"), im.literal_from_value(1)) + actual = ConstantFolding.apply(testee) + assert actual == expected + +def test_constant_folding_maximum_literal_plus6(): + testee = im.call("maximum")( + im.plus(im.ref("__out_size_1"), im.literal_from_value(1)), + im.plus(im.ref("__out_size_1"), im.literal_from_value(0)), + ) + expected = im.plus(im.ref("__out_size_1"), im.literal_from_value(1)) + actual = ConstantFolding.apply(testee) + assert actual == expected + +def test_constant_folding_maximum_literal_plus7(): + testee = im.minus( + im.plus(im.ref("__out_size_1"), im.literal_from_value(1)), + im.plus(im.literal_from_value(1), im.literal_from_value(1)), + ) + expected = im.plus(im.ref("__out_size_1"), im.literal_from_value(-1)) + actual = ConstantFolding.apply(testee) + assert actual == expected + +def test_constant_folding_maximum_literal_plus8(): + testee = im.plus( + im.plus(im.ref("__out_size_1"), im.literal_from_value(1)), + im.plus(im.literal_from_value(1), im.literal_from_value(1)), + ) + expected = im.plus(im.ref("__out_size_1"), im.literal_from_value(3)) + actual = ConstantFolding.apply(testee) + assert actual == expected + +def test_constant_folding_maximum_literal_plus9(): + testee = im.call("maximum")( + im.plus(im.ref("__out_size_1"), im.literal_from_value(1)), + im.plus( + im.plus(im.ref("__out_size_1"), im.literal_from_value(1)), im.literal_from_value(0) + ), + ) + expected = im.plus(im.ref("__out_size_1"), im.literal_from_value(1)) + actual = ConstantFolding.apply(testee) + assert actual == expected + +def test_constant_folding_maximum_literal_plus10(): testee = im.call("maximum")( im.plus(im.ref("__out_size_1"), im.literal_from_value(1)), im.ref("__out_size_1") ) @@ -56,6 +189,14 @@ def test_constant_folding_maximum_literal_plus(): actual = ConstantFolding.apply(testee) assert actual == expected +def test_constant_folding_maximum_literal_plus10a(): + testee = im.plus(im.ref("__out_size_1"), im.call("maximum")(im.literal_from_value(0), im.literal_from_value(-1))) + + expected = im.ref("__out_size_1") + actual = ConstantFolding.apply(testee) + assert actual == expected + +def test_constant_folding_maximum_literal_plus11(): testee = im.call("maximum")( im.ref("__out_size_1"), im.plus(im.ref("__out_size_1"), im.literal_from_value(1)) ) @@ -63,6 +204,7 @@ def test_constant_folding_maximum_literal_plus(): actual = ConstantFolding.apply(testee) assert actual == expected +def test_constant_folding_maximum_literal_plus12(): testee = im.call("maximum")( im.ref("__out_size_1"), im.plus(im.ref("__out_size_1"), im.literal_from_value(-1)) ) @@ -70,6 +212,7 @@ def test_constant_folding_maximum_literal_plus(): actual = ConstantFolding.apply(testee) assert actual == expected +def test_constant_folding_maximum_literal_plus13(): testee = im.call("maximum")( im.minus(im.ref("__out_size_1"), im.literal_from_value(1)), im.ref("__out_size_1") ) @@ -77,6 +220,7 @@ def test_constant_folding_maximum_literal_plus(): actual = ConstantFolding.apply(testee) assert actual == expected +def test_constant_folding_maximum_literal_plus14(): testee = im.call("maximum")( im.ref("__out_size_1"), im.minus(im.ref("__out_size_1"), im.literal_from_value(1)) ) @@ -84,13 +228,15 @@ def test_constant_folding_maximum_literal_plus(): actual = ConstantFolding.apply(testee) assert actual == expected +def test_constant_folding_maximum_literal_plus15(): testee = im.call("maximum")( im.ref("__out_size_1"), im.minus(im.ref("__out_size_1"), im.literal_from_value(-1)) ) - expected = im.minus(im.ref("__out_size_1"), im.literal_from_value(-1)) + expected = im.plus(im.ref("__out_size_1"), im.literal_from_value(1)) actual = ConstantFolding.apply(testee) assert actual == expected +def test_constant_folding_maximum_literal_plus16(): testee = im.call("minimum")( im.plus(im.ref("__out_size_1"), im.literal_from_value(1)), im.ref("__out_size_1") ) @@ -98,6 +244,7 @@ def test_constant_folding_maximum_literal_plus(): actual = ConstantFolding.apply(testee) assert actual == expected +def test_constant_folding_maximum_literal_plus17(): testee = im.call("minimum")( im.ref("__out_size_1"), im.plus(im.ref("__out_size_1"), im.literal_from_value(1)) ) @@ -105,6 +252,7 @@ def test_constant_folding_maximum_literal_plus(): actual = ConstantFolding.apply(testee) assert actual == expected +def test_constant_folding_maximum_literal_plus18(): testee = im.call("minimum")( im.ref("__out_size_1"), im.plus(im.ref("__out_size_1"), im.literal_from_value(-1)) ) @@ -112,20 +260,23 @@ def test_constant_folding_maximum_literal_plus(): actual = ConstantFolding.apply(testee) assert actual == expected +def test_constant_folding_maximum_literal_plus19(): testee = im.call("minimum")( im.minus(im.ref("__out_size_1"), im.literal_from_value(1)), im.ref("__out_size_1") ) - expected = im.minus(im.ref("__out_size_1"), im.literal_from_value(1)) + expected = im.plus(im.ref("__out_size_1"), im.literal_from_value(-1)) actual = ConstantFolding.apply(testee) assert actual == expected +def test_constant_folding_maximum_literal_plus20(): testee = im.call("minimum")( im.ref("__out_size_1"), im.minus(im.ref("__out_size_1"), im.literal_from_value(1)) ) - expected = im.minus(im.ref("__out_size_1"), im.literal_from_value(1)) + expected = im.plus(im.ref("__out_size_1"), im.literal_from_value(-1)) actual = ConstantFolding.apply(testee) assert actual == expected +def test_constant_folding_maximum_literal_plus21(): testee = im.call("minimum")( im.ref("__out_size_1"), im.minus(im.ref("__out_size_1"), im.literal_from_value(-1)) ) @@ -146,3 +297,68 @@ def test_constant_folding_literal_maximum(): expected = im.literal_from_value(2) actual = ConstantFolding.apply(testee) assert actual == expected + +def test_constant_folding_complex(): + # 1 - max(max(1, max(1, sym), min(1, sym), sym), 1 + (min(-1, 2) + max(-1, 1 - sym))) + testee = im.minus(im.literal_from_value(1), im.call("maximum")(im.call("maximum")(im.literal_from_value(1), im.call("maximum")(im.literal_from_value(1), im.ref("sym")), im.call("minimum")(im.literal_from_value(1), im.ref("sym")), im.ref("sym")), im.plus(im.literal_from_value(1), im.plus(im.call("minimum")(im.literal_from_value(-1), 2 ), im.call("maximum")(im.literal_from_value(-1), im.minus (im.literal_from_value(1), im.ref("sym"))))))) + # 1 - max(max(sym, 1), max(1 - sym, -1)) + expected = im.minus(im.literal_from_value(1),im.call("maximum")(im.call("maximum")(im.ref("sym"), im.literal_from_value(1)), im.call("maximum")(im.minus(im.literal_from_value(1), im.ref("sym")),im.literal_from_value(-1)))) + actual = ConstantFolding.apply(testee) + assert actual == expected + + + +#( (min(1 - sym, 1 + sym) + (max(max(1 - sym, 1 + sym),1 - sym) + max(1 - sym, 1 - sym)))))) - 2 + #max(sym, 1 + sym) + (max(1, max(1, sym)) + (sym - 1 + (1 + (sym + 1) + 1))) - 2 + +def test_constant_folding_complex_1(): + sym = im.ref("sym") + # maximum(sym, 1 + sym) + (maximum(1, maximum(1, sym)) + (sym - 1 + (1 + (sym + 1) + 1))) - 2 + testee = im.minus(im.plus(im.call("maximum")(sym, im.plus(im.literal_from_value(1), sym)), im.plus(im.call("maximum")(im.literal_from_value(1), im.call("maximum")(im.literal_from_value(1),sym)),im.plus(im.minus(sym,im.literal_from_value(1)), im.plus(im.plus(im.literal_from_value(1),im.plus(sym,im.literal_from_value(1))),im.literal_from_value(1)) ))) , im.literal_from_value(2)) + # sym + 1 + (maximum(sym, 1) + (sym + sym + 2)) + -2 + expected = im.plus(im.plus(im.plus(sym,1),im.plus( im.call("maximum")(sym,im.literal_from_value(1)),im.plus(im.plus(sym, sym),im.literal_from_value(2)))), im.literal_from_value(-2)) + actual = ConstantFolding.apply(testee) + assert actual == expected + +def test_constant_folding_complex_3(): + sym = im.ref("sym") + # minimum(1 - sym, 1 + sym) + (maximum(maximum(1 - sym, 1 + sym), 1 - sym) + maximum(1 - sym, 1 - sym)) + testee = im.plus(im.call("minimum")(im.minus(im.literal_from_value(1), sym), im.plus(im.literal_from_value(1), sym)), im.plus(im.call("maximum")(im.call("maximum")(im.minus(im.literal_from_value(1), sym), im.plus(im.literal_from_value(1), sym)), im.minus(im.literal_from_value(1), sym)),im.call("maximum")(im.minus(im.literal_from_value(1), sym), im.minus(im.literal_from_value(1), sym)))) + # minimum(1 - sym, sym + 1) + (maximum(sym + 1, 1 - sym) + (1 - sym)) + expected = im.plus(im.call("minimum")(im.minus(im.literal_from_value(1), sym), im.plus(sym, im.literal_from_value(1))), im.plus(im.call("maximum")(im.plus(sym, im.literal_from_value(1)), im.minus(im.literal_from_value(1), sym)),im.minus(im.literal_from_value(1), sym))) + actual = ConstantFolding.apply(testee) + assert actual == expected + +def test_constant_folding_complex_3a(): + sym = im.ref("sym") + # maximum(maximum(1 + sym, 1), 1 + sym) + testee = im.call("maximum")(im.call("maximum")( im.plus(im.literal_from_value(1), sym), 1), im.plus(im.literal_from_value(1), sym)) + # maximum(1 + sym, 1) + expected =im.call("maximum")( im.plus(sym, im.literal_from_value(1)), 1) + actual = ConstantFolding.apply(testee) + assert actual == expected + + +def test_constant_folding_complex_2(): + sym = im.ref("sym") + testee = im.plus(im.plus(sym, im.literal_from_value(-1)),im.plus(sym, im.literal_from_value(3))) + expected = im.plus(im.plus(sym, sym), im.literal_from_value(2)) + actual = ConstantFolding.apply(testee) + assert actual == expected + +#sym + 1 + (maximum(sym, 1) + (sym + -1 + (sym + 3))) + -2 +#maximum(1, sym) + (3 × sym + 1) + + + + +def test_constant_folding_complex_4(): + sym = im.ref("sym", "float32") + testee = im.divides_(im.minus(im.literal_from_value(1), sym), im.minus( im.literal_from_value(2), sym)) + expected = im.divides_(im.minus(im.literal_from_value(1), sym), im.minus( im.literal_from_value(2), sym)) + actual = ConstantFolding.apply(testee) + assert actual == expected + + + + From c630b2d7cca67f98b94dd2202cf24a715cf1c8a1 Mon Sep 17 00:00:00 2001 From: Sara Faghih-Naini Date: Tue, 21 Jan 2025 18:34:04 +0100 Subject: [PATCH 03/39] Fix ConvertMinusToUnary --- .../iterator/transforms/constant_folding.py | 38 ++++++++++++------- .../transforms_tests/test_constant_folding.py | 12 +++--- 2 files changed, 29 insertions(+), 21 deletions(-) diff --git a/src/gt4py/next/iterator/transforms/constant_folding.py b/src/gt4py/next/iterator/transforms/constant_folding.py index 7c64f8a08c..325f5efa6f 100644 --- a/src/gt4py/next/iterator/transforms/constant_folding.py +++ b/src/gt4py/next/iterator/transforms/constant_folding.py @@ -23,22 +23,31 @@ class ConvertMinusToUnary(PreserveLocationVisitor, NodeTranslator): def visit(self, node: ir.Node): node = self.generic_visit(node) - # im.minus(1, im.ref("a")) -> im.plus(im.call("neg)(im.ref("a")), 1) - if cpm.is_call_to(node, "minus"): - if isinstance(node.args[1], (ir.SymRef, ir.FunCall)): - node = im.plus(im.call("neg")(node.args[1]), node.args[0]) + # im.call("...")(im.minus(1, im.ref("a")), 1) -> im.call("...")(im.plus(im.call("neg)(im.ref("a")), 1), 1) + if isinstance(node, ir.FunCall) and len(node.args) > 0 and cpm.is_call_to(node.args[0], "minus"): + if cpm.is_call_to(node, ("minus", "plus", "multiplies", "divides")): + if isinstance(node.args[0].args[1], (ir.SymRef, ir.FunCall)): + node = im.call(node.fun.id)(im.plus(im.call("neg")(node.args[0].args[1]), node.args[0].args[0]), node.args[1]) + node = self.visit(node) + # im.call("...")(1, im.minus(1, im.ref("a"))) -> im.call("...")(1, im.plus(im.call("neg)(im.ref("a")), 1)) + elif isinstance(node, ir.FunCall) and len(node.args) > 1 and cpm.is_call_to(node.args[1], "minus"): + if cpm.is_call_to(node, ("minus", "plus", "multiplies", "divides")): + if isinstance(node.args[1].args[1], (ir.SymRef, ir.FunCall)): + node = im.call(node.fun.id)(node.args[0], im.plus(im.call("neg")(node.args[1].args[1]), node.args[1].args[0])) + node = self.visit(node) return node class ConvertUnaryToMinus(PreserveLocationVisitor, NodeTranslator): def visit(self, node: ir.Node): - if isinstance(node, ir.FunCall) and cpm.is_call_to(node.args[0], "neg"): + if isinstance(node, ir.FunCall) and len(node.args) > 0 and cpm.is_call_to(node.args[0], "neg"): manipulated_first_arg = False - if node.args[0].args[0].type: - zero = im.literal(str(0), node.args[0].args[0].type) - else: - zero = im.literal_from_value(0.0) # TODO: fix datatype + if cpm.is_call_to(node, ("multiplies", "divides", "maximum", "minimum")): + if node.args[0].args[0].type: + zero = im.literal(str(0), node.args[0].args[0].type) + else: + zero = im.literal_from_value(0.0) # TODO: fix datatype if cpm.is_call_to(node, "minus"): node = im.plus(node.args[1], node.args[0].args[0]) manipulated_first_arg = True @@ -59,11 +68,12 @@ def visit(self, node: ir.Node): manipulated_first_arg = True if manipulated_first_arg: node = self.visit(node) - elif isinstance(node, ir.FunCall) and len(node.args) > 2 and cpm.is_call_to(node.args[1], "neg"): - if node.args[0].args[0].type: - zero = im.literal(str(0), node.args[0].args[0].type) - else: - zero = im.literal_from_value(0.0) # TODO: fix datatype + elif isinstance(node, ir.FunCall) and len(node.args) > 1 and cpm.is_call_to(node.args[1], "neg"): + if cpm.is_call_to(node, ("multiplies", "divides", "maximum", "minimum")): + if node.args[0].args[0].type: + zero = im.literal(str(0), node.args[0].args[0].type) + else: + zero = im.literal_from_value(0.0) # TODO: fix datatype if cpm.is_call_to(node, "minus"): node = im.plus(node.args[0], node.args[1].args[0]) elif cpm.is_call_to(node, "plus"): diff --git a/tests/next_tests/unit_tests/iterator_tests/transforms_tests/test_constant_folding.py b/tests/next_tests/unit_tests/iterator_tests/transforms_tests/test_constant_folding.py index 3d82851f19..3371224ad7 100644 --- a/tests/next_tests/unit_tests/iterator_tests/transforms_tests/test_constant_folding.py +++ b/tests/next_tests/unit_tests/iterator_tests/transforms_tests/test_constant_folding.py @@ -346,11 +346,6 @@ def test_constant_folding_complex_2(): actual = ConstantFolding.apply(testee) assert actual == expected -#sym + 1 + (maximum(sym, 1) + (sym + -1 + (sym + 3))) + -2 -#maximum(1, sym) + (3 × sym + 1) - - - def test_constant_folding_complex_4(): sym = im.ref("sym", "float32") @@ -360,5 +355,8 @@ def test_constant_folding_complex_4(): assert actual == expected - - +def test_constant_folding_max(): + testee = im.call("maximum")(0,0) + expected = im.literal_from_value(0) + actual = ConstantFolding.apply(testee) + assert actual == expected \ No newline at end of file From fddbc81908c022e5efa79d4b85669c53ec5e3f00 Mon Sep 17 00:00:00 2001 From: Sara Faghih-Naini Date: Tue, 21 Jan 2025 19:34:18 +0100 Subject: [PATCH 04/39] Cleanup ConstantFolding --- .../iterator/transforms/constant_folding.py | 136 ++++++++---------- .../transforms_tests/test_constant_folding.py | 7 + 2 files changed, 65 insertions(+), 78 deletions(-) diff --git a/src/gt4py/next/iterator/transforms/constant_folding.py b/src/gt4py/next/iterator/transforms/constant_folding.py index 325f5efa6f..5eef13c5b6 100644 --- a/src/gt4py/next/iterator/transforms/constant_folding.py +++ b/src/gt4py/next/iterator/transforms/constant_folding.py @@ -94,30 +94,26 @@ def visit(self, node: ir.Node): @dataclasses.dataclass(frozen=True) class ConstantFolding(PreserveLocationVisitor, NodeTranslator): class Flag(enum.Flag): - # literal + symref -> symref + literal - CANONICALIZE_SYMREF_LITERAL = enum.auto() + # e.g. `literal + symref` -> `symref + literal` and + # `literal + funcall` -> `funcall + literal` and + # `symref + funcall` -> `funcall + symref` + CANONICALIZE_FUNCALL_SYMREF_LITERAL = enum.auto() - # literal + funcall -> funcall + literal - CANONICALIZE_FUNCALL_LITERAL = enum.auto() + # `minus(symref, literal) -> plus(symref,-literal)` + CANONICALIZE_MINUS_SYMREF_LITERAL = enum.auto() - # `__out_size_1 + 1 + 1` -> `__out_size_1 + 2` + # `sym + 1 + 1` -> `sym + 2` FOLD_FUNCALL_LITERAL = enum.auto() - # `maximum(1, __out_size_1)` -> `maximum(__out_size_1, 1)` and `maximum(__out_size_1, maximum(__out_size_1, 1))` -> `maximum(maximum(__out_size_1, 1), __out_size_1)` - CANONICALIZE_MIN_MAX_FUNCALL_SYMREF_LITERAL = enum.auto() - - # `maximum(maximum(__out_size_1, 1), __out_size_1)` -> `maximum(__out_size_1, 1)` + # `maximum(maximum(sym, 1), sym)` -> `maximum(sym, 1)` FOLD_MIN_MAX_FUNCALL_SYMREF_LITERAL = enum.auto() - # `minus(__out_size_1, literal) -> plus(__out_size_1,-literal)` - CANONICALIZE_MINUS_SYMREF_LITERAL = enum.auto() + # `maximum(plus(sym, 1), sym)` -> `plus(sym, 1)` and + # `maximum(plus(sym, 1), plus(sym, -1))` -> `plus(sym, 1)` + FOLD_MIN_MAX_PLUS= enum.auto() - # `maximum(plus(__out_size_1, 1), __out_size_1)` -> `plus(__out_size_1, 1)` - # and `maximum(plus(__out_size_1, 1), plus(__out_size_1, -1))` -> `plus(__out_size_1, 1)` - FOLD_MIN_MAX_PLUS_MINUS = enum.auto() - - # `__out_size_1 + 0` -> `__out_size_1` - FOLD_SYMREF_PLUS_MINUS_ZERO = enum.auto() + # `sym + 0` -> `sym` + FOLD_SYMREF_PLUS_ZERO = enum.auto() # `sym + 1 + (sym + 2)` -> `sym + sym + 2 + 1` CANONICALIZE_PLUS_SYMREF_LITERAL = enum.auto() @@ -181,89 +177,73 @@ def transform(self, node: ir.Node, **kwargs) -> Optional[ir.Node]: return result return None - def transform_canonicalize_symref_literal(self, node: ir.FunCall, **kwargs) -> Optional[ir.Node]: - # literal + symref -> symref + literal - if cpm.is_call_to(node, ("plus", "multiplies")): - if cpm.is_call_to(node, ("plus", "times")): - if isinstance(node.args[1], ir.SymRef) and isinstance(node.args[0], ir.Literal): - return im.call(node.fun.id)(node.args[1], node.args[0]) + def transform_canonicalize_funcall_symref_literal(self, node: ir.FunCall, **kwargs) -> Optional[ir.Node]: + # e.g. `literal + symref` -> `symref + literal` and + # `literal + funcall` -> `funcall + literal` and + # `symref + funcall` -> `funcall + symref` + if cpm.is_call_to(node, ("plus", "multiplies", "minimum", "maximum")): + if (isinstance(node.args[1], (ir.SymRef, ir.FunCall)) and isinstance(node.args[0], ir.Literal) + or isinstance(node.args[1], ir.FunCall) and isinstance(node.args[0], ir.SymRef)): + return im.call(node.fun.id)(node.args[1], node.args[0]) return None - def transform_canonicalize_funcall_literal(self, node: ir.FunCall, **kwargs) -> Optional[ir.Node]: - # literal + funcall -> funcall + literal - if cpm.is_call_to(node, ("plus", "multiplies")): - if cpm.is_call_to(node, ("plus", "multiplies")): - if isinstance(node.args[1], ir.FunCall) and isinstance(node.args[0], ir.Literal): - return im.call(node.fun.id)(node.args[1], node.args[0]) + + def transform_canonicalize_minus_symref_literal(self, node: ir.FunCall, **kwargs) -> Optional[ir.Node]: + # `minus(symref, literal) -> plus(symref,-literal)` + if (cpm.is_call_to(node, "minus") and + isinstance(node.args[0], (ir.SymRef, ir.FunCall)) and + isinstance(node.args[1], ir.Literal)): + return self.visit(im.plus(node.args[0], im.minus(0, node.args[1]))) return None + def transform_fold_funcall_literal(self, node: ir.FunCall, **kwargs) -> Optional[ir.Node]: - # `__out_size_1 + 1 + 1` -> `__out_size_1 + 2` - if cpm.is_call_to(node, ("plus", "minus")): - if isinstance(node.args[0], ir.FunCall) and isinstance( - node.args[1], ir.Literal - ): + # `sym + 1 + 1` -> `sym + 2` + if cpm.is_call_to(node, "plus"): + if cpm.is_call_to(node.args[0], "plus") and isinstance( node.args[1], ir.Literal): fun_call, literal = node.args - if cpm.is_call_to(fun_call, ("plus", "minus")): - if isinstance(fun_call.args[0], (ir.SymRef, ir.FunCall)) and isinstance( - fun_call.args[1], ir.Literal - ): - return self.visit(im.plus( - fun_call.args[0], - self.visit(im.call(node.fun.id)(fun_call.args[1], literal)))) - return None - - def transform_canonicalize_min_max_funcall_symref_literal(self, node: ir.FunCall, **kwargs) -> Optional[ir.Node]: - # `maximum(1, __out_size_1)` -> `maximum(__out_size_1, 1)` and `maximum(__out_size_1, maximum(__out_size_1, 1))` -> `maximum(maximum(__out_size_1, 1), __out_size_1)` - if cpm.is_call_to(node, ("minimum", "maximum")): - if ((isinstance(node.args[0], ir.Literal) and isinstance(node.args[1], (ir.SymRef, ir.FunCall))) or - (isinstance(node.args[0], ir.SymRef) and isinstance(node.args[1], ir.FunCall))): - return im.call(node.fun.id)(node.args[1], node.args[0]) + if isinstance(fun_call.args[0], (ir.SymRef, ir.FunCall)) and isinstance( fun_call.args[1], ir.Literal): + return self.visit(im.plus(fun_call.args[0], im.plus(fun_call.args[1], literal))) return None - def transform_fold_min_max_funcall_symref_literal(self, node: ir.FunCall, **kwargs) -> Optional[ir.Node]: - # `maximum(maximum(__out_size_1, 1), __out_size_1)` -> `maximum(__out_size_1, 1)` + # `maximum(maximum(sym, 1), sym)` -> `maximum(sym, 1)` if cpm.is_call_to(node, ("minimum", "maximum")): - if isinstance(node.args[0], ir.FunCall): + if cpm.is_call_to(node.args[0], ("maximum", "minimum")): fun_call, arg1, = node.args - if cpm.is_call_to(fun_call, ("maximum", "minimum")): - if arg1 == fun_call.args[0]: - return self.visit(im.call(fun_call.fun.id)(fun_call.args[1], arg1)) - if arg1 == fun_call.args[1]: - return self.visit(im.call(fun_call.fun.id)(fun_call.args[0], arg1)) + if arg1 == fun_call.args[0]: + return self.visit(im.call(fun_call.fun.id)(fun_call.args[1], arg1)) + if arg1 == fun_call.args[1]: + return self.visit(im.call(fun_call.fun.id)(fun_call.args[0], arg1)) return None - def transform_canonicalize_minus_symref_literal(self, node: ir.FunCall, **kwargs) -> Optional[ir.Node]: - # `minus(__out_size_1, literal) -> plus(__out_size_1,-literal)` - if cpm.is_call_to(node, "minus") and isinstance(node.args[0], (ir.SymRef, ir.FunCall)) and isinstance(node.args[1], ir.Literal): - return self.visit(im.plus(node.args[0], im.minus(0, node.args[1]))) - return None - def transform_fold_min_max_plus_minus(self, node: ir.FunCall, **kwargs) -> Optional[ir.Node]: + def transform_fold_min_max_plus(self, node: ir.FunCall, **kwargs) -> Optional[ir.Node]: if cpm.is_call_to(node, ("minimum", "maximum")): arg0, arg1 = node.args - # `maximum(plus(__out_size_1, 1), __out_size_1)` -> `plus(__out_size_1, 1)` - if cpm.is_call_to(arg0, ("plus", "minus")) and isinstance(arg1, (ir.SymRef, ir.FunCall)): + # `maximum(plus(sym, 1), sym)` -> `plus(sym, 1)` + if cpm.is_call_to(arg0, "plus"): if arg0.args[0] == arg1: - return self.visit(im.call(arg0.fun.id)(arg0.args[0], im.call(node.fun.id)(0, arg0.args[1]))) - # `maximum(plus(__out_size_1, 1), plus(__out_size_1, -1))` -> `plus(__out_size_1, 1)` - if cpm.is_call_to(arg0, ("plus", "minus")) and cpm.is_call_to(arg1, ("plus", "minus")): + return self.visit(im.plus(arg0.args[0], im.call(node.fun.id)(arg0.args[1], 0))) + # `maximum(plus(sym, 1), plus(sym, -1))` -> `plus(sym, 1)` + if cpm.is_call_to(arg0, "plus") and cpm.is_call_to(arg1, "plus"): if arg0.args[0] == arg1.args[0]: - return self.visit(im.call(arg0.fun.id)(arg0.args[0], im.call(node.fun.id)(arg0.args[1], arg1.args[1]))) + return self.visit(im.plus(arg0.args[0], im.call(node.fun.id)(arg0.args[1], arg1.args[1]))) return None - def transform_fold_symref_plus_minus_zero(self, node: ir.FunCall, **kwargs) -> Optional[ir.Node]: - # `__out_size_1 + 0` -> `__out_size_1` - if cpm.is_call_to(node, ("plus", "minus")) and isinstance(node.args[0], (ir.SymRef, ir.FunCall)) and isinstance(node.args[1], ir.Literal) and node.args[1].value.isdigit() and int(node.args[1].value) == 0: + def transform_fold_symref_plus_zero(self, node: ir.FunCall, **kwargs) -> Optional[ir.Node]: + # `sym + 0` -> `sym` + if (cpm.is_call_to(node, "plus") and isinstance(node.args[1], ir.Literal) and + node.args[1].value.isdigit() and int(node.args[1].value) == 0): return node.args[0] return None def transform_canonicalize_plus_symref_literal(self, node: ir.FunCall, **kwargs) -> Optional[ir.Node]: # `sym1 + 1 + (sym2 + 2)` -> `sym1 + sym2 + 2 + 1` if cpm.is_call_to(node, "plus"): - if cpm.is_call_to(node.args[0], "plus") and cpm.is_call_to(node.args[1], "plus") and isinstance(node.args[0].args[1], ir.Literal) and isinstance(node.args[1].args[1], ir.Literal): - return self.visit(im.plus(im.plus(node.args[0].args[0], node.args[1].args[0]),im.plus(node.args[0].args[1], node.args[1].args[1]))) + if (cpm.is_call_to(node.args[0], "plus") and cpm.is_call_to(node.args[1], "plus") and + isinstance(node.args[0].args[1], ir.Literal) and isinstance(node.args[1].args[1], ir.Literal)): + return self.visit(im.plus(im.plus(node.args[0].args[0], node.args[1].args[0]), im.plus(node.args[0].args[1], node.args[1].args[1]))) return None def transform_fold_arithmetic_builtins(self, node: ir.FunCall, **kwargs) -> Optional[ir.Node]: @@ -310,7 +290,7 @@ def transform_fold_if(self, node: ir.FunCall, **kwargs) -> Optional[ir.Node]: return node.args[2] return None - # # `maximum(maximum(__out_size_1, 1), maximum(1, __out_size_1))` -> `maximum(__out_size_1, 1)` + # # `maximum(maximum(sym, 1), maximum(1, sym))` -> `maximum(sym, 1)` # if cpm.is_call_to(new_node, ("minimum", "maximum")): # if all(cpm.is_call_to(arg, "maximum") for arg in new_node.args) or all( # cpm.is_call_to(arg, "minimum") for arg in new_node.args @@ -320,7 +300,7 @@ def transform_fold_if(self, node: ir.FunCall, **kwargs) -> Optional[ir.Node]: # and new_node.args[0].args[1] == new_node.args[1].args[0] # ): # new_node = new_node.args[0] - # # `maximum(maximum(__out_size_1, 1), __out_size_1)` -> `maximum(__out_size_1, 1)` + # # `maximum(maximum(sym, 1), sym)` -> `maximum(sym, 1)` # if cpm.is_call_to(new_node, ("minimum", "maximum")): # match = False # if isinstance(new_node.args[0], ir.FunCall) and isinstance( @@ -348,7 +328,7 @@ def transform_fold_if(self, node: ir.FunCall, **kwargs) -> Optional[ir.Node]: # new_node = im.call(fun_call.fun.id)(fun_call.args[1], sym_lit) # if sym_lit == fun_call.args[1]: # new_node = im.call(fun_call.fun.id)(sym_lit, fun_call.args[0]) - # # `maximum(plus(__out_size_1, 1), minus(__out_size_1, 1))` -> `plus(__out_size_1, 1)` + # # `maximum(plus(sym, 1), minus(sym, 1))` -> `plus(sym, 1)` # if cpm.is_call_to(new_node, ("minimum", "maximum")): # if all(cpm.is_call_to(arg, ("plus", "minus")) for arg in new_node.args): # if new_node.args[0].args[0] == new_node.args[1].args[0]: @@ -361,7 +341,7 @@ def transform_fold_if(self, node: ir.FunCall, **kwargs) -> Optional[ir.Node]: # ) # ), # ) - # # `maximum(plus(__out_size_1, 1), __out_size_1)` -> `plus(__out_size_1, 1)` + # # `maximum(plus(sym, 1), sym)` -> `plus(sym, 1)` # match = False # if isinstance(new_node.args[0], ir.FunCall) and isinstance(new_node.args[1], ir.SymRef): # fun_call, sym_ref = new_node.args diff --git a/tests/next_tests/unit_tests/iterator_tests/transforms_tests/test_constant_folding.py b/tests/next_tests/unit_tests/iterator_tests/transforms_tests/test_constant_folding.py index 3371224ad7..46dc2c5bcb 100644 --- a/tests/next_tests/unit_tests/iterator_tests/transforms_tests/test_constant_folding.py +++ b/tests/next_tests/unit_tests/iterator_tests/transforms_tests/test_constant_folding.py @@ -359,4 +359,11 @@ def test_constant_folding_max(): testee = im.call("maximum")(0,0) expected = im.literal_from_value(0) actual = ConstantFolding.apply(testee) + assert actual == expected + +def test_constant_folding_plus_new(): + sym = im.ref("sym") + testee = im.plus(im.minus(sym, im.literal_from_value(1)), im.literal_from_value(2)) + expected = im.plus(sym, im.literal_from_value(1)) + actual = ConstantFolding.apply(testee) assert actual == expected \ No newline at end of file From 16cec90e3052f431ac874ce4c94062f2abb50c4e Mon Sep 17 00:00:00 2001 From: Sara Faghih-Naini Date: Tue, 21 Jan 2025 19:35:35 +0100 Subject: [PATCH 05/39] Cleanup ConstantFolding --- .../iterator/transforms/constant_folding.py | 75 ------------------- 1 file changed, 75 deletions(-) diff --git a/src/gt4py/next/iterator/transforms/constant_folding.py b/src/gt4py/next/iterator/transforms/constant_folding.py index 5eef13c5b6..db40320221 100644 --- a/src/gt4py/next/iterator/transforms/constant_folding.py +++ b/src/gt4py/next/iterator/transforms/constant_folding.py @@ -290,78 +290,3 @@ def transform_fold_if(self, node: ir.FunCall, **kwargs) -> Optional[ir.Node]: return node.args[2] return None - # # `maximum(maximum(sym, 1), maximum(1, sym))` -> `maximum(sym, 1)` - # if cpm.is_call_to(new_node, ("minimum", "maximum")): - # if all(cpm.is_call_to(arg, "maximum") for arg in new_node.args) or all( - # cpm.is_call_to(arg, "minimum") for arg in new_node.args - # ): - # if ( - # new_node.args[0].args[0] == new_node.args[1].args[1] - # and new_node.args[0].args[1] == new_node.args[1].args[0] - # ): - # new_node = new_node.args[0] - # # `maximum(maximum(sym, 1), sym)` -> `maximum(sym, 1)` - # if cpm.is_call_to(new_node, ("minimum", "maximum")): - # match = False - # if isinstance(new_node.args[0], ir.FunCall) and isinstance( - # new_node.args[1], (ir.Literal, ir.SymRef) - # ): - # fun_call, sym_lit = new_node.args - # match = True - # elif isinstance(new_node.args[0], (ir.Literal, ir.SymRef)) and isinstance( - # new_node.args[1], ir.FunCall - # ): - # match = True - # sym_lit, fun_call = new_node.args - # if match and cpm.is_call_to(fun_call, ("maximum", "minimum")): - # if isinstance(fun_call.args[0], ir.SymRef) and isinstance( - # fun_call.args[1], ir.Literal - # ): - # if sym_lit == fun_call.args[0]: - # new_node = im.call(fun_call.fun.id)(sym_lit, fun_call.args[1]) - # if sym_lit == fun_call.args[1]: - # new_node = im.call(fun_call.fun.id)(fun_call.args[0], sym_lit) - # if isinstance(fun_call.args[0], ir.Literal) and isinstance( - # fun_call.args[1], ir.SymRef - # ): - # if sym_lit == fun_call.args[0]: - # new_node = im.call(fun_call.fun.id)(fun_call.args[1], sym_lit) - # if sym_lit == fun_call.args[1]: - # new_node = im.call(fun_call.fun.id)(sym_lit, fun_call.args[0]) - # # `maximum(plus(sym, 1), minus(sym, 1))` -> `plus(sym, 1)` - # if cpm.is_call_to(new_node, ("minimum", "maximum")): - # if all(cpm.is_call_to(arg, ("plus", "minus")) for arg in new_node.args): - # if new_node.args[0].args[0] == new_node.args[1].args[0]: - # new_node = im.plus( - # new_node.args[0].args[0], - # self.visit( - # im.call(new_node.fun.id)( - # im.call(new_node.args[0].fun.id)(0, new_node.args[0].args[1]), - # im.call(new_node.args[1].fun.id)(0, new_node.args[1].args[1]), - # ) - # ), - # ) - # # `maximum(plus(sym, 1), sym)` -> `plus(sym, 1)` - # match = False - # if isinstance(new_node.args[0], ir.FunCall) and isinstance(new_node.args[1], ir.SymRef): - # fun_call, sym_ref = new_node.args - # match = True - # elif isinstance(new_node.args[0], ir.SymRef) and isinstance( - # new_node.args[1], ir.FunCall - # ): - # match = True - # sym_ref, fun_call = new_node.args - # if match and fun_call.fun.id in ["plus", "minus"]: - # if fun_call.args[0] == sym_ref: - # if new_node.fun.id == "minimum": - # if fun_call.fun.id == "plus": - # new_node = sym_ref if int(fun_call.args[1].value) >= 0 else fun_call - # elif fun_call.fun.id == "minus": - # new_node = fun_call if int(fun_call.args[1].value) > 0 else sym_ref - # elif new_node.fun.id == "maximum": - # if fun_call.fun.id == "plus": - # new_node = fun_call if int(fun_call.args[1].value) > 0 else sym_ref - # elif fun_call.fun.id == "minus": - # new_node = sym_ref if int(fun_call.args[1].value) >= 0 else fun_call - - # return new_node From 4a52c380595f0fff23c7cae155e81735aba3a5f9 Mon Sep 17 00:00:00 2001 From: Sara Faghih-Naini Date: Wed, 22 Jan 2025 11:10:23 +0100 Subject: [PATCH 06/39] Fix imports --- src/gt4py/next/iterator/transforms/constant_folding.py | 7 +------ 1 file changed, 1 insertion(+), 6 deletions(-) diff --git a/src/gt4py/next/iterator/transforms/constant_folding.py b/src/gt4py/next/iterator/transforms/constant_folding.py index 76afd41ab2..10c8904b1e 100644 --- a/src/gt4py/next/iterator/transforms/constant_folding.py +++ b/src/gt4py/next/iterator/transforms/constant_folding.py @@ -7,8 +7,7 @@ # SPDX-License-Identifier: BSD-3-Clause from gt4py.eve import NodeTranslator, PreserveLocationVisitor -<<<<<<< HEAD -from gt4py.next.iterator import embedded, ir +from gt4py.next.iterator import builtins, embedded, ir import functools import operator from gt4py.next.iterator.ir_utils import common_pattern_matcher as cpm, ir_makers as im @@ -17,10 +16,6 @@ import dataclasses import enum from typing import Optional -======= -from gt4py.next.iterator import builtins, embedded, ir -from gt4py.next.iterator.ir_utils import ir_makers as im ->>>>>>> origin-main From 8dd80a283529333cd280836d85739b99636940ee Mon Sep 17 00:00:00 2001 From: Sara Faghih-Naini Date: Wed, 22 Jan 2025 11:17:06 +0100 Subject: [PATCH 07/39] Some fixes --- src/gt4py/next/iterator/builtins.py | 4 ++++ src/gt4py/next/iterator/embedded.py | 2 +- 2 files changed, 5 insertions(+), 1 deletion(-) diff --git a/src/gt4py/next/iterator/builtins.py b/src/gt4py/next/iterator/builtins.py index 920bc6a69f..95c1058d2e 100644 --- a/src/gt4py/next/iterator/builtins.py +++ b/src/gt4py/next/iterator/builtins.py @@ -291,6 +291,10 @@ def ceil(*args): def trunc(*args): raise BackendNotSelectedError() +@builtin_dispatch +def neg(*args): + raise BackendNotSelectedError() + @builtin_dispatch def isfinite(*args): diff --git a/src/gt4py/next/iterator/embedded.py b/src/gt4py/next/iterator/embedded.py index 970e88e8c5..4888a28817 100644 --- a/src/gt4py/next/iterator/embedded.py +++ b/src/gt4py/next/iterator/embedded.py @@ -541,7 +541,7 @@ def promote_scalars(val: CompositeOfScalarOrField): } decorator = getattr(builtins, math_builtin_name).register(EMBEDDED) impl: Callable - if math_builtin_name in ["gamma", "not_"]: + if math_builtin_name in ["gamma", "not_", "neg"]: continue # treated explicitly elif math_builtin_name in python_builtins: # TODO: Should potentially use numpy fixed size types to be consistent From ce8adeca0ba2150b76a38890da5a62a959566414 Mon Sep 17 00:00:00 2001 From: Sara Faghih-Naini Date: Thu, 23 Jan 2025 10:47:11 +0100 Subject: [PATCH 08/39] Add builtin for unary minus --- src/gt4py/next/ffront/fbuiltins.py | 1 + src/gt4py/next/iterator/builtins.py | 6 ++++++ src/gt4py/next/iterator/embedded.py | 9 ++++++++- .../next/program_processors/codegens/gtfn/codegen.py | 1 + .../runners/dace/gtir_python_codegen.py | 1 + .../feature_tests/iterator_tests/test_builtins.py | 2 ++ 6 files changed, 19 insertions(+), 1 deletion(-) diff --git a/src/gt4py/next/ffront/fbuiltins.py b/src/gt4py/next/ffront/fbuiltins.py index cef7fc101f..49e798826f 100644 --- a/src/gt4py/next/ffront/fbuiltins.py +++ b/src/gt4py/next/ffront/fbuiltins.py @@ -227,6 +227,7 @@ def astype( "floor": math.floor, "ceil": math.ceil, "trunc": math.trunc, + "neg": np.negative, } UNARY_MATH_FP_BUILTIN_NAMES: Final = [*_UNARY_MATH_FP_BUILTIN_IMPL.keys()] diff --git a/src/gt4py/next/iterator/builtins.py b/src/gt4py/next/iterator/builtins.py index 959f451e01..7174107f0a 100644 --- a/src/gt4py/next/iterator/builtins.py +++ b/src/gt4py/next/iterator/builtins.py @@ -292,6 +292,11 @@ def trunc(*args): raise BackendNotSelectedError() +@builtin_dispatch +def neg(*args): + raise BackendNotSelectedError() + + @builtin_dispatch def isfinite(*args): raise BackendNotSelectedError() @@ -420,6 +425,7 @@ def bool(*args): # noqa: A001 [builtin-variable-shadowing] "floor", "ceil", "trunc", + "neg", } UNARY_MATH_FP_PREDICATE_BUILTINS = {"isfinite", "isinf", "isnan"} BINARY_MATH_NUMBER_BUILTINS = { diff --git a/src/gt4py/next/iterator/embedded.py b/src/gt4py/next/iterator/embedded.py index 970e88e8c5..8030478cb6 100644 --- a/src/gt4py/next/iterator/embedded.py +++ b/src/gt4py/next/iterator/embedded.py @@ -392,6 +392,13 @@ def not_(a): return not a +@builtins.neg.register(EMBEDDED) +def neg(a): + if isinstance(a, Column): + return np.negative(a) + return np.negative(a) + + @builtins.gamma.register(EMBEDDED) def gamma(a): gamma_ = np.vectorize(math.gamma) @@ -541,7 +548,7 @@ def promote_scalars(val: CompositeOfScalarOrField): } decorator = getattr(builtins, math_builtin_name).register(EMBEDDED) impl: Callable - if math_builtin_name in ["gamma", "not_"]: + if math_builtin_name in ["gamma", "not_", "neg"]: continue # treated explicitly elif math_builtin_name in python_builtins: # TODO: Should potentially use numpy fixed size types to be consistent diff --git a/src/gt4py/next/program_processors/codegens/gtfn/codegen.py b/src/gt4py/next/program_processors/codegens/gtfn/codegen.py index c6bf28d8e0..640b9a9940 100644 --- a/src/gt4py/next/program_processors/codegens/gtfn/codegen.py +++ b/src/gt4py/next/program_processors/codegens/gtfn/codegen.py @@ -76,6 +76,7 @@ class GTFNCodegen(codegen.TemplatedGenerator): "xor_": "std::bit_xor{}", "mod": "std::modulus{}", "not_": "std::logical_not{}", + "neg": "std::negate{}", } Sym = as_fmt("{id}") diff --git a/src/gt4py/next/program_processors/runners/dace/gtir_python_codegen.py b/src/gt4py/next/program_processors/runners/dace/gtir_python_codegen.py index dfbba9c88b..65d27e3af7 100644 --- a/src/gt4py/next/program_processors/runners/dace/gtir_python_codegen.py +++ b/src/gt4py/next/program_processors/runners/dace/gtir_python_codegen.py @@ -70,6 +70,7 @@ "xor_": "({} ^ {})", "mod": "({} % {})", "not_": "(not {})", # ~ is not bitwise in numpy + "neg": "(- {})", } diff --git a/tests/next_tests/integration_tests/feature_tests/iterator_tests/test_builtins.py b/tests/next_tests/integration_tests/feature_tests/iterator_tests/test_builtins.py index 885a272bfe..0e7ca9f769 100644 --- a/tests/next_tests/integration_tests/feature_tests/iterator_tests/test_builtins.py +++ b/tests/next_tests/integration_tests/feature_tests/iterator_tests/test_builtins.py @@ -46,6 +46,7 @@ plus, shift, xor_, + neg, ) from gt4py.next.iterator.runtime import fendef, fundef, offset, set_at from gt4py.next.program_processors.runners.gtfn import run_gtfn @@ -135,6 +136,7 @@ def fenimpl(size, arg0, arg1, arg2, out): def arithmetic_and_logical_test_data(): return [ # (builtin, inputs, expected) + (neg, [[-1.0, 1.0]], [1.0, -1.0]), (plus, [2.0, 3.0], 5.0), (minus, [2.0, 3.0], -1.0), (multiplies, [2.0, 3.0], 6.0), From d0b99edbbf411830b52adf3e7eb0ac3d722f78af Mon Sep 17 00:00:00 2001 From: Sara Faghih-Naini Date: Thu, 23 Jan 2025 11:23:32 +0100 Subject: [PATCH 09/39] Address review comments --- src/gt4py/next/ffront/fbuiltins.py | 4 ++-- src/gt4py/next/iterator/builtins.py | 3 +-- src/gt4py/next/iterator/embedded.py | 3 ++- src/gt4py/next/program_processors/codegens/gtfn/codegen.py | 2 +- .../program_processors/runners/dace/gtir_python_codegen.py | 2 +- 5 files changed, 7 insertions(+), 7 deletions(-) diff --git a/src/gt4py/next/ffront/fbuiltins.py b/src/gt4py/next/ffront/fbuiltins.py index 49e798826f..ee14006b22 100644 --- a/src/gt4py/next/ffront/fbuiltins.py +++ b/src/gt4py/next/ffront/fbuiltins.py @@ -10,6 +10,7 @@ import functools import inspect import math +import operator from builtins import bool, float, int, tuple # noqa: A004 shadowing a Python built-in from typing import Any, Callable, Final, Generic, ParamSpec, Tuple, TypeAlias, TypeVar, Union, cast @@ -203,7 +204,7 @@ def astype( return core_defs.dtype(type_).scalar_type(value) -_UNARY_MATH_NUMBER_BUILTIN_IMPL: Final = {"abs": abs} +_UNARY_MATH_NUMBER_BUILTIN_IMPL: Final = {"abs": abs, "neg": operator.neg} UNARY_MATH_NUMBER_BUILTIN_NAMES: Final = [*_UNARY_MATH_NUMBER_BUILTIN_IMPL.keys()] _UNARY_MATH_FP_BUILTIN_IMPL: Final = { @@ -227,7 +228,6 @@ def astype( "floor": math.floor, "ceil": math.ceil, "trunc": math.trunc, - "neg": np.negative, } UNARY_MATH_FP_BUILTIN_NAMES: Final = [*_UNARY_MATH_FP_BUILTIN_IMPL.keys()] diff --git a/src/gt4py/next/iterator/builtins.py b/src/gt4py/next/iterator/builtins.py index 7174107f0a..8e5f7addca 100644 --- a/src/gt4py/next/iterator/builtins.py +++ b/src/gt4py/next/iterator/builtins.py @@ -402,7 +402,7 @@ def bool(*args): # noqa: A001 [builtin-variable-shadowing] raise BackendNotSelectedError() -UNARY_MATH_NUMBER_BUILTINS = {"abs"} +UNARY_MATH_NUMBER_BUILTINS = {"abs", "neg"} UNARY_LOGICAL_BUILTINS = {"not_"} UNARY_MATH_FP_BUILTINS = { "sin", @@ -425,7 +425,6 @@ def bool(*args): # noqa: A001 [builtin-variable-shadowing] "floor", "ceil", "trunc", - "neg", } UNARY_MATH_FP_PREDICATE_BUILTINS = {"isfinite", "isinf", "isnan"} BINARY_MATH_NUMBER_BUILTINS = { diff --git a/src/gt4py/next/iterator/embedded.py b/src/gt4py/next/iterator/embedded.py index 8030478cb6..16b1fa9d03 100644 --- a/src/gt4py/next/iterator/embedded.py +++ b/src/gt4py/next/iterator/embedded.py @@ -545,10 +545,11 @@ def promote_scalars(val: CompositeOfScalarOrField): "and_": operator.and_, "or_": operator.or_, "xor_": operator.xor, + "neg": operator.neg, } decorator = getattr(builtins, math_builtin_name).register(EMBEDDED) impl: Callable - if math_builtin_name in ["gamma", "not_", "neg"]: + if math_builtin_name in ["gamma", "not_"]: continue # treated explicitly elif math_builtin_name in python_builtins: # TODO: Should potentially use numpy fixed size types to be consistent diff --git a/src/gt4py/next/program_processors/codegens/gtfn/codegen.py b/src/gt4py/next/program_processors/codegens/gtfn/codegen.py index 640b9a9940..50181941e6 100644 --- a/src/gt4py/next/program_processors/codegens/gtfn/codegen.py +++ b/src/gt4py/next/program_processors/codegens/gtfn/codegen.py @@ -23,6 +23,7 @@ class GTFNCodegen(codegen.TemplatedGenerator): _builtins_mapping: Final = { "abs": "std::abs", + "neg": "std::negate{}", "sin": "std::sin", "cos": "std::cos", "tan": "std::tan", @@ -76,7 +77,6 @@ class GTFNCodegen(codegen.TemplatedGenerator): "xor_": "std::bit_xor{}", "mod": "std::modulus{}", "not_": "std::logical_not{}", - "neg": "std::negate{}", } Sym = as_fmt("{id}") diff --git a/src/gt4py/next/program_processors/runners/dace/gtir_python_codegen.py b/src/gt4py/next/program_processors/runners/dace/gtir_python_codegen.py index 65d27e3af7..56a67510e7 100644 --- a/src/gt4py/next/program_processors/runners/dace/gtir_python_codegen.py +++ b/src/gt4py/next/program_processors/runners/dace/gtir_python_codegen.py @@ -20,6 +20,7 @@ MATH_BUILTINS_MAPPING = { "abs": "abs({})", + "neg": "(- {})", "sin": "math.sin({})", "cos": "math.cos({})", "tan": "math.tan({})", @@ -70,7 +71,6 @@ "xor_": "({} ^ {})", "mod": "({} % {})", "not_": "(not {})", # ~ is not bitwise in numpy - "neg": "(- {})", } From 8f8a8fb03b620487699a9346af5915b478a188c2 Mon Sep 17 00:00:00 2001 From: Sara Faghih-Naini Date: Thu, 23 Jan 2025 13:53:37 +0100 Subject: [PATCH 10/39] Add abs to test --- .../feature_tests/iterator_tests/test_builtins.py | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/tests/next_tests/integration_tests/feature_tests/iterator_tests/test_builtins.py b/tests/next_tests/integration_tests/feature_tests/iterator_tests/test_builtins.py index 0e7ca9f769..73ec64a7fb 100644 --- a/tests/next_tests/integration_tests/feature_tests/iterator_tests/test_builtins.py +++ b/tests/next_tests/integration_tests/feature_tests/iterator_tests/test_builtins.py @@ -47,6 +47,7 @@ shift, xor_, neg, + abs, ) from gt4py.next.iterator.runtime import fendef, fundef, offset, set_at from gt4py.next.program_processors.runners.gtfn import run_gtfn @@ -136,6 +137,7 @@ def fenimpl(size, arg0, arg1, arg2, out): def arithmetic_and_logical_test_data(): return [ # (builtin, inputs, expected) + (abs, [[-1.0, 1.0]], [1.0, 1.0]), (neg, [[-1.0, 1.0]], [1.0, -1.0]), (plus, [2.0, 3.0], 5.0), (minus, [2.0, 3.0], -1.0), @@ -182,8 +184,8 @@ def test_arithmetic_and_logical_builtins(program_processor, builtin, inputs, exp @pytest.mark.parametrize("builtin, inputs, expected", arithmetic_and_logical_test_data()) def test_arithmetic_and_logical_functors_gtfn(builtin, inputs, expected): - if builtin == if_: - pytest.skip("If cannot be used unapplied") + if builtin == if_ or builtin == abs: + pytest.skip("If and abs cannot be used unapplied") inps = field_maker(*array_maker(*inputs)) out = field_maker((np.zeros_like(*array_maker(expected))))[0] From 61ec67b16679f03c77a50922126cb5ff19807647 Mon Sep 17 00:00:00 2001 From: Sara Faghih-Naini Date: Thu, 23 Jan 2025 16:18:10 +0100 Subject: [PATCH 11/39] Extend visit_UnaryOp in foast_to_gtir --- src/gt4py/next/ffront/foast_to_gtir.py | 12 ++++++------ .../ffront_tests/test_math_unary_builtins.py | 8 ++++++++ .../feature_tests/iterator_tests/test_builtins.py | 2 +- 3 files changed, 15 insertions(+), 7 deletions(-) diff --git a/src/gt4py/next/ffront/foast_to_gtir.py b/src/gt4py/next/ffront/foast_to_gtir.py index 4519b4e571..007e195f3e 100644 --- a/src/gt4py/next/ffront/foast_to_gtir.py +++ b/src/gt4py/next/ffront/foast_to_gtir.py @@ -241,12 +241,12 @@ def visit_UnaryOp(self, node: foast.UnaryOp, **kwargs: Any) -> itir.Expr: if dtype.kind != ts.ScalarKind.BOOL: raise NotImplementedError(f"'{node.op}' is only supported on 'bool' arguments.") return self._lower_and_map("not_", node.operand) - - return self._lower_and_map( - node.op.value, - foast.Constant(value="0", type=dtype, location=node.location), - node.operand, - ) + if node.op in [dialect_ast_enums.UnaryOperator.USUB]: + return self._lower_and_map("neg", node.operand) + if node.op in [dialect_ast_enums.UnaryOperator.UADD]: + return self.visit(node.operand) + else: + raise NotImplementedError(f"Unary operator '{node.op}' is not supported.") def visit_BinOp(self, node: foast.BinOp, **kwargs: Any) -> itir.FunCall: return self._lower_and_map(node.op.value, node.left, node.right) diff --git a/tests/next_tests/integration_tests/feature_tests/ffront_tests/test_math_unary_builtins.py b/tests/next_tests/integration_tests/feature_tests/ffront_tests/test_math_unary_builtins.py index 89c341e9a6..1707adada8 100644 --- a/tests/next_tests/integration_tests/feature_tests/ffront_tests/test_math_unary_builtins.py +++ b/tests/next_tests/integration_tests/feature_tests/ffront_tests/test_math_unary_builtins.py @@ -128,6 +128,14 @@ def uneg(inp: cases.IField) -> cases.IField: cases.verify_with_default_data(cartesian_case, uneg, ref=lambda inp1: -inp1) +def test_unary_pos(cartesian_case): + @gtx.field_operator + def upos(inp: cases.IField) -> cases.IField: + return +inp + + cases.verify_with_default_data(cartesian_case, upos, ref=lambda inp1: inp1) + + def test_unary_neg_float_conversion(cartesian_case): @gtx.field_operator def uneg_float() -> cases.IFloatField: diff --git a/tests/next_tests/integration_tests/feature_tests/iterator_tests/test_builtins.py b/tests/next_tests/integration_tests/feature_tests/iterator_tests/test_builtins.py index 73ec64a7fb..191f0e3c5e 100644 --- a/tests/next_tests/integration_tests/feature_tests/iterator_tests/test_builtins.py +++ b/tests/next_tests/integration_tests/feature_tests/iterator_tests/test_builtins.py @@ -185,7 +185,7 @@ def test_arithmetic_and_logical_builtins(program_processor, builtin, inputs, exp @pytest.mark.parametrize("builtin, inputs, expected", arithmetic_and_logical_test_data()) def test_arithmetic_and_logical_functors_gtfn(builtin, inputs, expected): if builtin == if_ or builtin == abs: - pytest.skip("If and abs cannot be used unapplied") + pytest.skip("If and abs cannot be used unapplied.") inps = field_maker(*array_maker(*inputs)) out = field_maker((np.zeros_like(*array_maker(expected))))[0] From 596216441f55784dbd73f1154455799c23292021 Mon Sep 17 00:00:00 2001 From: Sara Faghih-Naini Date: Thu, 23 Jan 2025 17:09:24 +0100 Subject: [PATCH 12/39] Use neg in ConstantFolding --- .../iterator/transforms/constant_folding.py | 103 ++---------------- .../transforms_tests/test_constant_folding.py | 22 ++-- 2 files changed, 23 insertions(+), 102 deletions(-) diff --git a/src/gt4py/next/iterator/transforms/constant_folding.py b/src/gt4py/next/iterator/transforms/constant_folding.py index 10c8904b1e..5cb6ac5163 100644 --- a/src/gt4py/next/iterator/transforms/constant_folding.py +++ b/src/gt4py/next/iterator/transforms/constant_folding.py @@ -11,86 +11,12 @@ import functools import operator from gt4py.next.iterator.ir_utils import common_pattern_matcher as cpm, ir_makers as im -from gt4py.next.type_system import type_specifications as ts, type_translation import dataclasses import enum from typing import Optional - -class ConvertMinusToUnary(PreserveLocationVisitor, NodeTranslator): - def visit(self, node: ir.Node): - node = self.generic_visit(node) - - # im.call("...")(im.minus(1, im.ref("a")), 1) -> im.call("...")(im.plus(im.call("neg)(im.ref("a")), 1), 1) - if isinstance(node, ir.FunCall) and len(node.args) > 0 and cpm.is_call_to(node.args[0], "minus"): - if cpm.is_call_to(node, ("minus", "plus", "multiplies", "divides")): - if isinstance(node.args[0].args[1], (ir.SymRef, ir.FunCall)): - node = im.call(node.fun.id)(im.plus(im.call("neg")(node.args[0].args[1]), node.args[0].args[0]), node.args[1]) - node = self.visit(node) - # im.call("...")(1, im.minus(1, im.ref("a"))) -> im.call("...")(1, im.plus(im.call("neg)(im.ref("a")), 1)) - elif isinstance(node, ir.FunCall) and len(node.args) > 1 and cpm.is_call_to(node.args[1], "minus"): - if cpm.is_call_to(node, ("minus", "plus", "multiplies", "divides")): - if isinstance(node.args[1].args[1], (ir.SymRef, ir.FunCall)): - node = im.call(node.fun.id)(node.args[0], im.plus(im.call("neg")(node.args[1].args[1]), node.args[1].args[0])) - node = self.visit(node) - return node - -class ConvertUnaryToMinus(PreserveLocationVisitor, NodeTranslator): - - def visit(self, node: ir.Node): - - if isinstance(node, ir.FunCall) and len(node.args) > 0 and cpm.is_call_to(node.args[0], "neg"): - manipulated_first_arg = False - if cpm.is_call_to(node, ("multiplies", "divides", "maximum", "minimum")): - if node.args[0].args[0].type: - zero = im.literal(str(0), node.args[0].args[0].type) - else: - zero = im.literal_from_value(0.0) # TODO: fix datatype - if cpm.is_call_to(node, "minus"): - node = im.plus(node.args[1], node.args[0].args[0]) - manipulated_first_arg = True - elif cpm.is_call_to(node, "plus"): - node = im.minus(node.args[1], node.args[0].args[0]) - manipulated_first_arg = True - elif cpm.is_call_to(node, "multiplies"): - node = im.multiplies_(im.minus(zero,node.args[0].args[0]), node.args[1]) - manipulated_first_arg = True - elif cpm.is_call_to(node, "divides"): - node = im.divides_(im.minus(zero,node.args[0].args[0]), node.args[1]) - manipulated_first_arg = True - elif cpm.is_call_to(node, "minimum"): - node = im.call("minimum")(im.minus(zero,node.args[0].args[0]), node.args[1]) - manipulated_first_arg = True - elif cpm.is_call_to(node, "maximum"): - node = im.call("maximum")(im.minus(zero,node.args[0].args[0]), node.args[1]) - manipulated_first_arg = True - if manipulated_first_arg: - node = self.visit(node) - elif isinstance(node, ir.FunCall) and len(node.args) > 1 and cpm.is_call_to(node.args[1], "neg"): - if cpm.is_call_to(node, ("multiplies", "divides", "maximum", "minimum")): - if node.args[0].args[0].type: - zero = im.literal(str(0), node.args[0].args[0].type) - else: - zero = im.literal_from_value(0.0) # TODO: fix datatype - if cpm.is_call_to(node, "minus"): - node = im.plus(node.args[0], node.args[1].args[0]) - elif cpm.is_call_to(node, "plus"): - node = im.minus(node.args[0], node.args[1].args[0]) - elif cpm.is_call_to(node, "multiplies"): - node = im.multiplies_(node.args[0], im.minus(zero,node.args[1].args[0])) - elif cpm.is_call_to(node, "divides"): - node = im.divides_(node.args[0], im.minus(zero, node.args[1].args[0])) - elif cpm.is_call_to(node, "minimum"): - node = im.call("minimum")(im.node.args[1], im.minus(zero,node.args[1].args[0])) - elif cpm.is_call_to(node, "maximum"): - node = im.call("maximum")(im.node.args[1], im.minus(zero,node.args[1].args[0])) - - return self.generic_visit(node) - - - @dataclasses.dataclass(frozen=True) class ConstantFolding(PreserveLocationVisitor, NodeTranslator): class Flag(enum.Flag): @@ -99,8 +25,8 @@ class Flag(enum.Flag): # `symref + funcall` -> `funcall + symref` CANONICALIZE_FUNCALL_SYMREF_LITERAL = enum.auto() - # `minus(symref, literal) -> plus(symref,-literal)` - CANONICALIZE_MINUS_SYMREF_LITERAL = enum.auto() + # `minus(arg0, arg1) -> plus(arg1, im.call("neg")(arg0))` + CANONICALIZE_MINUS = enum.auto() # `sym + 1 + 1` -> `sym + 2` FOLD_FUNCALL_LITERAL = enum.auto() @@ -110,7 +36,7 @@ class Flag(enum.Flag): # `maximum(plus(sym, 1), sym)` -> `plus(sym, 1)` and # `maximum(plus(sym, 1), plus(sym, -1))` -> `plus(sym, 1)` - FOLD_MIN_MAX_PLUS= enum.auto() + FOLD_MIN_MAX_PLUS = enum.auto() # `sym + 0` -> `sym` FOLD_SYMREF_PLUS_ZERO = enum.auto() @@ -121,8 +47,6 @@ class Flag(enum.Flag): # `1 + 1` -> `2` FOLD_ARITHMETIC_BUILTINS = enum.auto() - # `neg(1)` -> `-1` - CANONICALIZE_NEG_LITERAL = enum.auto() # `minimum(a, a)` -> `a` FOLD_MIN_MAX_LITERALS = enum.auto() @@ -141,9 +65,7 @@ def all(self): # TODO: -> ConstantFolding.Flag def apply(cls, node: ir.Node, flags: Optional[Flag] = None) -> ir.Node: flags = flags or cls.flags - node = ConvertMinusToUnary().visit(node) node = cls().visit(node, flags=flags) #TODO: remove flags? - node = ConvertUnaryToMinus().visit(node) return node @@ -188,12 +110,10 @@ def transform_canonicalize_funcall_symref_literal(self, node: ir.FunCall, **kwar return None - def transform_canonicalize_minus_symref_literal(self, node: ir.FunCall, **kwargs) -> Optional[ir.Node]: - # `minus(symref, literal) -> plus(symref,-literal)` - if (cpm.is_call_to(node, "minus") and - isinstance(node.args[0], (ir.SymRef, ir.FunCall)) and - isinstance(node.args[1], ir.Literal)): - return self.visit(im.plus(node.args[0], im.minus(0, node.args[1]))) + def transform_canonicalize_minus(self, node: ir.FunCall, **kwargs) -> Optional[ir.Node]: + # `minus(arg0, arg1) -> plus(arg1, im.call("neg")(arg0))` + if cpm.is_call_to(node, "minus"): + return self.visit(im.plus(im.call("neg")(node.args[1]), node.args[0])) return None @@ -255,7 +175,7 @@ def transform_fold_arithmetic_builtins(self, node: ir.FunCall, **kwargs) -> Opti and all(isinstance(arg, ir.Literal) for arg in node.args) ): try: - if node.fun.id in builtins.ARITHMETIC_BUILTINS and not cpm.is_call_to(node, "neg"): + if node.fun.id in builtins.ARITHMETIC_BUILTINS: fun = getattr(embedded, str(node.fun.id)) arg_values = [ getattr(embedded, str(arg.type))(arg.value) @@ -267,13 +187,6 @@ def transform_fold_arithmetic_builtins(self, node: ir.FunCall, **kwargs) -> Opti pass # happens for inf and neginf return None - def transform_canonicalize_neg_literal(self, node: ir.FunCall, **kwargs) -> Optional[ir.Node]: - # `neg(1)` -> `-1` - if cpm.is_call_to(node, ("neg")): - if isinstance(node.args[0], ir.Literal): - return self.visit(im.minus(0, int(node.args[0].value))) - return None - def transform_fold_min_max_literals(self, node: ir.FunCall, **kwargs) -> Optional[ir.Node]: # `minimum(a, a)` -> `a` if cpm.is_call_to(node, ("minimum", "maximum")): diff --git a/tests/next_tests/unit_tests/iterator_tests/transforms_tests/test_constant_folding.py b/tests/next_tests/unit_tests/iterator_tests/transforms_tests/test_constant_folding.py index 46dc2c5bcb..be8fe3847d 100644 --- a/tests/next_tests/unit_tests/iterator_tests/transforms_tests/test_constant_folding.py +++ b/tests/next_tests/unit_tests/iterator_tests/transforms_tests/test_constant_folding.py @@ -74,7 +74,7 @@ def test_constant_folding_literal_minus0(): assert actual == expected testee = im.minus(im.literal_from_value(1), im.ref("__out_size_1")) - expected = im.minus(im.literal_from_value(1), im.ref("__out_size_1")) + expected = im.plus(im.call("neg")(im.ref("__out_size_1")), im.literal_from_value(1)) actual = ConstantFolding.apply(testee) assert actual == expected @@ -299,10 +299,11 @@ def test_constant_folding_literal_maximum(): assert actual == expected def test_constant_folding_complex(): - # 1 - max(max(1, max(1, sym), min(1, sym), sym), 1 + (min(-1, 2) + max(-1, 1 - sym))) + sym = im.ref("sym") + # 1 - max(max(1, max(1, sym), min(1, sym), sym), 1 + (min(-1, 2) + max(-1, 1 - sym))) testee = im.minus(im.literal_from_value(1), im.call("maximum")(im.call("maximum")(im.literal_from_value(1), im.call("maximum")(im.literal_from_value(1), im.ref("sym")), im.call("minimum")(im.literal_from_value(1), im.ref("sym")), im.ref("sym")), im.plus(im.literal_from_value(1), im.plus(im.call("minimum")(im.literal_from_value(-1), 2 ), im.call("maximum")(im.literal_from_value(-1), im.minus (im.literal_from_value(1), im.ref("sym"))))))) - # 1 - max(max(sym, 1), max(1 - sym, -1)) - expected = im.minus(im.literal_from_value(1),im.call("maximum")(im.call("maximum")(im.ref("sym"), im.literal_from_value(1)), im.call("maximum")(im.minus(im.literal_from_value(1), im.ref("sym")),im.literal_from_value(-1)))) + # neg(maximum(maximum(sym, 1), maximum(neg(sym) + 1, -1))) + 1 + expected = im.plus(im.call("neg")(im.call("maximum")(im.call("maximum")(im.ref("sym"), im.literal_from_value(1)), im.call("maximum")(im.plus(im.call("neg")(sym),im.literal_from_value(1)),im.literal_from_value(-1)))),im.literal_from_value(1)) actual = ConstantFolding.apply(testee) assert actual == expected @@ -324,8 +325,8 @@ def test_constant_folding_complex_3(): sym = im.ref("sym") # minimum(1 - sym, 1 + sym) + (maximum(maximum(1 - sym, 1 + sym), 1 - sym) + maximum(1 - sym, 1 - sym)) testee = im.plus(im.call("minimum")(im.minus(im.literal_from_value(1), sym), im.plus(im.literal_from_value(1), sym)), im.plus(im.call("maximum")(im.call("maximum")(im.minus(im.literal_from_value(1), sym), im.plus(im.literal_from_value(1), sym)), im.minus(im.literal_from_value(1), sym)),im.call("maximum")(im.minus(im.literal_from_value(1), sym), im.minus(im.literal_from_value(1), sym)))) - # minimum(1 - sym, sym + 1) + (maximum(sym + 1, 1 - sym) + (1 - sym)) - expected = im.plus(im.call("minimum")(im.minus(im.literal_from_value(1), sym), im.plus(sym, im.literal_from_value(1))), im.plus(im.call("maximum")(im.plus(sym, im.literal_from_value(1)), im.minus(im.literal_from_value(1), sym)),im.minus(im.literal_from_value(1), sym))) + # minimum(neg(sym) + 1, sym + 1) + (maximum(sym + 1, neg(sym) + 1) + (neg(sym) + 1)) + expected = im.plus(im.call("minimum")(im.plus(im.call("neg")(sym),im.literal_from_value(1)), im.plus(sym, im.literal_from_value(1))), im.plus(im.call("maximum")(im.plus(sym, im.literal_from_value(1)), im.plus(im.call("neg")(sym),im.literal_from_value(1))),im.plus(im.call("neg")(sym),im.literal_from_value(1)))) actual = ConstantFolding.apply(testee) assert actual == expected @@ -350,7 +351,7 @@ def test_constant_folding_complex_2(): def test_constant_folding_complex_4(): sym = im.ref("sym", "float32") testee = im.divides_(im.minus(im.literal_from_value(1), sym), im.minus( im.literal_from_value(2), sym)) - expected = im.divides_(im.minus(im.literal_from_value(1), sym), im.minus( im.literal_from_value(2), sym)) + expected = im.divides_(im.plus(im.call("neg")(sym), im.literal_from_value(1)), im.plus(im.call("neg")(sym), im.literal_from_value(2))) actual = ConstantFolding.apply(testee) assert actual == expected @@ -366,4 +367,11 @@ def test_constant_folding_plus_new(): testee = im.plus(im.minus(sym, im.literal_from_value(1)), im.literal_from_value(2)) expected = im.plus(sym, im.literal_from_value(1)) actual = ConstantFolding.apply(testee) + assert actual == expected + +def test_minus(): + sym = im.ref("sym") + testee = im.plus(im.minus(im.literal_from_value(1), sym), im.literal_from_value(1)) + expected = im.plus(im.call("neg")(sym), im.literal_from_value(2)) + actual = ConstantFolding.apply(testee) assert actual == expected \ No newline at end of file From a9f279171d55e2c009ec453655dd251cdeb28fb7 Mon Sep 17 00:00:00 2001 From: Sara Faghih-Naini Date: Thu, 23 Jan 2025 17:18:21 +0100 Subject: [PATCH 13/39] Fix foast_to_gtir test references --- .../next_tests/unit_tests/ffront_tests/test_foast_to_gtir.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/tests/next_tests/unit_tests/ffront_tests/test_foast_to_gtir.py b/tests/next_tests/unit_tests/ffront_tests/test_foast_to_gtir.py index 59a8dc961b..d2d5404cb5 100644 --- a/tests/next_tests/unit_tests/ffront_tests/test_foast_to_gtir.py +++ b/tests/next_tests/unit_tests/ffront_tests/test_foast_to_gtir.py @@ -378,7 +378,7 @@ def foo(inp: gtx.Field[[TDim], float64]): parsed = FieldOperatorParser.apply_to_function(foo) lowered = FieldOperatorLowering.apply(parsed) - reference = im.op_as_fieldop("minus")(im.literal("0", "float64"), "inp") + reference = im.op_as_fieldop("neg")("inp") assert lowered.expr == reference @@ -390,7 +390,7 @@ def foo(inp: gtx.Field[[TDim], float64]): parsed = FieldOperatorParser.apply_to_function(foo) lowered = FieldOperatorLowering.apply(parsed) - reference = im.op_as_fieldop("plus")(im.literal("0", "float64"), "inp") + reference = im.ref("inp") assert lowered.expr == reference From 586147c49ec9012c7dd32d30fe23847fbc420700 Mon Sep 17 00:00:00 2001 From: Sara Faghih-Naini Date: Thu, 23 Jan 2025 17:57:13 +0100 Subject: [PATCH 14/39] Cleanup and use neg builtin --- src/gt4py/next/iterator/builtins.py | 4 - .../iterator/transforms/constant_folding.py | 120 +++++++----- .../transforms_tests/test_constant_folding.py | 176 +++++++++++++++--- 3 files changed, 222 insertions(+), 78 deletions(-) diff --git a/src/gt4py/next/iterator/builtins.py b/src/gt4py/next/iterator/builtins.py index 3faac068ea..daa0fe5df1 100644 --- a/src/gt4py/next/iterator/builtins.py +++ b/src/gt4py/next/iterator/builtins.py @@ -291,10 +291,6 @@ def ceil(*args): def trunc(*args): raise BackendNotSelectedError() -@builtin_dispatch -def neg(*args): - raise BackendNotSelectedError() - @builtin_dispatch def neg(*args): diff --git a/src/gt4py/next/iterator/transforms/constant_folding.py b/src/gt4py/next/iterator/transforms/constant_folding.py index 5cb6ac5163..78d742a368 100644 --- a/src/gt4py/next/iterator/transforms/constant_folding.py +++ b/src/gt4py/next/iterator/transforms/constant_folding.py @@ -6,29 +6,29 @@ # Please, refer to the LICENSE file in the root directory. # SPDX-License-Identifier: BSD-3-Clause -from gt4py.eve import NodeTranslator, PreserveLocationVisitor -from gt4py.next.iterator import builtins, embedded, ir -import functools -import operator -from gt4py.next.iterator.ir_utils import common_pattern_matcher as cpm, ir_makers as im - import dataclasses import enum +import functools +import operator from typing import Optional +from gt4py import eve +from gt4py.next.iterator import builtins, embedded, ir +from gt4py.next.iterator.ir_utils import common_pattern_matcher as cpm, ir_makers as im + @dataclasses.dataclass(frozen=True) -class ConstantFolding(PreserveLocationVisitor, NodeTranslator): +class ConstantFolding(eve.PreserveLocationVisitor, eve.NodeTranslator): class Flag(enum.Flag): # e.g. `literal + symref` -> `symref + literal` and # `literal + funcall` -> `funcall + literal` and # `symref + funcall` -> `funcall + symref` CANONICALIZE_FUNCALL_SYMREF_LITERAL = enum.auto() - # `minus(arg0, arg1) -> plus(arg1, im.call("neg")(arg0))` + # `minus(arg0, arg1) -> plus(im.call("neg")(arg0), arg1)` CANONICALIZE_MINUS = enum.auto() - # `sym + 1 + 1` -> `sym + 2` + # `(sym + 1) + 1` -> `sym + 2` FOLD_FUNCALL_LITERAL = enum.auto() # `maximum(maximum(sym, 1), sym)` -> `maximum(sym, 1)` @@ -47,7 +47,6 @@ class Flag(enum.Flag): # `1 + 1` -> `2` FOLD_ARITHMETIC_BUILTINS = enum.auto() - # `minimum(a, a)` -> `a` FOLD_MIN_MAX_LITERALS = enum.auto() @@ -55,20 +54,16 @@ class Flag(enum.Flag): FOLD_IF = enum.auto() @classmethod - def all(self): # TODO: -> ConstantFolding.Flag + def all(self): # TODO -> ConstantFolding.Flag: return functools.reduce(operator.or_, self.__members__.values()) - flags: Flag = Flag.all() - + flags: Flag = Flag.all() # noqa: RUF009 [function-call-in-dataclass-default-argument] @classmethod - def apply(cls, node: ir.Node, flags: Optional[Flag] = None) -> ir.Node: - flags = flags or cls.flags - - node = cls().visit(node, flags=flags) #TODO: remove flags? + def apply(cls, node: ir.Node) -> ir.Node: + node = cls().visit(node) return node - def visit_FunCall(self, node: ir.FunCall, **kwargs): # visit depth-first such that nested constant expressions (e.g. `(1+2)+3`) are properly folded node = self.generic_visit(node, **kwargs) @@ -88,58 +83,66 @@ def transform(self, node: ir.Node, **kwargs) -> Optional[ir.Node]: return None for transformation in self.Flag: - if self.flags & transformation: + if transformation: assert isinstance(transformation.name, str) method = getattr(self, f"transform_{transformation.name.lower()}") result = method(node) if result is not None: assert ( - result is not node + result is not node ) # transformation should have returned None, since nothing changed return result return None - def transform_canonicalize_funcall_symref_literal(self, node: ir.FunCall, **kwargs) -> Optional[ir.Node]: + def transform_canonicalize_funcall_symref_literal( + self, node: ir.FunCall, **kwargs + ) -> Optional[ir.Node]: # e.g. `literal + symref` -> `symref + literal` and # `literal + funcall` -> `funcall + literal` and # `symref + funcall` -> `funcall + symref` - if cpm.is_call_to(node, ("plus", "multiplies", "minimum", "maximum")): - if (isinstance(node.args[1], (ir.SymRef, ir.FunCall)) and isinstance(node.args[0], ir.Literal) - or isinstance(node.args[1], ir.FunCall) and isinstance(node.args[0], ir.SymRef)): + if isinstance(node.fun, ir.SymRef) and cpm.is_call_to( + node, ("plus", "multiplies", "minimum", "maximum") + ): + if ( + isinstance(node.args[1], (ir.SymRef, ir.FunCall)) + and isinstance(node.args[0], ir.Literal) + ) or (isinstance(node.args[1], ir.FunCall) and isinstance(node.args[0], ir.SymRef)): return im.call(node.fun.id)(node.args[1], node.args[0]) return None - def transform_canonicalize_minus(self, node: ir.FunCall, **kwargs) -> Optional[ir.Node]: - # `minus(arg0, arg1) -> plus(arg1, im.call("neg")(arg0))` + # `minus(arg0, arg1) -> plus(im.call("neg")(arg0), arg1)` if cpm.is_call_to(node, "minus"): return self.visit(im.plus(im.call("neg")(node.args[1]), node.args[0])) return None - def transform_fold_funcall_literal(self, node: ir.FunCall, **kwargs) -> Optional[ir.Node]: - # `sym + 1 + 1` -> `sym + 2` + # `(sym + 1) + 1` -> `sym + 2` if cpm.is_call_to(node, "plus"): - if cpm.is_call_to(node.args[0], "plus") and isinstance( node.args[1], ir.Literal): + if cpm.is_call_to(node.args[0], "plus") and isinstance(node.args[1], ir.Literal): fun_call, literal = node.args - if isinstance(fun_call.args[0], (ir.SymRef, ir.FunCall)) and isinstance( fun_call.args[1], ir.Literal): - return self.visit(im.plus(fun_call.args[0], im.plus(fun_call.args[1], literal))) + if ( + isinstance(fun_call.args[0], (ir.SymRef, ir.FunCall)) # type: ignore[attr-defined] # assured by if above + and isinstance(fun_call.args[1], ir.Literal) # type: ignore[attr-defined] # assured by if above + ): + return self.visit(im.plus(fun_call.args[0], im.plus(fun_call.args[1], literal))) # type: ignore[attr-defined] # assured by if above return None - def transform_fold_min_max_funcall_symref_literal(self, node: ir.FunCall, **kwargs) -> Optional[ir.Node]: + def transform_fold_min_max_funcall_symref_literal( + self, node: ir.FunCall, **kwargs + ) -> Optional[ir.Node]: # `maximum(maximum(sym, 1), sym)` -> `maximum(sym, 1)` - if cpm.is_call_to(node, ("minimum", "maximum")): - if cpm.is_call_to(node.args[0], ("maximum", "minimum")): - fun_call, arg1, = node.args - if arg1 == fun_call.args[0]: - return self.visit(im.call(fun_call.fun.id)(fun_call.args[1], arg1)) - if arg1 == fun_call.args[1]: - return self.visit(im.call(fun_call.fun.id)(fun_call.args[0], arg1)) + if isinstance(node.fun, ir.SymRef) and cpm.is_call_to(node, ("minimum", "maximum")): + if cpm.is_call_to(node.args[0], ("maximum", "minimum")): + fun_call, arg1 = node.args + if arg1 == fun_call.args[0]: # type: ignore[attr-defined] # assured by if above + return self.visit(im.call(fun_call.fun.id)(fun_call.args[1], arg1)) # type: ignore[attr-defined] # assured by if above + if arg1 == fun_call.args[1]: # type: ignore[attr-defined] # assured by if above + return self.visit(im.call(fun_call.fun.id)(fun_call.args[0], arg1)) # type: ignore[attr-defined] # assured by if above return None - def transform_fold_min_max_plus(self, node: ir.FunCall, **kwargs) -> Optional[ir.Node]: - if cpm.is_call_to(node, ("minimum", "maximum")): + if isinstance(node.fun, ir.SymRef) and cpm.is_call_to(node, ("minimum", "maximum")): arg0, arg1 = node.args # `maximum(plus(sym, 1), sym)` -> `plus(sym, 1)` if cpm.is_call_to(arg0, "plus"): @@ -148,22 +151,39 @@ def transform_fold_min_max_plus(self, node: ir.FunCall, **kwargs) -> Optional[ir # `maximum(plus(sym, 1), plus(sym, -1))` -> `plus(sym, 1)` if cpm.is_call_to(arg0, "plus") and cpm.is_call_to(arg1, "plus"): if arg0.args[0] == arg1.args[0]: - return self.visit(im.plus(arg0.args[0], im.call(node.fun.id)(arg0.args[1], arg1.args[1]))) + return self.visit( + im.plus(arg0.args[0], im.call(node.fun.id)(arg0.args[1], arg1.args[1])) + ) return None def transform_fold_symref_plus_zero(self, node: ir.FunCall, **kwargs) -> Optional[ir.Node]: # `sym + 0` -> `sym` - if (cpm.is_call_to(node, "plus") and isinstance(node.args[1], ir.Literal) and - node.args[1].value.isdigit() and int(node.args[1].value) == 0): + if ( + cpm.is_call_to(node, "plus") + and isinstance(node.args[1], ir.Literal) + and node.args[1].value.isdigit() + and int(node.args[1].value) == 0 + ): return node.args[0] return None - def transform_canonicalize_plus_symref_literal(self, node: ir.FunCall, **kwargs) -> Optional[ir.Node]: + def transform_canonicalize_plus_symref_literal( + self, node: ir.FunCall, **kwargs + ) -> Optional[ir.Node]: # `sym1 + 1 + (sym2 + 2)` -> `sym1 + sym2 + 2 + 1` if cpm.is_call_to(node, "plus"): - if (cpm.is_call_to(node.args[0], "plus") and cpm.is_call_to(node.args[1], "plus") and - isinstance(node.args[0].args[1], ir.Literal) and isinstance(node.args[1].args[1], ir.Literal)): - return self.visit(im.plus(im.plus(node.args[0].args[0], node.args[1].args[0]), im.plus(node.args[0].args[1], node.args[1].args[1]))) + if ( + cpm.is_call_to(node.args[0], "plus") + and cpm.is_call_to(node.args[1], "plus") + and isinstance(node.args[0].args[1], ir.Literal) + and isinstance(node.args[1].args[1], ir.Literal) + ): + return self.visit( + im.plus( + im.plus(node.args[0].args[0], node.args[1].args[0]), + im.plus(node.args[0].args[1], node.args[1].args[1]), + ) + ) return None def transform_fold_arithmetic_builtins(self, node: ir.FunCall, **kwargs) -> Optional[ir.Node]: @@ -178,8 +198,7 @@ def transform_fold_arithmetic_builtins(self, node: ir.FunCall, **kwargs) -> Opti if node.fun.id in builtins.ARITHMETIC_BUILTINS: fun = getattr(embedded, str(node.fun.id)) arg_values = [ - getattr(embedded, str(arg.type))(arg.value) - # type: ignore[attr-defined] # arg type already established in if condition + getattr(embedded, str(arg.type))(arg.value) # type: ignore[attr-defined] # arg type already established in if condition for arg in node.args ] return im.literal_from_value(fun(*arg_values)) @@ -202,4 +221,3 @@ def transform_fold_if(self, node: ir.FunCall, **kwargs) -> Optional[ir.Node]: else: return node.args[2] return None - diff --git a/tests/next_tests/unit_tests/iterator_tests/transforms_tests/test_constant_folding.py b/tests/next_tests/unit_tests/iterator_tests/transforms_tests/test_constant_folding.py index be8fe3847d..bde6160a76 100644 --- a/tests/next_tests/unit_tests/iterator_tests/transforms_tests/test_constant_folding.py +++ b/tests/next_tests/unit_tests/iterator_tests/transforms_tests/test_constant_folding.py @@ -10,14 +10,13 @@ from gt4py.next.iterator.transforms.constant_folding import ConstantFolding - def test_constant_folding_plus(): expected = im.literal_from_value(2) - testee = im.plus( - im.literal_from_value(1), im.literal_from_value(1)) + testee = im.plus(im.literal_from_value(1), im.literal_from_value(1)) actual = ConstantFolding.apply(testee) assert actual == expected + def test_constant_folding_boolean(): testee = im.not_(im.literal_from_value(True)) expected = im.literal_from_value(False) @@ -67,6 +66,7 @@ def test_constant_folding_literal_plus0(): actual = ConstantFolding.apply(testee) assert actual == expected + def test_constant_folding_literal_minus0(): testee = im.minus(im.ref("__out_size_1"), im.literal_from_value(1)) expected = im.plus(im.ref("__out_size_1"), im.literal_from_value(-1)) @@ -78,16 +78,23 @@ def test_constant_folding_literal_minus0(): actual = ConstantFolding.apply(testee) assert actual == expected + def test_constant_folding_funcall_literal(): - testee = im.plus(im.plus(im.ref("__out_size_1"), im.literal_from_value(1)), im.literal_from_value(1)) + testee = im.plus( + im.plus(im.ref("__out_size_1"), im.literal_from_value(1)), im.literal_from_value(1) + ) expected = im.plus(im.ref("__out_size_1"), im.literal_from_value(2)) actual = ConstantFolding.apply(testee) assert actual == expected - testee = im.plus(im.literal_from_value(1), im.plus(im.ref("__out_size_1"), im.literal_from_value(1))) + testee = im.plus( + im.literal_from_value(1), im.plus(im.ref("__out_size_1"), im.literal_from_value(1)) + ) expected = im.plus(im.ref("__out_size_1"), im.literal_from_value(2)) actual = ConstantFolding.apply(testee) assert actual == expected + + def test_constant_folding_maximum_literal_minus(): testee = im.plus(im.ref("__out_size_1"), im.literal_from_value(1)) expected = im.plus(im.ref("__out_size_1"), im.literal_from_value(1)) @@ -99,6 +106,7 @@ def test_constant_folding_maximum_literal_minus(): actual = ConstantFolding.apply(testee) assert actual == expected + def test_constant_folding_maximum_literal_plus1(): testee = im.call("maximum")( im.call("maximum")(im.ref("__out_size_1"), im.literal_from_value(1)), @@ -108,6 +116,7 @@ def test_constant_folding_maximum_literal_plus1(): actual = ConstantFolding.apply(testee) assert actual == expected + def test_constant_folding_maximum_literal_plus2(): testee = im.call("maximum")( im.call("maximum")(im.literal_from_value(1), im.ref("__out_size_1")), @@ -117,24 +126,27 @@ def test_constant_folding_maximum_literal_plus2(): actual = ConstantFolding.apply(testee) assert actual == expected + def test_constant_folding_maximum_literal_plus3(): testee = im.call("maximum")( im.call("maximum")(im.literal_from_value(1), im.ref("__out_size_1")), im.call("maximum")(im.literal_from_value(1), im.ref("__out_size_1")), ) - expected = im.call("maximum")(im.ref("__out_size_1"),im.literal_from_value(1)) + expected = im.call("maximum")(im.ref("__out_size_1"), im.literal_from_value(1)) actual = ConstantFolding.apply(testee) assert actual == expected + def test_constant_folding_maximum_literal_plus4(): testee = im.call("maximum")( im.call("maximum")(im.literal_from_value(1), im.ref("__out_size_1")), im.call("maximum")(im.ref("__out_size_1"), im.literal_from_value(1)), ) - expected = im.call("maximum")(im.ref("__out_size_1"),im.literal_from_value(1)) + expected = im.call("maximum")(im.ref("__out_size_1"), im.literal_from_value(1)) actual = ConstantFolding.apply(testee) assert actual == expected + def test_constant_folding_maximum_literal_plus5(): testee = im.call("maximum")( im.ref("__out_size_1"), im.call("maximum")(im.literal_from_value(1), im.ref("__out_size_1")) @@ -143,6 +155,7 @@ def test_constant_folding_maximum_literal_plus5(): actual = ConstantFolding.apply(testee) assert actual == expected + def test_constant_folding_maximum_literal_plus6(): testee = im.call("maximum")( im.plus(im.ref("__out_size_1"), im.literal_from_value(1)), @@ -152,6 +165,7 @@ def test_constant_folding_maximum_literal_plus6(): actual = ConstantFolding.apply(testee) assert actual == expected + def test_constant_folding_maximum_literal_plus7(): testee = im.minus( im.plus(im.ref("__out_size_1"), im.literal_from_value(1)), @@ -161,6 +175,7 @@ def test_constant_folding_maximum_literal_plus7(): actual = ConstantFolding.apply(testee) assert actual == expected + def test_constant_folding_maximum_literal_plus8(): testee = im.plus( im.plus(im.ref("__out_size_1"), im.literal_from_value(1)), @@ -170,6 +185,7 @@ def test_constant_folding_maximum_literal_plus8(): actual = ConstantFolding.apply(testee) assert actual == expected + def test_constant_folding_maximum_literal_plus9(): testee = im.call("maximum")( im.plus(im.ref("__out_size_1"), im.literal_from_value(1)), @@ -181,6 +197,7 @@ def test_constant_folding_maximum_literal_plus9(): actual = ConstantFolding.apply(testee) assert actual == expected + def test_constant_folding_maximum_literal_plus10(): testee = im.call("maximum")( im.plus(im.ref("__out_size_1"), im.literal_from_value(1)), im.ref("__out_size_1") @@ -189,13 +206,18 @@ def test_constant_folding_maximum_literal_plus10(): actual = ConstantFolding.apply(testee) assert actual == expected + def test_constant_folding_maximum_literal_plus10a(): - testee = im.plus(im.ref("__out_size_1"), im.call("maximum")(im.literal_from_value(0), im.literal_from_value(-1))) + testee = im.plus( + im.ref("__out_size_1"), + im.call("maximum")(im.literal_from_value(0), im.literal_from_value(-1)), + ) expected = im.ref("__out_size_1") actual = ConstantFolding.apply(testee) assert actual == expected + def test_constant_folding_maximum_literal_plus11(): testee = im.call("maximum")( im.ref("__out_size_1"), im.plus(im.ref("__out_size_1"), im.literal_from_value(1)) @@ -204,6 +226,7 @@ def test_constant_folding_maximum_literal_plus11(): actual = ConstantFolding.apply(testee) assert actual == expected + def test_constant_folding_maximum_literal_plus12(): testee = im.call("maximum")( im.ref("__out_size_1"), im.plus(im.ref("__out_size_1"), im.literal_from_value(-1)) @@ -212,6 +235,7 @@ def test_constant_folding_maximum_literal_plus12(): actual = ConstantFolding.apply(testee) assert actual == expected + def test_constant_folding_maximum_literal_plus13(): testee = im.call("maximum")( im.minus(im.ref("__out_size_1"), im.literal_from_value(1)), im.ref("__out_size_1") @@ -220,6 +244,7 @@ def test_constant_folding_maximum_literal_plus13(): actual = ConstantFolding.apply(testee) assert actual == expected + def test_constant_folding_maximum_literal_plus14(): testee = im.call("maximum")( im.ref("__out_size_1"), im.minus(im.ref("__out_size_1"), im.literal_from_value(1)) @@ -228,6 +253,7 @@ def test_constant_folding_maximum_literal_plus14(): actual = ConstantFolding.apply(testee) assert actual == expected + def test_constant_folding_maximum_literal_plus15(): testee = im.call("maximum")( im.ref("__out_size_1"), im.minus(im.ref("__out_size_1"), im.literal_from_value(-1)) @@ -236,6 +262,7 @@ def test_constant_folding_maximum_literal_plus15(): actual = ConstantFolding.apply(testee) assert actual == expected + def test_constant_folding_maximum_literal_plus16(): testee = im.call("minimum")( im.plus(im.ref("__out_size_1"), im.literal_from_value(1)), im.ref("__out_size_1") @@ -244,6 +271,7 @@ def test_constant_folding_maximum_literal_plus16(): actual = ConstantFolding.apply(testee) assert actual == expected + def test_constant_folding_maximum_literal_plus17(): testee = im.call("minimum")( im.ref("__out_size_1"), im.plus(im.ref("__out_size_1"), im.literal_from_value(1)) @@ -252,6 +280,7 @@ def test_constant_folding_maximum_literal_plus17(): actual = ConstantFolding.apply(testee) assert actual == expected + def test_constant_folding_maximum_literal_plus18(): testee = im.call("minimum")( im.ref("__out_size_1"), im.plus(im.ref("__out_size_1"), im.literal_from_value(-1)) @@ -260,6 +289,7 @@ def test_constant_folding_maximum_literal_plus18(): actual = ConstantFolding.apply(testee) assert actual == expected + def test_constant_folding_maximum_literal_plus19(): testee = im.call("minimum")( im.minus(im.ref("__out_size_1"), im.literal_from_value(1)), im.ref("__out_size_1") @@ -268,6 +298,7 @@ def test_constant_folding_maximum_literal_plus19(): actual = ConstantFolding.apply(testee) assert actual == expected + def test_constant_folding_maximum_literal_plus20(): testee = im.call("minimum")( im.ref("__out_size_1"), im.minus(im.ref("__out_size_1"), im.literal_from_value(1)) @@ -276,6 +307,7 @@ def test_constant_folding_maximum_literal_plus20(): actual = ConstantFolding.apply(testee) assert actual == expected + def test_constant_folding_maximum_literal_plus21(): testee = im.call("minimum")( im.ref("__out_size_1"), im.minus(im.ref("__out_size_1"), im.literal_from_value(-1)) @@ -298,51 +330,142 @@ def test_constant_folding_literal_maximum(): actual = ConstantFolding.apply(testee) assert actual == expected + def test_constant_folding_complex(): sym = im.ref("sym") # 1 - max(max(1, max(1, sym), min(1, sym), sym), 1 + (min(-1, 2) + max(-1, 1 - sym))) - testee = im.minus(im.literal_from_value(1), im.call("maximum")(im.call("maximum")(im.literal_from_value(1), im.call("maximum")(im.literal_from_value(1), im.ref("sym")), im.call("minimum")(im.literal_from_value(1), im.ref("sym")), im.ref("sym")), im.plus(im.literal_from_value(1), im.plus(im.call("minimum")(im.literal_from_value(-1), 2 ), im.call("maximum")(im.literal_from_value(-1), im.minus (im.literal_from_value(1), im.ref("sym"))))))) + testee = im.minus( + im.literal_from_value(1), + im.call("maximum")( + im.call("maximum")( + im.literal_from_value(1), + im.call("maximum")(im.literal_from_value(1), im.ref("sym")), + im.call("minimum")(im.literal_from_value(1), im.ref("sym")), + im.ref("sym"), + ), + im.plus( + im.literal_from_value(1), + im.plus( + im.call("minimum")(im.literal_from_value(-1), 2), + im.call("maximum")( + im.literal_from_value(-1), im.minus(im.literal_from_value(1), im.ref("sym")) + ), + ), + ), + ), + ) # neg(maximum(maximum(sym, 1), maximum(neg(sym) + 1, -1))) + 1 - expected = im.plus(im.call("neg")(im.call("maximum")(im.call("maximum")(im.ref("sym"), im.literal_from_value(1)), im.call("maximum")(im.plus(im.call("neg")(sym),im.literal_from_value(1)),im.literal_from_value(-1)))),im.literal_from_value(1)) + expected = im.plus( + im.call("neg")( + im.call("maximum")( + im.call("maximum")(im.ref("sym"), im.literal_from_value(1)), + im.call("maximum")( + im.plus(im.call("neg")(sym), im.literal_from_value(1)), + im.literal_from_value(-1), + ), + ) + ), + im.literal_from_value(1), + ) actual = ConstantFolding.apply(testee) assert actual == expected +# ( (min(1 - sym, 1 + sym) + (max(max(1 - sym, 1 + sym),1 - sym) + max(1 - sym, 1 - sym)))))) - 2 +# max(sym, 1 + sym) + (max(1, max(1, sym)) + (sym - 1 + (1 + (sym + 1) + 1))) - 2 -#( (min(1 - sym, 1 + sym) + (max(max(1 - sym, 1 + sym),1 - sym) + max(1 - sym, 1 - sym)))))) - 2 - #max(sym, 1 + sym) + (max(1, max(1, sym)) + (sym - 1 + (1 + (sym + 1) + 1))) - 2 def test_constant_folding_complex_1(): sym = im.ref("sym") # maximum(sym, 1 + sym) + (maximum(1, maximum(1, sym)) + (sym - 1 + (1 + (sym + 1) + 1))) - 2 - testee = im.minus(im.plus(im.call("maximum")(sym, im.plus(im.literal_from_value(1), sym)), im.plus(im.call("maximum")(im.literal_from_value(1), im.call("maximum")(im.literal_from_value(1),sym)),im.plus(im.minus(sym,im.literal_from_value(1)), im.plus(im.plus(im.literal_from_value(1),im.plus(sym,im.literal_from_value(1))),im.literal_from_value(1)) ))) , im.literal_from_value(2)) + testee = im.minus( + im.plus( + im.call("maximum")(sym, im.plus(im.literal_from_value(1), sym)), + im.plus( + im.call("maximum")( + im.literal_from_value(1), im.call("maximum")(im.literal_from_value(1), sym) + ), + im.plus( + im.minus(sym, im.literal_from_value(1)), + im.plus( + im.plus(im.literal_from_value(1), im.plus(sym, im.literal_from_value(1))), + im.literal_from_value(1), + ), + ), + ), + ), + im.literal_from_value(2), + ) # sym + 1 + (maximum(sym, 1) + (sym + sym + 2)) + -2 - expected = im.plus(im.plus(im.plus(sym,1),im.plus( im.call("maximum")(sym,im.literal_from_value(1)),im.plus(im.plus(sym, sym),im.literal_from_value(2)))), im.literal_from_value(-2)) + expected = im.plus( + im.plus( + im.plus(sym, 1), + im.plus( + im.call("maximum")(sym, im.literal_from_value(1)), + im.plus(im.plus(sym, sym), im.literal_from_value(2)), + ), + ), + im.literal_from_value(-2), + ) actual = ConstantFolding.apply(testee) assert actual == expected + def test_constant_folding_complex_3(): sym = im.ref("sym") # minimum(1 - sym, 1 + sym) + (maximum(maximum(1 - sym, 1 + sym), 1 - sym) + maximum(1 - sym, 1 - sym)) - testee = im.plus(im.call("minimum")(im.minus(im.literal_from_value(1), sym), im.plus(im.literal_from_value(1), sym)), im.plus(im.call("maximum")(im.call("maximum")(im.minus(im.literal_from_value(1), sym), im.plus(im.literal_from_value(1), sym)), im.minus(im.literal_from_value(1), sym)),im.call("maximum")(im.minus(im.literal_from_value(1), sym), im.minus(im.literal_from_value(1), sym)))) + testee = im.plus( + im.call("minimum")( + im.minus(im.literal_from_value(1), sym), im.plus(im.literal_from_value(1), sym) + ), + im.plus( + im.call("maximum")( + im.call("maximum")( + im.minus(im.literal_from_value(1), sym), im.plus(im.literal_from_value(1), sym) + ), + im.minus(im.literal_from_value(1), sym), + ), + im.call("maximum")( + im.minus(im.literal_from_value(1), sym), im.minus(im.literal_from_value(1), sym) + ), + ), + ) # minimum(neg(sym) + 1, sym + 1) + (maximum(sym + 1, neg(sym) + 1) + (neg(sym) + 1)) - expected = im.plus(im.call("minimum")(im.plus(im.call("neg")(sym),im.literal_from_value(1)), im.plus(sym, im.literal_from_value(1))), im.plus(im.call("maximum")(im.plus(sym, im.literal_from_value(1)), im.plus(im.call("neg")(sym),im.literal_from_value(1))),im.plus(im.call("neg")(sym),im.literal_from_value(1)))) + expected = im.plus( + im.call("minimum")( + im.plus(im.call("neg")(sym), im.literal_from_value(1)), + im.plus(sym, im.literal_from_value(1)), + ), + im.plus( + im.call("maximum")( + im.plus(sym, im.literal_from_value(1)), + im.plus(im.call("neg")(sym), im.literal_from_value(1)), + ), + im.plus(im.call("neg")(sym), im.literal_from_value(1)), + ), + ) actual = ConstantFolding.apply(testee) assert actual == expected + def test_constant_folding_complex_3a(): sym = im.ref("sym") # maximum(maximum(1 + sym, 1), 1 + sym) - testee = im.call("maximum")(im.call("maximum")( im.plus(im.literal_from_value(1), sym), 1), im.plus(im.literal_from_value(1), sym)) + testee = im.call("maximum")( + im.call("maximum")(im.plus(im.literal_from_value(1), sym), 1), + im.plus(im.literal_from_value(1), sym), + ) # maximum(1 + sym, 1) - expected =im.call("maximum")( im.plus(sym, im.literal_from_value(1)), 1) + expected = im.call("maximum")(im.plus(sym, im.literal_from_value(1)), 1) actual = ConstantFolding.apply(testee) assert actual == expected def test_constant_folding_complex_2(): sym = im.ref("sym") - testee = im.plus(im.plus(sym, im.literal_from_value(-1)),im.plus(sym, im.literal_from_value(3))) + testee = im.plus( + im.plus(sym, im.literal_from_value(-1)), im.plus(sym, im.literal_from_value(3)) + ) expected = im.plus(im.plus(sym, sym), im.literal_from_value(2)) actual = ConstantFolding.apply(testee) assert actual == expected @@ -350,18 +473,24 @@ def test_constant_folding_complex_2(): def test_constant_folding_complex_4(): sym = im.ref("sym", "float32") - testee = im.divides_(im.minus(im.literal_from_value(1), sym), im.minus( im.literal_from_value(2), sym)) - expected = im.divides_(im.plus(im.call("neg")(sym), im.literal_from_value(1)), im.plus(im.call("neg")(sym), im.literal_from_value(2))) + testee = im.divides_( + im.minus(im.literal_from_value(1), sym), im.minus(im.literal_from_value(2), sym) + ) + expected = im.divides_( + im.plus(im.call("neg")(sym), im.literal_from_value(1)), + im.plus(im.call("neg")(sym), im.literal_from_value(2)), + ) actual = ConstantFolding.apply(testee) assert actual == expected def test_constant_folding_max(): - testee = im.call("maximum")(0,0) + testee = im.call("maximum")(0, 0) expected = im.literal_from_value(0) actual = ConstantFolding.apply(testee) assert actual == expected + def test_constant_folding_plus_new(): sym = im.ref("sym") testee = im.plus(im.minus(sym, im.literal_from_value(1)), im.literal_from_value(2)) @@ -369,9 +498,10 @@ def test_constant_folding_plus_new(): actual = ConstantFolding.apply(testee) assert actual == expected + def test_minus(): sym = im.ref("sym") testee = im.plus(im.minus(im.literal_from_value(1), sym), im.literal_from_value(1)) expected = im.plus(im.call("neg")(sym), im.literal_from_value(2)) actual = ConstantFolding.apply(testee) - assert actual == expected \ No newline at end of file + assert actual == expected From ed0f06295a745e05e64d876552cae510adb3175b Mon Sep 17 00:00:00 2001 From: Sara Faghih-Naini Date: Fri, 24 Jan 2025 13:13:51 +0100 Subject: [PATCH 15/39] Take care of tuple_get --- .../iterator/transforms/constant_folding.py | 20 +++++++++++++++++++ 1 file changed, 20 insertions(+) diff --git a/src/gt4py/next/iterator/transforms/constant_folding.py b/src/gt4py/next/iterator/transforms/constant_folding.py index 78d742a368..17e4db3fd8 100644 --- a/src/gt4py/next/iterator/transforms/constant_folding.py +++ b/src/gt4py/next/iterator/transforms/constant_folding.py @@ -28,6 +28,12 @@ class Flag(enum.Flag): # `minus(arg0, arg1) -> plus(im.call("neg")(arg0), arg1)` CANONICALIZE_MINUS = enum.auto() + # `maximum(im.call(...)(), maximum(...))` -> `maximum(maximum(...), im.call(...)())` + CANONICALIZE_MIN_MAX = enum.auto() + + # `im.call(...)(im.tuple_get(...), im.plus(...)))` -> `im.call(...)( im.plus(...)), im.tuple_get(...))` + CANONICALIZE_TUPLE_GET_PLUS = enum.auto() + # `(sym + 1) + 1` -> `sym + 2` FOLD_FUNCALL_LITERAL = enum.auto() @@ -116,6 +122,20 @@ def transform_canonicalize_minus(self, node: ir.FunCall, **kwargs) -> Optional[i return self.visit(im.plus(im.call("neg")(node.args[1]), node.args[0])) return None + def transform_canonicalize_min_max(self, node: ir.FunCall, **kwargs) -> Optional[ir.Node]: + # `maximum(im.call(...)(), maximum(...))` -> `maximum(maximum(...), im.call(...)())` + if cpm.is_call_to(node, ("maximum", "minimum")): + if not cpm.is_call_to(node.args[0], ("maximum", "minimum")) and cpm.is_call_to(node.args[1], ("maximum", "minimum")): + return self.visit(im.call(node.fun.id)(node.args[1], node.args[0])) + return None + + def transform_canonicalize_tuple_get_plus(self, node: ir.FunCall, **kwargs) -> Optional[ir.Node]: + # im.call(...)(im.tuple_get(...), im.plus(...)))` -> `im.call(...)( im.plus(...)), im.tuple_get(...))` + if isinstance(node, ir.FunCall): + if cpm.is_call_to(node.args[0], "tuple_get") and cpm.is_call_to(node.args[1], "plus"): + return self.visit(im.call(node.fun.id)(node.args[1], node.args[0])) + return None + def transform_fold_funcall_literal(self, node: ir.FunCall, **kwargs) -> Optional[ir.Node]: # `(sym + 1) + 1` -> `sym + 2` if cpm.is_call_to(node, "plus"): From ed56f02da738cec27cb090f98390eb4e6e298dd7 Mon Sep 17 00:00:00 2001 From: Sara Faghih-Naini Date: Fri, 24 Jan 2025 14:44:24 +0100 Subject: [PATCH 16/39] Minor improvements --- .../iterator/transforms/constant_folding.py | 9 ++++--- .../transforms_tests/test_constant_folding.py | 25 +++++++++++++++++++ 2 files changed, 30 insertions(+), 4 deletions(-) diff --git a/src/gt4py/next/iterator/transforms/constant_folding.py b/src/gt4py/next/iterator/transforms/constant_folding.py index 17e4db3fd8..969a5b73d6 100644 --- a/src/gt4py/next/iterator/transforms/constant_folding.py +++ b/src/gt4py/next/iterator/transforms/constant_folding.py @@ -25,7 +25,7 @@ class Flag(enum.Flag): # `symref + funcall` -> `funcall + symref` CANONICALIZE_FUNCALL_SYMREF_LITERAL = enum.auto() - # `minus(arg0, arg1) -> plus(im.call("neg")(arg0), arg1)` + # `minus(arg0, arg1) -> plus(im.call("neg")(arg1), arg0)` CANONICALIZE_MINUS = enum.auto() # `maximum(im.call(...)(), maximum(...))` -> `maximum(maximum(...), im.call(...)())` @@ -117,7 +117,7 @@ def transform_canonicalize_funcall_symref_literal( return None def transform_canonicalize_minus(self, node: ir.FunCall, **kwargs) -> Optional[ir.Node]: - # `minus(arg0, arg1) -> plus(im.call("neg")(arg0), arg1)` + # `minus(arg0, arg1) -> plus(im.call("neg")(arg1), arg0)` if cpm.is_call_to(node, "minus"): return self.visit(im.plus(im.call("neg")(node.args[1]), node.args[0])) return None @@ -125,13 +125,14 @@ def transform_canonicalize_minus(self, node: ir.FunCall, **kwargs) -> Optional[i def transform_canonicalize_min_max(self, node: ir.FunCall, **kwargs) -> Optional[ir.Node]: # `maximum(im.call(...)(), maximum(...))` -> `maximum(maximum(...), im.call(...)())` if cpm.is_call_to(node, ("maximum", "minimum")): - if not cpm.is_call_to(node.args[0], ("maximum", "minimum")) and cpm.is_call_to(node.args[1], ("maximum", "minimum")): + if (isinstance(node.args[0], ir.FunCall) and not cpm.is_call_to(node.args[0], ("maximum", "minimum")) + and cpm.is_call_to(node.args[1], ("maximum", "minimum"))): return self.visit(im.call(node.fun.id)(node.args[1], node.args[0])) return None def transform_canonicalize_tuple_get_plus(self, node: ir.FunCall, **kwargs) -> Optional[ir.Node]: # im.call(...)(im.tuple_get(...), im.plus(...)))` -> `im.call(...)( im.plus(...)), im.tuple_get(...))` - if isinstance(node, ir.FunCall): + if isinstance(node, ir.FunCall) and len(node.args) > 1: if cpm.is_call_to(node.args[0], "tuple_get") and cpm.is_call_to(node.args[1], "plus"): return self.visit(im.call(node.fun.id)(node.args[1], node.args[0])) return None diff --git a/tests/next_tests/unit_tests/iterator_tests/transforms_tests/test_constant_folding.py b/tests/next_tests/unit_tests/iterator_tests/transforms_tests/test_constant_folding.py index bde6160a76..0f89e7fa0d 100644 --- a/tests/next_tests/unit_tests/iterator_tests/transforms_tests/test_constant_folding.py +++ b/tests/next_tests/unit_tests/iterator_tests/transforms_tests/test_constant_folding.py @@ -505,3 +505,28 @@ def test_minus(): expected = im.plus(im.call("neg")(sym), im.literal_from_value(2)) actual = ConstantFolding.apply(testee) assert actual == expected + + +def test_fold_min_max_plus(): + sym = im.ref("sym") + testee = im.call("minimum")(im.plus(sym, im.literal_from_value(-1)), sym) + expected = im.plus(sym, im.literal_from_value(-1)) + actual = ConstantFolding.apply(testee) + assert actual == expected + + +def test_max_tuple_get(): + sym = im.ref("sym") + testee = im.call("maximum")(im.plus(im.tuple_get(1, sym), 1), im.call("maximum")( im.tuple_get(1, sym),im.plus(im.tuple_get(1, sym), 1))) + expected = im.plus(im.tuple_get(1, sym), 1) + actual = ConstantFolding.apply(testee) + assert actual == expected + +def test_max_syms(): + sym1 = im.ref("sym1") + sym2 = im.ref("sym2") + testee = im.call("maximum")( sym1, im.call("maximum")(sym2, sym1)) + expected = im.call("maximum")(sym2, sym1) + actual = ConstantFolding.apply(testee) + assert actual == expected +>>>>>>> 6b7f2888 (Minor improvements) From b1e7197fe7c1056736671dc9687cc3354f79b3b7 Mon Sep 17 00:00:00 2001 From: Sara Faghih-Naini Date: Mon, 27 Jan 2025 13:08:30 +0100 Subject: [PATCH 17/39] Move fixed point transformations to a new class --- .../iterator/transforms/collapse_tuple.py | 30 +--------- .../iterator/transforms/constant_folding.py | 48 ++++++--------- .../transforms/fixed_point_transform.py | 58 +++++++++++++++++++ .../transforms_tests/test_constant_folding.py | 15 +++-- 4 files changed, 86 insertions(+), 65 deletions(-) create mode 100644 src/gt4py/next/iterator/transforms/fixed_point_transform.py diff --git a/src/gt4py/next/iterator/transforms/collapse_tuple.py b/src/gt4py/next/iterator/transforms/collapse_tuple.py index 0a0cf6d37e..7fda0851aa 100644 --- a/src/gt4py/next/iterator/transforms/collapse_tuple.py +++ b/src/gt4py/next/iterator/transforms/collapse_tuple.py @@ -14,7 +14,6 @@ import operator from typing import Optional -from gt4py import eve from gt4py.eve import utils as eve_utils from gt4py.next import common from gt4py.next.iterator import ir @@ -23,6 +22,7 @@ ir_makers as im, misc as ir_misc, ) +from gt4py.next.iterator.transforms.fixed_point_transform import FixedPointTransform from gt4py.next.iterator.transforms.inline_lambdas import InlineLambdas, inline_lambda from gt4py.next.iterator.type_system import inference as itir_type_inference from gt4py.next.type_system import type_info, type_specifications as ts @@ -87,7 +87,7 @@ def _is_trivial_or_tuple_thereof_expr(node: ir.Node) -> bool: # reads a little convoluted and is also different to how we write other transformations. We # should revisit the pattern here and try to find a more general mechanism. @dataclasses.dataclass(frozen=True) -class CollapseTuple(eve.PreserveLocationVisitor, eve.NodeTranslator): +class CollapseTuple(FixedPointTransform): """ Simplifies `make_tuple`, `tuple_get` calls. @@ -217,32 +217,6 @@ def visit_FunCall(self, node: ir.FunCall, **kwargs) -> ir.Node: node = self.generic_visit(node, **kwargs) return self.fp_transform(node, **kwargs) - def fp_transform(self, node: ir.Node, **kwargs) -> ir.Node: - while True: - new_node = self.transform(node, **kwargs) - if new_node is None: - break - assert new_node != node - node = new_node - return node - - def transform(self, node: ir.Node, **kwargs) -> Optional[ir.Node]: - if not isinstance(node, ir.FunCall): - return None - - for transformation in self.Flag: - if self.flags & transformation: - assert isinstance(transformation.name, str) - method = getattr(self, f"transform_{transformation.name.lower()}") - result = method(node, **kwargs) - if result is not None: - assert ( - result is not node - ) # transformation should have returned None, since nothing changed - itir_type_inference.reinfer(result) - return result - return None - def transform_collapse_make_tuple_tuple_get( self, node: ir.FunCall, **kwargs ) -> Optional[ir.Node]: diff --git a/src/gt4py/next/iterator/transforms/constant_folding.py b/src/gt4py/next/iterator/transforms/constant_folding.py index 969a5b73d6..376167de7c 100644 --- a/src/gt4py/next/iterator/transforms/constant_folding.py +++ b/src/gt4py/next/iterator/transforms/constant_folding.py @@ -12,13 +12,18 @@ import operator from typing import Optional -from gt4py import eve from gt4py.next.iterator import builtins, embedded, ir from gt4py.next.iterator.ir_utils import common_pattern_matcher as cpm, ir_makers as im +from gt4py.next.iterator.transforms.fixed_point_transform import FixedPointTransform @dataclasses.dataclass(frozen=True) -class ConstantFolding(eve.PreserveLocationVisitor, eve.NodeTranslator): +class ConstantFolding(FixedPointTransform): + PRESERVED_ANNEX_ATTRS = ( + "type", + "domain", + ) + class Flag(enum.Flag): # e.g. `literal + symref` -> `symref + literal` and # `literal + funcall` -> `funcall + literal` and @@ -75,31 +80,6 @@ def visit_FunCall(self, node: ir.FunCall, **kwargs): node = self.generic_visit(node, **kwargs) return self.fp_transform(node, **kwargs) - def fp_transform(self, node: ir.Node, **kwargs) -> ir.Node: - while True: - new_node = self.transform(node, **kwargs) - if new_node is None: - break - assert new_node != node - node = new_node - return node - - def transform(self, node: ir.Node, **kwargs) -> Optional[ir.Node]: - if not isinstance(node, ir.FunCall): - return None - - for transformation in self.Flag: - if transformation: - assert isinstance(transformation.name, str) - method = getattr(self, f"transform_{transformation.name.lower()}") - result = method(node) - if result is not None: - assert ( - result is not node - ) # transformation should have returned None, since nothing changed - return result - return None - def transform_canonicalize_funcall_symref_literal( self, node: ir.FunCall, **kwargs ) -> Optional[ir.Node]: @@ -125,14 +105,20 @@ def transform_canonicalize_minus(self, node: ir.FunCall, **kwargs) -> Optional[i def transform_canonicalize_min_max(self, node: ir.FunCall, **kwargs) -> Optional[ir.Node]: # `maximum(im.call(...)(), maximum(...))` -> `maximum(maximum(...), im.call(...)())` if cpm.is_call_to(node, ("maximum", "minimum")): - if (isinstance(node.args[0], ir.FunCall) and not cpm.is_call_to(node.args[0], ("maximum", "minimum")) - and cpm.is_call_to(node.args[1], ("maximum", "minimum"))): + if ( + isinstance(node.args[0], ir.FunCall) + and isinstance(node.fun, ir.SymRef) + and not cpm.is_call_to(node.args[0], ("maximum", "minimum")) + and cpm.is_call_to(node.args[1], ("maximum", "minimum")) + ): return self.visit(im.call(node.fun.id)(node.args[1], node.args[0])) return None - def transform_canonicalize_tuple_get_plus(self, node: ir.FunCall, **kwargs) -> Optional[ir.Node]: + def transform_canonicalize_tuple_get_plus( + self, node: ir.FunCall, **kwargs + ) -> Optional[ir.Node]: # im.call(...)(im.tuple_get(...), im.plus(...)))` -> `im.call(...)( im.plus(...)), im.tuple_get(...))` - if isinstance(node, ir.FunCall) and len(node.args) > 1: + if isinstance(node, ir.FunCall) and isinstance(node.fun, ir.SymRef) and len(node.args) > 1: if cpm.is_call_to(node.args[0], "tuple_get") and cpm.is_call_to(node.args[1], "plus"): return self.visit(im.call(node.fun.id)(node.args[1], node.args[0])) return None diff --git a/src/gt4py/next/iterator/transforms/fixed_point_transform.py b/src/gt4py/next/iterator/transforms/fixed_point_transform.py new file mode 100644 index 0000000000..57b34c7207 --- /dev/null +++ b/src/gt4py/next/iterator/transforms/fixed_point_transform.py @@ -0,0 +1,58 @@ +# GT4Py - GridTools Framework +# +# Copyright (c) 2014-2024, ETH Zurich +# All rights reserved. +# +# Please, refer to the LICENSE file in the root directory. +# SPDX-License-Identifier: BSD-3-Clause + +import dataclasses +import enum +import functools +import operator +from typing import Optional, Type + +from abc import abstractmethod + +from gt4py import eve +from gt4py.next.iterator import ir +from gt4py.next.iterator.type_system import inference as itir_type_inference + + +@dataclasses.dataclass(frozen=True) +class FixedPointTransform(eve.PreserveLocationVisitor, eve.NodeTranslator): + @property + @abstractmethod + def Flag(self) -> Type[enum.Flag]: + pass + + @property + @abstractmethod + def flags(self) -> enum.Flag: + pass + + def fp_transform(self, node: ir.Node, **kwargs) -> ir.Node: + while True: + new_node = self.transform(node, **kwargs) + if new_node is None: + break + assert new_node != node + node = new_node + return node + + def transform(self, node: ir.Node, **kwargs) -> Optional[ir.Node]: + if not isinstance(node, ir.FunCall): + return None + + for transformation in self.Flag: + if self.flags & transformation: + assert isinstance(transformation.name, str) + method = getattr(self, f"transform_{transformation.name.lower()}") + result = method(node, **kwargs) + if result is not None: + assert ( + result is not node + ) # transformation should have returned None, since nothing changed + itir_type_inference.reinfer(result) + return result + return None diff --git a/tests/next_tests/unit_tests/iterator_tests/transforms_tests/test_constant_folding.py b/tests/next_tests/unit_tests/iterator_tests/transforms_tests/test_constant_folding.py index 0f89e7fa0d..7f5c4c3f6a 100644 --- a/tests/next_tests/unit_tests/iterator_tests/transforms_tests/test_constant_folding.py +++ b/tests/next_tests/unit_tests/iterator_tests/transforms_tests/test_constant_folding.py @@ -474,11 +474,11 @@ def test_constant_folding_complex_2(): def test_constant_folding_complex_4(): sym = im.ref("sym", "float32") testee = im.divides_( - im.minus(im.literal_from_value(1), sym), im.minus(im.literal_from_value(2), sym) + im.minus(im.literal("1", "float32"), sym), im.minus(im.literal("2", "float32"), sym) ) expected = im.divides_( - im.plus(im.call("neg")(sym), im.literal_from_value(1)), - im.plus(im.call("neg")(sym), im.literal_from_value(2)), + im.plus(im.call("neg")(sym), im.literal("1", "float32")), + im.plus(im.call("neg")(sym), im.literal("2", "float32")), ) actual = ConstantFolding.apply(testee) assert actual == expected @@ -517,16 +517,19 @@ def test_fold_min_max_plus(): def test_max_tuple_get(): sym = im.ref("sym") - testee = im.call("maximum")(im.plus(im.tuple_get(1, sym), 1), im.call("maximum")( im.tuple_get(1, sym),im.plus(im.tuple_get(1, sym), 1))) + testee = im.call("maximum")( + im.plus(im.tuple_get(1, sym), 1), + im.call("maximum")(im.tuple_get(1, sym), im.plus(im.tuple_get(1, sym), 1)), + ) expected = im.plus(im.tuple_get(1, sym), 1) actual = ConstantFolding.apply(testee) assert actual == expected + def test_max_syms(): sym1 = im.ref("sym1") sym2 = im.ref("sym2") - testee = im.call("maximum")( sym1, im.call("maximum")(sym2, sym1)) + testee = im.call("maximum")(sym1, im.call("maximum")(sym2, sym1)) expected = im.call("maximum")(sym2, sym1) actual = ConstantFolding.apply(testee) assert actual == expected ->>>>>>> 6b7f2888 (Minor improvements) From f250699b708a3304a151cc4afbd02c3035419c27 Mon Sep 17 00:00:00 2001 From: Sara Faghih-Naini Date: Mon, 27 Jan 2025 13:18:58 +0100 Subject: [PATCH 18/39] Run pre-commit --- .../next/iterator/transforms/fixed_point_transform.py | 7 ++----- 1 file changed, 2 insertions(+), 5 deletions(-) diff --git a/src/gt4py/next/iterator/transforms/fixed_point_transform.py b/src/gt4py/next/iterator/transforms/fixed_point_transform.py index 57b34c7207..22a75cc37d 100644 --- a/src/gt4py/next/iterator/transforms/fixed_point_transform.py +++ b/src/gt4py/next/iterator/transforms/fixed_point_transform.py @@ -8,11 +8,8 @@ import dataclasses import enum -import functools -import operator -from typing import Optional, Type - from abc import abstractmethod +from typing import Optional, Type from gt4py import eve from gt4py.next.iterator import ir @@ -51,7 +48,7 @@ def transform(self, node: ir.Node, **kwargs) -> Optional[ir.Node]: result = method(node, **kwargs) if result is not None: assert ( - result is not node + result is not node ) # transformation should have returned None, since nothing changed itir_type_inference.reinfer(result) return result From c24ea806ac6bb366f76564b62e560d8acc0c0676 Mon Sep 17 00:00:00 2001 From: Sara Faghih-Naini Date: Mon, 27 Jan 2025 13:58:12 +0100 Subject: [PATCH 19/39] Address review comments --- .../iterator/transforms/collapse_tuple.py | 35 ++++++++++++------- .../iterator/transforms/constant_folding.py | 20 ++++++++--- .../transforms/fixed_point_transform.py | 19 +++------- 3 files changed, 41 insertions(+), 33 deletions(-) diff --git a/src/gt4py/next/iterator/transforms/collapse_tuple.py b/src/gt4py/next/iterator/transforms/collapse_tuple.py index 7fda0851aa..e807652212 100644 --- a/src/gt4py/next/iterator/transforms/collapse_tuple.py +++ b/src/gt4py/next/iterator/transforms/collapse_tuple.py @@ -86,7 +86,7 @@ def _is_trivial_or_tuple_thereof_expr(node: ir.Node) -> bool: # go through all available transformation and apply them. However the final result here still # reads a little convoluted and is also different to how we write other transformations. We # should revisit the pattern here and try to find a more general mechanism. -@dataclasses.dataclass(frozen=True) +@dataclasses.dataclass(frozen=True, kw_only=True) class CollapseTuple(FixedPointTransform): """ Simplifies `make_tuple`, `tuple_get` calls. @@ -220,9 +220,13 @@ def visit_FunCall(self, node: ir.FunCall, **kwargs) -> ir.Node: def transform_collapse_make_tuple_tuple_get( self, node: ir.FunCall, **kwargs ) -> Optional[ir.Node]: - if node.fun == ir.SymRef(id="make_tuple") and all( - isinstance(arg, ir.FunCall) and arg.fun == ir.SymRef(id="tuple_get") - for arg in node.args + if ( + isinstance(node, ir.FunCall) + and node.fun == ir.SymRef(id="make_tuple") + and all( + isinstance(arg, ir.FunCall) and arg.fun == ir.SymRef(id="tuple_get") + for arg in node.args + ) ): # `make_tuple(tuple_get(0, t), tuple_get(1, t), ..., tuple_get(N-1,t))` -> `t` assert isinstance(node.args[0], ir.FunCall) @@ -249,7 +253,8 @@ def transform_collapse_tuple_get_make_tuple( self, node: ir.FunCall, **kwargs ) -> Optional[ir.Node]: if ( - node.fun == ir.SymRef(id="tuple_get") + isinstance(node, ir.FunCall) + and node.fun == ir.SymRef(id="tuple_get") and isinstance(node.args[1], ir.FunCall) and node.args[1].fun == ir.SymRef(id="make_tuple") and isinstance(node.args[0], ir.Literal) @@ -265,11 +270,15 @@ def transform_collapse_tuple_get_make_tuple( return None def transform_propagate_tuple_get(self, node: ir.FunCall, **kwargs) -> Optional[ir.Node]: - if node.fun == ir.SymRef(id="tuple_get") and isinstance(node.args[0], ir.Literal): + if ( + isinstance(node, ir.FunCall) + and node.fun == ir.SymRef(id="tuple_get") + and isinstance(node.args[0], ir.Literal) + ): # TODO(tehrengruber): extend to general symbols as long as the tail call in the let # does not capture # `tuple_get(i, let(...)(make_tuple()))` -> `let(...)(tuple_get(i, make_tuple()))` - if cpm.is_let(node.args[1]): + if isinstance(node, ir.FunCall) and cpm.is_let(node.args[1]): idx, let_expr = node.args return im.call( im.lambda_(*let_expr.fun.params)( # type: ignore[attr-defined] # ensured by is_let @@ -289,7 +298,7 @@ def transform_propagate_tuple_get(self, node: ir.FunCall, **kwargs) -> Optional[ return None def transform_letify_make_tuple_elements(self, node: ir.FunCall, **kwargs) -> Optional[ir.Node]: - if node.fun == ir.SymRef(id="make_tuple"): + if isinstance(node, ir.FunCall) and node.fun == ir.SymRef(id="make_tuple"): # `make_tuple(expr1, expr1)` # -> `let((_tuple_el_1, expr1), (_tuple_el_2, expr2))(make_tuple(_tuple_el_1, _tuple_el_2))` bound_vars: dict[ir.Sym, ir.Expr] = {} @@ -309,7 +318,7 @@ def transform_letify_make_tuple_elements(self, node: ir.FunCall, **kwargs) -> Op return None def transform_inline_trivial_make_tuple(self, node: ir.FunCall, **kwargs) -> Optional[ir.Node]: - if cpm.is_let(node): + if isinstance(node, ir.FunCall) and cpm.is_let(node): # `let(tup, make_tuple(trivial_expr1, trivial_expr2))(foo(tup))` # -> `foo(make_tuple(trivial_expr1, trivial_expr2))` eligible_params = [_is_trivial_make_tuple_call(arg) for arg in node.args] @@ -323,7 +332,7 @@ def transform_propagate_to_if_on_tuples(self, node: ir.FunCall, **kwargs) -> Opt # in local-view for now. Revisit. return None - if not cpm.is_call_to(node, "if_"): + if isinstance(node, ir.FunCall) and not cpm.is_call_to(node, "if_"): # TODO(tehrengruber): Only inline if type of branch value is a tuple. # Examples: # `(if cond then {1, 2} else {3, 4})[0]` -> `if cond then {1, 2}[0] else {3, 4}[0]` @@ -365,7 +374,7 @@ def transform_propagate_to_if_on_tuples_cps( # `if True then {2, 1} else {4, 3}`. The examples in the comments below all refer to this # tuple reordering example here. - if cpm.is_call_to(node, "if_"): + if not isinstance(node, ir.FunCall) or cpm.is_call_to(node, "if_"): return None # The first argument that is eligible also transforms all remaining args (They will be @@ -438,7 +447,7 @@ def transform_propagate_to_if_on_tuples_cps( return None def transform_propagate_nested_let(self, node: ir.FunCall, **kwargs) -> Optional[ir.Node]: - if cpm.is_let(node): + if isinstance(node, ir.FunCall) and cpm.is_let(node): # `let((a, let(b, 1)(a_val)))(a)`-> `let(b, 1)(let(a, a_val)(a))` outer_vars = {} inner_vars = {} @@ -464,7 +473,7 @@ def transform_propagate_nested_let(self, node: ir.FunCall, **kwargs) -> Optional return None def transform_inline_trivial_let(self, node: ir.FunCall, **kwargs) -> Optional[ir.Node]: - if cpm.is_let(node): + if isinstance(node, ir.FunCall) and cpm.is_let(node): if isinstance(node.fun.expr, ir.SymRef): # type: ignore[attr-defined] # ensured by is_let # `let(a, 1)(a)` -> `1` for arg_sym, arg in zip(node.fun.params, node.args): # type: ignore[attr-defined] # ensured by is_let diff --git a/src/gt4py/next/iterator/transforms/constant_folding.py b/src/gt4py/next/iterator/transforms/constant_folding.py index 376167de7c..c613d98825 100644 --- a/src/gt4py/next/iterator/transforms/constant_folding.py +++ b/src/gt4py/next/iterator/transforms/constant_folding.py @@ -17,7 +17,7 @@ from gt4py.next.iterator.transforms.fixed_point_transform import FixedPointTransform -@dataclasses.dataclass(frozen=True) +@dataclasses.dataclass(frozen=True, kw_only=True) class ConstantFolding(FixedPointTransform): PRESERVED_ANNEX_ATTRS = ( "type", @@ -86,8 +86,10 @@ def transform_canonicalize_funcall_symref_literal( # e.g. `literal + symref` -> `symref + literal` and # `literal + funcall` -> `funcall + literal` and # `symref + funcall` -> `funcall + symref` - if isinstance(node.fun, ir.SymRef) and cpm.is_call_to( - node, ("plus", "multiplies", "minimum", "maximum") + if ( + isinstance(node, ir.FunCall) + and isinstance(node.fun, ir.SymRef) + and cpm.is_call_to(node, ("plus", "multiplies", "minimum", "maximum")) ): if ( isinstance(node.args[1], (ir.SymRef, ir.FunCall)) @@ -139,7 +141,11 @@ def transform_fold_min_max_funcall_symref_literal( self, node: ir.FunCall, **kwargs ) -> Optional[ir.Node]: # `maximum(maximum(sym, 1), sym)` -> `maximum(sym, 1)` - if isinstance(node.fun, ir.SymRef) and cpm.is_call_to(node, ("minimum", "maximum")): + if ( + isinstance(node, ir.FunCall) + and isinstance(node.fun, ir.SymRef) + and cpm.is_call_to(node, ("minimum", "maximum")) + ): if cpm.is_call_to(node.args[0], ("maximum", "minimum")): fun_call, arg1 = node.args if arg1 == fun_call.args[0]: # type: ignore[attr-defined] # assured by if above @@ -149,7 +155,11 @@ def transform_fold_min_max_funcall_symref_literal( return None def transform_fold_min_max_plus(self, node: ir.FunCall, **kwargs) -> Optional[ir.Node]: - if isinstance(node.fun, ir.SymRef) and cpm.is_call_to(node, ("minimum", "maximum")): + if ( + isinstance(node, ir.FunCall) + and isinstance(node.fun, ir.SymRef) + and cpm.is_call_to(node, ("minimum", "maximum")) + ): arg0, arg1 = node.args # `maximum(plus(sym, 1), sym)` -> `plus(sym, 1)` if cpm.is_call_to(arg0, "plus"): diff --git a/src/gt4py/next/iterator/transforms/fixed_point_transform.py b/src/gt4py/next/iterator/transforms/fixed_point_transform.py index 22a75cc37d..8b6902385f 100644 --- a/src/gt4py/next/iterator/transforms/fixed_point_transform.py +++ b/src/gt4py/next/iterator/transforms/fixed_point_transform.py @@ -8,25 +8,17 @@ import dataclasses import enum -from abc import abstractmethod -from typing import Optional, Type +from typing import ClassVar, Optional, Type from gt4py import eve from gt4py.next.iterator import ir from gt4py.next.iterator.type_system import inference as itir_type_inference -@dataclasses.dataclass(frozen=True) +@dataclasses.dataclass(frozen=True, kw_only=True) class FixedPointTransform(eve.PreserveLocationVisitor, eve.NodeTranslator): - @property - @abstractmethod - def Flag(self) -> Type[enum.Flag]: - pass - - @property - @abstractmethod - def flags(self) -> enum.Flag: - pass + Flag: ClassVar[Type[enum.Flag]] + flags: enum.Flag def fp_transform(self, node: ir.Node, **kwargs) -> ir.Node: while True: @@ -38,9 +30,6 @@ def fp_transform(self, node: ir.Node, **kwargs) -> ir.Node: return node def transform(self, node: ir.Node, **kwargs) -> Optional[ir.Node]: - if not isinstance(node, ir.FunCall): - return None - for transformation in self.Flag: if self.flags & transformation: assert isinstance(transformation.name, str) From 4e4869a450716482bf3792339955b70f8737e700 Mon Sep 17 00:00:00 2001 From: Sara Faghih-Naini Date: Mon, 27 Jan 2025 14:41:07 +0100 Subject: [PATCH 20/39] Remove self.visit --- .../iterator/transforms/constant_folding.py | 46 +++++++++++-------- 1 file changed, 28 insertions(+), 18 deletions(-) diff --git a/src/gt4py/next/iterator/transforms/constant_folding.py b/src/gt4py/next/iterator/transforms/constant_folding.py index c613d98825..1453e0587c 100644 --- a/src/gt4py/next/iterator/transforms/constant_folding.py +++ b/src/gt4py/next/iterator/transforms/constant_folding.py @@ -100,8 +100,8 @@ def transform_canonicalize_funcall_symref_literal( def transform_canonicalize_minus(self, node: ir.FunCall, **kwargs) -> Optional[ir.Node]: # `minus(arg0, arg1) -> plus(im.call("neg")(arg1), arg0)` - if cpm.is_call_to(node, "minus"): - return self.visit(im.plus(im.call("neg")(node.args[1]), node.args[0])) + if isinstance(node, ir.FunCall) and cpm.is_call_to(node, "minus"): + return im.plus(self.fp_transform(im.call("neg")(node.args[1])), node.args[0]) return None def transform_canonicalize_min_max(self, node: ir.FunCall, **kwargs) -> Optional[ir.Node]: @@ -113,7 +113,7 @@ def transform_canonicalize_min_max(self, node: ir.FunCall, **kwargs) -> Optional and not cpm.is_call_to(node.args[0], ("maximum", "minimum")) and cpm.is_call_to(node.args[1], ("maximum", "minimum")) ): - return self.visit(im.call(node.fun.id)(node.args[1], node.args[0])) + return im.call(node.fun.id)(node.args[1], node.args[0]) return None def transform_canonicalize_tuple_get_plus( @@ -122,19 +122,27 @@ def transform_canonicalize_tuple_get_plus( # im.call(...)(im.tuple_get(...), im.plus(...)))` -> `im.call(...)( im.plus(...)), im.tuple_get(...))` if isinstance(node, ir.FunCall) and isinstance(node.fun, ir.SymRef) and len(node.args) > 1: if cpm.is_call_to(node.args[0], "tuple_get") and cpm.is_call_to(node.args[1], "plus"): - return self.visit(im.call(node.fun.id)(node.args[1], node.args[0])) + return im.call(node.fun.id)(node.args[1], node.args[0]) return None def transform_fold_funcall_literal(self, node: ir.FunCall, **kwargs) -> Optional[ir.Node]: # `(sym + 1) + 1` -> `sym + 2` if cpm.is_call_to(node, "plus"): - if cpm.is_call_to(node.args[0], "plus") and isinstance(node.args[1], ir.Literal): + if ( + isinstance(node.args[0], ir.FunCall) + and cpm.is_call_to(node.args[0], "plus") + and isinstance(node.args[1], ir.Literal) + ): fun_call, literal = node.args if ( - isinstance(fun_call.args[0], (ir.SymRef, ir.FunCall)) # type: ignore[attr-defined] # assured by if above - and isinstance(fun_call.args[1], ir.Literal) # type: ignore[attr-defined] # assured by if above + isinstance(fun_call, ir.FunCall) + and isinstance(fun_call.args[0], (ir.SymRef, ir.FunCall)) + and isinstance(fun_call.args[1], ir.Literal) ): - return self.visit(im.plus(fun_call.args[0], im.plus(fun_call.args[1], literal))) # type: ignore[attr-defined] # assured by if above + return im.plus( + fun_call.args[0], + self.fp_transform(im.plus(fun_call.args[1], literal)), + ) return None def transform_fold_min_max_funcall_symref_literal( @@ -149,9 +157,9 @@ def transform_fold_min_max_funcall_symref_literal( if cpm.is_call_to(node.args[0], ("maximum", "minimum")): fun_call, arg1 = node.args if arg1 == fun_call.args[0]: # type: ignore[attr-defined] # assured by if above - return self.visit(im.call(fun_call.fun.id)(fun_call.args[1], arg1)) # type: ignore[attr-defined] # assured by if above + return im.call(fun_call.fun.id)(fun_call.args[1], arg1) # type: ignore[attr-defined] # assured by if above if arg1 == fun_call.args[1]: # type: ignore[attr-defined] # assured by if above - return self.visit(im.call(fun_call.fun.id)(fun_call.args[0], arg1)) # type: ignore[attr-defined] # assured by if above + return im.call(fun_call.fun.id)(fun_call.args[0], arg1) # type: ignore[attr-defined] # assured by if above return None def transform_fold_min_max_plus(self, node: ir.FunCall, **kwargs) -> Optional[ir.Node]: @@ -164,13 +172,17 @@ def transform_fold_min_max_plus(self, node: ir.FunCall, **kwargs) -> Optional[ir # `maximum(plus(sym, 1), sym)` -> `plus(sym, 1)` if cpm.is_call_to(arg0, "plus"): if arg0.args[0] == arg1: - return self.visit(im.plus(arg0.args[0], im.call(node.fun.id)(arg0.args[1], 0))) + return im.plus( + arg0.args[0], self.fp_transform(im.call(node.fun.id)(arg0.args[1], 0)) + ) # `maximum(plus(sym, 1), plus(sym, -1))` -> `plus(sym, 1)` if cpm.is_call_to(arg0, "plus") and cpm.is_call_to(arg1, "plus"): if arg0.args[0] == arg1.args[0]: - return self.visit( - im.plus(arg0.args[0], im.call(node.fun.id)(arg0.args[1], arg1.args[1])) + return im.plus( + arg0.args[0], + self.fp_transform(im.call(node.fun.id)(arg0.args[1], arg1.args[1])), ) + return None def transform_fold_symref_plus_zero(self, node: ir.FunCall, **kwargs) -> Optional[ir.Node]: @@ -195,11 +207,9 @@ def transform_canonicalize_plus_symref_literal( and isinstance(node.args[0].args[1], ir.Literal) and isinstance(node.args[1].args[1], ir.Literal) ): - return self.visit( - im.plus( - im.plus(node.args[0].args[0], node.args[1].args[0]), - im.plus(node.args[0].args[1], node.args[1].args[1]), - ) + return im.plus( + self.fp_transform(im.plus(node.args[0].args[0], node.args[1].args[0])), + self.fp_transform(im.plus(node.args[0].args[1], node.args[1].args[1])), ) return None From 39eda41fc2a354761f35c160afe9cb332bc41e3f Mon Sep 17 00:00:00 2001 From: Sara Faghih-Naini Date: Tue, 28 Jan 2025 11:49:36 +0100 Subject: [PATCH 21/39] Minor --- src/gt4py/next/iterator/transforms/constant_folding.py | 9 +++------ 1 file changed, 3 insertions(+), 6 deletions(-) diff --git a/src/gt4py/next/iterator/transforms/constant_folding.py b/src/gt4py/next/iterator/transforms/constant_folding.py index a67650d681..caae85fd78 100644 --- a/src/gt4py/next/iterator/transforms/constant_folding.py +++ b/src/gt4py/next/iterator/transforms/constant_folding.py @@ -6,6 +6,8 @@ # Please, refer to the LICENSE file in the root directory. # SPDX-License-Identifier: BSD-3-Clause +from __future__ import annotations + import dataclasses import enum import functools @@ -68,7 +70,7 @@ class Transformation(enum.Flag): FOLD_IF = enum.auto() @classmethod - def all(self): # TODO -> ConstantFolding.Flag: + def all(self) -> ConstantFolding.Transformation: return functools.reduce(operator.or_, self.__members__.values()) enabled_transformations: Transformation = Transformation.all() # noqa: RUF009 [function-call-in-dataclass-default-argument] @@ -78,11 +80,6 @@ def apply(cls, node: ir.Node) -> ir.Node: node = cls().visit(node) return node - def visit_FunCall(self, node: ir.FunCall, **kwargs): - # visit depth-first such that nested constant expressions (e.g. `(1+2)+3`) are properly folded - node = self.generic_visit(node, **kwargs) - return self.fp_transform(node, **kwargs) # TODO: is that as intended? - def transform_canonicalize_funcall_symref_literal( self, node: ir.FunCall, **kwargs ) -> Optional[ir.Node]: From b5f32cf4bd521ab46f00c0b8b63ca66005c16b83 Mon Sep 17 00:00:00 2001 From: Till Ehrengruber Date: Tue, 28 Jan 2025 11:58:10 +0100 Subject: [PATCH 22/39] Update src/gt4py/next/iterator/builtins.py --- src/gt4py/next/iterator/builtins.py | 1 - 1 file changed, 1 deletion(-) diff --git a/src/gt4py/next/iterator/builtins.py b/src/gt4py/next/iterator/builtins.py index daa0fe5df1..8e5f7addca 100644 --- a/src/gt4py/next/iterator/builtins.py +++ b/src/gt4py/next/iterator/builtins.py @@ -425,7 +425,6 @@ def bool(*args): # noqa: A001 [builtin-variable-shadowing] "floor", "ceil", "trunc", - "neg", } UNARY_MATH_FP_PREDICATE_BUILTINS = {"isfinite", "isinf", "isnan"} BINARY_MATH_NUMBER_BUILTINS = { From 054bc418ddd894db58e9c513d62cb13693a3129e Mon Sep 17 00:00:00 2001 From: Till Ehrengruber Date: Tue, 28 Jan 2025 11:58:16 +0100 Subject: [PATCH 23/39] Update src/gt4py/next/iterator/embedded.py --- src/gt4py/next/iterator/embedded.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/gt4py/next/iterator/embedded.py b/src/gt4py/next/iterator/embedded.py index 6a44b880e3..16b1fa9d03 100644 --- a/src/gt4py/next/iterator/embedded.py +++ b/src/gt4py/next/iterator/embedded.py @@ -549,7 +549,7 @@ def promote_scalars(val: CompositeOfScalarOrField): } decorator = getattr(builtins, math_builtin_name).register(EMBEDDED) impl: Callable - if math_builtin_name in ["gamma", "not_", "neg"]: + if math_builtin_name in ["gamma", "not_"]: continue # treated explicitly elif math_builtin_name in python_builtins: # TODO: Should potentially use numpy fixed size types to be consistent From 93e0ab5848ba1239b6010ab626a3478d02927884 Mon Sep 17 00:00:00 2001 From: Till Ehrengruber Date: Tue, 28 Jan 2025 11:58:22 +0100 Subject: [PATCH 24/39] Update src/gt4py/next/iterator/transforms/collapse_tuple.py --- src/gt4py/next/iterator/transforms/collapse_tuple.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/gt4py/next/iterator/transforms/collapse_tuple.py b/src/gt4py/next/iterator/transforms/collapse_tuple.py index e9b14f95c7..1ce24fafe4 100644 --- a/src/gt4py/next/iterator/transforms/collapse_tuple.py +++ b/src/gt4py/next/iterator/transforms/collapse_tuple.py @@ -463,7 +463,7 @@ def transform_propagate_nested_let(self, node: ir.FunCall, **kwargs) -> Optional return None def transform_inline_trivial_let(self, node: ir.FunCall, **kwargs) -> Optional[ir.Node]: - if isinstance(node, ir.FunCall) and cpm.is_let(node): + if cpm.is_let(node): if isinstance(node.fun.expr, ir.SymRef): # type: ignore[attr-defined] # ensured by is_let # `let(a, 1)(a)` -> `1` for arg_sym, arg in zip(node.fun.params, node.args): # type: ignore[attr-defined] # ensured by is_let From 3ea7217ea3dee9351043b3dbca8bd4d84b995e72 Mon Sep 17 00:00:00 2001 From: Till Ehrengruber Date: Tue, 28 Jan 2025 11:58:29 +0100 Subject: [PATCH 25/39] Update src/gt4py/next/iterator/transforms/collapse_tuple.py --- src/gt4py/next/iterator/transforms/collapse_tuple.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/gt4py/next/iterator/transforms/collapse_tuple.py b/src/gt4py/next/iterator/transforms/collapse_tuple.py index 1ce24fafe4..699279d2f7 100644 --- a/src/gt4py/next/iterator/transforms/collapse_tuple.py +++ b/src/gt4py/next/iterator/transforms/collapse_tuple.py @@ -437,7 +437,7 @@ def transform_propagate_to_if_on_tuples_cps( return None def transform_propagate_nested_let(self, node: ir.FunCall, **kwargs) -> Optional[ir.Node]: - if isinstance(node, ir.FunCall) and cpm.is_let(node): + if cpm.is_let(node): # `let((a, let(b, 1)(a_val)))(a)`-> `let(b, 1)(let(a, a_val)(a))` outer_vars = {} inner_vars = {} From 00332ccbace76def3543765af5b51651479dc49d Mon Sep 17 00:00:00 2001 From: Till Ehrengruber Date: Tue, 28 Jan 2025 11:58:34 +0100 Subject: [PATCH 26/39] Update src/gt4py/next/iterator/transforms/collapse_tuple.py --- src/gt4py/next/iterator/transforms/collapse_tuple.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/gt4py/next/iterator/transforms/collapse_tuple.py b/src/gt4py/next/iterator/transforms/collapse_tuple.py index 699279d2f7..1f9ac00b77 100644 --- a/src/gt4py/next/iterator/transforms/collapse_tuple.py +++ b/src/gt4py/next/iterator/transforms/collapse_tuple.py @@ -268,7 +268,7 @@ def transform_propagate_tuple_get(self, node: ir.FunCall, **kwargs) -> Optional[ # TODO(tehrengruber): extend to general symbols as long as the tail call in the let # does not capture # `tuple_get(i, let(...)(make_tuple()))` -> `let(...)(tuple_get(i, make_tuple()))` - if isinstance(node, ir.FunCall) and cpm.is_let(node.args[1]): + if cpm.is_let(node.args[1]): idx, let_expr = node.args return im.call( im.lambda_(*let_expr.fun.params)( # type: ignore[attr-defined] # ensured by is_let From c65dbfb88373cc0ea97413c0da34268576232e86 Mon Sep 17 00:00:00 2001 From: Sara Faghih-Naini Date: Tue, 28 Jan 2025 11:59:00 +0100 Subject: [PATCH 27/39] Remove file --- .../transforms/fixed_point_transform.py | 44 ------------------- 1 file changed, 44 deletions(-) delete mode 100644 src/gt4py/next/iterator/transforms/fixed_point_transform.py diff --git a/src/gt4py/next/iterator/transforms/fixed_point_transform.py b/src/gt4py/next/iterator/transforms/fixed_point_transform.py deleted file mode 100644 index 8b6902385f..0000000000 --- a/src/gt4py/next/iterator/transforms/fixed_point_transform.py +++ /dev/null @@ -1,44 +0,0 @@ -# GT4Py - GridTools Framework -# -# Copyright (c) 2014-2024, ETH Zurich -# All rights reserved. -# -# Please, refer to the LICENSE file in the root directory. -# SPDX-License-Identifier: BSD-3-Clause - -import dataclasses -import enum -from typing import ClassVar, Optional, Type - -from gt4py import eve -from gt4py.next.iterator import ir -from gt4py.next.iterator.type_system import inference as itir_type_inference - - -@dataclasses.dataclass(frozen=True, kw_only=True) -class FixedPointTransform(eve.PreserveLocationVisitor, eve.NodeTranslator): - Flag: ClassVar[Type[enum.Flag]] - flags: enum.Flag - - def fp_transform(self, node: ir.Node, **kwargs) -> ir.Node: - while True: - new_node = self.transform(node, **kwargs) - if new_node is None: - break - assert new_node != node - node = new_node - return node - - def transform(self, node: ir.Node, **kwargs) -> Optional[ir.Node]: - for transformation in self.Flag: - if self.flags & transformation: - assert isinstance(transformation.name, str) - method = getattr(self, f"transform_{transformation.name.lower()}") - result = method(node, **kwargs) - if result is not None: - assert ( - result is not node - ) # transformation should have returned None, since nothing changed - itir_type_inference.reinfer(result) - return result - return None From 96763340f21964d525bb0c30ef00c97bf7cc1638 Mon Sep 17 00:00:00 2001 From: Sara Faghih-Naini Date: Tue, 28 Jan 2025 14:56:53 +0100 Subject: [PATCH 28/39] Add new line --- src/gt4py/next/iterator/transforms/collapse_tuple.py | 1 + 1 file changed, 1 insertion(+) diff --git a/src/gt4py/next/iterator/transforms/collapse_tuple.py b/src/gt4py/next/iterator/transforms/collapse_tuple.py index 1f9ac00b77..6db58f3765 100644 --- a/src/gt4py/next/iterator/transforms/collapse_tuple.py +++ b/src/gt4py/next/iterator/transforms/collapse_tuple.py @@ -216,6 +216,7 @@ def apply( def visit(self, node, **kwargs): if cpm.is_call_to(node, "as_fieldop"): kwargs = {**kwargs, "within_stencil": True} + return super().visit(node, **kwargs) def transform_collapse_make_tuple_tuple_get( From 231c0e52095161b2817d04cf1f8003777af940ae Mon Sep 17 00:00:00 2001 From: Sara Faghih-Naini Date: Tue, 28 Jan 2025 18:29:10 +0100 Subject: [PATCH 29/39] Address review comments --- .../iterator/transforms/constant_folding.py | 155 +++++++----------- .../transforms_tests/test_constant_folding.py | 26 ++- 2 files changed, 73 insertions(+), 108 deletions(-) diff --git a/src/gt4py/next/iterator/transforms/constant_folding.py b/src/gt4py/next/iterator/transforms/constant_folding.py index caae85fd78..0856a5785b 100644 --- a/src/gt4py/next/iterator/transforms/constant_folding.py +++ b/src/gt4py/next/iterator/transforms/constant_folding.py @@ -30,35 +30,30 @@ class ConstantFolding( ) class Transformation(enum.Flag): - # e.g. `literal + symref` -> `symref + literal` and - # `literal + funcall` -> `funcall + literal` and - # `symref + funcall` -> `funcall + symref` - CANONICALIZE_FUNCALL_SYMREF_LITERAL = enum.auto() + # `literal, symref` -> `symref, literal`, prerequisite for FOLD_FUNCALL_LITERAL, FOLD_NEUTRAL_OP + # `literal, funcall` -> `funcall, literal`, prerequisite for FOLD_FUNCALL_LITERAL, FOLD_NEUTRAL_OP + # `funcall, op` -> `op, funcall` for s[0] + (s[0] + 1), prerequisite for FOLD_MIN_MAX_PLUS + CANONICALIZE_OP_FUNCALL_SYMREF_LITERAL = enum.auto() - # `minus(arg0, arg1) -> plus(im.call("neg")(arg1), arg0)` + # `a - b` -> `a + (-b)`, prerequisite for FOLD_MIN_MAX_PLUS CANONICALIZE_MINUS = enum.auto() - # `maximum(im.call(...)(), maximum(...))` -> `maximum(maximum(...), im.call(...)())` + # `maximum(a, maximum(...))` -> `maximum(maximum(...), a)`, prerequisite for FOLD_MIN_MAX CANONICALIZE_MIN_MAX = enum.auto() - # `im.call(...)(im.tuple_get(...), im.plus(...)))` -> `im.call(...)( im.plus(...)), im.tuple_get(...))` - CANONICALIZE_TUPLE_GET_PLUS = enum.auto() - - # `(sym + 1) + 1` -> `sym + 2` + # `(a + 1) + 1` -> `a + (1 + 1)` FOLD_FUNCALL_LITERAL = enum.auto() - # `maximum(maximum(sym, 1), sym)` -> `maximum(sym, 1)` - FOLD_MIN_MAX_FUNCALL_SYMREF_LITERAL = enum.auto() + # `maximum(maximum(a, 1), a)` -> `maximum(a, 1)` + # `maximum(maximum(a, 1), 1)` -> `maximum(a, 1)` + FOLD_MIN_MAX = enum.auto() - # `maximum(plus(sym, 1), sym)` -> `plus(sym, 1)` and - # `maximum(plus(sym, 1), plus(sym, -1))` -> `plus(sym, 1)` + # `maximum(plus(a, 1), a)` -> `plus(a, 1)` + # `maximum(plus(a, 1), plus(a, -1))` -> `plus(a, maximum(1, -1))` FOLD_MIN_MAX_PLUS = enum.auto() - # `sym + 0` -> `sym` - FOLD_SYMREF_PLUS_ZERO = enum.auto() - - # `sym + 1 + (sym + 2)` -> `sym + sym + 2 + 1` - CANONICALIZE_PLUS_SYMREF_LITERAL = enum.auto() + # `a + 0` -> `a`, `a * 1` -> `a` + FOLD_NEUTRAL_OP = enum.auto() # `1 + 1` -> `2` FOLD_ARITHMETIC_BUILTINS = enum.auto() @@ -80,86 +75,60 @@ def apply(cls, node: ir.Node) -> ir.Node: node = cls().visit(node) return node - def transform_canonicalize_funcall_symref_literal( + def transform_canonicalize_op_funcall_symref_literal( self, node: ir.FunCall, **kwargs ) -> Optional[ir.Node]: - # e.g. `literal + symref` -> `symref + literal` and - # `literal + funcall` -> `funcall + literal` and - # `symref + funcall` -> `funcall + symref` - if ( - isinstance(node, ir.FunCall) - and isinstance(node.fun, ir.SymRef) - and cpm.is_call_to(node, ("plus", "multiplies", "minimum", "maximum")) - ): + # `literal, symref` -> `symref, literal`, prerequisite for FOLD_FUNCALL_LITERAL, FOLD_NEUTRAL_OP + # `literal, funcall` -> `funcall, literal`, prerequisite for FOLD_FUNCALL_LITERAL, FOLD_NEUTRAL_OP + # `funcall, op` -> `op, funcall` for s[0] + (s[0] + 1), prerequisite for FOLD_MIN_MAX_PLUS + if cpm.is_call_to(node, ("plus", "multiplies", "minimum", "maximum")): if ( - isinstance(node.args[1], (ir.SymRef, ir.FunCall)) - and isinstance(node.args[0], ir.Literal) - ) or (isinstance(node.args[1], ir.FunCall) and isinstance(node.args[0], ir.SymRef)): - return im.call(node.fun.id)(node.args[1], node.args[0]) + isinstance(node.args[0], ir.Literal) and not isinstance(node.args[1], ir.Literal) + ) or ( + (not cpm.is_call_to(node.args[0], ("plus", "multiplies", "minimum", "maximum"))) + and cpm.is_call_to(node.args[1], ("plus", "multiplies", "minimum", "maximum")) + ): + return im.call(node.fun)(node.args[1], node.args[0]) return None def transform_canonicalize_minus(self, node: ir.FunCall, **kwargs) -> Optional[ir.Node]: - # `minus(arg0, arg1) -> plus(im.call("neg")(arg1), arg0)` - if isinstance(node, ir.FunCall) and cpm.is_call_to(node, "minus"): - return im.plus(self.fp_transform(im.call("neg")(node.args[1])), node.args[0]) + # `a - b` -> `a + (-b)`, prerequisite for FOLD_MIN_MAX_PLUS + if cpm.is_call_to(node, "minus"): + return im.plus(node.args[0], self.fp_transform(im.call("neg")(node.args[1]))) return None def transform_canonicalize_min_max(self, node: ir.FunCall, **kwargs) -> Optional[ir.Node]: - # `maximum(im.call(...)(), maximum(...))` -> `maximum(maximum(...), im.call(...)())` + # `maximum(a, maximum(...))` -> `maximum(maximum(...), a)`, prerequisite for FOLD_MIN_MAX if cpm.is_call_to(node, ("maximum", "minimum")): - if ( - isinstance(node.args[0], ir.FunCall) - and isinstance(node.fun, ir.SymRef) - and not cpm.is_call_to(node.args[0], ("maximum", "minimum")) - and cpm.is_call_to(node.args[1], ("maximum", "minimum")) - ): - return im.call(node.fun.id)(node.args[1], node.args[0]) - return None - - def transform_canonicalize_tuple_get_plus( - self, node: ir.FunCall, **kwargs - ) -> Optional[ir.Node]: - # im.call(...)(im.tuple_get(...), im.plus(...)))` -> `im.call(...)( im.plus(...)), im.tuple_get(...))` - if isinstance(node, ir.FunCall) and isinstance(node.fun, ir.SymRef) and len(node.args) > 1: - if cpm.is_call_to(node.args[0], "tuple_get") and cpm.is_call_to(node.args[1], "plus"): - return im.call(node.fun.id)(node.args[1], node.args[0]) + op = node.fun.id # type: ignore[attr-defined] # assured by if above + if cpm.is_call_to(node.args[1], op) and not cpm.is_call_to(node.args[0], op): + return im.call(op)(node.args[1], node.args[0]) return None def transform_fold_funcall_literal(self, node: ir.FunCall, **kwargs) -> Optional[ir.Node]: - # `(sym + 1) + 1` -> `sym + 2` + # `(a + 1) + 1` -> `a + (1 + 1)` if cpm.is_call_to(node, "plus"): - if ( - isinstance(node.args[0], ir.FunCall) - and cpm.is_call_to(node.args[0], "plus") - and isinstance(node.args[1], ir.Literal) - ): + if cpm.is_call_to(node.args[0], "plus") and isinstance(node.args[1], ir.Literal): fun_call, literal = node.args - if ( - isinstance(fun_call, ir.FunCall) - and isinstance(fun_call.args[0], (ir.SymRef, ir.FunCall)) - and isinstance(fun_call.args[1], ir.Literal) + if isinstance(fun_call.args[0], (ir.SymRef, ir.FunCall)) and isinstance( # type: ignore[attr-defined] # assured by if above + fun_call.args[1], # type: ignore[attr-defined] # assured by if above + ir.Literal, ): return im.plus( - fun_call.args[0], - self.fp_transform(im.plus(fun_call.args[1], literal)), + fun_call.args[0], # type: ignore[attr-defined] # assured by if above + self.fp_transform(im.plus(fun_call.args[1], literal)), # type: ignore[attr-defined] # assured by if above ) return None - def transform_fold_min_max_funcall_symref_literal( - self, node: ir.FunCall, **kwargs - ) -> Optional[ir.Node]: - # `maximum(maximum(sym, 1), sym)` -> `maximum(sym, 1)` - if ( - isinstance(node, ir.FunCall) - and isinstance(node.fun, ir.SymRef) - and cpm.is_call_to(node, ("minimum", "maximum")) - ): - if cpm.is_call_to(node.args[0], ("maximum", "minimum")): + def transform_fold_min_max(self, node: ir.FunCall, **kwargs) -> Optional[ir.Node]: + # `maximum(maximum(a, 1), a)` -> `maximum(a, 1)` + # `maximum(maximum(a, 1), 1)` -> `maximum(a, 1)` + if cpm.is_call_to(node, ("minimum", "maximum")): + op = node.fun.id # type: ignore[attr-defined] # assured by if above + if cpm.is_call_to(node.args[0], op): fun_call, arg1 = node.args - if arg1 == fun_call.args[0]: # type: ignore[attr-defined] # assured by if above - return im.call(fun_call.fun.id)(fun_call.args[1], arg1) # type: ignore[attr-defined] # assured by if above - if arg1 == fun_call.args[1]: # type: ignore[attr-defined] # assured by if above - return im.call(fun_call.fun.id)(fun_call.args[0], arg1) # type: ignore[attr-defined] # assured by if above + if arg1 in fun_call.args: # type: ignore[attr-defined] # assured by if above + return fun_call return None def transform_fold_min_max_plus(self, node: ir.FunCall, **kwargs) -> Optional[ir.Node]: @@ -169,13 +138,13 @@ def transform_fold_min_max_plus(self, node: ir.FunCall, **kwargs) -> Optional[ir and cpm.is_call_to(node, ("minimum", "maximum")) ): arg0, arg1 = node.args - # `maximum(plus(sym, 1), sym)` -> `plus(sym, 1)` + # `maximum(plus(a, 1), a)` -> `plus(a, 1)` if cpm.is_call_to(arg0, "plus"): if arg0.args[0] == arg1: return im.plus( arg0.args[0], self.fp_transform(im.call(node.fun.id)(arg0.args[1], 0)) ) - # `maximum(plus(sym, 1), plus(sym, -1))` -> `plus(sym, 1)` + # `maximum(plus(a, 1), plus(a, -1))` -> `plus(a, maximum(1, -1))` if cpm.is_call_to(arg0, "plus") and cpm.is_call_to(arg1, "plus"): if arg0.args[0] == arg1.args[0]: return im.plus( @@ -185,34 +154,22 @@ def transform_fold_min_max_plus(self, node: ir.FunCall, **kwargs) -> Optional[ir return None - def transform_fold_symref_plus_zero(self, node: ir.FunCall, **kwargs) -> Optional[ir.Node]: - # `sym + 0` -> `sym` + def transform_fold_neutral_op(self, node: ir.FunCall, **kwargs) -> Optional[ir.Node]: + # `a + 0` -> `a`, `a * 1` -> `a` if ( cpm.is_call_to(node, "plus") and isinstance(node.args[1], ir.Literal) and node.args[1].value.isdigit() and int(node.args[1].value) == 0 + ) or ( + cpm.is_call_to(node, "multiplies") + and isinstance(node.args[1], ir.Literal) + and node.args[1].value.isdigit() + and int(node.args[1].value) == 1 ): return node.args[0] return None - def transform_canonicalize_plus_symref_literal( - self, node: ir.FunCall, **kwargs - ) -> Optional[ir.Node]: - # `sym1 + 1 + (sym2 + 2)` -> `sym1 + sym2 + 2 + 1` - if cpm.is_call_to(node, "plus"): - if ( - cpm.is_call_to(node.args[0], "plus") - and cpm.is_call_to(node.args[1], "plus") - and isinstance(node.args[0].args[1], ir.Literal) - and isinstance(node.args[1].args[1], ir.Literal) - ): - return im.plus( - self.fp_transform(im.plus(node.args[0].args[0], node.args[1].args[0])), - self.fp_transform(im.plus(node.args[0].args[1], node.args[1].args[1])), - ) - return None - def transform_fold_arithmetic_builtins(self, node: ir.FunCall, **kwargs) -> Optional[ir.Node]: # `1 + 1` -> `2` if ( diff --git a/tests/next_tests/unit_tests/iterator_tests/transforms_tests/test_constant_folding.py b/tests/next_tests/unit_tests/iterator_tests/transforms_tests/test_constant_folding.py index 0cb77d4496..94663b123c 100644 --- a/tests/next_tests/unit_tests/iterator_tests/transforms_tests/test_constant_folding.py +++ b/tests/next_tests/unit_tests/iterator_tests/transforms_tests/test_constant_folding.py @@ -371,10 +371,6 @@ def test_constant_folding_complex(): assert actual == expected -# ( (min(1 - sym, 1 + sym) + (max(max(1 - sym, 1 + sym),1 - sym) + max(1 - sym, 1 - sym)))))) - 2 -# max(sym, 1 + sym) + (max(1, max(1, sym)) + (sym - 1 + (1 + (sym + 1) + 1))) - 2 - - def test_constant_folding_complex_1(): sym = im.ref("sym") # maximum(sym, 1 + sym) + (maximum(1, maximum(1, sym)) + (sym - 1 + (1 + (sym + 1) + 1))) - 2 @@ -396,13 +392,15 @@ def test_constant_folding_complex_1(): ), im.literal_from_value(2), ) - # sym + 1 + (maximum(sym, 1) + (sym + sym + 2)) + -2 + # sym + 1 + (maximum(sym, 1) + (sym + -1 + (sym + 3))) + -2 expected = im.plus( im.plus( im.plus(sym, 1), im.plus( im.call("maximum")(sym, im.literal_from_value(1)), - im.plus(im.plus(sym, sym), im.literal_from_value(2)), + im.plus( + im.plus(sym, im.literal_from_value(-1)), im.plus(sym, im.literal_from_value(3)) + ), ), ), im.literal_from_value(-2), @@ -430,7 +428,7 @@ def test_constant_folding_complex_3(): ), ), ) - # minimum(neg(sym) + 1, sym + 1) + (maximum(sym + 1, neg(sym) + 1) + (neg(sym) + 1)) + # minimum(neg(sym) + 1, sym + 1) + (maximum(neg(sym) + 1, sym + 1) + (neg(sym) + 1)) expected = im.plus( im.call("minimum")( im.plus(im.call("neg")(sym), im.literal_from_value(1)), @@ -438,8 +436,8 @@ def test_constant_folding_complex_3(): ), im.plus( im.call("maximum")( - im.plus(sym, im.literal_from_value(1)), im.plus(im.call("neg")(sym), im.literal_from_value(1)), + im.plus(sym, im.literal_from_value(1)), ), im.plus(im.call("neg")(sym), im.literal_from_value(1)), ), @@ -466,7 +464,9 @@ def test_constant_folding_complex_2(): testee = im.plus( im.plus(sym, im.literal_from_value(-1)), im.plus(sym, im.literal_from_value(3)) ) - expected = im.plus(im.plus(sym, sym), im.literal_from_value(2)) + expected = im.plus( + im.plus(sym, im.literal_from_value(-1)), im.plus(sym, im.literal_from_value(3)) + ) actual = ConstantFolding.apply(testee) assert actual == expected @@ -533,3 +533,11 @@ def test_max_syms(): expected = im.call("maximum")(sym2, sym1) actual = ConstantFolding.apply(testee) assert actual == expected + + +def test_max_min(): + sym = im.ref("sym1") + testee = im.call("maximum")(im.call("minimum")(sym, 1), sym) + expected = im.call("maximum")(im.call("minimum")(sym, 1), sym) + actual = ConstantFolding.apply(testee) + assert actual == expected From 5df9b8b8fbf66b38dcc89ba44741402329bcf09d Mon Sep 17 00:00:00 2001 From: Sara Faghih-Naini Date: Tue, 28 Jan 2025 23:55:17 +0100 Subject: [PATCH 30/39] Remove unary minus in the end --- .../iterator/transforms/constant_folding.py | 40 ++++++++++++- .../transforms_tests/test_constant_folding.py | 59 +++++++++++-------- 2 files changed, 72 insertions(+), 27 deletions(-) diff --git a/src/gt4py/next/iterator/transforms/constant_folding.py b/src/gt4py/next/iterator/transforms/constant_folding.py index 0856a5785b..1d2d15cdda 100644 --- a/src/gt4py/next/iterator/transforms/constant_folding.py +++ b/src/gt4py/next/iterator/transforms/constant_folding.py @@ -20,6 +20,43 @@ from gt4py.next.iterator.transforms import fixed_point_transformation +class UndoCanonicalizeMinus(eve.NodeTranslator): + def generic_visit(self, node, **kwargs) -> ir.Node: + node = super().generic_visit(node, **kwargs) + return self.undo_canonicalize_minus(node, **kwargs) + + def undo_canonicalize_minus(self, node: ir.FunCall, **kwargs) -> ir.Node: + # `a + (-b)` -> `a - b` , `-a + b` -> `b - a`, `-a + (-b)` -> `-a - b` + if cpm.is_call_to(node, "plus"): + if cpm.is_call_to(node.args[1], "neg"): + return im.minus(node.args[0], node.args[1].args[0]) + if ( + isinstance(node.args[1], ir.Literal) + and node.args[1].value + < im.literal("0", typename=node.args[1].type.kind.name.lower()).value + ): + return im.minus( + node.args[0], + ConstantFolding.transform_fold_arithmetic_builtins( + im.multiplies_(node.args[1], -1) + ), + ) + if cpm.is_call_to(node.args[0], "neg"): + return im.minus(node.args[1], node.args[0].args[0]) + if ( + isinstance(node.args[0], ir.Literal) + and node.args[0].value + < im.literal("0", typename=node.args[0].type.kind.name.lower()).value + ): + return im.minus( + node.args[1], + ConstantFolding.transform_fold_arithmetic_builtins( + im.multiplies_(node.args[0], -1) + ), + ) + return node + + @dataclasses.dataclass(frozen=True, kw_only=True) class ConstantFolding( fixed_point_transformation.FixedPointTransformation, eve.PreserveLocationVisitor @@ -73,7 +110,7 @@ def all(self) -> ConstantFolding.Transformation: @classmethod def apply(cls, node: ir.Node) -> ir.Node: node = cls().visit(node) - return node + return UndoCanonicalizeMinus().generic_visit(node) def transform_canonicalize_op_funcall_symref_literal( self, node: ir.FunCall, **kwargs @@ -170,6 +207,7 @@ def transform_fold_neutral_op(self, node: ir.FunCall, **kwargs) -> Optional[ir.N return node.args[0] return None + @classmethod def transform_fold_arithmetic_builtins(self, node: ir.FunCall, **kwargs) -> Optional[ir.Node]: # `1 + 1` -> `2` if ( diff --git a/tests/next_tests/unit_tests/iterator_tests/transforms_tests/test_constant_folding.py b/tests/next_tests/unit_tests/iterator_tests/transforms_tests/test_constant_folding.py index 94663b123c..e5c85922e2 100644 --- a/tests/next_tests/unit_tests/iterator_tests/transforms_tests/test_constant_folding.py +++ b/tests/next_tests/unit_tests/iterator_tests/transforms_tests/test_constant_folding.py @@ -69,12 +69,12 @@ def test_constant_folding_literal_plus0(): def test_constant_folding_literal_minus0(): testee = im.minus(im.ref("__out_size_1"), im.literal_from_value(1)) - expected = im.plus(im.ref("__out_size_1"), im.literal_from_value(-1)) + expected = im.minus(im.ref("__out_size_1"), im.literal_from_value(1)) actual = ConstantFolding.apply(testee) assert actual == expected testee = im.minus(im.literal_from_value(1), im.ref("__out_size_1")) - expected = im.plus(im.call("neg")(im.ref("__out_size_1")), im.literal_from_value(1)) + expected = im.minus(im.literal_from_value(1), im.ref("__out_size_1")) actual = ConstantFolding.apply(testee) assert actual == expected @@ -171,7 +171,7 @@ def test_constant_folding_maximum_literal_plus7(): im.plus(im.ref("__out_size_1"), im.literal_from_value(1)), im.plus(im.literal_from_value(1), im.literal_from_value(1)), ) - expected = im.plus(im.ref("__out_size_1"), im.literal_from_value(-1)) + expected = im.minus(im.ref("__out_size_1"), im.literal_from_value(1)) actual = ConstantFolding.apply(testee) assert actual == expected @@ -285,7 +285,7 @@ def test_constant_folding_maximum_literal_plus18(): testee = im.call("minimum")( im.ref("__out_size_1"), im.plus(im.ref("__out_size_1"), im.literal_from_value(-1)) ) - expected = im.plus(im.ref("__out_size_1"), im.literal_from_value(-1)) + expected = im.minus(im.ref("__out_size_1"), im.literal_from_value(1)) actual = ConstantFolding.apply(testee) assert actual == expected @@ -294,7 +294,7 @@ def test_constant_folding_maximum_literal_plus19(): testee = im.call("minimum")( im.minus(im.ref("__out_size_1"), im.literal_from_value(1)), im.ref("__out_size_1") ) - expected = im.plus(im.ref("__out_size_1"), im.literal_from_value(-1)) + expected = im.minus(im.ref("__out_size_1"), im.literal_from_value(1)) actual = ConstantFolding.apply(testee) assert actual == expected @@ -303,7 +303,7 @@ def test_constant_folding_maximum_literal_plus20(): testee = im.call("minimum")( im.ref("__out_size_1"), im.minus(im.ref("__out_size_1"), im.literal_from_value(1)) ) - expected = im.plus(im.ref("__out_size_1"), im.literal_from_value(-1)) + expected = im.minus(im.ref("__out_size_1"), im.literal_from_value(1)) actual = ConstantFolding.apply(testee) assert actual == expected @@ -355,17 +355,15 @@ def test_constant_folding_complex(): ), ) # neg(maximum(maximum(sym, 1), maximum(neg(sym) + 1, -1))) + 1 - expected = im.plus( - im.call("neg")( + expected = im.minus( + im.literal_from_value(1), + im.call("maximum")( + im.call("maximum")(im.ref("sym"), im.literal_from_value(1)), im.call("maximum")( - im.call("maximum")(im.ref("sym"), im.literal_from_value(1)), - im.call("maximum")( - im.plus(im.call("neg")(sym), im.literal_from_value(1)), - im.literal_from_value(-1), - ), - ) + im.minus(im.literal_from_value(1), sym), + im.literal_from_value(-1), + ), ), - im.literal_from_value(1), ) actual = ConstantFolding.apply(testee) assert actual == expected @@ -393,17 +391,17 @@ def test_constant_folding_complex_1(): im.literal_from_value(2), ) # sym + 1 + (maximum(sym, 1) + (sym + -1 + (sym + 3))) + -2 - expected = im.plus( + expected = im.minus( im.plus( im.plus(sym, 1), im.plus( im.call("maximum")(sym, im.literal_from_value(1)), im.plus( - im.plus(sym, im.literal_from_value(-1)), im.plus(sym, im.literal_from_value(3)) + im.minus(sym, im.literal_from_value(1)), im.plus(sym, im.literal_from_value(3)) ), ), ), - im.literal_from_value(-2), + im.literal_from_value(2), ) actual = ConstantFolding.apply(testee) assert actual == expected @@ -431,15 +429,15 @@ def test_constant_folding_complex_3(): # minimum(neg(sym) + 1, sym + 1) + (maximum(neg(sym) + 1, sym + 1) + (neg(sym) + 1)) expected = im.plus( im.call("minimum")( - im.plus(im.call("neg")(sym), im.literal_from_value(1)), + im.minus(im.literal_from_value(1), sym), im.plus(sym, im.literal_from_value(1)), ), im.plus( im.call("maximum")( - im.plus(im.call("neg")(sym), im.literal_from_value(1)), + im.minus(im.literal_from_value(1), sym), im.plus(sym, im.literal_from_value(1)), ), - im.plus(im.call("neg")(sym), im.literal_from_value(1)), + im.minus(im.literal_from_value(1), sym), ), ) actual = ConstantFolding.apply(testee) @@ -465,7 +463,7 @@ def test_constant_folding_complex_2(): im.plus(sym, im.literal_from_value(-1)), im.plus(sym, im.literal_from_value(3)) ) expected = im.plus( - im.plus(sym, im.literal_from_value(-1)), im.plus(sym, im.literal_from_value(3)) + im.minus(sym, im.literal_from_value(1)), im.plus(sym, im.literal_from_value(3)) ) actual = ConstantFolding.apply(testee) assert actual == expected @@ -477,8 +475,8 @@ def test_constant_folding_complex_4(): im.minus(im.literal("1", "float32"), sym), im.minus(im.literal("2", "float32"), sym) ) expected = im.divides_( - im.plus(im.call("neg")(sym), im.literal("1", "float32")), - im.plus(im.call("neg")(sym), im.literal("2", "float32")), + im.minus(im.literal("1", "float32"), sym), + im.minus(im.literal("2", "float32"), sym), ) actual = ConstantFolding.apply(testee) assert actual == expected @@ -502,7 +500,7 @@ def test_constant_folding_plus_new(): def test_minus(): sym = im.ref("sym") testee = im.plus(im.minus(im.literal_from_value(1), sym), im.literal_from_value(1)) - expected = im.plus(im.call("neg")(sym), im.literal_from_value(2)) + expected = im.minus(im.literal_from_value(2), sym) actual = ConstantFolding.apply(testee) assert actual == expected @@ -510,7 +508,7 @@ def test_minus(): def test_fold_min_max_plus(): sym = im.ref("sym") testee = im.call("minimum")(im.plus(sym, im.literal_from_value(-1)), sym) - expected = im.plus(sym, im.literal_from_value(-1)) + expected = im.minus(sym, im.literal_from_value(1)) actual = ConstantFolding.apply(testee) assert actual == expected @@ -541,3 +539,12 @@ def test_max_min(): expected = im.call("maximum")(im.call("minimum")(sym, 1), sym) actual = ConstantFolding.apply(testee) assert actual == expected + + +def test_minus_1(): + sym = im.ref("sym") + testee = im.maximum(im.plus(sym, im.literal_from_value(-1)), im.literal_from_value(1)) + expected = im.maximum(im.minus(sym, im.literal_from_value(1)), im.literal_from_value(1)) + + actual = ConstantFolding.apply(testee) + assert actual == expected From 91f2a664b008fa8cba911733cd2d68699e34fe39 Mon Sep 17 00:00:00 2001 From: Sara Faghih-Naini Date: Wed, 29 Jan 2025 11:05:20 +0100 Subject: [PATCH 31/39] Cleanup UndoCanonicalizeMinus --- .../iterator/transforms/constant_folding.py | 32 ++++++------------- 1 file changed, 9 insertions(+), 23 deletions(-) diff --git a/src/gt4py/next/iterator/transforms/constant_folding.py b/src/gt4py/next/iterator/transforms/constant_folding.py index 1d2d15cdda..abdaa81f81 100644 --- a/src/gt4py/next/iterator/transforms/constant_folding.py +++ b/src/gt4py/next/iterator/transforms/constant_folding.py @@ -20,6 +20,10 @@ from gt4py.next.iterator.transforms import fixed_point_transformation +def get_value_from_literal(literal: ir.Literal): + return getattr(embedded, str(literal.type))(literal.value) + + class UndoCanonicalizeMinus(eve.NodeTranslator): def generic_visit(self, node, **kwargs) -> ir.Node: node = super().generic_visit(node, **kwargs) @@ -30,30 +34,12 @@ def undo_canonicalize_minus(self, node: ir.FunCall, **kwargs) -> ir.Node: if cpm.is_call_to(node, "plus"): if cpm.is_call_to(node.args[1], "neg"): return im.minus(node.args[0], node.args[1].args[0]) - if ( - isinstance(node.args[1], ir.Literal) - and node.args[1].value - < im.literal("0", typename=node.args[1].type.kind.name.lower()).value - ): - return im.minus( - node.args[0], - ConstantFolding.transform_fold_arithmetic_builtins( - im.multiplies_(node.args[1], -1) - ), - ) + if isinstance(node.args[1], ir.Literal) and get_value_from_literal(node.args[1]) < 0: + return im.minus(node.args[0], -get_value_from_literal(node.args[1])) if cpm.is_call_to(node.args[0], "neg"): return im.minus(node.args[1], node.args[0].args[0]) - if ( - isinstance(node.args[0], ir.Literal) - and node.args[0].value - < im.literal("0", typename=node.args[0].type.kind.name.lower()).value - ): - return im.minus( - node.args[1], - ConstantFolding.transform_fold_arithmetic_builtins( - im.multiplies_(node.args[0], -1) - ), - ) + if isinstance(node.args[0], ir.Literal) and get_value_from_literal(node.args[0]) < 0: + return im.minus(node.args[1], -get_value_from_literal(node.args[0])) return node @@ -220,7 +206,7 @@ def transform_fold_arithmetic_builtins(self, node: ir.FunCall, **kwargs) -> Opti if node.fun.id in builtins.ARITHMETIC_BUILTINS: fun = getattr(embedded, str(node.fun.id)) arg_values = [ - getattr(embedded, str(arg.type))(arg.value) # type: ignore[attr-defined] # arg type already established in if condition + get_value_from_literal(arg) # type: ignore[arg-type] # arg type already established in if condition for arg in node.args ] return im.literal_from_value(fun(*arg_values)) From 84d813de364a5fc1a8c588f17c608e136bc46e63 Mon Sep 17 00:00:00 2001 From: Sara Faghih-Naini Date: Wed, 29 Jan 2025 11:25:51 +0100 Subject: [PATCH 32/39] Cleanup UndoCanonicalizeMinus --- .../iterator/transforms/constant_folding.py | 28 +++++++++---------- 1 file changed, 13 insertions(+), 15 deletions(-) diff --git a/src/gt4py/next/iterator/transforms/constant_folding.py b/src/gt4py/next/iterator/transforms/constant_folding.py index abdaa81f81..2a2ad3f278 100644 --- a/src/gt4py/next/iterator/transforms/constant_folding.py +++ b/src/gt4py/next/iterator/transforms/constant_folding.py @@ -20,26 +20,24 @@ from gt4py.next.iterator.transforms import fixed_point_transformation -def get_value_from_literal(literal: ir.Literal): +def value_from_literal(literal: ir.Literal): return getattr(embedded, str(literal.type))(literal.value) class UndoCanonicalizeMinus(eve.NodeTranslator): - def generic_visit(self, node, **kwargs) -> ir.Node: + def visit_FunCall(self, node: ir.FunCall, **kwargs) -> ir.Node: node = super().generic_visit(node, **kwargs) - return self.undo_canonicalize_minus(node, **kwargs) - - def undo_canonicalize_minus(self, node: ir.FunCall, **kwargs) -> ir.Node: # `a + (-b)` -> `a - b` , `-a + b` -> `b - a`, `-a + (-b)` -> `-a - b` if cpm.is_call_to(node, "plus"): - if cpm.is_call_to(node.args[1], "neg"): - return im.minus(node.args[0], node.args[1].args[0]) - if isinstance(node.args[1], ir.Literal) and get_value_from_literal(node.args[1]) < 0: - return im.minus(node.args[0], -get_value_from_literal(node.args[1])) - if cpm.is_call_to(node.args[0], "neg"): - return im.minus(node.args[1], node.args[0].args[0]) - if isinstance(node.args[0], ir.Literal) and get_value_from_literal(node.args[0]) < 0: - return im.minus(node.args[1], -get_value_from_literal(node.args[0])) + a, b = node.args + if cpm.is_call_to(b, "neg"): + return im.minus(a, b.args[0]) + if isinstance(b, ir.Literal) and value_from_literal(b) < 0: + return im.minus(a, -value_from_literal(b)) + if cpm.is_call_to(a, "neg"): + return im.minus(b, a.args[0]) + if isinstance(a, ir.Literal) and value_from_literal(a) < 0: + return im.minus(b, -value_from_literal(a)) return node @@ -96,7 +94,7 @@ def all(self) -> ConstantFolding.Transformation: @classmethod def apply(cls, node: ir.Node) -> ir.Node: node = cls().visit(node) - return UndoCanonicalizeMinus().generic_visit(node) + return UndoCanonicalizeMinus().visit(node) def transform_canonicalize_op_funcall_symref_literal( self, node: ir.FunCall, **kwargs @@ -206,7 +204,7 @@ def transform_fold_arithmetic_builtins(self, node: ir.FunCall, **kwargs) -> Opti if node.fun.id in builtins.ARITHMETIC_BUILTINS: fun = getattr(embedded, str(node.fun.id)) arg_values = [ - get_value_from_literal(arg) # type: ignore[arg-type] # arg type already established in if condition + value_from_literal(arg) # type: ignore[arg-type] # arg type already established in if condition for arg in node.args ] return im.literal_from_value(fun(*arg_values)) From ed0248e624ec87d2d2b8532d6a237abb7633f3ae Mon Sep 17 00:00:00 2001 From: Sara Faghih-Naini Date: Tue, 11 Feb 2025 13:30:24 +0100 Subject: [PATCH 33/39] Fix some test failures --- src/gt4py/next/iterator/transforms/constant_folding.py | 5 +++++ 1 file changed, 5 insertions(+) diff --git a/src/gt4py/next/iterator/transforms/constant_folding.py b/src/gt4py/next/iterator/transforms/constant_folding.py index 2a2ad3f278..ebb0a77fe5 100644 --- a/src/gt4py/next/iterator/transforms/constant_folding.py +++ b/src/gt4py/next/iterator/transforms/constant_folding.py @@ -25,6 +25,11 @@ def value_from_literal(literal: ir.Literal): class UndoCanonicalizeMinus(eve.NodeTranslator): + PRESERVED_ANNEX_ATTRS = ( + "type", + "domain", + ) + def visit_FunCall(self, node: ir.FunCall, **kwargs) -> ir.Node: node = super().generic_visit(node, **kwargs) # `a + (-b)` -> `a - b` , `-a + b` -> `b - a`, `-a + (-b)` -> `-a - b` From d452af96d500c9b709c7199f7d7d6d9fdf29d1c3 Mon Sep 17 00:00:00 2001 From: Sara Faghih-Naini Date: Tue, 11 Feb 2025 14:56:04 +0100 Subject: [PATCH 34/39] Clean up constant_folding tests --- .../transforms_tests/test_constant_folding.py | 472 +++++------------- 1 file changed, 135 insertions(+), 337 deletions(-) diff --git a/tests/next_tests/unit_tests/iterator_tests/transforms_tests/test_constant_folding.py b/tests/next_tests/unit_tests/iterator_tests/transforms_tests/test_constant_folding.py index e5c85922e2..dbd574bf2a 100644 --- a/tests/next_tests/unit_tests/iterator_tests/transforms_tests/test_constant_folding.py +++ b/tests/next_tests/unit_tests/iterator_tests/transforms_tests/test_constant_folding.py @@ -9,10 +9,12 @@ from gt4py.next.iterator.ir_utils import ir_makers as im from gt4py.next.iterator.transforms.constant_folding import ConstantFolding +one = im.literal_from_value(1) + def test_constant_folding_plus(): + testee = im.plus(one, one) expected = im.literal_from_value(2) - testee = im.plus(im.literal_from_value(1), im.literal_from_value(1)) actual = ConstantFolding.apply(testee) assert actual == expected @@ -38,331 +40,263 @@ def test_constant_folding_math_op(): def test_constant_folding_if(): - expected = im.plus("a", 2) + expected = im.plus("sym", 2) testee = im.if_( im.literal_from_value(True), - im.plus(im.ref("a"), im.literal_from_value(2)), + im.plus("sym", im.literal_from_value(2)), im.minus(im.literal_from_value(9), im.literal_from_value(5)), ) actual = ConstantFolding.apply(testee) assert actual == expected -def test_constant_folding_minimum(): - testee = im.minimum("a", "a") - expected = im.ref("a") +def test_constant_folding_maximum_literal(): + testee = im.maximum(one, im.literal_from_value(2)) + expected = im.literal_from_value(2) actual = ConstantFolding.apply(testee) assert actual == expected -def test_constant_folding_literal_plus0(): - testee = im.plus(im.ref("__out_size_1"), im.literal_from_value(1)) - expected = im.plus(im.ref("__out_size_1"), im.literal_from_value(1)) +def test_constant_folding_minimum(): + testee = im.minimum("sym", "sym") + expected = im.ref("sym") actual = ConstantFolding.apply(testee) assert actual == expected - testee = im.plus(im.literal_from_value(1), im.ref("__out_size_1")) - expected = im.plus(im.ref("__out_size_1"), im.literal_from_value(1)) + testee = im.maximum(0, 0) + expected = im.literal_from_value(0) actual = ConstantFolding.apply(testee) assert actual == expected -def test_constant_folding_literal_minus0(): - testee = im.minus(im.ref("__out_size_1"), im.literal_from_value(1)) - expected = im.minus(im.ref("__out_size_1"), im.literal_from_value(1)) +def test_cannonicalize_plus_funcall_symref_literal(): + testee = im.plus("sym", one) + expected = im.plus("sym", one) actual = ConstantFolding.apply(testee) assert actual == expected - testee = im.minus(im.literal_from_value(1), im.ref("__out_size_1")) - expected = im.minus(im.literal_from_value(1), im.ref("__out_size_1")) + testee = im.plus(one, "sym") + expected = im.plus("sym", one) actual = ConstantFolding.apply(testee) assert actual == expected -def test_constant_folding_funcall_literal(): - testee = im.plus( - im.plus(im.ref("__out_size_1"), im.literal_from_value(1)), im.literal_from_value(1) - ) - expected = im.plus(im.ref("__out_size_1"), im.literal_from_value(2)) +def test_canonicalize_minus(): + testee = im.minus("sym", one) + expected = im.minus("sym", one) actual = ConstantFolding.apply(testee) assert actual == expected - testee = im.plus( - im.literal_from_value(1), im.plus(im.ref("__out_size_1"), im.literal_from_value(1)) - ) - expected = im.plus(im.ref("__out_size_1"), im.literal_from_value(2)) + testee = im.minus(one, "sym") + expected = im.minus(one, "sym") actual = ConstantFolding.apply(testee) assert actual == expected -def test_constant_folding_maximum_literal_minus(): - testee = im.plus(im.ref("__out_size_1"), im.literal_from_value(1)) - expected = im.plus(im.ref("__out_size_1"), im.literal_from_value(1)) +def test_canonicalize_fold_op_funcall_symref_literal(): + testee = im.plus(im.plus("sym", one), one) + expected = im.plus("sym", im.literal_from_value(2)) actual = ConstantFolding.apply(testee) assert actual == expected - testee = im.plus(im.literal_from_value(1), im.ref("__out_size_1")) - expected = im.plus(im.ref("__out_size_1"), im.literal_from_value(1)) + testee = im.plus(one, im.plus("sym", one)) + expected = im.plus("sym", im.literal_from_value(2)) actual = ConstantFolding.apply(testee) assert actual == expected -def test_constant_folding_maximum_literal_plus1(): - testee = im.call("maximum")( - im.call("maximum")(im.ref("__out_size_1"), im.literal_from_value(1)), - im.literal_from_value(1), - ) - expected = im.call("maximum")(im.ref("__out_size_1"), im.literal_from_value(1)) +def test_constant_folding_nested_maximum(): + testee = im.maximum(im.maximum("sym", one), one) + expected = im.maximum("sym", one) actual = ConstantFolding.apply(testee) assert actual == expected - -def test_constant_folding_maximum_literal_plus2(): - testee = im.call("maximum")( - im.call("maximum")(im.literal_from_value(1), im.ref("__out_size_1")), - im.literal_from_value(1), - ) - expected = im.call("maximum")(im.ref("__out_size_1"), im.literal_from_value(1)) + testee = im.maximum(im.maximum(one, "sym"), one) + expected = im.maximum("sym", one) actual = ConstantFolding.apply(testee) assert actual == expected - -def test_constant_folding_maximum_literal_plus3(): - testee = im.call("maximum")( - im.call("maximum")(im.literal_from_value(1), im.ref("__out_size_1")), - im.call("maximum")(im.literal_from_value(1), im.ref("__out_size_1")), - ) - expected = im.call("maximum")(im.ref("__out_size_1"), im.literal_from_value(1)) + testee = im.maximum("sym", im.maximum(one, "sym")) + expected = im.maximum("sym", one) actual = ConstantFolding.apply(testee) assert actual == expected - -def test_constant_folding_maximum_literal_plus4(): - testee = im.call("maximum")( - im.call("maximum")(im.literal_from_value(1), im.ref("__out_size_1")), - im.call("maximum")(im.ref("__out_size_1"), im.literal_from_value(1)), - ) - expected = im.call("maximum")(im.ref("__out_size_1"), im.literal_from_value(1)) + testee = im.maximum(im.maximum(one, "sym"), im.maximum(one, "sym")) + expected = im.maximum("sym", one) actual = ConstantFolding.apply(testee) assert actual == expected - -def test_constant_folding_maximum_literal_plus5(): - testee = im.call("maximum")( - im.ref("__out_size_1"), im.call("maximum")(im.literal_from_value(1), im.ref("__out_size_1")) - ) - expected = im.call("maximum")(im.ref("__out_size_1"), im.literal_from_value(1)) + testee = im.maximum(im.maximum(one, "sym"), im.maximum("sym", one)) + expected = im.maximum("sym", one) actual = ConstantFolding.apply(testee) assert actual == expected - -def test_constant_folding_maximum_literal_plus6(): - testee = im.call("maximum")( - im.plus(im.ref("__out_size_1"), im.literal_from_value(1)), - im.plus(im.ref("__out_size_1"), im.literal_from_value(0)), - ) - expected = im.plus(im.ref("__out_size_1"), im.literal_from_value(1)) + testee = im.maximum(im.minimum("sym", 1), "sym") + expected = im.maximum(im.minimum("sym", 1), "sym") actual = ConstantFolding.apply(testee) assert actual == expected -def test_constant_folding_maximum_literal_plus7(): - testee = im.minus( - im.plus(im.ref("__out_size_1"), im.literal_from_value(1)), - im.plus(im.literal_from_value(1), im.literal_from_value(1)), - ) - expected = im.minus(im.ref("__out_size_1"), im.literal_from_value(1)) +def test_constant_folding_maximum_plus(): + testee = im.maximum(im.plus("sym", one), im.plus("sym", im.literal_from_value(0))) + expected = im.plus("sym", one) actual = ConstantFolding.apply(testee) assert actual == expected - -def test_constant_folding_maximum_literal_plus8(): - testee = im.plus( - im.plus(im.ref("__out_size_1"), im.literal_from_value(1)), - im.plus(im.literal_from_value(1), im.literal_from_value(1)), - ) - expected = im.plus(im.ref("__out_size_1"), im.literal_from_value(3)) + testee = im.maximum(im.plus("sym", one), im.plus(im.plus("sym", one), im.literal_from_value(0))) + expected = im.plus("sym", one) actual = ConstantFolding.apply(testee) assert actual == expected - -def test_constant_folding_maximum_literal_plus9(): - testee = im.call("maximum")( - im.plus(im.ref("__out_size_1"), im.literal_from_value(1)), - im.plus( - im.plus(im.ref("__out_size_1"), im.literal_from_value(1)), im.literal_from_value(0) - ), - ) - expected = im.plus(im.ref("__out_size_1"), im.literal_from_value(1)) + testee = im.maximum("sym", im.plus("sym", one)) + expected = im.plus("sym", one) actual = ConstantFolding.apply(testee) assert actual == expected - -def test_constant_folding_maximum_literal_plus10(): - testee = im.call("maximum")( - im.plus(im.ref("__out_size_1"), im.literal_from_value(1)), im.ref("__out_size_1") - ) - expected = im.plus(im.ref("__out_size_1"), im.literal_from_value(1)) + testee = im.maximum("sym", im.plus("sym", im.literal_from_value(-1))) + expected = im.ref("sym") actual = ConstantFolding.apply(testee) assert actual == expected - -def test_constant_folding_maximum_literal_plus10a(): - testee = im.plus( - im.ref("__out_size_1"), - im.call("maximum")(im.literal_from_value(0), im.literal_from_value(-1)), - ) - - expected = im.ref("__out_size_1") + testee = im.plus("sym", im.maximum(im.literal_from_value(0), im.literal_from_value(-1))) + expected = im.ref("sym") actual = ConstantFolding.apply(testee) assert actual == expected -def test_constant_folding_maximum_literal_plus11(): - testee = im.call("maximum")( - im.ref("__out_size_1"), im.plus(im.ref("__out_size_1"), im.literal_from_value(1)) - ) - expected = im.plus(im.ref("__out_size_1"), im.literal_from_value(1)) +def test_constant_folding_plus_minus(): + testee = im.minus(im.plus("sym", one), im.plus(one, one)) + expected = im.minus("sym", one) actual = ConstantFolding.apply(testee) assert actual == expected + testee = im.plus(im.minus("sym", one), im.literal_from_value(2)) + expected = im.plus("sym", one) + actual = ConstantFolding.apply(testee) + assert actual == expected -def test_constant_folding_maximum_literal_plus12(): - testee = im.call("maximum")( - im.ref("__out_size_1"), im.plus(im.ref("__out_size_1"), im.literal_from_value(-1)) - ) - expected = im.ref("__out_size_1") + testee = im.plus(im.minus(one, "sym"), one) + expected = im.minus(im.literal_from_value(2), "sym") actual = ConstantFolding.apply(testee) assert actual == expected -def test_constant_folding_maximum_literal_plus13(): - testee = im.call("maximum")( - im.minus(im.ref("__out_size_1"), im.literal_from_value(1)), im.ref("__out_size_1") - ) - expected = im.ref("__out_size_1") +def test_constant_folding_nested_plus(): + testee = im.plus(im.plus("sym", one), im.plus(one, one)) + expected = im.plus("sym", im.literal_from_value(3)) actual = ConstantFolding.apply(testee) assert actual == expected - -def test_constant_folding_maximum_literal_plus14(): - testee = im.call("maximum")( - im.ref("__out_size_1"), im.minus(im.ref("__out_size_1"), im.literal_from_value(1)) + testee = im.plus( + im.plus("sym", im.literal_from_value(-1)), im.plus("sym", im.literal_from_value(3)) ) - expected = im.ref("__out_size_1") + expected = im.plus(im.minus("sym", one), im.plus("sym", im.literal_from_value(3))) actual = ConstantFolding.apply(testee) assert actual == expected -def test_constant_folding_maximum_literal_plus15(): - testee = im.call("maximum")( - im.ref("__out_size_1"), im.minus(im.ref("__out_size_1"), im.literal_from_value(-1)) - ) - expected = im.plus(im.ref("__out_size_1"), im.literal_from_value(1)) +def test_constant_folding_maximum_minus(): + testee = im.maximum(im.minus("sym", one), "sym") + expected = im.ref("sym") actual = ConstantFolding.apply(testee) assert actual == expected + testee = im.maximum("sym", im.minus("sym", im.literal_from_value(-1))) + expected = im.plus("sym", one) + actual = ConstantFolding.apply(testee) + assert actual == expected -def test_constant_folding_maximum_literal_plus16(): - testee = im.call("minimum")( - im.plus(im.ref("__out_size_1"), im.literal_from_value(1)), im.ref("__out_size_1") - ) - expected = im.ref("__out_size_1") + testee = im.maximum(im.plus("sym", im.literal_from_value(-1)), one) + expected = im.maximum(im.minus("sym", one), one) actual = ConstantFolding.apply(testee) assert actual == expected -def test_constant_folding_maximum_literal_plus17(): - testee = im.call("minimum")( - im.ref("__out_size_1"), im.plus(im.ref("__out_size_1"), im.literal_from_value(1)) - ) - expected = im.ref("__out_size_1") +def test_constant_folding_minimum_plus_minus(): + testee = im.minimum(im.plus("sym", one), "sym") + expected = im.ref("sym") actual = ConstantFolding.apply(testee) assert actual == expected - -def test_constant_folding_maximum_literal_plus18(): - testee = im.call("minimum")( - im.ref("__out_size_1"), im.plus(im.ref("__out_size_1"), im.literal_from_value(-1)) - ) - expected = im.minus(im.ref("__out_size_1"), im.literal_from_value(1)) + testee = im.minimum("sym", im.plus("sym", im.literal_from_value(-1))) + expected = im.minus("sym", one) actual = ConstantFolding.apply(testee) assert actual == expected + testee = im.minimum(im.minus("sym", one), "sym") + expected = im.minus("sym", one) + actual = ConstantFolding.apply(testee) + assert actual == expected -def test_constant_folding_maximum_literal_plus19(): - testee = im.call("minimum")( - im.minus(im.ref("__out_size_1"), im.literal_from_value(1)), im.ref("__out_size_1") - ) - expected = im.minus(im.ref("__out_size_1"), im.literal_from_value(1)) + testee = im.minimum("sym", im.minus("sym", im.literal_from_value(-1))) + expected = im.ref("sym") actual = ConstantFolding.apply(testee) assert actual == expected -def test_constant_folding_maximum_literal_plus20(): - testee = im.call("minimum")( - im.ref("__out_size_1"), im.minus(im.ref("__out_size_1"), im.literal_from_value(1)) - ) - expected = im.minus(im.ref("__out_size_1"), im.literal_from_value(1)) +def test_constant_folding_max_syms(): + testee = im.maximum("sym1", im.maximum("sym2", "sym1")) + expected = im.maximum("sym2", "sym1") actual = ConstantFolding.apply(testee) assert actual == expected -def test_constant_folding_maximum_literal_plus21(): - testee = im.call("minimum")( - im.ref("__out_size_1"), im.minus(im.ref("__out_size_1"), im.literal_from_value(-1)) +def test_constant_folding_max_tuple_get(): + testee = im.maximum( + im.plus(im.tuple_get(1, "sym"), 1), + im.maximum(im.tuple_get(1, "sym"), im.plus(im.tuple_get(1, "sym"), 1)), ) - expected = im.ref("__out_size_1") + expected = im.plus(im.tuple_get(1, "sym"), 1) actual = ConstantFolding.apply(testee) assert actual == expected -def test_constant_folding_literal(): - testee = im.plus(im.literal_from_value(1), im.literal_from_value(2)) - expected = im.literal_from_value(3) +def test_constant_folding_nested_max_plus(): + # maximum(maximum(1 + sym, 1), 1 + sym) + testee = im.maximum(im.maximum(im.plus(one, "sym"), 1), im.plus(one, "sym")) + + # maximum(sym + 1, 1) + expected = im.maximum(im.plus("sym", one), 1) actual = ConstantFolding.apply(testee) assert actual == expected -def test_constant_folding_literal_maximum(): - testee = im.maximum(im.literal_from_value(1), im.literal_from_value(2)) - expected = im.literal_from_value(2) +def test_constant_folding_divides_float32(): + sym = im.ref("sym", "float32") + testee = im.divides_( + im.minus(im.literal("1", "float32"), sym), im.plus(im.literal("2", "float32"), sym) + ) + expected = im.divides_( + im.minus(im.literal("1", "float32"), sym), im.plus(sym, im.literal("2", "float32")) + ) actual = ConstantFolding.apply(testee) assert actual == expected def test_constant_folding_complex(): - sym = im.ref("sym") # 1 - max(max(1, max(1, sym), min(1, sym), sym), 1 + (min(-1, 2) + max(-1, 1 - sym))) testee = im.minus( - im.literal_from_value(1), - im.call("maximum")( - im.call("maximum")( - im.literal_from_value(1), - im.call("maximum")(im.literal_from_value(1), im.ref("sym")), - im.call("minimum")(im.literal_from_value(1), im.ref("sym")), - im.ref("sym"), + one, + im.maximum( + im.maximum( + im.maximum(one, im.maximum(one, "sym")), + im.maximum(im.maximum(one, "sym"), "sym"), ), im.plus( - im.literal_from_value(1), + one, im.plus( - im.call("minimum")(im.literal_from_value(-1), 2), - im.call("maximum")( - im.literal_from_value(-1), im.minus(im.literal_from_value(1), im.ref("sym")) - ), + im.minimum(im.literal_from_value(-1), 2), + im.maximum(im.literal_from_value(-1), im.minus(one, "sym")), ), ), ), ) - # neg(maximum(maximum(sym, 1), maximum(neg(sym) + 1, -1))) + 1 + # 1 - maximum(maximum(sym, 1), maximum(1 - sym, -1)) expected = im.minus( - im.literal_from_value(1), - im.call("maximum")( - im.call("maximum")(im.ref("sym"), im.literal_from_value(1)), - im.call("maximum")( - im.minus(im.literal_from_value(1), sym), - im.literal_from_value(-1), - ), + one, + im.maximum( + im.maximum("sym", one), + im.maximum(im.minus(one, "sym"), im.literal_from_value(-1)), ), ) actual = ConstantFolding.apply(testee) @@ -370,35 +304,24 @@ def test_constant_folding_complex(): def test_constant_folding_complex_1(): - sym = im.ref("sym") # maximum(sym, 1 + sym) + (maximum(1, maximum(1, sym)) + (sym - 1 + (1 + (sym + 1) + 1))) - 2 testee = im.minus( im.plus( - im.call("maximum")(sym, im.plus(im.literal_from_value(1), sym)), + im.maximum("sym", im.plus(one, "sym")), im.plus( - im.call("maximum")( - im.literal_from_value(1), im.call("maximum")(im.literal_from_value(1), sym) - ), - im.plus( - im.minus(sym, im.literal_from_value(1)), - im.plus( - im.plus(im.literal_from_value(1), im.plus(sym, im.literal_from_value(1))), - im.literal_from_value(1), - ), - ), + im.maximum(one, im.maximum(one, "sym")), + im.plus(im.minus("sym", one), im.plus(im.plus(one, im.plus("sym", one)), one)), ), ), im.literal_from_value(2), ) - # sym + 1 + (maximum(sym, 1) + (sym + -1 + (sym + 3))) + -2 + # sym + 1 + (maximum(sym, 1) + (sym - 1 + (sym + 3))) - 2 expected = im.minus( im.plus( - im.plus(sym, 1), + im.plus("sym", 1), im.plus( - im.call("maximum")(sym, im.literal_from_value(1)), - im.plus( - im.minus(sym, im.literal_from_value(1)), im.plus(sym, im.literal_from_value(3)) - ), + im.maximum("sym", one), + im.plus(im.minus("sym", one), im.plus("sym", im.literal_from_value(3))), ), ), im.literal_from_value(2), @@ -408,143 +331,18 @@ def test_constant_folding_complex_1(): def test_constant_folding_complex_3(): - sym = im.ref("sym") # minimum(1 - sym, 1 + sym) + (maximum(maximum(1 - sym, 1 + sym), 1 - sym) + maximum(1 - sym, 1 - sym)) testee = im.plus( - im.call("minimum")( - im.minus(im.literal_from_value(1), sym), im.plus(im.literal_from_value(1), sym) - ), + im.minimum(im.minus(one, "sym"), im.plus(one, "sym")), im.plus( - im.call("maximum")( - im.call("maximum")( - im.minus(im.literal_from_value(1), sym), im.plus(im.literal_from_value(1), sym) - ), - im.minus(im.literal_from_value(1), sym), - ), - im.call("maximum")( - im.minus(im.literal_from_value(1), sym), im.minus(im.literal_from_value(1), sym) - ), + im.maximum(im.maximum(im.minus(one, "sym"), im.plus(one, "sym")), im.minus(one, "sym")), + im.maximum(im.minus(one, "sym"), im.minus(one, "sym")), ), ) - # minimum(neg(sym) + 1, sym + 1) + (maximum(neg(sym) + 1, sym + 1) + (neg(sym) + 1)) + # minimum(1 - sym, sym + 1) + (maximum(1 - sym, sym + 1) + (1 - sym)) expected = im.plus( - im.call("minimum")( - im.minus(im.literal_from_value(1), sym), - im.plus(sym, im.literal_from_value(1)), - ), - im.plus( - im.call("maximum")( - im.minus(im.literal_from_value(1), sym), - im.plus(sym, im.literal_from_value(1)), - ), - im.minus(im.literal_from_value(1), sym), - ), - ) - actual = ConstantFolding.apply(testee) - assert actual == expected - - -def test_constant_folding_complex_3a(): - sym = im.ref("sym") - # maximum(maximum(1 + sym, 1), 1 + sym) - testee = im.call("maximum")( - im.call("maximum")(im.plus(im.literal_from_value(1), sym), 1), - im.plus(im.literal_from_value(1), sym), - ) - # maximum(1 + sym, 1) - expected = im.call("maximum")(im.plus(sym, im.literal_from_value(1)), 1) - actual = ConstantFolding.apply(testee) - assert actual == expected - - -def test_constant_folding_complex_2(): - sym = im.ref("sym") - testee = im.plus( - im.plus(sym, im.literal_from_value(-1)), im.plus(sym, im.literal_from_value(3)) - ) - expected = im.plus( - im.minus(sym, im.literal_from_value(1)), im.plus(sym, im.literal_from_value(3)) - ) - actual = ConstantFolding.apply(testee) - assert actual == expected - - -def test_constant_folding_complex_4(): - sym = im.ref("sym", "float32") - testee = im.divides_( - im.minus(im.literal("1", "float32"), sym), im.minus(im.literal("2", "float32"), sym) - ) - expected = im.divides_( - im.minus(im.literal("1", "float32"), sym), - im.minus(im.literal("2", "float32"), sym), - ) - actual = ConstantFolding.apply(testee) - assert actual == expected - - -def test_constant_folding_max(): - testee = im.call("maximum")(0, 0) - expected = im.literal_from_value(0) - actual = ConstantFolding.apply(testee) - assert actual == expected - - -def test_constant_folding_plus_new(): - sym = im.ref("sym") - testee = im.plus(im.minus(sym, im.literal_from_value(1)), im.literal_from_value(2)) - expected = im.plus(sym, im.literal_from_value(1)) - actual = ConstantFolding.apply(testee) - assert actual == expected - - -def test_minus(): - sym = im.ref("sym") - testee = im.plus(im.minus(im.literal_from_value(1), sym), im.literal_from_value(1)) - expected = im.minus(im.literal_from_value(2), sym) - actual = ConstantFolding.apply(testee) - assert actual == expected - - -def test_fold_min_max_plus(): - sym = im.ref("sym") - testee = im.call("minimum")(im.plus(sym, im.literal_from_value(-1)), sym) - expected = im.minus(sym, im.literal_from_value(1)) - actual = ConstantFolding.apply(testee) - assert actual == expected - - -def test_max_tuple_get(): - sym = im.ref("sym") - testee = im.call("maximum")( - im.plus(im.tuple_get(1, sym), 1), - im.call("maximum")(im.tuple_get(1, sym), im.plus(im.tuple_get(1, sym), 1)), + im.minimum(im.minus(one, "sym"), im.plus("sym", one)), + im.plus(im.maximum(im.minus(one, "sym"), im.plus("sym", one)), im.minus(one, "sym")), ) - expected = im.plus(im.tuple_get(1, sym), 1) - actual = ConstantFolding.apply(testee) - assert actual == expected - - -def test_max_syms(): - sym1 = im.ref("sym1") - sym2 = im.ref("sym2") - testee = im.call("maximum")(sym1, im.call("maximum")(sym2, sym1)) - expected = im.call("maximum")(sym2, sym1) - actual = ConstantFolding.apply(testee) - assert actual == expected - - -def test_max_min(): - sym = im.ref("sym1") - testee = im.call("maximum")(im.call("minimum")(sym, 1), sym) - expected = im.call("maximum")(im.call("minimum")(sym, 1), sym) - actual = ConstantFolding.apply(testee) - assert actual == expected - - -def test_minus_1(): - sym = im.ref("sym") - testee = im.maximum(im.plus(sym, im.literal_from_value(-1)), im.literal_from_value(1)) - expected = im.maximum(im.minus(sym, im.literal_from_value(1)), im.literal_from_value(1)) - actual = ConstantFolding.apply(testee) assert actual == expected From 225acf72bede03b0faca293ab29ab3c639268d18 Mon Sep 17 00:00:00 2001 From: Sara Faghih-Naini Date: Wed, 12 Feb 2025 11:59:53 +0100 Subject: [PATCH 35/39] Reformat --- src/gt4py/next/iterator/transforms/constant_folding.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/gt4py/next/iterator/transforms/constant_folding.py b/src/gt4py/next/iterator/transforms/constant_folding.py index ebb0a77fe5..5019cb0523 100644 --- a/src/gt4py/next/iterator/transforms/constant_folding.py +++ b/src/gt4py/next/iterator/transforms/constant_folding.py @@ -118,13 +118,13 @@ def transform_canonicalize_op_funcall_symref_literal( return None def transform_canonicalize_minus(self, node: ir.FunCall, **kwargs) -> Optional[ir.Node]: - # `a - b` -> `a + (-b)`, prerequisite for FOLD_MIN_MAX_PLUS + # `a - b` -> `a + (-b)` if cpm.is_call_to(node, "minus"): return im.plus(node.args[0], self.fp_transform(im.call("neg")(node.args[1]))) return None def transform_canonicalize_min_max(self, node: ir.FunCall, **kwargs) -> Optional[ir.Node]: - # `maximum(a, maximum(...))` -> `maximum(maximum(...), a)`, prerequisite for FOLD_MIN_MAX + # `maximum(a, maximum(...))` -> `maximum(maximum(...), a)` if cpm.is_call_to(node, ("maximum", "minimum")): op = node.fun.id # type: ignore[attr-defined] # assured by if above if cpm.is_call_to(node.args[1], op) and not cpm.is_call_to(node.args[0], op): From c91b0114da0f2eb97675017719e93a26d2d38996 Mon Sep 17 00:00:00 2001 From: Till Ehrengruber Date: Fri, 14 Feb 2025 03:51:20 +0100 Subject: [PATCH 36/39] Cleanup tests --- .../transforms_tests/test_constant_folding.py | 457 ++++++------------ 1 file changed, 141 insertions(+), 316 deletions(-) diff --git a/tests/next_tests/unit_tests/iterator_tests/transforms_tests/test_constant_folding.py b/tests/next_tests/unit_tests/iterator_tests/transforms_tests/test_constant_folding.py index dbd574bf2a..f8ec301017 100644 --- a/tests/next_tests/unit_tests/iterator_tests/transforms_tests/test_constant_folding.py +++ b/tests/next_tests/unit_tests/iterator_tests/transforms_tests/test_constant_folding.py @@ -11,338 +11,163 @@ one = im.literal_from_value(1) - -def test_constant_folding_plus(): - testee = im.plus(one, one) - expected = im.literal_from_value(2) - actual = ConstantFolding.apply(testee) - assert actual == expected - - -def test_constant_folding_boolean(): - testee = im.not_(im.literal_from_value(True)) - expected = im.literal_from_value(False) - - actual = ConstantFolding.apply(testee) - assert actual == expected +import pytest +from gt4py.next.iterator.ir_utils import ir_makers as im +from gt4py.next.iterator.transforms.constant_folding import ConstantFolding -def test_constant_folding_math_op(): - expected = im.literal_from_value(13) - testee = im.plus( - im.literal_from_value(4), - im.plus( - im.literal_from_value(7), im.minus(im.literal_from_value(7), im.literal_from_value(5)) +def test_cases(): + return ( + # expr, simplified expr + (im.plus(1, 1), 2), + (im.not_(True), False), + (im.plus(4, im.plus(7, im.minus(7, 5))), 13), + (im.if_(True, im.plus(im.ref("a"), 2), im.minus(9, 5)), im.plus("a", 2)), + (im.minimum("a", "a"), "a"), + (im.maximum(1, 2), 2), + # canonicalization + (im.plus("a", 1), im.plus("a", 1)), + (im.plus(1, "a"), im.plus("a", 1)), + # nested plus + (im.plus(im.plus("a", 1), 1), im.plus("a", 2)), + (im.plus(1, im.plus("a", 1)), im.plus("a", 2)), + # nested maximum + (im.maximum(im.maximum("a", 1), 1), im.maximum("a", 1)), + (im.maximum(im.maximum(1, "a"), 1), im.maximum("a", 1)), + (im.maximum("a", im.maximum(1, "a")), im.maximum("a", 1)), + (im.maximum(im.maximum(1, "a"), im.maximum(1, "a")), im.maximum("a", 1)), + (im.maximum(im.maximum(1, "a"), im.maximum("a", 1)), im.maximum("a", 1)), + (im.maximum(im.minimum("a", 1), "a"), im.maximum(im.minimum("a", 1), "a")), + # maximum & plus + (im.maximum(im.plus("a", one), im.plus("a", im.literal_from_value(0))), im.plus("a", one)), + ( + im.maximum(im.plus("a", one), im.plus(im.plus("a", one), im.literal_from_value(0))), + im.plus("a", one), ), - ) - actual = ConstantFolding.apply(testee) - assert actual == expected - - -def test_constant_folding_if(): - expected = im.plus("sym", 2) - testee = im.if_( - im.literal_from_value(True), - im.plus("sym", im.literal_from_value(2)), - im.minus(im.literal_from_value(9), im.literal_from_value(5)), - ) - actual = ConstantFolding.apply(testee) - assert actual == expected - - -def test_constant_folding_maximum_literal(): - testee = im.maximum(one, im.literal_from_value(2)) - expected = im.literal_from_value(2) - actual = ConstantFolding.apply(testee) - assert actual == expected - - -def test_constant_folding_minimum(): - testee = im.minimum("sym", "sym") - expected = im.ref("sym") - actual = ConstantFolding.apply(testee) - assert actual == expected - - testee = im.maximum(0, 0) - expected = im.literal_from_value(0) - actual = ConstantFolding.apply(testee) - assert actual == expected - - -def test_cannonicalize_plus_funcall_symref_literal(): - testee = im.plus("sym", one) - expected = im.plus("sym", one) - actual = ConstantFolding.apply(testee) - assert actual == expected - - testee = im.plus(one, "sym") - expected = im.plus("sym", one) - actual = ConstantFolding.apply(testee) - assert actual == expected - - -def test_canonicalize_minus(): - testee = im.minus("sym", one) - expected = im.minus("sym", one) - actual = ConstantFolding.apply(testee) - assert actual == expected - - testee = im.minus(one, "sym") - expected = im.minus(one, "sym") - actual = ConstantFolding.apply(testee) - assert actual == expected - - -def test_canonicalize_fold_op_funcall_symref_literal(): - testee = im.plus(im.plus("sym", one), one) - expected = im.plus("sym", im.literal_from_value(2)) - actual = ConstantFolding.apply(testee) - assert actual == expected - - testee = im.plus(one, im.plus("sym", one)) - expected = im.plus("sym", im.literal_from_value(2)) - actual = ConstantFolding.apply(testee) - assert actual == expected - - -def test_constant_folding_nested_maximum(): - testee = im.maximum(im.maximum("sym", one), one) - expected = im.maximum("sym", one) - actual = ConstantFolding.apply(testee) - assert actual == expected - - testee = im.maximum(im.maximum(one, "sym"), one) - expected = im.maximum("sym", one) - actual = ConstantFolding.apply(testee) - assert actual == expected - - testee = im.maximum("sym", im.maximum(one, "sym")) - expected = im.maximum("sym", one) - actual = ConstantFolding.apply(testee) - assert actual == expected - - testee = im.maximum(im.maximum(one, "sym"), im.maximum(one, "sym")) - expected = im.maximum("sym", one) - actual = ConstantFolding.apply(testee) - assert actual == expected - - testee = im.maximum(im.maximum(one, "sym"), im.maximum("sym", one)) - expected = im.maximum("sym", one) - actual = ConstantFolding.apply(testee) - assert actual == expected - - testee = im.maximum(im.minimum("sym", 1), "sym") - expected = im.maximum(im.minimum("sym", 1), "sym") - actual = ConstantFolding.apply(testee) - assert actual == expected - - -def test_constant_folding_maximum_plus(): - testee = im.maximum(im.plus("sym", one), im.plus("sym", im.literal_from_value(0))) - expected = im.plus("sym", one) - actual = ConstantFolding.apply(testee) - assert actual == expected - - testee = im.maximum(im.plus("sym", one), im.plus(im.plus("sym", one), im.literal_from_value(0))) - expected = im.plus("sym", one) - actual = ConstantFolding.apply(testee) - assert actual == expected - - testee = im.maximum("sym", im.plus("sym", one)) - expected = im.plus("sym", one) - actual = ConstantFolding.apply(testee) - assert actual == expected - - testee = im.maximum("sym", im.plus("sym", im.literal_from_value(-1))) - expected = im.ref("sym") - actual = ConstantFolding.apply(testee) - assert actual == expected - - testee = im.plus("sym", im.maximum(im.literal_from_value(0), im.literal_from_value(-1))) - expected = im.ref("sym") - actual = ConstantFolding.apply(testee) - assert actual == expected - - -def test_constant_folding_plus_minus(): - testee = im.minus(im.plus("sym", one), im.plus(one, one)) - expected = im.minus("sym", one) - actual = ConstantFolding.apply(testee) - assert actual == expected - - testee = im.plus(im.minus("sym", one), im.literal_from_value(2)) - expected = im.plus("sym", one) - actual = ConstantFolding.apply(testee) - assert actual == expected - - testee = im.plus(im.minus(one, "sym"), one) - expected = im.minus(im.literal_from_value(2), "sym") - actual = ConstantFolding.apply(testee) - assert actual == expected - - -def test_constant_folding_nested_plus(): - testee = im.plus(im.plus("sym", one), im.plus(one, one)) - expected = im.plus("sym", im.literal_from_value(3)) - actual = ConstantFolding.apply(testee) - assert actual == expected - - testee = im.plus( - im.plus("sym", im.literal_from_value(-1)), im.plus("sym", im.literal_from_value(3)) - ) - expected = im.plus(im.minus("sym", one), im.plus("sym", im.literal_from_value(3))) - actual = ConstantFolding.apply(testee) - assert actual == expected - - -def test_constant_folding_maximum_minus(): - testee = im.maximum(im.minus("sym", one), "sym") - expected = im.ref("sym") - actual = ConstantFolding.apply(testee) - assert actual == expected - - testee = im.maximum("sym", im.minus("sym", im.literal_from_value(-1))) - expected = im.plus("sym", one) - actual = ConstantFolding.apply(testee) - assert actual == expected - - testee = im.maximum(im.plus("sym", im.literal_from_value(-1)), one) - expected = im.maximum(im.minus("sym", one), one) - actual = ConstantFolding.apply(testee) - assert actual == expected - - -def test_constant_folding_minimum_plus_minus(): - testee = im.minimum(im.plus("sym", one), "sym") - expected = im.ref("sym") - actual = ConstantFolding.apply(testee) - assert actual == expected - - testee = im.minimum("sym", im.plus("sym", im.literal_from_value(-1))) - expected = im.minus("sym", one) - actual = ConstantFolding.apply(testee) - assert actual == expected - - testee = im.minimum(im.minus("sym", one), "sym") - expected = im.minus("sym", one) - actual = ConstantFolding.apply(testee) - assert actual == expected - - testee = im.minimum("sym", im.minus("sym", im.literal_from_value(-1))) - expected = im.ref("sym") - actual = ConstantFolding.apply(testee) - assert actual == expected - - -def test_constant_folding_max_syms(): - testee = im.maximum("sym1", im.maximum("sym2", "sym1")) - expected = im.maximum("sym2", "sym1") - actual = ConstantFolding.apply(testee) - assert actual == expected - - -def test_constant_folding_max_tuple_get(): - testee = im.maximum( - im.plus(im.tuple_get(1, "sym"), 1), - im.maximum(im.tuple_get(1, "sym"), im.plus(im.tuple_get(1, "sym"), 1)), - ) - expected = im.plus(im.tuple_get(1, "sym"), 1) - actual = ConstantFolding.apply(testee) - assert actual == expected - - -def test_constant_folding_nested_max_plus(): - # maximum(maximum(1 + sym, 1), 1 + sym) - testee = im.maximum(im.maximum(im.plus(one, "sym"), 1), im.plus(one, "sym")) - - # maximum(sym + 1, 1) - expected = im.maximum(im.plus("sym", one), 1) - actual = ConstantFolding.apply(testee) - assert actual == expected - - -def test_constant_folding_divides_float32(): - sym = im.ref("sym", "float32") - testee = im.divides_( - im.minus(im.literal("1", "float32"), sym), im.plus(im.literal("2", "float32"), sym) - ) - expected = im.divides_( - im.minus(im.literal("1", "float32"), sym), im.plus(sym, im.literal("2", "float32")) - ) - actual = ConstantFolding.apply(testee) - assert actual == expected - - -def test_constant_folding_complex(): - # 1 - max(max(1, max(1, sym), min(1, sym), sym), 1 + (min(-1, 2) + max(-1, 1 - sym))) - testee = im.minus( - one, - im.maximum( + (im.maximum("a", im.plus("a", one)), im.plus("a", one)), + (im.maximum("a", im.plus("a", im.literal_from_value(-1))), im.ref("a")), + ( + im.plus("a", im.maximum(im.literal_from_value(0), im.literal_from_value(-1))), + im.ref("a"), + ), + # plus & minus + (im.minus(im.plus("sym", one), im.plus(one, one)), im.minus("sym", one)), + (im.plus(im.minus("sym", one), im.literal_from_value(2)), im.plus("sym", one)), + (im.plus(im.minus(one, "sym"), one), im.minus(im.literal_from_value(2), "sym")), + # nested plus + (im.plus(im.plus("sym", one), im.plus(one, one)), im.plus("sym", im.literal_from_value(3))), + ( + im.plus( + im.plus("sym", im.literal_from_value(-1)), im.plus("sym", im.literal_from_value(3)) + ), + im.plus(im.minus("sym", one), im.plus("sym", im.literal_from_value(3))), + ), + # maximum & minus + (im.maximum(im.minus("sym", one), "sym"), im.ref("sym")), + (im.maximum("sym", im.minus("sym", im.literal_from_value(-1))), im.plus("sym", one)), + ( + im.maximum(im.plus("sym", im.literal_from_value(-1)), one), + im.maximum(im.minus("sym", one), one), + ), + # minimum & plus & minus + (im.minimum(im.plus("sym", one), "sym"), im.ref("sym")), + (im.minimum("sym", im.plus("sym", im.literal_from_value(-1))), im.minus("sym", one)), + (im.minimum(im.minus("sym", one), "sym"), im.minus("sym", one)), + (im.minimum("sym", im.minus("sym", im.literal_from_value(-1))), im.ref("sym")), + # nested maximum + (im.maximum("sym1", im.maximum("sym2", "sym1")), im.maximum("sym2", "sym1")), + # maximum & plus on complicated expr (tuple_get) + ( im.maximum( - im.maximum(one, im.maximum(one, "sym")), - im.maximum(im.maximum(one, "sym"), "sym"), + im.plus(im.tuple_get(1, "a"), 1), + im.maximum(im.tuple_get(1, "a"), im.plus(im.tuple_get(1, "a"), 1)), ), - im.plus( + im.plus(im.tuple_get(1, "a"), 1), + ), + # nested maximum & plus + ( + im.maximum(im.maximum(im.plus(1, "sym"), 1), im.plus(1, "sym")), + im.maximum(im.plus("sym", 1), 1), + ), + # sanity check that no strange things happen + # complex tests + ( + # 1 - max(max(1, max(1, sym), min(1, sym), sym), 1 + (min(-1, 2) + max(-1, 1 - sym))) + im.minus( one, - im.plus( - im.minimum(im.literal_from_value(-1), 2), - im.maximum(im.literal_from_value(-1), im.minus(one, "sym")), + im.maximum( + im.maximum( + im.maximum(one, im.maximum(one, "sym")), + im.maximum(im.maximum(one, "sym"), "sym"), + ), + im.plus( + one, + im.plus( + im.minimum(im.literal_from_value(-1), 2), + im.maximum(im.literal_from_value(-1), im.minus(one, "sym")), + ), + ), + ), + ), + # 1 - maximum(maximum(sym, 1), maximum(1 - sym, -1)) + im.minus( + one, + im.maximum( + im.maximum("sym", one), + im.maximum(im.minus(one, "sym"), im.literal_from_value(-1)), ), ), ), - ) - # 1 - maximum(maximum(sym, 1), maximum(1 - sym, -1)) - expected = im.minus( - one, - im.maximum( - im.maximum("sym", one), - im.maximum(im.minus(one, "sym"), im.literal_from_value(-1)), + ( + # maximum(sym, 1 + sym) + (maximum(1, maximum(1, sym)) + (sym - 1 + (1 + (sym + 1) + 1))) - 2 + im.minus( + im.plus( + im.maximum("sym", im.plus(one, "sym")), + im.plus( + im.maximum(one, im.maximum(one, "sym")), + im.plus( + im.minus("sym", one), im.plus(im.plus(one, im.plus("sym", one)), one) + ), + ), + ), + im.literal_from_value(2), + ), + # sym + 1 + (maximum(sym, 1) + (sym - 1 + (sym + 3))) - 2 + im.minus( + im.plus( + im.plus("sym", 1), + im.plus( + im.maximum("sym", one), + im.plus(im.minus("sym", one), im.plus("sym", im.literal_from_value(3))), + ), + ), + im.literal_from_value(2), + ), ), - ) - actual = ConstantFolding.apply(testee) - assert actual == expected - - -def test_constant_folding_complex_1(): - # maximum(sym, 1 + sym) + (maximum(1, maximum(1, sym)) + (sym - 1 + (1 + (sym + 1) + 1))) - 2 - testee = im.minus( - im.plus( - im.maximum("sym", im.plus(one, "sym")), + ( + # minimum(1 - sym, 1 + sym) + (maximum(maximum(1 - sym, 1 + sym), 1 - sym) + maximum(1 - sym, 1 - sym)) im.plus( - im.maximum(one, im.maximum(one, "sym")), - im.plus(im.minus("sym", one), im.plus(im.plus(one, im.plus("sym", one)), one)), + im.minimum(im.minus(one, "sym"), im.plus(one, "sym")), + im.plus( + im.maximum( + im.maximum(im.minus(one, "sym"), im.plus(one, "sym")), im.minus(one, "sym") + ), + im.maximum(im.minus(one, "sym"), im.minus(one, "sym")), + ), ), - ), - im.literal_from_value(2), - ) - # sym + 1 + (maximum(sym, 1) + (sym - 1 + (sym + 3))) - 2 - expected = im.minus( - im.plus( - im.plus("sym", 1), + # minimum(1 - sym, sym + 1) + (maximum(1 - sym, sym + 1) + (1 - sym)) im.plus( - im.maximum("sym", one), - im.plus(im.minus("sym", one), im.plus("sym", im.literal_from_value(3))), + im.minimum(im.minus(one, "sym"), im.plus("sym", one)), + im.plus( + im.maximum(im.minus(one, "sym"), im.plus("sym", one)), im.minus(one, "sym") + ), ), ), - im.literal_from_value(2), ) - actual = ConstantFolding.apply(testee) - assert actual == expected -def test_constant_folding_complex_3(): - # minimum(1 - sym, 1 + sym) + (maximum(maximum(1 - sym, 1 + sym), 1 - sym) + maximum(1 - sym, 1 - sym)) - testee = im.plus( - im.minimum(im.minus(one, "sym"), im.plus(one, "sym")), - im.plus( - im.maximum(im.maximum(im.minus(one, "sym"), im.plus(one, "sym")), im.minus(one, "sym")), - im.maximum(im.minus(one, "sym"), im.minus(one, "sym")), - ), - ) - # minimum(1 - sym, sym + 1) + (maximum(1 - sym, sym + 1) + (1 - sym)) - expected = im.plus( - im.minimum(im.minus(one, "sym"), im.plus("sym", one)), - im.plus(im.maximum(im.minus(one, "sym"), im.plus("sym", one)), im.minus(one, "sym")), - ) +@pytest.mark.parametrize("test_case", test_cases()) +def test_constant_folding(test_case): + testee, expected = test_case actual = ConstantFolding.apply(testee) - assert actual == expected + assert actual == im.ensure_expr(expected) From 39cadaaaa858fd53329be0fbfbd27d75ef4245d1 Mon Sep 17 00:00:00 2001 From: Till Ehrengruber Date: Fri, 14 Feb 2025 03:53:32 +0100 Subject: [PATCH 37/39] Cleanup tests --- .../transforms_tests/test_constant_folding.py | 88 +++++++++---------- 1 file changed, 43 insertions(+), 45 deletions(-) diff --git a/tests/next_tests/unit_tests/iterator_tests/transforms_tests/test_constant_folding.py b/tests/next_tests/unit_tests/iterator_tests/transforms_tests/test_constant_folding.py index f8ec301017..16c7edf4e3 100644 --- a/tests/next_tests/unit_tests/iterator_tests/transforms_tests/test_constant_folding.py +++ b/tests/next_tests/unit_tests/iterator_tests/transforms_tests/test_constant_folding.py @@ -9,8 +9,6 @@ from gt4py.next.iterator.ir_utils import ir_makers as im from gt4py.next.iterator.transforms.constant_folding import ConstantFolding -one = im.literal_from_value(1) - import pytest from gt4py.next.iterator.ir_utils import ir_makers as im from gt4py.next.iterator.transforms.constant_folding import ConstantFolding @@ -39,43 +37,43 @@ def test_cases(): (im.maximum(im.maximum(1, "a"), im.maximum("a", 1)), im.maximum("a", 1)), (im.maximum(im.minimum("a", 1), "a"), im.maximum(im.minimum("a", 1), "a")), # maximum & plus - (im.maximum(im.plus("a", one), im.plus("a", im.literal_from_value(0))), im.plus("a", one)), + (im.maximum(im.plus("a", 1), im.plus("a", 0)), im.plus("a", 1)), ( - im.maximum(im.plus("a", one), im.plus(im.plus("a", one), im.literal_from_value(0))), - im.plus("a", one), + im.maximum(im.plus("a", 1), im.plus(im.plus("a", 1), 0)), + im.plus("a", 1), ), - (im.maximum("a", im.plus("a", one)), im.plus("a", one)), + (im.maximum("a", im.plus("a", 1)), im.plus("a", 1)), (im.maximum("a", im.plus("a", im.literal_from_value(-1))), im.ref("a")), ( - im.plus("a", im.maximum(im.literal_from_value(0), im.literal_from_value(-1))), + im.plus("a", im.maximum(0, im.literal_from_value(-1))), im.ref("a"), ), # plus & minus - (im.minus(im.plus("sym", one), im.plus(one, one)), im.minus("sym", one)), - (im.plus(im.minus("sym", one), im.literal_from_value(2)), im.plus("sym", one)), - (im.plus(im.minus(one, "sym"), one), im.minus(im.literal_from_value(2), "sym")), + (im.minus(im.plus("a", 1), im.plus(1, 1)), im.minus("a", 1)), + (im.plus(im.minus("a", 1), 2), im.plus("a", 1)), + (im.plus(im.minus(1, "a"), 1), im.minus(2, "a")), # nested plus - (im.plus(im.plus("sym", one), im.plus(one, one)), im.plus("sym", im.literal_from_value(3))), + (im.plus(im.plus("a", 1), im.plus(1, 1)), im.plus("a", 3)), ( im.plus( - im.plus("sym", im.literal_from_value(-1)), im.plus("sym", im.literal_from_value(3)) + im.plus("a", im.literal_from_value(-1)), im.plus("a", 3) ), - im.plus(im.minus("sym", one), im.plus("sym", im.literal_from_value(3))), + im.plus(im.minus("a", 1), im.plus("a", 3)), ), # maximum & minus - (im.maximum(im.minus("sym", one), "sym"), im.ref("sym")), - (im.maximum("sym", im.minus("sym", im.literal_from_value(-1))), im.plus("sym", one)), + (im.maximum(im.minus("a", 1), "a"), im.ref("a")), + (im.maximum("a", im.minus("a", im.literal_from_value(-1))), im.plus("a", 1)), ( - im.maximum(im.plus("sym", im.literal_from_value(-1)), one), - im.maximum(im.minus("sym", one), one), + im.maximum(im.plus("a", im.literal_from_value(-1)), 1), + im.maximum(im.minus("a", 1), 1), ), # minimum & plus & minus - (im.minimum(im.plus("sym", one), "sym"), im.ref("sym")), - (im.minimum("sym", im.plus("sym", im.literal_from_value(-1))), im.minus("sym", one)), - (im.minimum(im.minus("sym", one), "sym"), im.minus("sym", one)), - (im.minimum("sym", im.minus("sym", im.literal_from_value(-1))), im.ref("sym")), + (im.minimum(im.plus("a", 1), "a"), im.ref("a")), + (im.minimum("a", im.plus("a", im.literal_from_value(-1))), im.minus("a", 1)), + (im.minimum(im.minus("a", 1), "a"), im.minus("a", 1)), + (im.minimum("a", im.minus("a", im.literal_from_value(-1))), im.ref("a")), # nested maximum - (im.maximum("sym1", im.maximum("sym2", "sym1")), im.maximum("sym2", "sym1")), + (im.maximum("a", im.maximum("b", "a")), im.maximum("b", "a")), # maximum & plus on complicated expr (tuple_get) ( im.maximum( @@ -86,35 +84,35 @@ def test_cases(): ), # nested maximum & plus ( - im.maximum(im.maximum(im.plus(1, "sym"), 1), im.plus(1, "sym")), - im.maximum(im.plus("sym", 1), 1), + im.maximum(im.maximum(im.plus(1, "a"), 1), im.plus(1, "a")), + im.maximum(im.plus("a", 1), 1), ), # sanity check that no strange things happen # complex tests ( # 1 - max(max(1, max(1, sym), min(1, sym), sym), 1 + (min(-1, 2) + max(-1, 1 - sym))) im.minus( - one, + 1, im.maximum( im.maximum( - im.maximum(one, im.maximum(one, "sym")), - im.maximum(im.maximum(one, "sym"), "sym"), + im.maximum(1, im.maximum(1, "a")), + im.maximum(im.maximum(1, "a"), "a"), ), im.plus( - one, + 1, im.plus( im.minimum(im.literal_from_value(-1), 2), - im.maximum(im.literal_from_value(-1), im.minus(one, "sym")), + im.maximum(im.literal_from_value(-1), im.minus(1, "a")), ), ), ), ), # 1 - maximum(maximum(sym, 1), maximum(1 - sym, -1)) im.minus( - one, + 1, im.maximum( - im.maximum("sym", one), - im.maximum(im.minus(one, "sym"), im.literal_from_value(-1)), + im.maximum("a", 1), + im.maximum(im.minus(1, "a"), im.literal_from_value(-1)), ), ), ), @@ -122,44 +120,44 @@ def test_cases(): # maximum(sym, 1 + sym) + (maximum(1, maximum(1, sym)) + (sym - 1 + (1 + (sym + 1) + 1))) - 2 im.minus( im.plus( - im.maximum("sym", im.plus(one, "sym")), + im.maximum("a", im.plus(1, "a")), im.plus( - im.maximum(one, im.maximum(one, "sym")), + im.maximum(1, im.maximum(1, "a")), im.plus( - im.minus("sym", one), im.plus(im.plus(one, im.plus("sym", one)), one) + im.minus("a", 1), im.plus(im.plus(1, im.plus("a", 1)), 1) ), ), ), - im.literal_from_value(2), + 2, ), # sym + 1 + (maximum(sym, 1) + (sym - 1 + (sym + 3))) - 2 im.minus( im.plus( - im.plus("sym", 1), + im.plus("a", 1), im.plus( - im.maximum("sym", one), - im.plus(im.minus("sym", one), im.plus("sym", im.literal_from_value(3))), + im.maximum("a", 1), + im.plus(im.minus("a", 1), im.plus("a", 3)), ), ), - im.literal_from_value(2), + 2, ), ), ( # minimum(1 - sym, 1 + sym) + (maximum(maximum(1 - sym, 1 + sym), 1 - sym) + maximum(1 - sym, 1 - sym)) im.plus( - im.minimum(im.minus(one, "sym"), im.plus(one, "sym")), + im.minimum(im.minus(1, "a"), im.plus(1, "a")), im.plus( im.maximum( - im.maximum(im.minus(one, "sym"), im.plus(one, "sym")), im.minus(one, "sym") + im.maximum(im.minus(1, "a"), im.plus(1, "a")), im.minus(1, "a") ), - im.maximum(im.minus(one, "sym"), im.minus(one, "sym")), + im.maximum(im.minus(1, "a"), im.minus(1, "a")), ), ), # minimum(1 - sym, sym + 1) + (maximum(1 - sym, sym + 1) + (1 - sym)) im.plus( - im.minimum(im.minus(one, "sym"), im.plus("sym", one)), + im.minimum(im.minus(1, "a"), im.plus("a", 1)), im.plus( - im.maximum(im.minus(one, "sym"), im.plus("sym", one)), im.minus(one, "sym") + im.maximum(im.minus(1, "a"), im.plus("a", 1)), im.minus(1, "a") ), ), ), From 3ca3630e76b5d044e13b0b1523b7122ffec96a85 Mon Sep 17 00:00:00 2001 From: Till Ehrengruber Date: Fri, 14 Feb 2025 04:11:53 +0100 Subject: [PATCH 38/39] Cleanup --- .../iterator/transforms/constant_folding.py | 24 ++++++++----------- .../transforms_tests/test_constant_folding.py | 16 ++++--------- 2 files changed, 14 insertions(+), 26 deletions(-) diff --git a/src/gt4py/next/iterator/transforms/constant_folding.py b/src/gt4py/next/iterator/transforms/constant_folding.py index 5019cb0523..23910670ac 100644 --- a/src/gt4py/next/iterator/transforms/constant_folding.py +++ b/src/gt4py/next/iterator/transforms/constant_folding.py @@ -106,15 +106,14 @@ def transform_canonicalize_op_funcall_symref_literal( ) -> Optional[ir.Node]: # `literal, symref` -> `symref, literal`, prerequisite for FOLD_FUNCALL_LITERAL, FOLD_NEUTRAL_OP # `literal, funcall` -> `funcall, literal`, prerequisite for FOLD_FUNCALL_LITERAL, FOLD_NEUTRAL_OP - # `funcall, op` -> `op, funcall` for s[0] + (s[0] + 1), prerequisite for FOLD_MIN_MAX_PLUS + # `funcall, op` -> `op, funcall` for `s[0] + (s[0] + 1)`, prerequisite for FOLD_MIN_MAX_PLUS if cpm.is_call_to(node, ("plus", "multiplies", "minimum", "maximum")): - if ( - isinstance(node.args[0], ir.Literal) and not isinstance(node.args[1], ir.Literal) - ) or ( - (not cpm.is_call_to(node.args[0], ("plus", "multiplies", "minimum", "maximum"))) - and cpm.is_call_to(node.args[1], ("plus", "multiplies", "minimum", "maximum")) + a, b = node.args + if (isinstance(a, ir.Literal) and not isinstance(b, ir.Literal)) or ( + (not cpm.is_call_to(a, ("plus", "multiplies", "minimum", "maximum"))) + and cpm.is_call_to(b, ("plus", "multiplies", "minimum", "maximum")) ): - return im.call(node.fun)(node.args[1], node.args[0]) + return im.call(node.fun)(b, a) return None def transform_canonicalize_minus(self, node: ir.FunCall, **kwargs) -> Optional[ir.Node]: @@ -135,14 +134,11 @@ def transform_fold_funcall_literal(self, node: ir.FunCall, **kwargs) -> Optional # `(a + 1) + 1` -> `a + (1 + 1)` if cpm.is_call_to(node, "plus"): if cpm.is_call_to(node.args[0], "plus") and isinstance(node.args[1], ir.Literal): - fun_call, literal = node.args - if isinstance(fun_call.args[0], (ir.SymRef, ir.FunCall)) and isinstance( # type: ignore[attr-defined] # assured by if above - fun_call.args[1], # type: ignore[attr-defined] # assured by if above - ir.Literal, - ): + (expr, lit1), lit2 = node.args[0].args, node.args[1] + if isinstance(expr, (ir.SymRef, ir.FunCall)) and isinstance(lit1, ir.Literal): return im.plus( - fun_call.args[0], # type: ignore[attr-defined] # assured by if above - self.fp_transform(im.plus(fun_call.args[1], literal)), # type: ignore[attr-defined] # assured by if above + expr, + self.fp_transform(im.plus(lit1, lit2)), ) return None diff --git a/tests/next_tests/unit_tests/iterator_tests/transforms_tests/test_constant_folding.py b/tests/next_tests/unit_tests/iterator_tests/transforms_tests/test_constant_folding.py index 16c7edf4e3..1da2b8cec5 100644 --- a/tests/next_tests/unit_tests/iterator_tests/transforms_tests/test_constant_folding.py +++ b/tests/next_tests/unit_tests/iterator_tests/transforms_tests/test_constant_folding.py @@ -55,9 +55,7 @@ def test_cases(): # nested plus (im.plus(im.plus("a", 1), im.plus(1, 1)), im.plus("a", 3)), ( - im.plus( - im.plus("a", im.literal_from_value(-1)), im.plus("a", 3) - ), + im.plus(im.plus("a", im.literal_from_value(-1)), im.plus("a", 3)), im.plus(im.minus("a", 1), im.plus("a", 3)), ), # maximum & minus @@ -123,9 +121,7 @@ def test_cases(): im.maximum("a", im.plus(1, "a")), im.plus( im.maximum(1, im.maximum(1, "a")), - im.plus( - im.minus("a", 1), im.plus(im.plus(1, im.plus("a", 1)), 1) - ), + im.plus(im.minus("a", 1), im.plus(im.plus(1, im.plus("a", 1)), 1)), ), ), 2, @@ -147,18 +143,14 @@ def test_cases(): im.plus( im.minimum(im.minus(1, "a"), im.plus(1, "a")), im.plus( - im.maximum( - im.maximum(im.minus(1, "a"), im.plus(1, "a")), im.minus(1, "a") - ), + im.maximum(im.maximum(im.minus(1, "a"), im.plus(1, "a")), im.minus(1, "a")), im.maximum(im.minus(1, "a"), im.minus(1, "a")), ), ), # minimum(1 - sym, sym + 1) + (maximum(1 - sym, sym + 1) + (1 - sym)) im.plus( im.minimum(im.minus(1, "a"), im.plus("a", 1)), - im.plus( - im.maximum(im.minus(1, "a"), im.plus("a", 1)), im.minus(1, "a") - ), + im.plus(im.maximum(im.minus(1, "a"), im.plus("a", 1)), im.minus(1, "a")), ), ), ) From 74a401616bfc2c7f2886279f93ad19dad236496d Mon Sep 17 00:00:00 2001 From: tehrengruber Date: Wed, 19 Feb 2025 03:17:18 +0100 Subject: [PATCH 39/39] Cleanup --- .../iterator/transforms/constant_folding.py | 39 ++++++++++--------- 1 file changed, 20 insertions(+), 19 deletions(-) diff --git a/src/gt4py/next/iterator/transforms/constant_folding.py b/src/gt4py/next/iterator/transforms/constant_folding.py index 23910670ac..fdbfec99ca 100644 --- a/src/gt4py/next/iterator/transforms/constant_folding.py +++ b/src/gt4py/next/iterator/transforms/constant_folding.py @@ -20,7 +20,7 @@ from gt4py.next.iterator.transforms import fixed_point_transformation -def value_from_literal(literal: ir.Literal): +def _value_from_literal(literal: ir.Literal): return getattr(embedded, str(literal.type))(literal.value) @@ -37,15 +37,18 @@ def visit_FunCall(self, node: ir.FunCall, **kwargs) -> ir.Node: a, b = node.args if cpm.is_call_to(b, "neg"): return im.minus(a, b.args[0]) - if isinstance(b, ir.Literal) and value_from_literal(b) < 0: - return im.minus(a, -value_from_literal(b)) + if isinstance(b, ir.Literal) and _value_from_literal(b) < 0: + return im.minus(a, -_value_from_literal(b)) if cpm.is_call_to(a, "neg"): return im.minus(b, a.args[0]) - if isinstance(a, ir.Literal) and value_from_literal(a) < 0: - return im.minus(b, -value_from_literal(a)) + if isinstance(a, ir.Literal) and _value_from_literal(a) < 0: + return im.minus(b, -_value_from_literal(a)) return node +_COMMUTATIVE_OPS = ("plus", "multiplies", "minimum", "maximum") + + @dataclasses.dataclass(frozen=True, kw_only=True) class ConstantFolding( fixed_point_transformation.FixedPointTransformation, eve.PreserveLocationVisitor @@ -56,9 +59,9 @@ class ConstantFolding( ) class Transformation(enum.Flag): - # `literal, symref` -> `symref, literal`, prerequisite for FOLD_FUNCALL_LITERAL, FOLD_NEUTRAL_OP - # `literal, funcall` -> `funcall, literal`, prerequisite for FOLD_FUNCALL_LITERAL, FOLD_NEUTRAL_OP - # `funcall, op` -> `op, funcall` for s[0] + (s[0] + 1), prerequisite for FOLD_MIN_MAX_PLUS + # `1 + a` -> `a + 1`, prerequisite for FOLD_FUNCALL_LITERAL, FOLD_NEUTRAL_OP + # `1 + f(...)` -> `f(...) + 1`, prerequisite for FOLD_FUNCALL_LITERAL, FOLD_NEUTRAL_OP + # `f(...) + (expr1 + expr2)` -> `(expr1 + expr2) + f(...)`, for `s[0] + (s[0] + 1)`, prerequisite for FOLD_MIN_MAX_PLUS CANONICALIZE_OP_FUNCALL_SYMREF_LITERAL = enum.auto() # `a - b` -> `a + (-b)`, prerequisite for FOLD_MIN_MAX_PLUS @@ -74,8 +77,8 @@ class Transformation(enum.Flag): # `maximum(maximum(a, 1), 1)` -> `maximum(a, 1)` FOLD_MIN_MAX = enum.auto() - # `maximum(plus(a, 1), a)` -> `plus(a, 1)` - # `maximum(plus(a, 1), plus(a, -1))` -> `plus(a, maximum(1, -1))` + # `maximum(a + 1), a)` -> `a + 1` + # `maximum(a + 1, a + (-1))` -> `a + maximum(1, -1)` FOLD_MIN_MAX_PLUS = enum.auto() # `a + 0` -> `a`, `a * 1` -> `a` @@ -104,14 +107,12 @@ def apply(cls, node: ir.Node) -> ir.Node: def transform_canonicalize_op_funcall_symref_literal( self, node: ir.FunCall, **kwargs ) -> Optional[ir.Node]: - # `literal, symref` -> `symref, literal`, prerequisite for FOLD_FUNCALL_LITERAL, FOLD_NEUTRAL_OP - # `literal, funcall` -> `funcall, literal`, prerequisite for FOLD_FUNCALL_LITERAL, FOLD_NEUTRAL_OP - # `funcall, op` -> `op, funcall` for `s[0] + (s[0] + 1)`, prerequisite for FOLD_MIN_MAX_PLUS - if cpm.is_call_to(node, ("plus", "multiplies", "minimum", "maximum")): + # `op(literal, symref|funcall)` -> `op(symref|funcall, literal)` + # `op1(funcall, op2(...))` -> `op1(op2(...), funcall)` for `s[0] + (s[0] + 1)` + if cpm.is_call_to(node, _COMMUTATIVE_OPS): a, b = node.args if (isinstance(a, ir.Literal) and not isinstance(b, ir.Literal)) or ( - (not cpm.is_call_to(a, ("plus", "multiplies", "minimum", "maximum"))) - and cpm.is_call_to(b, ("plus", "multiplies", "minimum", "maximum")) + not cpm.is_call_to(a, _COMMUTATIVE_OPS) and cpm.is_call_to(b, _COMMUTATIVE_OPS) ): return im.call(node.fun)(b, a) return None @@ -160,13 +161,13 @@ def transform_fold_min_max_plus(self, node: ir.FunCall, **kwargs) -> Optional[ir and cpm.is_call_to(node, ("minimum", "maximum")) ): arg0, arg1 = node.args - # `maximum(plus(a, 1), a)` -> `plus(a, 1)` + # `maximum(a + 1, a)` -> `a + 1` if cpm.is_call_to(arg0, "plus"): if arg0.args[0] == arg1: return im.plus( arg0.args[0], self.fp_transform(im.call(node.fun.id)(arg0.args[1], 0)) ) - # `maximum(plus(a, 1), plus(a, -1))` -> `plus(a, maximum(1, -1))` + # `maximum(a + 1, a + (-1))` -> `a + maximum(1, -1)` if cpm.is_call_to(arg0, "plus") and cpm.is_call_to(arg1, "plus"): if arg0.args[0] == arg1.args[0]: return im.plus( @@ -205,7 +206,7 @@ def transform_fold_arithmetic_builtins(self, node: ir.FunCall, **kwargs) -> Opti if node.fun.id in builtins.ARITHMETIC_BUILTINS: fun = getattr(embedded, str(node.fun.id)) arg_values = [ - value_from_literal(arg) # type: ignore[arg-type] # arg type already established in if condition + _value_from_literal(arg) # type: ignore[arg-type] # arg type already established in if condition for arg in node.args ] return im.literal_from_value(fun(*arg_values))