Skip to content

Commit

Permalink
Refactor identity_connectivity
Browse files Browse the repository at this point in the history
  • Loading branch information
egparedes committed Apr 29, 2024
1 parent e8a5c42 commit 8d4cf62
Show file tree
Hide file tree
Showing 2 changed files with 33 additions and 49 deletions.
53 changes: 19 additions & 34 deletions src/gt4py/next/embedded/nd_array_field.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,6 @@
from numpy import typing as npt

from gt4py._core import definitions as core_defs
from gt4py.eve import utils
from gt4py.eve.extended_typing import (
ClassVar,
Iterable,
Expand Down Expand Up @@ -262,9 +261,7 @@ def premap(
conn_fields.append(connectivity)
codomains_counter[connectivity.codomain] += 1

if unknown_dims := [
dim for dim in codomains_counter.keys() if dim not in self.domain.dims
]:
if unknown_dims := [dim for dim in codomains_counter.keys() if dim not in self.domain.dims]:
raise ValueError(
f"Incompatible dimensions in the connectivity codomain(s) {unknown_dims}"
f"while pre-mapping a field with domain {self.domain}."
Expand Down Expand Up @@ -596,9 +593,7 @@ def _reshuffling_premap(
# Create identity connectivities for the missing domain dimensions
for dim in data.domain.dims:
if dim not in conn_map:
conn_map[dim] = utils.first(
_identity_connectivities(new_domain, [dim], cls=type(connectivity))
)
conn_map[dim] = _identity_connectivity(new_domain, dim, cls=type(connectivity))

# Take data
take_indices = tuple(conn_map[dim].ndarray for dim in data.domain.dims)
Expand Down Expand Up @@ -654,36 +649,26 @@ def _remapping_premap(data: NdArrayField, connectivity: common.ConnectivityField
_ConnT = TypeVar("_ConnT", bound=common.ConnectivityField)


def _identity_connectivities(
domain: common.Domain,
codomains: Sequence[common.DimT],
*,
cls: type[_ConnT],
def _identity_connectivity(
domain: common.Domain, codomain: common.DimT, *, cls: type[_ConnT]
) -> tuple[_ConnT, ...]:
assert codomain in domain.dims
xp = cls.array_ns
shape = domain.shape
identities = []
for d in codomains:
assert d in domain.dims
d_idx = domain.dim_index(d)
indices = xp.arange(domain[d_idx].unit_range.start, domain[d_idx].unit_range.stop)
identities.append(
cls.from_array(
xp.broadcast_to(
indices[
tuple(
slice(None) if i == d_idx else None for i, dim in enumerate(domain.dims)
)
],
shape,
),
codomain=d,
domain=domain,
dtype=int,
)
)

return tuple(identities)
d_idx = domain.dim_index(codomain)
indices = xp.arange(domain[d_idx].unit_range.start, domain[d_idx].unit_range.stop)

return cls.from_array(
xp.broadcast_to(
indices[
tuple(slice(None) if i == d_idx else None for i, dim in enumerate(domain.dims))
],
shape,
),
codomain=codomain,
domain=domain,
dtype=int,
)


def _hyperslice(
Expand Down
29 changes: 14 additions & 15 deletions tests/next_tests/unit_tests/embedded_tests/test_nd_array_field.py
Original file line number Diff line number Diff line change
Expand Up @@ -391,7 +391,7 @@ def test_remapping_premap():
assert np.all(result.ndarray == expected.ndarray)


def test_identity_connectivities():
def test_identity_connectivity():
D0 = Dimension("D0")
D1 = Dimension("D1")
D2 = Dimension("D2")
Expand All @@ -402,12 +402,8 @@ def test_identity_connectivities():
)
codomains = [D0, D1, D2]

result = nd_array_field._identity_connectivities(
domain, codomains, cls=nd_array_field.NumPyArrayConnectivityField
)

expected = (
nd_array_field.NumPyArrayConnectivityField.from_array(
expected = {
D0: nd_array_field.NumPyArrayConnectivityField.from_array(
np.array(
[
[[0, 0, 0, 0, 0], [0, 0, 0, 0, 0], [0, 0, 0, 0, 0], [0, 0, 0, 0, 0]],
Expand All @@ -419,7 +415,7 @@ def test_identity_connectivities():
codomain=D0,
domain=domain,
),
nd_array_field.NumPyArrayConnectivityField.from_array(
D1: nd_array_field.NumPyArrayConnectivityField.from_array(
np.array(
[
[[0, 0, 0, 0, 0], [1, 1, 1, 1, 1], [2, 2, 2, 2, 2], [3, 3, 3, 3, 3]],
Expand All @@ -431,7 +427,7 @@ def test_identity_connectivities():
codomain=D1,
domain=domain,
),
nd_array_field.NumPyArrayConnectivityField.from_array(
D2: nd_array_field.NumPyArrayConnectivityField.from_array(
np.array(
[
[[0, 1, 2, 3, 4], [0, 1, 2, 3, 4], [0, 1, 2, 3, 4], [0, 1, 2, 3, 4]],
Expand All @@ -443,13 +439,16 @@ def test_identity_connectivities():
codomain=D2,
domain=domain,
),
)
}

for r, e in zip(result, expected):
assert r.codomain == e.codomain
assert r.domain == e.domain
assert r.dtype == e.dtype
assert np.all(r.ndarray == e.ndarray)
for codomain in codomains:
result = nd_array_field._identity_connectivity(
domain, codomain, cls=nd_array_field.NumPyArrayConnectivityField
)
assert result.codomain == expected[codomain].codomain
assert result.domain == expected[codomain].domain
assert result.dtype == expected[codomain].dtype
assert np.all(result.ndarray == expected[codomain].ndarray)


@pytest.mark.parametrize(
Expand Down

0 comments on commit 8d4cf62

Please sign in to comment.