Skip to content

Commit

Permalink
fix[cartesian]: Fix minimal k-range computation (#1842)
Browse files Browse the repository at this point in the history
## Description

The computation of the minimal k-ranges that happen during the
vaildate-args step is allowing for inconsistent computation to happen.
This PR stiffens the requirements on fields:

- intervals `[START+X ,END+y)` are now also considered:
- `interval(3,-1)` requires a minimal size of 5 for the interval to not
be empty
  - `interval(3,None)` now requires a minimal size of 4
- intervals `[START+X, START+Y)` and `[END+X,END+Y)` are not affected. 
- empty intervals are still allowed to have a 0-domain as to accomodate
2-dimensional fields
  - `interval(...)` still requires no k-size

## Requirements

- [x] All fixes and/or new features come with corresponding tests.

---------

Co-authored-by: Florian Deconinck <deconinck.florian@gmail.com>
  • Loading branch information
twicki and FlorianDeconinck authored Feb 27, 2025
1 parent 587d107 commit 70569bc
Show file tree
Hide file tree
Showing 4 changed files with 63 additions and 34 deletions.
54 changes: 33 additions & 21 deletions src/gt4py/cartesian/gtc/passes/gtir_k_boundary.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,20 +41,21 @@ def visit_FieldAccess(
node: gtir.FieldAccess,
vloop: gtir.VerticalLoop,
field_boundaries: Dict[str, Tuple[Union[float, int], Union[float, int]]],
include_center_interval: bool,
**kwargs: Any,
**_: Any,
):
boundary = field_boundaries[node.name]
interval = vloop.interval
if not isinstance(node.offset, gtir.VariableKOffset):
if interval.start.level == LevelMarker.START and (
include_center_interval or interval.end.level == LevelMarker.START
):
boundary = (max(-interval.start.offset - node.offset.k, boundary[0]), boundary[1])
if (
include_center_interval or interval.start.level == LevelMarker.END
) and interval.end.level == LevelMarker.END:
boundary = (boundary[0], max(interval.end.offset + node.offset.k, boundary[1]))
if interval.start.level == LevelMarker.START:
boundary = (
max(-interval.start.offset - node.offset.k, boundary[0]),
boundary[1],
)
if interval.end.level == LevelMarker.END:
boundary = (
boundary[0],
max(interval.end.offset + node.offset.k, boundary[1]),
)
if node.name in [decl.name for decl in vloop.temporaries] and (
boundary[0] > 0 or boundary[1] > 0
):
Expand All @@ -63,24 +64,35 @@ def visit_FieldAccess(
field_boundaries[node.name] = boundary


def compute_k_boundary(
node: gtir.Stencil, include_center_interval=True
) -> Dict[str, Tuple[int, int]]:
def compute_k_boundary(node: gtir.Stencil) -> Dict[str, Tuple[int, int]]:
# loop from START to END is not considered as it might be empty. additional check possible in the future
return KBoundaryVisitor().visit(node, include_center_interval=include_center_interval)
return KBoundaryVisitor().visit(node)


def compute_min_k_size(node: gtir.Stencil, include_center_interval=True) -> int:
def compute_min_k_size(node: gtir.Stencil) -> int:
"""Compute the required number of k levels to run a stencil."""

min_size_start = 0
min_size_end = 0
biggest_offset = 0
for vloop in node.vertical_loops:
if vloop.interval.start.level == LevelMarker.START and (
include_center_interval or vloop.interval.end.level == LevelMarker.START
if (
vloop.interval.start.level == LevelMarker.START
and vloop.interval.end.level == LevelMarker.END
):
min_size_start = max(min_size_start, vloop.interval.end.offset)
if not (vloop.interval.start.offset == 0 and vloop.interval.end.offset == 0):
biggest_offset = max(
biggest_offset,
vloop.interval.start.offset - vloop.interval.end.offset + 1,
)
elif (
include_center_interval or vloop.interval.start.level == LevelMarker.END
) and vloop.interval.end.level == LevelMarker.END:
vloop.interval.start.level == LevelMarker.START
and vloop.interval.end.level == LevelMarker.START
):
min_size_start = max(min_size_start, vloop.interval.end.offset)
biggest_offset = max(biggest_offset, vloop.interval.end.offset)
else:
min_size_end = max(min_size_end, -vloop.interval.start.offset)
return min_size_start + min_size_end
biggest_offset = max(biggest_offset, -vloop.interval.start.offset)
minimal_size = max(min_size_start + min_size_end, biggest_offset)
return minimal_size
Original file line number Diff line number Diff line change
Expand Up @@ -295,7 +295,8 @@ def large_k_interval(in_field: Field3D, out_field: Field3D):
with computation(PARALLEL):
with interval(0, 6):
out_field = in_field
with interval(6, -10): # this stage will only run if field has more than 16 elements
# this stenicl is only legal to call with fields that have more than 16 elements
with interval(6, -10):
out_field = in_field + 1
with interval(-10, None):
out_field = in_field
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -56,8 +56,8 @@ def test_generation(name, backend):
)
else:
args[k] = v(1.5)
# vertical domain size >= 16 required for test_large_k_interval
stencil(**args, origin=(10, 10, 5), domain=(3, 3, 16))
# vertical domain size > 16 required for test_large_k_interval
stencil(**args, origin=(10, 10, 5), domain=(3, 3, 17))


@pytest.mark.parametrize("backend", ALL_BACKENDS)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,10 @@
from gt4py import cartesian as gt4pyc
from gt4py.cartesian import gtscript as gs
from gt4py.cartesian.backend import from_name
from gt4py.cartesian.gtc.passes.gtir_k_boundary import compute_k_boundary, compute_min_k_size
from gt4py.cartesian.gtc.passes.gtir_k_boundary import (
compute_k_boundary,
compute_min_k_size,
)
from gt4py.cartesian.gtc.passes.gtir_pipeline import prune_unused_parameters
from gt4py.cartesian.gtscript import PARALLEL, computation, interval, stencil
from gt4py.cartesian.stencil_builder import StencilBuilder
Expand Down Expand Up @@ -48,21 +51,21 @@ def stencil_no_extent_0(field_a: gs.Field[float], field_b: gs.Field[float]):
field_a = field_b[0, 0, 0]


@register_test_case(k_bounds=(max(0, -2), 0), min_k_size=2)
@register_test_case(k_bounds=(0, 0), min_k_size=2)
@typing.no_type_check
def stencil_no_extent_1(field_a: gs.Field[float], field_b: gs.Field[float]):
with computation(PARALLEL), interval(0, 2):
field_a = field_b[0, 0, 0]


@register_test_case(k_bounds=(max(-1, -2), 0), min_k_size=2)
@register_test_case(k_bounds=(-1, 0), min_k_size=2)
@typing.no_type_check
def stencil_no_extent_2(field_a: gs.Field[float], field_b: gs.Field[float]):
with computation(PARALLEL), interval(1, 2):
field_a = field_b[0, 0, 0]


@register_test_case(k_bounds=(max(max(0, -2), max(-2, -2)), 0), min_k_size=3)
@register_test_case(k_bounds=(0, 0), min_k_size=4)
@typing.no_type_check
def stencil_no_extent_3(field_a: gs.Field[float], field_b: gs.Field[float]):
with computation(PARALLEL), interval(0, 2):
Expand All @@ -73,14 +76,14 @@ def stencil_no_extent_3(field_a: gs.Field[float], field_b: gs.Field[float]):
field_a = field_b[0, 0, 0]


@register_test_case(k_bounds=(0, max(-1, 0)), min_k_size=1)
@register_test_case(k_bounds=(0, 0), min_k_size=1)
@typing.no_type_check
def stencil_no_extent_4(field_a: gs.Field[float], field_b: gs.Field[float]):
with computation(PARALLEL), interval(-1, None):
field_a = field_b[0, 0, 0]


@register_test_case(k_bounds=(max(0, -1), max(-2, 0)), min_k_size=3)
@register_test_case(k_bounds=(0, 0), min_k_size=3)
@typing.no_type_check
def stencil_no_extent_5(field_a: gs.Field[float], field_b: gs.Field[float]):
with computation(PARALLEL), interval(0, 1):
Expand All @@ -89,6 +92,13 @@ def stencil_no_extent_5(field_a: gs.Field[float], field_b: gs.Field[float]):
field_a = field_b[0, 0, 0]


@register_test_case(k_bounds=(-1, -2), min_k_size=4)
@typing.no_type_check
def stencil_no_extent_6(field_a: gs.Field[float], field_b: gs.Field[float]):
with computation(PARALLEL), interval(1, -2):
field_a[0, 0, 0] = field_b[0, 0, 0]


# stencils with extent
@register_test_case(k_bounds=(5, -5), min_k_size=0)
@typing.no_type_check
Expand All @@ -111,7 +121,7 @@ def stencil_with_extent_2(field_a: gs.Field[float], field_b: gs.Field[float]):
field_a = field_b[0, 0, 5]


@register_test_case(k_bounds=(3, -3), min_k_size=3)
@register_test_case(k_bounds=(3, -3), min_k_size=4)
@typing.no_type_check
def stencil_with_extent_3(field_a: gs.Field[float], field_b: gs.Field[float]):
with computation(PARALLEL), interval(0, 2):
Expand All @@ -122,7 +132,7 @@ def stencil_with_extent_3(field_a: gs.Field[float], field_b: gs.Field[float]):
field_a = field_b[0, 0, -3]


@register_test_case(k_bounds=(-5, 5), min_k_size=1)
@register_test_case(k_bounds=(-5, 5), min_k_size=2)
@typing.no_type_check
def stencil_with_extent_4(field_a: gs.Field[float], field_b: gs.Field[float]):
with computation(PARALLEL), interval(0, -1):
Expand Down Expand Up @@ -171,7 +181,10 @@ def test_min_k_size(definition, expected_min_k_size):

@pytest.mark.parametrize("definition,expected", test_data)
def test_k_bounds_exec(definition, expected):
expected_k_bounds, expected_min_k_size = expected["k_bounds"], expected["min_k_size"]
expected_k_bounds, expected_min_k_size = (
expected["k_bounds"],
expected["min_k_size"],
)

required_field_size = expected_min_k_size + expected_k_bounds[0] + expected_k_bounds[1]

Expand Down Expand Up @@ -234,7 +247,10 @@ def stencil_with_invalid_temporary_access_end(field_a: gs.Field[float], field_b:

@pytest.mark.parametrize(
"definition",
[stencil_with_invalid_temporary_access_start, stencil_with_invalid_temporary_access_end],
[
stencil_with_invalid_temporary_access_start,
stencil_with_invalid_temporary_access_end,
],
)
def test_invalid_temporary_access(definition):
builder = StencilBuilder(definition, backend=from_name("numpy"))
Expand Down

0 comments on commit 70569bc

Please sign in to comment.