diff --git a/src/gt4py/cartesian/gtc/passes/gtir_k_boundary.py b/src/gt4py/cartesian/gtc/passes/gtir_k_boundary.py index 96cec5b6d4..40c31dca53 100644 --- a/src/gt4py/cartesian/gtc/passes/gtir_k_boundary.py +++ b/src/gt4py/cartesian/gtc/passes/gtir_k_boundary.py @@ -41,20 +41,21 @@ def visit_FieldAccess( node: gtir.FieldAccess, vloop: gtir.VerticalLoop, field_boundaries: Dict[str, Tuple[Union[float, int], Union[float, int]]], - include_center_interval: bool, - **kwargs: Any, + **_: Any, ): boundary = field_boundaries[node.name] interval = vloop.interval if not isinstance(node.offset, gtir.VariableKOffset): - if interval.start.level == LevelMarker.START and ( - include_center_interval or interval.end.level == LevelMarker.START - ): - boundary = (max(-interval.start.offset - node.offset.k, boundary[0]), boundary[1]) - if ( - include_center_interval or interval.start.level == LevelMarker.END - ) and interval.end.level == LevelMarker.END: - boundary = (boundary[0], max(interval.end.offset + node.offset.k, boundary[1])) + if interval.start.level == LevelMarker.START: + boundary = ( + max(-interval.start.offset - node.offset.k, boundary[0]), + boundary[1], + ) + if interval.end.level == LevelMarker.END: + boundary = ( + boundary[0], + max(interval.end.offset + node.offset.k, boundary[1]), + ) if node.name in [decl.name for decl in vloop.temporaries] and ( boundary[0] > 0 or boundary[1] > 0 ): @@ -63,24 +64,35 @@ def visit_FieldAccess( field_boundaries[node.name] = boundary -def compute_k_boundary( - node: gtir.Stencil, include_center_interval=True -) -> Dict[str, Tuple[int, int]]: +def compute_k_boundary(node: gtir.Stencil) -> Dict[str, Tuple[int, int]]: # loop from START to END is not considered as it might be empty. additional check possible in the future - return KBoundaryVisitor().visit(node, include_center_interval=include_center_interval) + return KBoundaryVisitor().visit(node) -def compute_min_k_size(node: gtir.Stencil, include_center_interval=True) -> int: +def compute_min_k_size(node: gtir.Stencil) -> int: """Compute the required number of k levels to run a stencil.""" + min_size_start = 0 min_size_end = 0 + biggest_offset = 0 for vloop in node.vertical_loops: - if vloop.interval.start.level == LevelMarker.START and ( - include_center_interval or vloop.interval.end.level == LevelMarker.START + if ( + vloop.interval.start.level == LevelMarker.START + and vloop.interval.end.level == LevelMarker.END ): - min_size_start = max(min_size_start, vloop.interval.end.offset) + if not (vloop.interval.start.offset == 0 and vloop.interval.end.offset == 0): + biggest_offset = max( + biggest_offset, + vloop.interval.start.offset - vloop.interval.end.offset + 1, + ) elif ( - include_center_interval or vloop.interval.start.level == LevelMarker.END - ) and vloop.interval.end.level == LevelMarker.END: + vloop.interval.start.level == LevelMarker.START + and vloop.interval.end.level == LevelMarker.START + ): + min_size_start = max(min_size_start, vloop.interval.end.offset) + biggest_offset = max(biggest_offset, vloop.interval.end.offset) + else: min_size_end = max(min_size_end, -vloop.interval.start.offset) - return min_size_start + min_size_end + biggest_offset = max(biggest_offset, -vloop.interval.start.offset) + minimal_size = max(min_size_start + min_size_end, biggest_offset) + return minimal_size diff --git a/tests/cartesian_tests/integration_tests/multi_feature_tests/stencil_definitions.py b/tests/cartesian_tests/integration_tests/multi_feature_tests/stencil_definitions.py index 8112866092..217c0ee488 100644 --- a/tests/cartesian_tests/integration_tests/multi_feature_tests/stencil_definitions.py +++ b/tests/cartesian_tests/integration_tests/multi_feature_tests/stencil_definitions.py @@ -295,7 +295,8 @@ def large_k_interval(in_field: Field3D, out_field: Field3D): with computation(PARALLEL): with interval(0, 6): out_field = in_field - with interval(6, -10): # this stage will only run if field has more than 16 elements + # this stenicl is only legal to call with fields that have more than 16 elements + with interval(6, -10): out_field = in_field + 1 with interval(-10, None): out_field = in_field diff --git a/tests/cartesian_tests/integration_tests/multi_feature_tests/test_code_generation.py b/tests/cartesian_tests/integration_tests/multi_feature_tests/test_code_generation.py index 66d45abe21..4e0fa8903c 100644 --- a/tests/cartesian_tests/integration_tests/multi_feature_tests/test_code_generation.py +++ b/tests/cartesian_tests/integration_tests/multi_feature_tests/test_code_generation.py @@ -56,8 +56,8 @@ def test_generation(name, backend): ) else: args[k] = v(1.5) - # vertical domain size >= 16 required for test_large_k_interval - stencil(**args, origin=(10, 10, 5), domain=(3, 3, 16)) + # vertical domain size > 16 required for test_large_k_interval + stencil(**args, origin=(10, 10, 5), domain=(3, 3, 17)) @pytest.mark.parametrize("backend", ALL_BACKENDS) diff --git a/tests/cartesian_tests/unit_tests/test_gtc/test_passes/test_min_k_interval.py b/tests/cartesian_tests/unit_tests/test_gtc/test_passes/test_min_k_interval.py index 078adcc8da..6bb4ec63f6 100644 --- a/tests/cartesian_tests/unit_tests/test_gtc/test_passes/test_min_k_interval.py +++ b/tests/cartesian_tests/unit_tests/test_gtc/test_passes/test_min_k_interval.py @@ -16,7 +16,10 @@ from gt4py import cartesian as gt4pyc from gt4py.cartesian import gtscript as gs from gt4py.cartesian.backend import from_name -from gt4py.cartesian.gtc.passes.gtir_k_boundary import compute_k_boundary, compute_min_k_size +from gt4py.cartesian.gtc.passes.gtir_k_boundary import ( + compute_k_boundary, + compute_min_k_size, +) from gt4py.cartesian.gtc.passes.gtir_pipeline import prune_unused_parameters from gt4py.cartesian.gtscript import PARALLEL, computation, interval, stencil from gt4py.cartesian.stencil_builder import StencilBuilder @@ -48,21 +51,21 @@ def stencil_no_extent_0(field_a: gs.Field[float], field_b: gs.Field[float]): field_a = field_b[0, 0, 0] -@register_test_case(k_bounds=(max(0, -2), 0), min_k_size=2) +@register_test_case(k_bounds=(0, 0), min_k_size=2) @typing.no_type_check def stencil_no_extent_1(field_a: gs.Field[float], field_b: gs.Field[float]): with computation(PARALLEL), interval(0, 2): field_a = field_b[0, 0, 0] -@register_test_case(k_bounds=(max(-1, -2), 0), min_k_size=2) +@register_test_case(k_bounds=(-1, 0), min_k_size=2) @typing.no_type_check def stencil_no_extent_2(field_a: gs.Field[float], field_b: gs.Field[float]): with computation(PARALLEL), interval(1, 2): field_a = field_b[0, 0, 0] -@register_test_case(k_bounds=(max(max(0, -2), max(-2, -2)), 0), min_k_size=3) +@register_test_case(k_bounds=(0, 0), min_k_size=4) @typing.no_type_check def stencil_no_extent_3(field_a: gs.Field[float], field_b: gs.Field[float]): with computation(PARALLEL), interval(0, 2): @@ -73,14 +76,14 @@ def stencil_no_extent_3(field_a: gs.Field[float], field_b: gs.Field[float]): field_a = field_b[0, 0, 0] -@register_test_case(k_bounds=(0, max(-1, 0)), min_k_size=1) +@register_test_case(k_bounds=(0, 0), min_k_size=1) @typing.no_type_check def stencil_no_extent_4(field_a: gs.Field[float], field_b: gs.Field[float]): with computation(PARALLEL), interval(-1, None): field_a = field_b[0, 0, 0] -@register_test_case(k_bounds=(max(0, -1), max(-2, 0)), min_k_size=3) +@register_test_case(k_bounds=(0, 0), min_k_size=3) @typing.no_type_check def stencil_no_extent_5(field_a: gs.Field[float], field_b: gs.Field[float]): with computation(PARALLEL), interval(0, 1): @@ -89,6 +92,13 @@ def stencil_no_extent_5(field_a: gs.Field[float], field_b: gs.Field[float]): field_a = field_b[0, 0, 0] +@register_test_case(k_bounds=(-1, -2), min_k_size=4) +@typing.no_type_check +def stencil_no_extent_6(field_a: gs.Field[float], field_b: gs.Field[float]): + with computation(PARALLEL), interval(1, -2): + field_a[0, 0, 0] = field_b[0, 0, 0] + + # stencils with extent @register_test_case(k_bounds=(5, -5), min_k_size=0) @typing.no_type_check @@ -111,7 +121,7 @@ def stencil_with_extent_2(field_a: gs.Field[float], field_b: gs.Field[float]): field_a = field_b[0, 0, 5] -@register_test_case(k_bounds=(3, -3), min_k_size=3) +@register_test_case(k_bounds=(3, -3), min_k_size=4) @typing.no_type_check def stencil_with_extent_3(field_a: gs.Field[float], field_b: gs.Field[float]): with computation(PARALLEL), interval(0, 2): @@ -122,7 +132,7 @@ def stencil_with_extent_3(field_a: gs.Field[float], field_b: gs.Field[float]): field_a = field_b[0, 0, -3] -@register_test_case(k_bounds=(-5, 5), min_k_size=1) +@register_test_case(k_bounds=(-5, 5), min_k_size=2) @typing.no_type_check def stencil_with_extent_4(field_a: gs.Field[float], field_b: gs.Field[float]): with computation(PARALLEL), interval(0, -1): @@ -171,7 +181,10 @@ def test_min_k_size(definition, expected_min_k_size): @pytest.mark.parametrize("definition,expected", test_data) def test_k_bounds_exec(definition, expected): - expected_k_bounds, expected_min_k_size = expected["k_bounds"], expected["min_k_size"] + expected_k_bounds, expected_min_k_size = ( + expected["k_bounds"], + expected["min_k_size"], + ) required_field_size = expected_min_k_size + expected_k_bounds[0] + expected_k_bounds[1] @@ -234,7 +247,10 @@ def stencil_with_invalid_temporary_access_end(field_a: gs.Field[float], field_b: @pytest.mark.parametrize( "definition", - [stencil_with_invalid_temporary_access_start, stencil_with_invalid_temporary_access_end], + [ + stencil_with_invalid_temporary_access_start, + stencil_with_invalid_temporary_access_end, + ], ) def test_invalid_temporary_access(definition): builder = StencilBuilder(definition, backend=from_name("numpy"))