From 4456f433019a7db8aa185398ca5a90181a50f1a8 Mon Sep 17 00:00:00 2001 From: Till Ehrengruber Date: Thu, 20 Feb 2025 01:35:18 +0100 Subject: [PATCH 1/7] fix[next]: Git ignore gt4py cache (#1875) --- .gitignore | 1 + 1 file changed, 1 insertion(+) diff --git a/.gitignore b/.gitignore index b1c8ed26e9..ebbbfaebeb 100644 --- a/.gitignore +++ b/.gitignore @@ -6,6 +6,7 @@ _local /src/__init__.py /tests/__init__.py .gt_cache/ +.gt4py_cache/ .gt_cache_pytest*/ # DaCe From 1a46fb0f7bce91ddd793c771d2f5d31be26528af Mon Sep 17 00:00:00 2001 From: edopao Date: Thu, 20 Feb 2025 13:13:19 +0100 Subject: [PATCH 2/7] fix[next][dace]: remove temporary arrays with runtime shape on the output of a mapped nested SDFG (#1877) This PR provides a better fix than the one delivered earlier in https://github.com/GridTools/gt4py/pull/1828. It adds a check to detect whether the temporary output data has compile-time or runtime size. In case of runtime size, the transient array on the output connector of a mapped nested SDFG is removed. This is needed in order to avoid dynamic memory allocation inside the cuda kernel that represents a parallel map scope. --- .../runners/dace/gtir_dataflow.py | 39 ++++++++++++++----- 1 file changed, 30 insertions(+), 9 deletions(-) diff --git a/src/gt4py/next/program_processors/runners/dace/gtir_dataflow.py b/src/gt4py/next/program_processors/runners/dace/gtir_dataflow.py index e6f33208e3..43e7c6354d 100644 --- a/src/gt4py/next/program_processors/runners/dace/gtir_dataflow.py +++ b/src/gt4py/next/program_processors/runners/dace/gtir_dataflow.py @@ -232,15 +232,36 @@ def connect( dest: dace.nodes.AccessNode, subset: dace_subsets.Range, ) -> None: - # retrieve the node which writes the result - last_node = self.state.in_edges(self.result.dc_node)[0].src - if isinstance(last_node, dace.nodes.Tasklet): - # the last transient node can be deleted - # Note that it could also be applied when `last_node` is a NestedSDFG, - # but an exception would be when the inner write to global data is a - # WCR memlet, because that prevents fusion of the outer map. This case - # happens for the reduce with skip values, which uses a map with WCR. - last_node_connector = self.state.in_edges(self.result.dc_node)[0].src_conn + write_edge = self.state.in_edges(self.result.dc_node)[0] + write_size = write_edge.data.dst_subset.num_elements() + # check the kind of node which writes the result + if isinstance(write_edge.src, dace.nodes.Tasklet): + # The temporary data written by a tasklet can be safely deleted + assert write_size.is_constant() + remove_last_node = True + elif isinstance(write_edge.src, dace.nodes.NestedSDFG): + if write_size.is_constant(): + # Temporary data with compile-time size is allocated on the stack + # and therefore is safe to keep. We decide to keep it as a workaround + # for a dace issue with memlet propagation in combination with + # nested SDFGs containing conditional blocks. The output memlet + # of such blocks will be marked as dynamic because dace is not able + # to detect the exact size of a conditional branch dataflow, even + # in case of if-else expressions with exact same output data. + remove_last_node = False + else: + # In case the output data has runtime size it is necessary to remove + # it in order to avoid dynamic memory allocation inside a parallel + # map scope. Otherwise, the memory allocation will for sure lead + # to performance degradation, and eventually illegal memory issues + # when the gpu runs out of local memory. + remove_last_node = True + else: + remove_last_node = False + + if remove_last_node: + last_node = write_edge.src + last_node_connector = write_edge.src_conn self.state.remove_node(self.result.dc_node) else: last_node = self.result.dc_node From 1176b2d4bf733f1f43b36c12d837904a9e2b52ad Mon Sep 17 00:00:00 2001 From: Roman Cattaneo Date: Thu, 20 Feb 2025 19:40:00 +0100 Subject: [PATCH 3/7] fix[cartesian]: DataType.isinteger() for 16-bit integers (#1878) ## Description 16-bit integers were missing in the set of data types that return true for `DataType.isinteger()`. Added missing test coverage. ## Requirements - [x] All fixes and/or new features come with corresponding tests. - [ ] Important design decisions have been documented in the appropriate ADR inside the [docs/development/ADRs/](docs/development/ADRs/Index.md) folder. N/A Co-authored-by: Roman Cattaneo <1116746+romanc@users.noreply.github.com> --- src/gt4py/cartesian/gtc/common.py | 2 +- .../unit_tests/test_gtc/test_common.py | 18 ++++++++++++++++++ 2 files changed, 19 insertions(+), 1 deletion(-) diff --git a/src/gt4py/cartesian/gtc/common.py b/src/gt4py/cartesian/gtc/common.py index ef38a9a658..60236a3e97 100644 --- a/src/gt4py/cartesian/gtc/common.py +++ b/src/gt4py/cartesian/gtc/common.py @@ -118,7 +118,7 @@ def isbool(self): return self == self.BOOL def isinteger(self): - return self in (self.INT8, self.INT32, self.INT64) + return self in (self.INT8, self.INT16, self.INT32, self.INT64) def isfloat(self): return self in (self.FLOAT32, self.FLOAT64) diff --git a/tests/cartesian_tests/unit_tests/test_gtc/test_common.py b/tests/cartesian_tests/unit_tests/test_gtc/test_common.py index 68006c113b..4e799d2090 100644 --- a/tests/cartesian_tests/unit_tests/test_gtc/test_common.py +++ b/tests/cartesian_tests/unit_tests/test_gtc/test_common.py @@ -41,6 +41,24 @@ # - For testing non-leave nodes, introduce builders with defaults (for leave nodes as well) +def test_data_type_methods(): + for type in DataType: + if type == DataType.BOOL: + assert type.isbool() + else: + assert not type.isbool() + + if type in (DataType.INT8, DataType.INT16, DataType.INT32, DataType.INT64): + assert type.isinteger() + else: + assert not type.isinteger() + + if type in (DataType.FLOAT32, DataType.FLOAT64): + assert type.isfloat() + else: + assert not type.isfloat() + + class DummyExpr(Expr): """Fake expression for cases where a concrete expression is not needed.""" From 198469177aa7f7cd493589fa155ec65bd74fd5dc Mon Sep 17 00:00:00 2001 From: Roman Cattaneo Date: Fri, 21 Feb 2025 15:02:20 +0100 Subject: [PATCH 4/7] fix[cartesian, dace]: warn about missing support for casting in variable k offsets (#1882) ## Description We figured that DaCe backends are currently missing support for casting in variable k offsets. This PR - adds a codegen test with a cast in a variable k offset - adds a node validator for the DaCe backends complaining about missing for support. - adds an `xfail` test for the node validator This should be fixed down the road. Here's the issue https://github.com/GridTools/gt4py/issues/1881 to keep track. The PR also has two smaller and unrelated commits - 741c448f5258fccbca942a6cc9548c7554e454c9 increases test coverage with another codgen test that has a couple of read after write access patterns which were breaking the "new bridge" (see https://github.com/GEOS-ESM/NDSL/issues/53). - e98ddc54f8571d8d24d2169a421955c4b4e795e1 just forwards all keyword arguments when visiting offsets. I don't think this was a problem until now, but it's best practice to forward everything. ## Requirements - [x] All fixes and/or new features come with corresponding tests. - [ ] Important design decisions have been documented in the appropriate ADR inside the [docs/development/ADRs/](docs/development/ADRs/Index.md) folder. N/A --------- Co-authored-by: Roman Cattaneo <1116746+romanc@users.noreply.github.com> Co-authored-by: Florian Deconinck --- src/gt4py/cartesian/gtc/dace/daceir.py | 8 +- .../gtc/dace/expansion/tasklet_codegen.py | 4 +- tests/cartesian_tests/definitions.py | 6 ++ .../test_code_generation.py | 94 ++++++++++++++++++- 4 files changed, 109 insertions(+), 3 deletions(-) diff --git a/src/gt4py/cartesian/gtc/dace/daceir.py b/src/gt4py/cartesian/gtc/dace/daceir.py index 492a9598c5..43a33fdd6d 100644 --- a/src/gt4py/cartesian/gtc/dace/daceir.py +++ b/src/gt4py/cartesian/gtc/dace/daceir.py @@ -734,7 +734,13 @@ class ScalarAccess(common.ScalarAccess, Expr): class VariableKOffset(common.VariableKOffset[Expr]): - pass + @datamodels.validator("k") + def no_casts_in_offset_expression(self, _: datamodels.Attribute, expression: Expr) -> None: + for part in expression.walk_values(): + if isinstance(part, Cast): + raise ValueError( + "DaCe backends are currently missing support for casts in variable k offsets. See issue https://github.com/GridTools/gt4py/issues/1881." + ) class IndexAccess(common.FieldAccess, Expr): diff --git a/src/gt4py/cartesian/gtc/dace/expansion/tasklet_codegen.py b/src/gt4py/cartesian/gtc/dace/expansion/tasklet_codegen.py index 8033c64710..2948b9d76d 100644 --- a/src/gt4py/cartesian/gtc/dace/expansion/tasklet_codegen.py +++ b/src/gt4py/cartesian/gtc/dace/expansion/tasklet_codegen.py @@ -44,7 +44,9 @@ def _visit_offset( else: int_sizes.append(None) sym_offsets = [ - dace.symbolic.pystr_to_symbolic(self.visit(off, **kwargs)) + dace.symbolic.pystr_to_symbolic( + self.visit(off, access_info=access_info, decl=decl, **kwargs) + ) for off in (node.to_dict()["i"], node.to_dict()["j"], node.k) ] for axis in access_info.variable_offset_axes: diff --git a/tests/cartesian_tests/definitions.py b/tests/cartesian_tests/definitions.py index 7499ad4a95..4d52b9b773 100644 --- a/tests/cartesian_tests/definitions.py +++ b/tests/cartesian_tests/definitions.py @@ -51,6 +51,12 @@ def _get_backends_with_storage_info(storage_info_kind: str): _PERFORMANCE_BACKEND_NAMES = [name for name in _ALL_BACKEND_NAMES if name not in ("numpy", "cuda")] PERFORMANCE_BACKENDS = [_backend_name_as_param(name) for name in _PERFORMANCE_BACKEND_NAMES] +DACE_BACKENDS = [ + _backend_name_as_param(name) + for name in filter(lambda name: name.startswith("dace:"), _ALL_BACKEND_NAMES) +] +NON_DACE_BACKENDS = [backend for backend in ALL_BACKENDS if backend not in DACE_BACKENDS] + @pytest.fixture() def id_version(): diff --git a/tests/cartesian_tests/integration_tests/multi_feature_tests/test_code_generation.py b/tests/cartesian_tests/integration_tests/multi_feature_tests/test_code_generation.py index 8ace0de740..8e5f3466d0 100644 --- a/tests/cartesian_tests/integration_tests/multi_feature_tests/test_code_generation.py +++ b/tests/cartesian_tests/integration_tests/multi_feature_tests/test_code_generation.py @@ -27,7 +27,13 @@ ) from gt4py.storage.cartesian import utils as storage_utils -from cartesian_tests.definitions import ALL_BACKENDS, CPU_BACKENDS, get_array_library +from cartesian_tests.definitions import ( + ALL_BACKENDS, + CPU_BACKENDS, + DACE_BACKENDS, + NON_DACE_BACKENDS, + get_array_library, +) from cartesian_tests.integration_tests.multi_feature_tests.stencil_definitions import ( EXTERNALS_REGISTRY as externals_registry, REGISTRY as stencil_definitions, @@ -762,3 +768,89 @@ def test( out_arr = gt_storage.ones(backend=backend, shape=domain, dtype=np.float64) test(in_arr, out_arr) assert (out_arr[:, :, :] == 388.0).all() + + +@pytest.mark.parametrize("backend", NON_DACE_BACKENDS) +def test_cast_in_index(backend): + @gtscript.stencil(backend) + def cast_in_index( + in_field: Field[np.float64], i32: np.int32, i64: np.int64, out_field: Field[np.float64] + ): + """Simple copy stencil with forced cast in index calculation.""" + with computation(PARALLEL), interval(...): + out_field = in_field[0, 0, i32 - i64] + + +@pytest.mark.parametrize("backend", DACE_BACKENDS) +@pytest.mark.xfail(raises=ValueError) +def test_dace_no_cast_in_index(backend): + @gtscript.stencil(backend) + def cast_in_index( + in_field: Field[np.float64], i32: np.int32, i64: np.int64, out_field: Field[np.float64] + ): + """Simple copy stencil with forced cast in index calculation.""" + with computation(PARALLEL), interval(...): + out_field = in_field[0, 0, i32 - i64] + + +@pytest.mark.parametrize("backend", ALL_BACKENDS) +def test_read_after_write_stencil(backend): + """Stencil with multiple read after write access patterns.""" + + @gtscript.stencil(backend=backend) + def lagrangian_contributions( + q: Field[np.float64], + pe1: Field[np.float64], + pe2: Field[np.float64], + q4_1: Field[np.float64], + q4_2: Field[np.float64], + q4_3: Field[np.float64], + q4_4: Field[np.float64], + dp1: Field[np.float64], + lev: gtscript.Field[gtscript.IJ, np.int64], + ): + """ + Args: + q (out): + pe1 (in): + pe2 (in): + q4_1 (in): + q4_2 (in): + q4_3 (in): + q4_4 (in): + dp1 (in): + lev (inout): + """ + with computation(FORWARD), interval(...): + pl = (pe2 - pe1[0, 0, lev]) / dp1[0, 0, lev] + if pe2[0, 0, 1] <= pe1[0, 0, lev + 1]: + pr = (pe2[0, 0, 1] - pe1[0, 0, lev]) / dp1[0, 0, lev] + q = ( + q4_2[0, 0, lev] + + 0.5 * (q4_4[0, 0, lev] + q4_3[0, 0, lev] - q4_2[0, 0, lev]) * (pr + pl) + - q4_4[0, 0, lev] * 1.0 / 3.0 * (pr * (pr + pl) + pl * pl) + ) + else: + qsum = (pe1[0, 0, lev + 1] - pe2) * ( + q4_2[0, 0, lev] + + 0.5 * (q4_4[0, 0, lev] + q4_3[0, 0, lev] - q4_2[0, 0, lev]) * (1.0 + pl) + - q4_4[0, 0, lev] * 1.0 / 3.0 * (1.0 + pl * (1.0 + pl)) + ) + lev = lev + 1 + while pe1[0, 0, lev + 1] < pe2[0, 0, 1]: + qsum += dp1[0, 0, lev] * q4_1[0, 0, lev] + lev = lev + 1 + dp = pe2[0, 0, 1] - pe1[0, 0, lev] + esl = dp / dp1[0, 0, lev] + qsum += dp * ( + q4_2[0, 0, lev] + + 0.5 + * esl + * ( + q4_3[0, 0, lev] + - q4_2[0, 0, lev] + + q4_4[0, 0, lev] * (1.0 - (2.0 / 3.0) * esl) + ) + ) + q = qsum / (pe2[0, 0, 1] - pe2) + lev = lev - 1 From c6a841ad045b2b08c26642bb5fc642d88611d4e7 Mon Sep 17 00:00:00 2001 From: Sara Faghih-Naini Date: Tue, 25 Feb 2025 09:58:15 +0100 Subject: [PATCH 5/7] Address review comments --- src/gt4py/next/common.py | 37 ++++++++++--- .../iterator/type_system/type_synthesizer.py | 12 ++--- .../next/type_system/type_specifications.py | 6 +++ .../iterator_tests/test_type_inference.py | 3 +- tests/next_tests/unit_tests/test_common.py | 53 +++++++++---------- 5 files changed, 65 insertions(+), 46 deletions(-) diff --git a/src/gt4py/next/common.py b/src/gt4py/next/common.py index b37c292bc1..edf98d8e4a 100644 --- a/src/gt4py/next/common.py +++ b/src/gt4py/next/common.py @@ -69,6 +69,9 @@ def __str__(self) -> str: return self.value +dims_kind_order = {DimensionKind.HORIZONTAL: 1, DimensionKind.VERTICAL: 2, DimensionKind.LOCAL: 3} + + def dimension_to_implicit_offset(dim: str) -> str: """ Return name of offset implicitly defined by a dimension. @@ -1123,12 +1126,20 @@ class GridType(StrEnum): UNSTRUCTURED = "unstructured" +def check_dims(dims: list[Dimension]) -> None: + if sum(1 for dim in dims if dim.kind == DimensionKind.LOCAL) > 1: + raise ValueError("There are more than one dimension with DimensionKind 'LOCAL'.") + + if dims != sorted(dims, key=lambda dim: (dims_kind_order[dim.kind], dim.value)): + raise ValueError("Dimensions are not correctly ordered.") + + def promote_dims(*dims_list: Sequence[Dimension]) -> list[Dimension]: """ - Find a sorted ordering of multiple lists of dimensions. + Find an ordering of multiple lists of dimensions. The resulting list contains all unique dimensions from the input lists, - sorted first by `Dimension.kind` (`HORIZONTAL` < `VERTICAL` < `LOCAL`) and then + sorted first by dims_kind_order, i.e., `Dimension.kind` (`HORIZONTAL` < `VERTICAL` < `LOCAL`) and then lexicographically by `Dimension.value`. Examples: @@ -1138,17 +1149,29 @@ def promote_dims(*dims_list: Sequence[Dimension]) -> list[Dimension]: >>> K = Dimension("K", DimensionKind.VERTICAL) >>> E2V = Dimension("E2V", kind=DimensionKind.LOCAL) >>> E2C = Dimension("E2C", kind=DimensionKind.LOCAL) - >>> promote_dims([K, J], [I, K]) == [I, J, K] + >>> promote_dims([J, K], [I, K]) == [I, J, K] True - >>> promote_dims([K, I], [E2C, E2V]) == [I, K, E2C, E2V] + >>> promote_dims([K, J], [I, K]) + Traceback (most recent call last): + ... + raise ValueError("Dimensions are not correctly ordered.") + ValueError: Dimensions are not correctly ordered. + >>> promote_dims([I, K], [J, E2V]) == [I, J, K, E2V] True + >>> promote_dims([I, E2C], [K, E2V]) + Traceback (most recent call last): + ... + raise ValueError("There are more than one dimension with DimensionKind 'LOCAL'.") + ValueError: There are more than one dimension with DimensionKind 'LOCAL'. """ + for dims in dims_list: + check_dims(list(dims)) unique_dims = {dim for dims in dims_list for dim in dims} - kind_order = {DimensionKind.HORIZONTAL: 1, DimensionKind.VERTICAL: 2, DimensionKind.LOCAL: 3} - - return sorted(unique_dims, key=lambda dim: (kind_order[dim.kind], dim.value)) + promoted_dims = sorted(unique_dims, key=lambda dim: (dims_kind_order[dim.kind], dim.value)) + check_dims(promoted_dims) + return promoted_dims class FieldBuiltinFuncRegistry: diff --git a/src/gt4py/next/iterator/type_system/type_synthesizer.py b/src/gt4py/next/iterator/type_system/type_synthesizer.py index 2327a47df4..131b773dd2 100644 --- a/src/gt4py/next/iterator/type_system/type_synthesizer.py +++ b/src/gt4py/next/iterator/type_system/type_synthesizer.py @@ -93,19 +93,13 @@ def power(base: ts.ScalarType, exponent: ts.ScalarType) -> ts.ScalarType: @_register_builtin_type_synthesizer(fun_names=builtins.BINARY_MATH_NUMBER_BUILTINS) -def _( - lhs: ts.ScalarType | ts.FieldType, rhs: ts.ScalarType | ts.FieldType -) -> ts.ScalarType | ts.FieldType: +def _(lhs: ts.ScalarType, rhs: ts.ScalarType) -> ts.ScalarType: if isinstance(lhs, ts.DeferredType): return rhs if isinstance(rhs, ts.DeferredType): return lhs - if lhs == rhs: - return lhs - else: - assert isinstance(lhs, ts.FieldType) and isinstance(rhs, ts.FieldType) - assert lhs.dtype == rhs.dtype - return ts.FieldType(dims=common.promote_dims(*[lhs.dims, rhs.dims]), dtype=lhs.dtype) + assert lhs == rhs + return lhs @_register_builtin_type_synthesizer( diff --git a/src/gt4py/next/type_system/type_specifications.py b/src/gt4py/next/type_system/type_specifications.py index 2fbd039d16..5b46f9dd0d 100644 --- a/src/gt4py/next/type_system/type_specifications.py +++ b/src/gt4py/next/type_system/type_specifications.py @@ -105,6 +105,12 @@ def __str__(self) -> str: dims = "..." if self.dims is Ellipsis else f"[{', '.join(dim.value for dim in self.dims)}]" return f"Field[{dims}, {self.dtype}]" + @eve_datamodels.validator("dims") + def _dims_validator( + self, attribute: eve_datamodels.Attribute, dims: list[common.Dimension] + ) -> None: + common.check_dims(dims) + class TupleType(DataType): # TODO(tehrengruber): Remove `DeferredType` again. This was erroneously diff --git a/tests/next_tests/unit_tests/iterator_tests/test_type_inference.py b/tests/next_tests/unit_tests/iterator_tests/test_type_inference.py index 59d8dbdb87..391cdc9157 100644 --- a/tests/next_tests/unit_tests/iterator_tests/test_type_inference.py +++ b/tests/next_tests/unit_tests/iterator_tests/test_type_inference.py @@ -78,7 +78,6 @@ def expression_test_cases(): (im.call("abs")(1), int_type), (im.call("power")(2.0, 2), float64_type), (im.plus(1, 2), int_type), - (im.plus(im.ref("inp1", float_i_field), im.ref("inp2", float_j_field)), float_ij_field), (im.eq(1, 2), bool_type), (im.deref(im.ref("it", it_on_e_of_e_type)), it_on_e_of_e_type.element_type), (im.can_deref(im.ref("it", it_on_e_of_e_type)), bool_type), @@ -189,7 +188,7 @@ def expression_test_cases(): ), ts.DeferredType(constraint=None), ), - # (im.as_fieldop(im.lambda_("x", "y")(im.plus(im.deref("x"), im.deref("y"))))( + # (im.as_fieldop(im.lambda_("x", "y")(im.plus(im.deref("x"), im.deref("y"))))( # TODO(SF-N): this needs PR 1853 # im.ref("inp1", float_i_field), im.ref("inp2", float_j_field)), float_ij_field), # if in field-view scope ( diff --git a/tests/next_tests/unit_tests/test_common.py b/tests/next_tests/unit_tests/test_common.py index c9126be127..fcf2caa03f 100644 --- a/tests/next_tests/unit_tests/test_common.py +++ b/tests/next_tests/unit_tests/test_common.py @@ -333,7 +333,7 @@ def test_domain_intersection_different_dimensions(a_domain, second_domain, expec def test_domain_intersection_reversed_dimensions(a_domain): - domain2 = Domain(dims=(JDim, IDim), ranges=(UnitRange(2, 12), UnitRange(7, 17))) + domain2 = Domain(dims=(IDim, JDim), ranges=(UnitRange(7, 17), UnitRange(2, 12))) assert a_domain & domain2 == Domain( dims=(IDim, JDim, KDim), ranges=(UnitRange(7, 10), UnitRange(5, 12), UnitRange(20, 30)) @@ -573,45 +573,42 @@ def test_domain_replace(index, named_ranges, domain, expected): def dimension_promotion_cases() -> ( - list[tuple[list[list[Dimension]], list[Dimension]]] -): # TODO: rename and remove promote_dims + list[tuple[list[list[Dimension]], list[Dimension] | None, None | Pattern]] +): raw_list = [ - # list of list of dimensions, expected result - ([[I, J], [I]], [I, J]), - ([[J], [I, J]], [I, J]), - ([[J, K], [I, J]], [I, J, K]), - ( - [[I, J], [J, I]], - [I, J], - ), - ( - [[K, J], [I, K]], - [I, J, K], - ), + # list of list of dimensions, expected result, expected error message + ([[I, J], [I]], [I, J], None), + ([[J], [I, J]], [I, J], None), + ([[J, K], [I, J]], [I, J, K], None), + ([[I, J], [J, I]], None, "Dimensions are not correctly ordered."), + ([[J, K], [I, K]], [I, J, K], None), + ([[K, J], [I, K]], None, "Dimensions are not correctly ordered."), ( - [[K, J], [I, K]], - [I, J, K], - ), - ( - [[V2E, C2E, J], [K, I, E2C2V], [E2C, E2V]], - [I, J, K, C2E, E2C, E2C2V, E2V, V2E], + [[J, V2E], [I, K, E2C2V]], + None, + "There are more than one dimension with DimensionKind 'LOCAL'.", ), + ([[J, V2E], [I, K]], [I, J, K, V2E], None), ] return [ - ( - [[el for el in arg] for arg in args], - [el for el in result] if result else result, - ) - for args, result in raw_list + ([[el for el in arg] for arg in args], [el for el in result] if result else result, msg) + for args, result, msg in raw_list ] -@pytest.mark.parametrize("dim_list,expected_result", dimension_promotion_cases()) +@pytest.mark.parametrize("dim_list,expected_result,expected_error_msg", dimension_promotion_cases()) def test_dimension_promotion( dim_list: list[list[Dimension]], expected_result: Optional[list[Dimension]], + expected_error_msg: Optional[str], ): - assert promote_dims(*dim_list) == expected_result + if expected_result: + assert promote_dims(*dim_list) == expected_result + else: + with pytest.raises(Exception) as exc_info: + promote_dims(*dim_list) + + assert exc_info.match(expected_error_msg) class TestCartesianConnectivity: From bd66acb9fbae649a179e79f936cbda711c044916 Mon Sep 17 00:00:00 2001 From: Sara Faghih-Naini Date: Tue, 25 Feb 2025 10:50:03 +0100 Subject: [PATCH 6/7] Fix tests --- src/gt4py/next/common.py | 6 +++--- src/gt4py/next/type_system/type_info.py | 5 ++++- .../ffront_tests/test_gt4py_builtins.py | 9 +-------- .../unit_tests/ffront_tests/test_foast_to_gtir.py | 4 ++-- .../otf_tests/binding_tests/test_cpp_interface.py | 8 ++++---- tests/next_tests/unit_tests/test_common.py | 15 ++++++++++++--- 6 files changed, 26 insertions(+), 21 deletions(-) diff --git a/src/gt4py/next/common.py b/src/gt4py/next/common.py index edf98d8e4a..0b2e033520 100644 --- a/src/gt4py/next/common.py +++ b/src/gt4py/next/common.py @@ -1131,7 +1131,7 @@ def check_dims(dims: list[Dimension]) -> None: raise ValueError("There are more than one dimension with DimensionKind 'LOCAL'.") if dims != sorted(dims, key=lambda dim: (dims_kind_order[dim.kind], dim.value)): - raise ValueError("Dimensions are not correctly ordered.") + raise ValueError(f"Dimensions {dims} are not correctly ordered.") def promote_dims(*dims_list: Sequence[Dimension]) -> list[Dimension]: @@ -1154,8 +1154,8 @@ def promote_dims(*dims_list: Sequence[Dimension]) -> list[Dimension]: >>> promote_dims([K, J], [I, K]) Traceback (most recent call last): ... - raise ValueError("Dimensions are not correctly ordered.") - ValueError: Dimensions are not correctly ordered. + raise ValueError(f"Dimensions {dims} are not correctly ordered.") + ValueError: Dimensions [Dimension(value='K', kind=), Dimension(value='J', kind=)] are not correctly ordered. >>> promote_dims([I, K], [J, E2V]) == [I, J, K, E2V] True >>> promote_dims([I, E2C], [K, E2V]) diff --git a/src/gt4py/next/type_system/type_info.py b/src/gt4py/next/type_system/type_info.py index b45bd4101f..e8c629753c 100644 --- a/src/gt4py/next/type_system/type_info.py +++ b/src/gt4py/next/type_system/type_info.py @@ -641,7 +641,10 @@ def return_type_field( new_dims.append(d) else: new_dims.extend(target_dims) - return ts.FieldType(dims=new_dims, dtype=field_type.dtype) + return ts.FieldType( + dims=sorted(new_dims, key=lambda dim: (common.dims_kind_order[dim.kind], dim.value)), + dtype=field_type.dtype, + ) UNDEFINED_ARG = types.new_class("UNDEFINED_ARG") diff --git a/tests/next_tests/integration_tests/feature_tests/ffront_tests/test_gt4py_builtins.py b/tests/next_tests/integration_tests/feature_tests/ffront_tests/test_gt4py_builtins.py index ab1c625fef..d7fe252cb4 100644 --- a/tests/next_tests/integration_tests/feature_tests/ffront_tests/test_gt4py_builtins.py +++ b/tests/next_tests/integration_tests/feature_tests/ffront_tests/test_gt4py_builtins.py @@ -95,16 +95,9 @@ def reduction_ek_field( return neighbor_sum(edge_f(V2E), axis=V2EDim) -@gtx.field_operator -def reduction_ke_field( - edge_f: common.Field[[KDim, Edge], np.int32], -) -> common.Field[[KDim, Vertex], np.int32]: - return neighbor_sum(edge_f(V2E), axis=V2EDim) - - @pytest.mark.uses_unstructured_shift @pytest.mark.parametrize( - "fop", [reduction_e_field, reduction_ek_field, reduction_ke_field], ids=lambda fop: fop.__name__ + "fop", [reduction_e_field, reduction_ek_field], ids=lambda fop: fop.__name__ ) def test_neighbor_sum(unstructured_case_3d, fop): v2e_table = unstructured_case_3d.offset_provider["V2E"].ndarray diff --git a/tests/next_tests/unit_tests/ffront_tests/test_foast_to_gtir.py b/tests/next_tests/unit_tests/ffront_tests/test_foast_to_gtir.py index c0d762efc8..776cd4e1a9 100644 --- a/tests/next_tests/unit_tests/ffront_tests/test_foast_to_gtir.py +++ b/tests/next_tests/unit_tests/ffront_tests/test_foast_to_gtir.py @@ -901,7 +901,7 @@ def foo() -> tuple[bool, bool, bool, bool, bool, bool, bool, bool]: def test_broadcast(): def foo(inp: gtx.Field[[TDim], float64]): - return broadcast(inp, (UDim, TDim)) + return broadcast(inp, (TDim, UDim)) parsed = FieldOperatorParser.apply_to_function(foo) lowered = FieldOperatorLowering.apply(parsed) @@ -912,7 +912,7 @@ def foo(inp: gtx.Field[[TDim], float64]): def test_scalar_broadcast(): def foo(): - return broadcast(1, (UDim, TDim)) + return broadcast(1, (TDim, UDim)) parsed = FieldOperatorParser.apply_to_function(foo) lowered = FieldOperatorLowering.apply(parsed) diff --git a/tests/next_tests/unit_tests/otf_tests/binding_tests/test_cpp_interface.py b/tests/next_tests/unit_tests/otf_tests/binding_tests/test_cpp_interface.py index 51b6bf512b..a25732649a 100644 --- a/tests/next_tests/unit_tests/otf_tests/binding_tests/test_cpp_interface.py +++ b/tests/next_tests/unit_tests/otf_tests/binding_tests/test_cpp_interface.py @@ -60,14 +60,14 @@ def function_buffer_example(): interface.Parameter( name="a_buf", type_=ts.FieldType( - dims=[gtx.Dimension("foo"), gtx.Dimension("bar")], + dims=[gtx.Dimension("bar"), gtx.Dimension("foo")], dtype=ts.ScalarType(ts.ScalarKind.FLOAT64), ), ), interface.Parameter( name="b_buf", type_=ts.FieldType( - dims=[gtx.Dimension("foo")], dtype=ts.ScalarType(ts.ScalarKind.INT64) + dims=[gtx.Dimension("bar")], dtype=ts.ScalarType(ts.ScalarKind.INT64) ), ), ], @@ -111,11 +111,11 @@ def function_tuple_example(): type_=ts.TupleType( types=[ ts.FieldType( - dims=[gtx.Dimension("foo"), gtx.Dimension("bar")], + dims=[gtx.Dimension("bar"), gtx.Dimension("foo")], dtype=ts.ScalarType(ts.ScalarKind.FLOAT64), ), ts.FieldType( - dims=[gtx.Dimension("foo"), gtx.Dimension("bar")], + dims=[gtx.Dimension("bar"), gtx.Dimension("foo")], dtype=ts.ScalarType(ts.ScalarKind.FLOAT64), ), ] diff --git a/tests/next_tests/unit_tests/test_common.py b/tests/next_tests/unit_tests/test_common.py index fcf2caa03f..adbe5911ab 100644 --- a/tests/next_tests/unit_tests/test_common.py +++ b/tests/next_tests/unit_tests/test_common.py @@ -10,6 +10,7 @@ from typing import Optional, Pattern import pytest +import re from gt4py import next as gtx import gt4py.next.common as common @@ -580,9 +581,17 @@ def dimension_promotion_cases() -> ( ([[I, J], [I]], [I, J], None), ([[J], [I, J]], [I, J], None), ([[J, K], [I, J]], [I, J, K], None), - ([[I, J], [J, I]], None, "Dimensions are not correctly ordered."), + ( + [[I, J], [J, I]], + None, + "Dimensions [Dimension(value='J', kind=), Dimension(value='I', kind=)] are not correctly ordered.", + ), ([[J, K], [I, K]], [I, J, K], None), - ([[K, J], [I, K]], None, "Dimensions are not correctly ordered."), + ( + [[K, J], [I, K]], + None, + "Dimensions [Dimension(value='K', kind=), Dimension(value='J', kind=)] are not correctly ordered.", + ), ( [[J, V2E], [I, K, E2C2V]], None, @@ -608,7 +617,7 @@ def test_dimension_promotion( with pytest.raises(Exception) as exc_info: promote_dims(*dim_list) - assert exc_info.match(expected_error_msg) + assert exc_info.match(re.escape(expected_error_msg)) class TestCartesianConnectivity: From e5ed26260eba01740107a4988760f0b6a4390b19 Mon Sep 17 00:00:00 2001 From: Sara Faghih-Naini Date: Tue, 25 Feb 2025 10:59:09 +0100 Subject: [PATCH 7/7] Fix test --- src/gt4py/next/type_system/type_info.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/gt4py/next/type_system/type_info.py b/src/gt4py/next/type_system/type_info.py index e8c629753c..0ce07565fd 100644 --- a/src/gt4py/next/type_system/type_info.py +++ b/src/gt4py/next/type_system/type_info.py @@ -568,7 +568,7 @@ def promote( True >>> promoted: ts.FieldType = promote( - ... ts.FieldType(dims=[J, I], dtype=dtype), ts.FieldType(dims=[K], dtype=dtype) + ... ts.FieldType(dims=[I, J], dtype=dtype), ts.FieldType(dims=[K], dtype=dtype) ... ) >>> promoted.dims == [I, J, K] and promoted.dtype == dtype True