diff --git a/src/gt4py/next/common.py b/src/gt4py/next/common.py index e5b393f1ae..75e6cd9e9b 100644 --- a/src/gt4py/next/common.py +++ b/src/gt4py/next/common.py @@ -69,6 +69,9 @@ def __str__(self) -> str: return self.value +dims_kind_order = {DimensionKind.HORIZONTAL: 1, DimensionKind.VERTICAL: 2, DimensionKind.LOCAL: 3} + + def dimension_to_implicit_offset(dim: str) -> str: """ Return name of offset implicitly defined by a dimension. @@ -1123,84 +1126,52 @@ class GridType(StrEnum): UNSTRUCTURED = "unstructured" +def check_dims(dims: list[Dimension]) -> None: + if sum(1 for dim in dims if dim.kind == DimensionKind.LOCAL) > 1: + raise ValueError("There are more than one dimension with DimensionKind 'LOCAL'.") + + if dims != sorted(dims, key=lambda dim: (dims_kind_order[dim.kind], dim.value)): + raise ValueError(f"Dimensions {dims} are not correctly ordered.") + + def promote_dims(*dims_list: Sequence[Dimension]) -> list[Dimension]: """ - Find a unique ordering of multiple (individually ordered) lists of dimensions. - - The resulting list of dimensions contains all dimensions of the arguments - in the order they originally appear. If no unique order exists or a - contradicting order is found an exception is raised. + Find an ordering of multiple lists of dimensions. - A modified version (ensuring uniqueness of the order) of - `Kahn's algorithm `_ - is used to topologically sort the arguments. + The resulting list contains all unique dimensions from the input lists, + sorted first by dims_kind_order, i.e., `Dimension.kind` (`HORIZONTAL` < `VERTICAL` < `LOCAL`) and then + lexicographically by `Dimension.value`. Examples: >>> from gt4py.next.common import Dimension - >>> I, J, K = (Dimension(value=dim) for dim in ["I", "J", "K"]) - >>> promote_dims([I, J], [I, J, K]) == [I, J, K] + >>> I = Dimension("I", DimensionKind.HORIZONTAL) + >>> J = Dimension("J", DimensionKind.HORIZONTAL) + >>> K = Dimension("K", DimensionKind.VERTICAL) + >>> E2V = Dimension("E2V", kind=DimensionKind.LOCAL) + >>> E2C = Dimension("E2C", kind=DimensionKind.LOCAL) + >>> promote_dims([J, K], [I, K]) == [I, J, K] True - - >>> promote_dims([I, J], [K]) # doctest: +ELLIPSIS + >>> promote_dims([K, J], [I, K]) Traceback (most recent call last): ... - ValueError: Dimensions can not be promoted. Could not determine order of the following dimensions: J, K. - - >>> promote_dims([I, J], [J, I]) # doctest: +ELLIPSIS + raise ValueError(f"Dimensions {dims} are not correctly ordered.") + ValueError: Dimensions [Dimension(value='K', kind=), Dimension(value='J', kind=)] are not correctly ordered. + >>> promote_dims([I, K], [J, E2V]) == [I, J, K, E2V] + True + >>> promote_dims([I, E2C], [K, E2V]) Traceback (most recent call last): ... - ValueError: Dimensions can not be promoted. The following dimensions appear in contradicting order: I, J. + raise ValueError("There are more than one dimension with DimensionKind 'LOCAL'.") + ValueError: There are more than one dimension with DimensionKind 'LOCAL'. """ - # build a graph with the vertices being dimensions and edges representing - # the order between two dimensions. The graph is encoded as a dictionary - # mapping dimensions to their predecessors, i.e. a dictionary containing - # adjacency lists. Since graphlib.TopologicalSorter uses predecessors - # (contrary to successors) we also use this directionality here. - graph: dict[Dimension, set[Dimension]] = {} + for dims in dims_list: - if len(dims) == 0: - continue - # create a vertex for each dimension - for dim in dims: - graph.setdefault(dim, set()) - # add edges - predecessor = dims[0] - for dim in dims[1:]: - graph[dim].add(predecessor) - predecessor = dim - - # modified version of Kahn's algorithm - topologically_sorted_list: list[Dimension] = [] - - # compute in-degree for each vertex - in_degree = {v: 0 for v in graph.keys()} - for v1 in graph: - for v2 in graph[v1]: - in_degree[v2] += 1 - - # process vertices with in-degree == 0 - # TODO(tehrengruber): avoid recomputation of zero_in_degree_vertex_list - while zero_in_degree_vertex_list := [v for v, d in in_degree.items() if d == 0]: - if len(zero_in_degree_vertex_list) != 1: - raise ValueError( - f"Dimensions can not be promoted. Could not determine " - f"order of the following dimensions: " - f"{', '.join((dim.value for dim in zero_in_degree_vertex_list))}." - ) - v = zero_in_degree_vertex_list[0] - del in_degree[v] - topologically_sorted_list.insert(0, v) - # update in-degree - for predecessor in graph[v]: - in_degree[predecessor] -= 1 - - if len(in_degree.items()) > 0: - raise ValueError( - f"Dimensions can not be promoted. The following dimensions " - f"appear in contradicting order: {', '.join((dim.value for dim in in_degree.keys()))}." - ) + check_dims(list(dims)) + unique_dims = {dim for dims in dims_list for dim in dims} - return topologically_sorted_list + promoted_dims = sorted(unique_dims, key=lambda dim: (dims_kind_order[dim.kind], dim.value)) + check_dims(promoted_dims) + return promoted_dims if unique_dims else [] class FieldBuiltinFuncRegistry: diff --git a/src/gt4py/next/iterator/type_system/inference.py b/src/gt4py/next/iterator/type_system/inference.py index fe450625db..652745db52 100644 --- a/src/gt4py/next/iterator/type_system/inference.py +++ b/src/gt4py/next/iterator/type_system/inference.py @@ -492,10 +492,7 @@ def visit_SetAt(self, node: itir.SetAt, *, ctx) -> None: # probably just change the behaviour of the lowering. Until then we do this more # complicated comparison. if isinstance(target_type, ts.FieldType) and isinstance(expr_type, ts.FieldType): - assert ( - set(expr_type.dims).issubset(set(target_type.dims)) - and target_type.dtype == expr_type.dtype - ) + assert target_type.dtype == expr_type.dtype def visit_AxisLiteral(self, node: itir.AxisLiteral, **kwargs) -> ts.DimensionType: return ts.DimensionType(dim=common.Dimension(value=node.value, kind=node.kind)) diff --git a/src/gt4py/next/iterator/type_system/type_specifications.py b/src/gt4py/next/iterator/type_system/type_specifications.py index 7825bf1c98..0028ae71ee 100644 --- a/src/gt4py/next/iterator/type_system/type_specifications.py +++ b/src/gt4py/next/iterator/type_system/type_specifications.py @@ -17,7 +17,7 @@ class NamedRangeType(ts.TypeSpec): class DomainType(ts.DataType): - dims: list[common.Dimension] | Literal["unknown"] + dims: list[common.Dimension] | Literal["unknown"] # TODO: remove unknown class OffsetLiteralType(ts.TypeSpec): @@ -25,7 +25,7 @@ class OffsetLiteralType(ts.TypeSpec): class IteratorType(ts.DataType, ts.CallableType): - position_dims: list[common.Dimension] | Literal["unknown"] + position_dims: list[common.Dimension] | Literal["unknown"] # TODO: remove unknown: 80%? defined_dims: list[common.Dimension] element_type: ts.DataType diff --git a/src/gt4py/next/iterator/type_system/type_synthesizer.py b/src/gt4py/next/iterator/type_system/type_synthesizer.py index 131b773dd2..ccd843bb16 100644 --- a/src/gt4py/next/iterator/type_system/type_synthesizer.py +++ b/src/gt4py/next/iterator/type_system/type_synthesizer.py @@ -14,7 +14,9 @@ from gt4py.eve.extended_typing import Callable, Iterable, Optional, Union from gt4py.next import common -from gt4py.next.iterator import builtins +from gt4py.next.ffront import fbuiltins +from gt4py.next.iterator import builtins, ir as itir +from gt4py.next.iterator.transforms import trace_shifts from gt4py.next.iterator.type_system import type_specifications as it_ts from gt4py.next.type_system import type_info, type_specifications as ts from gt4py.next.utils import tree_map @@ -302,6 +304,7 @@ def as_fieldop( # information on the ordering of dimensions. In this example # `as_fieldop(it1, it2 -> deref(it1) + deref(it2))(i_field, j_field)` # it is unclear if the result has dimension I, J or J, I. + # Todo: update comment if domain is None: domain = it_ts.DomainType(dims="unknown") @@ -314,15 +317,92 @@ def applied_as_fieldop(*fields) -> ts.FieldType | ts.DeferredType: ): return ts.DeferredType(constraint=None) + output_dims = [] + for i, field in enumerate(fields): + input_dims = common.promote_dims( + *[field.dims if isinstance(field, ts.FieldType) else []] + ) + seen = set() + if isinstance(stencil.node, itir.Expr): + shifts_results = trace_shifts.trace_stencil( + stencil.node, num_args=len(fields) + ) # TODO: access node differently? + + def resolve_shift( + input_dim: common.Dimension, shift_tuple: tuple[itir.OffsetLiteral, ...] + ) -> common.Dimension: + """ + Resolves the final dimension by applying shifts from the given shift tuple. + + Iterates through the shift tuple, updating `input_dim` based on matching offsets. + + Parameters: + - input_dim (common.Dimension): The initial dimension to resolve. + - shift_tuple (tuple[itir.OffsetLiteral, ...]): A tuple of offset literals defining the shift. + + Returns: + - common.Dimension: The resolved dimension or `input_dim` if no shift is applied. + """ + + final_target: common.Dimension = input_dim + + for off_literal in shift_tuple[::2]: + if offset_provider_type: + offset_type = offset_provider_type[off_literal.value] # type: ignore [index] # ensured by accessing only every second element + if ( + isinstance(offset_type, common.Dimension) and input_dim == offset_type + ): # no shift applied + return offset_type + if isinstance( + offset_type, (fbuiltins.FieldOffset, common.NeighborConnectivityType) + ): + off_source = ( + offset_type.source + if isinstance(offset_type, fbuiltins.FieldOffset) + else (offset_type.codomain) + ) + off_targets = ( + offset_type.target + if isinstance(offset_type, fbuiltins.FieldOffset) + else (offset_type.domain) + ) + + if input_dim == off_source: # Check if input fits to offset + for target in off_targets: + if ( + target.value != off_literal.value + ): # Exclude target matching off_literal.value + final_target = target + input_dim = target # Update input_dim for next iteration + return final_target + + for shift_tuple in shifts_results[ + i + ]: # Use shift tuple corresponding to the input field + for input_dim in input_dims: + final_dim = resolve_shift(input_dim, shift_tuple) + if final_dim not in seen: + seen.add(final_dim) + output_dims.append(final_dim) + else: + output_dims = domain.dims + stencil_return = stencil( *(_convert_as_fieldop_input_to_iterator(domain, field) for field in fields), offset_provider_type=offset_provider_type, ) + assert isinstance(stencil_return, ts.DataType) return type_info.apply_to_primitive_constituents( - lambda el_type: ts.FieldType(dims=domain.dims, dtype=el_type) - if domain.dims != "unknown" - else ts.DeferredType(constraint=ts.FieldType), + lambda el_type: ts.FieldType( + dims=sorted( + {dim for dim in output_dims}, + key=lambda dim: (common.dims_kind_order[dim.kind], dim.value), + ) + if output_dims != "unknown" + else [], + dtype=el_type, + ), stencil_return, ) diff --git a/src/gt4py/next/type_system/type_info.py b/src/gt4py/next/type_system/type_info.py index 27dd2cf02c..0ce07565fd 100644 --- a/src/gt4py/next/type_system/type_info.py +++ b/src/gt4py/next/type_system/type_info.py @@ -567,12 +567,11 @@ def promote( >>> promoted.dims == [I, J, K] and promoted.dtype == dtype True - >>> promote( + >>> promoted: ts.FieldType = promote( ... ts.FieldType(dims=[I, J], dtype=dtype), ts.FieldType(dims=[K], dtype=dtype) - ... ) # doctest: +ELLIPSIS - Traceback (most recent call last): - ... - ValueError: Dimensions can not be promoted. Could not determine order of the following dimensions: J, K. + ... ) + >>> promoted.dims == [I, J, K] and promoted.dtype == dtype + True """ if not always_field and all(isinstance(type_, ts.ScalarType) for type_ in types): if not all(type_ == types[0] for type_ in types): @@ -642,7 +641,10 @@ def return_type_field( new_dims.append(d) else: new_dims.extend(target_dims) - return ts.FieldType(dims=new_dims, dtype=field_type.dtype) + return ts.FieldType( + dims=sorted(new_dims, key=lambda dim: (common.dims_kind_order[dim.kind], dim.value)), + dtype=field_type.dtype, + ) UNDEFINED_ARG = types.new_class("UNDEFINED_ARG") diff --git a/src/gt4py/next/type_system/type_specifications.py b/src/gt4py/next/type_system/type_specifications.py index 2fbd039d16..5b46f9dd0d 100644 --- a/src/gt4py/next/type_system/type_specifications.py +++ b/src/gt4py/next/type_system/type_specifications.py @@ -105,6 +105,12 @@ def __str__(self) -> str: dims = "..." if self.dims is Ellipsis else f"[{', '.join(dim.value for dim in self.dims)}]" return f"Field[{dims}, {self.dtype}]" + @eve_datamodels.validator("dims") + def _dims_validator( + self, attribute: eve_datamodels.Attribute, dims: list[common.Dimension] + ) -> None: + common.check_dims(dims) + class TupleType(DataType): # TODO(tehrengruber): Remove `DeferredType` again. This was erroneously diff --git a/tests/next_tests/integration_tests/cases.py b/tests/next_tests/integration_tests/cases.py index 89ad556476..6e8ff1b3f6 100644 --- a/tests/next_tests/integration_tests/cases.py +++ b/tests/next_tests/integration_tests/cases.py @@ -60,6 +60,7 @@ # mypy does not accept [IDim, ...] as a type IField: TypeAlias = gtx.Field[[IDim], np.int32] # type: ignore [valid-type] +JField: TypeAlias = gtx.Field[[JDim], np.int32] # type: ignore [valid-type] IFloatField: TypeAlias = gtx.Field[[IDim], np.float64] # type: ignore [valid-type] IBoolField: TypeAlias = gtx.Field[[IDim], bool] # type: ignore [valid-type] KField: TypeAlias = gtx.Field[[KDim], np.int32] # type: ignore [valid-type] diff --git a/tests/next_tests/integration_tests/feature_tests/ffront_tests/test_gt4py_builtins.py b/tests/next_tests/integration_tests/feature_tests/ffront_tests/test_gt4py_builtins.py index ab1c625fef..d7fe252cb4 100644 --- a/tests/next_tests/integration_tests/feature_tests/ffront_tests/test_gt4py_builtins.py +++ b/tests/next_tests/integration_tests/feature_tests/ffront_tests/test_gt4py_builtins.py @@ -95,16 +95,9 @@ def reduction_ek_field( return neighbor_sum(edge_f(V2E), axis=V2EDim) -@gtx.field_operator -def reduction_ke_field( - edge_f: common.Field[[KDim, Edge], np.int32], -) -> common.Field[[KDim, Vertex], np.int32]: - return neighbor_sum(edge_f(V2E), axis=V2EDim) - - @pytest.mark.uses_unstructured_shift @pytest.mark.parametrize( - "fop", [reduction_e_field, reduction_ek_field, reduction_ke_field], ids=lambda fop: fop.__name__ + "fop", [reduction_e_field, reduction_ek_field], ids=lambda fop: fop.__name__ ) def test_neighbor_sum(unstructured_case_3d, fop): v2e_table = unstructured_case_3d.offset_provider["V2E"].ndarray diff --git a/tests/next_tests/unit_tests/ffront_tests/test_foast_to_gtir.py b/tests/next_tests/unit_tests/ffront_tests/test_foast_to_gtir.py index c0d762efc8..776cd4e1a9 100644 --- a/tests/next_tests/unit_tests/ffront_tests/test_foast_to_gtir.py +++ b/tests/next_tests/unit_tests/ffront_tests/test_foast_to_gtir.py @@ -901,7 +901,7 @@ def foo() -> tuple[bool, bool, bool, bool, bool, bool, bool, bool]: def test_broadcast(): def foo(inp: gtx.Field[[TDim], float64]): - return broadcast(inp, (UDim, TDim)) + return broadcast(inp, (TDim, UDim)) parsed = FieldOperatorParser.apply_to_function(foo) lowered = FieldOperatorLowering.apply(parsed) @@ -912,7 +912,7 @@ def foo(inp: gtx.Field[[TDim], float64]): def test_scalar_broadcast(): def foo(): - return broadcast(1, (UDim, TDim)) + return broadcast(1, (TDim, UDim)) parsed = FieldOperatorParser.apply_to_function(foo) lowered = FieldOperatorLowering.apply(parsed) diff --git a/tests/next_tests/unit_tests/ffront_tests/test_type_deduction.py b/tests/next_tests/unit_tests/ffront_tests/test_type_deduction.py index 254772fd8a..373ad00aec 100644 --- a/tests/next_tests/unit_tests/ffront_tests/test_type_deduction.py +++ b/tests/next_tests/unit_tests/ffront_tests/test_type_deduction.py @@ -88,13 +88,11 @@ def test_binop_nonmatching_dims(): def nonmatching(a: Field[[X], float64], b: Field[[Y], float64]): return a + b - with pytest.raises( - errors.DSLError, - match=( - r"Could not promote 'Field\[\[X], float64\]' and 'Field\[\[Y\], float64\]' to common type in call to +." - ), - ): - _ = FieldOperatorParser.apply_to_function(nonmatching) + parsed = FieldOperatorParser.apply_to_function(nonmatching) + + assert parsed.body.stmts[0].value.type == ts.FieldType( + dims=[X, Y], dtype=ts.ScalarType(kind=ts.ScalarKind.FLOAT64) + ) def test_bitopping_float(): diff --git a/tests/next_tests/unit_tests/iterator_tests/test_type_inference.py b/tests/next_tests/unit_tests/iterator_tests/test_type_inference.py index a39fe3c6d8..d4918602fd 100644 --- a/tests/next_tests/unit_tests/iterator_tests/test_type_inference.py +++ b/tests/next_tests/unit_tests/iterator_tests/test_type_inference.py @@ -31,6 +31,7 @@ V2E, E2VDim, Edge, + Cell, IDim, Ioff, JDim, @@ -43,7 +44,7 @@ unstructured_case, ) from next_tests.integration_tests.feature_tests.ffront_tests.ffront_test_utils import simple_mesh - +from next_tests.integration_tests.cases import IField, JField bool_type = ts.ScalarType(kind=ts.ScalarKind.BOOL) int_type = ts.ScalarType(kind=ts.ScalarKind.INT32) @@ -52,10 +53,16 @@ int_list_type = ts.ListType(element_type=int_type) float_i_field = ts.FieldType(dims=[IDim], dtype=float64_type) +float_j_field = ts.FieldType(dims=[JDim], dtype=float64_type) +float_k_field = ts.FieldType(dims=[KDim], dtype=float64_type) +float_ij_field = ts.FieldType(dims=[IDim, JDim], dtype=float64_type) float_vertex_k_field = ts.FieldType(dims=[Vertex, KDim], dtype=float64_type) float_edge_k_field = ts.FieldType(dims=[Edge, KDim], dtype=float64_type) +float_edge_field = ts.FieldType(dims=[Edge], dtype=float64_type) +float_vertex_field = ts.FieldType(dims=[Vertex], dtype=float64_type) float_vertex_v2e_field = ts.FieldType(dims=[Vertex, V2EDim], dtype=float64_type) + it_on_v_of_e_type = it_ts.IteratorType( position_dims=[Vertex, KDim], defined_dims=[Edge, KDim], element_type=int_type ) @@ -174,10 +181,11 @@ def expression_test_cases(): im.as_fieldop( im.lambda_("a", "b")(im.make_tuple(im.deref("a"), im.deref("b"))), im.call("cartesian_domain")( - im.call("named_range")(itir.AxisLiteral(value="IDim"), 0, 1) + im.call("named_range")(itir.AxisLiteral(value="IDim"), 0, 1), + im.call("named_range")(itir.AxisLiteral(value="JDim"), 0, 1), ), - )(im.ref("inp1", float_i_field), im.ref("inp2", float_i_field)), - ts.TupleType(types=[float_i_field, float_i_field]), + )(im.ref("inp1", float_i_field), im.ref("inp2", float_j_field)), + ts.TupleType(types=[float_ij_field, float_ij_field]), ), ( im.as_fieldop(im.lambda_("x")(im.deref("x")))( @@ -185,6 +193,8 @@ def expression_test_cases(): ), ts.DeferredType(constraint=None), ), + # (im.as_fieldop(im.lambda_("x", "y")(im.plus(im.deref("x"), im.deref("y"))))( # TODO(SF-N): this needs PR 1853 + # im.ref("inp1", float_i_field), im.ref("inp2", float_j_field)), float_ij_field), # if in field-view scope ( im.if_( @@ -400,7 +410,12 @@ def test_fencil_with_nb_field_input(): ) result = itir_type_inference.infer(testee, offset_provider_type=mesh.offset_provider_type) - + assert result.body[0].expr.type == ts.FieldType(dims=[Vertex, V2EDim], dtype=float64_type) + assert result.body[0].expr.fun.args[0].type.pos_only_args[0] == it_ts.IteratorType( + position_dims=float_vertex_k_field.dims, + defined_dims=float_vertex_field.dims, + element_type=ts.ListType(element_type=ts.ScalarType(kind=ts.ScalarKind.FLOAT64)), + ) stencil = result.body[0].expr.fun.args[0] assert stencil.expr.args[0].type == float64_list_type assert stencil.type.returns == float64_type @@ -458,10 +473,7 @@ def test_program_setat_without_domain(): result = itir_type_inference.infer(testee, offset_provider_type={"Ioff": IDim}) - assert ( - isinstance(result.body[0].expr.type, ts.DeferredType) - and result.body[0].expr.type.constraint == ts.FieldType - ) + assert result.body[0].expr.type, ts.FieldType(dims=[IDim], dtype=float64_type) def test_if_stmt(): @@ -488,19 +500,188 @@ def test_if_stmt(): assert result.true_branch[0].expr.type == float_i_field -def test_as_fieldop_without_domain(): +def test_as_fieldop_without_domain_I(): testee = im.as_fieldop(im.lambda_("it")(im.deref(im.shift("IOff", 1)("it"))))( im.ref("inp", float_i_field) ) result = itir_type_inference.infer( testee, offset_provider_type={"IOff": IDim}, allow_undeclared_symbols=True ) - assert result.type == ts.DeferredType(constraint=ts.FieldType) + assert result.type == ts.FieldType(dims=[IDim], dtype=float64_type) + assert result.fun.args[0].type.pos_only_args[0] == it_ts.IteratorType( + position_dims="unknown", defined_dims=float_i_field.dims, element_type=float_i_field.dtype + ) + + +def test_as_fieldop_without_domain_different_three_datatypes(): + stencil = im.lambda_("it1", "it2", "it3")( + im.plus( + im.plus( + im.deref(im.shift("C2E", 1)(im.shift("E2V", 1)("it1"))), + im.deref(im.shift("KOff", 1)("it2")), + ), + im.deref(im.shift("IOff", 1)("it3")), + ) + ) + + testee = im.as_fieldop(stencil)( + im.ref("inp1", float_vertex_field), + im.ref("inp2", float_edge_k_field), + im.ref("inp3", float_i_field), + ) + result = itir_type_inference.infer( + testee, + offset_provider_type={"C2E": C2E, "E2V": E2V, "KOff": KDim, "IOff": IDim}, + allow_undeclared_symbols=True, + ) + assert result.type == ts.FieldType(dims=[Cell, Edge, IDim, KDim], dtype=float64_type) assert result.fun.args[0].type.pos_only_args[0] == it_ts.IteratorType( + position_dims="unknown", + defined_dims=float_vertex_field.dims, + element_type=float_vertex_field.dtype, + ) + assert result.fun.args[0].type.pos_only_args[1] == it_ts.IteratorType( + position_dims="unknown", + defined_dims=float_edge_k_field.dims, + element_type=float_edge_k_field.dtype, + ) + assert result.fun.args[0].type.pos_only_args[2] == it_ts.IteratorType( position_dims="unknown", defined_dims=float_i_field.dims, element_type=float_i_field.dtype ) +def test_as_fieldop_without_domain_new(): + stencil = im.lambda_("it1", "it2")( + im.plus(im.deref("it1"), im.deref(im.shift("KOff", 1)(im.shift("V2E", 0)("it2")))) + ) + + testee = im.as_fieldop(stencil)( + im.ref("inp2", float_vertex_field), im.ref("inp1", float_edge_k_field) + ) + result = itir_type_inference.infer( + testee, offset_provider_type={"V2E": V2E, "KOff": KDim}, allow_undeclared_symbols=True + ) + assert result.type == ts.FieldType(dims=[Vertex, KDim], dtype=float64_type) + assert result.fun.args[0].type.pos_only_args[0] == it_ts.IteratorType( + position_dims="unknown", + defined_dims=float_vertex_field.dims, + element_type=float_vertex_field.dtype, + ) + assert result.fun.args[0].type.pos_only_args[1] == it_ts.IteratorType( + position_dims="unknown", + defined_dims=float_edge_k_field.dims, + element_type=float_edge_k_field.dtype, + ) + + +def test_as_fieldop_without_domain_V2E(): + stencil = im.lambda_("it")(im.deref(im.shift("V2E", 0)("it"))) + + testee = im.as_fieldop(stencil)(im.ref("inp", float_edge_field)) + result = itir_type_inference.infer( + testee, offset_provider_type={"V2E": V2E}, allow_undeclared_symbols=True + ) + assert result.type == ts.FieldType(dims=[Vertex], dtype=float64_type) + assert result.fun.args[0].type.pos_only_args[0] == it_ts.IteratorType( + position_dims="unknown", + defined_dims=float_edge_field.dims, + element_type=float_edge_field.dtype, + ) + + +def test_as_fieldop_without_domain_V2E_new(): + stencil = im.lambda_("it")( + im.deref(im.shift("C2E", 0)(im.shift("E2V", 0)(im.shift("V2E", 0)("it")))) + ) + + testee = im.as_fieldop(stencil)(im.ref("inp", float_edge_field)) + result = itir_type_inference.infer( + testee, + offset_provider_type={"C2E": C2E, "E2V": E2V, "V2E": V2E}, + allow_undeclared_symbols=True, + ) + assert result.type == ts.FieldType(dims=[Cell], dtype=float64_type) + assert result.fun.args[0].type.pos_only_args[0] == it_ts.IteratorType( + position_dims="unknown", + defined_dims=float_edge_field.dims, + element_type=float_edge_field.dtype, + ) + + +def test_as_fieldop_without_domain_only_one_shift(): + stencil = im.lambda_("it1", "it2")( + im.plus(im.deref(im.shift("V2E", 1)("it1")), im.deref("it2")) + ) + + testee = im.as_fieldop(stencil)( + im.ref("inp1", float_edge_field), im.ref("inp2", float_edge_field) + ) + result = itir_type_inference.infer( + testee, offset_provider_type={"V2E": V2E, "Edge": Edge}, allow_undeclared_symbols=True + ) + assert result.type == ts.FieldType(dims=[Edge, Vertex], dtype=float64_type) + assert result.fun.args[0].type.pos_only_args[0] == it_ts.IteratorType( + position_dims="unknown", + defined_dims=float_edge_field.dims, + element_type=float_edge_field.dtype, + ) + assert result.fun.args[0].type.pos_only_args[1] == it_ts.IteratorType( + position_dims="unknown", + defined_dims=float_edge_field.dims, + element_type=float_edge_field.dtype, + ) + + +def test_as_fieldop_without_domain_new_nested_shifts(): + stencil = im.lambda_("it")(im.deref(im.shift("C2E", 0)(im.shift("E2V", 0)("it")))) + + testee = im.as_fieldop(stencil)(im.ref("inp", float_vertex_field)) + result = itir_type_inference.infer( + testee, offset_provider_type={"C2E": C2E, "E2V": E2V}, allow_undeclared_symbols=True + ) + assert result.type == ts.FieldType(dims=[Cell], dtype=float64_type) + assert result.fun.args[0].type.pos_only_args[0] == it_ts.IteratorType( + position_dims="unknown", + defined_dims=float_vertex_field.dims, + element_type=float_vertex_field.dtype, + ) + + +def test_as_fieldop_without_domain_new_no_shift(): + stencil = im.lambda_("it")(im.deref("it")) + + testee = im.as_fieldop(stencil)(im.ref("inp", float_edge_field)) + result = itir_type_inference.infer( + testee, offset_provider_type={}, allow_undeclared_symbols=True + ) + assert result.type == ts.FieldType(dims=[Edge], dtype=float64_type) + assert result.fun.args[0].type.pos_only_args[0] == it_ts.IteratorType( + position_dims="unknown", + defined_dims=float_edge_field.dims, + element_type=float_edge_field.dtype, + ) + + +def test_as_fieldop_without_domain_plus(): + stencil = im.lambda_("it1", "it2")( + im.plus(im.deref(im.shift("V2E", 1)("it1")), im.deref(im.shift("KOff", 1)("it2"))) + ) + + testee = im.as_fieldop(stencil)(im.ref("inp1", float_edge_field), im.ref("inp2", float_k_field)) + result = itir_type_inference.infer( + testee, offset_provider_type={"V2E": V2E, "KOff": KDim}, allow_undeclared_symbols=True + ) + assert result.type == ts.FieldType(dims=[Vertex, KDim], dtype=float64_type) + assert result.fun.args[0].type.pos_only_args[0] == it_ts.IteratorType( + position_dims="unknown", + defined_dims=float_edge_field.dims, + element_type=float_edge_field.dtype, + ) + assert result.fun.args[0].type.pos_only_args[1] == it_ts.IteratorType( + position_dims="unknown", defined_dims=float_k_field.dims, element_type=float_k_field.dtype + ) + + def test_reinference(): testee = im.make_tuple(im.ref("inp1", float_i_field), im.ref("inp2", float_i_field)) result = itir_type_inference.reinfer(copy.deepcopy(testee)) diff --git a/tests/next_tests/unit_tests/otf_tests/binding_tests/test_cpp_interface.py b/tests/next_tests/unit_tests/otf_tests/binding_tests/test_cpp_interface.py index 51b6bf512b..a25732649a 100644 --- a/tests/next_tests/unit_tests/otf_tests/binding_tests/test_cpp_interface.py +++ b/tests/next_tests/unit_tests/otf_tests/binding_tests/test_cpp_interface.py @@ -60,14 +60,14 @@ def function_buffer_example(): interface.Parameter( name="a_buf", type_=ts.FieldType( - dims=[gtx.Dimension("foo"), gtx.Dimension("bar")], + dims=[gtx.Dimension("bar"), gtx.Dimension("foo")], dtype=ts.ScalarType(ts.ScalarKind.FLOAT64), ), ), interface.Parameter( name="b_buf", type_=ts.FieldType( - dims=[gtx.Dimension("foo")], dtype=ts.ScalarType(ts.ScalarKind.INT64) + dims=[gtx.Dimension("bar")], dtype=ts.ScalarType(ts.ScalarKind.INT64) ), ), ], @@ -111,11 +111,11 @@ def function_tuple_example(): type_=ts.TupleType( types=[ ts.FieldType( - dims=[gtx.Dimension("foo"), gtx.Dimension("bar")], + dims=[gtx.Dimension("bar"), gtx.Dimension("foo")], dtype=ts.ScalarType(ts.ScalarKind.FLOAT64), ), ts.FieldType( - dims=[gtx.Dimension("foo"), gtx.Dimension("bar")], + dims=[gtx.Dimension("bar"), gtx.Dimension("foo")], dtype=ts.ScalarType(ts.ScalarKind.FLOAT64), ), ] diff --git a/tests/next_tests/unit_tests/test_common.py b/tests/next_tests/unit_tests/test_common.py index 8f46fc7ce1..adbe5911ab 100644 --- a/tests/next_tests/unit_tests/test_common.py +++ b/tests/next_tests/unit_tests/test_common.py @@ -10,7 +10,9 @@ from typing import Optional, Pattern import pytest +import re +from gt4py import next as gtx import gt4py.next.common as common from gt4py.next.common import ( Dimension, @@ -25,7 +27,14 @@ unit_range, ) - +I = gtx.Dimension("I") +J = gtx.Dimension("J") +K = gtx.Dimension("K", kind=DimensionKind.VERTICAL) +C2E = Dimension("C2E", kind=DimensionKind.LOCAL) +V2E = Dimension("V2E", kind=DimensionKind.LOCAL) +E2V = Dimension("E2V", kind=DimensionKind.LOCAL) +E2C = Dimension("E2C", kind=DimensionKind.LOCAL) +E2C2V = Dimension("E2C2V", kind=DimensionKind.LOCAL) ECDim = Dimension("ECDim") IDim = Dimension("IDim") JDim = Dimension("JDim") @@ -325,13 +334,11 @@ def test_domain_intersection_different_dimensions(a_domain, second_domain, expec def test_domain_intersection_reversed_dimensions(a_domain): - domain2 = Domain(dims=(JDim, IDim), ranges=(UnitRange(2, 12), UnitRange(7, 17))) + domain2 = Domain(dims=(IDim, JDim), ranges=(UnitRange(7, 17), UnitRange(2, 12))) - with pytest.raises( - ValueError, - match="Dimensions can not be promoted. The following dimensions appear in contradicting order: IDim, JDim.", - ): - a_domain & domain2 + assert a_domain & domain2 == Domain( + dims=(IDim, JDim, KDim), ranges=(UnitRange(7, 10), UnitRange(5, 12), UnitRange(20, 30)) + ) @pytest.mark.parametrize( @@ -571,27 +578,29 @@ def dimension_promotion_cases() -> ( ): raw_list = [ # list of list of dimensions, expected result, expected error message - ([["I", "J"], ["I"]], ["I", "J"], None), - ([["I", "J"], ["J"]], ["I", "J"], None), - ([["I", "J"], ["J", "K"]], ["I", "J", "K"], None), + ([[I, J], [I]], [I, J], None), + ([[J], [I, J]], [I, J], None), + ([[J, K], [I, J]], [I, J, K], None), ( - [["I", "J"], ["J", "I"]], + [[I, J], [J, I]], None, - r"The following dimensions appear in contradicting order: I, J.", + "Dimensions [Dimension(value='J', kind=), Dimension(value='I', kind=)] are not correctly ordered.", ), + ([[J, K], [I, K]], [I, J, K], None), ( - [["I", "K"], ["J", "K"]], + [[K, J], [I, K]], None, - r"Could not determine order of the following dimensions: I, J", + "Dimensions [Dimension(value='K', kind=), Dimension(value='J', kind=)] are not correctly ordered.", ), + ( + [[J, V2E], [I, K, E2C2V]], + None, + "There are more than one dimension with DimensionKind 'LOCAL'.", + ), + ([[J, V2E], [I, K]], [I, J, K, V2E], None), ] - # transform dimension names into Dimension objects return [ - ( - [[Dimension(el) for el in arg] for arg in args], - [Dimension(el) for el in result] if result else result, - msg, - ) + ([[el for el in arg] for arg in args], [el for el in result] if result else result, msg) for args, result, msg in raw_list ] @@ -608,7 +617,7 @@ def test_dimension_promotion( with pytest.raises(Exception) as exc_info: promote_dims(*dim_list) - assert exc_info.match(expected_error_msg) + assert exc_info.match(re.escape(expected_error_msg)) class TestCartesianConnectivity: diff --git a/tests/next_tests/unit_tests/test_type_system.py b/tests/next_tests/unit_tests/test_type_system.py index 99758d6f14..69ff54b711 100644 --- a/tests/next_tests/unit_tests/test_type_system.py +++ b/tests/next_tests/unit_tests/test_type_system.py @@ -305,10 +305,7 @@ def callable_type_info_cases(): ts.FieldType(dims=[KDim], dtype=int_type), ], {}, - [ - r"Dimensions can not be promoted. Could not determine order of the " - r"following dimensions: J, K." - ], + [], ts.FieldType(dims=[IDim, JDim, KDim], dtype=float_type), ), (