Skip to content

Commit

Permalink
Cleanup concat where:
Browse files Browse the repository at this point in the history
- Cleanup infinity literal ir node and constant folding
- Improve testing
- Fix domain complement
- Simplify & cleanup lowering, domain ops pass
  • Loading branch information
tehrengruber committed Feb 20, 2025
1 parent 157b0e2 commit 435d057
Show file tree
Hide file tree
Showing 12 changed files with 290 additions and 292 deletions.
27 changes: 2 additions & 25 deletions src/gt4py/next/ffront/foast_to_gtir.py
Original file line number Diff line number Diff line change
Expand Up @@ -251,28 +251,7 @@ def visit_UnaryOp(self, node: foast.UnaryOp, **kwargs: Any) -> itir.Expr:
raise NotImplementedError(f"Unary operator '{node.op}' is not supported.")

def visit_BinOp(self, node: foast.BinOp, **kwargs: Any) -> itir.FunCall:
if (
node.op == dialect_ast_enums.BinaryOperator.BIT_AND
and isinstance(node.left.type, ts.DomainType)
and isinstance(node.right.type, ts.DomainType)
):
return im.and_(self.visit(node.left), self.visit(node.right))
if (
node.op == dialect_ast_enums.BinaryOperator.BIT_OR
and isinstance(node.left.type, ts.DomainType)
and isinstance(node.right.type, ts.DomainType)
):
return im.or_(self.visit(node.left), self.visit(node.right))
if (
node.op == dialect_ast_enums.BinaryOperator.BIT_XOR
and isinstance(node.left.type, ts.DomainType)
and isinstance(node.right.type, ts.DomainType)
):
raise NotImplementedError(
f"Binary operator '{node.op}' is not supported for '{node.right.type}' and '{node.right.type}'."
)
else:
return self._lower_and_map(node.op.value, node.left, node.right)
return self._lower_and_map(node.op.value, node.left, node.right)

def visit_TernaryExpr(self, node: foast.TernaryExpr, **kwargs: Any) -> itir.FunCall:
assert (
Expand All @@ -284,7 +263,6 @@ def visit_TernaryExpr(self, node: foast.TernaryExpr, **kwargs: Any) -> itir.FunC
)

def visit_Compare(self, node: foast.Compare, **kwargs: Any) -> itir.FunCall:
# TODO: double-check if we need the changes in the original PR
return self._lower_and_map(node.op.value, node.left, node.right)

def _visit_shift(self, node: foast.Call, **kwargs: Any) -> itir.Expr:
Expand Down Expand Up @@ -506,9 +484,8 @@ def _map(
"""
Mapping includes making the operation an `as_fieldop` (first kind of mapping), but also `itir.map_`ing lists.
"""
# TODO double-check that this code is consistent with the changes in the original PR
if all(
isinstance(t, (ts.ScalarType, ts.DimensionType))
isinstance(t, (ts.ScalarType, ts.DimensionType, ts.DomainType))
for arg_type in original_arg_types
for t in type_info.primitive_constituents(arg_type)
):
Expand Down
17 changes: 11 additions & 6 deletions src/gt4py/next/iterator/ir.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,8 +5,9 @@
#
# Please, refer to the LICENSE file in the root directory.
# SPDX-License-Identifier: BSD-3-Clause
from __future__ import annotations

from typing import ClassVar, List, Optional, Union
from typing import ClassVar, List, Optional, Union, TYPE_CHECKING

import gt4py.eve as eve
from gt4py.eve import Coerced, SymbolName, SymbolRef
Expand Down Expand Up @@ -62,13 +63,18 @@ class Literal(Expr):
class NoneLiteral(Expr):
_none_literal: int = 0


class InfinityLiteral(Expr):
pass
if TYPE_CHECKING:
POSITIVE: ClassVar[InfinityLiteral] # TODO(tehrengruber): should be `ClassVar[InfinityLiteral]`, but self-referential not supported in eve
NEGATIVE: ClassVar[InfinityLiteral]

name: typing.Literal["POSITIVE", "NEGATIVE"]

def __str__(self):
return f"{type(self).__name__}.{self.name}"

class NegInfinityLiteral(Expr):
pass
InfinityLiteral.NEGATIVE = InfinityLiteral(name="NEGATIVE")
InfinityLiteral.POSITIVE = InfinityLiteral(name="POSITIVE")


class OffsetLiteral(Expr):
Expand Down Expand Up @@ -151,4 +157,3 @@ class Program(Node, ValidatedSymbolTableTrait):
SetAt.__hash__ = Node.__hash__ # type: ignore[method-assign]
IfStmt.__hash__ = Node.__hash__ # type: ignore[method-assign]
InfinityLiteral.__hash__ = Node.__hash__ # type: ignore[method-assign]
NegInfinityLiteral.__hash__ = Node.__hash__ # type: ignore[method-assign]
12 changes: 7 additions & 5 deletions src/gt4py/next/iterator/ir_utils/domain_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -199,10 +199,12 @@ def domain_complement(domain: SymbolicDomain) -> SymbolicDomain:
dims_dict = {}
for dim in domain.ranges.keys():
lb, ub = domain.ranges[dim].start, domain.ranges[dim].stop
if isinstance(lb, itir.NegInfinityLiteral):
dims_dict[dim] = SymbolicRange(start=ub, stop=itir.InfinityLiteral())
elif isinstance(ub, itir.InfinityLiteral):
dims_dict[dim] = SymbolicRange(start=itir.NegInfinityLiteral(), stop=lb)
# `]-inf, a[` -> `[a, inf[`
if lb == itir.InfinityLiteral.NEGATIVE:
dims_dict[dim] = SymbolicRange(start=ub, stop=itir.InfinityLiteral.POSITIVE)
# `[a, inf]` -> `]-inf, a]`
elif ub == itir.InfinityLiteral.POSITIVE:
dims_dict[dim] = SymbolicRange(start=itir.InfinityLiteral.NEGATIVE, stop=lb)
else:
raise ValueError("Invalid domain ranges")
return SymbolicDomain(domain.grid_type, dims_dict)
Expand All @@ -218,5 +220,5 @@ def promote_to_same_dimensions(
lb, ub = domain_small.ranges[dim].start, domain_small.ranges[dim].stop
dims_dict[dim] = SymbolicRange(lb, ub)
else:
dims_dict[dim] = SymbolicRange(itir.NegInfinityLiteral(), itir.InfinityLiteral())
dims_dict[dim] = SymbolicRange(itir.InfinityLiteral.NEGATIVE, itir.InfinityLiteral.POSITIVE)
return SymbolicDomain(domain_small.grid_type, dims_dict) # TODO: fix for unstructured
11 changes: 6 additions & 5 deletions src/gt4py/next/iterator/pretty_printer.py
Original file line number Diff line number Diff line change
Expand Up @@ -133,11 +133,12 @@ def visit_Sym(self, node: ir.Sym, *, prec: int) -> list[str]:
def visit_Literal(self, node: ir.Literal, *, prec: int) -> list[str]:
return [str(node.value)]

def visit_InfinityLiteral(self, node: ir.Literal, *, prec: int) -> list[str]:
return ["INF"]

def visit_NegInfinityLiteral(self, node: ir.Literal, *, prec: int) -> list[str]:
return ["-INF"]
def visit_InfinityLiteral(self, node: ir.InfinityLiteral, *, prec: int) -> list[str]:
if node == ir.InfinityLiteral.POSITIVE:
return ["∞"]
elif node == ir.InfinityLiteral.NEGATIVE:
return ["-∞"]
raise AssertionError()

def visit_OffsetLiteral(self, node: ir.OffsetLiteral, *, prec: int) -> list[str]:
return [str(node.value) + "ₒ"]
Expand Down
97 changes: 52 additions & 45 deletions src/gt4py/next/iterator/transforms/constant_folding.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,49 +31,58 @@ def visit_FunCall(self, node: ir.FunCall) -> ir.Node:
): # `minimum(a, a)` -> `a`
return new_node.args[0]

if cpm.is_call_to(new_node, "plus"):
a, b = new_node.args
for arg, other_arg in ((a, b), (b, a)):
# `a + inf` -> `inf`
if arg == ir.InfinityLiteral.POSITIVE:
return ir.InfinityLiteral.POSITIVE
# `a + (-inf)` -> `-inf`
if arg == ir.InfinityLiteral.NEGATIVE:
return ir.InfinityLiteral.NEGATIVE

if cpm.is_call_to(new_node, "minimum"):
# `minimum(neg_inf, neg_inf)` -> `neg_inf`
if isinstance(new_node.args[0], ir.NegInfinityLiteral) or isinstance(
new_node.args[1], ir.NegInfinityLiteral
):
return ir.NegInfinityLiteral()
# `minimum(inf, a)` -> `a`
elif isinstance(new_node.args[0], ir.InfinityLiteral):
return new_node.args[1]
# `minimum(a, inf)` -> `a`
elif isinstance(new_node.args[1], ir.InfinityLiteral):
return new_node.args[0]
a, b = new_node.args
for arg, other_arg in ((a, b), (b, a)):
# `minimum(inf, a)` -> `a`
if arg == ir.InfinityLiteral.POSITIVE:
return other_arg
# `minimum(-inf, a)` -> `-inf`
if arg == ir.InfinityLiteral.NEGATIVE:
return ir.InfinityLiteral.NEGATIVE

if cpm.is_call_to(new_node, "maximum"):
# `minimum(inf, inf)` -> `inf`
if isinstance(new_node.args[0], ir.InfinityLiteral) or isinstance(
new_node.args[1], ir.InfinityLiteral
):
return ir.InfinityLiteral()
# `minimum(neg_inf, a)` -> `a`
elif isinstance(new_node.args[0], ir.NegInfinityLiteral):
return new_node.args[1]
# `minimum(a, neg_inf)` -> `a`
elif isinstance(new_node.args[1], ir.NegInfinityLiteral):
return new_node.args[0]
a, b = new_node.args
for arg, other_arg in ((a, b), (b, a)):
# `maximum(inf, a)` -> `inf`
if arg == ir.InfinityLiteral.POSITIVE:
return ir.InfinityLiteral.POSITIVE
# `maximum(-inf, a)` -> `a`
if arg == ir.InfinityLiteral.NEGATIVE:
return other_arg

if cpm.is_call_to(new_node, ("less", "less_equal")):
if isinstance(new_node.args[0], ir.NegInfinityLiteral) or isinstance(
new_node.args[1], ir.InfinityLiteral
):
a, b = new_node.args
# `-inf < v` -> `True`
# `v < inf` -> `True`
if a == ir.InfinityLiteral.NEGATIVE or b == ir.InfinityLiteral.POSITIVE:
return im.literal_from_value(True)
if isinstance(new_node.args[0], ir.InfinityLiteral) or isinstance(
new_node.args[1], ir.NegInfinityLiteral
):
# `inf < v` -> `False`
# `v < -inf ` -> `False`
if a == ir.InfinityLiteral.POSITIVE or b == ir.InfinityLiteral.NEGATIVE:
return im.literal_from_value(False)

if cpm.is_call_to(new_node, ("greater", "greater_equal")):
if isinstance(new_node.args[0], ir.NegInfinityLiteral) or isinstance(
new_node.args[1], ir.InfinityLiteral
):
return im.literal_from_value(False)
if isinstance(new_node.args[0], ir.InfinityLiteral) or isinstance(
new_node.args[1], ir.NegInfinityLiteral
):
a, b = new_node.args
# `inf > v` -> `True`
# `v > -inf ` -> `True`
if a == ir.InfinityLiteral.POSITIVE or b == ir.InfinityLiteral.NEGATIVE:
return im.literal_from_value(True)
# `-inf > v` -> `False`
# `v > inf` -> `False`
if a == ir.InfinityLiteral.NEGATIVE or b == ir.InfinityLiteral.POSITIVE:
return im.literal_from_value(False)

if (
isinstance(new_node.fun, ir.SymRef)
and new_node.fun.id == "if_"
Expand All @@ -90,15 +99,13 @@ def visit_FunCall(self, node: ir.FunCall) -> ir.Node:
and len(new_node.args) > 0
and all(isinstance(arg, ir.Literal) for arg in new_node.args)
): # `1 + 1` -> `2`
try:
if new_node.fun.id in builtins.ARITHMETIC_BUILTINS:
fun = getattr(embedded, str(new_node.fun.id))
arg_values = [
getattr(embedded, str(arg.type))(arg.value) # type: ignore[attr-defined] # arg type already established in if condition
for arg in new_node.args
]
new_node = im.literal_from_value(fun(*arg_values))
except ValueError:
pass # happens for SymRefs which are not inf or neg_inf
if new_node.fun.id in builtins.ARITHMETIC_BUILTINS:
fun = getattr(embedded, str(new_node.fun.id))
arg_values = [
getattr(embedded, str(arg.type))(arg.value)
# type: ignore[attr-defined] # arg type already established in if condition
for arg in new_node.args
]
new_node = im.literal_from_value(fun(*arg_values))

return new_node
Loading

0 comments on commit 435d057

Please sign in to comment.