Skip to content

Commit

Permalink
Cleanup & improve test coverage
Browse files Browse the repository at this point in the history
  • Loading branch information
tehrengruber committed Feb 24, 2025
1 parent 24e2f57 commit d14fb21
Show file tree
Hide file tree
Showing 4 changed files with 74 additions and 43 deletions.
6 changes: 3 additions & 3 deletions src/gt4py/next/iterator/ir.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
# SPDX-License-Identifier: BSD-3-Clause
from __future__ import annotations

import typing
from typing import TYPE_CHECKING, ClassVar, List, Optional, Union

import gt4py.eve as eve
Expand Down Expand Up @@ -65,10 +66,9 @@ class NoneLiteral(Expr):


class InfinityLiteral(Expr):
# TODO(tehrengruber): self referential `ClassVar` not supported in eve.
if TYPE_CHECKING:
POSITIVE: ClassVar[
InfinityLiteral
] # TODO(tehrengruber): should be `ClassVar[InfinityLiteral]`, but self-referential not supported in eve
POSITIVE: ClassVar[InfinityLiteral]
NEGATIVE: ClassVar[InfinityLiteral]

name: typing.Literal["POSITIVE", "NEGATIVE"]
Expand Down
3 changes: 1 addition & 2 deletions src/gt4py/next/iterator/transforms/constant_folding.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,8 +32,7 @@ def visit_FunCall(self, node: ir.FunCall) -> ir.Node:
return new_node.args[0]

if cpm.is_call_to(new_node, "plus"):
a, b = new_node.args
for arg, other_arg in ((a, b), (b, a)):
for arg in new_node.args:
# `a + inf` -> `inf`
if arg == ir.InfinityLiteral.POSITIVE:
return ir.InfinityLiteral.POSITIVE
Expand Down
10 changes: 5 additions & 5 deletions src/gt4py/next/iterator/transforms/infer_domain_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,23 +63,23 @@ def visit_FunCall(self, node: ir.FunCall, **kwargs) -> ir.Node:
min_: int | ir.InfinityLiteral
max_: int | ir.InfinityLiteral

# IDim < 1
# `IDim < 1`
if cpm.is_call_to(node, "less"):
min_ = ir.InfinityLiteral.NEGATIVE
max_ = value
# IDim <= 1
# `IDim <= 1`
elif cpm.is_call_to(node, "less_equal"):
min_ = ir.InfinityLiteral.NEGATIVE
max_ = im.plus(value, 1)
# IDim > 1
# `IDim > 1`
elif cpm.is_call_to(node, "greater"):
min_ = im.plus(value, 1)
max_ = ir.InfinityLiteral.POSITIVE
# IDim >= 1
# `IDim >= 1`
elif cpm.is_call_to(node, "greater_equal"):
min_ = value
max_ = ir.InfinityLiteral.POSITIVE
# IDim == 1 # TODO: isn't this removed before and rewritten as two concat_where?
# `IDim == 1` # TODO: isn't this removed before and rewritten as two concat_where?
elif cpm.is_call_to(node, "eq"):
min_ = value
max_ = im.plus(value, 1)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -66,29 +66,81 @@ def testee(ground: cases.IJKField, air: cases.IJKField) -> cases.IJKField:


@pytest.mark.uses_frontend_concat_where
def test_concat_where_non_overlapping_different_dims(cartesian_case):
def test_concat_where_single_level_broadcast(cartesian_case):
@gtx.field_operator
def testee(
ground: cases.IJField, # note: boundary field is only defined in K
air: cases.IJKField,
) -> cases.IJKField:
return concat_where(KDim == 0, ground, air)
def testee(a: cases.KField, b: cases.IJKField) -> cases.IJKField:
return concat_where(KDim == 0, a, b)

out = cases.allocate(cartesian_case, testee, cases.RETURN)()
ground = cases.allocate(cartesian_case, testee, "ground", domain=gtx.domain({KDim: (0, 1)}))()
air = cases.allocate(cartesian_case, testee, "air", domain=out.domain.slice_at[:, :, 1:])()
a = cases.allocate(
cartesian_case, testee, "a", domain=gtx.domain({KDim: out.domain.shape[2]})
)()
b = cases.allocate(cartesian_case, testee, "b", domain=out.domain.slice_at[:, :, 1:])()

ref = np.concatenate(
(
np.tile(
ground.asnumpy(), (*air.domain.shape[0:2], len(ground.domain[KDim].unit_range))
),
air.asnumpy(),
np.tile(a.asnumpy()[0], (*b.domain.shape[0:2], 1)),
b.asnumpy(),
),
axis=2,
)
cases.verify(cartesian_case, testee, a, b, out=out, ref=ref)

cases.verify(cartesian_case, testee, ground, air, out=out, ref=ref)

@pytest.mark.uses_frontend_concat_where
def test_concat_where_single_level_broadcast(cartesian_case):
@gtx.field_operator
def testee(a: cases.KField, b: cases.IJKField) -> cases.IJKField:
return concat_where(KDim == 0, a, b)

out = cases.allocate(cartesian_case, testee, cases.RETURN)()
# note: this field is only defined on K: 0, 1, i.e., contains only a single value
a = cases.allocate(cartesian_case, testee, "a", domain=gtx.domain({KDim: (0, 1)}))()
b = cases.allocate(cartesian_case, testee, "b", domain=out.domain.slice_at[:, :, 1:])()

ref = np.concatenate(
(
np.tile(a.asnumpy()[0], (*b.domain.shape[0:2], 1)),
b.asnumpy(),
),
axis=2,
)
cases.verify(cartesian_case, testee, a, b, out=out, ref=ref)


@pytest.mark.uses_frontend_concat_where
def test_lap_like(cartesian_case):
pytest.xfail("Requires #1847.")

@gtx.field_operator
def testee(
input: cases.IJKField, boundary: float, shape: tuple[np.int64, np.int64, np.int64]
) -> cases.IJKField:
return concat_where(
(IDim == 0)
| (JDim == 0)
| (KDim == 0)
| (IDim == shape[0] - 1)
| (JDim == shape[1] - 1)
| (KDim == shape[2] - 1),
boundary,
input,
)

out = cases.allocate(cartesian_case, testee, cases.RETURN)()
input = cases.allocate(
cartesian_case, testee, "input", domain=out.domain.slice_at[1:-1, 1:-1, 1:-1]
)()
boundary = 2.0

ref = np.full(out.domain.shape, np.nan)
ref[0, :, :] = boundary
ref[:, 0, :] = boundary
ref[:, :, 0] = boundary
ref[-1, :, :] = boundary
ref[:, -1, :] = boundary
ref[:, :, -1] = boundary
cases.verify(cartesian_case, testee, input, boundary, out.domain.shape, out=out, ref=ref)


@pytest.mark.uses_frontend_concat_where
Expand Down Expand Up @@ -155,26 +207,6 @@ def testee(interior: cases.KField, boundary: cases.KField) -> cases.KField:
cases.verify(cartesian_case, testee, interior, boundary, out=out, ref=ref)


@pytest.mark.uses_frontend_concat_where
def test_boundary_horizontal_slice(cartesian_case):
@gtx.field_operator
def testee(interior: cases.IJKField, boundary: cases.IJField) -> cases.IJKField:
return concat_where(KDim == 0, boundary, interior)

interior = cases.allocate(cartesian_case, testee, "interior")()
boundary = cases.allocate(cartesian_case, testee, "boundary")()
out = cases.allocate(cartesian_case, testee, cases.RETURN)()

k = np.arange(0, cartesian_case.default_sizes[KDim])
ref = np.where(
k[np.newaxis, np.newaxis, :] == 0,
boundary.asnumpy()[:, :, np.newaxis],
interior.asnumpy(),
)

cases.verify(cartesian_case, testee, interior, boundary, out=out, ref=ref)


@pytest.mark.uses_frontend_concat_where
def test_boundary_single_layer(cartesian_case):
@gtx.field_operator
Expand Down

0 comments on commit d14fb21

Please sign in to comment.