Skip to content

Commit

Permalink
fix[cartesian, dace]: warn about missing support for casting in varia…
Browse files Browse the repository at this point in the history
…ble k offsets (#1882)

<!--
Delete this comment and add a proper description of the changes
contained in this PR. The text here will be used in the commit message
since the approved PRs are always squash-merged. The preferred format
is:

- PR Title: <type>[<scope>]: <one-line-summary>

    <type>:
- build: Changes that affect the build system or external dependencies
        - ci: Changes to our CI configuration files and scripts
        - docs: Documentation only changes
        - feat: A new feature
        - fix: A bug fix
        - perf: A code change that improves performance
- refactor: A code change that neither fixes a bug nor adds a feature
        - style: Changes that do not affect the meaning of the code
        - test: Adding missing tests or correcting existing tests

    <scope>: cartesian | eve | next | storage
    # ONLY if changes are limited to a specific subsystem

- PR Description:

Description of the main changes with links to appropriate
issues/documents/references/...
-->

## Description

We figured that DaCe backends are currently missing support for casting
in variable k offsets. This PR

- adds a codegen test with a cast in a variable k offset
- adds a node validator for the DaCe backends complaining about missing
for support.
- adds an `xfail` test for the node validator

This should be fixed down the road. Here's the issue
#1881 to keep track.

The PR also has two smaller and unrelated commits

- 741c448 increases test coverage with
another codgen test that has a couple of read after write access
patterns which were breaking the "new bridge" (see
GEOS-ESM/NDSL#53).
- e98ddc5 just forwards all keyword
arguments when visiting offsets. I don't think this was a problem until
now, but it's best practice to forward everything.

## Requirements

- [x] All fixes and/or new features come with corresponding tests.
- [ ] Important design decisions have been documented in the appropriate
ADR inside the [docs/development/ADRs/](docs/development/ADRs/Index.md)
folder.
  N/A

---------

Co-authored-by: Roman Cattaneo <1116746+romanc@users.noreply.github.com>
Co-authored-by: Florian Deconinck <deconinck.florian@gmail.com>
  • Loading branch information
3 people authored Feb 21, 2025
1 parent 1176b2d commit 1984691
Show file tree
Hide file tree
Showing 4 changed files with 109 additions and 3 deletions.
8 changes: 7 additions & 1 deletion src/gt4py/cartesian/gtc/dace/daceir.py
Original file line number Diff line number Diff line change
Expand Up @@ -734,7 +734,13 @@ class ScalarAccess(common.ScalarAccess, Expr):


class VariableKOffset(common.VariableKOffset[Expr]):
pass
@datamodels.validator("k")
def no_casts_in_offset_expression(self, _: datamodels.Attribute, expression: Expr) -> None:
for part in expression.walk_values():
if isinstance(part, Cast):
raise ValueError(
"DaCe backends are currently missing support for casts in variable k offsets. See issue https://github.com/GridTools/gt4py/issues/1881."
)


class IndexAccess(common.FieldAccess, Expr):
Expand Down
4 changes: 3 additions & 1 deletion src/gt4py/cartesian/gtc/dace/expansion/tasklet_codegen.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,9 @@ def _visit_offset(
else:
int_sizes.append(None)
sym_offsets = [
dace.symbolic.pystr_to_symbolic(self.visit(off, **kwargs))
dace.symbolic.pystr_to_symbolic(
self.visit(off, access_info=access_info, decl=decl, **kwargs)
)
for off in (node.to_dict()["i"], node.to_dict()["j"], node.k)
]
for axis in access_info.variable_offset_axes:
Expand Down
6 changes: 6 additions & 0 deletions tests/cartesian_tests/definitions.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,6 +51,12 @@ def _get_backends_with_storage_info(storage_info_kind: str):
_PERFORMANCE_BACKEND_NAMES = [name for name in _ALL_BACKEND_NAMES if name not in ("numpy", "cuda")]
PERFORMANCE_BACKENDS = [_backend_name_as_param(name) for name in _PERFORMANCE_BACKEND_NAMES]

DACE_BACKENDS = [
_backend_name_as_param(name)
for name in filter(lambda name: name.startswith("dace:"), _ALL_BACKEND_NAMES)
]
NON_DACE_BACKENDS = [backend for backend in ALL_BACKENDS if backend not in DACE_BACKENDS]


@pytest.fixture()
def id_version():
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,13 @@
)
from gt4py.storage.cartesian import utils as storage_utils

from cartesian_tests.definitions import ALL_BACKENDS, CPU_BACKENDS, get_array_library
from cartesian_tests.definitions import (
ALL_BACKENDS,
CPU_BACKENDS,
DACE_BACKENDS,
NON_DACE_BACKENDS,
get_array_library,
)
from cartesian_tests.integration_tests.multi_feature_tests.stencil_definitions import (
EXTERNALS_REGISTRY as externals_registry,
REGISTRY as stencil_definitions,
Expand Down Expand Up @@ -762,3 +768,89 @@ def test(
out_arr = gt_storage.ones(backend=backend, shape=domain, dtype=np.float64)
test(in_arr, out_arr)
assert (out_arr[:, :, :] == 388.0).all()


@pytest.mark.parametrize("backend", NON_DACE_BACKENDS)
def test_cast_in_index(backend):
@gtscript.stencil(backend)
def cast_in_index(
in_field: Field[np.float64], i32: np.int32, i64: np.int64, out_field: Field[np.float64]
):
"""Simple copy stencil with forced cast in index calculation."""
with computation(PARALLEL), interval(...):
out_field = in_field[0, 0, i32 - i64]


@pytest.mark.parametrize("backend", DACE_BACKENDS)
@pytest.mark.xfail(raises=ValueError)
def test_dace_no_cast_in_index(backend):
@gtscript.stencil(backend)
def cast_in_index(
in_field: Field[np.float64], i32: np.int32, i64: np.int64, out_field: Field[np.float64]
):
"""Simple copy stencil with forced cast in index calculation."""
with computation(PARALLEL), interval(...):
out_field = in_field[0, 0, i32 - i64]


@pytest.mark.parametrize("backend", ALL_BACKENDS)
def test_read_after_write_stencil(backend):
"""Stencil with multiple read after write access patterns."""

@gtscript.stencil(backend=backend)
def lagrangian_contributions(
q: Field[np.float64],
pe1: Field[np.float64],
pe2: Field[np.float64],
q4_1: Field[np.float64],
q4_2: Field[np.float64],
q4_3: Field[np.float64],
q4_4: Field[np.float64],
dp1: Field[np.float64],
lev: gtscript.Field[gtscript.IJ, np.int64],
):
"""
Args:
q (out):
pe1 (in):
pe2 (in):
q4_1 (in):
q4_2 (in):
q4_3 (in):
q4_4 (in):
dp1 (in):
lev (inout):
"""
with computation(FORWARD), interval(...):
pl = (pe2 - pe1[0, 0, lev]) / dp1[0, 0, lev]
if pe2[0, 0, 1] <= pe1[0, 0, lev + 1]:
pr = (pe2[0, 0, 1] - pe1[0, 0, lev]) / dp1[0, 0, lev]
q = (
q4_2[0, 0, lev]
+ 0.5 * (q4_4[0, 0, lev] + q4_3[0, 0, lev] - q4_2[0, 0, lev]) * (pr + pl)
- q4_4[0, 0, lev] * 1.0 / 3.0 * (pr * (pr + pl) + pl * pl)
)
else:
qsum = (pe1[0, 0, lev + 1] - pe2) * (
q4_2[0, 0, lev]
+ 0.5 * (q4_4[0, 0, lev] + q4_3[0, 0, lev] - q4_2[0, 0, lev]) * (1.0 + pl)
- q4_4[0, 0, lev] * 1.0 / 3.0 * (1.0 + pl * (1.0 + pl))
)
lev = lev + 1
while pe1[0, 0, lev + 1] < pe2[0, 0, 1]:
qsum += dp1[0, 0, lev] * q4_1[0, 0, lev]
lev = lev + 1
dp = pe2[0, 0, 1] - pe1[0, 0, lev]
esl = dp / dp1[0, 0, lev]
qsum += dp * (
q4_2[0, 0, lev]
+ 0.5
* esl
* (
q4_3[0, 0, lev]
- q4_2[0, 0, lev]
+ q4_4[0, 0, lev] * (1.0 - (2.0 / 3.0) * esl)
)
)
q = qsum / (pe2[0, 0, 1] - pe2)
lev = lev - 1

0 comments on commit 1984691

Please sign in to comment.