diff --git a/src/gt4py/_core/definitions.py b/src/gt4py/_core/definitions.py index 769be9ba5a..1cd58b78c8 100644 --- a/src/gt4py/_core/definitions.py +++ b/src/gt4py/_core/definitions.py @@ -20,7 +20,6 @@ import functools import math import numbers -from typing import overload import numpy as np import numpy.typing as npt @@ -42,6 +41,7 @@ TypeVar, Union, cast, + overload, ) diff --git a/src/gt4py/eve/extended_typing.py b/src/gt4py/eve/extended_typing.py index e406a5f097..750eeba903 100644 --- a/src/gt4py/eve/extended_typing.py +++ b/src/gt4py/eve/extended_typing.py @@ -127,11 +127,16 @@ def __dir__() -> List[str]: return self_func.__cached_dir -_T = TypeVar("_T") - # -- Common type aliases -- NoArgsCallable = Callable[[], Any] +_A = TypeVar("_A", contravariant=True) +_R = TypeVar("_R", covariant=True) + + +class ArgsOnlyCallable(Protocol[_A, _R]): + def __call__(self, *args: _A) -> _R: ... + # -- Typing annotations -- if _sys.version_info >= (3, 9): @@ -367,6 +372,9 @@ def has_type_parameters(cls: Type) -> bool: return issubclass(cls, Generic) and len(getattr(cls, "__parameters__", [])) > 0 # type: ignore[arg-type] # Generic not considered as a class +_T = TypeVar("_T") + + def get_actual_type(obj: _T) -> Type[_T]: """Return type of an object (also working for GenericAlias instances which pretend to be an actual type).""" return StdGenericAliasType if isinstance(obj, StdGenericAliasType) else type(obj) diff --git a/src/gt4py/eve/utils.py b/src/gt4py/eve/utils.py index d1f9d0f7d5..2c2d4b6c58 100644 --- a/src/gt4py/eve/utils.py +++ b/src/gt4py/eve/utils.py @@ -52,6 +52,7 @@ from . import extended_typing as xtyping from .extended_typing import ( Any, + ArgsOnlyCallable, Callable, Collection, Dict, @@ -84,6 +85,15 @@ T = TypeVar("T") +def first(iterable: Iterable[T], *, default: Union[T, NothingType] = NOTHING) -> T: + try: + return next(iter(iterable)) + except StopIteration as error: + if default is not NOTHING: + return cast(T, default) + raise error + + def isinstancechecker(type_info: Union[Type, Iterable[Type]]) -> Callable[[Any], bool]: """Return a callable object that checks if operand is an instance of `type_info`. @@ -227,9 +237,31 @@ def itemgetter_(key: Any, default: Any = NOTHING) -> Callable[[Any], Any]: _P = ParamSpec("_P") +_S = TypeVar("_S") _T = TypeVar("_T") +@dataclasses.dataclass(frozen=True) +class IndexerCallable(Generic[_S, _T]): + """ + An indexer class applying the wrapped function to the index arguments. + + Examples: + >>> indexer = IndexerCallable(lambda x: x**2) + >>> indexer[3] + 9 + + >>> indexer = IndexerCallable(lambda a, b: a + b) + >>> indexer[3, 4] + 7 + """ + + func: ArgsOnlyCallable[_S, _T] + + def __getitem__(self, key: _S | Tuple[_S, ...]) -> _T: + return self.func(*key) if isinstance(key, tuple) else self.func(key) + + class fluid_partial(functools.partial): """Create a `functools.partial` with support for multiple applications calling `.partial()`.""" diff --git a/src/gt4py/next/common.py b/src/gt4py/next/common.py index 7a0f0c54eb..5dadcd663b 100644 --- a/src/gt4py/next/common.py +++ b/src/gt4py/next/common.py @@ -19,7 +19,7 @@ import dataclasses import enum import functools -import numbers +import math import types from collections.abc import Mapping, Sequence @@ -27,6 +27,7 @@ import numpy.typing as npt from gt4py._core import definitions as core_defs +from gt4py.eve import utils from gt4py.eve.extended_typing import ( TYPE_CHECKING, Any, @@ -34,6 +35,7 @@ ClassVar, Final, Generic, + Literal, NamedTuple, Never, Optional, @@ -53,11 +55,10 @@ DimT = TypeVar("DimT", bound="Dimension") # , covariant=True) -ShapeT = TypeVarTuple("ShapeT") +ShapeTs = TypeVarTuple("ShapeTs") -class Dims(Generic[Unpack[ShapeT]]): - shape: tuple[Unpack[ShapeT]] +class Dims(tuple[Unpack[ShapeTs]]): ... DimsT = TypeVar("DimsT", bound=Dims, covariant=True) @@ -375,10 +376,10 @@ def __init__( ) -> None: if dims is not None or ranges is not None: if dims is None and ranges is None: - raise ValueError("Either both none of 'dims' and 'ranges' must be specified.") + raise ValueError("Either specify both 'dims' and 'ranges' or neither.") if len(args) > 0: raise ValueError( - "No extra 'args' allowed when constructing fomr 'dims' and 'ranges'." + "No extra 'args' allowed when constructing from 'dims' and 'ranges'." ) assert dims is not None and ranges is not None # for mypy @@ -409,9 +410,6 @@ def __init__( if len(set(self.dims)) != len(self.dims): raise NotImplementedError(f"Domain dimensions must be unique, not '{self.dims}'.") - def __len__(self) -> int: - return len(self.ranges) - @property def ndim(self) -> int: return len(self.dims) @@ -420,13 +418,15 @@ def ndim(self) -> int: def shape(self) -> tuple[int, ...]: return tuple(len(r) for r in self.ranges) - @classmethod - def is_finite(cls, obj: Domain) -> TypeGuard[FiniteDomain]: - # classmethod since TypeGuards requires the guarded obj as separate argument - return all(UnitRange.is_finite(rng) for rng in obj.ranges) + @property + def size(self) -> Optional[int]: + return math.prod(self.shape) if all(UnitRange.is_finite(r) for r in self.ranges) else None - def is_empty(self) -> bool: - return any(rng.is_empty() for rng in self.ranges) + def __len__(self) -> int: + return len(self.ranges) + + def __str__(self) -> str: + return f"Domain({', '.join(f'{e}' for e in self)})" @overload def __getitem__(self, index: int) -> NamedRange: ... @@ -438,37 +438,35 @@ def __getitem__(self, index: slice) -> Self: ... def __getitem__(self, index: Dimension) -> NamedRange: ... def __getitem__(self, index: int | slice | Dimension) -> NamedRange | Domain: + if isinstance(index, Dimension): + try: + index = self.dims.index(index) + except ValueError as ex: + raise KeyError(f"No Dimension of type '{index}' is present in the Domain.") from ex if isinstance(index, int): return NamedRange(dim=self.dims[index], unit_range=self.ranges[index]) - elif isinstance(index, slice): + if isinstance(index, slice): dims_slice = self.dims[index] ranges_slice = self.ranges[index] return Domain(dims=dims_slice, ranges=ranges_slice) - elif isinstance(index, Dimension): - try: - index_pos = self.dims.index(index) - return NamedRange(dim=self.dims[index_pos], unit_range=self.ranges[index_pos]) - except ValueError as ex: - raise KeyError(f"No Dimension of type '{index}' is present in the Domain.") from ex - else: - raise KeyError("Invalid index type, must be either int, slice, or Dimension.") + + raise KeyError("Invalid index type, must be either int, slice, or Dimension.") def __and__(self, other: Domain) -> Domain: """ Intersect `Domain`s, missing `Dimension`s are considered infinite. Examples: - --------- - >>> I = Dimension("I") - >>> J = Dimension("J") + >>> I = Dimension("I") + >>> J = Dimension("J") - >>> Domain(NamedRange(I, UnitRange(-1, 3))) & Domain(NamedRange(I, UnitRange(1, 6))) - Domain(dims=(Dimension(value='I', kind=),), ranges=(UnitRange(1, 3),)) + >>> Domain(NamedRange(I, UnitRange(-1, 3))) & Domain(NamedRange(I, UnitRange(1, 6))) + Domain(dims=(Dimension(value='I', kind=),), ranges=(UnitRange(1, 3),)) - >>> Domain(NamedRange(I, UnitRange(-1, 3)), NamedRange(J, UnitRange(2, 4))) & Domain( - ... NamedRange(I, UnitRange(1, 6)) - ... ) - Domain(dims=(Dimension(value='I', kind=), Dimension(value='J', kind=)), ranges=(UnitRange(1, 3), UnitRange(2, 4))) + >>> Domain(NamedRange(I, UnitRange(-1, 3)), NamedRange(J, UnitRange(2, 4))) & Domain( + ... NamedRange(I, UnitRange(1, 6)) + ... ) + Domain(dims=(Dimension(value='I', kind=), Dimension(value='J', kind=)), ranges=(UnitRange(1, 3), UnitRange(2, 4))) """ broadcast_dims = tuple(promote_dims(self.dims, other.dims)) intersected_ranges = tuple( @@ -480,11 +478,52 @@ def __and__(self, other: Domain) -> Domain: ) return Domain(dims=broadcast_dims, ranges=intersected_ranges) - def __str__(self) -> str: - return f"Domain({', '.join(f'{e}' for e in self)})" + @functools.cached_property + def slice_at(self) -> utils.IndexerCallable[slice, Domain]: + """ + Create a new domain by slicing the domain ranges at the provided relative slices. + + Examples: + >>> I, J = Dimension("I"), Dimension("J") + >>> domain = Domain(NamedRange(I, UnitRange(0, 10)), NamedRange(J, UnitRange(5, 15))) + >>> domain.slice_at[2:3, 2:5] + Domain(dims=(Dimension(value='I', kind=), Dimension(value='J', kind=)), ranges=(UnitRange(2, 3), UnitRange(7, 10))) + """ + + def _domain_slicer(*args: slice) -> Domain: + if not all(isinstance(a, slice) for a in args): + raise TypeError(f"Indices must be 'slice's but got '{args}'") + if len(args) != len(self): + raise ValueError( + f"Number of provided slices ({len(args)}) does not match the number of dimensions ({len(self)})." + ) + return Domain(dims=self.dims, ranges=[r[s] for r, s in zip(self.ranges, args)]) - def dim_index(self, dim: Dimension) -> Optional[int]: - return self.dims.index(dim) if dim in self.dims else None + return utils.IndexerCallable(_domain_slicer) + + @classmethod + def is_finite(cls, obj: Domain) -> TypeGuard[FiniteDomain]: + # classmethod since TypeGuards requires the guarded obj as separate argument + return all(UnitRange.is_finite(rng) for rng in obj.ranges) + + def is_empty(self) -> bool: + return any(rng.is_empty() for rng in self.ranges) + + @overload + def dim_index(self, dim: Dimension, *, allow_missing: Literal[False]) -> int: ... + + @overload + def dim_index( + self, dim: Dimension, *, allow_missing: Literal[True] = True + ) -> Optional[int]: ... + + def dim_index(self, dim: Dimension, *, allow_missing: bool = True) -> Optional[int]: + if dim in self.dims: + return self.dims.index(dim) + elif allow_missing: + return None + else: + raise ValueError(f"Dimension '{dim}' not found in Domain.") def pop(self, index: int | Dimension = -1) -> Domain: return self.replace(index) @@ -530,21 +569,20 @@ def domain(domain_like: DomainLike) -> Domain: Construct `Domain` from `DomainLike` object. Examples: - --------- - >>> I = Dimension("I") - >>> J = Dimension("J") + >>> I = Dimension("I") + >>> J = Dimension("J") - >>> domain(((I, (2, 4)), (J, (3, 5)))) - Domain(dims=(Dimension(value='I', kind=), Dimension(value='J', kind=)), ranges=(UnitRange(2, 4), UnitRange(3, 5))) + >>> domain(((I, (2, 4)), (J, (3, 5)))) + Domain(dims=(Dimension(value='I', kind=), Dimension(value='J', kind=)), ranges=(UnitRange(2, 4), UnitRange(3, 5))) - >>> domain({I: (2, 4), J: (3, 5)}) - Domain(dims=(Dimension(value='I', kind=), Dimension(value='J', kind=)), ranges=(UnitRange(2, 4), UnitRange(3, 5))) + >>> domain({I: (2, 4), J: (3, 5)}) + Domain(dims=(Dimension(value='I', kind=), Dimension(value='J', kind=)), ranges=(UnitRange(2, 4), UnitRange(3, 5))) - >>> domain(((I, 2), (J, 4))) - Domain(dims=(Dimension(value='I', kind=), Dimension(value='J', kind=)), ranges=(UnitRange(0, 2), UnitRange(0, 4))) + >>> domain(((I, 2), (J, 4))) + Domain(dims=(Dimension(value='I', kind=), Dimension(value='J', kind=)), ranges=(UnitRange(0, 2), UnitRange(0, 4))) - >>> domain({I: 2, J: 4}) - Domain(dims=(Dimension(value='I', kind=), Dimension(value='J', kind=)), ranges=(UnitRange(0, 2), UnitRange(0, 4))) + >>> domain({I: 2, J: 4}) + Domain(dims=(Dimension(value='I', kind=), Dimension(value='J', kind=)), ranges=(UnitRange(0, 2), UnitRange(0, 4))) """ if isinstance(domain_like, Domain): return domain_like @@ -630,7 +668,7 @@ def __str__(self) -> str: def asnumpy(self) -> np.ndarray: ... @abc.abstractmethod - def remap(self, index_field: ConnectivityField | fbuiltins.FieldOffset) -> Field: ... + def premap(self, index_field: ConnectivityField | fbuiltins.FieldOffset) -> Field: ... @abc.abstractmethod def restrict(self, item: AnyIndexSpec) -> Self: ... @@ -716,13 +754,41 @@ def __setitem__(self, index: AnyIndexSpec, value: Field | core_defs.ScalarT) -> class ConnectivityKind(enum.Flag): - MODIFY_DIMS = enum.auto() - MODIFY_RANK = enum.auto() - MODIFY_STRUCTURE = enum.auto() + """ + Describes the kind of connectivity field. + + - `ALTER_DIMS`: change the dimensions of the data field domain. + - `ALTER_STRUCT`: transform structured of the data inside the field (non-compact transformation). + + | Dims \ Struct | No | Yes | + | ------------- | ------------------------ | ------------------------ | + | No | Translation (I -> I) | Reshuffling (I x K -> K) | + | Yes | Relocation (I -> I_half) | Remapping (V x V2E -> E) | + + """ + + ALTER_DIMS = enum.auto() + ALTER_STRUCT = enum.auto() + + @classmethod + def translation(cls) -> ConnectivityKind: + return cls(0) + + @classmethod + def relocation(cls) -> ConnectivityKind: + return cls.ALTER_DIMS + + @classmethod + def reshuffling(cls) -> ConnectivityKind: + return cls.ALTER_STRUCT + + @classmethod + def remapping(cls) -> ConnectivityKind: + return cls.ALTER_DIMS | cls.ALTER_STRUCT @runtime_checkable -# type: ignore[misc] # DimT should be covariant, but breaks in another place +# type: ignore[misc] # DimT should be covariant, but then it breaks in other places class ConnectivityField(Field[DimsT, core_defs.IntegralScalar], Protocol[DimsT, DimT]): @property @abc.abstractmethod @@ -730,11 +796,7 @@ def codomain(self) -> DimT: ... @property def kind(self) -> ConnectivityKind: - return ( - ConnectivityKind.MODIFY_DIMS - | ConnectivityKind.MODIFY_RANK - | ConnectivityKind.MODIFY_STRUCTURE - ) + return ConnectivityKind.remapping() @abc.abstractmethod def inverse_image(self, image_range: UnitRange | NamedRange) -> Sequence[NamedRange]: ... @@ -852,11 +914,22 @@ class NeighborTable(Connectivity, Protocol): OffsetProvider: TypeAlias = Mapping[Tag, OffsetProviderElem] +DomainDimT = TypeVar("DomainDimT", bound="Dimension") + + @dataclasses.dataclass(frozen=True, eq=False) -class CartesianConnectivity(ConnectivityField[DimsT, DimT]): - dimension: DimT +class CartesianConnectivity(ConnectivityField[Dims[DomainDimT], DimT]): + domain_dim: DomainDimT + codomain: DimT offset: int = 0 + def __init__( + self, domain_dim: DomainDimT, offset: int = 0, *, codomain: Optional[DimT] = None + ) -> None: + object.__setattr__(self, "domain_dim", domain_dim) + object.__setattr__(self, "codomain", codomain if codomain is not None else domain_dim) + object.__setattr__(self, "offset", offset) + @classmethod def __gt_builtin_func__(cls, _: fbuiltins.BuiltInFunction) -> Never: # type: ignore[override] raise NotImplementedError() @@ -873,7 +946,7 @@ def as_scalar(self) -> Never: @functools.cached_property def domain(self) -> Domain: - return Domain(dims=(self.dimension,), ranges=(UnitRange.infinite(),)) + return Domain(dims=(self.domain_dim,), ranges=(UnitRange.infinite(),)) @property def __gt_origin__(self) -> Never: @@ -883,9 +956,13 @@ def __gt_origin__(self) -> Never: def dtype(self) -> core_defs.DType[core_defs.IntegralScalar]: return core_defs.Int32DType() # type: ignore[return-value] - @functools.cached_property - def codomain(self) -> DimT: - return self.dimension + # This is a workaround to make this class concrete, since `codomain` is an + # abstract property of the `ConnectivityField` Protocol. + if not TYPE_CHECKING: + + @functools.cached_property + def codomain(self) -> DimT: + raise RuntimeError("This property should be always set in the constructor.") @property def skip_value(self) -> None: @@ -893,21 +970,21 @@ def skip_value(self) -> None: @functools.cached_property def kind(self) -> ConnectivityKind: - return ConnectivityKind(0) + return ( + ConnectivityKind.translation() + if self.domain_dim == self.codomain + else ConnectivityKind.relocation() + ) @classmethod - def from_offset( - cls, - definition: int, - /, - codomain: DimT, - *, - domain: Optional[DomainLike] = None, - dtype: Optional[core_defs.DTypeLike] = None, - ) -> CartesianConnectivity: - assert domain is None - assert dtype is None - return cls(codomain, definition) + def for_translation( + cls, dimension: DomainDimT, offset: int + ) -> CartesianConnectivity[DomainDimT, DomainDimT]: + return cast(CartesianConnectivity[DomainDimT, DomainDimT], cls(dimension, offset)) + + @classmethod + def for_relocation(cls, old: DimT, new: DomainDimT) -> CartesianConnectivity[DomainDimT, DimT]: + return cls(new, codomain=old) def inverse_image(self, image_range: UnitRange | NamedRange) -> Sequence[NamedRange]: if not isinstance(image_range, UnitRange): @@ -919,12 +996,12 @@ def inverse_image(self, image_range: UnitRange | NamedRange) -> Sequence[NamedRa image_range = image_range.unit_range assert isinstance(image_range, UnitRange) - return (named_range((self.codomain, image_range - self.offset)),) + return (named_range((self.domain_dim, image_range - self.offset)),) - def remap(self, index_field: ConnectivityField | fbuiltins.FieldOffset) -> ConnectivityField: + def premap(self, index_field: ConnectivityField | fbuiltins.FieldOffset) -> ConnectivityField: raise NotImplementedError() - __call__ = remap + __call__ = premap def restrict(self, index: AnyIndexSpec) -> Never: raise NotImplementedError() # we could possibly implement with a FunctionField, but we don't have a use-case @@ -932,9 +1009,6 @@ def restrict(self, index: AnyIndexSpec) -> Never: __getitem__ = restrict -_connectivity.register(numbers.Integral, CartesianConnectivity.from_offset) - - @enum.unique class GridType(StrEnum): CARTESIAN = "cartesian" @@ -952,18 +1026,22 @@ def promote_dims(*dims_list: Sequence[Dimension]) -> list[Dimension]: A modified version (ensuring uniqueness of the order) of `Kahn's algorithm `_ is used to topologically sort the arguments. - >>> from gt4py.next.common import Dimension - >>> I, J, K = (Dimension(value=dim) for dim in ["I", "J", "K"]) - >>> promote_dims([I, J], [I, J, K]) == [I, J, K] - True - >>> promote_dims([I, J], [K]) # doctest: +ELLIPSIS - Traceback (most recent call last): - ... - ValueError: Dimensions can not be promoted. Could not determine order of the following dimensions: J, K. - >>> promote_dims([I, J], [J, I]) # doctest: +ELLIPSIS - Traceback (most recent call last): - ... - ValueError: Dimensions can not be promoted. The following dimensions appear in contradicting order: I, J. + + Examples: + >>> from gt4py.next.common import Dimension + >>> I, J, K = (Dimension(value=dim) for dim in ["I", "J", "K"]) + >>> promote_dims([I, J], [I, J, K]) == [I, J, K] + True + + >>> promote_dims([I, J], [K]) # doctest: +ELLIPSIS + Traceback (most recent call last): + ... + ValueError: Dimensions can not be promoted. Could not determine order of the following dimensions: J, K. + + >>> promote_dims([I, J], [J, I]) # doctest: +ELLIPSIS + Traceback (most recent call last): + ... + ValueError: Dimensions can not be promoted. The following dimensions appear in contradicting order: I, J. """ # build a graph with the vertices being dimensions and edges representing # the order between two dimensions. The graph is encoded as a dictionary diff --git a/src/gt4py/next/embedded/nd_array_field.py b/src/gt4py/next/embedded/nd_array_field.py index e290da33a2..3322d69379 100644 --- a/src/gt4py/next/embedded/nd_array_field.py +++ b/src/gt4py/next/embedded/nd_array_field.py @@ -14,17 +14,26 @@ from __future__ import annotations +import collections import dataclasses import functools from collections.abc import Callable, Sequence from types import ModuleType -from typing import ClassVar, Iterable import numpy as np from numpy import typing as npt from gt4py._core import definitions as core_defs -from gt4py.eve.extended_typing import Never, Optional, ParamSpec, TypeAlias, TypeVar +from gt4py.eve.extended_typing import ( + ClassVar, + Iterable, + Never, + Optional, + ParamSpec, + TypeAlias, + TypeVar, + cast, +) from gt4py.next import common from gt4py.next.embedded import ( common as embedded_common, @@ -177,52 +186,136 @@ def from_array( return cls(domain, array) - def remap( - self: NdArrayField, connectivity: common.ConnectivityField | fbuiltins.FieldOffset + def premap( + self: NdArrayField, + *connectivities: common.ConnectivityField | fbuiltins.FieldOffset, ) -> NdArrayField: - # For neighbor reductions, a FieldOffset is passed instead of an actual ConnectivityField - if not isinstance(connectivity, common.ConnectivityField): - assert isinstance(connectivity, fbuiltins.FieldOffset) - connectivity = connectivity.as_connectivity_field() - assert isinstance(connectivity, common.ConnectivityField) + """ + Rearrange the field content using the provided connectivity fields as index mappings. + + This operation is conceptually equivalent to a regular composition of mappings + `f∘c`, being `c` the `connectivity` argument and `f` the `self` data field. + Note that the connectivity field appears at the right of the composition + operator and the data field at the left. + + The composition operation is only well-defined when the codomain of `c: A → B` + matches the domain of `f: B → ℝ` and it would then result in a new mapping + `f∘c: A → ℝ` defined as `(f∘c)(x) = f(c(x))`. When remaping a field whose + domain has multiple dimensions `f: A × B → ℝ`, the domain of the connectivity + argument used in the right hand side of the operator should therefore have the + same product of dimensions `c: S × T → A × B`. Such a mapping can also be + expressed as a pair of mappings `c1: S × T → A` and `c2: S × T → B`, and this + is actually the only supported form in GT4Py because `ConnectivityField` instances + can only deal with a single dimension in its codomain. This approach makes + connectivities reusable for any combination of dimensions in a field domain + and matches the NumPy advanced indexing API, which basically is a + composition of mappings of natural numbers representing tensor indices. + + In general, the `premap()` function is able to deal with data fields with multiple + dimensions even if only one connectivity is passed. Connectivity arguments are then + expanded to fully defined connectivities for each dimension in the domain of the + field according to some rules covering the most common use cases. + + Assuming a field `f: Field[Dims[A, B], DT]` the following cases are supported: + + - If the connectivity domain only contains dimensions which are NOT part of the + field domain (new dimensions), this function will use the same rules of + advanced-indexing and replace the connectivity codomain dimension by its domain + dimensions. A way to think about this is that the data field is transformed into + a curried mapping whose domain only contains the connectivity codomain dimension, + then composed as usual with the connectivity, and finally uncurried again: + + `f: A × B → ℝ` => `f': A → (B → ℝ)` + `c: X × Y → A` + `(f'∘c): X × Y → (B → ℝ)` => `(f∘c): X × Y × B → ℝ` + + - If the connectivity domain only contains dimensions which are ALREADY part of the + data field domain, the connectivity field would be interpreted as an homomorphic + function which preserves the domain dimensions. A way to think about this is that + the connectivity defines how the current field data gets translated and rearranged + into new domain ranges, and the mappings for the missing domain dimensions + are assumed to be identities: + + `f: A × B × C → ℝ` + `c: A × B → A` => `c0: A × B × C → A`, `c1: A × B × C → B`, `c2: A × B × C → C` + `(f∘c): A × B × C → ℝ` => `(f∘(c0 × c1 × c2)): A × B × C → ℝ)` + + Note that cartesian shifts (e.g. `I → I_half`, `(I+1): I → I`) are just simpler + versions of these cases where the internal structure of the data (codomain) is + preserved and therefore the `premap` operation can be implemented as a compact + domain translation (i.e. only transform the domain without altering the data). + + A table showing the relation between the connectivity kind and the supported cases + is shown in :class:`common.ConnectivityKind`. + + Args: + *connectivities: connectivities to be used for the `premap` operation. If only one + connectivity is passed, it will be expanded to fully defined connectivities for + each dimension in the domain of the field according to the rules described above. + If more than one connectivity is passed, they all must satisfy: + - be of the same kind or encode only compact domain transformations + - the codomain of each connectivity must be different + - for reshuffling operations, all connectivities must have the same domain + (Note that remapping operations only support a single connectivity argument.) + + """ # noqa: RUF002 # TODO(egparedes): move docstring to the `premap` builtin function when it exists + + conn_fields: list[common.ConnectivityField] = [] + codomains_counter: collections.Counter[common.Dimension] = collections.Counter() + + for connectivity in connectivities: + # For neighbor reductions, a FieldOffset is passed instead of an actual ConnectivityField + if not isinstance(connectivity, common.ConnectivityField): + assert isinstance(connectivity, fbuiltins.FieldOffset) + connectivity = connectivity.as_connectivity_field() + assert isinstance(connectivity, common.ConnectivityField) + + # Current implementation relies on skip_value == -1: + # if we assume the indexed array has at least one element, + # we wrap around without out of bounds access + assert connectivity.skip_value is None or connectivity.skip_value == -1 + + conn_fields.append(connectivity) + codomains_counter[connectivity.codomain] += 1 + + if unknown_dims := [dim for dim in codomains_counter.keys() if dim not in self.domain.dims]: + raise ValueError( + f"Incompatible dimensions in the connectivity codomain(s) {unknown_dims}" + f"while pre-mapping a field with domain {self.domain}." + ) - # Current implementation relies on skip_value == -1: - # if we assume the indexed array has at least one element, we wrap around without out of bounds - assert connectivity.skip_value is None or connectivity.skip_value == -1 + if repeated_codomain_dims := [dim for dim, count in codomains_counter.items() if count > 1]: + raise ValueError( + "All connectivities must have different codomains but some are repeated:" + f" {repeated_codomain_dims}." + ) - # Compute the new domain - dim = connectivity.codomain - dim_idx = self.domain.dim_index(dim) - if dim_idx is None: - raise ValueError(f"Incompatible index field, expected a field with dimension '{dim}'.") + if any(c.kind & common.ConnectivityKind.ALTER_STRUCT for c in conn_fields) and any( + (~c.kind & common.ConnectivityKind.ALTER_STRUCT) for c in conn_fields + ): + raise ValueError( + "Mixing connectivities that change the data structure with connectivities that do not is not allowed." + ) - current_range: common.UnitRange = self.domain[dim_idx].unit_range - new_ranges = connectivity.inverse_image(current_range) - new_domain = self.domain.replace(dim_idx, *new_ranges) + # Select actual implementation of the transformation + if not (conn_fields[0].kind & common.ConnectivityKind.ALTER_STRUCT): + return _domain_premap(self, *conn_fields) - # perform contramap - if not (connectivity.kind & common.ConnectivityKind.MODIFY_STRUCTURE): - # shortcut for compact remap: don't change the array, only the domain - new_buffer = self._ndarray - else: - # general case: first restrict the connectivity to the new domain - restricted_connectivity_domain = common.Domain(*new_ranges) - restricted_connectivity = ( - connectivity.restrict(restricted_connectivity_domain) - if restricted_connectivity_domain != connectivity.domain - else connectivity + if any(c.kind & common.ConnectivityKind.ALTER_DIMS for c in conn_fields) and any( + (~c.kind & common.ConnectivityKind.ALTER_DIMS) for c in conn_fields + ): + raise ValueError( + "Mixing connectivities that change the dimensions in the domain with connectivities that do not is not allowed." ) - assert isinstance(restricted_connectivity, common.ConnectivityField) - # then compute the index array - xp = self.array_ns - new_idx_array = xp.asarray(restricted_connectivity.ndarray) - current_range.start - # finally, take the new array - new_buffer = xp.take(self._ndarray, new_idx_array, axis=dim_idx) + if not (conn_fields[0].kind & common.ConnectivityKind.ALTER_DIMS): + assert all(isinstance(c, NdArrayConnectivityField) for c in conn_fields) + return _reshuffling_premap(self, *cast(list[NdArrayConnectivityField], conn_fields)) - return self.__class__.from_array(new_buffer, domain=new_domain, dtype=self.dtype) + assert len(conn_fields) == 1 + return _remapping_premap(self, conn_fields[0]) - __call__ = remap # type: ignore[assignment] + __call__ = premap # type: ignore[assignment] def restrict(self, index: common.AnyIndexSpec) -> NdArrayField: new_domain, buffer_slice = self._slice(index) @@ -334,6 +427,12 @@ class NdArrayConnectivityField( # type: ignore[misc] # for __ne__, __eq__ ): _codomain: common.DimT _skip_value: Optional[core_defs.IntegralScalar] + _kind: Optional[common.ConnectivityKind] = None + + def __post_init__(self) -> None: + assert self._kind is None or bool(self._kind & common.ConnectivityKind.ALTER_DIMS) == ( + self.domain.dim_index(self.codomain) is not None + ) @functools.cached_property def _cache(self) -> dict: @@ -352,16 +451,22 @@ def codomain(self) -> common.DimT: def skip_value(self) -> Optional[core_defs.IntegralScalar]: return self._skip_value - @functools.cached_property + @property def kind(self) -> common.ConnectivityKind: - kind = common.ConnectivityKind.MODIFY_STRUCTURE - if self.domain.ndim > 1: - kind |= common.ConnectivityKind.MODIFY_RANK - kind |= common.ConnectivityKind.MODIFY_DIMS - if self.domain.dim_index(self.codomain) is None: - kind |= common.ConnectivityKind.MODIFY_DIMS + if self._kind is None: + object.__setattr__( + self, + "_kind", + common.ConnectivityKind.ALTER_STRUCT + | ( + common.ConnectivityKind.ALTER_DIMS + if self.domain.dim_index(self.codomain) is None + else common.ConnectivityKind(0) + ), + ) + assert self._kind is not None - return kind + return self._kind @classmethod def from_array( # type: ignore[override] @@ -393,14 +498,10 @@ def from_array( # type: ignore[override] return cls(domain, array, codomain, _skip_value=skip_value) - def inverse_image( - self, image_range: common.UnitRange | common.NamedRange - ) -> Sequence[common.NamedRange]: + def inverse_image(self, image_range: common.UnitRange | common.NamedRange) -> common.Domain: cache_key = hash((id(self.ndarray), self.domain, image_range)) - if (new_dims := self._cache.get(cache_key, None)) is None: - xp = self.array_ns - + if (new_domain := self._cache.get(cache_key, None)) is None: if not isinstance( image_range, common.UnitRange ): # TODO(havogt): cleanup duplication with CartesianConnectivity @@ -412,19 +513,17 @@ def inverse_image( image_range = image_range.unit_range assert isinstance(image_range, common.UnitRange) - assert common.UnitRange.is_finite(image_range) - relative_ranges = _hypercube(self._ndarray, image_range, xp, self.skip_value) - - if relative_ranges is None: + xp = self.array_ns + slices = _hyperslice(self._ndarray, image_range, xp, self.skip_value) + if slices is None: raise ValueError("Restriction generates non-contiguous dimensions.") - new_dims = _relative_ranges_to_domain(relative_ranges, self.domain) - - self._cache[cache_key] = new_dims + new_domain = self.domain.slice_at[slices] + self._cache[cache_key] = new_domain - return new_dims + return new_domain def restrict(self, index: common.AnyIndexSpec) -> NdArrayConnectivityField: cache_key = (id(self.ndarray), self.domain, index) @@ -442,31 +541,180 @@ def restrict(self, index: common.AnyIndexSpec) -> NdArrayConnectivityField: __getitem__ = restrict -def _relative_ranges_to_domain( - relative_ranges: Sequence[common.UnitRange], domain: common.Domain -) -> common.Domain: - return common.Domain( - dims=domain.dims, ranges=[rr + ar.start for ar, rr in zip(domain.ranges, relative_ranges)] +def _domain_premap(data: NdArrayField, *connectivities: common.ConnectivityField) -> NdArrayField: + """`premap` implementation transforming only the field domain not the data (i.e. translation and relocation).""" + new_domain = data.domain + for connectivity in connectivities: + dim = connectivity.codomain + dim_idx = data.domain.dim_index(dim) + if dim_idx is None: + raise ValueError( + f"Incompatible index field expects a data field with dimension '{dim}'" + f"but got '{data.domain}'." + ) + + current_range: common.UnitRange = data.domain[dim_idx].unit_range + new_ranges = connectivity.inverse_image(current_range) + new_domain = new_domain.replace(dim_idx, *new_ranges) + + return data.__class__.from_array(data._ndarray, domain=new_domain, dtype=data.dtype) + + +def _reshuffling_premap( + data: NdArrayField, *connectivities: NdArrayConnectivityField +) -> NdArrayField: + # Check that all connectivities have the same domain + assert len(connectivities) == 1 or all( + c.domain == connectivities[0].domain for c in connectivities[1:] + ) + + connectivity = connectivities[0] + xp = data.array_ns + + # Reorder and complete connectivity dimensions to match the field domain + # It should be enough to check this only the first connectivity + # since all connectivities must have the same domain + transposed_axes = [] + expanded_axes: list[int] = [] + transpose_needed = False + for new_dim_idx, dim in enumerate(data.domain.dims): + if (dim_idx := connectivity.domain.dim_index(dim)) is None: + expanded_axes.append(connectivity.domain.ndim + len(expanded_axes)) + dim_idx = expanded_axes[-1] + transposed_axes.append(dim_idx) + transpose_needed = transpose_needed | (dim_idx != new_dim_idx) + + # Broadcast connectivity arrays to match the full domain + conn_map = {} + new_ranges = data.domain.ranges + for conn in connectivities: + conn_ndarray = conn.ndarray + if expanded_axes: + conn_ndarray = xp.expand_dims(conn_ndarray, axis=expanded_axes) + if transpose_needed: + conn_ndarray = xp.transpose(conn_ndarray, transposed_axes) + if conn_ndarray.shape != data.domain.shape: + conn_ndarray = xp.broadcast_to(conn_ndarray, data.domain.shape) + if conn_ndarray is not conn.ndarray: + conn = conn.__class__.from_array( + conn_ndarray, domain=data.domain, codomain=conn.codomain + ) + conn_map[conn.codomain] = conn + dim_idx = data.domain.dim_index(conn.codomain, allow_missing=False) + current_range: common.UnitRange = data.domain.ranges[dim_idx] + new_conn_ranges = connectivity.inverse_image(current_range).ranges + new_ranges = tuple(r & s for r, s in zip(new_ranges, new_conn_ranges)) + + conns_dims = [c.domain.dims for c in conn_map.values()] + for i in range(len(conns_dims) - 1): + if conns_dims[i] != conns_dims[i + 1]: + raise ValueError( + f"All premapping connectivities must have the same dimensions, got: {conns_dims}." + ) + + new_domain = common.Domain(dims=data.domain.dims, ranges=new_ranges) + + # Create identity connectivities for the missing domain dimensions + for dim in data.domain.dims: + if dim not in conn_map: + conn_map[dim] = _identity_connectivity(new_domain, dim, cls=type(connectivity)) + + # Take data + take_indices = tuple(conn_map[dim].ndarray for dim in data.domain.dims) + new_buffer = data._ndarray.__getitem__(take_indices) + + return data.__class__.from_array( + new_buffer, + domain=new_domain, + dtype=data.dtype, + ) + + +def _remapping_premap(data: NdArrayField, connectivity: common.ConnectivityField) -> NdArrayField: + new_dims = {*connectivity.domain.dims} - {connectivity.codomain} + if repeated_dims := (new_dims & {*data.domain.dims}): + raise ValueError(f"Remapped field will contain repeated dimensions '{repeated_dims}'.") + + # Compute the new domain + dim = connectivity.codomain + dim_idx = data.domain.dim_index(dim) + if dim_idx is None: + raise ValueError(f"Incompatible index field, expected a field with dimension '{dim}'.") + + current_range: common.UnitRange = data.domain[dim_idx][1] + new_ranges = connectivity.inverse_image(current_range) + new_domain = data.domain.replace(dim_idx, *new_ranges) + + # Perform premap: + xp = data.array_ns + + # 1- first restrict the connectivity to the new domain + restricted_connectivity_domain = common.Domain(*new_ranges) + restricted_connectivity = ( + connectivity.restrict(restricted_connectivity_domain) + if restricted_connectivity_domain != connectivity.domain + else connectivity + ) + assert isinstance(restricted_connectivity, common.ConnectivityField) + + # 2- then compute the index array + new_idx_array = xp.asarray(restricted_connectivity.ndarray) - current_range.start + + # 3- finally, take the new array + new_buffer = xp.take(data._ndarray, new_idx_array, axis=dim_idx) + + return data.__class__.from_array( + new_buffer, + domain=new_domain, + dtype=data.dtype, + ) + + +_NdConnT = TypeVar("_NdConnT", bound=NdArrayConnectivityField) + + +def _identity_connectivity( + domain: common.Domain, codomain: common.DimT, *, cls: type[_NdConnT] +) -> _NdConnT: + assert codomain in domain.dims + xp = cls.array_ns + shape = domain.shape + d_idx = domain.dim_index(codomain, allow_missing=False) + indices = xp.arange(domain[d_idx].unit_range.start, domain[d_idx].unit_range.stop) + result = cls.from_array( + xp.broadcast_to( + indices[ + tuple(slice(None) if i == d_idx else None for i, dim in enumerate(domain.dims)) + ], + shape, + ), + codomain=codomain, + domain=domain, + dtype=int, ) + return cast(_NdConnT, result) + -def _hypercube( +def _hyperslice( index_array: core_defs.NDArrayObject, image_range: common.UnitRange, xp: ModuleType, skip_value: Optional[core_defs.IntegralScalar] = None, -) -> Optional[list[common.UnitRange]]: +) -> Optional[tuple[slice, ...]]: """ - Return the hypercube that contains all indices in `index_array` that are within `image_range`, or `None` if no such hypercube exists. + Return the hypercube slice that contains all indices in `index_array` that are within `image_range`, or `None` if no such hypercube exists. If `skip_value` is given, the selected values are ignored. It returns the smallest hypercube. A bigger hypercube could be constructed by adding lines that contain only `skip_value`s. + Example: - index_array = 0 1 -1 - 3 4 -1 - -1 -1 -1 - skip_value = -1 - would currently select the 2x2 range [0,2], [0,2], but could also select the 3x3 range [0,3], [0,3]. + index_array = 0 1 -1 + 3 4 -1 + -1 -1 -1 + skip_value = -1 + + would currently select the 2x2 range [0,2], [0,2], but could also select the 3x3 range [0,3], [0,3]. """ select_mask = (index_array >= image_range.start) & (index_array < image_range.stop) @@ -482,7 +730,7 @@ def _hypercube( if not xp.all(hcube): return None - return [common.UnitRange(s.start, s.stop) for s in slices] + return slices # -- Specialized implementations for builtin operations on array fields -- @@ -521,21 +769,22 @@ def _hypercube( NdArrayField.register_builtin_func(fbuiltins.where, _make_builtin("where", "where")) -def _compute_mask_ranges(mask: core_defs.NDArrayObject) -> list[tuple[bool, common.UnitRange]]: - """Take a 1-dimensional mask and return a sequence of mappings from boolean values to ranges.""" +def _compute_mask_slices( + mask: core_defs.NDArrayObject, +) -> list[tuple[bool, slice]]: + """Take a 1-dimensional mask and return a sequence of mappings from boolean values to slices.""" # TODO: does it make sense to upgrade this naive algorithm to numpy? assert mask.ndim == 1 cur = bool(mask[0].item()) ind = 0 res = [] for i in range(1, mask.shape[0]): - if ( - mask_i := bool(mask[i].item()) - ) != cur: # `.item()` to extract the scalar from a 0-d array in case of e.g. cupy - res.append((cur, common.UnitRange(ind, i))) + # Use `.item()` to extract the scalar from a 0-d array in case of e.g. cupy + if (mask_i := bool(mask[i].item())) != cur: + res.append((cur, slice(ind, i))) cur = mask_i ind = i - res.append((cur, common.UnitRange(ind, mask.shape[0]))) + res.append((cur, slice(ind, mask.shape[0]))) return res @@ -620,7 +869,7 @@ def _concat(*fields: common.Field, dim: common.Dimension) -> common.Field: return nd_array_class.from_array( nd_array_class.array_ns.concatenate( [nd_array_class.array_ns.broadcast_to(f.ndarray, f.domain.shape) for f in fields], - axis=new_domain.dim_index(dim), + axis=new_domain.dim_index(dim, allow_missing=False), ), domain=new_domain, ) @@ -642,12 +891,12 @@ def _concat_where( # TODO(havogt): for clarity, most of it could be implemented on named_range in the masked dimension, but we currently lack the utils # compute the consecutive ranges (first relative, then domain) of true and false values - mask_values_to_relative_range_mapping: Iterable[tuple[bool, common.UnitRange]] = ( - _compute_mask_ranges(mask_field.ndarray) + mask_values_to_slices_mapping: Iterable[tuple[bool, slice]] = _compute_mask_slices( + mask_field.ndarray ) mask_values_to_domain_mapping: Iterable[tuple[bool, common.Domain]] = ( - (mask, _relative_ranges_to_domain((relative_range,), mask_field.domain)) - for mask, relative_range in mask_values_to_relative_range_mapping + (mask, mask_field.domain.slice_at[domain_slice]) + for mask, domain_slice in mask_values_to_slices_mapping ) # mask domains intersected with the respective fields mask_values_to_intersected_domains_mapping: Iterable[tuple[bool, common.Domain]] = ( diff --git a/src/gt4py/next/iterator/embedded.py b/src/gt4py/next/iterator/embedded.py index f5d4c6e53b..47bcd938e0 100644 --- a/src/gt4py/next/iterator/embedded.py +++ b/src/gt4py/next/iterator/embedded.py @@ -1071,7 +1071,7 @@ def as_scalar(self) -> core_defs.IntegralScalar: assert self._cur_index is not None return self._cur_index - def remap(self, index_field: common.ConnectivityField | fbuiltins.FieldOffset) -> common.Field: + def premap(self, index_field: common.ConnectivityField | fbuiltins.FieldOffset) -> common.Field: # TODO can be implemented by constructing and ndarray (but do we know of which kind?) raise NotImplementedError() @@ -1085,7 +1085,7 @@ def restrict(self, item: common.AnyIndexSpec) -> Self: # TODO set a domain... raise NotImplementedError() - __call__ = remap + __call__ = premap __getitem__ = restrict def __abs__(self) -> common.Field: @@ -1191,7 +1191,7 @@ def ndarray(self) -> core_defs.NDArrayObject: def asnumpy(self) -> np.ndarray: raise NotImplementedError() - def remap(self, index_field: common.ConnectivityField | fbuiltins.FieldOffset) -> common.Field: + def premap(self, index_field: common.ConnectivityField | fbuiltins.FieldOffset) -> common.Field: # TODO can be implemented by constructing and ndarray (but do we know of which kind?) raise NotImplementedError() @@ -1203,7 +1203,7 @@ def as_scalar(self) -> core_defs.ScalarT: assert self.domain.ndim == 0 return self._value - __call__ = remap + __call__ = premap __getitem__ = restrict def __abs__(self) -> common.Field: diff --git a/tests/eve_tests/unit_tests/test_utils.py b/tests/eve_tests/unit_tests/test_utils.py index e64de42db8..7777f4bb75 100644 --- a/tests/eve_tests/unit_tests/test_utils.py +++ b/tests/eve_tests/unit_tests/test_utils.py @@ -24,6 +24,31 @@ from gt4py.eve.utils import XIterable +def test_first(): + from gt4py.eve.utils import first + + # Test case 1: Non-empty iterable + iterable = [1, 2, 3, 4, 5] + result = first(iterable) + assert result == 1 + + # Test case 2: Empty iterable with default value + iterable = [] + default = "default" + result = first(iterable, default=default) + assert result == default + + # Test case 3: Empty iterable without default value + iterable = [] + with pytest.raises(StopIteration): + first(iterable) + + # Test case 4: Iterable with single element + iterable = [42] + result = first(iterable) + assert result == 42 + + def test_getitem_(): from gt4py.eve.utils import getitem_ diff --git a/tests/next_tests/integration_tests/feature_tests/ffront_tests/test_type_deduction.py b/tests/next_tests/integration_tests/feature_tests/ffront_tests/test_type_deduction.py index 06226548ed..a02f2d210a 100644 --- a/tests/next_tests/integration_tests/feature_tests/ffront_tests/test_type_deduction.py +++ b/tests/next_tests/integration_tests/feature_tests/ffront_tests/test_type_deduction.py @@ -510,7 +510,7 @@ def not_int(a: Field[[TDim], int64]): @pytest.fixture -def remap_setup(): +def premap_setup(): X = Dimension("X") Y = Dimension("Y") Y2XDim = Dimension("Y2X", kind=DimensionKind.LOCAL) @@ -518,52 +518,52 @@ def remap_setup(): return X, Y, Y2XDim, Y2X -def test_remap(remap_setup): - X, Y, Y2XDim, Y2X = remap_setup +def test_premap(premap_setup): + X, Y, Y2XDim, Y2X = premap_setup - def remap_fo(bar: Field[[X], int64]) -> Field[[Y], int64]: + def premap_fo(bar: Field[[X], int64]) -> Field[[Y], int64]: return bar(Y2X[0]) - parsed = FieldOperatorParser.apply_to_function(remap_fo) + parsed = FieldOperatorParser.apply_to_function(premap_fo) assert parsed.body.stmts[0].value.type == ts.FieldType( dims=[Y], dtype=ts.ScalarType(kind=ts.ScalarKind.INT64) ) -def test_remap_nbfield(remap_setup): - X, Y, Y2XDim, Y2X = remap_setup +def test_premap_nbfield(premap_setup): + X, Y, Y2XDim, Y2X = premap_setup - def remap_fo(bar: Field[[X], int64]) -> Field[[Y, Y2XDim], int64]: + def premap_fo(bar: Field[[X], int64]) -> Field[[Y, Y2XDim], int64]: return bar(Y2X) - parsed = FieldOperatorParser.apply_to_function(remap_fo) + parsed = FieldOperatorParser.apply_to_function(premap_fo) assert parsed.body.stmts[0].value.type == ts.FieldType( dims=[Y, Y2XDim], dtype=ts.ScalarType(kind=ts.ScalarKind.INT64) ) -def test_remap_reduce(remap_setup): - X, Y, Y2XDim, Y2X = remap_setup +def test_premap_reduce(premap_setup): + X, Y, Y2XDim, Y2X = premap_setup - def remap_fo(bar: Field[[X], int32]) -> Field[[Y], int32]: + def premap_fo(bar: Field[[X], int32]) -> Field[[Y], int32]: return 2 * neighbor_sum(bar(Y2X), axis=Y2XDim) - parsed = FieldOperatorParser.apply_to_function(remap_fo) + parsed = FieldOperatorParser.apply_to_function(premap_fo) assert parsed.body.stmts[0].value.type == ts.FieldType( dims=[Y], dtype=ts.ScalarType(kind=ts.ScalarKind.INT32) ) -def test_remap_reduce_sparse(remap_setup): - X, Y, Y2XDim, Y2X = remap_setup +def test_premap_reduce_sparse(premap_setup): + X, Y, Y2XDim, Y2X = premap_setup - def remap_fo(bar: Field[[Y, Y2XDim], int32]) -> Field[[Y], int32]: + def premap_fo(bar: Field[[Y, Y2XDim], int32]) -> Field[[Y], int32]: return 5 * neighbor_sum(bar, axis=Y2XDim) - parsed = FieldOperatorParser.apply_to_function(remap_fo) + parsed = FieldOperatorParser.apply_to_function(premap_fo) assert parsed.body.stmts[0].value.type == ts.FieldType( dims=[Y], dtype=ts.ScalarType(kind=ts.ScalarKind.INT32) diff --git a/tests/next_tests/unit_tests/embedded_tests/test_nd_array_field.py b/tests/next_tests/unit_tests/embedded_tests/test_nd_array_field.py index 1c818c8e23..67c19b6c44 100644 --- a/tests/next_tests/unit_tests/embedded_tests/test_nd_array_field.py +++ b/tests/next_tests/unit_tests/embedded_tests/test_nd_array_field.py @@ -11,7 +11,6 @@ # distribution for a copy of the license or check . # # SPDX-License-Identifier: GPL-3.0-or-later -import itertools import math import operator from typing import Callable, Iterable, Optional @@ -320,7 +319,81 @@ def fma(a: common.Field, b: common.Field, c: common.Field, /) -> common.Field: assert np.allclose(result.ndarray, expected) -def test_remap_implementation(): +def test_domain_premap(): + # Translation case + I = Dimension("I") + J = Dimension("J") + + N = 10 + data_field = common._field( + 0.1 * np.arange(N * N).reshape((N, N)), + domain=common.Domain( + common.NamedRange(I, common.unit_range(N)), common.NamedRange(J, common.unit_range(N)) + ), + ) + conn = common.CartesianConnectivity.for_translation(J, +1) + + result = data_field.premap(conn) + expected = common._field( + data_field.ndarray, + domain=common.Domain( + common.NamedRange(I, common.unit_range(N)), + common.NamedRange(J, common.unit_range((-1, N - 1))), + ), + ) + + assert result.domain == expected.domain + assert np.all(result.ndarray == expected.ndarray) + + # Relocation case + I_half = Dimension("I_half") + + conn = common.CartesianConnectivity.for_relocation(I, I_half) + + result = data_field.premap(conn) + expected = common._field( + data_field.ndarray, + domain=common.Domain( + dims=( + I_half, + J, + ), + ranges=(data_field.domain[I].unit_range, data_field.domain[J].unit_range), + ), + ) + + assert result.domain == expected.domain + assert np.all(result.ndarray == expected.ndarray) + + +def test_reshuffling_premap(): + I = Dimension("I") + J = Dimension("J") + + ij_field = common._field( + np.asarray([[0.0, 1.0, 2.0], [3.0, 4.0, 5.0], [6.0, 7.0, 8.0]]), + domain=common.Domain(dims=(I, J), ranges=(UnitRange(0, 3), UnitRange(0, 3))), + ) + max_ij_conn = common._connectivity( + np.fromfunction(lambda i, j: np.maximum(i, j), (3, 3), dtype=int), + domain=common.Domain( + dims=ij_field.domain.dims, + ranges=ij_field.domain.ranges, + ), + codomain=I, + ) + + result = ij_field.premap(max_ij_conn) + expected = common._field( + np.asarray([[0.0, 4.0, 8.0], [3.0, 4.0, 8.0], [6.0, 7.0, 8.0]]), + domain=common.Domain(dims=(I, J), ranges=(UnitRange(0, 3), UnitRange(0, 3))), + ) + + assert result.domain == expected.domain + assert np.all(result.ndarray == expected.ndarray) + + +def test_remapping_premap(): V = Dimension("V") E = Dimension("E") @@ -336,7 +409,7 @@ def test_remap_implementation(): codomain=V, ) - result = v_field.remap(e2v_conn) + result = v_field.premap(e2v_conn) expected = common._field( -0.1 * np.arange(V_START, V_STOP), domain=common.Domain(dims=(E,), ranges=(UnitRange(V_START, V_STOP),)), @@ -346,26 +419,64 @@ def test_remap_implementation(): assert np.all(result.ndarray == expected.ndarray) -def test_cartesian_remap_implementation(): - V = Dimension("V") - E = Dimension("E") - - V_START, V_STOP = 2, 7 - OFFSET = 2 - v_field = common._field( - -0.1 * np.arange(V_START, V_STOP), - domain=common.Domain(dims=(V,), ranges=(UnitRange(V_START, V_STOP),)), - ) - v2_conn = common._connectivity(OFFSET, V) +def test_identity_connectivity(): + D0 = Dimension("D0") + D1 = Dimension("D1") + D2 = Dimension("D2") - result = v_field.remap(v2_conn) - expected = common._field( - v_field.ndarray, - domain=common.Domain(dims=(V,), ranges=(UnitRange(V_START - OFFSET, V_STOP - OFFSET),)), + domain = common.Domain( + dims=(D0, D1, D2), + ranges=(common.UnitRange(0, 3), common.UnitRange(0, 4), common.UnitRange(0, 5)), ) + codomains = [D0, D1, D2] + + expected = { + D0: nd_array_field.NumPyArrayConnectivityField.from_array( + np.array( + [ + [[0, 0, 0, 0, 0], [0, 0, 0, 0, 0], [0, 0, 0, 0, 0], [0, 0, 0, 0, 0]], + [[1, 1, 1, 1, 1], [1, 1, 1, 1, 1], [1, 1, 1, 1, 1], [1, 1, 1, 1, 1]], + [[2, 2, 2, 2, 2], [2, 2, 2, 2, 2], [2, 2, 2, 2, 2], [2, 2, 2, 2, 2]], + ], + dtype=int, + ), + codomain=D0, + domain=domain, + ), + D1: nd_array_field.NumPyArrayConnectivityField.from_array( + np.array( + [ + [[0, 0, 0, 0, 0], [1, 1, 1, 1, 1], [2, 2, 2, 2, 2], [3, 3, 3, 3, 3]], + [[0, 0, 0, 0, 0], [1, 1, 1, 1, 1], [2, 2, 2, 2, 2], [3, 3, 3, 3, 3]], + [[0, 0, 0, 0, 0], [1, 1, 1, 1, 1], [2, 2, 2, 2, 2], [3, 3, 3, 3, 3]], + ], + dtype=int, + ), + codomain=D1, + domain=domain, + ), + D2: nd_array_field.NumPyArrayConnectivityField.from_array( + np.array( + [ + [[0, 1, 2, 3, 4], [0, 1, 2, 3, 4], [0, 1, 2, 3, 4], [0, 1, 2, 3, 4]], + [[0, 1, 2, 3, 4], [0, 1, 2, 3, 4], [0, 1, 2, 3, 4], [0, 1, 2, 3, 4]], + [[0, 1, 2, 3, 4], [0, 1, 2, 3, 4], [0, 1, 2, 3, 4], [0, 1, 2, 3, 4]], + ], + dtype=int, + ), + codomain=D2, + domain=domain, + ), + } - assert result.domain == expected.domain - assert np.all(result.ndarray == expected.ndarray) + for codomain in codomains: + result = nd_array_field._identity_connectivity( + domain, codomain, cls=nd_array_field.NumPyArrayConnectivityField + ) + assert result.codomain == expected[codomain].codomain + assert result.domain == expected[codomain].domain + assert result.dtype == expected[codomain].dtype + assert np.all(result.ndarray == expected[codomain].ndarray) @pytest.mark.parametrize( @@ -867,14 +978,14 @@ def test_connectivity_field_inverse_image_2d_domain_skip_values(): ([[1, 0, -1], [1, 0, 0]], [(0, 2), (1, 3)]), ], ) -def test_hypercube(index_array, expected): +def test_hyperslice(index_array, expected): index_array = np.asarray(index_array) image_range = common.UnitRange(0, 1) skip_value = -1 - expected = [common.unit_range(e) for e in expected] if expected is not None else None + expected = tuple(slice(*e) for e in expected) if expected is not None else None - result = nd_array_field._hypercube(index_array, image_range, np, skip_value) + result = nd_array_field._hyperslice(index_array, image_range, np, skip_value) assert result == expected diff --git a/tests/next_tests/unit_tests/test_common.py b/tests/next_tests/unit_tests/test_common.py index 44150f344e..b84873bbdf 100644 --- a/tests/next_tests/unit_tests/test_common.py +++ b/tests/next_tests/unit_tests/test_common.py @@ -16,6 +16,7 @@ import pytest +import gt4py.next.common as common from gt4py.next.common import ( Dimension, DimensionKind, @@ -30,8 +31,8 @@ ) -IDim = Dimension("IDim") ECDim = Dimension("ECDim") +IDim = Dimension("IDim") JDim = Dimension("JDim") KDim = Dimension("KDim", kind=DimensionKind.VERTICAL) @@ -404,6 +405,41 @@ def test_domain_dims_ranges_length_mismatch(): Domain(dims=dims, ranges=ranges) +def test_domain_slice_at(): + # Create a sample domain + domain = Domain( + NamedRange(IDim, UnitRange(0, 10)), + NamedRange(JDim, UnitRange(5, 15)), + NamedRange(KDim, UnitRange(20, 30)), + ) + + # Test indexing with slices + result = domain.slice_at[slice(2, 5), slice(5, 7), slice(7, 10)] + expected_result = Domain( + NamedRange(IDim, UnitRange(2, 5)), + NamedRange(JDim, UnitRange(10, 12)), + NamedRange(KDim, UnitRange(27, 30)), + ) + assert result == expected_result + + # Test indexing with out-of-range slices + result = domain.slice_at[slice(2, 15), slice(5, 7), slice(7, 10)] + expected_result = Domain( + NamedRange(IDim, UnitRange(2, 10)), + NamedRange(JDim, UnitRange(10, 12)), + NamedRange(KDim, UnitRange(27, 30)), + ) + assert result == expected_result + + # Test indexing with incorrect types + with pytest.raises(TypeError): + domain.slice_at["a", 7, 25] + + # Test indexing with incorrect number of indices + with pytest.raises(ValueError, match="not match the number of dimensions"): + domain.slice_at[slice(2, 5), slice(7, 10)] + + def test_domain_dim_index(): dims = [Dimension("X"), Dimension("Y"), Dimension("Z")] ranges = [UnitRange(0, 1), UnitRange(0, 1), UnitRange(0, 1)] @@ -578,3 +614,25 @@ def test_dimension_promotion( promote_dims(*dim_list) assert exc_info.match(expected_error_msg) + + +class TestCartesianConnectivity: + def test_for_translation(self): + offset = 5 + I = common.Dimension("I") + + result = common.CartesianConnectivity.for_translation(I, offset) + assert isinstance(result, common.CartesianConnectivity) + assert result.domain_dim == I + assert result.codomain == I + assert result.offset == offset + + def test_for_relocation(self): + I = common.Dimension("I") + I_half = common.Dimension("I_half") + + result = common.CartesianConnectivity.for_relocation(I, I_half) + assert isinstance(result, common.CartesianConnectivity) + assert result.domain_dim == I_half + assert result.codomain == I + assert result.offset == 0