Skip to content

Commit

Permalink
Merge branch 'main' into decouple_inferences
Browse files Browse the repository at this point in the history
  • Loading branch information
SF-N authored Feb 24, 2025
2 parents 9e48581 + 1984691 commit 2820315
Show file tree
Hide file tree
Showing 12 changed files with 172 additions and 26 deletions.
1 change: 1 addition & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@ _local
/src/__init__.py
/tests/__init__.py
.gt_cache/
.gt4py_cache/
.gt_cache_pytest*/

# DaCe
Expand Down
2 changes: 1 addition & 1 deletion ci/cscs-ci.yml
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,7 @@ stages:
DOCKERFILE: ci/base.Dockerfile
# change to 'always' if you want to rebuild, even if target tag exists already (if-not-exists is the default, i.e. we could also skip the variable)
CSCS_REBUILD_POLICY: if-not-exists
DOCKER_BUILD_ARGS: '["CUDA_VERSION=$CUDA_VERSION", "CUPY_PACKAGE=$CUPY_PACKAGE", "CUPY_VERSION=$CUPY_VERSION", "UBUNTU_VERSION=$UBUNTU_VERSION", "PYVERSION=$PYVERSION", "CI_PROJECT_DIR=$CI_PROJECT_DIR"]'
DOCKER_BUILD_ARGS: '["CUDA_VERSION=$CUDA_VERSION", "CUPY_PACKAGE=$CUPY_PACKAGE", "CUPY_VERSION=$CUPY_VERSION", "UBUNTU_VERSION=$UBUNTU_VERSION", "PYVERSION=$PYVERSION"]'
.build_baseimage_x86_64:
extends: [.container-builder-cscs-zen2, .build_baseimage]
variables:
Expand Down
10 changes: 5 additions & 5 deletions noxfile.py
Original file line number Diff line number Diff line change
Expand Up @@ -87,7 +87,7 @@ def test_cartesian(
groups=["test"],
)

num_processes = session.env.get("NUM_PROCESSES", "auto")
num_processes = os.environ.get("NUM_PROCESSES", "auto")
markers = " and ".join(codegen_settings["markers"] + device_settings["markers"])

session.run(
Expand All @@ -111,7 +111,7 @@ def test_examples(session: nox.Session) -> None:
session.run(*"jupytext docs/user/next/QuickstartGuide.md --to .ipynb".split())
session.run(*"jupytext docs/user/next/advanced/*.md --to .ipynb".split())

num_processes = session.env.get("NUM_PROCESSES", "auto")
num_processes = os.environ.get("NUM_PROCESSES", "auto")
for notebook, extra_args in [
("docs/user/next/workshop/slides", None),
("docs/user/next/workshop/exercises", ["-k", "solutions"]),
Expand All @@ -131,7 +131,7 @@ def test_eve(session: nox.Session) -> None:

_install_session_venv(session, groups=["test"])

num_processes = session.env.get("NUM_PROCESSES", "auto")
num_processes = os.environ.get("NUM_PROCESSES", "auto")

session.run(
*f"pytest --cache-clear -sv -n {num_processes}".split(),
Expand Down Expand Up @@ -180,7 +180,7 @@ def test_next(
groups=groups,
)

num_processes = session.env.get("NUM_PROCESSES", "auto")
num_processes = os.environ.get("NUM_PROCESSES", "auto")
markers = " and ".join(codegen_settings["markers"] + device_settings["markers"] + mesh_markers)

session.run(
Expand Down Expand Up @@ -211,7 +211,7 @@ def test_storage(
session, extras=["performance", "testing", *device_settings["extras"]], groups=["test"]
)

num_processes = session.env.get("NUM_PROCESSES", "auto")
num_processes = os.environ.get("NUM_PROCESSES", "auto")
markers = " and ".join(device_settings["markers"])

session.run(
Expand Down
2 changes: 1 addition & 1 deletion src/gt4py/cartesian/gtc/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -118,7 +118,7 @@ def isbool(self):
return self == self.BOOL

def isinteger(self):
return self in (self.INT8, self.INT32, self.INT64)
return self in (self.INT8, self.INT16, self.INT32, self.INT64)

def isfloat(self):
return self in (self.FLOAT32, self.FLOAT64)
Expand Down
8 changes: 7 additions & 1 deletion src/gt4py/cartesian/gtc/dace/daceir.py
Original file line number Diff line number Diff line change
Expand Up @@ -734,7 +734,13 @@ class ScalarAccess(common.ScalarAccess, Expr):


class VariableKOffset(common.VariableKOffset[Expr]):
pass
@datamodels.validator("k")
def no_casts_in_offset_expression(self, _: datamodels.Attribute, expression: Expr) -> None:
for part in expression.walk_values():
if isinstance(part, Cast):
raise ValueError(
"DaCe backends are currently missing support for casts in variable k offsets. See issue https://github.com/GridTools/gt4py/issues/1881."
)


class IndexAccess(common.FieldAccess, Expr):
Expand Down
4 changes: 3 additions & 1 deletion src/gt4py/cartesian/gtc/dace/expansion/tasklet_codegen.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,9 @@ def _visit_offset(
else:
int_sizes.append(None)
sym_offsets = [
dace.symbolic.pystr_to_symbolic(self.visit(off, **kwargs))
dace.symbolic.pystr_to_symbolic(
self.visit(off, access_info=access_info, decl=decl, **kwargs)
)
for off in (node.to_dict()["i"], node.to_dict()["j"], node.k)
]
for axis in access_info.variable_offset_axes:
Expand Down
39 changes: 30 additions & 9 deletions src/gt4py/next/program_processors/runners/dace/gtir_dataflow.py
Original file line number Diff line number Diff line change
Expand Up @@ -232,15 +232,36 @@ def connect(
dest: dace.nodes.AccessNode,
subset: dace_subsets.Range,
) -> None:
# retrieve the node which writes the result
last_node = self.state.in_edges(self.result.dc_node)[0].src
if isinstance(last_node, dace.nodes.Tasklet):
# the last transient node can be deleted
# Note that it could also be applied when `last_node` is a NestedSDFG,
# but an exception would be when the inner write to global data is a
# WCR memlet, because that prevents fusion of the outer map. This case
# happens for the reduce with skip values, which uses a map with WCR.
last_node_connector = self.state.in_edges(self.result.dc_node)[0].src_conn
write_edge = self.state.in_edges(self.result.dc_node)[0]
write_size = write_edge.data.dst_subset.num_elements()
# check the kind of node which writes the result
if isinstance(write_edge.src, dace.nodes.Tasklet):
# The temporary data written by a tasklet can be safely deleted
assert write_size.is_constant()
remove_last_node = True
elif isinstance(write_edge.src, dace.nodes.NestedSDFG):
if write_size.is_constant():
# Temporary data with compile-time size is allocated on the stack
# and therefore is safe to keep. We decide to keep it as a workaround
# for a dace issue with memlet propagation in combination with
# nested SDFGs containing conditional blocks. The output memlet
# of such blocks will be marked as dynamic because dace is not able
# to detect the exact size of a conditional branch dataflow, even
# in case of if-else expressions with exact same output data.
remove_last_node = False
else:
# In case the output data has runtime size it is necessary to remove
# it in order to avoid dynamic memory allocation inside a parallel
# map scope. Otherwise, the memory allocation will for sure lead
# to performance degradation, and eventually illegal memory issues
# when the gpu runs out of local memory.
remove_last_node = True
else:
remove_last_node = False

if remove_last_node:
last_node = write_edge.src
last_node_connector = write_edge.src_conn
self.state.remove_node(self.result.dc_node)
else:
last_node = self.result.dc_node
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -66,11 +66,11 @@
"less_equal": "({} <= {})",
"greater": "({} > {})",
"greater_equal": "({} >= {})",
"and_": "({} & {})",
"or_": "({} | {})",
"xor_": "({} ^ {})",
"and_": "({} and {})",
"or_": "({} or {})",
"xor_": "({} != {})",
"mod": "({} % {})",
"not_": "(not {})", # ~ is not bitwise in numpy
"not_": "(not {})",
}


Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -718,9 +718,9 @@ def apply(self, graph: Union[dace.SDFGState, dace.SDFG], sdfg: dace.SDFG) -> Non
if not self.do_not_fuse:
gtx_transformations.MapFusionSerial.apply_to(
sdfg=sdfg,
map_exit_1=trivial_map_exit,
intermediate_access_node=access_node,
map_entry_2=second_map_entry,
first_map_exit=trivial_map_exit,
array=access_node,
second_map_entry=second_map_entry,
verify=True,
)

Expand Down
6 changes: 6 additions & 0 deletions tests/cartesian_tests/definitions.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,6 +51,12 @@ def _get_backends_with_storage_info(storage_info_kind: str):
_PERFORMANCE_BACKEND_NAMES = [name for name in _ALL_BACKEND_NAMES if name not in ("numpy", "cuda")]
PERFORMANCE_BACKENDS = [_backend_name_as_param(name) for name in _PERFORMANCE_BACKEND_NAMES]

DACE_BACKENDS = [
_backend_name_as_param(name)
for name in filter(lambda name: name.startswith("dace:"), _ALL_BACKEND_NAMES)
]
NON_DACE_BACKENDS = [backend for backend in ALL_BACKENDS if backend not in DACE_BACKENDS]


@pytest.fixture()
def id_version():
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,13 @@
)
from gt4py.storage.cartesian import utils as storage_utils

from cartesian_tests.definitions import ALL_BACKENDS, CPU_BACKENDS, get_array_library
from cartesian_tests.definitions import (
ALL_BACKENDS,
CPU_BACKENDS,
DACE_BACKENDS,
NON_DACE_BACKENDS,
get_array_library,
)
from cartesian_tests.integration_tests.multi_feature_tests.stencil_definitions import (
EXTERNALS_REGISTRY as externals_registry,
REGISTRY as stencil_definitions,
Expand Down Expand Up @@ -762,3 +768,89 @@ def test(
out_arr = gt_storage.ones(backend=backend, shape=domain, dtype=np.float64)
test(in_arr, out_arr)
assert (out_arr[:, :, :] == 388.0).all()


@pytest.mark.parametrize("backend", NON_DACE_BACKENDS)
def test_cast_in_index(backend):
@gtscript.stencil(backend)
def cast_in_index(
in_field: Field[np.float64], i32: np.int32, i64: np.int64, out_field: Field[np.float64]
):
"""Simple copy stencil with forced cast in index calculation."""
with computation(PARALLEL), interval(...):
out_field = in_field[0, 0, i32 - i64]


@pytest.mark.parametrize("backend", DACE_BACKENDS)
@pytest.mark.xfail(raises=ValueError)
def test_dace_no_cast_in_index(backend):
@gtscript.stencil(backend)
def cast_in_index(
in_field: Field[np.float64], i32: np.int32, i64: np.int64, out_field: Field[np.float64]
):
"""Simple copy stencil with forced cast in index calculation."""
with computation(PARALLEL), interval(...):
out_field = in_field[0, 0, i32 - i64]


@pytest.mark.parametrize("backend", ALL_BACKENDS)
def test_read_after_write_stencil(backend):
"""Stencil with multiple read after write access patterns."""

@gtscript.stencil(backend=backend)
def lagrangian_contributions(
q: Field[np.float64],
pe1: Field[np.float64],
pe2: Field[np.float64],
q4_1: Field[np.float64],
q4_2: Field[np.float64],
q4_3: Field[np.float64],
q4_4: Field[np.float64],
dp1: Field[np.float64],
lev: gtscript.Field[gtscript.IJ, np.int64],
):
"""
Args:
q (out):
pe1 (in):
pe2 (in):
q4_1 (in):
q4_2 (in):
q4_3 (in):
q4_4 (in):
dp1 (in):
lev (inout):
"""
with computation(FORWARD), interval(...):
pl = (pe2 - pe1[0, 0, lev]) / dp1[0, 0, lev]
if pe2[0, 0, 1] <= pe1[0, 0, lev + 1]:
pr = (pe2[0, 0, 1] - pe1[0, 0, lev]) / dp1[0, 0, lev]
q = (
q4_2[0, 0, lev]
+ 0.5 * (q4_4[0, 0, lev] + q4_3[0, 0, lev] - q4_2[0, 0, lev]) * (pr + pl)
- q4_4[0, 0, lev] * 1.0 / 3.0 * (pr * (pr + pl) + pl * pl)
)
else:
qsum = (pe1[0, 0, lev + 1] - pe2) * (
q4_2[0, 0, lev]
+ 0.5 * (q4_4[0, 0, lev] + q4_3[0, 0, lev] - q4_2[0, 0, lev]) * (1.0 + pl)
- q4_4[0, 0, lev] * 1.0 / 3.0 * (1.0 + pl * (1.0 + pl))
)
lev = lev + 1
while pe1[0, 0, lev + 1] < pe2[0, 0, 1]:
qsum += dp1[0, 0, lev] * q4_1[0, 0, lev]
lev = lev + 1
dp = pe2[0, 0, 1] - pe1[0, 0, lev]
esl = dp / dp1[0, 0, lev]
qsum += dp * (
q4_2[0, 0, lev]
+ 0.5
* esl
* (
q4_3[0, 0, lev]
- q4_2[0, 0, lev]
+ q4_4[0, 0, lev] * (1.0 - (2.0 / 3.0) * esl)
)
)
q = qsum / (pe2[0, 0, 1] - pe2)
lev = lev - 1
18 changes: 18 additions & 0 deletions tests/cartesian_tests/unit_tests/test_gtc/test_common.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,24 @@
# - For testing non-leave nodes, introduce builders with defaults (for leave nodes as well)


def test_data_type_methods():
for type in DataType:
if type == DataType.BOOL:
assert type.isbool()
else:
assert not type.isbool()

if type in (DataType.INT8, DataType.INT16, DataType.INT32, DataType.INT64):
assert type.isinteger()
else:
assert not type.isinteger()

if type in (DataType.FLOAT32, DataType.FLOAT64):
assert type.isfloat()
else:
assert not type.isfloat()


class DummyExpr(Expr):
"""Fake expression for cases where a concrete expression is not needed."""

Expand Down

0 comments on commit 2820315

Please sign in to comment.