From 8d4cf628004cfd40de54c69deaa8940e31975737 Mon Sep 17 00:00:00 2001 From: Enrique Gonzalez Paredes Date: Mon, 29 Apr 2024 17:31:46 +0200 Subject: [PATCH] Refactor identity_connectivity --- src/gt4py/next/embedded/nd_array_field.py | 53 +++++++------------ .../embedded_tests/test_nd_array_field.py | 29 +++++----- 2 files changed, 33 insertions(+), 49 deletions(-) diff --git a/src/gt4py/next/embedded/nd_array_field.py b/src/gt4py/next/embedded/nd_array_field.py index e2ad6d2d3d..39f8afeaf8 100644 --- a/src/gt4py/next/embedded/nd_array_field.py +++ b/src/gt4py/next/embedded/nd_array_field.py @@ -24,7 +24,6 @@ from numpy import typing as npt from gt4py._core import definitions as core_defs -from gt4py.eve import utils from gt4py.eve.extended_typing import ( ClassVar, Iterable, @@ -262,9 +261,7 @@ def premap( conn_fields.append(connectivity) codomains_counter[connectivity.codomain] += 1 - if unknown_dims := [ - dim for dim in codomains_counter.keys() if dim not in self.domain.dims - ]: + if unknown_dims := [dim for dim in codomains_counter.keys() if dim not in self.domain.dims]: raise ValueError( f"Incompatible dimensions in the connectivity codomain(s) {unknown_dims}" f"while pre-mapping a field with domain {self.domain}." @@ -596,9 +593,7 @@ def _reshuffling_premap( # Create identity connectivities for the missing domain dimensions for dim in data.domain.dims: if dim not in conn_map: - conn_map[dim] = utils.first( - _identity_connectivities(new_domain, [dim], cls=type(connectivity)) - ) + conn_map[dim] = _identity_connectivity(new_domain, dim, cls=type(connectivity)) # Take data take_indices = tuple(conn_map[dim].ndarray for dim in data.domain.dims) @@ -654,36 +649,26 @@ def _remapping_premap(data: NdArrayField, connectivity: common.ConnectivityField _ConnT = TypeVar("_ConnT", bound=common.ConnectivityField) -def _identity_connectivities( - domain: common.Domain, - codomains: Sequence[common.DimT], - *, - cls: type[_ConnT], +def _identity_connectivity( + domain: common.Domain, codomain: common.DimT, *, cls: type[_ConnT] ) -> tuple[_ConnT, ...]: + assert codomain in domain.dims xp = cls.array_ns shape = domain.shape - identities = [] - for d in codomains: - assert d in domain.dims - d_idx = domain.dim_index(d) - indices = xp.arange(domain[d_idx].unit_range.start, domain[d_idx].unit_range.stop) - identities.append( - cls.from_array( - xp.broadcast_to( - indices[ - tuple( - slice(None) if i == d_idx else None for i, dim in enumerate(domain.dims) - ) - ], - shape, - ), - codomain=d, - domain=domain, - dtype=int, - ) - ) - - return tuple(identities) + d_idx = domain.dim_index(codomain) + indices = xp.arange(domain[d_idx].unit_range.start, domain[d_idx].unit_range.stop) + + return cls.from_array( + xp.broadcast_to( + indices[ + tuple(slice(None) if i == d_idx else None for i, dim in enumerate(domain.dims)) + ], + shape, + ), + codomain=codomain, + domain=domain, + dtype=int, + ) def _hyperslice( diff --git a/tests/next_tests/unit_tests/embedded_tests/test_nd_array_field.py b/tests/next_tests/unit_tests/embedded_tests/test_nd_array_field.py index afb1a27112..e4eb89207f 100644 --- a/tests/next_tests/unit_tests/embedded_tests/test_nd_array_field.py +++ b/tests/next_tests/unit_tests/embedded_tests/test_nd_array_field.py @@ -391,7 +391,7 @@ def test_remapping_premap(): assert np.all(result.ndarray == expected.ndarray) -def test_identity_connectivities(): +def test_identity_connectivity(): D0 = Dimension("D0") D1 = Dimension("D1") D2 = Dimension("D2") @@ -402,12 +402,8 @@ def test_identity_connectivities(): ) codomains = [D0, D1, D2] - result = nd_array_field._identity_connectivities( - domain, codomains, cls=nd_array_field.NumPyArrayConnectivityField - ) - - expected = ( - nd_array_field.NumPyArrayConnectivityField.from_array( + expected = { + D0: nd_array_field.NumPyArrayConnectivityField.from_array( np.array( [ [[0, 0, 0, 0, 0], [0, 0, 0, 0, 0], [0, 0, 0, 0, 0], [0, 0, 0, 0, 0]], @@ -419,7 +415,7 @@ def test_identity_connectivities(): codomain=D0, domain=domain, ), - nd_array_field.NumPyArrayConnectivityField.from_array( + D1: nd_array_field.NumPyArrayConnectivityField.from_array( np.array( [ [[0, 0, 0, 0, 0], [1, 1, 1, 1, 1], [2, 2, 2, 2, 2], [3, 3, 3, 3, 3]], @@ -431,7 +427,7 @@ def test_identity_connectivities(): codomain=D1, domain=domain, ), - nd_array_field.NumPyArrayConnectivityField.from_array( + D2: nd_array_field.NumPyArrayConnectivityField.from_array( np.array( [ [[0, 1, 2, 3, 4], [0, 1, 2, 3, 4], [0, 1, 2, 3, 4], [0, 1, 2, 3, 4]], @@ -443,13 +439,16 @@ def test_identity_connectivities(): codomain=D2, domain=domain, ), - ) + } - for r, e in zip(result, expected): - assert r.codomain == e.codomain - assert r.domain == e.domain - assert r.dtype == e.dtype - assert np.all(r.ndarray == e.ndarray) + for codomain in codomains: + result = nd_array_field._identity_connectivity( + domain, codomain, cls=nd_array_field.NumPyArrayConnectivityField + ) + assert result.codomain == expected[codomain].codomain + assert result.domain == expected[codomain].domain + assert result.dtype == expected[codomain].dtype + assert np.all(result.ndarray == expected[codomain].ndarray) @pytest.mark.parametrize(