From c8ff8ed3160164e6d1d1ec74f766a2bcd1366cfd Mon Sep 17 00:00:00 2001 From: Hannes Vogt Date: Wed, 15 Nov 2023 16:47:29 +0100 Subject: [PATCH 01/85] bug[next] Fix broken gpu tox setup (#1358) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit And update gpu test which was broken in refactoring. --------- Co-authored-by: Rico Häuselmann --- .../ffront_tests/test_gpu_backend.py | 24 +++++++------------ tox.ini | 4 ++-- 2 files changed, 11 insertions(+), 17 deletions(-) diff --git a/tests/next_tests/integration_tests/feature_tests/ffront_tests/test_gpu_backend.py b/tests/next_tests/integration_tests/feature_tests/ffront_tests/test_gpu_backend.py index 80e9a8e07a..7054597831 100644 --- a/tests/next_tests/integration_tests/feature_tests/ffront_tests/test_gpu_backend.py +++ b/tests/next_tests/integration_tests/feature_tests/ffront_tests/test_gpu_backend.py @@ -15,6 +15,7 @@ import pytest import gt4py.next as gtx +from gt4py.next import common from gt4py.next.program_processors.runners import dace_iterator, gtfn from next_tests.integration_tests import cases @@ -26,26 +27,19 @@ @pytest.mark.requires_gpu @pytest.mark.parametrize("fieldview_backend", [dace_iterator.run_dace_gpu, gtfn.run_gtfn_gpu]) -def test_copy(cartesian_case, fieldview_backend): # noqa: F811 # fixtures +def test_copy(fieldview_backend): # noqa: F811 # fixtures import cupy as cp @gtx.field_operator(backend=fieldview_backend) def testee(a: cases.IJKField) -> cases.IJKField: return a - inp_arr = cp.full(shape=(3, 4, 5), fill_value=3, dtype=cp.int32) - outp_arr = cp.zeros_like(inp_arr) - inp = gtx.as_field([cases.IDim, cases.JDim, cases.KDim], inp_arr) - outp = gtx.as_field([cases.IDim, cases.JDim, cases.KDim], outp_arr) - - testee(inp, out=outp, offset_provider={}) - assert cp.allclose(inp_arr, outp_arr) - - inp_field = gtx.full( - [cases.IDim, cases.JDim, cases.KDim], fill_value=3, allocator=fieldview_backend - ) - out_field = gtx.zeros( - [cases.IDim, cases.JDim, cases.KDim], outp_arr, allocator=fieldview_backend - ) + domain = { + cases.IDim: common.unit_range(3), + cases.JDim: common.unit_range(4), + cases.KDim: common.unit_range(5), + } + inp_field = gtx.full(domain, fill_value=3, allocator=fieldview_backend, dtype=cp.int32) + out_field = gtx.zeros(domain, allocator=fieldview_backend, dtype=cp.int32) testee(inp_field, out=out_field, offset_provider={}) assert cp.allclose(inp_field.ndarray, out_field.ndarray) diff --git a/tox.ini b/tox.ini index 18a6ff8e84..5b644e7d97 100644 --- a/tox.ini +++ b/tox.ini @@ -82,9 +82,9 @@ set_env = PIP_EXTRA_INDEX_URL = {env:PIP_EXTRA_INDEX_URL:https://test.pypi.org/simple/} commands = nomesh-cpu: python -m pytest --cache-clear -v -n {env:NUM_PROCESSES:1} -m "not requires_atlas and not requires_gpu" {posargs} tests{/}next_tests - nomesh-gpu: python -m pytest --cache-clear -v -n {env:NUM_PROCESSES:1} -m "not requires_atlas and requires_gpu" {posargs} tests{/}next_tests + nomesh-{cuda,cuda11x,cuda12x}: python -m pytest --cache-clear -v -n {env:NUM_PROCESSES:1} -m "not requires_atlas and requires_gpu" {posargs} tests{/}next_tests atlas-cpu: python -m pytest --cache-clear -v -n {env:NUM_PROCESSES:1} -m "requires_atlas and not requires_gpu" {posargs} tests{/}next_tests - atlas-gpu: python -m pytest --cache-clear -v -n {env:NUM_PROCESSES:1} -m "requires_atlas and requires_gpu" {posargs} tests{/}next_tests + # atlas-{cuda,cuda11x,cuda12x}: python -m pytest --cache-clear -v -n {env:NUM_PROCESSES:1} -m "requires_atlas and requires_gpu" {posargs} tests{/}next_tests # TODO(ricoh): activate when such tests exist pytest --doctest-modules src{/}gt4py{/}next [testenv:storage-py{38,39,310}-{internal,dace}-{cpu,cuda,cuda11x,cuda12x}] From b8cda74e2eade6d2cfb8c9ee175e456dafa8adc8 Mon Sep 17 00:00:00 2001 From: Hannes Vogt Date: Thu, 16 Nov 2023 11:38:10 +0100 Subject: [PATCH 02/85] feat[next]: add `where` to embedded field view (#1316) - unifies unary and binary builtin to general nary in NdArrayField - special case for `where` with tuples --- src/gt4py/next/embedded/nd_array_field.py | 111 ++++++++---------- src/gt4py/next/ffront/fbuiltins.py | 23 +++- .../embedded_tests/test_nd_array_field.py | 54 +++++++++ 3 files changed, 128 insertions(+), 60 deletions(-) diff --git a/src/gt4py/next/embedded/nd_array_field.py b/src/gt4py/next/embedded/nd_array_field.py index 527197e0bc..ea88948841 100644 --- a/src/gt4py/next/embedded/nd_array_field.py +++ b/src/gt4py/next/embedded/nd_array_field.py @@ -15,6 +15,8 @@ from __future__ import annotations import dataclasses +import functools +import operator from collections.abc import Callable, Sequence from types import ModuleType from typing import Any, ClassVar, Optional, ParamSpec, TypeAlias, TypeVar @@ -39,40 +41,38 @@ jnp: Optional[ModuleType] = None # type:ignore[no-redef] -def _make_unary_array_field_intrinsic_func(builtin_name: str, array_builtin_name: str) -> Callable: - def _builtin_unary_op(a: NdArrayField) -> common.Field: - xp = a.__class__.array_ns +def _make_builtin(builtin_name: str, array_builtin_name: str) -> Callable[..., NdArrayField]: + def _builtin_op(*fields: common.Field | core_defs.Scalar) -> NdArrayField: + first = fields[0] + assert isinstance(first, NdArrayField) + xp = first.__class__.array_ns op = getattr(xp, array_builtin_name) - new_data = op(a.ndarray) - return a.__class__.from_array(new_data, domain=a.domain) - - _builtin_unary_op.__name__ = builtin_name - return _builtin_unary_op - - -def _make_binary_array_field_intrinsic_func(builtin_name: str, array_builtin_name: str) -> Callable: - def _builtin_binary_op(a: NdArrayField, b: common.Field) -> common.Field: - xp = a.__class__.array_ns - op = getattr(xp, array_builtin_name) - if hasattr(b, "__gt_builtin_func__"): # common.is_field(b): - if not a.domain == b.domain: - domain_intersection = a.domain & b.domain - a_broadcasted = _broadcast(a, domain_intersection.dims) - b_broadcasted = _broadcast(b, domain_intersection.dims) - a_slices = _get_slices_from_domain_slice(a_broadcasted.domain, domain_intersection) - b_slices = _get_slices_from_domain_slice(b_broadcasted.domain, domain_intersection) - new_data = op(a_broadcasted.ndarray[a_slices], b_broadcasted.ndarray[b_slices]) - return a.__class__.from_array(new_data, domain=domain_intersection) - new_data = op(a.ndarray, xp.asarray(b.ndarray)) - else: - assert isinstance(b, core_defs.SCALAR_TYPES) - new_data = op(a.ndarray, b) - - return a.__class__.from_array(new_data, domain=a.domain) - - _builtin_binary_op.__name__ = builtin_name - return _builtin_binary_op + domain_intersection = functools.reduce( + operator.and_, + [f.domain for f in fields if common.is_field(f)], + common.Domain(dims=tuple(), ranges=tuple()), + ) + transformed: list[core_defs.NDArrayObject | core_defs.Scalar] = [] + for f in fields: + if common.is_field(f): + if f.domain == domain_intersection: + transformed.append(xp.asarray(f.ndarray)) + else: + f_broadcasted = _broadcast(f, domain_intersection.dims) + f_slices = _get_slices_from_domain_slice( + f_broadcasted.domain, domain_intersection + ) + transformed.append(xp.asarray(f_broadcasted.ndarray[f_slices])) + else: + assert core_defs.is_scalar_type(f) + transformed.append(f) + + new_data = op(*transformed) + return first.__class__.from_array(new_data, domain=domain_intersection) + + _builtin_op.__name__ = builtin_name + return _builtin_op _Value: TypeAlias = common.Field | core_defs.ScalarT @@ -174,56 +174,50 @@ def restrict(self, index: common.AnyIndexSpec) -> common.Field | core_defs.Scala __call__ = None # type: ignore[assignment] # TODO: remap - __abs__ = _make_unary_array_field_intrinsic_func("abs", "abs") + __abs__ = _make_builtin("abs", "abs") - __neg__ = _make_unary_array_field_intrinsic_func("neg", "negative") + __neg__ = _make_builtin("neg", "negative") - __pos__ = _make_unary_array_field_intrinsic_func("pos", "positive") + __add__ = __radd__ = _make_builtin("add", "add") - __add__ = __radd__ = _make_binary_array_field_intrinsic_func("add", "add") + __pos__ = _make_builtin("pos", "positive") - __sub__ = __rsub__ = _make_binary_array_field_intrinsic_func("sub", "subtract") + __sub__ = __rsub__ = _make_builtin("sub", "subtract") - __mul__ = __rmul__ = _make_binary_array_field_intrinsic_func("mul", "multiply") + __mul__ = __rmul__ = _make_builtin("mul", "multiply") - __truediv__ = __rtruediv__ = _make_binary_array_field_intrinsic_func("div", "divide") + __truediv__ = __rtruediv__ = _make_builtin("div", "divide") - __floordiv__ = __rfloordiv__ = _make_binary_array_field_intrinsic_func( - "floordiv", "floor_divide" - ) + __floordiv__ = __rfloordiv__ = _make_builtin("floordiv", "floor_divide") - __pow__ = _make_binary_array_field_intrinsic_func("pow", "power") + __pow__ = _make_builtin("pow", "power") - __mod__ = __rmod__ = _make_binary_array_field_intrinsic_func("mod", "mod") + __mod__ = __rmod__ = _make_builtin("mod", "mod") def __and__(self, other: common.Field | core_defs.ScalarT) -> NdArrayField: if self.dtype == core_defs.BoolDType(): - return _make_binary_array_field_intrinsic_func("logical_and", "logical_and")( - self, other - ) + return _make_builtin("logical_and", "logical_and")(self, other) raise NotImplementedError("`__and__` not implemented for non-`bool` fields.") __rand__ = __and__ def __or__(self, other: common.Field | core_defs.ScalarT) -> NdArrayField: if self.dtype == core_defs.BoolDType(): - return _make_binary_array_field_intrinsic_func("logical_or", "logical_or")(self, other) + return _make_builtin("logical_or", "logical_or")(self, other) raise NotImplementedError("`__or__` not implemented for non-`bool` fields.") __ror__ = __or__ def __xor__(self, other: common.Field | core_defs.ScalarT) -> NdArrayField: if self.dtype == core_defs.BoolDType(): - return _make_binary_array_field_intrinsic_func("logical_xor", "logical_xor")( - self, other - ) + return _make_builtin("logical_xor", "logical_xor")(self, other) raise NotImplementedError("`__xor__` not implemented for non-`bool` fields.") __rxor__ = __xor__ def __invert__(self) -> NdArrayField: if self.dtype == core_defs.BoolDType(): - return _make_unary_array_field_intrinsic_func("invert", "invert")(self) + return _make_builtin("invert", "invert")(self) raise NotImplementedError("`__invert__` not implemented for non-`bool` fields.") def _slice( @@ -241,7 +235,7 @@ def _slice( return new_domain, slice_ -# -- Specialized implementations for intrinsic operations on array fields -- +# -- Specialized implementations for builtin operations on array fields -- NdArrayField.register_builtin_func(fbuiltins.abs, NdArrayField.__abs__) # type: ignore[attr-defined] NdArrayField.register_builtin_func(fbuiltins.power, NdArrayField.__pow__) # type: ignore[attr-defined] @@ -254,19 +248,18 @@ def _slice( ): if name in ["abs", "power", "gamma"]: continue - NdArrayField.register_builtin_func( - getattr(fbuiltins, name), _make_unary_array_field_intrinsic_func(name, name) - ) + NdArrayField.register_builtin_func(getattr(fbuiltins, name), _make_builtin(name, name)) NdArrayField.register_builtin_func( - fbuiltins.minimum, _make_binary_array_field_intrinsic_func("minimum", "minimum") # type: ignore[attr-defined] + fbuiltins.minimum, _make_builtin("minimum", "minimum") # type: ignore[attr-defined] ) NdArrayField.register_builtin_func( - fbuiltins.maximum, _make_binary_array_field_intrinsic_func("maximum", "maximum") # type: ignore[attr-defined] + fbuiltins.maximum, _make_builtin("maximum", "maximum") # type: ignore[attr-defined] ) NdArrayField.register_builtin_func( - fbuiltins.fmod, _make_binary_array_field_intrinsic_func("fmod", "fmod") # type: ignore[attr-defined] + fbuiltins.fmod, _make_builtin("fmod", "fmod") # type: ignore[attr-defined] ) +NdArrayField.register_builtin_func(fbuiltins.where, _make_builtin("where", "where")) def _np_cp_setitem( diff --git a/src/gt4py/next/ffront/fbuiltins.py b/src/gt4py/next/ffront/fbuiltins.py index 52aae34b3f..13c21eb516 100644 --- a/src/gt4py/next/ffront/fbuiltins.py +++ b/src/gt4py/next/ffront/fbuiltins.py @@ -132,6 +132,27 @@ def builtin_function(fun: Callable[_P, _R]) -> BuiltInFunction[_R, _P]: return BuiltInFunction(fun) +MaskT = TypeVar("MaskT", bound=Field) +FieldT = TypeVar("FieldT", bound=Union[Field, gt4py_defs.Scalar, Tuple]) + + +class WhereBuiltinFunction( + BuiltInFunction[_R, [MaskT, FieldT, FieldT]], Generic[_R, MaskT, FieldT] +): + def __call__(self, mask: MaskT, true_field: FieldT, false_field: FieldT) -> _R: + if isinstance(true_field, tuple) or isinstance(false_field, tuple): + if not (isinstance(true_field, tuple) and isinstance(false_field, tuple)): + raise ValueError( + f"Either both or none can be tuple in {true_field=} and {false_field=}." # TODO(havogt) find a strategy to unify parsing and embedded error messages + ) + if len(true_field) != len(false_field): + raise ValueError( + "Tuple of different size not allowed." + ) # TODO(havogt) find a strategy to unify parsing and embedded error messages + return tuple(where(mask, t, f) for t, f in zip(true_field, false_field)) # type: ignore[return-value] # `tuple` is not `_R` + return super().__call__(mask, true_field, false_field) + + @builtin_function def neighbor_sum( field: Field, @@ -164,7 +185,7 @@ def broadcast(field: Field | gt4py_defs.ScalarT, dims: Tuple[Dimension, ...], /) raise NotImplementedError() -@builtin_function +@WhereBuiltinFunction def where( mask: Field, true_field: Field | gt4py_defs.ScalarT | Tuple, diff --git a/tests/next_tests/unit_tests/embedded_tests/test_nd_array_field.py b/tests/next_tests/unit_tests/embedded_tests/test_nd_array_field.py index 8a4b4cbd84..49aeece87e 100644 --- a/tests/next_tests/unit_tests/embedded_tests/test_nd_array_field.py +++ b/tests/next_tests/unit_tests/embedded_tests/test_nd_array_field.py @@ -98,6 +98,60 @@ def test_math_function_builtins(builtin_name: str, inputs, nd_array_implementati assert np.allclose(result.ndarray, expected) +def test_where_builtin(nd_array_implementation): + cond = np.asarray([True, False]) + true_ = np.asarray([1.0, 2.0], dtype=np.float32) + false_ = np.asarray([3.0, 4.0], dtype=np.float32) + + field_inputs = [_make_field(inp, nd_array_implementation) for inp in [cond, true_, false_]] + expected = np.where(cond, true_, false_) + + result = fbuiltins.where(*field_inputs) + assert np.allclose(result.ndarray, expected) + + +def test_where_builtin_different_domain(nd_array_implementation): + cond = np.asarray([True, False]) + true_ = np.asarray([[1.0, 2.0, 3.0], [4.0, 5.0, 6.0]], dtype=np.float32) + false_ = np.asarray([7.0, 8.0, 9.0, 10.0], dtype=np.float32) + + cond_field = common.field( + nd_array_implementation.asarray(cond), domain=common.domain({JDim: 2}) + ) + true_field = common.field( + nd_array_implementation.asarray(true_), + domain=common.domain({IDim: common.UnitRange(0, 2), JDim: common.UnitRange(-1, 2)}), + ) + false_field = common.field( + nd_array_implementation.asarray(false_), + domain=common.domain({JDim: common.UnitRange(-1, 3)}), + ) + + expected = np.where(cond[np.newaxis, :], true_[:, 1:], false_[np.newaxis, 1:-1]) + + result = fbuiltins.where(cond_field, true_field, false_field) + assert np.allclose(result.ndarray, expected) + + +def test_where_builtin_with_tuple(nd_array_implementation): + cond = np.asarray([True, False]) + true0 = np.asarray([1.0, 2.0], dtype=np.float32) + false0 = np.asarray([3.0, 4.0], dtype=np.float32) + true1 = np.asarray([11.0, 12.0], dtype=np.float32) + false1 = np.asarray([13.0, 14.0], dtype=np.float32) + + expected0 = np.where(cond, true0, false0) + expected1 = np.where(cond, true1, false1) + + cond_field = _make_field(cond, nd_array_implementation, dtype=bool) + field_true = tuple(_make_field(inp, nd_array_implementation) for inp in [true0, true1]) + field_false = tuple(_make_field(inp, nd_array_implementation) for inp in [false0, false1]) + + result = fbuiltins.where(cond_field, field_true, field_false) + assert np.allclose(result[0].ndarray, expected0) + assert np.allclose(result[1].ndarray, expected1) + + def test_binary_arithmetic_ops(binary_arithmetic_op, nd_array_implementation): inp_a = [-1.0, 4.2, 42] inp_b = [2.0, 3.0, -3.0] From 87832eddc7ee0156ab97290b168e2499d7ab541b Mon Sep 17 00:00:00 2001 From: nfarabullini Date: Thu, 16 Nov 2023 13:29:38 +0100 Subject: [PATCH 03/85] presentation slides --- docs/user/next/presentation_slides.md | 411 ++++++++++++++++++++++++++ docs/user/next/scan_operator.png | Bin 0 -> 8760 bytes docs/user/next/simple_offset.png | Bin 0 -> 10292 bytes 3 files changed, 411 insertions(+) create mode 100644 docs/user/next/presentation_slides.md create mode 100644 docs/user/next/scan_operator.png create mode 100644 docs/user/next/simple_offset.png diff --git a/docs/user/next/presentation_slides.md b/docs/user/next/presentation_slides.md new file mode 100644 index 0000000000..87cd2b7787 --- /dev/null +++ b/docs/user/next/presentation_slides.md @@ -0,0 +1,411 @@ +--- +jupytext: + formats: ipynb,md:myst + text_representation: + extension: .md + format_name: myst + format_version: 0.13 + jupytext_version: 1.15.2 +kernelspec: + display_name: Python 3 (ipykernel) + language: python + name: python3 +--- + +# GT4Py workshop + ++++ + +## GT4Py: GridTools for Python + +GT4Py is a Python library for generating high performance implementations of stencil kernels from a high-level definition using regular Python functions. + +GT4Py is part of the GridTools framework: a set of libraries and utilities to develop performance portable applications in the area of weather and climate modeling. + +**NOTE:** The `gt4py.next` subpackage contains a new and currently experimental version of GT4Py. + +## Description + +GT4Py is a Python library for expressing computational motifs as found in weather and climate applications. + +These computations are expressed in a domain specific language (GTScript) which is translated to high-performance implementations for CPUs and GPUs. + +The DSL expresses computations on a 3-dimensional Cartesian grid. The horizontal axes are always computed in parallel, while the vertical can be iterated in sequential, forward or backward, order. + +In addition, GT4Py provides functions to allocate arrays with memory layout suited for a particular backend. + +The following backends are supported: + +- `numpy`: Pure-Python backend +- `gt:cpu_ifirst`: GridTools C++ CPU backend using `I`-first data ordering +- `gt:cpu_kfirst`: GridTools C++ CPU backend using `K`-first data ordering +- `gt:gpu`: GridTools backend for CUDA +- `cuda`: CUDA backend minimally using utilities from GridTools +- `dace:cpu`: Dace code-generated CPU backend +- `dace:gpu`: Dace code-generated GPU backend + ++++ + +## Installation + +You can install the library directly from GitHub using pip: + +```{raw-cell} +pip install --upgrade git+https://github.com/gridtools/gt4py.git +``` + +```{code-cell} ipython3 +import warnings +warnings.filterwarnings('ignore') +``` + +```{code-cell} ipython3 +import numpy as np +import gt4py.next as gtx +from gt4py.next import float64, neighbor_sum, where +from gt4py.next.common import DimensionKind +``` + +## Key concepts and application structure + +- [Fields](#Fields), +- [Field operators](#Field-operators), and +- [Programs](#Programs). + ++++ + +### Fields +Fields are **multi-dimensional array** defined over a set of dimensions and a dtype: `gtx.Field[[dimensions], dtype]` + +The `as_field` builtin is used to define fields + +```{code-cell} ipython3 +CellDim = gtx.Dimension("Cell") +KDim = gtx.Dimension("K", kind=DimensionKind.VERTICAL) +grid_shape = (5, 6) +a = gtx.as_field([CellDim, KDim], np.full(shape=grid_shape, fill_value=2.0, dtype=np.float64)) +b = gtx.as_field([CellDim, KDim], np.full(shape=grid_shape, fill_value=3.0, dtype=np.float64)) + +print("a definition: \n {}".format(a)) +print("a array: \n {}".format(np.asarray(a))) +print("b array: \n {}".format(np.asarray(b))) +``` + +### Field operators + +Field operators perform operations on a set of fields, i.e. elementwise addition or reduction along a dimension. + +They are written as Python functions by using the `@field_operator` decorator. + +```{code-cell} ipython3 +@gtx.field_operator +def add(a: gtx.Field[[CellDim, KDim], float64], + b: gtx.Field[[CellDim, KDim], float64]) -> gtx.Field[[CellDim, KDim], float64]: + return a + b +``` + +Direct calls to field operators require two additional arguments: +- `out`: a field to write the return value to +- `offset_provider`: empty dict for now, explanation will follow + +```{code-cell} ipython3 +result = gtx.as_field([CellDim, KDim], np.zeros(shape=grid_shape)) +add(a, b, out=result, offset_provider={}) + +print("result array \n {}".format(np.asarray(result))) +``` + +### Programs + ++++ + +Programs are used to call field operators to mutate their arguments. + +They are written as Python functions by using the `@program` decorator. + +This example below calls the `add` field operator twice: + +```{code-cell} ipython3 +# @gtx.field_operator +# def add(a, b): +# return a + b + +@gtx.program +def run_add(a : gtx.Field[[CellDim, KDim], float64], + b : gtx.Field[[CellDim, KDim], float64], + result : gtx.Field[[CellDim, KDim], float64]): + add(a, b, out=result) # 2.0 + 3.0 = 5.0 + add(b, result, out=result) # 5.0 + 3.0 = 8.0 +``` + +```{code-cell} ipython3 +result = gtx.as_field([CellDim, KDim], np.zeros(shape=grid_shape)) +run_add(a, b, result, offset_provider={}) + +print("result array: \n {}".format(np.asarray(result))) +``` + +The fields in the subsequent code snippets are 1-dimensional, either over the cells or over the edges. The corresponding named dimensions are thus the following: + ++++ + +### Offsets +Fields can be offset by a predefined number of indices. + +Take an array with values ranging from 0 to 5: + +```{code-cell} ipython3 +a_off = gtx.as_field([CellDim], np.array([1.0, 1.0, 2.0, 3.0, 5.0, 8.0])) + +print("a_off array: \n {}".format(np.asarray(a_off))) +``` + +Visually, offsetting this field by 1 would result in the following: + +| ![Coff](simple_offset.png) | +| :------------------------: | +| _CellDim Offset (Coff)_ | + ++++ + +Fields can be offeset by a predefined number of indices. + +Take an array with values ranging from 0 to 5: + +```{code-cell} ipython3 +Coff = gtx.FieldOffset("Coff", source=CellDim, target=(CellDim,)) + +@gtx.field_operator +def a_offset(a_off: gtx.Field[[CellDim], float64]) -> gtx.Field[[CellDim], float64]: + return a_off(Coff[1]) + +a_offset(a_off, out=a_off, offset_provider={"Coff": CellDim}) +print("result array: \n {}".format(np.asarray(a_off))) +``` + +## Defining the mesh and its connectivities +Take an unstructured mesh with numbered cells (in red) and edges (in blue). + +| ![grid_topo](connectivity_numbered_grid.svg) | +| :------------------------------------------: | +| _The mesh with the indices_ | + +```{code-cell} ipython3 +CellDim = gtx.Dimension("Cell") +EdgeDim = gtx.Dimension("Edge") +``` + +Connectivityy among mesh elements is expressed through connectivity tables. + +For example, `e2c_table` lists for each edge its adjacent rows. + +Similarly, `c2e_table` lists the edges that are neighbors to a particular cell. + +Note that if an edge is lying at the border, one entry will be filled with -1. + +```{code-cell} ipython3 +e2c_table = np.array([ + [0, -1], # edge 0 (neighbours: cell 0) + [2, -1], # edge 1 + [2, -1], # edge 2 + [3, -1], # edge 3 + [4, -1], # edge 4 + [5, -1], # edge 5 + [0, 5], # edge 6 (neighbours: cell 0, cell 5) + [0, 1], # edge 7 + [1, 2], # edge 8 + [1, 3], # edge 9 + [3, 4], # edge 10 + [4, 5] # edge 11 +]) + +c2e_table = np.array([ + [0, 6, 7], # cell 0 (neighbors: edge 0, edge 6, edge 7) + [7, 8, 9], # cell 1 + [1, 2, 8], # cell 2 + [3, 9, 10], # cell 3 + [4, 10, 11], # cell 4 + [5, 6, 11], # cell 5 +]) +``` + +#### Using connectivities in field operators + +Let's start by defining two fields: one over the cells and another one over the edges. The field over cells serves input for subsequent calculations and is therefore filled up with values, whereas the field over the edges stores the output of the calculations and is therefore left blank. + +```{code-cell} ipython3 +cell_field = gtx.as_field([CellDim], np.array([1.0, 1.0, 2.0, 3.0, 5.0, 8.0])) +edge_field = gtx.as_field([EdgeDim], np.zeros((12,))) +``` + +| ![cell_values](connectivity_cell_field.svg) | +| :-----------------------------------------: | +| _Cell values_ | + ++++ + +`field_offset` is used as an argument to transform fields over one domain to another domain. + +For example, `E2C` can be used to shift a field over cells to edges with the following dimension transformation: + +[CellDim] -> CellDim(E2C) -> [EdgeDim, E2CDim] + +A field with an offset dimension is called a sparse field + +```{code-cell} ipython3 +E2CDim = gtx.Dimension("E2C", kind=gtx.DimensionKind.LOCAL) +E2C = gtx.FieldOffset("E2C", source=CellDim, target=(EdgeDim, E2CDim)) +``` + +```{code-cell} ipython3 +E2C_offset_provider = gtx.NeighborTableOffsetProvider(e2c_table, EdgeDim, CellDim, 2) +``` + +```{code-cell} ipython3 +@gtx.field_operator +def nearest_cell_to_edge(cell_field: gtx.Field[[CellDim], float64]) -> gtx.Field[[EdgeDim], float64]: + return cell_field(E2C[0]) # 0th index to isolate edge dimension + +@gtx.program +def run_nearest_cell_to_edge(cell_field: gtx.Field[[CellDim], float64], edge_field: gtx.Field[[EdgeDim], float64]): + nearest_cell_to_edge(cell_field, out=edge_field) + +run_nearest_cell_to_edge(cell_field, edge_field, offset_provider={"E2C": E2C_offset_provider}) + +print("0th adjacent cell's value: {}".format(np.asarray(edge_field))) +``` + +Running the above snippet results in the following edge field: + +| ![nearest_cell_values](connectivity_numbered_grid.svg) | $\mapsto$ | ![grid_topo](connectivity_edge_0th_cell.svg) | +| :----------------------------------------------------: | :-------: | :------------------------------------------: | +| _Domain (edges)_ | | _Edge values_ | + ++++ + +### Using reductions on connected mesh elements + +To sum up all the cells adjacent to an edge the `neighbor_sum` builtin function can be called to operate along the `E2CDim` dimension. + +```{code-cell} ipython3 +@gtx.field_operator +def sum_adjacent_cells(cell_field : gtx.Field[[CellDim], float64]) -> gtx.Field[[EdgeDim], float64]: + return neighbor_sum(cell_field(E2C), axis=E2CDim) + +@gtx.program +def run_sum_adjacent_cells(cell_field : gtx.Field[[CellDim], float64], edge_field: gtx.Field[[EdgeDim], float64]): + sum_adjacent_cells(cell_field, out=edge_field) + +run_sum_adjacent_cells(cell_field, edge_field, offset_provider={"E2C": E2C_offset_provider}) + +print("sum of adjacent cells: {}".format(np.asarray(edge_field))) +``` + +For the border edges, the results are unchanged compared to the previous example, but the inner edges now contain the sum of the two adjacent cells: + +| ![nearest_cell_values](connectivity_numbered_grid.svg) | $\mapsto$ | ![cell_values](connectivity_edge_cell_sum.svg) | +| :----------------------------------------------------: | :-------: | :--------------------------------------------: | +| _Domain (edges)_ | | _Edge values_ | + ++++ + +#### Using conditionals on fields + +To filter operations such that they are performed on only certain cells instead of the whole field, the `where` builtin was developed. + +This function takes 3 input arguments: +- mask: a field of booleans or an expression evaluating to this type +- true branch: a tuple, a field, or a scalar +- false branch: a tuple, a field, of a scalar + +```{code-cell} ipython3 +mask = gtx.as_field([CellDim], np.zeros(shape=grid_shape[0], dtype=bool)) +result = gtx.as_field([CellDim], np.zeros(shape=grid_shape[0])) +b = 6.0 + +@gtx.field_operator +def conditional(mask: gtx.Field[[CellDim], bool], cell_field: gtx.Field[[CellDim], float64], b: float +) -> gtx.Field[[CellDim], float64]: + return where(mask, cell_field, b) + +conditional(mask, cell_field, b, out=result, offset_provider={}) +print("where return: {}".format(np.asarray(result))) +``` + +#### Using domain on fields + +Another way to filter parts of a field where to perform operations, is to use the `domain` keyword argument when calling the field operator. + +Note: domain needs both dimensions to be included with integer tuple values. + +```{code-cell} ipython3 +# @gtx.field_operator +# def add(a, b): +# return a + b + +@gtx.program +def run_add_domain(a : gtx.Field[[CellDim, KDim], float64], + b : gtx.Field[[CellDim, KDim], float64], + result : gtx.Field[[CellDim, KDim], float64]): + add(a, b, out=result, domain={CellDim: (1, 3), KDim: (1, 4)}) +``` + +```{code-cell} ipython3 +a = gtx.as_field([CellDim, KDim], np.full(shape=grid_shape, fill_value=2.0, dtype=np.float64)) +b = gtx.as_field([CellDim, KDim], np.full(shape=grid_shape, fill_value=3.0, dtype=np.float64)) +result = gtx.as_field([CellDim, KDim], np.zeros(shape=grid_shape)) +run_add_domain(a, b, result, offset_provider={}) + +print("result array: \n {}".format(np.asarray(result))) +``` + +#### Scan operators + +Scan operators work in a similar fashion to iterations in Python. + +```{code-cell} ipython3 +x = np.asarray([1.0, 2.0, 4.0, 6.0, 0.0, 2.0, 5.0]) +def x_iteration(x): + for i, x_i in enumerate(x): + if i > 0: + x[i] = x[i-1] + x[i] + return x + +print("result array: \n {}".format(x_iteration(x))) +``` + +Visually, this is what `x_iteration` is doing: + +| ![scan_operator](scan_operator.png) | +| :---------------------------------: | +| _Iterative sum over K_ | + ++++ + +`scan_operators` allow for the same computations and only require a return statement for the operation, for loops and indexing are handled in the background. The return state of the previous iteration is provided as its first argument. + +This decorator takes 3 input arguments: +- `axis`: vertical axis over which operations have to be performed +- `forward`: True if order of operations is from bottom to top, False if from top to bottom +- `init`: initialized decorator value with type float or tuple thereof + +```{code-cell} ipython3 +@gtx.scan_operator(axis=KDim, forward=True, init=0.0) +def add_scan(state: float, k: float) -> float: + return state + k +``` + +```{code-cell} ipython3 +k_field = gtx.as_field([KDim], np.asarray([1.0, 2.0, 4.0, 6.0, 0.0, 2.0, 5.0])) +result = gtx.as_field([KDim], np.zeros(shape=(7,))) + +add_scan(k_field, out=result, offset_provider={}) # Note: `state` is not an input here + +print("result array: \n {}".format(np.asarray(result))) +``` + +Note: `scan_operators` can be called from `field_operators` and `programs`. Likewise, `field_operators` can be called from `scan_operators` + +```{code-cell} ipython3 + +``` diff --git a/docs/user/next/scan_operator.png b/docs/user/next/scan_operator.png new file mode 100644 index 0000000000000000000000000000000000000000..f0c1d03636b2758296da39a29251c2adc5b321d3 GIT binary patch literal 8760 zcmb7q2{_d4`>#YsMPmuAjJ1en7{e%HW{hEGY%^xji_F3dGiHn#Tb3}1BD=IuM9Ef) zNGq=*T2!o>O|GBPnF4uh5XStvIxj*;kzVEM0Z%;S1mFrh3 zC@8486L7xZT&|$7Y#n3;xcd91>w0in7Uk=PQD}Iu@somr%1)^(Ln@37;S0D5CN|iA zuS~4rJW-U?#0F<#Z5=EYTXFcop^{)>lvM;*3Lb&?LJ@}_!sl}SZDS3$wy{JYEDzQ z63(HqX|4iB2%L@){kxiIt|W>tiuljst*sDNpyyI&X`Gn*?^6ypn$HFkTASFo0>%DE zgP=)7Fp%RP&3Xf+SR(%I$p4QfgJ_752!FUWGME+XCBY$3{+2*fOM}?aqW#(8s2IFA zJeVMcW1=`b6Kf3GBNz_17m&SV;o(%8J&MTZxncd7uC|_GG6aGh7sHR?BhkJBES2W% zD+u9J#If;Vwpdvhk>pL2$|SBXESV?6%|{%AqghiVD25B(#Sclqhe&t=JR&M8o*C*H z7wcjt65u0cWIE85AA=YXMMTAhqCD6>@JM%G8#fQL=TY-J{;~N(0VsT$uz?C>V|F=RsidxZWaC zya?y&#tpR>hqFjRFA+s7kTA$pMsTc@iUi99ZApH?_-H#Zcm-$B$Nd+EJEVlQMx(g_S7Is)wzOsClqnJ5A(Qc6OF!?{$@n-)T& zuo!p>n#zwwgL|&KeONG(&asR36$5Z%aCXsnmv97!CZGcg#9`2^7!*0iFC^Z^)y10> zA7SkoN@nrc@h(1WrZpNH%x8-HI6SodQdeJB7C<~c4#BafdEjEgtWmzm2sSE`3ET(9 zRzmd)59hkE@O*bxxEtHr-PblOMCOOKx0Ct#auIgE0v92HC<(Q}ibRn}8ZsJ5Lq+p# zqC(>lWHdSk5evk_6TO1P{&rse-V`?4o{YAUcnIx${q1CQGSwc9CZl6mIJTW_gguu; zh#}MBmmC*f#E{vF2vj)6)0!eCQC*j0BmyK3Z+=b^4NSF9(X&v(8iy` zBQhe9NUj@|=?9k*C9F_)K1l+1W$`6qA75(@A%>476Jz4&{P0LJTFML|c#CbMQQ;9_ z#t3v6INFnAXf8N5TuPA$5w5}hL|Yl3YYXSXQ379qABq$v6btF`NSPPTi-U^tisaxW zY$7Lw%?`si zO57vy#4uT;5P%sI=N&@l3TQa2j6xI=Xm}bG;g07gsJZ2;V9*6Rw2;8I+WK3k7n?1&Z!J&C`{cS0cev#3>QoMxZ$w#9R5$I^ws8GKp zPa%oMMtRv#BLwhpU$KaY_6p(QL%cb0p7xPawDg7VH1aA2zgAt^lu)f9}=R%DQoG4Z8ClAbboxI2VWM++Dr!Y<~+ zu<1Ry<+)7%qlQe2i@Bt``CH00*Im3vNjGHfXB;iS82`vczq@&3ex_`G$E0j>?+p5( z)65?;JNH@=<7eK44oxl&hE66vh&wqc#7WaMZw?mEVoA?6Ocxed64795Uh7}iCQ!Fe zI**oplb;Mq>H9tM#@tdF^M&MYJYrrlYf*~BOm8$xP|;MqLNU(Etv!XqTuxPl?Z#xv z|4>!NG^A>(zF4_ZE>Jbml$dBkuB`Mh=2&cjUNk8rbQgIWjF@l9%dJ1<++FN#2%}?C za%=Y}7{UUR^lZbloj#Gn1{wm6dlZ@9_VqE~+&ysTu;II#e=icFKHMwM-eC85@%Q{W zLZ0E&_@#%ww-2tRsEzcWqckP_lFUuN5`9{npFA5jP~&q&@_yqQogckPf|+MQ$v%q< zzl;JCdJ4DppCNDPT7Jn63SoXnr+yt9s13Vu#_ME*)iKN6D2JFYC*~$qwM@OwdY1Ml z>pR959@Md{NeVfi7JdJ?T?}%M5ws!bW|=v?fn*bM@03|ttl@ydm$uJVJ;!WoPrK}= z>-;mU_)&!`^=49HSbwz!xzwS^n%jdLO8B8RP{*XG4R>a0721S)u3h)>;YqJcp%2g8 z$=+bG#lo)!2bOttE@~)ohQH<|M6M^OrPYt#N?9(Qygm~a^hXQJ8CK_koBeUQ!IR(r zB=GS3&(Chy1L{?;uf|Zm4d2EFC?dB#vY^$S(lqh#ZckM{7yM^BQuOAU_oljd;j7D4 z;MV7)Q)1Hrp;8CX;*(P@X}r7nCQaO49Sk9P_T7O2J9BLM>X&JmG{h+==5QizDjA2& zRXvT8GbAw+FI`W&<(!dzdc6NkWx}uTVL7vZ?>2<9>XoGWVAE>3Dop{yq`1twnznrx{k53pW^ zcb=c^7cIErx3RYN@!Up62jp{XGY?zNf2gK&Jm3GRGTa zNt=z_4tw87_|>0tVP{nFdu>SbF6u4wv4t<$Mv#XmcX^3!{4!1a+OFc!^X|_Bl$kHD zw}N$y4P|7_#>|%HJUQz{IlK|wj}m|#dHLj4`qz`mi{`%2LJWY-5(^;yCDqT3AXv^)aDtKsUlgo(SmOLxT7Y;Difn|H!> zrYx^7&vjT+Xt}#MernVO*h$%R!X;Rzsm|X0I+p%@P5dVq-8YSFm7;A--~Hdf%*_9& z{zTSQ4WDW4fbqSU9VZ$q%29G7e3rM_Zd^yqn>e%KYo$Bm^1s}ekFcFoSkn^eJvTj3 zd+N4o8$Nse%cuwGP%n3YCG7KEQ)2^%#-7#_Lm!s;Eq{3_bj|(J9rV|f;X(y0FGVNG zVA+X`R9dEN&m+*K^{e*C06i!>Hn{~@?U<^^vxb0Af3Dd@)%bQ>H;4*X*Msi^(5+J6 z=6N|t04eV%#~!E#+td{QbBtP-IQ=4YY!TL?q>cH!xTkVno>x5DR8g)Yf7)(X|LdZ< zQEbxzbW2&?DYeavZGW`pe%;sMj!-O@x$F#G0Cu{?t8C{Ov!XdqW^KSc%&u9xSy{H! z8*zSPxnh|U;omkv_W;KaX#e39e1Cf2$PTQ0M zFUDm4jKa02FBT`n)&r!eKd|HUE3LV= z!?r{5ep@znR1;gB^$uOs+ocOp#+YIjKDe)jm4y#mKKg}gy1n05Q#BytLyfy8b{3)` zG&LL~m5yOHru6OSH5`M;FK)c5*u^<*oCaT|!4w@Db{d+&$?dRz&%7-itvHjUscO5Q zMJHNAzM;DQMBv(YQ2q{sm_frbIF%mxF0+Oj%=5=|NmssQJ$sydA=5TJqai%wf4$~C z!>OAF-e3ylHwrzKI`(ZOkuZPx1T_W@PK{nHiw<|~IFd2CN|*jP@NmG?re~(#pIbFZ zHe4@3$-h58d~UDi_AU~)Q0vHpl`zpGagh#W`9(9*l(_y^_Q&r@warHedgs>_5P*6D z>s+rNcHRPADZgQ10G)GE)&zu+(ifwwgc&;PY#7|XHDme^`JLK`6YK~JFvF=!ou9^Q z(c{iZlh4cdwD4q^^1;jkLbo5>APd=a)vB7K1689UWg`KxNoTlqr1#N4f7m) zZ^3>PUtjq3@#MLD*o)-!djT&l?S3)N%6Z)_S*@=%a{ASks8DmSRmS&mI%ak&ET+p> z$r}O7s?mGDt=YBUoU?1O)kDiJdi*){$n03{d8+v+;XU9IzfXTj4{1R1s2X?Iz8jMx zFLvK^h#0gl56v_yQ`BoYBl+at*=Bk9S8R3 z?Q~3>n|hgsQiP2wzYR_upP;#vg_SfI{NA=r#Xj9(+5SFVeXV=?`KN=f#pcyf`|=C~ z&l)x*{{Hz)`t_}g@OgzFa;a2H=i?{WJtbTQ2$Cp;D$=y!SimCVI^ru=&1Q zbp3)RgKIw??7)`AOy7mG(A<}^4$A& z`UN~_89EC8qRVlr58G3_f*JQt1bL{%`R!))*(#uo=XC`EnVERy20S%Au6C97mQ(7c zYU`1Gn+g7FDF(>Lbj|3BOXNtIgQtZi{;@{-AwvF|BWA{4&n)o`Ip1AUY^oy=iZFpuu~k6L;;qzM zQXnHx$fFc7Al@E^Ml$>7rIo97BUkJz>h8`nYDVnAX0UfwY>ptgy;OOtn5MEOdL$eM zaO^yoNRk?K3fJ}e>dj~qi_EL)_(93PWzFk{KR(a@q=S;To?G`hv%89Tq(AbFP?cw# zbv4I*cA>J_Y~jn*$$mET`5mURigt-$_(o$vM|$7g+k;Qlw#Xlic5biLUNs|s zH+A^;se;`$PgS%buaj&R!p)palX%;~^lh^xv&#$`wnEbjfi0x;ZC9=cq+7!~!##{O zSIWP!EXQQ|dYEsBJW~CQKH)-Zzh$7J*_*pg0r2l_CQwsAgNZd{0aa1xf>RcC*Kh+ z>v(f^jbti4d0oc!Tg3w`n!@&c$V#S zUuP3SeTL7tV-OB#t(fNtY=p$n*RsGI-t6eFtY5o&Ju+i%X3`7b(`Ps0bw@f>6&_sp zq>gD*WbnxRtCfornYFfavG@5v$7U0( zQ#gY(3T$#m^!TZK6QcPlgfY2A*M2!HbGXQY?rZY=>82BY+iyWl>vA$}?GEW~Hx+b2 zH|3c>KbNR97mRI{y!Vi8iJPBj;osV~?A&Zk;Ptk-kUip!+Y11b=cGpqRWvoqisfey z9-96g-DK9}@Tf6&U$b6nj%4l zA)i7tdbWcw2i@NLz`S+Sf#g4zO%-@KJt`AN8+R6kJS?kE81E^zBwpV;Lo+jJb!bnJ9tp!kCA!3m_VH*USTzPIRw!%otkGwhkRC#A`!p<{vle8&`c5o?$|tZkt6Q@JDe z;4qt|xWz6JB8Lrhc6F*bo4zzg$n&ocb+UJ8<=!4M%7hgz$Q6UPEja#t1>i4Uomshx zFue%(8_^*zpbN);D+Vu|TXnU9GpP$nJkaN)@ItGntS`m@U=to}&^hlN)T);~QWp(7 zR%_~>aI&vx^JV;9Wf(o%zUf(9wxjx=glZxWTjG8CFkJ5&!hLGpK_@36yx_;rUL)^JJ4lX?gNNlxg&NfAohkP96-OpRKT*7zL~ALg#U)|E7YA6NW->v zAX)Ud+S7CUMooyy_f8z8diS+s&fTTXAemHNJ`C(V^JLOoCJ1M7SEDCxv*_(l`+K$ym6jaEOD-6=iduIkTGBWNNsk0b1e>g5SC)7dFJN# z9*b>jm#-|JEwL|OF@Lkd_0Z#yy9GrzUxU!J|7Q5J2J%sx(CEj0oo&nZL;W*yGSwZU z4{gO^Di1ZU%s1_bUYQ)_u@9!LueC|MZ64k)d{m)){o9=l<_8epb=kW{kp+7z!$BUj z7uiHKax#s!ot_-h?|Gg)(yh@`V3sq7h z!?J*NUtx`?vr0?c@km)^G2}wwexcF=QDNwLBfF?dZ=14%boA7#q7#QH76wkHp4~?s zWaiCb1Ll4O8b)~Uhh>gi1KOmXZEfcwoR?;3Ennee%E*u0(thz?@FGw&#t zJM&VFX!<&G0GuhhBJ<)^-ic1T+JNfFCsCY-Z|Xj%d$uHwO6y+T>WX(@&D-w`vZ^n& ze|_W;y0sc?E!3jDCJYNG^4`&c_x3gW6k!FH6Cc2+)oO?6z)Dj3EG{?AuLX1`WL&FR zvgE+Vs@LZM+`|4G_H$h_^MjTXNlSFMLFc^rN`M1Ecjr&7nxy|%x;sAbzti2eNQe0_ z^&MZ0H9JrhUMHQVzV~i_67|sWqnKbv*M+>$Y&$F`Bpl?>`9;N+yL~|kBVf&AdwF+{ zf5RW>msg_9vNxjhL2mRuUQENL`A^n$yySZjC zXj_>sa_WAU@zd(Ox$}Y7<9qPx7J&~=pxL(WKR(_u7`z&D{=1dSiod_SzG`v6EG5_0 zjyGPLB`v%6v%S;m*n8%3qx-;JuX|FU_O5U%rQgN4o;3`~sTPh?hWt1E`oPw-(#n6E zv)S=j@irz#Hr9_gW4Hk`zWg0ObUS1qBdgWzc|*XC_p^m2tMYxWSS-8YH42tJD_FaY zk^7b#JNq@5YbL0<0Lwx`XOHYrElyM2do1B6qA};qMZNGRXVij!+Aw$S7x!2D^mNI0 zE3Q{f1VQ`54Wv8}>r|?nfJphHC%H2(L`%ojI-72lC}9$Gm(s(q$rux?uMQq#P)n^_8cv44Mn&&E=am%*CW zI8VzjnoVsYr1YtmPd?Gut8jNVL^imwl-2uLr1lN*jh_5q;C`dc3Btl{zapeMcIHBfX~CXk0%%g+7KD@TV2@_Tdg<}(Y1uL|30k57ts<(J5O8H zu~TyEoAT=Z827<_a(-hrWlhk^-`c;$e?0kP??JT zkYwzuZx?YwB2(PA0VQ~wg{bB`*HwO^C50VM-v(J?iQXgI%_`GZWS#XoNs9>RezWG` z`pT@ow2h!0R>MmLzNyhieT;^{>op6OUnj1J-zi!MRFImOUP~O+)Rn59)d1_<*|Ftp z((kz*WS4pP)4Fm{5b_eCi7M=!m93PN>^Itac>lD%! zv#}Yu*%>@F$>U1*ZS#=u+^rf08BnLAEQHSDi)HUxwkESav6MPMj#cw}s7>~{A7v5q zWe~}}v%1-9M4JumkgvIt)fi{#eQKXWNy&v17hOuSZjM&k#JtndoQ90{e*SG2G3foZ z9i#3UxG$fQ<(#qGW~~K{I&^rcwAm8zdJCwaortUo+wj)0+;Nv|_IvnmQYbTeld1;w zEm!54)nNLir4nCZ*SAJ?*XZLaqTmp8TJ&n5e)o}o^2N9IqBg*cxHewy3nNaQUx%&y z=~lgT_vil1_bo$7rwv^XR{O2#MV`H21iR-4{VD!a^Z3)RsjJ>uSSM>R%-r-C;oqqzfwYcEiPy?mVpJ?ZGbD z(*s%%+nW$oD&W?IlGT|YOuXIIxke$9oC2@BHd_LbCvAoPY`Za7hpJBnMzOuDz6&KSsG@pDV9n} zASq#T`XtZl$5pvrqlY)9)hmH00+CU0wE<*~u6EPM2fj=_6BTvNsQz@#_=8ebB3|1s z+z#!z#QuLw6cx)u#h;?L0g7#_i=}x1U*VWWY+LaLU1n`1-dffgt9eygm)aH1z)Uv23bz1w{F!UwNnHtOyjgQ>8j>>IAj zEmRq9_m0^fI)>zT(AI2%$OY@NT}B5MF2Z&uE!YD_$$NS}$P1)Lw1cX0ftVD1D;2)^ zM{hdRUF&oIxzo5Ed#CF>HUJmQ$#|{1sw4K^_6?Y;IKtLLN@I$!4I2wKT`(2=;N`IbopEbHB{- z_iEZ5b1qnQX&YR`ft2l@Z|G3NjJHAj?9=w2*&6SJ=+v9gt-K+{qUSky-wa@LH< Tz9PZDf+)DVdg2-|ds6=g`@|Ku literal 0 HcmV?d00001 diff --git a/docs/user/next/simple_offset.png b/docs/user/next/simple_offset.png new file mode 100644 index 0000000000000000000000000000000000000000..660abe87642151d390abb723a0c40e5ee9c22e00 GIT binary patch literal 10292 zcma)i2UJtr@_xXpAS#Lq8mSQh=_G*!1W^(QC834TODLg*1OlN25Ks^sXcQ5Gf(79! z2q=051*E7Lny5&XB2rYONEd1U9rWG%-dn%5-hX8!=bU|Zojvo-H?wob%EEY?&^{pu z1hUQ4#E=Mq@Ik>BCb$Wd;C0@2fD0d!XpDgry_Wb2fo!q}HF5~0NBa2FydiQ(?5{gH zgoa-bGgJ;~D2G6JFc@kSe-Gaf4?0sV&^r_qf%o(viocJ)H|19yga!hs3Wux0Q6xBA z4r!o;1V5TcHIxRz=~sOZU+=)*9hwJ4`O|0~atNH3ng;0Vpt*;ae_&`3(@zeG0nes^ zq24rb1;yak(gyrEfG=Ew4A&x~6~V0mgF*AQ_x8m5gW2G;kZ3g|3KSp2n~=;2atJKA zr}+nXgNw1ZS3nT23Dz$phz?3{aJY_|2JZt3Z9IHDLj3=&3*O4SLp^+dt%~XtVc{4W zhNpT)nt2BV2HH4k(SNNb+&hHn9~Ag|@d!1z8feLD9Li#N|0<<;hx>bh0TFUYBOut{ z5(HHOJ-|HHC2IwQq6+`j(Eqn2qtNE|XnHV%OtbgX)<$djN16rw>ICf+LiA)XBXFi3 zK7PzlytzHr#=+2u5ekPhHA0ACKmr&|x?v6Q@Z$|XD zh-4v*0wXoy)<`qFHPwsZNHRnjA_;aDIAfp~7-J188eFIZuw^(V&CZDBi43=oG|)tl z;HGwF7)!J-9HU{ZK_uWEiNRDIPqdl0E!@gHjAq5#OA8}3l0vnmc?2`Oy}YcnoiuI2 zHMD$?NK|Mz(<4gD!N@0ABizu;$;6lHW9SD)(zFfIw)Z8HJk5hG;Ix1MYl4F#&K7h< z471k3qbMi}4vxlQLp*6=rVa>AqhMPWJ{ZsNWQ1$mVS`aVpeDiC){$Zu#WD|u(@18- zD5Oy!2IJ&Qcf|XVtVm$`*hs9Ct%+}AUTfW%qZU?Mzi9W)FGgh-4B9O;N7 zSQD(7RzB96p}vu!7(WA`ShxVIC?5w_qyfeQWr)zgn;4lBtWB&gql3m6Z*PA`dLYv(&<JhW5e&4WFNxsG3PAW9n1in9aGa+nHJBbuhof!1 zeXTHbB&d&#Y? zorfulrD2N1vm9`wC3@Hjr1Q$z`k2MMhl$dBm4faHbVC}W(Xo9hmPbk8~mVxs! zv%v;nsZI_~lyHsUU>&$9N@F!OW-P~yQ1q}rL{fX(6H;Y=gCZ2*Z5 zo{gyXREnb(uZ<}yl0*t65qSm~7@>s;^e1T9S~_a`*hJWcFif;8!w`{nRy1qP0Lu^y z4LFmxzi2HjLyM?TxQ;fF8Xgsfqz3CyyiIioek@B?co@Th7G-5`8MuWtDVa@^Y@*oXid z%}kHo8n$<_*0KioPD=+omNk$Q4DeLUkf0Zac@8GW6Zy>M-&VsdqI>)_MqzT&kt zovInVNu7m0804qNn%0Ze-OaDrtC@fKFbGXq6)P(%;e!k=TZk;k8JjYRdl>NQiqviy zir#6ziY=JBGJ^+}7uaWdJ$&+~bZIds6O)4z^avL?$pRu>M&TPlo9>hEF0nViyle8= zlxs+(HJ1o-ejG+sVlZ`&4!g-4b0aTa%%a%Fv+MI+w8sTGjEb`%d5MI}kEI?R#K^~9 zFfV4a;q(jX^&$v23Pr(DREAQbw8ughBQH^W>zX2LS8kz7`lnK@Yoa?51CeQAc}?K` zt8lm7--S4my5a796|9WI(I|OB#z)^1F$uTrjDy%WtDKZK`9>GB9Uu9L96EI9O;^{M z`T215v(v7fRA{9^(rSL?-pi7AU=NQig{wH0{PC^5@cAxFJ!_GRY`7?1iTj3Z~I zpEbL40|lX$3R}wS^WyEfdb4k^89Ge)r%#`1gpIzZA@c*$T6a|mYX=lyv&aI{>NBYK zXVBa~hKgZ}I-#z@gyU-3qmk~i`P-^yJjQBNn{RG@O$|{~cCFAM3 zM_9j5YMNQv&|5G+aC8mX)h`Q6O)nXkTUn=7i ze zL_v(-hhE*|CkUuk^%)E4Y(+7Ice0PZ;K}bmndX*4{0DY686~n62OljU^M>j_Wz17MOu^UFdN+IMQAE}D#zj8h0-{n&WoI`upo1;enU{ucNFfq6S z`AM~@38F3E6?_WzI2%TJBy{J))$JIMCkL8)-k;dmew-Coc?GWw$xr)k)wpXrc6(FC{x5#of6NV7Y+>%zg1c3dMvz5SP-f^OsromacO`e zSc?ToP4nX)9#HCc-g7P$w$z(4G}?rz+d_5@+h18yfPnt|(2(?QEwGRL4?bL;7#}wx zlcixcgdS0T=uhWl&RJ^`Td@|@p}ur{8mj2ioN2^8xL@AL#((jK5C*$y)^Qb6dr+7s zA!f;q ztjsy79sUo5Af<(vdSP&?Ii)fxU1a}}JTaB~!c9xxL=mhEEccG>oLJ&0NAru!Nw7;^ z8=AndL%}1vTYU$HiBgSbd|*B^8F|C{o1uleiYKgeNq*zG_Te)I2YXUCH($Rs63|EX zK^UDpu@)={UCT`Bl!fJlrErq7%JEL>;(u8Vs5i<@)DMZASk{AX&(ExSx4BWye`Lg< z>oFK}+vV9gk!G%w%BRPx>R$Do%vE{+slw)C#q>J?$f^E3;x5ppF zXh~1`c#=_yOYozu8-q`vLB^-xOCvy0^`6onDbE;~aPRDY1&kxPp`R%J&_&5M z7MZXM;iOy!?<5)gvK3A;>ar(2h}Y9Vk`!!LQG~v)1=p{QBHv0fxPtHgcsA!SENAW_ z`>hhtuY)S-w>ejx?pGgLmu>ZarRSmRvaMi?|!KGlxWb#Ydg_+T|0=P%*A8Edc z$!7W;4x))k%fr37bvNC41_r$MeZRN+Rop7}qI(z{!YBJuf{uN2W>iJQJ66S3j;g>Y zr0KgbFD@A)FTmYBP`nS}d;J*dX?gk7>e8e%+r9tUu^&SUM;z;^;S)V)F3KAF=-UW# zLK@|;Jagwm82td|-t&YRJYBCa7(ia?=Gz=Qb&l*l`{ixgT0}MUQ0PHes^jqxbnMsS zxIw*zJ4)$izuRx)MAf*=BYa;==uLBF)7D}?>-JmwHxTk%0C%wcA@&QS9ejS=k2#aBfkA)USD$YPF={3o~W7Y zAYTxX#ocMyimAJrsKnhnozAZRy|$tH)Tu$~wX*4;PcP2&z21t+%f0+=+F&D;2SFc* zWn5v6$zmZy^uk0TMRlU4JfP^wh*#eeqHoQKjTPmWElK#|F-`(|b;Rr1XSe31f)mNn z;syl@Q4!-K$BbqCEFlJVFBIFED{t5eCUYAOH8wLysBj1xDOR1fE{kaJMQKlkl(_h$= z?QTx`8Nfrx%en^n=~F+iLy^XXAY`KdsQSWP zw<_1e)qnM;2J2t)`Rz~T-yX>GYFYwN136pH@?kv*mxQe56;17fFd-EeZ?y(N%J`~n zAgg<;;~?5ws>%F6f*@<8@o%vZblcq0c8Jub@3gsv8Ay-wN>_7`^;y>K!Ouj_$(>kYjZ6;*KT$RuQWWBieOwRtH1-oCSqAujR*B_a#s z%%VgNpF2$pzZd|yz|ZI>mP%B=Y#V!(W*TRun&t-i4v8OVBjkx8&>J$wo{33RLxEd^ zFucgya2TBO&HaLFlGTcB&0Th_CiBiNX3wD2@2QT%W(5ib-xs4H%K5o1WuvQ2Sr9w^ zz5NhGeTJ|}k!3$8`R^_7vvYejA2G99cdfom{d2f@d0-XgM*0Ww;H86cTR#3AJs!Ay z#^IT6Nv3J7&3#bZFD~tHn{LCbI|Rsex&bTE|=W8ngwp|K-|n)HfrwRZATqv zxutDej~rxmW3>e&9m04M-17R7t?Ymi#MllU(>sHlXn&Z%j;^coHWXW+2y$*&czqZI zj52Qb@Fnq|5qxEuOLb3KWx7T)1#g1mLb}d2Mp+c-Oh`fEY>!R6tW2vuygJ#SE^u#% zPq)5q@-D<}5cr6T>{{{_d{64zPll3-8_cUKi@Ar_rbXga&HJ25ORMtsn6}Bmt?#SK z5BH8uygNFsxqZ-dBRVd3*A@J`Ci8r}$}0{qfRMX)j+%i;QU5;yf?{Ud|t!=rUdANhr=#-D8S>UmIP-klDV z0QBvHcsG~V2A>^lUiv1N+VwO{St@QhxMF#_QG4wEbH)hA_tRPBqH69%*|tra_@Q%$ zkfAbwR|1IjVfluv>4w(@Q|*XT4?YiN}voP1lb|lh=NB z*tUxxlonkEn+*J#{$Yp(Cz`SkM}LeDsokvb@v==%j{Fk2v@c3oOvNF-t-yZp#Fide zxh`@dA$&Yv6)A8N&O;oYR~MyFJ9>J1Q-G)0bY!WkYr6g+%jFW z?RwuKMdcHRTteR(30z+Ol5>GE_Tbg0gzfs(MXu$HFE=L&k94^u5?I5R;Bk&>EVcBr z)Ak!VLu&8Z?8zy)_;e}u$7Xc@3qXGuJ;ASPa=SA3z;djM-=SvGUT0Ljo3>JQaF@(j zu-^!cILBQ6G1Xb(y(RG3qEF1~%=K4!=2!32x)3}{w9p$N9r*O_s`BY~jwamOtpzo2 zISFqv0RQ_;$U7Kk%MK2RU771pj<|#h4DHvSzS>cHa*J|Hy+p|ArLvwebxdAGNPPzI zwt84ly^S5bD=b=LvSP+!xc{QDl%rzWqaX#7*#neJ3@*yMlj9NPX4vSC4p_P0p9p843$G&BSae7jxMoa5o1Fuy!Rd81P-xjzTYAeC{UqS2F z=l<(%`u~nOrp3Z-HvS?jJY1k+)gH}{ta!`sax%%iUi>@a$hAK>CdMQFwZaIIo1XdZpW8oyGvK@J+St^3F|YKsszNq{w~5-{pGsVy@`*<|#TDd6&tU*Vou@ z^Xtn2c_-P%4-*HiWGI5rLjKWtFi{6JQzZH?_v2x}5W}9l@9)3x^4y;CmdND{^ezR9 z7hnIpUNoJ|`J0J@sF_PyC%g-N(c1x=r9B^3cYeGT>CO(OQ#$tM{+1E(8y!+21NJ z7k?!kS7AIpAMx@Ri=Kpymt#eLjr}38d|O2SF?4p!v?&irJ4=J4UuzF;C2$ZMv8RM> zx0rD+JF7vb(}#(|!$gtQd>FIKJN@RvhcjM{5jR*rzy1l2n)MR}Ih(YY(G)UZzKvL( z3~n_1>VJ!R#g3~edhPbu?;BzO`+w11==^{&fGN|wS6NvJi+`5^4AQz_-~kBFhty%I zSb3YoDk@jgY2h*b1EUH4inKb^pkUr*_|G(WvA-x~exmo>cKsiaQ_`{8xRiZiZ(=O< z7bz8U?arN*(P&)q-iwPN1n%Y1DSRCA6#e2hq;I~_A^_Ha;>x5}3*Bw4e_Fnp9&X$& zxHMW&>c}!ot#X!_*7@I11}z=Q&X$T9RwmbEY6S>i`WW56r5505;DhB{fin-e>Xz^U zpqG%bd9esUc+}>msQ_<8qFZ+R{8gU220E%<%!>rvMo0RKXB#HuQtJbB6lGF7c4H<1 znMfp*!EoHfZ&@uI5O!*@a=o7L)Ocs3oT|mXQU%KUryxx*l{!U|0*R;LOAY-p`#>HD zgA+CWUj}nYyed0!YlS!l7a+6nUulx0i1f$^|LTXQ|4A(kWG8DfZMKLx2#l-@XA|5- z(mbB>(gF1!M}so|PG)`W6$6oVBH+8ZnZ(+UBP)eR^ZxL7C;WA!zN1FSq*36fc*GE> z%^{TW9ZVwR6z_OkLJJ%#$P$!I_1KZ0qL#B^orO87-K_=DeiJbHO>a9vA`T|o+;c+X z<=_EcE+R7YgP;9+F2d;izj6^iPe}R8KAS*l!;ax~|96@PSh#E&@aNj^4#@1?~5w9ob3e*^7yVJ@xznI!&-N=-rWv$F%wPYoikFK#V*I;xZ` zp+g)8u(jGNC(ARwuqKeO+lOE+74)q$q}PYOsm#vJNe(r4KJa=FyAmZ4`(r?R>p(8Mw&QZ#>ZcR0DkKG^X-Fu(UUJh`j6zh z8J2U&avO+ocOrC-<_Q9K2t@}FJ2I$YE44L|)>^cD2o4 z7W(6)f$i1d)hk2Uf#lWc^asGfZevJ7ZDcG6G<3flbkPDM&$OTl-z-O|raxdJ|(EDp|Z?E?1 zk3r=lYhR3y)LvG!(Kz-qRE7NP8rxd&M&PwYGv6k1;Br(F!Fn}cW~O$F`p0X`*9Oti zx`%Xix0`(bMqb;eaE-wq6%}>gw7yHEOog1 z9Csab#<`o8y`Ky{aZzd@2eGFxmp5hYKv2DH?GMd;Q9WcbP^`&pjBaTE$@8gRs+8rH zSHSQVdIUKVNa&hJX_aw=2#hjKM6bo>4r5h7#u^6P)){-{z#uHUt7Q;Sl!G zH?=-kcfs3?|1mj^r`NB9=z0=^=l_#Na>AeS{D010ow3pDxS`2>IC;7snhAKuR6yCw zz9!)QbA*1=5IkK4#M8V`>)qTHs@diQ0C{Xhn3CQ$#oy)GlygPvl?Uo1{>lFh_2%XY znKcT45t;yRXiIYO8*g4}B*;Hvn;jwl9v5G1)v5X`t+3BtpOp8q(e>*k`|*~-Zljfj z$B!Sc`ew88O6tH}!;yuMs)y4{x)WvCPwKE%Z6Y{_*vZ5m+AQf3D8NvN?=+BAlJ5TS zp(8vzd_v(|b4M07Y3s#EIj7=gxsgK!G6;@~sKy)C&gx4?Q5vpOyKmcvw;E{EyC3<2 zW2)DCiZcb;mEF7hhg>@!e7HhE2WG*&1FerHi7$<|USM0}zZ7qs%8$ydu`z zANAxiKjPNh&n@ZhQTr*$PQ&pJA0I-e=Xcb;`f~Z9c|lcG6)Lr{-Sm7R@iK_?R9|&v zj%Bjp{*5(V$r}j#oS(E;KW3dRTMkir1=-*8rJ$#QzOhUL{}rR+~(GF&F?WX z&*pi`W4XI?9(Gt`o`}m;mRuFC1u?6As_bafqpF1AiA$*Ld-~D4Hd2VRW$LlyzYfE8 zop~sBU0y70dwXN&C3o@)Hda4d`#e8F5fwWVrrgs9^5T^< zhi3lsNDTSh@1Irsyn)ZkClbC0M`g#2pCK-JU>yxwL!#Qzu{Jz`7Us%dlsw8c^tn?! ztljCd{p0COVJ2V7KXdO`+;i+mBgDl?U~lP{?R+uFmkK}6GzfefCJJ#+w~HT}jfpwl zMA>xn>ju>>a0s?y?D2G~JgblZjuX;)LwDpIOy6tPMA^H6e`AcEbn4b$!W9~lynHI- z+3D|M;hk_rCFy?&r1OeT=3!W>DGD{67eANM4|YeBIg{rqBPoJ7zE^Zyvk)u)vQ8cL z{4mF_#in4(uD>^?Bi~Wy^*IjV{3qf0B5-WcYLIcuW#3pV+PTUfkRI{yP8=U+w|JFD zd0PM2X3p{>$=So_{$7TzV9p6PG%@e0Q^~(ZMdrrwfpm<}4i0wfjrhOiaprSWjnHY( z(5O7P4*cWnH1EFvzhvQjUqCF$OWv#Iq{(NH23S`v;%!uZx_39Z&y~33Ock&`3Z@@) z*Hc{T8XHct_+s#wr7KDEjmG+ZS(tc1)n1dYlIEUpQcZS9^WXO_Xiwn^ zHH)xexYy65Z^QjKrK1}&mEH*M@JB*trgtJnR6qS8YkOn0y!2KA{*yRtm;WD3#}QGJ zk_^8A3^(I~frW)$=v2@aV(mR@z)@XdfbULXHiwtS>n`Af8@ Ubnz(o8z98g$ilD)<9_1*0a&l~Y5)KL literal 0 HcmV?d00001 From 8039a900dbfb57526fe6dc2df5a4404bfa23965f Mon Sep 17 00:00:00 2001 From: nfarabullini Date: Thu, 16 Nov 2023 13:35:49 +0100 Subject: [PATCH 04/85] reverted accidental commit --- docs/user/next/presentation_slides.md | 411 -------------------------- docs/user/next/scan_operator.png | Bin 8760 -> 0 bytes docs/user/next/simple_offset.png | Bin 10292 -> 0 bytes 3 files changed, 411 deletions(-) delete mode 100644 docs/user/next/presentation_slides.md delete mode 100644 docs/user/next/scan_operator.png delete mode 100644 docs/user/next/simple_offset.png diff --git a/docs/user/next/presentation_slides.md b/docs/user/next/presentation_slides.md deleted file mode 100644 index 87cd2b7787..0000000000 --- a/docs/user/next/presentation_slides.md +++ /dev/null @@ -1,411 +0,0 @@ ---- -jupytext: - formats: ipynb,md:myst - text_representation: - extension: .md - format_name: myst - format_version: 0.13 - jupytext_version: 1.15.2 -kernelspec: - display_name: Python 3 (ipykernel) - language: python - name: python3 ---- - -# GT4Py workshop - -+++ - -## GT4Py: GridTools for Python - -GT4Py is a Python library for generating high performance implementations of stencil kernels from a high-level definition using regular Python functions. - -GT4Py is part of the GridTools framework: a set of libraries and utilities to develop performance portable applications in the area of weather and climate modeling. - -**NOTE:** The `gt4py.next` subpackage contains a new and currently experimental version of GT4Py. - -## Description - -GT4Py is a Python library for expressing computational motifs as found in weather and climate applications. - -These computations are expressed in a domain specific language (GTScript) which is translated to high-performance implementations for CPUs and GPUs. - -The DSL expresses computations on a 3-dimensional Cartesian grid. The horizontal axes are always computed in parallel, while the vertical can be iterated in sequential, forward or backward, order. - -In addition, GT4Py provides functions to allocate arrays with memory layout suited for a particular backend. - -The following backends are supported: - -- `numpy`: Pure-Python backend -- `gt:cpu_ifirst`: GridTools C++ CPU backend using `I`-first data ordering -- `gt:cpu_kfirst`: GridTools C++ CPU backend using `K`-first data ordering -- `gt:gpu`: GridTools backend for CUDA -- `cuda`: CUDA backend minimally using utilities from GridTools -- `dace:cpu`: Dace code-generated CPU backend -- `dace:gpu`: Dace code-generated GPU backend - -+++ - -## Installation - -You can install the library directly from GitHub using pip: - -```{raw-cell} -pip install --upgrade git+https://github.com/gridtools/gt4py.git -``` - -```{code-cell} ipython3 -import warnings -warnings.filterwarnings('ignore') -``` - -```{code-cell} ipython3 -import numpy as np -import gt4py.next as gtx -from gt4py.next import float64, neighbor_sum, where -from gt4py.next.common import DimensionKind -``` - -## Key concepts and application structure - -- [Fields](#Fields), -- [Field operators](#Field-operators), and -- [Programs](#Programs). - -+++ - -### Fields -Fields are **multi-dimensional array** defined over a set of dimensions and a dtype: `gtx.Field[[dimensions], dtype]` - -The `as_field` builtin is used to define fields - -```{code-cell} ipython3 -CellDim = gtx.Dimension("Cell") -KDim = gtx.Dimension("K", kind=DimensionKind.VERTICAL) -grid_shape = (5, 6) -a = gtx.as_field([CellDim, KDim], np.full(shape=grid_shape, fill_value=2.0, dtype=np.float64)) -b = gtx.as_field([CellDim, KDim], np.full(shape=grid_shape, fill_value=3.0, dtype=np.float64)) - -print("a definition: \n {}".format(a)) -print("a array: \n {}".format(np.asarray(a))) -print("b array: \n {}".format(np.asarray(b))) -``` - -### Field operators - -Field operators perform operations on a set of fields, i.e. elementwise addition or reduction along a dimension. - -They are written as Python functions by using the `@field_operator` decorator. - -```{code-cell} ipython3 -@gtx.field_operator -def add(a: gtx.Field[[CellDim, KDim], float64], - b: gtx.Field[[CellDim, KDim], float64]) -> gtx.Field[[CellDim, KDim], float64]: - return a + b -``` - -Direct calls to field operators require two additional arguments: -- `out`: a field to write the return value to -- `offset_provider`: empty dict for now, explanation will follow - -```{code-cell} ipython3 -result = gtx.as_field([CellDim, KDim], np.zeros(shape=grid_shape)) -add(a, b, out=result, offset_provider={}) - -print("result array \n {}".format(np.asarray(result))) -``` - -### Programs - -+++ - -Programs are used to call field operators to mutate their arguments. - -They are written as Python functions by using the `@program` decorator. - -This example below calls the `add` field operator twice: - -```{code-cell} ipython3 -# @gtx.field_operator -# def add(a, b): -# return a + b - -@gtx.program -def run_add(a : gtx.Field[[CellDim, KDim], float64], - b : gtx.Field[[CellDim, KDim], float64], - result : gtx.Field[[CellDim, KDim], float64]): - add(a, b, out=result) # 2.0 + 3.0 = 5.0 - add(b, result, out=result) # 5.0 + 3.0 = 8.0 -``` - -```{code-cell} ipython3 -result = gtx.as_field([CellDim, KDim], np.zeros(shape=grid_shape)) -run_add(a, b, result, offset_provider={}) - -print("result array: \n {}".format(np.asarray(result))) -``` - -The fields in the subsequent code snippets are 1-dimensional, either over the cells or over the edges. The corresponding named dimensions are thus the following: - -+++ - -### Offsets -Fields can be offset by a predefined number of indices. - -Take an array with values ranging from 0 to 5: - -```{code-cell} ipython3 -a_off = gtx.as_field([CellDim], np.array([1.0, 1.0, 2.0, 3.0, 5.0, 8.0])) - -print("a_off array: \n {}".format(np.asarray(a_off))) -``` - -Visually, offsetting this field by 1 would result in the following: - -| ![Coff](simple_offset.png) | -| :------------------------: | -| _CellDim Offset (Coff)_ | - -+++ - -Fields can be offeset by a predefined number of indices. - -Take an array with values ranging from 0 to 5: - -```{code-cell} ipython3 -Coff = gtx.FieldOffset("Coff", source=CellDim, target=(CellDim,)) - -@gtx.field_operator -def a_offset(a_off: gtx.Field[[CellDim], float64]) -> gtx.Field[[CellDim], float64]: - return a_off(Coff[1]) - -a_offset(a_off, out=a_off, offset_provider={"Coff": CellDim}) -print("result array: \n {}".format(np.asarray(a_off))) -``` - -## Defining the mesh and its connectivities -Take an unstructured mesh with numbered cells (in red) and edges (in blue). - -| ![grid_topo](connectivity_numbered_grid.svg) | -| :------------------------------------------: | -| _The mesh with the indices_ | - -```{code-cell} ipython3 -CellDim = gtx.Dimension("Cell") -EdgeDim = gtx.Dimension("Edge") -``` - -Connectivityy among mesh elements is expressed through connectivity tables. - -For example, `e2c_table` lists for each edge its adjacent rows. - -Similarly, `c2e_table` lists the edges that are neighbors to a particular cell. - -Note that if an edge is lying at the border, one entry will be filled with -1. - -```{code-cell} ipython3 -e2c_table = np.array([ - [0, -1], # edge 0 (neighbours: cell 0) - [2, -1], # edge 1 - [2, -1], # edge 2 - [3, -1], # edge 3 - [4, -1], # edge 4 - [5, -1], # edge 5 - [0, 5], # edge 6 (neighbours: cell 0, cell 5) - [0, 1], # edge 7 - [1, 2], # edge 8 - [1, 3], # edge 9 - [3, 4], # edge 10 - [4, 5] # edge 11 -]) - -c2e_table = np.array([ - [0, 6, 7], # cell 0 (neighbors: edge 0, edge 6, edge 7) - [7, 8, 9], # cell 1 - [1, 2, 8], # cell 2 - [3, 9, 10], # cell 3 - [4, 10, 11], # cell 4 - [5, 6, 11], # cell 5 -]) -``` - -#### Using connectivities in field operators - -Let's start by defining two fields: one over the cells and another one over the edges. The field over cells serves input for subsequent calculations and is therefore filled up with values, whereas the field over the edges stores the output of the calculations and is therefore left blank. - -```{code-cell} ipython3 -cell_field = gtx.as_field([CellDim], np.array([1.0, 1.0, 2.0, 3.0, 5.0, 8.0])) -edge_field = gtx.as_field([EdgeDim], np.zeros((12,))) -``` - -| ![cell_values](connectivity_cell_field.svg) | -| :-----------------------------------------: | -| _Cell values_ | - -+++ - -`field_offset` is used as an argument to transform fields over one domain to another domain. - -For example, `E2C` can be used to shift a field over cells to edges with the following dimension transformation: - -[CellDim] -> CellDim(E2C) -> [EdgeDim, E2CDim] - -A field with an offset dimension is called a sparse field - -```{code-cell} ipython3 -E2CDim = gtx.Dimension("E2C", kind=gtx.DimensionKind.LOCAL) -E2C = gtx.FieldOffset("E2C", source=CellDim, target=(EdgeDim, E2CDim)) -``` - -```{code-cell} ipython3 -E2C_offset_provider = gtx.NeighborTableOffsetProvider(e2c_table, EdgeDim, CellDim, 2) -``` - -```{code-cell} ipython3 -@gtx.field_operator -def nearest_cell_to_edge(cell_field: gtx.Field[[CellDim], float64]) -> gtx.Field[[EdgeDim], float64]: - return cell_field(E2C[0]) # 0th index to isolate edge dimension - -@gtx.program -def run_nearest_cell_to_edge(cell_field: gtx.Field[[CellDim], float64], edge_field: gtx.Field[[EdgeDim], float64]): - nearest_cell_to_edge(cell_field, out=edge_field) - -run_nearest_cell_to_edge(cell_field, edge_field, offset_provider={"E2C": E2C_offset_provider}) - -print("0th adjacent cell's value: {}".format(np.asarray(edge_field))) -``` - -Running the above snippet results in the following edge field: - -| ![nearest_cell_values](connectivity_numbered_grid.svg) | $\mapsto$ | ![grid_topo](connectivity_edge_0th_cell.svg) | -| :----------------------------------------------------: | :-------: | :------------------------------------------: | -| _Domain (edges)_ | | _Edge values_ | - -+++ - -### Using reductions on connected mesh elements - -To sum up all the cells adjacent to an edge the `neighbor_sum` builtin function can be called to operate along the `E2CDim` dimension. - -```{code-cell} ipython3 -@gtx.field_operator -def sum_adjacent_cells(cell_field : gtx.Field[[CellDim], float64]) -> gtx.Field[[EdgeDim], float64]: - return neighbor_sum(cell_field(E2C), axis=E2CDim) - -@gtx.program -def run_sum_adjacent_cells(cell_field : gtx.Field[[CellDim], float64], edge_field: gtx.Field[[EdgeDim], float64]): - sum_adjacent_cells(cell_field, out=edge_field) - -run_sum_adjacent_cells(cell_field, edge_field, offset_provider={"E2C": E2C_offset_provider}) - -print("sum of adjacent cells: {}".format(np.asarray(edge_field))) -``` - -For the border edges, the results are unchanged compared to the previous example, but the inner edges now contain the sum of the two adjacent cells: - -| ![nearest_cell_values](connectivity_numbered_grid.svg) | $\mapsto$ | ![cell_values](connectivity_edge_cell_sum.svg) | -| :----------------------------------------------------: | :-------: | :--------------------------------------------: | -| _Domain (edges)_ | | _Edge values_ | - -+++ - -#### Using conditionals on fields - -To filter operations such that they are performed on only certain cells instead of the whole field, the `where` builtin was developed. - -This function takes 3 input arguments: -- mask: a field of booleans or an expression evaluating to this type -- true branch: a tuple, a field, or a scalar -- false branch: a tuple, a field, of a scalar - -```{code-cell} ipython3 -mask = gtx.as_field([CellDim], np.zeros(shape=grid_shape[0], dtype=bool)) -result = gtx.as_field([CellDim], np.zeros(shape=grid_shape[0])) -b = 6.0 - -@gtx.field_operator -def conditional(mask: gtx.Field[[CellDim], bool], cell_field: gtx.Field[[CellDim], float64], b: float -) -> gtx.Field[[CellDim], float64]: - return where(mask, cell_field, b) - -conditional(mask, cell_field, b, out=result, offset_provider={}) -print("where return: {}".format(np.asarray(result))) -``` - -#### Using domain on fields - -Another way to filter parts of a field where to perform operations, is to use the `domain` keyword argument when calling the field operator. - -Note: domain needs both dimensions to be included with integer tuple values. - -```{code-cell} ipython3 -# @gtx.field_operator -# def add(a, b): -# return a + b - -@gtx.program -def run_add_domain(a : gtx.Field[[CellDim, KDim], float64], - b : gtx.Field[[CellDim, KDim], float64], - result : gtx.Field[[CellDim, KDim], float64]): - add(a, b, out=result, domain={CellDim: (1, 3), KDim: (1, 4)}) -``` - -```{code-cell} ipython3 -a = gtx.as_field([CellDim, KDim], np.full(shape=grid_shape, fill_value=2.0, dtype=np.float64)) -b = gtx.as_field([CellDim, KDim], np.full(shape=grid_shape, fill_value=3.0, dtype=np.float64)) -result = gtx.as_field([CellDim, KDim], np.zeros(shape=grid_shape)) -run_add_domain(a, b, result, offset_provider={}) - -print("result array: \n {}".format(np.asarray(result))) -``` - -#### Scan operators - -Scan operators work in a similar fashion to iterations in Python. - -```{code-cell} ipython3 -x = np.asarray([1.0, 2.0, 4.0, 6.0, 0.0, 2.0, 5.0]) -def x_iteration(x): - for i, x_i in enumerate(x): - if i > 0: - x[i] = x[i-1] + x[i] - return x - -print("result array: \n {}".format(x_iteration(x))) -``` - -Visually, this is what `x_iteration` is doing: - -| ![scan_operator](scan_operator.png) | -| :---------------------------------: | -| _Iterative sum over K_ | - -+++ - -`scan_operators` allow for the same computations and only require a return statement for the operation, for loops and indexing are handled in the background. The return state of the previous iteration is provided as its first argument. - -This decorator takes 3 input arguments: -- `axis`: vertical axis over which operations have to be performed -- `forward`: True if order of operations is from bottom to top, False if from top to bottom -- `init`: initialized decorator value with type float or tuple thereof - -```{code-cell} ipython3 -@gtx.scan_operator(axis=KDim, forward=True, init=0.0) -def add_scan(state: float, k: float) -> float: - return state + k -``` - -```{code-cell} ipython3 -k_field = gtx.as_field([KDim], np.asarray([1.0, 2.0, 4.0, 6.0, 0.0, 2.0, 5.0])) -result = gtx.as_field([KDim], np.zeros(shape=(7,))) - -add_scan(k_field, out=result, offset_provider={}) # Note: `state` is not an input here - -print("result array: \n {}".format(np.asarray(result))) -``` - -Note: `scan_operators` can be called from `field_operators` and `programs`. Likewise, `field_operators` can be called from `scan_operators` - -```{code-cell} ipython3 - -``` diff --git a/docs/user/next/scan_operator.png b/docs/user/next/scan_operator.png deleted file mode 100644 index f0c1d03636b2758296da39a29251c2adc5b321d3..0000000000000000000000000000000000000000 GIT binary patch literal 0 HcmV?d00001 literal 8760 zcmb7q2{_d4`>#YsMPmuAjJ1en7{e%HW{hEGY%^xji_F3dGiHn#Tb3}1BD=IuM9Ef) zNGq=*T2!o>O|GBPnF4uh5XStvIxj*;kzVEM0Z%;S1mFrh3 zC@8486L7xZT&|$7Y#n3;xcd91>w0in7Uk=PQD}Iu@somr%1)^(Ln@37;S0D5CN|iA zuS~4rJW-U?#0F<#Z5=EYTXFcop^{)>lvM;*3Lb&?LJ@}_!sl}SZDS3$wy{JYEDzQ z63(HqX|4iB2%L@){kxiIt|W>tiuljst*sDNpyyI&X`Gn*?^6ypn$HFkTASFo0>%DE zgP=)7Fp%RP&3Xf+SR(%I$p4QfgJ_752!FUWGME+XCBY$3{+2*fOM}?aqW#(8s2IFA zJeVMcW1=`b6Kf3GBNz_17m&SV;o(%8J&MTZxncd7uC|_GG6aGh7sHR?BhkJBES2W% zD+u9J#If;Vwpdvhk>pL2$|SBXESV?6%|{%AqghiVD25B(#Sclqhe&t=JR&M8o*C*H z7wcjt65u0cWIE85AA=YXMMTAhqCD6>@JM%G8#fQL=TY-J{;~N(0VsT$uz?C>V|F=RsidxZWaC zya?y&#tpR>hqFjRFA+s7kTA$pMsTc@iUi99ZApH?_-H#Zcm-$B$Nd+EJEVlQMx(g_S7Is)wzOsClqnJ5A(Qc6OF!?{$@n-)T& zuo!p>n#zwwgL|&KeONG(&asR36$5Z%aCXsnmv97!CZGcg#9`2^7!*0iFC^Z^)y10> zA7SkoN@nrc@h(1WrZpNH%x8-HI6SodQdeJB7C<~c4#BafdEjEgtWmzm2sSE`3ET(9 zRzmd)59hkE@O*bxxEtHr-PblOMCOOKx0Ct#auIgE0v92HC<(Q}ibRn}8ZsJ5Lq+p# zqC(>lWHdSk5evk_6TO1P{&rse-V`?4o{YAUcnIx${q1CQGSwc9CZl6mIJTW_gguu; zh#}MBmmC*f#E{vF2vj)6)0!eCQC*j0BmyK3Z+=b^4NSF9(X&v(8iy` zBQhe9NUj@|=?9k*C9F_)K1l+1W$`6qA75(@A%>476Jz4&{P0LJTFML|c#CbMQQ;9_ z#t3v6INFnAXf8N5TuPA$5w5}hL|Yl3YYXSXQ379qABq$v6btF`NSPPTi-U^tisaxW zY$7Lw%?`si zO57vy#4uT;5P%sI=N&@l3TQa2j6xI=Xm}bG;g07gsJZ2;V9*6Rw2;8I+WK3k7n?1&Z!J&C`{cS0cev#3>QoMxZ$w#9R5$I^ws8GKp zPa%oMMtRv#BLwhpU$KaY_6p(QL%cb0p7xPawDg7VH1aA2zgAt^lu)f9}=R%DQoG4Z8ClAbboxI2VWM++Dr!Y<~+ zu<1Ry<+)7%qlQe2i@Bt``CH00*Im3vNjGHfXB;iS82`vczq@&3ex_`G$E0j>?+p5( z)65?;JNH@=<7eK44oxl&hE66vh&wqc#7WaMZw?mEVoA?6Ocxed64795Uh7}iCQ!Fe zI**oplb;Mq>H9tM#@tdF^M&MYJYrrlYf*~BOm8$xP|;MqLNU(Etv!XqTuxPl?Z#xv z|4>!NG^A>(zF4_ZE>Jbml$dBkuB`Mh=2&cjUNk8rbQgIWjF@l9%dJ1<++FN#2%}?C za%=Y}7{UUR^lZbloj#Gn1{wm6dlZ@9_VqE~+&ysTu;II#e=icFKHMwM-eC85@%Q{W zLZ0E&_@#%ww-2tRsEzcWqckP_lFUuN5`9{npFA5jP~&q&@_yqQogckPf|+MQ$v%q< zzl;JCdJ4DppCNDPT7Jn63SoXnr+yt9s13Vu#_ME*)iKN6D2JFYC*~$qwM@OwdY1Ml z>pR959@Md{NeVfi7JdJ?T?}%M5ws!bW|=v?fn*bM@03|ttl@ydm$uJVJ;!WoPrK}= z>-;mU_)&!`^=49HSbwz!xzwS^n%jdLO8B8RP{*XG4R>a0721S)u3h)>;YqJcp%2g8 z$=+bG#lo)!2bOttE@~)ohQH<|M6M^OrPYt#N?9(Qygm~a^hXQJ8CK_koBeUQ!IR(r zB=GS3&(Chy1L{?;uf|Zm4d2EFC?dB#vY^$S(lqh#ZckM{7yM^BQuOAU_oljd;j7D4 z;MV7)Q)1Hrp;8CX;*(P@X}r7nCQaO49Sk9P_T7O2J9BLM>X&JmG{h+==5QizDjA2& zRXvT8GbAw+FI`W&<(!dzdc6NkWx}uTVL7vZ?>2<9>XoGWVAE>3Dop{yq`1twnznrx{k53pW^ zcb=c^7cIErx3RYN@!Up62jp{XGY?zNf2gK&Jm3GRGTa zNt=z_4tw87_|>0tVP{nFdu>SbF6u4wv4t<$Mv#XmcX^3!{4!1a+OFc!^X|_Bl$kHD zw}N$y4P|7_#>|%HJUQz{IlK|wj}m|#dHLj4`qz`mi{`%2LJWY-5(^;yCDqT3AXv^)aDtKsUlgo(SmOLxT7Y;Difn|H!> zrYx^7&vjT+Xt}#MernVO*h$%R!X;Rzsm|X0I+p%@P5dVq-8YSFm7;A--~Hdf%*_9& z{zTSQ4WDW4fbqSU9VZ$q%29G7e3rM_Zd^yqn>e%KYo$Bm^1s}ekFcFoSkn^eJvTj3 zd+N4o8$Nse%cuwGP%n3YCG7KEQ)2^%#-7#_Lm!s;Eq{3_bj|(J9rV|f;X(y0FGVNG zVA+X`R9dEN&m+*K^{e*C06i!>Hn{~@?U<^^vxb0Af3Dd@)%bQ>H;4*X*Msi^(5+J6 z=6N|t04eV%#~!E#+td{QbBtP-IQ=4YY!TL?q>cH!xTkVno>x5DR8g)Yf7)(X|LdZ< zQEbxzbW2&?DYeavZGW`pe%;sMj!-O@x$F#G0Cu{?t8C{Ov!XdqW^KSc%&u9xSy{H! z8*zSPxnh|U;omkv_W;KaX#e39e1Cf2$PTQ0M zFUDm4jKa02FBT`n)&r!eKd|HUE3LV= z!?r{5ep@znR1;gB^$uOs+ocOp#+YIjKDe)jm4y#mKKg}gy1n05Q#BytLyfy8b{3)` zG&LL~m5yOHru6OSH5`M;FK)c5*u^<*oCaT|!4w@Db{d+&$?dRz&%7-itvHjUscO5Q zMJHNAzM;DQMBv(YQ2q{sm_frbIF%mxF0+Oj%=5=|NmssQJ$sydA=5TJqai%wf4$~C z!>OAF-e3ylHwrzKI`(ZOkuZPx1T_W@PK{nHiw<|~IFd2CN|*jP@NmG?re~(#pIbFZ zHe4@3$-h58d~UDi_AU~)Q0vHpl`zpGagh#W`9(9*l(_y^_Q&r@warHedgs>_5P*6D z>s+rNcHRPADZgQ10G)GE)&zu+(ifwwgc&;PY#7|XHDme^`JLK`6YK~JFvF=!ou9^Q z(c{iZlh4cdwD4q^^1;jkLbo5>APd=a)vB7K1689UWg`KxNoTlqr1#N4f7m) zZ^3>PUtjq3@#MLD*o)-!djT&l?S3)N%6Z)_S*@=%a{ASks8DmSRmS&mI%ak&ET+p> z$r}O7s?mGDt=YBUoU?1O)kDiJdi*){$n03{d8+v+;XU9IzfXTj4{1R1s2X?Iz8jMx zFLvK^h#0gl56v_yQ`BoYBl+at*=Bk9S8R3 z?Q~3>n|hgsQiP2wzYR_upP;#vg_SfI{NA=r#Xj9(+5SFVeXV=?`KN=f#pcyf`|=C~ z&l)x*{{Hz)`t_}g@OgzFa;a2H=i?{WJtbTQ2$Cp;D$=y!SimCVI^ru=&1Q zbp3)RgKIw??7)`AOy7mG(A<}^4$A& z`UN~_89EC8qRVlr58G3_f*JQt1bL{%`R!))*(#uo=XC`EnVERy20S%Au6C97mQ(7c zYU`1Gn+g7FDF(>Lbj|3BOXNtIgQtZi{;@{-AwvF|BWA{4&n)o`Ip1AUY^oy=iZFpuu~k6L;;qzM zQXnHx$fFc7Al@E^Ml$>7rIo97BUkJz>h8`nYDVnAX0UfwY>ptgy;OOtn5MEOdL$eM zaO^yoNRk?K3fJ}e>dj~qi_EL)_(93PWzFk{KR(a@q=S;To?G`hv%89Tq(AbFP?cw# zbv4I*cA>J_Y~jn*$$mET`5mURigt-$_(o$vM|$7g+k;Qlw#Xlic5biLUNs|s zH+A^;se;`$PgS%buaj&R!p)palX%;~^lh^xv&#$`wnEbjfi0x;ZC9=cq+7!~!##{O zSIWP!EXQQ|dYEsBJW~CQKH)-Zzh$7J*_*pg0r2l_CQwsAgNZd{0aa1xf>RcC*Kh+ z>v(f^jbti4d0oc!Tg3w`n!@&c$V#S zUuP3SeTL7tV-OB#t(fNtY=p$n*RsGI-t6eFtY5o&Ju+i%X3`7b(`Ps0bw@f>6&_sp zq>gD*WbnxRtCfornYFfavG@5v$7U0( zQ#gY(3T$#m^!TZK6QcPlgfY2A*M2!HbGXQY?rZY=>82BY+iyWl>vA$}?GEW~Hx+b2 zH|3c>KbNR97mRI{y!Vi8iJPBj;osV~?A&Zk;Ptk-kUip!+Y11b=cGpqRWvoqisfey z9-96g-DK9}@Tf6&U$b6nj%4l zA)i7tdbWcw2i@NLz`S+Sf#g4zO%-@KJt`AN8+R6kJS?kE81E^zBwpV;Lo+jJb!bnJ9tp!kCA!3m_VH*USTzPIRw!%otkGwhkRC#A`!p<{vle8&`c5o?$|tZkt6Q@JDe z;4qt|xWz6JB8Lrhc6F*bo4zzg$n&ocb+UJ8<=!4M%7hgz$Q6UPEja#t1>i4Uomshx zFue%(8_^*zpbN);D+Vu|TXnU9GpP$nJkaN)@ItGntS`m@U=to}&^hlN)T);~QWp(7 zR%_~>aI&vx^JV;9Wf(o%zUf(9wxjx=glZxWTjG8CFkJ5&!hLGpK_@36yx_;rUL)^JJ4lX?gNNlxg&NfAohkP96-OpRKT*7zL~ALg#U)|E7YA6NW->v zAX)Ud+S7CUMooyy_f8z8diS+s&fTTXAemHNJ`C(V^JLOoCJ1M7SEDCxv*_(l`+K$ym6jaEOD-6=iduIkTGBWNNsk0b1e>g5SC)7dFJN# z9*b>jm#-|JEwL|OF@Lkd_0Z#yy9GrzUxU!J|7Q5J2J%sx(CEj0oo&nZL;W*yGSwZU z4{gO^Di1ZU%s1_bUYQ)_u@9!LueC|MZ64k)d{m)){o9=l<_8epb=kW{kp+7z!$BUj z7uiHKax#s!ot_-h?|Gg)(yh@`V3sq7h z!?J*NUtx`?vr0?c@km)^G2}wwexcF=QDNwLBfF?dZ=14%boA7#q7#QH76wkHp4~?s zWaiCb1Ll4O8b)~Uhh>gi1KOmXZEfcwoR?;3Ennee%E*u0(thz?@FGw&#t zJM&VFX!<&G0GuhhBJ<)^-ic1T+JNfFCsCY-Z|Xj%d$uHwO6y+T>WX(@&D-w`vZ^n& ze|_W;y0sc?E!3jDCJYNG^4`&c_x3gW6k!FH6Cc2+)oO?6z)Dj3EG{?AuLX1`WL&FR zvgE+Vs@LZM+`|4G_H$h_^MjTXNlSFMLFc^rN`M1Ecjr&7nxy|%x;sAbzti2eNQe0_ z^&MZ0H9JrhUMHQVzV~i_67|sWqnKbv*M+>$Y&$F`Bpl?>`9;N+yL~|kBVf&AdwF+{ zf5RW>msg_9vNxjhL2mRuUQENL`A^n$yySZjC zXj_>sa_WAU@zd(Ox$}Y7<9qPx7J&~=pxL(WKR(_u7`z&D{=1dSiod_SzG`v6EG5_0 zjyGPLB`v%6v%S;m*n8%3qx-;JuX|FU_O5U%rQgN4o;3`~sTPh?hWt1E`oPw-(#n6E zv)S=j@irz#Hr9_gW4Hk`zWg0ObUS1qBdgWzc|*XC_p^m2tMYxWSS-8YH42tJD_FaY zk^7b#JNq@5YbL0<0Lwx`XOHYrElyM2do1B6qA};qMZNGRXVij!+Aw$S7x!2D^mNI0 zE3Q{f1VQ`54Wv8}>r|?nfJphHC%H2(L`%ojI-72lC}9$Gm(s(q$rux?uMQq#P)n^_8cv44Mn&&E=am%*CW zI8VzjnoVsYr1YtmPd?Gut8jNVL^imwl-2uLr1lN*jh_5q;C`dc3Btl{zapeMcIHBfX~CXk0%%g+7KD@TV2@_Tdg<}(Y1uL|30k57ts<(J5O8H zu~TyEoAT=Z827<_a(-hrWlhk^-`c;$e?0kP??JT zkYwzuZx?YwB2(PA0VQ~wg{bB`*HwO^C50VM-v(J?iQXgI%_`GZWS#XoNs9>RezWG` z`pT@ow2h!0R>MmLzNyhieT;^{>op6OUnj1J-zi!MRFImOUP~O+)Rn59)d1_<*|Ftp z((kz*WS4pP)4Fm{5b_eCi7M=!m93PN>^Itac>lD%! zv#}Yu*%>@F$>U1*ZS#=u+^rf08BnLAEQHSDi)HUxwkESav6MPMj#cw}s7>~{A7v5q zWe~}}v%1-9M4JumkgvIt)fi{#eQKXWNy&v17hOuSZjM&k#JtndoQ90{e*SG2G3foZ z9i#3UxG$fQ<(#qGW~~K{I&^rcwAm8zdJCwaortUo+wj)0+;Nv|_IvnmQYbTeld1;w zEm!54)nNLir4nCZ*SAJ?*XZLaqTmp8TJ&n5e)o}o^2N9IqBg*cxHewy3nNaQUx%&y z=~lgT_vil1_bo$7rwv^XR{O2#MV`H21iR-4{VD!a^Z3)RsjJ>uSSM>R%-r-C;oqqzfwYcEiPy?mVpJ?ZGbD z(*s%%+nW$oD&W?IlGT|YOuXIIxke$9oC2@BHd_LbCvAoPY`Za7hpJBnMzOuDz6&KSsG@pDV9n} zASq#T`XtZl$5pvrqlY)9)hmH00+CU0wE<*~u6EPM2fj=_6BTvNsQz@#_=8ebB3|1s z+z#!z#QuLw6cx)u#h;?L0g7#_i=}x1U*VWWY+LaLU1n`1-dffgt9eygm)aH1z)Uv23bz1w{F!UwNnHtOyjgQ>8j>>IAj zEmRq9_m0^fI)>zT(AI2%$OY@NT}B5MF2Z&uE!YD_$$NS}$P1)Lw1cX0ftVD1D;2)^ zM{hdRUF&oIxzo5Ed#CF>HUJmQ$#|{1sw4K^_6?Y;IKtLLN@I$!4I2wKT`(2=;N`IbopEbHB{- z_iEZ5b1qnQX&YR`ft2l@Z|G3NjJHAj?9=w2*&6SJ=+v9gt-K+{qUSky-wa@LH< Tz9PZDf+)DVdg2-|ds6=g`@|Ku diff --git a/docs/user/next/simple_offset.png b/docs/user/next/simple_offset.png deleted file mode 100644 index 660abe87642151d390abb723a0c40e5ee9c22e00..0000000000000000000000000000000000000000 GIT binary patch literal 0 HcmV?d00001 literal 10292 zcma)i2UJtr@_xXpAS#Lq8mSQh=_G*!1W^(QC834TODLg*1OlN25Ks^sXcQ5Gf(79! z2q=051*E7Lny5&XB2rYONEd1U9rWG%-dn%5-hX8!=bU|Zojvo-H?wob%EEY?&^{pu z1hUQ4#E=Mq@Ik>BCb$Wd;C0@2fD0d!XpDgry_Wb2fo!q}HF5~0NBa2FydiQ(?5{gH zgoa-bGgJ;~D2G6JFc@kSe-Gaf4?0sV&^r_qf%o(viocJ)H|19yga!hs3Wux0Q6xBA z4r!o;1V5TcHIxRz=~sOZU+=)*9hwJ4`O|0~atNH3ng;0Vpt*;ae_&`3(@zeG0nes^ zq24rb1;yak(gyrEfG=Ew4A&x~6~V0mgF*AQ_x8m5gW2G;kZ3g|3KSp2n~=;2atJKA zr}+nXgNw1ZS3nT23Dz$phz?3{aJY_|2JZt3Z9IHDLj3=&3*O4SLp^+dt%~XtVc{4W zhNpT)nt2BV2HH4k(SNNb+&hHn9~Ag|@d!1z8feLD9Li#N|0<<;hx>bh0TFUYBOut{ z5(HHOJ-|HHC2IwQq6+`j(Eqn2qtNE|XnHV%OtbgX)<$djN16rw>ICf+LiA)XBXFi3 zK7PzlytzHr#=+2u5ekPhHA0ACKmr&|x?v6Q@Z$|XD zh-4v*0wXoy)<`qFHPwsZNHRnjA_;aDIAfp~7-J188eFIZuw^(V&CZDBi43=oG|)tl z;HGwF7)!J-9HU{ZK_uWEiNRDIPqdl0E!@gHjAq5#OA8}3l0vnmc?2`Oy}YcnoiuI2 zHMD$?NK|Mz(<4gD!N@0ABizu;$;6lHW9SD)(zFfIw)Z8HJk5hG;Ix1MYl4F#&K7h< z471k3qbMi}4vxlQLp*6=rVa>AqhMPWJ{ZsNWQ1$mVS`aVpeDiC){$Zu#WD|u(@18- zD5Oy!2IJ&Qcf|XVtVm$`*hs9Ct%+}AUTfW%qZU?Mzi9W)FGgh-4B9O;N7 zSQD(7RzB96p}vu!7(WA`ShxVIC?5w_qyfeQWr)zgn;4lBtWB&gql3m6Z*PA`dLYv(&<JhW5e&4WFNxsG3PAW9n1in9aGa+nHJBbuhof!1 zeXTHbB&d&#Y? zorfulrD2N1vm9`wC3@Hjr1Q$z`k2MMhl$dBm4faHbVC}W(Xo9hmPbk8~mVxs! zv%v;nsZI_~lyHsUU>&$9N@F!OW-P~yQ1q}rL{fX(6H;Y=gCZ2*Z5 zo{gyXREnb(uZ<}yl0*t65qSm~7@>s;^e1T9S~_a`*hJWcFif;8!w`{nRy1qP0Lu^y z4LFmxzi2HjLyM?TxQ;fF8Xgsfqz3CyyiIioek@B?co@Th7G-5`8MuWtDVa@^Y@*oXid z%}kHo8n$<_*0KioPD=+omNk$Q4DeLUkf0Zac@8GW6Zy>M-&VsdqI>)_MqzT&kt zovInVNu7m0804qNn%0Ze-OaDrtC@fKFbGXq6)P(%;e!k=TZk;k8JjYRdl>NQiqviy zir#6ziY=JBGJ^+}7uaWdJ$&+~bZIds6O)4z^avL?$pRu>M&TPlo9>hEF0nViyle8= zlxs+(HJ1o-ejG+sVlZ`&4!g-4b0aTa%%a%Fv+MI+w8sTGjEb`%d5MI}kEI?R#K^~9 zFfV4a;q(jX^&$v23Pr(DREAQbw8ughBQH^W>zX2LS8kz7`lnK@Yoa?51CeQAc}?K` zt8lm7--S4my5a796|9WI(I|OB#z)^1F$uTrjDy%WtDKZK`9>GB9Uu9L96EI9O;^{M z`T215v(v7fRA{9^(rSL?-pi7AU=NQig{wH0{PC^5@cAxFJ!_GRY`7?1iTj3Z~I zpEbL40|lX$3R}wS^WyEfdb4k^89Ge)r%#`1gpIzZA@c*$T6a|mYX=lyv&aI{>NBYK zXVBa~hKgZ}I-#z@gyU-3qmk~i`P-^yJjQBNn{RG@O$|{~cCFAM3 zM_9j5YMNQv&|5G+aC8mX)h`Q6O)nXkTUn=7i ze zL_v(-hhE*|CkUuk^%)E4Y(+7Ice0PZ;K}bmndX*4{0DY686~n62OljU^M>j_Wz17MOu^UFdN+IMQAE}D#zj8h0-{n&WoI`upo1;enU{ucNFfq6S z`AM~@38F3E6?_WzI2%TJBy{J))$JIMCkL8)-k;dmew-Coc?GWw$xr)k)wpXrc6(FC{x5#of6NV7Y+>%zg1c3dMvz5SP-f^OsromacO`e zSc?ToP4nX)9#HCc-g7P$w$z(4G}?rz+d_5@+h18yfPnt|(2(?QEwGRL4?bL;7#}wx zlcixcgdS0T=uhWl&RJ^`Td@|@p}ur{8mj2ioN2^8xL@AL#((jK5C*$y)^Qb6dr+7s zA!f;q ztjsy79sUo5Af<(vdSP&?Ii)fxU1a}}JTaB~!c9xxL=mhEEccG>oLJ&0NAru!Nw7;^ z8=AndL%}1vTYU$HiBgSbd|*B^8F|C{o1uleiYKgeNq*zG_Te)I2YXUCH($Rs63|EX zK^UDpu@)={UCT`Bl!fJlrErq7%JEL>;(u8Vs5i<@)DMZASk{AX&(ExSx4BWye`Lg< z>oFK}+vV9gk!G%w%BRPx>R$Do%vE{+slw)C#q>J?$f^E3;x5ppF zXh~1`c#=_yOYozu8-q`vLB^-xOCvy0^`6onDbE;~aPRDY1&kxPp`R%J&_&5M z7MZXM;iOy!?<5)gvK3A;>ar(2h}Y9Vk`!!LQG~v)1=p{QBHv0fxPtHgcsA!SENAW_ z`>hhtuY)S-w>ejx?pGgLmu>ZarRSmRvaMi?|!KGlxWb#Ydg_+T|0=P%*A8Edc z$!7W;4x))k%fr37bvNC41_r$MeZRN+Rop7}qI(z{!YBJuf{uN2W>iJQJ66S3j;g>Y zr0KgbFD@A)FTmYBP`nS}d;J*dX?gk7>e8e%+r9tUu^&SUM;z;^;S)V)F3KAF=-UW# zLK@|;Jagwm82td|-t&YRJYBCa7(ia?=Gz=Qb&l*l`{ixgT0}MUQ0PHes^jqxbnMsS zxIw*zJ4)$izuRx)MAf*=BYa;==uLBF)7D}?>-JmwHxTk%0C%wcA@&QS9ejS=k2#aBfkA)USD$YPF={3o~W7Y zAYTxX#ocMyimAJrsKnhnozAZRy|$tH)Tu$~wX*4;PcP2&z21t+%f0+=+F&D;2SFc* zWn5v6$zmZy^uk0TMRlU4JfP^wh*#eeqHoQKjTPmWElK#|F-`(|b;Rr1XSe31f)mNn z;syl@Q4!-K$BbqCEFlJVFBIFED{t5eCUYAOH8wLysBj1xDOR1fE{kaJMQKlkl(_h$= z?QTx`8Nfrx%en^n=~F+iLy^XXAY`KdsQSWP zw<_1e)qnM;2J2t)`Rz~T-yX>GYFYwN136pH@?kv*mxQe56;17fFd-EeZ?y(N%J`~n zAgg<;;~?5ws>%F6f*@<8@o%vZblcq0c8Jub@3gsv8Ay-wN>_7`^;y>K!Ouj_$(>kYjZ6;*KT$RuQWWBieOwRtH1-oCSqAujR*B_a#s z%%VgNpF2$pzZd|yz|ZI>mP%B=Y#V!(W*TRun&t-i4v8OVBjkx8&>J$wo{33RLxEd^ zFucgya2TBO&HaLFlGTcB&0Th_CiBiNX3wD2@2QT%W(5ib-xs4H%K5o1WuvQ2Sr9w^ zz5NhGeTJ|}k!3$8`R^_7vvYejA2G99cdfom{d2f@d0-XgM*0Ww;H86cTR#3AJs!Ay z#^IT6Nv3J7&3#bZFD~tHn{LCbI|Rsex&bTE|=W8ngwp|K-|n)HfrwRZATqv zxutDej~rxmW3>e&9m04M-17R7t?Ymi#MllU(>sHlXn&Z%j;^coHWXW+2y$*&czqZI zj52Qb@Fnq|5qxEuOLb3KWx7T)1#g1mLb}d2Mp+c-Oh`fEY>!R6tW2vuygJ#SE^u#% zPq)5q@-D<}5cr6T>{{{_d{64zPll3-8_cUKi@Ar_rbXga&HJ25ORMtsn6}Bmt?#SK z5BH8uygNFsxqZ-dBRVd3*A@J`Ci8r}$}0{qfRMX)j+%i;QU5;yf?{Ud|t!=rUdANhr=#-D8S>UmIP-klDV z0QBvHcsG~V2A>^lUiv1N+VwO{St@QhxMF#_QG4wEbH)hA_tRPBqH69%*|tra_@Q%$ zkfAbwR|1IjVfluv>4w(@Q|*XT4?YiN}voP1lb|lh=NB z*tUxxlonkEn+*J#{$Yp(Cz`SkM}LeDsokvb@v==%j{Fk2v@c3oOvNF-t-yZp#Fide zxh`@dA$&Yv6)A8N&O;oYR~MyFJ9>J1Q-G)0bY!WkYr6g+%jFW z?RwuKMdcHRTteR(30z+Ol5>GE_Tbg0gzfs(MXu$HFE=L&k94^u5?I5R;Bk&>EVcBr z)Ak!VLu&8Z?8zy)_;e}u$7Xc@3qXGuJ;ASPa=SA3z;djM-=SvGUT0Ljo3>JQaF@(j zu-^!cILBQ6G1Xb(y(RG3qEF1~%=K4!=2!32x)3}{w9p$N9r*O_s`BY~jwamOtpzo2 zISFqv0RQ_;$U7Kk%MK2RU771pj<|#h4DHvSzS>cHa*J|Hy+p|ArLvwebxdAGNPPzI zwt84ly^S5bD=b=LvSP+!xc{QDl%rzWqaX#7*#neJ3@*yMlj9NPX4vSC4p_P0p9p843$G&BSae7jxMoa5o1Fuy!Rd81P-xjzTYAeC{UqS2F z=l<(%`u~nOrp3Z-HvS?jJY1k+)gH}{ta!`sax%%iUi>@a$hAK>CdMQFwZaIIo1XdZpW8oyGvK@J+St^3F|YKsszNq{w~5-{pGsVy@`*<|#TDd6&tU*Vou@ z^Xtn2c_-P%4-*HiWGI5rLjKWtFi{6JQzZH?_v2x}5W}9l@9)3x^4y;CmdND{^ezR9 z7hnIpUNoJ|`J0J@sF_PyC%g-N(c1x=r9B^3cYeGT>CO(OQ#$tM{+1E(8y!+21NJ z7k?!kS7AIpAMx@Ri=Kpymt#eLjr}38d|O2SF?4p!v?&irJ4=J4UuzF;C2$ZMv8RM> zx0rD+JF7vb(}#(|!$gtQd>FIKJN@RvhcjM{5jR*rzy1l2n)MR}Ih(YY(G)UZzKvL( z3~n_1>VJ!R#g3~edhPbu?;BzO`+w11==^{&fGN|wS6NvJi+`5^4AQz_-~kBFhty%I zSb3YoDk@jgY2h*b1EUH4inKb^pkUr*_|G(WvA-x~exmo>cKsiaQ_`{8xRiZiZ(=O< z7bz8U?arN*(P&)q-iwPN1n%Y1DSRCA6#e2hq;I~_A^_Ha;>x5}3*Bw4e_Fnp9&X$& zxHMW&>c}!ot#X!_*7@I11}z=Q&X$T9RwmbEY6S>i`WW56r5505;DhB{fin-e>Xz^U zpqG%bd9esUc+}>msQ_<8qFZ+R{8gU220E%<%!>rvMo0RKXB#HuQtJbB6lGF7c4H<1 znMfp*!EoHfZ&@uI5O!*@a=o7L)Ocs3oT|mXQU%KUryxx*l{!U|0*R;LOAY-p`#>HD zgA+CWUj}nYyed0!YlS!l7a+6nUulx0i1f$^|LTXQ|4A(kWG8DfZMKLx2#l-@XA|5- z(mbB>(gF1!M}so|PG)`W6$6oVBH+8ZnZ(+UBP)eR^ZxL7C;WA!zN1FSq*36fc*GE> z%^{TW9ZVwR6z_OkLJJ%#$P$!I_1KZ0qL#B^orO87-K_=DeiJbHO>a9vA`T|o+;c+X z<=_EcE+R7YgP;9+F2d;izj6^iPe}R8KAS*l!;ax~|96@PSh#E&@aNj^4#@1?~5w9ob3e*^7yVJ@xznI!&-N=-rWv$F%wPYoikFK#V*I;xZ` zp+g)8u(jGNC(ARwuqKeO+lOE+74)q$q}PYOsm#vJNe(r4KJa=FyAmZ4`(r?R>p(8Mw&QZ#>ZcR0DkKG^X-Fu(UUJh`j6zh z8J2U&avO+ocOrC-<_Q9K2t@}FJ2I$YE44L|)>^cD2o4 z7W(6)f$i1d)hk2Uf#lWc^asGfZevJ7ZDcG6G<3flbkPDM&$OTl-z-O|raxdJ|(EDp|Z?E?1 zk3r=lYhR3y)LvG!(Kz-qRE7NP8rxd&M&PwYGv6k1;Br(F!Fn}cW~O$F`p0X`*9Oti zx`%Xix0`(bMqb;eaE-wq6%}>gw7yHEOog1 z9Csab#<`o8y`Ky{aZzd@2eGFxmp5hYKv2DH?GMd;Q9WcbP^`&pjBaTE$@8gRs+8rH zSHSQVdIUKVNa&hJX_aw=2#hjKM6bo>4r5h7#u^6P)){-{z#uHUt7Q;Sl!G zH?=-kcfs3?|1mj^r`NB9=z0=^=l_#Na>AeS{D010ow3pDxS`2>IC;7snhAKuR6yCw zz9!)QbA*1=5IkK4#M8V`>)qTHs@diQ0C{Xhn3CQ$#oy)GlygPvl?Uo1{>lFh_2%XY znKcT45t;yRXiIYO8*g4}B*;Hvn;jwl9v5G1)v5X`t+3BtpOp8q(e>*k`|*~-Zljfj z$B!Sc`ew88O6tH}!;yuMs)y4{x)WvCPwKE%Z6Y{_*vZ5m+AQf3D8NvN?=+BAlJ5TS zp(8vzd_v(|b4M07Y3s#EIj7=gxsgK!G6;@~sKy)C&gx4?Q5vpOyKmcvw;E{EyC3<2 zW2)DCiZcb;mEF7hhg>@!e7HhE2WG*&1FerHi7$<|USM0}zZ7qs%8$ydu`z zANAxiKjPNh&n@ZhQTr*$PQ&pJA0I-e=Xcb;`f~Z9c|lcG6)Lr{-Sm7R@iK_?R9|&v zj%Bjp{*5(V$r}j#oS(E;KW3dRTMkir1=-*8rJ$#QzOhUL{}rR+~(GF&F?WX z&*pi`W4XI?9(Gt`o`}m;mRuFC1u?6As_bafqpF1AiA$*Ld-~D4Hd2VRW$LlyzYfE8 zop~sBU0y70dwXN&C3o@)Hda4d`#e8F5fwWVrrgs9^5T^< zhi3lsNDTSh@1Irsyn)ZkClbC0M`g#2pCK-JU>yxwL!#Qzu{Jz`7Us%dlsw8c^tn?! ztljCd{p0COVJ2V7KXdO`+;i+mBgDl?U~lP{?R+uFmkK}6GzfefCJJ#+w~HT}jfpwl zMA>xn>ju>>a0s?y?D2G~JgblZjuX;)LwDpIOy6tPMA^H6e`AcEbn4b$!W9~lynHI- z+3D|M;hk_rCFy?&r1OeT=3!W>DGD{67eANM4|YeBIg{rqBPoJ7zE^Zyvk)u)vQ8cL z{4mF_#in4(uD>^?Bi~Wy^*IjV{3qf0B5-WcYLIcuW#3pV+PTUfkRI{yP8=U+w|JFD zd0PM2X3p{>$=So_{$7TzV9p6PG%@e0Q^~(ZMdrrwfpm<}4i0wfjrhOiaprSWjnHY( z(5O7P4*cWnH1EFvzhvQjUqCF$OWv#Iq{(NH23S`v;%!uZx_39Z&y~33Ock&`3Z@@) z*Hc{T8XHct_+s#wr7KDEjmG+ZS(tc1)n1dYlIEUpQcZS9^WXO_Xiwn^ zHH)xexYy65Z^QjKrK1}&mEH*M@JB*trgtJnR6qS8YkOn0y!2KA{*yRtm;WD3#}QGJ zk_^8A3^(I~frW)$=v2@aV(mR@z)@XdfbULXHiwtS>n`Af8@ Ubnz(o8z98g$ilD)<9_1*0a&l~Y5)KL From da1da20b0d6bde48e1a15b6ca6dee9e7d065f337 Mon Sep 17 00:00:00 2001 From: edopao Date: Thu, 16 Nov 2023 15:10:27 +0100 Subject: [PATCH 05/85] feat[next]: DaCe support for can_deref (#1356) This small PR adds support for can_deref operator in DaCe backend. It also improves the code for preprocess ITIR transformations. --- pyproject.toml | 1 - .../runners/dace_iterator/__init__.py | 32 +++++++++++++++---- .../runners/dace_iterator/itir_to_tasklet.py | 30 +++++++++++++++-- .../iterator_tests/test_builtins.py | 1 - 4 files changed, 52 insertions(+), 12 deletions(-) diff --git a/pyproject.toml b/pyproject.toml index e2d2a7dfe9..7690ae583e 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -330,7 +330,6 @@ markers = [ 'requires_dace: tests that require `dace` package', 'requires_gpu: tests that require a NVidia GPU (`cupy` and `cudatoolkit` are required)', 'uses_applied_shifts: tests that require backend support for applied-shifts', - 'uses_can_deref: tests that require backend support for can_deref', 'uses_constant_fields: tests that require backend support for constant fields', 'uses_dynamic_offsets: tests that require backend support for dynamic offsets', 'uses_if_stmts: tests that require backend support for if-statements', diff --git a/src/gt4py/next/program_processors/runners/dace_iterator/__init__.py b/src/gt4py/next/program_processors/runners/dace_iterator/__init__.py index 9f67cb26da..e3fba87571 100644 --- a/src/gt4py/next/program_processors/runners/dace_iterator/__init__.py +++ b/src/gt4py/next/program_processors/runners/dace_iterator/__init__.py @@ -65,14 +65,29 @@ def convert_arg(arg: Any): return arg -def preprocess_program(program: itir.FencilDefinition, offset_provider: Mapping[str, Any]): - program = apply_common_transforms( +def preprocess_program( + program: itir.FencilDefinition, offset_provider: Mapping[str, Any], lift_mode: LiftMode +): + node = apply_common_transforms( program, - offset_provider=offset_provider, - lift_mode=LiftMode.FORCE_INLINE, common_subexpression_elimination=False, + lift_mode=lift_mode, + offset_provider=offset_provider, + unroll_reduce=False, ) - return program + # If we don't unroll, there may be lifts left in the itir which can't be lowered to SDFG. + # In this case, just retry with unrolled reductions. + if all([ItirToSDFG._check_no_lifts(closure) for closure in node.closures]): + fencil_definition = node + else: + fencil_definition = apply_common_transforms( + program, + common_subexpression_elimination=False, + lift_mode=lift_mode, + offset_provider=offset_provider, + unroll_reduce=True, + ) + return fencil_definition def get_args(params: Sequence[itir.Sym], args: Sequence[Any]) -> dict[str, Any]: @@ -156,11 +171,14 @@ def get_cache_id( def run_dace_iterator(program: itir.FencilDefinition, *args, **kwargs) -> None: # build parameters auto_optimize = kwargs.get("auto_optimize", False) + build_cache = kwargs.get("build_cache", None) build_type = kwargs.get("build_type", "RelWithDebInfo") run_on_gpu = kwargs.get("run_on_gpu", False) - build_cache = kwargs.get("build_cache", None) # ITIR parameters column_axis = kwargs.get("column_axis", None) + lift_mode = ( + LiftMode.FORCE_INLINE + ) # TODO(edopao): make it configurable once temporaries are supported in DaCe backend offset_provider = kwargs["offset_provider"] arg_types = [type_translation.from_value(arg) for arg in args] @@ -173,7 +191,7 @@ def run_dace_iterator(program: itir.FencilDefinition, *args, **kwargs) -> None: sdfg = sdfg_program.sdfg else: # visit ITIR and generate SDFG - program = preprocess_program(program, offset_provider) + program = preprocess_program(program, offset_provider, lift_mode) sdfg_genenerator = ItirToSDFG(arg_types, offset_provider, column_axis) sdfg = sdfg_genenerator.visit(program) sdfg.simplify() diff --git a/src/gt4py/next/program_processors/runners/dace_iterator/itir_to_tasklet.py b/src/gt4py/next/program_processors/runners/dace_iterator/itir_to_tasklet.py index 5d47cad909..5b240ea2b7 100644 --- a/src/gt4py/next/program_processors/runners/dace_iterator/itir_to_tasklet.py +++ b/src/gt4py/next/program_processors/runners/dace_iterator/itir_to_tasklet.py @@ -265,6 +265,29 @@ def builtin_neighbors( return [ValueExpr(result_access, iterator.dtype)] +def builtin_can_deref( + transformer: "PythonTaskletCodegen", node: itir.Expr, node_args: list[itir.Expr] +) -> list[ValueExpr]: + # first visit shift, to get set of indices for deref + can_deref_callable = node_args[0] + assert isinstance(can_deref_callable, itir.FunCall) + shift_callable = can_deref_callable.fun + assert isinstance(shift_callable, itir.FunCall) + assert isinstance(shift_callable.fun, itir.SymRef) + assert shift_callable.fun.id == "shift" + iterator = transformer._visit_shift(can_deref_callable) + + # create tasklet to check that field indices are non-negative (-1 is invalid) + args = [ValueExpr(iterator.indices[dim], iterator.dtype) for dim in iterator.dimensions] + internals = [f"{arg.value.data}_v" for arg in args] + expr_code = " && ".join([f"{v} >= 0" for v in internals]) + + # TODO(edopao): select-memlet could maybe allow to efficiently translate can_deref to predicative execution + return transformer.add_expr_tasklet( + list(zip(args, internals)), expr_code, dace.dtypes.bool, "can_deref" + ) + + def builtin_if( transformer: "PythonTaskletCodegen", node: itir.Expr, node_args: list[itir.Expr] ) -> list[ValueExpr]: @@ -318,11 +341,12 @@ def builtin_undefined(*args: Any) -> Any: _GENERAL_BUILTIN_MAPPING: dict[ str, Callable[["PythonTaskletCodegen", itir.Expr, list[itir.Expr]], list[ValueExpr]] ] = { - "make_tuple": builtin_make_tuple, - "tuple_get": builtin_tuple_get, - "if_": builtin_if, + "can_deref": builtin_can_deref, "cast_": builtin_cast, + "if_": builtin_if, + "make_tuple": builtin_make_tuple, "neighbors": builtin_neighbors, + "tuple_get": builtin_tuple_get, } diff --git a/tests/next_tests/integration_tests/feature_tests/iterator_tests/test_builtins.py b/tests/next_tests/integration_tests/feature_tests/iterator_tests/test_builtins.py index d5d57c9024..2bcd0f8367 100644 --- a/tests/next_tests/integration_tests/feature_tests/iterator_tests/test_builtins.py +++ b/tests/next_tests/integration_tests/feature_tests/iterator_tests/test_builtins.py @@ -250,7 +250,6 @@ def foo(a): @pytest.mark.parametrize("stencil", [_can_deref, _can_deref_lifted]) -@pytest.mark.uses_can_deref def test_can_deref(program_processor, stencil): program_processor, validate = program_processor From 67a618856331a97530560b1607a7d11c3c3ef802 Mon Sep 17 00:00:00 2001 From: ninaburg <83002751+ninaburg@users.noreply.github.com> Date: Fri, 17 Nov 2023 12:49:20 +0100 Subject: [PATCH 06/85] feat[next]: Extend astype to work with tuples (#1352) * Extend astype() for tuples * Adapt existing test for arg types of astype() * Adress requested style change * Add extra type check * Use apply_to_primitive_constituents function on (nested) tuples * Adress 'nitpicking' change * Remove previous test and add integration test for casting (nested) tuples * Adapt visit_astype method with recursive func for nested tuples * Fix integration test * Call 'with_altered_scalar_kind' only once * Recursive 'process_elements' func to apply a func on the elts of a tuple * Fix execution tests * Adapt visit_astype for foast.Call and foast.Name * Fix tests * Rename args and refactor 'process_elements' * Fix tests --------- Co-authored-by: Nina Burgdorfer --- src/gt4py/next/ffront/fbuiltins.py | 6 +- .../ffront/foast_passes/type_deduction.py | 11 ++- src/gt4py/next/ffront/foast_to_itir.py | 35 ++++++++-- .../ffront_tests/test_execution.py | 70 +++++++++++++++++++ .../ffront_tests/test_type_deduction.py | 4 +- 5 files changed, 114 insertions(+), 12 deletions(-) diff --git a/src/gt4py/next/ffront/fbuiltins.py b/src/gt4py/next/ffront/fbuiltins.py index 13c21eb516..7b96de8e89 100644 --- a/src/gt4py/next/ffront/fbuiltins.py +++ b/src/gt4py/next/ffront/fbuiltins.py @@ -196,7 +196,11 @@ def where( @builtin_function -def astype(field: Field | gt4py_defs.ScalarT, type_: type, /) -> Field: +def astype( + field: Field | gt4py_defs.ScalarT | Tuple[Field, ...], + type_: type, + /, +) -> Field | Tuple[Field, ...]: raise NotImplementedError() diff --git a/src/gt4py/next/ffront/foast_passes/type_deduction.py b/src/gt4py/next/ffront/foast_passes/type_deduction.py index 605b83a5f0..95c9128f87 100644 --- a/src/gt4py/next/ffront/foast_passes/type_deduction.py +++ b/src/gt4py/next/ffront/foast_passes/type_deduction.py @@ -823,10 +823,12 @@ def _visit_min_over(self, node: foast.Call, **kwargs) -> foast.Call: return self._visit_reduction(node, **kwargs) def _visit_astype(self, node: foast.Call, **kwargs) -> foast.Call: + return_type: ts.TupleType | ts.ScalarType | ts.FieldType value, new_type = node.args assert isinstance( - value.type, (ts.FieldType, ts.ScalarType) + value.type, (ts.FieldType, ts.ScalarType, ts.TupleType) ) # already checked using generic mechanism + if not isinstance(new_type, foast.Name) or new_type.id.upper() not in [ kind.name for kind in ts.ScalarKind ]: @@ -835,8 +837,11 @@ def _visit_astype(self, node: foast.Call, **kwargs) -> foast.Call: f"Invalid call to `astype`. Second argument must be a scalar type, but got {new_type}.", ) - return_type = with_altered_scalar_kind( - value.type, getattr(ts.ScalarKind, new_type.id.upper()) + return_type = type_info.apply_to_primitive_constituents( + value.type, + lambda primitive_type: with_altered_scalar_kind( + primitive_type, getattr(ts.ScalarKind, new_type.id.upper()) + ), ) return foast.Call( diff --git a/src/gt4py/next/ffront/foast_to_itir.py b/src/gt4py/next/ffront/foast_to_itir.py index 1902d71b3c..816b8581f1 100644 --- a/src/gt4py/next/ffront/foast_to_itir.py +++ b/src/gt4py/next/ffront/foast_to_itir.py @@ -317,12 +317,9 @@ def visit_Call(self, node: foast.Call, **kwargs) -> itir.Expr: def _visit_astype(self, node: foast.Call, **kwargs) -> itir.FunCall: assert len(node.args) == 2 and isinstance(node.args[1], foast.Name) - obj, dtype = node.args[0], node.args[1].id - - # TODO check that we test astype that results in a itir.map_ operation - return self._map( - im.lambda_("it")(im.call("cast_")("it", str(dtype))), - obj, + obj, new_type = node.args[0], node.args[1].id + return self._process_elements( + lambda x: im.call("cast_")(x, str(new_type)), obj, obj.type, **kwargs ) def _visit_where(self, node: foast.Call, **kwargs) -> itir.FunCall: @@ -403,6 +400,32 @@ def _map(self, op, *args, **kwargs): return im.promote_to_lifted_stencil(im.call(op))(*lowered_args) + def _process_elements( + self, + process_func: Callable[[itir.Expr], itir.Expr], + obj: foast.Expr, + current_el_type: ts.TypeSpec, + current_el_expr: itir.Expr = im.ref("expr"), + ): + """Recursively applies a processing function to all primitive constituents of a tuple.""" + if isinstance(current_el_type, ts.TupleType): + # TODO(ninaburg): Refactor to avoid duplicating lowered obj expression for each tuple element. + return im.promote_to_lifted_stencil(lambda *elts: im.make_tuple(*elts))( + *[ + self._process_elements( + process_func, + obj, + current_el_type.types[i], + im.tuple_get(i, current_el_expr), + ) + for i in range(len(current_el_type.types)) + ] + ) + elif type_info.contains_local_field(current_el_type): + raise NotImplementedError("Processing fields with local dimension is not implemented.") + else: + return self._map(im.lambda_("expr")(process_func(current_el_expr)), obj) + class FieldOperatorLoweringError(Exception): ... diff --git a/tests/next_tests/integration_tests/feature_tests/ffront_tests/test_execution.py b/tests/next_tests/integration_tests/feature_tests/ffront_tests/test_execution.py index d381a2242a..58181fd7a8 100644 --- a/tests/next_tests/integration_tests/feature_tests/ffront_tests/test_execution.py +++ b/tests/next_tests/integration_tests/feature_tests/ffront_tests/test_execution.py @@ -325,6 +325,76 @@ def testee(a: cases.IFloatField) -> gtx.Field[[IDim], int64]: ) +@pytest.mark.uses_tuple_returns +def test_astype_on_tuples(cartesian_case): # noqa: F811 # fixtures + @gtx.field_operator + def field_op_returning_a_tuple( + a: cases.IFloatField, b: cases.IFloatField + ) -> tuple[gtx.Field[[IDim], float], gtx.Field[[IDim], float]]: + tup = (a, b) + return tup + + @gtx.field_operator + def cast_tuple( + a: cases.IFloatField, + b: cases.IFloatField, + a_casted_to_int_outside_of_gt4py: cases.IField, + b_casted_to_int_outside_of_gt4py: cases.IField, + ) -> tuple[gtx.Field[[IDim], bool], gtx.Field[[IDim], bool]]: + result = astype(field_op_returning_a_tuple(a, b), int32) + return ( + result[0] == a_casted_to_int_outside_of_gt4py, + result[1] == b_casted_to_int_outside_of_gt4py, + ) + + @gtx.field_operator + def cast_nested_tuple( + a: cases.IFloatField, + b: cases.IFloatField, + a_casted_to_int_outside_of_gt4py: cases.IField, + b_casted_to_int_outside_of_gt4py: cases.IField, + ) -> tuple[gtx.Field[[IDim], bool], gtx.Field[[IDim], bool], gtx.Field[[IDim], bool]]: + result = astype((a, field_op_returning_a_tuple(a, b)), int32) + return ( + result[0] == a_casted_to_int_outside_of_gt4py, + result[1][0] == a_casted_to_int_outside_of_gt4py, + result[1][1] == b_casted_to_int_outside_of_gt4py, + ) + + a = cases.allocate(cartesian_case, cast_tuple, "a")() + b = cases.allocate(cartesian_case, cast_tuple, "b")() + a_casted_to_int_outside_of_gt4py = gtx.np_as_located_field(IDim)(np.asarray(a).astype(int32)) + b_casted_to_int_outside_of_gt4py = gtx.np_as_located_field(IDim)(np.asarray(b).astype(int32)) + out_tuple = cases.allocate(cartesian_case, cast_tuple, cases.RETURN)() + out_nested_tuple = cases.allocate(cartesian_case, cast_nested_tuple, cases.RETURN)() + + cases.verify( + cartesian_case, + cast_tuple, + a, + b, + a_casted_to_int_outside_of_gt4py, + b_casted_to_int_outside_of_gt4py, + out=out_tuple, + ref=(np.full_like(a, True, dtype=bool), np.full_like(b, True, dtype=bool)), + ) + + cases.verify( + cartesian_case, + cast_nested_tuple, + a, + b, + a_casted_to_int_outside_of_gt4py, + b_casted_to_int_outside_of_gt4py, + out=out_nested_tuple, + ref=( + np.full_like(a, True, dtype=bool), + np.full_like(a, True, dtype=bool), + np.full_like(b, True, dtype=bool), + ), + ) + + def test_astype_bool_field(cartesian_case): # noqa: F811 # fixtures @gtx.field_operator def testee(a: cases.IFloatField) -> gtx.Field[[IDim], bool]: diff --git a/tests/next_tests/integration_tests/feature_tests/ffront_tests/test_type_deduction.py b/tests/next_tests/integration_tests/feature_tests/ffront_tests/test_type_deduction.py index 7800a30e41..dfa710e038 100644 --- a/tests/next_tests/integration_tests/feature_tests/ffront_tests/test_type_deduction.py +++ b/tests/next_tests/integration_tests/feature_tests/ffront_tests/test_type_deduction.py @@ -785,8 +785,8 @@ def simple_astype(a: Field[[TDim], float64]): def test_astype_wrong_value_type(): def simple_astype(a: Field[[TDim], float64]): - # we just use a tuple here but anything that is not a field or scalar works - return astype((1, 2), bool) + # we just use broadcast here but anything that is not a field, scalar or tuple thereof works + return astype(broadcast, bool) with pytest.raises(errors.DSLError) as exc_info: _ = FieldOperatorParser.apply_to_function(simple_astype) From 39d1c0958c06da22973417660035ec8e023b6956 Mon Sep 17 00:00:00 2001 From: ninaburg <83002751+ninaburg@users.noreply.github.com> Date: Fri, 17 Nov 2023 14:17:43 +0100 Subject: [PATCH 07/85] fix[next]: Names of variable in tests (#1362) --- .../ffront_tests/test_execution.py | 30 +++++++++---------- 1 file changed, 15 insertions(+), 15 deletions(-) diff --git a/tests/next_tests/integration_tests/feature_tests/ffront_tests/test_execution.py b/tests/next_tests/integration_tests/feature_tests/ffront_tests/test_execution.py index 58181fd7a8..8787b7d7bc 100644 --- a/tests/next_tests/integration_tests/feature_tests/ffront_tests/test_execution.py +++ b/tests/next_tests/integration_tests/feature_tests/ffront_tests/test_execution.py @@ -338,33 +338,33 @@ def field_op_returning_a_tuple( def cast_tuple( a: cases.IFloatField, b: cases.IFloatField, - a_casted_to_int_outside_of_gt4py: cases.IField, - b_casted_to_int_outside_of_gt4py: cases.IField, + a_asint: cases.IField, + b_asint: cases.IField, ) -> tuple[gtx.Field[[IDim], bool], gtx.Field[[IDim], bool]]: result = astype(field_op_returning_a_tuple(a, b), int32) return ( - result[0] == a_casted_to_int_outside_of_gt4py, - result[1] == b_casted_to_int_outside_of_gt4py, + result[0] == a_asint, + result[1] == b_asint, ) @gtx.field_operator def cast_nested_tuple( a: cases.IFloatField, b: cases.IFloatField, - a_casted_to_int_outside_of_gt4py: cases.IField, - b_casted_to_int_outside_of_gt4py: cases.IField, + a_asint: cases.IField, + b_asint: cases.IField, ) -> tuple[gtx.Field[[IDim], bool], gtx.Field[[IDim], bool], gtx.Field[[IDim], bool]]: result = astype((a, field_op_returning_a_tuple(a, b)), int32) return ( - result[0] == a_casted_to_int_outside_of_gt4py, - result[1][0] == a_casted_to_int_outside_of_gt4py, - result[1][1] == b_casted_to_int_outside_of_gt4py, + result[0] == a_asint, + result[1][0] == a_asint, + result[1][1] == b_asint, ) a = cases.allocate(cartesian_case, cast_tuple, "a")() b = cases.allocate(cartesian_case, cast_tuple, "b")() - a_casted_to_int_outside_of_gt4py = gtx.np_as_located_field(IDim)(np.asarray(a).astype(int32)) - b_casted_to_int_outside_of_gt4py = gtx.np_as_located_field(IDim)(np.asarray(b).astype(int32)) + a_asint = gtx.np_as_located_field(IDim)(np.asarray(a).astype(int32)) + b_asint = gtx.np_as_located_field(IDim)(np.asarray(b).astype(int32)) out_tuple = cases.allocate(cartesian_case, cast_tuple, cases.RETURN)() out_nested_tuple = cases.allocate(cartesian_case, cast_nested_tuple, cases.RETURN)() @@ -373,8 +373,8 @@ def cast_nested_tuple( cast_tuple, a, b, - a_casted_to_int_outside_of_gt4py, - b_casted_to_int_outside_of_gt4py, + a_asint, + b_asint, out=out_tuple, ref=(np.full_like(a, True, dtype=bool), np.full_like(b, True, dtype=bool)), ) @@ -384,8 +384,8 @@ def cast_nested_tuple( cast_nested_tuple, a, b, - a_casted_to_int_outside_of_gt4py, - b_casted_to_int_outside_of_gt4py, + a_asint, + b_asint, out=out_nested_tuple, ref=( np.full_like(a, True, dtype=bool), From ecd0b68a4492a6f01ac3f9a66e888da7c992a0c0 Mon Sep 17 00:00:00 2001 From: Hannes Vogt Date: Fri, 17 Nov 2023 22:31:43 +0100 Subject: [PATCH 08/85] feat[next] Enable embedded field view in ffront_tests (#1361) Enables field view in ffront_tests New exclusion markers for some cases - cartesian and unstructured shifts - scan - check for a very concrete error message in parsing: we should match this later in embedded Adds the following features to embedded: - support for scalar broadcast, astype, binary functions - adds `__ne__` and `__eq__` to Field TODOs: - full comparison operators for UnitRange - full comparison operators for Fields --- pyproject.toml | 6 +- src/gt4py/_core/definitions.py | 3 + src/gt4py/next/common.py | 20 ++- src/gt4py/next/constructors.py | 2 + src/gt4py/next/embedded/nd_array_field.py | 30 +++-- src/gt4py/next/ffront/decorator.py | 55 +++++--- src/gt4py/next/ffront/fbuiltins.py | 117 +++++++++++------- src/gt4py/next/iterator/embedded.py | 12 ++ tests/next_tests/exclusion_matrices.py | 12 ++ .../ffront_tests/ffront_test_utils.py | 19 ++- .../ffront_tests/test_arg_call_interface.py | 4 + .../ffront_tests/test_execution.py | 66 ++++++++-- .../ffront_tests/test_external_local_field.py | 3 + .../ffront_tests/test_gt4py_builtins.py | 7 ++ .../test_math_builtin_execution.py | 3 + .../ffront_tests/test_math_unary_builtins.py | 17 ++- .../ffront_tests/test_program.py | 4 +- .../ffront_tests/test_icon_like_scan.py | 3 + .../ffront_tests/test_laplacian.py | 4 + .../embedded_tests/test_nd_array_field.py | 2 +- tests/next_tests/unit_tests/test_common.py | 16 +++ 21 files changed, 293 insertions(+), 112 deletions(-) diff --git a/pyproject.toml b/pyproject.toml index 7690ae583e..041448e17d 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -343,7 +343,11 @@ markers = [ 'uses_strided_neighbor_offset: tests that require backend support for strided neighbor offset', 'uses_tuple_args: tests that require backend support for tuple arguments', 'uses_tuple_returns: tests that require backend support for tuple results', - 'uses_zero_dimensional_fields: tests that require backend support for zero-dimensional fields' + 'uses_zero_dimensional_fields: tests that require backend support for zero-dimensional fields', + 'uses_cartesian_shift: tests that use a Cartesian connectivity', + 'uses_unstructured_shift: tests that use a unstructured connectivity', + 'uses_scan: tests that uses scan', + 'checks_specific_error: tests that rely on the backend to produce a specific error message' ] norecursedirs = ['dist', 'build', 'cpp_backend_tests/build*', '_local/*', '.*'] testpaths = 'tests' diff --git a/src/gt4py/_core/definitions.py b/src/gt4py/_core/definitions.py index 7b318bc2de..79543a1849 100644 --- a/src/gt4py/_core/definitions.py +++ b/src/gt4py/_core/definitions.py @@ -446,6 +446,9 @@ def shape(self) -> tuple[int, ...]: def dtype(self) -> Any: ... + def astype(self, dtype: npt.DTypeLike) -> NDArrayObject: + ... + def __getitem__(self, item: Any) -> NDArrayObject: ... diff --git a/src/gt4py/next/common.py b/src/gt4py/next/common.py index ffaa410563..66766be76b 100644 --- a/src/gt4py/next/common.py +++ b/src/gt4py/next/common.py @@ -133,7 +133,7 @@ def __getitem__(self, index: int | slice) -> int | UnitRange: else: raise IndexError("UnitRange index out of range") - def __and__(self, other: Set[Any]) -> UnitRange: + def __and__(self, other: Set[int]) -> UnitRange: if isinstance(other, UnitRange): start = max(self.start, other.start) stop = min(self.stop, other.stop) @@ -141,6 +141,16 @@ def __and__(self, other: Set[Any]) -> UnitRange: else: raise NotImplementedError("Can only find the intersection between UnitRange instances.") + def __le__(self, other: Set[int]): + if isinstance(other, UnitRange): + return self.start >= other.start and self.stop <= other.stop + elif len(self) == Infinity.positive(): + return False + else: + return Set.__le__(self, other) + + __ge__ = __lt__ = __gt__ = lambda self, other: NotImplemented + def __str__(self) -> str: return f"({self.start}:{self.stop})" @@ -486,6 +496,14 @@ def __neg__(self) -> Field: def __invert__(self) -> Field: """Only defined for `Field` of value type `bool`.""" + @abc.abstractmethod + def __eq__(self, other: Any) -> Field: # type: ignore[override] # mypy wants return `bool` + ... + + @abc.abstractmethod + def __ne__(self, other: Any) -> Field: # type: ignore[override] # mypy wants return `bool` + ... + @abc.abstractmethod def __add__(self, other: Field | core_defs.ScalarT) -> Field: ... diff --git a/src/gt4py/next/constructors.py b/src/gt4py/next/constructors.py index 30ef8452aa..42b0bcda90 100644 --- a/src/gt4py/next/constructors.py +++ b/src/gt4py/next/constructors.py @@ -82,6 +82,8 @@ def empty( (3, 3) """ dtype = core_defs.dtype(dtype) + if allocator is None and device is None: + device = core_defs.Device(core_defs.DeviceType.CPU, device_id=0) buffer = next_allocators.allocate( domain, dtype, aligned_index=aligned_index, allocator=allocator, device=device ) diff --git a/src/gt4py/next/embedded/nd_array_field.py b/src/gt4py/next/embedded/nd_array_field.py index ea88948841..51e613ef81 100644 --- a/src/gt4py/next/embedded/nd_array_field.py +++ b/src/gt4py/next/embedded/nd_array_field.py @@ -135,25 +135,22 @@ def from_array( /, *, domain: common.DomainLike, - dtype_like: Optional[core_defs.DType] = None, # TODO define DTypeLike + dtype: Optional[core_defs.DTypeLike] = None, ) -> NdArrayField: domain = common.domain(domain) xp = cls.array_ns - xp_dtype = None if dtype_like is None else xp.dtype(core_defs.dtype(dtype_like).scalar_type) + xp_dtype = None if dtype is None else xp.dtype(core_defs.dtype(dtype).scalar_type) array = xp.asarray(data, dtype=xp_dtype) - if dtype_like is not None: - assert array.dtype.type == core_defs.dtype(dtype_like).scalar_type + if dtype is not None: + assert array.dtype.type == core_defs.dtype(dtype).scalar_type assert issubclass(array.dtype.type, core_defs.SCALAR_TYPES) assert all(isinstance(d, common.Dimension) for d in domain.dims), domain assert len(domain) == array.ndim - assert all( - len(r) == s or (s == 1 and r == common.UnitRange.infinity()) - for r, s in zip(domain.ranges, array.shape) - ) + assert all(len(r) == s or s == 1 for r, s in zip(domain.ranges, array.shape)) return cls(domain, array) @@ -194,6 +191,10 @@ def restrict(self, index: common.AnyIndexSpec) -> common.Field | core_defs.Scala __mod__ = __rmod__ = _make_builtin("mod", "mod") + __ne__ = _make_builtin("not_equal", "not_equal") # type: ignore[assignment] # mypy wants return `bool` + + __eq__ = _make_builtin("equal", "equal") # type: ignore[assignment] # mypy wants return `bool` + def __and__(self, other: common.Field | core_defs.ScalarT) -> NdArrayField: if self.dtype == core_defs.BoolDType(): return _make_builtin("logical_and", "logical_and")(self, other) @@ -285,7 +286,7 @@ def _np_cp_setitem( _nd_array_implementations = [np] -@dataclasses.dataclass(frozen=True) +@dataclasses.dataclass(frozen=True, eq=False) class NumPyArrayField(NdArrayField): array_ns: ClassVar[ModuleType] = np @@ -298,7 +299,7 @@ class NumPyArrayField(NdArrayField): if cp: _nd_array_implementations.append(cp) - @dataclasses.dataclass(frozen=True) + @dataclasses.dataclass(frozen=True, eq=False) class CuPyArrayField(NdArrayField): array_ns: ClassVar[ModuleType] = cp @@ -310,7 +311,7 @@ class CuPyArrayField(NdArrayField): if jnp: _nd_array_implementations.append(jnp) - @dataclasses.dataclass(frozen=True) + @dataclasses.dataclass(frozen=True, eq=False) class JaxArrayField(NdArrayField): array_ns: ClassVar[ModuleType] = jnp @@ -351,6 +352,13 @@ def _builtins_broadcast( NdArrayField.register_builtin_func(fbuiltins.broadcast, _builtins_broadcast) +def _astype(field: NdArrayField, type_: type) -> NdArrayField: + return field.__class__.from_array(field.ndarray.astype(type_), domain=field.domain) + + +NdArrayField.register_builtin_func(fbuiltins.astype, _astype) # type: ignore[arg-type] # TODO(havogt) the registry should not be for any Field + + def _get_slices_from_domain_slice( domain: common.Domain, domain_slice: common.Domain | Sequence[common.NamedRange | common.NamedIndex | Any], diff --git a/src/gt4py/next/ffront/decorator.py b/src/gt4py/next/ffront/decorator.py index 2d12331513..107415eb06 100644 --- a/src/gt4py/next/ffront/decorator.py +++ b/src/gt4py/next/ffront/decorator.py @@ -32,7 +32,7 @@ from gt4py._core import definitions as core_defs from gt4py.eve import utils as eve_utils from gt4py.eve.extended_typing import Any, Optional -from gt4py.next import allocators as next_allocators +from gt4py.next import allocators as next_allocators, common from gt4py.next.common import Dimension, DimensionKind, GridType from gt4py.next.ffront import ( dialect_ast_enums, @@ -171,14 +171,14 @@ class Program: past_node: past.Program closure_vars: dict[str, Any] definition: Optional[types.FunctionType] = None - backend: Optional[ppi.ProgramExecutor] = None + backend: Optional[ppi.ProgramExecutor] = DEFAULT_BACKEND grid_type: Optional[GridType] = None @classmethod def from_function( cls, definition: types.FunctionType, - backend: Optional[ppi.ProgramExecutor] = None, + backend: Optional[ppi.ProgramExecutor] = DEFAULT_BACKEND, grid_type: Optional[GridType] = None, ) -> Program: source_def = SourceDefinition.from_function(definition) @@ -282,27 +282,23 @@ def itir(self) -> itir.FencilDefinition: ) def __call__(self, *args, offset_provider: dict[str, Dimension], **kwargs) -> None: - if ( - self.backend is None and DEFAULT_BACKEND is None - ): # TODO(havogt): for now enable embedded execution by setting DEFAULT_BACKEND to None - self.definition(*args, **kwargs) - return - rewritten_args, size_args, kwargs = self._process_args(args, kwargs) - if not self.backend: + if self.backend is None: warnings.warn( UserWarning( - f"Field View Program '{self.itir.id}': Using default ({DEFAULT_BACKEND}) backend." + f"Field View Program '{self.itir.id}': Using Python execution, consider selecting a perfomance backend." ) ) - backend = self.backend or DEFAULT_BACKEND - ppi.ensure_processor_kind(backend, ppi.ProgramExecutor) + self.definition(*rewritten_args, **kwargs) + return + + ppi.ensure_processor_kind(self.backend, ppi.ProgramExecutor) if "debug" in kwargs: debug(self.itir) - backend( + self.backend( self.itir, *rewritten_args, *size_args, @@ -547,14 +543,14 @@ class FieldOperator(GTCallable, Generic[OperatorNodeT]): foast_node: OperatorNodeT closure_vars: dict[str, Any] definition: Optional[types.FunctionType] = None - backend: Optional[ppi.ProgramExecutor] = None + backend: Optional[ppi.ProgramExecutor] = DEFAULT_BACKEND grid_type: Optional[GridType] = None @classmethod def from_function( cls, definition: types.FunctionType, - backend: Optional[ppi.ProgramExecutor] = None, + backend: Optional[ppi.ProgramExecutor] = DEFAULT_BACKEND, grid_type: Optional[GridType] = None, *, operator_node_cls: type[OperatorNodeT] = foast.FieldOperator, @@ -687,9 +683,9 @@ def __call__( # if we are reaching this from a program call. if "out" in kwargs: out = kwargs.pop("out") - if "offset_provider" in kwargs: + offset_provider = kwargs.pop("offset_provider", None) + if self.backend is not None: # "out" and "offset_provider" -> field_operator as program - offset_provider = kwargs.pop("offset_provider") args, kwargs = type_info.canonicalize_arguments(self.foast_node.type, args, kwargs) # TODO(tehrengruber): check all offset providers are given # deduce argument types @@ -705,13 +701,34 @@ def __call__( ) else: # "out" -> field_operator called from program in embedded execution - out.ndarray[:] = self.definition(*args, **kwargs).ndarray[:] + # TODO(egparedes): put offset_provider in ctxt var here when implementing remap + domain = kwargs.pop("domain", None) + res = self.definition(*args, **kwargs) + _tuple_assign_field( + out, res, domain=None if domain is None else common.domain(domain) + ) return else: # field_operator called from other field_operator in embedded execution + assert self.backend is None return self.definition(*args, **kwargs) +def _tuple_assign_field( + target: tuple[common.Field | tuple, ...] | common.Field, + source: tuple[common.Field | tuple, ...] | common.Field, + domain: Optional[common.Domain], +): + if isinstance(target, tuple): + if not isinstance(source, tuple): + raise RuntimeError(f"Cannot assign {source} to {target}.") + for t, s in zip(target, source): + _tuple_assign_field(t, s, domain) + else: + domain = domain or target.domain + target[domain] = source[domain] + + @typing.overload def field_operator( definition: types.FunctionType, *, backend: Optional[ppi.ProgramExecutor] diff --git a/src/gt4py/next/ffront/fbuiltins.py b/src/gt4py/next/ffront/fbuiltins.py index 7b96de8e89..706b6a4606 100644 --- a/src/gt4py/next/ffront/fbuiltins.py +++ b/src/gt4py/next/ffront/fbuiltins.py @@ -28,10 +28,12 @@ cast, ) +import numpy as np from numpy import float32, float64, int32, int64 -from gt4py._core import definitions as gt4py_defs -from gt4py.next.common import Dimension, DimensionKind, Field +from gt4py._core import definitions as core_defs +from gt4py.next import common +from gt4py.next.common import Dimension, Field # direct import for TYPE_BUILTINS from gt4py.next.ffront.experimental import as_offset # noqa F401 from gt4py.next.iterator import runtime from gt4py.next.type_system import type_specifications as ts @@ -40,7 +42,14 @@ PYTHON_TYPE_BUILTINS = [bool, int, float, tuple] PYTHON_TYPE_BUILTIN_NAMES = [t.__name__ for t in PYTHON_TYPE_BUILTINS] -TYPE_BUILTINS = [Field, Dimension, int32, int64, float32, float64] + PYTHON_TYPE_BUILTINS +TYPE_BUILTINS = [ + Field, + Dimension, + int32, + int64, + float32, + float64, +] + PYTHON_TYPE_BUILTINS TYPE_BUILTIN_NAMES = [t.__name__ for t in TYPE_BUILTINS] # Be aware: Type aliases are not fully supported in the frontend yet, e.g. `IndexType(1)` will not @@ -54,11 +63,11 @@ def _type_conversion_helper(t: type) -> type[ts.TypeSpec] | tuple[type[ts.TypeSpec], ...]: - if t is Field: + if t is common.Field: return ts.FieldType - elif t is Dimension: + elif t is common.Dimension: return ts.DimensionType - elif t is gt4py_defs.ScalarT: + elif t is core_defs.ScalarT: return ts.ScalarType elif t is type: return ( @@ -128,12 +137,8 @@ def __gt_type__(self) -> ts.FunctionType: ) -def builtin_function(fun: Callable[_P, _R]) -> BuiltInFunction[_R, _P]: - return BuiltInFunction(fun) - - -MaskT = TypeVar("MaskT", bound=Field) -FieldT = TypeVar("FieldT", bound=Union[Field, gt4py_defs.Scalar, Tuple]) +MaskT = TypeVar("MaskT", bound=common.Field) +FieldT = TypeVar("FieldT", bound=Union[common.Field, core_defs.Scalar, Tuple]) class WhereBuiltinFunction( @@ -153,55 +158,71 @@ def __call__(self, mask: MaskT, true_field: FieldT, false_field: FieldT) -> _R: return super().__call__(mask, true_field, false_field) -@builtin_function +@BuiltInFunction def neighbor_sum( - field: Field, + field: common.Field, /, - axis: Dimension, -) -> Field: + axis: common.Dimension, +) -> common.Field: raise NotImplementedError() -@builtin_function +@BuiltInFunction def max_over( - field: Field, + field: common.Field, /, - axis: Dimension, -) -> Field: + axis: common.Dimension, +) -> common.Field: raise NotImplementedError() -@builtin_function +@BuiltInFunction def min_over( - field: Field, + field: common.Field, /, - axis: Dimension, -) -> Field: + axis: common.Dimension, +) -> common.Field: raise NotImplementedError() -@builtin_function -def broadcast(field: Field | gt4py_defs.ScalarT, dims: Tuple[Dimension, ...], /) -> Field: - raise NotImplementedError() +@BuiltInFunction +def broadcast( + field: common.Field | core_defs.ScalarT, + dims: tuple[common.Dimension, ...], + /, +) -> common.Field: + assert core_defs.is_scalar_type( + field + ) # default implementation for scalars, Fields are handled via dispatch + return common.field( + np.asarray(field)[ + tuple([np.newaxis] * len(dims)) + ], # TODO(havogt) use FunctionField once available + domain=common.Domain(dims=dims, ranges=tuple([common.UnitRange.infinity()] * len(dims))), + ) @WhereBuiltinFunction def where( - mask: Field, - true_field: Field | gt4py_defs.ScalarT | Tuple, - false_field: Field | gt4py_defs.ScalarT | Tuple, + mask: common.Field, + true_field: common.Field | core_defs.ScalarT | Tuple, + false_field: common.Field | core_defs.ScalarT | Tuple, /, -) -> Field | Tuple: +) -> common.Field | Tuple: raise NotImplementedError() -@builtin_function +@BuiltInFunction def astype( - field: Field | gt4py_defs.ScalarT | Tuple[Field, ...], + value: Field | core_defs.ScalarT | Tuple, type_: type, /, -) -> Field | Tuple[Field, ...]: - raise NotImplementedError() +) -> Field | core_defs.ScalarT | Tuple: + if isinstance(value, tuple): + return tuple(astype(v, type_) for v in value) + # default implementation for scalars, Fields are handled via dispatch + assert core_defs.is_scalar_type(value) + return core_defs.dtype(type_).scalar_type(value) UNARY_MATH_NUMBER_BUILTIN_NAMES = ["abs"] @@ -233,11 +254,14 @@ def astype( def _make_unary_math_builtin(name): - def impl(value: Field | gt4py_defs.ScalarT, /) -> Field | gt4py_defs.ScalarT: + def impl(value: common.Field | core_defs.ScalarT, /) -> common.Field | core_defs.ScalarT: + # TODO(havogt): enable once we have a failing test (see `test_math_builtin_execution.py`) + # assert core_defs.is_scalar_type(value) # default implementation for scalars, Fields are handled via dispatch # noqa: E800 # commented code + # return getattr(math, name)(value)# noqa: E800 # commented code raise NotImplementedError() impl.__name__ = name - globals()[name] = builtin_function(impl) + globals()[name] = BuiltInFunction(impl) for f in ( @@ -252,14 +276,17 @@ def impl(value: Field | gt4py_defs.ScalarT, /) -> Field | gt4py_defs.ScalarT: def _make_binary_math_builtin(name): def impl( - lhs: Field | gt4py_defs.ScalarT, - rhs: Field | gt4py_defs.ScalarT, + lhs: common.Field | core_defs.ScalarT, + rhs: common.Field | core_defs.ScalarT, /, - ) -> Field | gt4py_defs.ScalarT: - raise NotImplementedError() + ) -> common.Field | core_defs.ScalarT: + # default implementation for scalars, Fields are handled via dispatch + assert core_defs.is_scalar_type(lhs) + assert core_defs.is_scalar_type(rhs) + return getattr(np, name)(lhs, rhs) impl.__name__ = name - globals()[name] = builtin_function(impl) + globals()[name] = BuiltInFunction(impl) for f in BINARY_MATH_NUMBER_BUILTIN_NAMES: @@ -295,12 +322,12 @@ def impl( # guidelines for decision. @dataclasses.dataclass(frozen=True) class FieldOffset(runtime.Offset): - source: Dimension - target: tuple[Dimension] | tuple[Dimension, Dimension] + source: common.Dimension + target: tuple[common.Dimension] | tuple[common.Dimension, common.Dimension] connectivity: Optional[Any] = None # TODO def __post_init__(self): - if len(self.target) == 2 and self.target[1].kind != DimensionKind.LOCAL: + if len(self.target) == 2 and self.target[1].kind != common.DimensionKind.LOCAL: raise ValueError("Second dimension in offset must be a local dimension.") def __gt_type__(self): diff --git a/src/gt4py/next/iterator/embedded.py b/src/gt4py/next/iterator/embedded.py index 674f99f61c..44294a3a71 100644 --- a/src/gt4py/next/iterator/embedded.py +++ b/src/gt4py/next/iterator/embedded.py @@ -1093,6 +1093,12 @@ def __neg__(self) -> common.Field: def __invert__(self) -> common.Field: raise NotImplementedError() + def __eq__(self, other: Any) -> common.Field: # type: ignore[override] # mypy wants return `bool` + raise NotImplementedError() + + def __ne__(self, other: Any) -> common.Field: # type: ignore[override] # mypy wants return `bool` + raise NotImplementedError() + def __add__(self, other: common.Field | core_defs.ScalarT) -> common.Field: raise NotImplementedError() @@ -1194,6 +1200,12 @@ def __neg__(self) -> common.Field: def __invert__(self) -> common.Field: raise NotImplementedError() + def __eq__(self, other: Any) -> common.Field: # type: ignore[override] # mypy wants return `bool` + raise NotImplementedError() + + def __ne__(self, other: Any) -> common.Field: # type: ignore[override] # mypy wants return `bool` + raise NotImplementedError() + def __add__(self, other: common.Field | core_defs.ScalarT) -> common.Field: raise NotImplementedError() diff --git a/tests/next_tests/exclusion_matrices.py b/tests/next_tests/exclusion_matrices.py index ddea04649f..249e17d358 100644 --- a/tests/next_tests/exclusion_matrices.py +++ b/tests/next_tests/exclusion_matrices.py @@ -98,6 +98,10 @@ class ProgramFormatterId(_PythonObjectIdMixin, str, enum.Enum): USES_TUPLE_ARGS = "uses_tuple_args" USES_TUPLE_RETURNS = "uses_tuple_returns" USES_ZERO_DIMENSIONAL_FIELDS = "uses_zero_dimensional_fields" +USES_CARTESIAN_SHIFT = "uses_cartesian_shift" +USES_UNSTRUCTURED_SHIFT = "uses_unstructured_shift" +USES_SCAN = "uses_scan" +CHECKS_SPECIFIC_ERROR = "checks_specific_error" # Skip messages (available format keys: 'marker', 'backend') UNSUPPORTED_MESSAGE = "'{marker}' tests not supported by '{backend}' backend" @@ -114,10 +118,18 @@ class ProgramFormatterId(_PythonObjectIdMixin, str, enum.Enum): (USES_REDUCTION_WITH_ONLY_SPARSE_FIELDS, XFAIL, REDUCTION_WITH_ONLY_SPARSE_FIELDS_MESSAGE), (USES_SCAN_IN_FIELD_OPERATOR, XFAIL, UNSUPPORTED_MESSAGE), ] +EMBEDDED_SKIP_LIST = [ + (USES_CARTESIAN_SHIFT, XFAIL, UNSUPPORTED_MESSAGE), + (USES_UNSTRUCTURED_SHIFT, XFAIL, UNSUPPORTED_MESSAGE), + (USES_SCAN, XFAIL, UNSUPPORTED_MESSAGE), + (USES_DYNAMIC_OFFSETS, XFAIL, UNSUPPORTED_MESSAGE), + (CHECKS_SPECIFIC_ERROR, XFAIL, UNSUPPORTED_MESSAGE), +] #: Skip matrix, contains for each backend processor a list of tuples with following fields: #: (, ) BACKEND_SKIP_TEST_MATRIX = { + None: EMBEDDED_SKIP_LIST, OptionalProgramBackendId.DACE_CPU: GTFN_SKIP_TEST_LIST + [ (USES_CAN_DEREF, XFAIL, UNSUPPORTED_MESSAGE), diff --git a/tests/next_tests/integration_tests/feature_tests/ffront_tests/ffront_test_utils.py b/tests/next_tests/integration_tests/feature_tests/ffront_tests/ffront_test_utils.py index 386e64451d..fb753bf169 100644 --- a/tests/next_tests/integration_tests/feature_tests/ffront_tests/ffront_test_utils.py +++ b/tests/next_tests/integration_tests/feature_tests/ffront_tests/ffront_test_utils.py @@ -53,6 +53,7 @@ def no_backend(program: itir.FencilDefinition, *args: Any, **kwargs: Any) -> Non definitions.ProgramBackendId.GTFN_CPU, definitions.ProgramBackendId.GTFN_CPU_IMPERATIVE, definitions.ProgramBackendId.GTFN_CPU_WITH_TEMPORARIES, + None, ] + OPTIONAL_PROCESSORS, ids=lambda p: p.short_id() if p is not None else "None", @@ -65,19 +66,15 @@ def fieldview_backend(request): Check ADR 15 for details on the test-exclusion matrices. """ backend_id = request.param - if backend_id is None: - backend = None - else: - backend = backend_id.load() - - for marker, skip_mark, msg in next_tests.exclusion_matrices.BACKEND_SKIP_TEST_MATRIX.get( - backend_id, [] - ): - if request.node.get_closest_marker(marker): - skip_mark(msg.format(marker=marker, backend=backend_id)) + backend = None if backend_id is None else backend_id.load() - backup_backend = decorator.DEFAULT_BACKEND + for marker, skip_mark, msg in next_tests.exclusion_matrices.BACKEND_SKIP_TEST_MATRIX.get( + backend_id, [] + ): + if request.node.get_closest_marker(marker): + skip_mark(msg.format(marker=marker, backend=backend_id)) + backup_backend = decorator.DEFAULT_BACKEND decorator.DEFAULT_BACKEND = no_backend yield backend decorator.DEFAULT_BACKEND = backup_backend diff --git a/tests/next_tests/integration_tests/feature_tests/ffront_tests/test_arg_call_interface.py b/tests/next_tests/integration_tests/feature_tests/ffront_tests/test_arg_call_interface.py index deb1382dfb..6957e628bb 100644 --- a/tests/next_tests/integration_tests/feature_tests/ffront_tests/test_arg_call_interface.py +++ b/tests/next_tests/integration_tests/feature_tests/ffront_tests/test_arg_call_interface.py @@ -158,6 +158,7 @@ def testee( ) +@pytest.mark.uses_scan @pytest.mark.uses_scan_in_field_operator def test_call_scan_operator_from_field_operator(cartesian_case): @scan_operator(axis=KDim, forward=True, init=0.0) @@ -183,6 +184,7 @@ def testee(a: IJKFloatField, b: IJKFloatField) -> IJKFloatField: cases.verify(cartesian_case, testee, a, b, out=out, ref=expected) +@pytest.mark.uses_scan def test_call_scan_operator_from_program(cartesian_case): @scan_operator(axis=KDim, forward=True, init=0.0) def testee_scan(state: float, x: float, y: float) -> float: @@ -222,6 +224,7 @@ def testee( ) +@pytest.mark.uses_scan def test_scan_wrong_return_type(cartesian_case): with pytest.raises( errors.DSLError, @@ -239,6 +242,7 @@ def testee(qc: cases.IKFloatField, param_1: int32, param_2: float, scalar: float testee_scan(qc, param_1, param_2, scalar, out=(qc, param_1, param_2)) +@pytest.mark.uses_scan def test_scan_wrong_state_type(cartesian_case): with pytest.raises( errors.DSLError, diff --git a/tests/next_tests/integration_tests/feature_tests/ffront_tests/test_execution.py b/tests/next_tests/integration_tests/feature_tests/ffront_tests/test_execution.py index 8787b7d7bc..8036c22670 100644 --- a/tests/next_tests/integration_tests/feature_tests/ffront_tests/test_execution.py +++ b/tests/next_tests/integration_tests/feature_tests/ffront_tests/test_execution.py @@ -76,6 +76,7 @@ def testee(a: cases.IJKField, b: cases.IJKField) -> tuple[cases.IJKField, cases. cases.verify_with_default_data(cartesian_case, testee, ref=lambda a, b: (a, b)) +@pytest.mark.uses_cartesian_shift def test_cartesian_shift(cartesian_case): # noqa: F811 # fixtures @gtx.field_operator def testee(a: cases.IJKField) -> cases.IJKField: @@ -87,6 +88,7 @@ def testee(a: cases.IJKField) -> cases.IJKField: cases.verify(cartesian_case, testee, a, out=out, ref=a[1:]) +@pytest.mark.uses_unstructured_shift def test_unstructured_shift(unstructured_case): # noqa: F811 # fixtures @gtx.field_operator def testee(a: cases.VField) -> cases.EField: @@ -99,6 +101,7 @@ def testee(a: cases.VField) -> cases.EField: ) +@pytest.mark.uses_unstructured_shift def test_composed_unstructured_shift(unstructured_case): @gtx.field_operator def composed_shift_unstructured_flat(inp: cases.VField) -> cases.CField: @@ -143,6 +146,7 @@ def composed_shift_unstructured(inp: cases.VField) -> cases.CField: ) +@pytest.mark.uses_cartesian_shift def test_fold_shifts(cartesian_case): # noqa: F811 # fixtures """Shifting the result of an addition should work.""" @@ -206,6 +210,7 @@ def testee(a: int32) -> cases.VField: @pytest.mark.uses_index_fields +@pytest.mark.uses_cartesian_shift def test_scalar_arg_with_field(cartesian_case): # noqa: F811 # fixtures @gtx.field_operator def testee(a: cases.IJKField, b: int32) -> cases.IJKField: @@ -246,6 +251,7 @@ def testee(size: gtx.IndexType, out: gtx.Field[[IDim], gtx.IndexType]): ) +@pytest.mark.uses_scan def test_scalar_scan(cartesian_case): # noqa: F811 # fixtures @gtx.scan_operator(axis=KDim, forward=True, init=(0.0)) def testee_scan(state: float, qc_in: float, scalar: float) -> float: @@ -264,6 +270,7 @@ def testee(qc: cases.IKFloatField, scalar: float): cases.verify(cartesian_case, testee, qc, scalar, inout=qc, ref=expected) +@pytest.mark.uses_scan @pytest.mark.uses_scan_in_field_operator def test_tuple_scalar_scan(cartesian_case): # noqa: F811 # fixtures @gtx.scan_operator(axis=KDim, forward=True, init=0.0) @@ -285,6 +292,7 @@ def testee_op( cases.verify(cartesian_case, testee_op, qc, tuple_scalar, out=qc, ref=expected) +@pytest.mark.uses_scan @pytest.mark.uses_index_fields def test_scalar_scan_vertical_offset(cartesian_case): # noqa: F811 # fixtures @gtx.scan_operator(axis=KDim, forward=True, init=(0.0)) @@ -363,8 +371,8 @@ def cast_nested_tuple( a = cases.allocate(cartesian_case, cast_tuple, "a")() b = cases.allocate(cartesian_case, cast_tuple, "b")() - a_asint = gtx.np_as_located_field(IDim)(np.asarray(a).astype(int32)) - b_asint = gtx.np_as_located_field(IDim)(np.asarray(b).astype(int32)) + a_asint = gtx.as_field([IDim], np.asarray(a).astype(int32)) + b_asint = gtx.as_field([IDim], np.asarray(b).astype(int32)) out_tuple = cases.allocate(cartesian_case, cast_tuple, cases.RETURN)() out_nested_tuple = cases.allocate(cartesian_case, cast_nested_tuple, cases.RETURN)() @@ -483,6 +491,7 @@ def combine(a: cases.IField, b: cases.IField) -> cases.IField: cases.verify_with_default_data(cartesian_case, combine, ref=lambda a, b: a + a + b) +@pytest.mark.uses_unstructured_shift @pytest.mark.uses_reduction_over_lift_expressions def test_nested_reduction(unstructured_case): @gtx.field_operator @@ -504,6 +513,7 @@ def testee(a: cases.EField) -> cases.EField: ) +@pytest.mark.uses_unstructured_shift @pytest.mark.xfail(reason="Not yet supported in lowering, requires `map_`ing of inner reduce op.") def test_nested_reduction_shift_first(unstructured_case): @gtx.field_operator @@ -524,6 +534,7 @@ def testee(inp: cases.EField) -> cases.EField: ) +@pytest.mark.uses_unstructured_shift @pytest.mark.uses_tuple_returns def test_tuple_return_2(unstructured_case): @gtx.field_operator @@ -543,6 +554,7 @@ def testee(a: cases.EField, b: cases.EField) -> tuple[cases.VField, cases.VField ) +@pytest.mark.uses_unstructured_shift @pytest.mark.uses_constant_fields def test_tuple_with_local_field_in_reduction_shifted(unstructured_case): @gtx.field_operator @@ -572,6 +584,7 @@ def testee(a: tuple[tuple[cases.IField, cases.IField], cases.IField]) -> cases.I ) +@pytest.mark.uses_scan @pytest.mark.parametrize("forward", [True, False]) def test_fieldop_from_scan(cartesian_case, forward): init = 1.0 @@ -592,6 +605,7 @@ def simple_scan_operator(carry: float) -> float: cases.verify(cartesian_case, simple_scan_operator, out=out, ref=expected) +@pytest.mark.uses_scan @pytest.mark.uses_lift_expressions def test_solve_triag(cartesian_case): if cartesian_case.backend in [ @@ -680,6 +694,7 @@ def testee( ) +@pytest.mark.uses_unstructured_shift @pytest.mark.uses_reduction_over_lift_expressions def test_ternary_builtin_neighbor_sum(unstructured_case): @gtx.field_operator @@ -698,6 +713,7 @@ def testee(a: cases.EField, b: cases.EField) -> cases.VField: ) +@pytest.mark.uses_scan def test_ternary_scan(cartesian_case): if cartesian_case.backend in [gtfn.run_gtfn_with_temporaries]: pytest.xfail("Temporary extraction does not work correctly in combination with scans.") @@ -720,6 +736,7 @@ def simple_scan_operator(carry: float, a: float) -> float: @pytest.mark.parametrize("forward", [True, False]) +@pytest.mark.uses_scan @pytest.mark.uses_tuple_returns def test_scan_nested_tuple_output(forward, cartesian_case): if cartesian_case.backend in [gtfn.run_gtfn_with_temporaries]: @@ -745,13 +762,14 @@ def testee(out: tuple[cases.KField, tuple[cases.KField, cases.KField]]): cartesian_case, testee, ref=lambda: (expected + 1.0, (expected + 2.0, expected + 3.0)), - comparison=lambda ref, out: np.all(out[0] == ref[0]) - and np.all(out[1][0] == ref[1][0]) - and np.all(out[1][1] == ref[1][1]), + comparison=lambda ref, out: np.all(np.asarray(out[0]) == ref[0]) + and np.all(np.asarray(out[1][0]) == ref[1][0]) + and np.all(np.asarray(out[1][1]) == ref[1][1]), ) @pytest.mark.uses_tuple_args +@pytest.mark.uses_scan def test_scan_nested_tuple_input(cartesian_case): init = 1.0 k_size = cartesian_case.default_sizes[KDim] @@ -824,7 +842,10 @@ def program_domain(a: cases.IField, out: cases.IField): a = cases.allocate(cartesian_case, program_domain, "a")() out = cases.allocate(cartesian_case, program_domain, "out")() - cases.verify(cartesian_case, program_domain, a, out, inout=out[1:9], ref=a[1:9] * 2) + ref = out.ndarray.copy() # ensure we are not overwriting out outside of the domain + ref[1:9] = a[1:9] * 2 + + cases.verify(cartesian_case, program_domain, a, out, inout=out, ref=ref) def test_domain_input_bounds(cartesian_case): @@ -855,6 +876,9 @@ def program_domain( inp = cases.allocate(cartesian_case, program_domain, "inp")() out = cases.allocate(cartesian_case, fieldop_domain, cases.RETURN)() + ref = out.ndarray.copy() + ref[lower_i : int(upper_i / 2)] = inp[lower_i : int(upper_i / 2)] * 2 + cases.verify( cartesian_case, program_domain, @@ -862,8 +886,8 @@ def program_domain( out, lower_i, upper_i, - inout=out[lower_i : int(upper_i / 2)], - ref=inp[lower_i : int(upper_i / 2)] * 2, + inout=out, + ref=ref, ) @@ -895,6 +919,11 @@ def program_domain( a = cases.allocate(cartesian_case, program_domain, "a")() out = cases.allocate(cartesian_case, program_domain, "out")() + ref = out.ndarray.copy() + ref[1 * lower_i : upper_i + 0, lower_j - 0 : upper_j] = ( + a[1 * lower_i : upper_i + 0, lower_j - 0 : upper_j] * 2 + ) + cases.verify( cartesian_case, program_domain, @@ -904,8 +933,8 @@ def program_domain( upper_i, lower_j, upper_j, - inout=out[1 * lower_i : upper_i + 0, lower_j - 0 : upper_j], - ref=a[1 * lower_i : upper_i + 0, lower_j - 0 : upper_j] * 2, + inout=out, + ref=ref, ) @@ -930,6 +959,11 @@ def program_domain_tuple( out0 = cases.allocate(cartesian_case, program_domain_tuple, "out0")() out1 = cases.allocate(cartesian_case, program_domain_tuple, "out1")() + ref0 = out0.ndarray.copy() + ref0[1:9, 4:6] = inp0[1:9, 4:6] + inp1[1:9, 4:6] + ref1 = out1.ndarray.copy() + ref1[1:9, 4:6] = inp1[1:9, 4:6] + cases.verify( cartesian_case, program_domain_tuple, @@ -937,11 +971,12 @@ def program_domain_tuple( inp1, out0, out1, - inout=(out0[1:9, 4:6], out1[1:9, 4:6]), - ref=(inp0[1:9, 4:6] + inp1[1:9, 4:6], inp1[1:9, 4:6]), + inout=(out0, out1), + ref=(ref0, ref1), ) +@pytest.mark.uses_cartesian_shift def test_where_k_offset(cartesian_case): @gtx.field_operator def fieldop_where_k_offset( @@ -1079,6 +1114,13 @@ def _invalid_unpack() -> tuple[int32, float64, int32]: def test_constant_closure_vars(cartesian_case): + if cartesian_case.backend is None: + # >>> field = gtx.zeros(domain) + # >>> np.int32(1)*field # steals the buffer from the field + # array([0.]) + + # TODO(havogt): remove `__array__`` from `NdArrayField` + pytest.xfail("Bug: Binary operation between np datatype and Field returns ndarray.") from gt4py.eve.utils import FrozenNamespace constants = FrozenNamespace( diff --git a/tests/next_tests/integration_tests/feature_tests/ffront_tests/test_external_local_field.py b/tests/next_tests/integration_tests/feature_tests/ffront_tests/test_external_local_field.py index 04b27c6c17..5135b3d47a 100644 --- a/tests/next_tests/integration_tests/feature_tests/ffront_tests/test_external_local_field.py +++ b/tests/next_tests/integration_tests/feature_tests/ffront_tests/test_external_local_field.py @@ -26,6 +26,9 @@ ) +pytestmark = pytest.mark.uses_unstructured_shift + + def test_external_local_field(unstructured_case): @gtx.field_operator def testee( diff --git a/tests/next_tests/integration_tests/feature_tests/ffront_tests/test_gt4py_builtins.py b/tests/next_tests/integration_tests/feature_tests/ffront_tests/test_gt4py_builtins.py index 8213f54a45..1eba95e880 100644 --- a/tests/next_tests/integration_tests/feature_tests/ffront_tests/test_gt4py_builtins.py +++ b/tests/next_tests/integration_tests/feature_tests/ffront_tests/test_gt4py_builtins.py @@ -39,6 +39,7 @@ ) +@pytest.mark.uses_unstructured_shift @pytest.mark.parametrize( "strategy", [cases.UniqueInitializer(1), cases.UniqueInitializer(-100)], @@ -65,6 +66,7 @@ def testee(edge_f: cases.EField) -> cases.VField: cases.verify(unstructured_case, testee, inp, ref=ref, out=out) +@pytest.mark.uses_unstructured_shift def test_minover_execution(unstructured_case): @gtx.field_operator def minover(edge_f: cases.EField) -> cases.VField: @@ -77,6 +79,7 @@ def minover(edge_f: cases.EField) -> cases.VField: ) +@pytest.mark.uses_unstructured_shift def test_reduction_execution(unstructured_case): @gtx.field_operator def reduction(edge_f: cases.EField) -> cases.VField: @@ -93,6 +96,7 @@ def fencil(edge_f: cases.EField, out: cases.VField): ) +@pytest.mark.uses_unstructured_shift @pytest.mark.uses_constant_fields def test_reduction_expression_in_call(unstructured_case): @gtx.field_operator @@ -113,6 +117,7 @@ def fencil(edge_f: cases.EField, out: cases.VField): ) +@pytest.mark.uses_unstructured_shift def test_reduction_with_common_expression(unstructured_case): @gtx.field_operator def testee(flux: cases.EField) -> cases.VField: @@ -191,6 +196,7 @@ def broadcast_two_fields(inp1: cases.IField, inp2: gtx.Field[[JDim], int32]) -> ) +@pytest.mark.uses_cartesian_shift def test_broadcast_shifted(cartesian_case): @gtx.field_operator def simple_broadcast(inp: cases.IField) -> cases.IJField: @@ -249,6 +255,7 @@ def conditional_promotion(a: cases.IFloatField) -> cases.IFloatField: ) +@pytest.mark.uses_cartesian_shift def test_conditional_shifted(cartesian_case): @gtx.field_operator def conditional_shifted( diff --git a/tests/next_tests/integration_tests/feature_tests/ffront_tests/test_math_builtin_execution.py b/tests/next_tests/integration_tests/feature_tests/ffront_tests/test_math_builtin_execution.py index a5d2b92719..a1839b8e17 100644 --- a/tests/next_tests/integration_tests/feature_tests/ffront_tests/test_math_builtin_execution.py +++ b/tests/next_tests/integration_tests/feature_tests/ffront_tests/test_math_builtin_execution.py @@ -116,6 +116,9 @@ def make_builtin_field_operator(builtin_name: str): @pytest.mark.parametrize("builtin_name, inputs", math_builtin_test_data()) def test_math_function_builtins_execution(cartesian_case, builtin_name: str, inputs): + if cartesian_case.backend is None: + # TODO(havogt) find a way that works for embedded + pytest.xfail("Test does not have a field view program.") if builtin_name == "gamma": # numpy has no gamma function ref_impl: Callable = np.vectorize(math.gamma) diff --git a/tests/next_tests/integration_tests/feature_tests/ffront_tests/test_math_unary_builtins.py b/tests/next_tests/integration_tests/feature_tests/ffront_tests/test_math_unary_builtins.py index 5a277f9440..59e11a7de8 100644 --- a/tests/next_tests/integration_tests/feature_tests/ffront_tests/test_math_unary_builtins.py +++ b/tests/next_tests/integration_tests/feature_tests/ffront_tests/test_math_unary_builtins.py @@ -173,17 +173,14 @@ def tilde_fieldop(inp1: cases.IBoolField) -> cases.IBoolField: def test_unary_not(cartesian_case): - @gtx.field_operator - def not_fieldop(inp1: cases.IBoolField) -> cases.IBoolField: - return not inp1 + pytest.xfail( + "We accidentally supported `not` on fields. This is wrong, we should raise an error." + ) + with pytest.raises: # TODO `not` on a field should be illegal - size = cartesian_case.default_sizes[IDim] - bool_field = np.random.choice(a=[False, True], size=(size)) - inp1 = cases.allocate(cartesian_case, not_fieldop, "inp1").strategy( - cases.ConstInitializer(bool_field) - )() - out = cases.allocate(cartesian_case, not_fieldop, cases.RETURN)() - cases.verify(cartesian_case, not_fieldop, inp1, out=out, ref=~inp1) + @gtx.field_operator + def not_fieldop(inp1: cases.IBoolField) -> cases.IBoolField: + return not inp1 # Trig builtins diff --git a/tests/next_tests/integration_tests/feature_tests/ffront_tests/test_program.py b/tests/next_tests/integration_tests/feature_tests/ffront_tests/test_program.py index 7a1c827a0d..545abd2825 100644 --- a/tests/next_tests/integration_tests/feature_tests/ffront_tests/test_program.py +++ b/tests/next_tests/integration_tests/feature_tests/ffront_tests/test_program.py @@ -51,6 +51,7 @@ def test_identity_fo_execution(cartesian_case, identity_def): ) +@pytest.mark.uses_cartesian_shift def test_shift_by_one_execution(cartesian_case): @gtx.field_operator def shift_by_one(in_field: cases.IFloatField) -> cases.IFloatField: @@ -230,6 +231,7 @@ def test_wrong_argument_type(cartesian_case, copy_program_def): assert re.search(msg, exc_info.value.__cause__.args[0]) is not None +@pytest.mark.checks_specific_error def test_dimensions_domain(cartesian_case): @gtx.field_operator def empty_domain_fieldop(a: cases.IJField): @@ -246,4 +248,4 @@ def empty_domain_program(a: cases.IJField, out_field: cases.IJField): ValueError, match=(r"Dimensions in out field and field domain are not equivalent"), ): - empty_domain_program(a, out_field, offset_provider={}) + cases.run(cartesian_case, empty_domain_program, a, out_field, offset_provider={}) diff --git a/tests/next_tests/integration_tests/multi_feature_tests/ffront_tests/test_icon_like_scan.py b/tests/next_tests/integration_tests/multi_feature_tests/ffront_tests/test_icon_like_scan.py index 108ee25862..eaae9a2a3e 100644 --- a/tests/next_tests/integration_tests/multi_feature_tests/ffront_tests/test_icon_like_scan.py +++ b/tests/next_tests/integration_tests/multi_feature_tests/ffront_tests/test_icon_like_scan.py @@ -25,6 +25,9 @@ ) +pytestmark = pytest.mark.uses_unstructured_shift + + Cell = gtx.Dimension("Cell") KDim = gtx.Dimension("KDim", kind=gtx.DimensionKind.VERTICAL) Koff = gtx.FieldOffset("Koff", KDim, (KDim,)) diff --git a/tests/next_tests/integration_tests/multi_feature_tests/ffront_tests/test_laplacian.py b/tests/next_tests/integration_tests/multi_feature_tests/ffront_tests/test_laplacian.py index d275a977dd..9a1e968de0 100644 --- a/tests/next_tests/integration_tests/multi_feature_tests/ffront_tests/test_laplacian.py +++ b/tests/next_tests/integration_tests/multi_feature_tests/ffront_tests/test_laplacian.py @@ -13,6 +13,7 @@ # SPDX-License-Identifier: GPL-3.0-or-later import numpy as np +import pytest import gt4py.next as gtx @@ -23,6 +24,9 @@ ) +pytestmark = pytest.mark.uses_cartesian_shift + + @gtx.field_operator def lap(in_field: gtx.Field[[IDim, JDim], "float"]) -> gtx.Field[[IDim, JDim], "float"]: return ( diff --git a/tests/next_tests/unit_tests/embedded_tests/test_nd_array_field.py b/tests/next_tests/unit_tests/embedded_tests/test_nd_array_field.py index 49aeece87e..00dbf68274 100644 --- a/tests/next_tests/unit_tests/embedded_tests/test_nd_array_field.py +++ b/tests/next_tests/unit_tests/embedded_tests/test_nd_array_field.py @@ -259,7 +259,7 @@ def test_mixed_fields(product_nd_array_implementation): def test_non_dispatched_function(): - @fbuiltins.builtin_function + @fbuiltins.BuiltInFunction def fma(a: common.Field, b: common.Field, c: common.Field, /) -> common.Field: return a * b + c diff --git a/tests/next_tests/unit_tests/test_common.py b/tests/next_tests/unit_tests/test_common.py index 31e35221ab..84008eb99c 100644 --- a/tests/next_tests/unit_tests/test_common.py +++ b/tests/next_tests/unit_tests/test_common.py @@ -11,6 +11,7 @@ # distribution for a copy of the license or check . # # SPDX-License-Identifier: GPL-3.0-or-later +import operator from typing import Optional, Pattern import pytest @@ -150,6 +151,21 @@ def test_mixed_infinity_range(): assert len(mixed_inf_range) == Infinity.positive() +@pytest.mark.parametrize( + "op, rng1, rng2, expected", + [ + (operator.le, UnitRange(-1, 2), UnitRange(-2, 3), True), + (operator.le, UnitRange(-1, 2), {-1, 0, 1}, True), + (operator.le, UnitRange(-1, 2), {-1, 0}, False), + (operator.le, UnitRange(-1, 2), {-2, -1, 0, 1, 2}, True), + (operator.le, UnitRange(Infinity.negative(), 2), UnitRange(Infinity.negative(), 3), True), + (operator.le, UnitRange(Infinity.negative(), 2), {1, 2, 3}, False), + ], +) +def test_range_comparison(op, rng1, rng2, expected): + assert op(rng1, rng2) == expected + + @pytest.mark.parametrize( "named_rng_like", [ From 42912cc9d14e409801c1c71fc99a98f46e7c4a1b Mon Sep 17 00:00:00 2001 From: Hannes Vogt Date: Mon, 20 Nov 2023 11:13:36 +0100 Subject: [PATCH 09/85] feat[next] Enable GPU backend tests (#1357) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit - connectivities are implicitly copied to GPU if they are not already on GPU, this might be removed later - changes to cases: ensure we don't pass arrays to ConstInitializer --------- Co-authored-by: Rico Häuselmann --- src/gt4py/next/embedded/nd_array_field.py | 5 +- .../codegens/gtfn/codegen.py | 59 +++++++------- .../next/program_processors/runners/gtfn.py | 30 +++++-- tests/next_tests/exclusion_matrices.py | 5 ++ tests/next_tests/integration_tests/cases.py | 18 ++++- .../ffront_tests/ffront_test_utils.py | 1 + .../ffront_tests/test_execution.py | 33 ++++---- .../ffront_tests/test_external_local_field.py | 8 +- .../ffront_tests/test_gt4py_builtins.py | 18 ++--- .../test_math_builtin_execution.py | 4 +- .../ffront_tests/test_math_unary_builtins.py | 35 +++----- .../ffront_tests/test_program.py | 2 +- .../ffront_tests/test_icon_like_scan.py | 79 ++++++++++++------- .../ffront_tests/test_laplacian.py | 2 +- tests/next_tests/unit_tests/conftest.py | 1 + tox.ini | 2 +- 16 files changed, 176 insertions(+), 126 deletions(-) diff --git a/src/gt4py/next/embedded/nd_array_field.py b/src/gt4py/next/embedded/nd_array_field.py index 51e613ef81..9357570b05 100644 --- a/src/gt4py/next/embedded/nd_array_field.py +++ b/src/gt4py/next/embedded/nd_array_field.py @@ -121,7 +121,10 @@ def ndarray(self) -> core_defs.NDArrayObject: return self._ndarray def __array__(self, dtype: npt.DTypeLike = None) -> np.ndarray: - return np.asarray(self._ndarray, dtype) + if self.array_ns == cp: + return np.asarray(cp.asnumpy(self._ndarray), dtype) + else: + return np.asarray(self._ndarray, dtype) @property def dtype(self) -> core_defs.DType[core_defs.ScalarT]: diff --git a/src/gt4py/next/program_processors/codegens/gtfn/codegen.py b/src/gt4py/next/program_processors/codegens/gtfn/codegen.py index 645d1f742f..23165854de 100644 --- a/src/gt4py/next/program_processors/codegens/gtfn/codegen.py +++ b/src/gt4py/next/program_processors/codegens/gtfn/codegen.py @@ -179,6 +179,10 @@ def visit_FunCall(self, node: gtfn_ir.FunCall, **kwargs): """ ) + def visit_FunctionDefinition(self, node: gtfn_ir.FunctionDefinition, **kwargs): + expr_ = "return " + self.visit(node.expr) + return self.generic_visit(node, expr_=expr_) + FunctionDefinition = as_mako( """ struct ${id} { @@ -206,24 +210,6 @@ def visit_FunCall(self, node: gtfn_ir.FunCall, **kwargs): """ ) - def visit_FunctionDefinition(self, node: gtfn_ir.FunctionDefinition, **kwargs): - expr_ = "return " + self.visit(node.expr) - return self.generic_visit(node, expr_=expr_) - - def visit_FencilDefinition( - self, node: gtfn_ir.FencilDefinition, **kwargs: Any - ) -> Union[str, Collection[str]]: - self.is_cartesian = node.grid_type == common.GridType.CARTESIAN - self.user_defined_function_ids = list( - str(fundef.id) for fundef in node.function_definitions - ) - return self.generic_visit( - node, - grid_type_str=self._grid_type_str[node.grid_type], - block_sizes=self._block_sizes(node.offset_definitions), - **kwargs, - ) - def visit_TemporaryAllocation(self, node, **kwargs): # TODO(tehrengruber): Revisit. We are currently converting an itir.NamedRange with # start and stop values into an gtfn_ir.(Cartesian|Unstructured)Domain with @@ -244,6 +230,20 @@ def visit_TemporaryAllocation(self, node, **kwargs): "auto {id} = gtfn::allocate_global_tmp<{dtype}>(tmp_alloc__, {tmp_sizes});" ) + def visit_FencilDefinition( + self, node: gtfn_ir.FencilDefinition, **kwargs: Any + ) -> Union[str, Collection[str]]: + self.is_cartesian = node.grid_type == common.GridType.CARTESIAN + self.user_defined_function_ids = list( + str(fundef.id) for fundef in node.function_definitions + ) + return self.generic_visit( + node, + grid_type_str=self._grid_type_str[node.grid_type], + block_sizes=self._block_sizes(node.offset_definitions), + **kwargs, + ) + FencilDefinition = as_mako( """ #include @@ -277,16 +277,19 @@ def visit_TemporaryAllocation(self, node, **kwargs): ) def _block_sizes(self, offset_definitions: list[gtfn_ir.TagDefinition]) -> str: - block_dims = [] - block_sizes = [32, 8] + [1] * (len(offset_definitions) - 2) - for i, tag in enumerate(offset_definitions): - if tag.alias is None: - block_dims.append( - f"gridtools::meta::list<{tag.name.id}_t, " - f"gridtools::integral_constant>" - ) - sizes_str = ",\n".join(block_dims) - return f"using block_sizes_t = gridtools::meta::list<{sizes_str}>;" + if self.is_cartesian: + block_dims = [] + block_sizes = [32, 8] + [1] * (len(offset_definitions) - 2) + for i, tag in enumerate(offset_definitions): + if tag.alias is None: + block_dims.append( + f"gridtools::meta::list<{tag.name.id}_t, " + f"gridtools::integral_constant>" + ) + sizes_str = ",\n".join(block_dims) + return f"using block_sizes_t = gridtools::meta::list<{sizes_str}>;" + else: + return "using block_sizes_t = gridtools::meta::list>, gridtools::meta::list>>;" @classmethod def apply(cls, root: Any, **kwargs: Any) -> str: diff --git a/src/gt4py/next/program_processors/runners/gtfn.py b/src/gt4py/next/program_processors/runners/gtfn.py index 7233e7a893..5d4b450d39 100644 --- a/src/gt4py/next/program_processors/runners/gtfn.py +++ b/src/gt4py/next/program_processors/runners/gtfn.py @@ -12,6 +12,8 @@ # # SPDX-License-Identifier: GPL-3.0-or-later +import functools +import warnings from typing import Any import numpy.typing as npt @@ -42,12 +44,14 @@ def convert_arg(arg: Any) -> Any: return arg -def convert_args(inp: stages.CompiledProgram) -> stages.CompiledProgram: +def convert_args( + inp: stages.CompiledProgram, device: core_defs.DeviceType = core_defs.DeviceType.CPU +) -> stages.CompiledProgram: def decorated_program( *args, offset_provider: dict[str, common.Connectivity | common.Dimension] ): converted_args = [convert_arg(arg) for arg in args] - conn_args = extract_connectivity_args(offset_provider) + conn_args = extract_connectivity_args(offset_provider, device) return inp( *converted_args, *conn_args, @@ -56,8 +60,22 @@ def decorated_program( return decorated_program +def _ensure_is_on_device( + connectivity_arg: npt.NDArray, device: core_defs.DeviceType +) -> npt.NDArray: + if device == core_defs.DeviceType.CUDA: + import cupy as cp + + if not isinstance(connectivity_arg, cp.ndarray): + warnings.warn( + "Copying connectivity to device. For performance make sure connectivity is provided on device." + ) + return cp.asarray(connectivity_arg) + return connectivity_arg + + def extract_connectivity_args( - offset_provider: dict[str, common.Connectivity | common.Dimension] + offset_provider: dict[str, common.Connectivity | common.Dimension], device: core_defs.DeviceType ) -> list[tuple[npt.NDArray, tuple[int, ...]]]: # note: the order here needs to agree with the order of the generated bindings args: list[tuple[npt.NDArray, tuple[int, ...]]] = [] @@ -67,7 +85,9 @@ def extract_connectivity_args( raise NotImplementedError( "Only `NeighborTable` connectivities implemented at this point." ) - args.append((conn.table, tuple([0] * 2))) + # copying to device here is a fallback for easy testing and might be removed later + conn_arg = _ensure_is_on_device(conn.table, device) + args.append((conn_arg, tuple([0] * 2))) elif isinstance(conn, common.Dimension): pass else: @@ -126,7 +146,7 @@ def compilation_hash(otf_closure: stages.ProgramCall) -> int: translation=GTFN_GPU_TRANSLATION_STEP, bindings=nanobind.bind_source, compilation=GTFN_DEFAULT_COMPILE_STEP, - decoration=convert_args, + decoration=functools.partial(convert_args, device=core_defs.DeviceType.CUDA), ) diff --git a/tests/next_tests/exclusion_matrices.py b/tests/next_tests/exclusion_matrices.py index 249e17d358..ef30a61687 100644 --- a/tests/next_tests/exclusion_matrices.py +++ b/tests/next_tests/exclusion_matrices.py @@ -50,6 +50,7 @@ class ProgramBackendId(_PythonObjectIdMixin, str, enum.Enum): GTFN_CPU_WITH_TEMPORARIES = ( "gt4py.next.program_processors.runners.gtfn.run_gtfn_with_temporaries" ) + GTFN_GPU = "gt4py.next.program_processors.runners.gtfn.run_gtfn_gpu" ROUNDTRIP = "gt4py.next.program_processors.runners.roundtrip.backend" DOUBLE_ROUNDTRIP = "gt4py.next.program_processors.runners.double_roundtrip.backend" @@ -148,6 +149,10 @@ class ProgramFormatterId(_PythonObjectIdMixin, str, enum.Enum): + [ (USES_STRIDED_NEIGHBOR_OFFSET, XFAIL, BINDINGS_UNSUPPORTED_MESSAGE), ], + ProgramBackendId.GTFN_GPU: GTFN_SKIP_TEST_LIST + + [ + (USES_STRIDED_NEIGHBOR_OFFSET, XFAIL, BINDINGS_UNSUPPORTED_MESSAGE), + ], ProgramBackendId.GTFN_CPU_IMPERATIVE: GTFN_SKIP_TEST_LIST + [ (USES_STRIDED_NEIGHBOR_OFFSET, XFAIL, BINDINGS_UNSUPPORTED_MESSAGE), diff --git a/tests/next_tests/integration_tests/cases.py b/tests/next_tests/integration_tests/cases.py index 634d85e64c..730ce18fd5 100644 --- a/tests/next_tests/integration_tests/cases.py +++ b/tests/next_tests/integration_tests/cases.py @@ -25,6 +25,7 @@ import pytest import gt4py.next as gtx +from gt4py._core import definitions as core_defs from gt4py.eve import extended_typing as xtyping from gt4py.eve.extended_typing import Self from gt4py.next import common, constructors @@ -73,7 +74,7 @@ E2V = gtx.FieldOffset("E2V", source=Vertex, target=(Edge, E2VDim)) C2E = gtx.FieldOffset("E2V", source=Edge, target=(Cell, C2EDim)) -ScalarValue: TypeAlias = np.int32 | np.int64 | np.float32 | np.float64 | np.generic +ScalarValue: TypeAlias = core_defs.Scalar FieldValue: TypeAlias = gtx.Field FieldViewArg: TypeAlias = FieldValue | ScalarValue | tuple["FieldViewArg", ...] FieldViewInout: TypeAlias = FieldValue | tuple["FieldViewInout", ...] @@ -117,12 +118,19 @@ def from_case( return self -@dataclasses.dataclass +@dataclasses.dataclass(init=False) class ConstInitializer(DataInitializer): """Initialize with a given value across the coordinate space.""" value: ScalarValue + def __init__(self, value: ScalarValue): + if not core_defs.is_scalar_type(value): + raise ValueError( + "`ConstInitializer` can not be used with non-scalars. Use `Case.as_field` instead." + ) + self.value = value + @property def scalar_value(self) -> ScalarValue: return self.value @@ -460,7 +468,7 @@ def verify_with_default_data( ``comparison(ref, )`` and should return a boolean. """ inps, kwfields = get_default_data(case, fieldop) - ref_args = tuple(i.ndarray if hasattr(i, "ndarray") else i for i in inps) + ref_args = tuple(i.__array__() if common.is_field(i) else i for i in inps) verify( case, fieldop, @@ -598,3 +606,7 @@ class Case: offset_provider: dict[str, common.Connectivity | gtx.Dimension] default_sizes: dict[gtx.Dimension, int] grid_type: common.GridType + + @property + def as_field(self): + return constructors.as_field.partial(allocator=self.backend) diff --git a/tests/next_tests/integration_tests/feature_tests/ffront_tests/ffront_test_utils.py b/tests/next_tests/integration_tests/feature_tests/ffront_tests/ffront_test_utils.py index fb753bf169..01c78cf950 100644 --- a/tests/next_tests/integration_tests/feature_tests/ffront_tests/ffront_test_utils.py +++ b/tests/next_tests/integration_tests/feature_tests/ffront_tests/ffront_test_utils.py @@ -53,6 +53,7 @@ def no_backend(program: itir.FencilDefinition, *args: Any, **kwargs: Any) -> Non definitions.ProgramBackendId.GTFN_CPU, definitions.ProgramBackendId.GTFN_CPU_IMPERATIVE, definitions.ProgramBackendId.GTFN_CPU_WITH_TEMPORARIES, + pytest.param(definitions.ProgramBackendId.GTFN_GPU, marks=pytest.mark.requires_gpu), None, ] + OPTIONAL_PROCESSORS, diff --git a/tests/next_tests/integration_tests/feature_tests/ffront_tests/test_execution.py b/tests/next_tests/integration_tests/feature_tests/ffront_tests/test_execution.py index 8036c22670..fe18bda9e3 100644 --- a/tests/next_tests/integration_tests/feature_tests/ffront_tests/test_execution.py +++ b/tests/next_tests/integration_tests/feature_tests/ffront_tests/test_execution.py @@ -371,8 +371,8 @@ def cast_nested_tuple( a = cases.allocate(cartesian_case, cast_tuple, "a")() b = cases.allocate(cartesian_case, cast_tuple, "b")() - a_asint = gtx.as_field([IDim], np.asarray(a).astype(int32)) - b_asint = gtx.as_field([IDim], np.asarray(b).astype(int32)) + a_asint = cartesian_case.as_field([IDim], np.asarray(a).astype(int32)) + b_asint = cartesian_case.as_field([IDim], np.asarray(b).astype(int32)) out_tuple = cases.allocate(cartesian_case, cast_tuple, cases.RETURN)() out_nested_tuple = cases.allocate(cartesian_case, cast_nested_tuple, cases.RETURN)() @@ -589,7 +589,7 @@ def testee(a: tuple[tuple[cases.IField, cases.IField], cases.IField]) -> cases.I def test_fieldop_from_scan(cartesian_case, forward): init = 1.0 expected = np.arange(init + 1.0, init + 1.0 + cartesian_case.default_sizes[IDim], 1) - out = gtx.as_field([KDim], np.zeros((cartesian_case.default_sizes[KDim],))) + out = cartesian_case.as_field([KDim], np.zeros((cartesian_case.default_sizes[KDim],))) if not forward: expected = np.flip(expected) @@ -610,6 +610,7 @@ def simple_scan_operator(carry: float) -> float: def test_solve_triag(cartesian_case): if cartesian_case.backend in [ gtfn.run_gtfn, + gtfn.run_gtfn_gpu, gtfn.run_gtfn_imperative, gtfn.run_gtfn_with_temporaries, ]: @@ -723,8 +724,8 @@ def simple_scan_operator(carry: float, a: float) -> float: return carry if carry > a else carry + 1.0 k_size = cartesian_case.default_sizes[KDim] - a = gtx.as_field([KDim], 4.0 * np.ones((k_size,))) - out = gtx.as_field([KDim], np.zeros((k_size,))) + a = cartesian_case.as_field([KDim], 4.0 * np.ones((k_size,))) + out = cartesian_case.as_field([KDim], np.zeros((k_size,))) cases.verify( cartesian_case, @@ -773,16 +774,19 @@ def testee(out: tuple[cases.KField, tuple[cases.KField, cases.KField]]): def test_scan_nested_tuple_input(cartesian_case): init = 1.0 k_size = cartesian_case.default_sizes[KDim] - inp1 = gtx.as_field([KDim], np.ones((k_size,))) - inp2 = gtx.as_field([KDim], np.arange(0.0, k_size, 1)) - out = gtx.as_field([KDim], np.zeros((k_size,))) + + inp1_np = np.ones((k_size,)) + inp2_np = np.arange(0.0, k_size, 1) + inp1 = cartesian_case.as_field([KDim], inp1_np) + inp2 = cartesian_case.as_field([KDim], inp2_np) + out = cartesian_case.as_field([KDim], np.zeros((k_size,))) def prev_levels_iterator(i): return range(i + 1) expected = np.asarray( [ - reduce(lambda prev, i: prev + inp1[i] + inp2[i], prev_levels_iterator(i), init) + reduce(lambda prev, i: prev + inp1_np[i] + inp2_np[i], prev_levels_iterator(i), init) for i in range(k_size) ] ) @@ -842,7 +846,7 @@ def program_domain(a: cases.IField, out: cases.IField): a = cases.allocate(cartesian_case, program_domain, "a")() out = cases.allocate(cartesian_case, program_domain, "out")() - ref = out.ndarray.copy() # ensure we are not overwriting out outside of the domain + ref = np.asarray(out).copy() # ensure we are not overwriting `out` outside of the domain ref[1:9] = a[1:9] * 2 cases.verify(cartesian_case, program_domain, a, out, inout=out, ref=ref) @@ -851,6 +855,7 @@ def program_domain(a: cases.IField, out: cases.IField): def test_domain_input_bounds(cartesian_case): if cartesian_case.backend in [ gtfn.run_gtfn, + gtfn.run_gtfn_gpu, gtfn.run_gtfn_imperative, gtfn.run_gtfn_with_temporaries, ]: @@ -876,7 +881,7 @@ def program_domain( inp = cases.allocate(cartesian_case, program_domain, "inp")() out = cases.allocate(cartesian_case, fieldop_domain, cases.RETURN)() - ref = out.ndarray.copy() + ref = np.asarray(out).copy() ref[lower_i : int(upper_i / 2)] = inp[lower_i : int(upper_i / 2)] * 2 cases.verify( @@ -919,7 +924,7 @@ def program_domain( a = cases.allocate(cartesian_case, program_domain, "a")() out = cases.allocate(cartesian_case, program_domain, "out")() - ref = out.ndarray.copy() + ref = np.asarray(out).copy() ref[1 * lower_i : upper_i + 0, lower_j - 0 : upper_j] = ( a[1 * lower_i : upper_i + 0, lower_j - 0 : upper_j] * 2 ) @@ -959,9 +964,9 @@ def program_domain_tuple( out0 = cases.allocate(cartesian_case, program_domain_tuple, "out0")() out1 = cases.allocate(cartesian_case, program_domain_tuple, "out1")() - ref0 = out0.ndarray.copy() + ref0 = np.asarray(out0).copy() ref0[1:9, 4:6] = inp0[1:9, 4:6] + inp1[1:9, 4:6] - ref1 = out1.ndarray.copy() + ref1 = np.asarray(out1).copy() ref1[1:9, 4:6] = inp1[1:9, 4:6] cases.verify( diff --git a/tests/next_tests/integration_tests/feature_tests/ffront_tests/test_external_local_field.py b/tests/next_tests/integration_tests/feature_tests/ffront_tests/test_external_local_field.py index 5135b3d47a..05adc63a45 100644 --- a/tests/next_tests/integration_tests/feature_tests/ffront_tests/test_external_local_field.py +++ b/tests/next_tests/integration_tests/feature_tests/ffront_tests/test_external_local_field.py @@ -38,7 +38,9 @@ def testee( inp * ones(V2E), axis=V2EDim ) # multiplication with shifted `ones` because reduction of only non-shifted field with local dimension is not supported - inp = gtx.as_field([Vertex, V2EDim], unstructured_case.offset_provider["V2E"].table) + inp = unstructured_case.as_field( + [Vertex, V2EDim], unstructured_case.offset_provider["V2E"].table + ) ones = cases.allocate(unstructured_case, testee, "ones").strategy(cases.ConstInitializer(1))() cases.verify( @@ -59,7 +61,9 @@ def test_external_local_field_only(unstructured_case): def testee(inp: gtx.Field[[Vertex, V2EDim], int32]) -> gtx.Field[[Vertex], int32]: return neighbor_sum(inp, axis=V2EDim) - inp = gtx.as_field([Vertex, V2EDim], unstructured_case.offset_provider["V2E"].table) + inp = unstructured_case.as_field( + [Vertex, V2EDim], unstructured_case.offset_provider["V2E"].table + ) cases.verify( unstructured_case, diff --git a/tests/next_tests/integration_tests/feature_tests/ffront_tests/test_gt4py_builtins.py b/tests/next_tests/integration_tests/feature_tests/ffront_tests/test_gt4py_builtins.py index 1eba95e880..8bc325d276 100644 --- a/tests/next_tests/integration_tests/feature_tests/ffront_tests/test_gt4py_builtins.py +++ b/tests/next_tests/integration_tests/feature_tests/ffront_tests/test_gt4py_builtins.py @@ -48,6 +48,7 @@ def test_maxover_execution_(unstructured_case, strategy): if unstructured_case.backend in [ gtfn.run_gtfn, + gtfn.run_gtfn_gpu, gtfn.run_gtfn_imperative, gtfn.run_gtfn_with_temporaries, ]: @@ -142,10 +143,7 @@ def conditional_nested_tuple( return where(mask, ((a, b), (b, a)), ((5.0, 7.0), (7.0, 5.0))) size = cartesian_case.default_sizes[IDim] - bool_field = np.random.choice(a=[False, True], size=(size)) - mask = cases.allocate(cartesian_case, conditional_nested_tuple, "mask").strategy( - cases.ConstInitializer(bool_field) - )() + mask = cartesian_case.as_field([IDim], np.random.choice(a=[False, True], size=size)) a = cases.allocate(cartesian_case, conditional_nested_tuple, "a")() b = cases.allocate(cartesian_case, conditional_nested_tuple, "b")() @@ -216,10 +214,7 @@ def conditional( return where(mask, a, b) size = cartesian_case.default_sizes[IDim] - bool_field = np.random.choice(a=[False, True], size=(size)) - mask = cases.allocate(cartesian_case, conditional, "mask").strategy( - cases.ConstInitializer(bool_field) - )() + mask = cartesian_case.as_field([IDim], np.random.choice(a=[False, True], size=(size))) a = cases.allocate(cartesian_case, conditional, "a")() b = cases.allocate(cartesian_case, conditional, "b")() out = cases.allocate(cartesian_case, conditional, cases.RETURN)() @@ -233,10 +228,7 @@ def conditional_promotion(mask: cases.IBoolField, a: cases.IFloatField) -> cases return where(mask, a, 10.0) size = cartesian_case.default_sizes[IDim] - bool_field = np.random.choice(a=[False, True], size=(size)) - mask = cases.allocate(cartesian_case, conditional_promotion, "mask").strategy( - cases.ConstInitializer(bool_field) - )() + mask = cartesian_case.as_field([IDim], np.random.choice(a=[False, True], size=(size))) a = cases.allocate(cartesian_case, conditional_promotion, "a")() out = cases.allocate(cartesian_case, conditional_promotion, cases.RETURN)() @@ -274,7 +266,7 @@ def conditional_program( conditional_shifted(mask, a, b, out=out) size = cartesian_case.default_sizes[IDim] + 1 - mask = gtx.as_field([IDim], np.random.choice(a=[False, True], size=(size))) + mask = cartesian_case.as_field([IDim], np.random.choice(a=[False, True], size=(size))) a = cases.allocate(cartesian_case, conditional_program, "a").extend({IDim: (0, 1)})() b = cases.allocate(cartesian_case, conditional_program, "b").extend({IDim: (0, 1)})() out = cases.allocate(cartesian_case, conditional_shifted, cases.RETURN)() diff --git a/tests/next_tests/integration_tests/feature_tests/ffront_tests/test_math_builtin_execution.py b/tests/next_tests/integration_tests/feature_tests/ffront_tests/test_math_builtin_execution.py index a1839b8e17..937b05e087 100644 --- a/tests/next_tests/integration_tests/feature_tests/ffront_tests/test_math_builtin_execution.py +++ b/tests/next_tests/integration_tests/feature_tests/ffront_tests/test_math_builtin_execution.py @@ -125,9 +125,9 @@ def test_math_function_builtins_execution(cartesian_case, builtin_name: str, inp else: ref_impl: Callable = getattr(np, builtin_name) - inps = [gtx.as_field([IDim], np.asarray(input)) for input in inputs] + inps = [cartesian_case.as_field([IDim], np.asarray(input)) for input in inputs] expected = ref_impl(*inputs) - out = gtx.as_field([IDim], np.zeros_like(expected)) + out = cartesian_case.as_field([IDim], np.zeros_like(expected)) builtin_field_op = make_builtin_field_operator(builtin_name).with_backend( cartesian_case.backend diff --git a/tests/next_tests/integration_tests/feature_tests/ffront_tests/test_math_unary_builtins.py b/tests/next_tests/integration_tests/feature_tests/ffront_tests/test_math_unary_builtins.py index 59e11a7de8..8660ecfdbd 100644 --- a/tests/next_tests/integration_tests/feature_tests/ffront_tests/test_math_unary_builtins.py +++ b/tests/next_tests/integration_tests/feature_tests/ffront_tests/test_math_unary_builtins.py @@ -72,6 +72,7 @@ def test_floordiv(cartesian_case): gtfn.run_gtfn, gtfn.run_gtfn_imperative, gtfn.run_gtfn_with_temporaries, + gtfn.run_gtfn_gpu, ]: pytest.xfail( "FloorDiv not yet supported." @@ -90,7 +91,7 @@ def test_mod(cartesian_case): def mod_fieldop(inp1: cases.IField) -> cases.IField: return inp1 % 2 - inp1 = gtx.as_field([IDim], np.asarray(range(10), dtype=int32) - 5) + inp1 = cartesian_case.as_field([IDim], np.asarray(range(10), dtype=int32) - 5) out = cases.allocate(cartesian_case, mod_fieldop, cases.RETURN)() cases.verify(cartesian_case, mod_fieldop, inp1, out=out, ref=inp1 % 2) @@ -102,13 +103,8 @@ def binary_xor(inp1: cases.IBoolField, inp2: cases.IBoolField) -> cases.IBoolFie return inp1 ^ inp2 size = cartesian_case.default_sizes[IDim] - bool_field = np.random.choice(a=[False, True], size=(size)) - inp1 = cases.allocate(cartesian_case, binary_xor, "inp1").strategy( - cases.ConstInitializer(bool_field) - )() - inp2 = cases.allocate(cartesian_case, binary_xor, "inp2").strategy( - cases.ConstInitializer(bool_field) - )() + inp1 = cartesian_case.as_field([IDim], np.random.choice(a=[False, True], size=(size))) + inp2 = cartesian_case.as_field([IDim], np.random.choice(a=[False, True], size=(size))) out = cases.allocate(cartesian_case, binary_xor, cases.RETURN)() cases.verify(cartesian_case, binary_xor, inp1, inp2, out=out, ref=inp1 ^ inp2) @@ -119,13 +115,8 @@ def bit_and(inp1: cases.IBoolField, inp2: cases.IBoolField) -> cases.IBoolField: return inp1 & inp2 size = cartesian_case.default_sizes[IDim] - bool_field = np.random.choice(a=[False, True], size=(size)) - inp1 = cases.allocate(cartesian_case, bit_and, "inp1").strategy( - cases.ConstInitializer(bool_field) - )() - inp2 = cases.allocate(cartesian_case, bit_and, "inp2").strategy( - cases.ConstInitializer(bool_field) - )() + inp1 = cartesian_case.as_field([IDim], np.random.choice(a=[False, True], size=(size))) + inp2 = cartesian_case.as_field([IDim], np.random.choice(a=[False, True], size=(size))) out = cases.allocate(cartesian_case, bit_and, cases.RETURN)() cases.verify(cartesian_case, bit_and, inp1, inp2, out=out, ref=inp1 & inp2) @@ -136,13 +127,8 @@ def bit_or(inp1: cases.IBoolField, inp2: cases.IBoolField) -> cases.IBoolField: return inp1 | inp2 size = cartesian_case.default_sizes[IDim] - bool_field = np.random.choice(a=[False, True], size=(size)) - inp1 = cases.allocate(cartesian_case, bit_or, "inp1").strategy( - cases.ConstInitializer(bool_field) - )() - inp2 = cases.allocate(cartesian_case, bit_or, "inp2").strategy( - cases.ConstInitializer(bool_field) - )() + inp1 = cartesian_case.as_field([IDim], np.random.choice(a=[False, True], size=(size))) + inp2 = cartesian_case.as_field([IDim], np.random.choice(a=[False, True], size=(size))) out = cases.allocate(cartesian_case, bit_or, cases.RETURN)() cases.verify(cartesian_case, bit_or, inp1, inp2, out=out, ref=inp1 | inp2) @@ -164,10 +150,7 @@ def tilde_fieldop(inp1: cases.IBoolField) -> cases.IBoolField: return ~inp1 size = cartesian_case.default_sizes[IDim] - bool_field = np.random.choice(a=[False, True], size=(size)) - inp1 = cases.allocate(cartesian_case, tilde_fieldop, "inp1").strategy( - cases.ConstInitializer(bool_field) - )() + inp1 = cartesian_case.as_field([IDim], np.random.choice(a=[False, True], size=(size))) out = cases.allocate(cartesian_case, tilde_fieldop, cases.RETURN)() cases.verify(cartesian_case, tilde_fieldop, inp1, out=out, ref=~inp1) diff --git a/tests/next_tests/integration_tests/feature_tests/ffront_tests/test_program.py b/tests/next_tests/integration_tests/feature_tests/ffront_tests/test_program.py index 545abd2825..b82cae25a8 100644 --- a/tests/next_tests/integration_tests/feature_tests/ffront_tests/test_program.py +++ b/tests/next_tests/integration_tests/feature_tests/ffront_tests/test_program.py @@ -215,7 +215,7 @@ def prog( def test_wrong_argument_type(cartesian_case, copy_program_def): copy_program = gtx.program(copy_program_def, backend=cartesian_case.backend) - inp = gtx.as_field([JDim], np.ones((cartesian_case.default_sizes[JDim],))) + inp = cartesian_case.as_field([JDim], np.ones((cartesian_case.default_sizes[JDim],))) out = cases.allocate(cartesian_case, copy_program, "out").strategy(cases.ConstInitializer(1))() with pytest.raises(TypeError) as exc_info: diff --git a/tests/next_tests/integration_tests/multi_feature_tests/ffront_tests/test_icon_like_scan.py b/tests/next_tests/integration_tests/multi_feature_tests/ffront_tests/test_icon_like_scan.py index eaae9a2a3e..cd948ffa02 100644 --- a/tests/next_tests/integration_tests/multi_feature_tests/ffront_tests/test_icon_like_scan.py +++ b/tests/next_tests/integration_tests/multi_feature_tests/ffront_tests/test_icon_like_scan.py @@ -18,8 +18,11 @@ import pytest import gt4py.next as gtx +from gt4py.next import common from gt4py.next.program_processors.runners import gtfn, roundtrip +from next_tests.integration_tests import cases +from next_tests.integration_tests.cases import Cell, KDim, Koff from next_tests.integration_tests.feature_tests.ffront_tests.ffront_test_utils import ( fieldview_backend, ) @@ -190,80 +193,97 @@ def reference( @pytest.fixture -def test_setup(): +def test_setup(fieldview_backend): + test_case = cases.Case( + fieldview_backend, + offset_provider={"Koff": KDim}, + default_sizes={Cell: 14, KDim: 10}, + grid_type=common.GridType.UNSTRUCTURED, + ) + @dataclass(frozen=True) class setup: - cell_size = 14 - k_size = 10 - z_alpha = gtx.as_field( + case: cases.Case = test_case + cell_size = case.default_sizes[Cell] + k_size = case.default_sizes[KDim] + z_alpha = case.as_field( [Cell, KDim], np.random.default_rng().uniform(size=(cell_size, k_size + 1)) ) - z_beta = gtx.as_field( + z_beta = case.as_field( [Cell, KDim], np.random.default_rng().uniform(size=(cell_size, k_size)) ) - z_q = gtx.as_field([Cell, KDim], np.random.default_rng().uniform(size=(cell_size, k_size))) - w = gtx.as_field([Cell, KDim], np.random.default_rng().uniform(size=(cell_size, k_size))) + z_q = case.as_field([Cell, KDim], np.random.default_rng().uniform(size=(cell_size, k_size))) + w = case.as_field([Cell, KDim], np.random.default_rng().uniform(size=(cell_size, k_size))) z_q_ref, w_ref = reference(z_alpha.ndarray, z_beta.ndarray, z_q.ndarray, w.ndarray) - dummy = gtx.as_field([Cell, KDim], np.zeros((cell_size, k_size), dtype=bool)) - z_q_out = gtx.as_field([Cell, KDim], np.zeros((cell_size, k_size))) + dummy = case.as_field([Cell, KDim], np.zeros((cell_size, k_size), dtype=bool)) + z_q_out = case.as_field([Cell, KDim], np.zeros((cell_size, k_size))) return setup() @pytest.mark.uses_tuple_returns -def test_solve_nonhydro_stencil_52_like_z_q(test_setup, fieldview_backend): - if fieldview_backend in [ +def test_solve_nonhydro_stencil_52_like_z_q(test_setup): + if test_setup.case.backend in [ gtfn.run_gtfn, + gtfn.run_gtfn_gpu, gtfn.run_gtfn_imperative, gtfn.run_gtfn_with_temporaries, ]: pytest.xfail("Needs implementation of scan projector.") - solve_nonhydro_stencil_52_like_z_q.with_backend(fieldview_backend)( + cases.verify( + test_setup.case, + solve_nonhydro_stencil_52_like_z_q, test_setup.z_alpha, test_setup.z_beta, test_setup.z_q, test_setup.w, test_setup.z_q_out, - offset_provider={"Koff": KDim}, + ref=test_setup.z_q_ref, + inout=test_setup.z_q_out, + comparison=lambda ref, a: np.allclose(ref[:, 1:], a[:, 1:]), ) assert np.allclose(test_setup.z_q_ref[:, 1:], test_setup.z_q_out[:, 1:]) @pytest.mark.uses_tuple_returns -def test_solve_nonhydro_stencil_52_like_z_q_tup(test_setup, fieldview_backend): - if fieldview_backend in [gtfn.run_gtfn_with_temporaries]: +def test_solve_nonhydro_stencil_52_like_z_q_tup(test_setup): + if test_setup.case.backend in [gtfn.run_gtfn_with_temporaries]: pytest.xfail( "Needs implementation of scan projector. Breaks in type inference as executed" "again after CollapseTuple." ) - if fieldview_backend == roundtrip.backend: + if test_setup.case.backend == roundtrip.backend: pytest.xfail("Needs proper handling of tuple[Column] <-> Column[tuple].") - solve_nonhydro_stencil_52_like_z_q_tup.with_backend(fieldview_backend)( + cases.verify( + test_setup.case, + solve_nonhydro_stencil_52_like_z_q_tup, test_setup.z_alpha, test_setup.z_beta, test_setup.z_q, test_setup.w, test_setup.z_q_out, - offset_provider={"Koff": KDim}, + ref=test_setup.z_q_ref, + inout=test_setup.z_q_out, + comparison=lambda ref, a: np.allclose(ref[:, 1:], a[:, 1:]), ) - assert np.allclose(test_setup.z_q_ref[:, 1:], test_setup.z_q_out[:, 1:]) - @pytest.mark.uses_tuple_returns -def test_solve_nonhydro_stencil_52_like(test_setup, fieldview_backend): - if fieldview_backend in [gtfn.run_gtfn_with_temporaries]: +def test_solve_nonhydro_stencil_52_like(test_setup): + if test_setup.case.backend in [gtfn.run_gtfn_with_temporaries]: pytest.xfail("Temporary extraction does not work correctly in combination with scans.") - solve_nonhydro_stencil_52_like.with_backend(fieldview_backend)( + + cases.run( + test_setup.case, + solve_nonhydro_stencil_52_like, test_setup.z_alpha, test_setup.z_beta, test_setup.z_q, test_setup.w, test_setup.dummy, - offset_provider={"Koff": KDim}, ) assert np.allclose(test_setup.z_q_ref, test_setup.z_q) @@ -271,18 +291,19 @@ def test_solve_nonhydro_stencil_52_like(test_setup, fieldview_backend): @pytest.mark.uses_tuple_returns -def test_solve_nonhydro_stencil_52_like_with_gtfn_tuple_merge(test_setup, fieldview_backend): - if fieldview_backend in [gtfn.run_gtfn_with_temporaries]: +def test_solve_nonhydro_stencil_52_like_with_gtfn_tuple_merge(test_setup): + if test_setup.case.backend in [gtfn.run_gtfn_with_temporaries]: pytest.xfail("Temporary extraction does not work correctly in combination with scans.") - if fieldview_backend == roundtrip.backend: + if test_setup.case.backend == roundtrip.backend: pytest.xfail("Needs proper handling of tuple[Column] <-> Column[tuple].") - solve_nonhydro_stencil_52_like_with_gtfn_tuple_merge.with_backend(fieldview_backend)( + cases.run( + test_setup.case, + solve_nonhydro_stencil_52_like_with_gtfn_tuple_merge, test_setup.z_alpha, test_setup.z_beta, test_setup.z_q, test_setup.w, - offset_provider={"Koff": KDim}, ) assert np.allclose(test_setup.z_q_ref, test_setup.z_q) diff --git a/tests/next_tests/integration_tests/multi_feature_tests/ffront_tests/test_laplacian.py b/tests/next_tests/integration_tests/multi_feature_tests/ffront_tests/test_laplacian.py index 9a1e968de0..4f4d4969a9 100644 --- a/tests/next_tests/integration_tests/multi_feature_tests/ffront_tests/test_laplacian.py +++ b/tests/next_tests/integration_tests/multi_feature_tests/ffront_tests/test_laplacian.py @@ -86,5 +86,5 @@ def test_ffront_lap(cartesian_case): in_field, out_field, inout=out_field[2:-2, 2:-2], - ref=lap_ref(lap_ref(np.asarray(in_field.ndarray))), + ref=lap_ref(lap_ref(in_field.array_ns.asarray(in_field.ndarray))), ) diff --git a/tests/next_tests/unit_tests/conftest.py b/tests/next_tests/unit_tests/conftest.py index b43eeb3f91..372062d08a 100644 --- a/tests/next_tests/unit_tests/conftest.py +++ b/tests/next_tests/unit_tests/conftest.py @@ -60,6 +60,7 @@ def lift_mode(request): (definitions.ProgramBackendId.GTFN_CPU, True), (definitions.ProgramBackendId.GTFN_CPU_IMPERATIVE, True), (definitions.ProgramBackendId.GTFN_CPU_WITH_TEMPORARIES, True), + # pytest.param((definitions.ProgramBackendId.GTFN_GPU, True), marks=pytest.mark.requires_gpu), # TODO(havogt): update tests to use proper allocation (definitions.ProgramFormatterId.LISP_FORMATTER, False), (definitions.ProgramFormatterId.ITIR_PRETTY_PRINTER, False), (definitions.ProgramFormatterId.ITIR_TYPE_CHECKER, False), diff --git a/tox.ini b/tox.ini index 5b644e7d97..44dc912c8a 100644 --- a/tox.ini +++ b/tox.ini @@ -84,7 +84,7 @@ commands = nomesh-cpu: python -m pytest --cache-clear -v -n {env:NUM_PROCESSES:1} -m "not requires_atlas and not requires_gpu" {posargs} tests{/}next_tests nomesh-{cuda,cuda11x,cuda12x}: python -m pytest --cache-clear -v -n {env:NUM_PROCESSES:1} -m "not requires_atlas and requires_gpu" {posargs} tests{/}next_tests atlas-cpu: python -m pytest --cache-clear -v -n {env:NUM_PROCESSES:1} -m "requires_atlas and not requires_gpu" {posargs} tests{/}next_tests - # atlas-{cuda,cuda11x,cuda12x}: python -m pytest --cache-clear -v -n {env:NUM_PROCESSES:1} -m "requires_atlas and requires_gpu" {posargs} tests{/}next_tests # TODO(ricoh): activate when such tests exist + # atlas-{cuda,cuda11x,cuda12x}: python -m pytest --cache-clear -v -n {env:NUM_PROCESSES:1} -m "requires_atlas and requires_gpu" {posargs} tests{/}next_tests # TODO(ricoh): activate when such tests exist pytest --doctest-modules src{/}gt4py{/}next [testenv:storage-py{38,39,310}-{internal,dace}-{cpu,cuda,cuda11x,cuda12x}] From 6375445b4edd93ab734124325c1adfae42b2bb84 Mon Sep 17 00:00:00 2001 From: Hannes Vogt Date: Tue, 21 Nov 2023 12:50:30 +0100 Subject: [PATCH 10/85] feat[next] Embedded field remove __array__ (#1366) Add `.asnumpy` to `Field`. Implicit conversion via `__array__` creates a problem, because expression `np.float*field` will return ndarray instead of field, because `np.float`'s multiply operator will `asarray(rhs)`. Update all tests to do an explicit conversion to ndarray if needed. --------- Co-authored-by: nfarabullini --- src/gt4py/next/common.py | 4 ++ src/gt4py/next/embedded/nd_array_field.py | 6 +-- src/gt4py/next/iterator/embedded.py | 6 +++ src/gt4py/next/utils.py | 40 +++++++++++++++- tests/next_tests/integration_tests/cases.py | 13 +++--- .../ffront_tests/test_arg_call_interface.py | 14 +++--- .../ffront_tests/test_execution.py | 46 +++++++++---------- .../ffront_tests/test_gt4py_builtins.py | 21 ++++++--- .../test_math_builtin_execution.py | 2 +- .../ffront_tests/test_program.py | 10 ++-- .../ffront_tests/test_scalar_if.py | 8 ++-- .../iterator_tests/test_builtins.py | 12 ++--- .../iterator_tests/test_conditional.py | 4 +- .../iterator_tests/test_constant.py | 4 +- .../test_horizontal_indirection.py | 4 +- .../iterator_tests/test_implicit_fencil.py | 6 +-- .../feature_tests/iterator_tests/test_scan.py | 2 +- .../test_strided_offset_provider.py | 4 +- .../iterator_tests/test_trivial.py | 8 ++-- .../iterator_tests/test_tuple.py | 26 +++++------ .../feature_tests/test_util_cases.py | 18 ++++---- .../ffront_tests/test_icon_like_scan.py | 10 ++-- .../iterator_tests/test_anton_toy.py | 2 +- .../iterator_tests/test_column_stencil.py | 19 ++++---- .../iterator_tests/test_fvm_nabla.py | 40 ++++++++-------- .../iterator_tests/test_hdiff.py | 2 +- .../iterator_tests/test_vertical_advection.py | 2 +- .../test_with_toy_connectivity.py | 26 +++++------ .../otf_tests/test_gtfn_workflow.py | 2 +- .../embedded_tests/test_nd_array_field.py | 4 +- 30 files changed, 209 insertions(+), 156 deletions(-) diff --git a/src/gt4py/next/common.py b/src/gt4py/next/common.py index 66766be76b..51ad14f22d 100644 --- a/src/gt4py/next/common.py +++ b/src/gt4py/next/common.py @@ -467,6 +467,10 @@ def ndarray(self) -> core_defs.NDArrayObject: def __str__(self) -> str: return f"⟨{self.domain!s} → {self.dtype}⟩" + @abc.abstractmethod + def asnumpy(self) -> np.ndarray: + ... + @abc.abstractmethod def remap(self, index_field: Field) -> Field: ... diff --git a/src/gt4py/next/embedded/nd_array_field.py b/src/gt4py/next/embedded/nd_array_field.py index 9357570b05..a843772a20 100644 --- a/src/gt4py/next/embedded/nd_array_field.py +++ b/src/gt4py/next/embedded/nd_array_field.py @@ -120,11 +120,11 @@ def __gt_origin__(self) -> tuple[int, ...]: def ndarray(self) -> core_defs.NDArrayObject: return self._ndarray - def __array__(self, dtype: npt.DTypeLike = None) -> np.ndarray: + def asnumpy(self) -> np.ndarray: if self.array_ns == cp: - return np.asarray(cp.asnumpy(self._ndarray), dtype) + return cp.asnumpy(self._ndarray) else: - return np.asarray(self._ndarray, dtype) + return np.asarray(self._ndarray) @property def dtype(self) -> core_defs.DType[core_defs.ScalarT]: diff --git a/src/gt4py/next/iterator/embedded.py b/src/gt4py/next/iterator/embedded.py index 44294a3a71..9000b00d8f 100644 --- a/src/gt4py/next/iterator/embedded.py +++ b/src/gt4py/next/iterator/embedded.py @@ -1068,6 +1068,9 @@ def dtype(self) -> core_defs.Int32DType: def ndarray(self) -> core_defs.NDArrayObject: raise AttributeError("Cannot get `ndarray` of an infinite Field.") + def asnumpy(self) -> np.ndarray: + raise NotImplementedError() + def remap(self, index_field: common.Field) -> common.Field: # TODO can be implemented by constructing and ndarray (but do we know of which kind?) raise NotImplementedError() @@ -1180,6 +1183,9 @@ def dtype(self) -> core_defs.DType[core_defs.ScalarT]: def ndarray(self) -> core_defs.NDArrayObject: raise AttributeError("Cannot get `ndarray` of an infinite Field.") + def asnumpy(self) -> np.ndarray: + raise NotImplementedError() + def remap(self, index_field: common.Field) -> common.Field: # TODO can be implemented by constructing and ndarray (but do we know of which kind?) raise NotImplementedError() diff --git a/src/gt4py/next/utils.py b/src/gt4py/next/utils.py index 006b3057b0..baae8361c5 100644 --- a/src/gt4py/next/utils.py +++ b/src/gt4py/next/utils.py @@ -12,7 +12,12 @@ # # SPDX-License-Identifier: GPL-3.0-or-later -from typing import Any, ClassVar, TypeGuard, TypeVar +import functools +from typing import Any, Callable, ClassVar, ParamSpec, TypeGuard, TypeVar, cast + +import numpy as np + +from gt4py.next import common class RecursionGuard: @@ -53,6 +58,39 @@ def __exit__(self, *exc): _T = TypeVar("_T") +_P = ParamSpec("_P") +_R = TypeVar("_R") + def is_tuple_of(v: Any, t: type[_T]) -> TypeGuard[tuple[_T, ...]]: return isinstance(v, tuple) and all(isinstance(e, t) for e in v) + + +def tree_map(fun: Callable[_P, _R]) -> Callable[..., _R | tuple[_R | tuple, ...]]: + """Apply `fun` to each entry of (possibly nested) tuples. + + Examples: + >>> tree_map(lambda x: x + 1)(((1, 2), 3)) + ((2, 3), 4) + + >>> tree_map(lambda x, y: x + y)(((1, 2), 3), ((4, 5), 6)) + ((5, 7), 9) + """ + + @functools.wraps(fun) + def impl(*args: Any | tuple[Any | tuple, ...]) -> _R | tuple[_R | tuple, ...]: + if isinstance(args[0], tuple): + assert all(isinstance(arg, tuple) and len(args[0]) == len(arg) for arg in args) + return tuple(impl(*arg) for arg in zip(*args)) + + return fun( + *cast(_P.args, args) + ) # mypy doesn't understand that `args` at this point is of type `_P.args` + + return impl + + +# TODO(havogt): consider moving to module like `field_utils` +@tree_map +def asnumpy(field: common.Field | np.ndarray) -> np.ndarray: + return field.asnumpy() if common.is_field(field) else field # type: ignore[return-value] # mypy doesn't understand the condition diff --git a/tests/next_tests/integration_tests/cases.py b/tests/next_tests/integration_tests/cases.py index 730ce18fd5..7ef724ee2f 100644 --- a/tests/next_tests/integration_tests/cases.py +++ b/tests/next_tests/integration_tests/cases.py @@ -28,7 +28,7 @@ from gt4py._core import definitions as core_defs from gt4py.eve import extended_typing as xtyping from gt4py.eve.extended_typing import Self -from gt4py.next import common, constructors +from gt4py.next import common, constructors, utils from gt4py.next.ffront import decorator from gt4py.next.program_processors import processor_interface as ppi from gt4py.next.type_system import type_specifications as ts, type_translation @@ -435,14 +435,13 @@ def verify( run(case, fieldview_prog, *args, offset_provider=offset_provider) out_comp = out or inout - out_comp_str = str(out_comp) assert out_comp is not None - if hasattr(out_comp, "ndarray"): - out_comp_str = str(out_comp.ndarray) - assert comparison(ref, out_comp), ( + out_comp_ndarray = utils.asnumpy(out_comp) + ref_ndarray = utils.asnumpy(ref) + assert comparison(ref_ndarray, out_comp_ndarray), ( f"Verification failed:\n" f"\tcomparison={comparison.__name__}(ref, out)\n" - f"\tref = {ref}\n\tout = {out_comp_str}" + f"\tref = {ref_ndarray}\n\tout = {str(out_comp_ndarray)}" ) @@ -468,7 +467,7 @@ def verify_with_default_data( ``comparison(ref, )`` and should return a boolean. """ inps, kwfields = get_default_data(case, fieldop) - ref_args = tuple(i.__array__() if common.is_field(i) else i for i in inps) + ref_args = tuple(i.asnumpy() if common.is_field(i) else i for i in inps) verify( case, fieldop, diff --git a/tests/next_tests/integration_tests/feature_tests/ffront_tests/test_arg_call_interface.py b/tests/next_tests/integration_tests/feature_tests/ffront_tests/test_arg_call_interface.py index 6957e628bb..6293ff76bd 100644 --- a/tests/next_tests/integration_tests/feature_tests/ffront_tests/test_arg_call_interface.py +++ b/tests/next_tests/integration_tests/feature_tests/ffront_tests/test_arg_call_interface.py @@ -63,9 +63,9 @@ def testee(a: IField, b: IField, c: IField) -> IField: *pos_args, **kw_args, out=out, offset_provider=cartesian_case.offset_provider ) - expected = np.asarray(args["a"]) * 2 * np.asarray(args["b"]) - np.asarray(args["c"]) + expected = args["a"] * 2 * args["b"] - args["c"] - assert np.allclose(out, expected) + assert np.allclose(out.asnumpy(), expected.asnumpy()) @pytest.mark.parametrize("arg_spec", _generate_arg_permutations(("a", "b", "out"))) @@ -89,9 +89,9 @@ def testee(a: IField, b: IField, out: IField): *pos_args, **kw_args, offset_provider=cartesian_case.offset_provider ) - expected = np.asarray(args["a"]) + 2 * np.asarray(args["b"]) + expected = args["a"] + 2 * args["b"] - assert np.allclose(args["out"], expected) + assert np.allclose(args["out"].asnumpy(), expected.asnumpy()) def test_call_field_operator_from_field_operator(cartesian_case): @@ -177,9 +177,7 @@ def testee(a: IJKFloatField, b: IJKFloatField) -> IJKFloatField: a, b, out = ( cases.allocate(cartesian_case, testee, name)() for name in ("a", "b", cases.RETURN) ) - expected = (1.0 + 3.0 + 5.0 + 7.0) * np.add.accumulate( - np.asarray(a) + 2.0 * np.asarray(b), axis=2 - ) + expected = (1.0 + 3.0 + 5.0 + 7.0) * np.add.accumulate(a.asnumpy() + 2.0 * b.asnumpy(), axis=2) cases.verify(cartesian_case, testee, a, b, out=out, ref=expected) @@ -210,7 +208,7 @@ def testee( for name in ("out1", "out2", "out3", "out4") ) - ref = np.add.accumulate(np.asarray(a) + 2 * np.asarray(b), axis=2) + ref = np.add.accumulate(a.asnumpy() + 2 * b.asnumpy(), axis=2) cases.verify( cartesian_case, diff --git a/tests/next_tests/integration_tests/feature_tests/ffront_tests/test_execution.py b/tests/next_tests/integration_tests/feature_tests/ffront_tests/test_execution.py index fe18bda9e3..1f3b54d6f0 100644 --- a/tests/next_tests/integration_tests/feature_tests/ffront_tests/test_execution.py +++ b/tests/next_tests/integration_tests/feature_tests/ffront_tests/test_execution.py @@ -371,8 +371,8 @@ def cast_nested_tuple( a = cases.allocate(cartesian_case, cast_tuple, "a")() b = cases.allocate(cartesian_case, cast_tuple, "b")() - a_asint = cartesian_case.as_field([IDim], np.asarray(a).astype(int32)) - b_asint = cartesian_case.as_field([IDim], np.asarray(b).astype(int32)) + a_asint = cartesian_case.as_field([IDim], a.asnumpy().astype(int32)) + b_asint = cartesian_case.as_field([IDim], b.asnumpy().astype(int32)) out_tuple = cases.allocate(cartesian_case, cast_tuple, cases.RETURN)() out_nested_tuple = cases.allocate(cartesian_case, cast_nested_tuple, cases.RETURN)() @@ -384,7 +384,10 @@ def cast_nested_tuple( a_asint, b_asint, out=out_tuple, - ref=(np.full_like(a, True, dtype=bool), np.full_like(b, True, dtype=bool)), + ref=( + np.full_like(a.asnumpy(), True, dtype=bool), + np.full_like(b.asnumpy(), True, dtype=bool), + ), ) cases.verify( @@ -396,9 +399,9 @@ def cast_nested_tuple( b_asint, out=out_nested_tuple, ref=( - np.full_like(a, True, dtype=bool), - np.full_like(a, True, dtype=bool), - np.full_like(b, True, dtype=bool), + np.full_like(a.asnumpy(), True, dtype=bool), + np.full_like(a.asnumpy(), True, dtype=bool), + np.full_like(b.asnumpy(), True, dtype=bool), ), ) @@ -473,7 +476,7 @@ def testee(a: cases.IKField, offset_field: cases.IKField) -> gtx.Field[[IDim, KD comparison=lambda out, ref: np.all(out == ref), ) - assert np.allclose(out, ref) + assert np.allclose(out.asnumpy(), ref) def test_nested_tuple_return(cartesian_case): @@ -846,8 +849,8 @@ def program_domain(a: cases.IField, out: cases.IField): a = cases.allocate(cartesian_case, program_domain, "a")() out = cases.allocate(cartesian_case, program_domain, "out")() - ref = np.asarray(out).copy() # ensure we are not overwriting `out` outside of the domain - ref[1:9] = a[1:9] * 2 + ref = out.asnumpy().copy() # ensure we are not overwriting out outside of the domain + ref[1:9] = a.asnumpy()[1:9] * 2 cases.verify(cartesian_case, program_domain, a, out, inout=out, ref=ref) @@ -881,8 +884,8 @@ def program_domain( inp = cases.allocate(cartesian_case, program_domain, "inp")() out = cases.allocate(cartesian_case, fieldop_domain, cases.RETURN)() - ref = np.asarray(out).copy() - ref[lower_i : int(upper_i / 2)] = inp[lower_i : int(upper_i / 2)] * 2 + ref = out.asnumpy().copy() + ref[lower_i : int(upper_i / 2)] = inp.asnumpy()[lower_i : int(upper_i / 2)] * 2 cases.verify( cartesian_case, @@ -924,9 +927,9 @@ def program_domain( a = cases.allocate(cartesian_case, program_domain, "a")() out = cases.allocate(cartesian_case, program_domain, "out")() - ref = np.asarray(out).copy() + ref = out.asnumpy().copy() ref[1 * lower_i : upper_i + 0, lower_j - 0 : upper_j] = ( - a[1 * lower_i : upper_i + 0, lower_j - 0 : upper_j] * 2 + a.asnumpy()[1 * lower_i : upper_i + 0, lower_j - 0 : upper_j] * 2 ) cases.verify( @@ -964,10 +967,10 @@ def program_domain_tuple( out0 = cases.allocate(cartesian_case, program_domain_tuple, "out0")() out1 = cases.allocate(cartesian_case, program_domain_tuple, "out1")() - ref0 = np.asarray(out0).copy() - ref0[1:9, 4:6] = inp0[1:9, 4:6] + inp1[1:9, 4:6] - ref1 = np.asarray(out1).copy() - ref1[1:9, 4:6] = inp1[1:9, 4:6] + ref0 = out0.asnumpy().copy() + ref0[1:9, 4:6] = inp0.asnumpy()[1:9, 4:6] + inp1.asnumpy()[1:9, 4:6] + ref1 = out1.asnumpy().copy() + ref1[1:9, 4:6] = inp1.asnumpy()[1:9, 4:6] cases.verify( cartesian_case, @@ -995,7 +998,7 @@ def fieldop_where_k_offset( )() out = cases.allocate(cartesian_case, fieldop_where_k_offset, "inp")() - ref = np.where(np.asarray(k_index) > 0, np.roll(inp, 1, axis=1), 2) + ref = np.where(k_index.asnumpy() > 0, np.roll(inp.asnumpy(), 1, axis=1), 2) cases.verify(cartesian_case, fieldop_where_k_offset, inp, k_index, out=out, ref=ref) @@ -1119,13 +1122,6 @@ def _invalid_unpack() -> tuple[int32, float64, int32]: def test_constant_closure_vars(cartesian_case): - if cartesian_case.backend is None: - # >>> field = gtx.zeros(domain) - # >>> np.int32(1)*field # steals the buffer from the field - # array([0.]) - - # TODO(havogt): remove `__array__`` from `NdArrayField` - pytest.xfail("Bug: Binary operation between np datatype and Field returns ndarray.") from gt4py.eve.utils import FrozenNamespace constants = FrozenNamespace( diff --git a/tests/next_tests/integration_tests/feature_tests/ffront_tests/test_gt4py_builtins.py b/tests/next_tests/integration_tests/feature_tests/ffront_tests/test_gt4py_builtins.py index 8bc325d276..e2434d860a 100644 --- a/tests/next_tests/integration_tests/feature_tests/ffront_tests/test_gt4py_builtins.py +++ b/tests/next_tests/integration_tests/feature_tests/ffront_tests/test_gt4py_builtins.py @@ -155,8 +155,8 @@ def conditional_nested_tuple( b, out=cases.allocate(cartesian_case, conditional_nested_tuple, cases.RETURN)(), ref=np.where( - mask, - ((a, b), (b, a)), + mask.asnumpy(), + ((a.asnumpy(), b.asnumpy()), (b.asnumpy(), a.asnumpy())), ((np.full(size, 5.0), np.full(size, 7.0)), (np.full(size, 7.0), np.full(size, 5.0))), ), ) @@ -219,7 +219,15 @@ def conditional( b = cases.allocate(cartesian_case, conditional, "b")() out = cases.allocate(cartesian_case, conditional, cases.RETURN)() - cases.verify(cartesian_case, conditional, mask, a, b, out=out, ref=np.where(mask, a, b)) + cases.verify( + cartesian_case, + conditional, + mask, + a, + b, + out=out, + ref=np.where(mask.asnumpy(), a.asnumpy(), b.asnumpy()), + ) def test_conditional_promotion(cartesian_case): @@ -231,10 +239,9 @@ def conditional_promotion(mask: cases.IBoolField, a: cases.IFloatField) -> cases mask = cartesian_case.as_field([IDim], np.random.choice(a=[False, True], size=(size))) a = cases.allocate(cartesian_case, conditional_promotion, "a")() out = cases.allocate(cartesian_case, conditional_promotion, cases.RETURN)() + ref = np.where(mask.asnumpy(), a.asnumpy(), 10.0) - cases.verify( - cartesian_case, conditional_promotion, mask, a, out=out, ref=np.where(mask, a, 10.0) - ) + cases.verify(cartesian_case, conditional_promotion, mask, a, out=out, ref=ref) def test_conditional_compareop(cartesian_case): @@ -279,7 +286,7 @@ def conditional_program( b, out, inout=out, - ref=np.where(mask, a, b)[1:], + ref=np.where(mask.asnumpy(), a.asnumpy(), b.asnumpy())[1:], ) diff --git a/tests/next_tests/integration_tests/feature_tests/ffront_tests/test_math_builtin_execution.py b/tests/next_tests/integration_tests/feature_tests/ffront_tests/test_math_builtin_execution.py index 937b05e087..8cfcff160c 100644 --- a/tests/next_tests/integration_tests/feature_tests/ffront_tests/test_math_builtin_execution.py +++ b/tests/next_tests/integration_tests/feature_tests/ffront_tests/test_math_builtin_execution.py @@ -135,4 +135,4 @@ def test_math_function_builtins_execution(cartesian_case, builtin_name: str, inp builtin_field_op(*inps, out=out, offset_provider={}) - assert np.allclose(np.asarray(out), expected) + assert np.allclose(out.asnumpy(), expected) diff --git a/tests/next_tests/integration_tests/feature_tests/ffront_tests/test_program.py b/tests/next_tests/integration_tests/feature_tests/ffront_tests/test_program.py index b82cae25a8..a0f69f332c 100644 --- a/tests/next_tests/integration_tests/feature_tests/ffront_tests/test_program.py +++ b/tests/next_tests/integration_tests/feature_tests/ffront_tests/test_program.py @@ -152,7 +152,7 @@ def prog( cases.run(cartesian_case, prog, a, b, out_a, out_b, offset_provider={}) - assert np.allclose((a, b), (out_a, out_b)) + assert np.allclose((a.asnumpy(), b.asnumpy()), (out_a.asnumpy(), out_b.asnumpy())) def test_tuple_program_return_constructed_inside_with_slicing(cartesian_case): @@ -178,7 +178,9 @@ def prog( cases.run(cartesian_case, prog, a, b, out_a, out_b, offset_provider={}) - assert np.allclose((a[1:], b[1:]), (out_a[1:], out_b[1:])) + assert np.allclose( + (a[1:].asnumpy(), b[1:].asnumpy()), (out_a[1:].asnumpy(), out_b[1:].asnumpy()) + ) assert out_a[0] == 0 and out_b[0] == 0 @@ -209,7 +211,9 @@ def prog( cases.run(cartesian_case, prog, a, b, c, out_a, out_b, out_c, offset_provider={}) - assert np.allclose((a, b, c), (out_a, out_b, out_c)) + assert np.allclose( + (a.asnumpy(), b.asnumpy(), c.asnumpy()), (out_a.asnumpy(), out_b.asnumpy(), out_c.asnumpy()) + ) def test_wrong_argument_type(cartesian_case, copy_program_def): diff --git a/tests/next_tests/integration_tests/feature_tests/ffront_tests/test_scalar_if.py b/tests/next_tests/integration_tests/feature_tests/ffront_tests/test_scalar_if.py index e9c3ac8d19..84b480a23d 100644 --- a/tests/next_tests/integration_tests/feature_tests/ffront_tests/test_scalar_if.py +++ b/tests/next_tests/integration_tests/feature_tests/ffront_tests/test_scalar_if.py @@ -315,10 +315,10 @@ def if_without_else( out = cases.allocate(cartesian_case, if_without_else, cases.RETURN)() ref = { - (True, True): np.asarray(a) + 2, - (True, False): np.asarray(a), - (False, True): np.asarray(b) + 1, - (False, False): np.asarray(b) + 1, + (True, True): a.asnumpy() + 2, + (True, False): a.asnumpy(), + (False, True): b.asnumpy() + 1, + (False, False): b.asnumpy() + 1, } cases.verify( diff --git a/tests/next_tests/integration_tests/feature_tests/iterator_tests/test_builtins.py b/tests/next_tests/integration_tests/feature_tests/iterator_tests/test_builtins.py index 2bcd0f8367..c0d565bbf4 100644 --- a/tests/next_tests/integration_tests/feature_tests/iterator_tests/test_builtins.py +++ b/tests/next_tests/integration_tests/feature_tests/iterator_tests/test_builtins.py @@ -178,7 +178,7 @@ def test_arithmetic_and_logical_builtins(program_processor, builtin, inputs, exp fencil(builtin, out, *inps, processor=program_processor, as_column=as_column) if validate: - assert np.allclose(np.asarray(out), expected) + assert np.allclose(out.asnumpy(), expected) @pytest.mark.parametrize("builtin, inputs, expected", arithmetic_and_logical_test_data()) @@ -199,7 +199,7 @@ def test_arithmetic_and_logical_functors_gtfn(builtin, inputs, expected): ) # avoid inlining the function fencil(builtin, out, *inps, processor=gtfn_without_transforms) - assert np.allclose(np.asarray(out), expected) + assert np.allclose(out.asnumpy(), expected) @pytest.mark.parametrize("as_column", [False, True]) @@ -228,7 +228,7 @@ def test_math_function_builtins(program_processor, builtin_name, inputs, as_colu ) if validate: - assert np.allclose(np.asarray(out), expected) + assert np.allclose(out.asnumpy(), expected) Neighbor = offset("Neighbor") @@ -268,7 +268,7 @@ def test_can_deref(program_processor, stencil): ) if validate: - assert np.allclose(np.asarray(out), -1.0) + assert np.allclose(out.asnumpy(), -1.0) a_neighbor_tbl = gtx.NeighborTableOffsetProvider(np.array([[0]]), Node, Node, 1) run_processor( @@ -280,7 +280,7 @@ def test_can_deref(program_processor, stencil): ) if validate: - assert np.allclose(np.asarray(out), 1.0) + assert np.allclose(out.asnumpy(), 1.0) # def test_can_deref_lifted(program_processor): @@ -336,7 +336,7 @@ def test_cast(program_processor, as_column, input_value, dtype, np_dtype): def sten_cast(it, casted_valued): return eq(cast_(deref(it), dtype), deref(casted_valued)) - out = field_maker(np.zeros_like(inp, dtype=builtins.bool))[0] + out = field_maker(np.zeros_like(inp.asnumpy(), dtype=builtins.bool))[0] run_processor( sten_cast[{IDim: range(1)}], program_processor, diff --git a/tests/next_tests/integration_tests/feature_tests/iterator_tests/test_conditional.py b/tests/next_tests/integration_tests/feature_tests/iterator_tests/test_conditional.py index de7ebf2869..8536dbea90 100644 --- a/tests/next_tests/integration_tests/feature_tests/iterator_tests/test_conditional.py +++ b/tests/next_tests/integration_tests/feature_tests/iterator_tests/test_conditional.py @@ -51,5 +51,5 @@ def test_conditional_w_tuple(program_processor): offset_provider={}, ) if validate: - assert np.all(out.ndarray[np.asarray(inp) == 0] == 3.0) - assert np.all(out.ndarray[np.asarray(inp) == 1] == 7.0) + assert np.all(out.asnumpy()[inp.asnumpy() == 0] == 3.0) + assert np.all(out.asnumpy()[inp.asnumpy() == 1] == 7.0) diff --git a/tests/next_tests/integration_tests/feature_tests/iterator_tests/test_constant.py b/tests/next_tests/integration_tests/feature_tests/iterator_tests/test_constant.py index 83a86319b4..faae549086 100644 --- a/tests/next_tests/integration_tests/feature_tests/iterator_tests/test_constant.py +++ b/tests/next_tests/integration_tests/feature_tests/iterator_tests/test_constant.py @@ -31,8 +31,8 @@ def constant_stencil(): # this is traced as a lambda, TODO directly feed iterat return deref(inp) + deref(lift(constant_stencil)()) inp = gtx.as_field([IDim], np.asarray([0, 42], dtype=np.int32)) - res = gtx.as_field([IDim], np.zeros_like(inp)) + res = gtx.as_field([IDim], np.zeros_like(inp.asnumpy())) add_constant[{IDim: range(2)}](inp, out=res, offset_provider={}, backend=roundtrip.executor) - assert np.allclose(res, np.asarray([1, 43])) + assert np.allclose(res.asnumpy(), np.asarray([1, 43])) diff --git a/tests/next_tests/integration_tests/feature_tests/iterator_tests/test_horizontal_indirection.py b/tests/next_tests/integration_tests/feature_tests/iterator_tests/test_horizontal_indirection.py index f9bd2cc33b..69f594a2bc 100644 --- a/tests/next_tests/integration_tests/feature_tests/iterator_tests/test_horizontal_indirection.py +++ b/tests/next_tests/integration_tests/feature_tests/iterator_tests/test_horizontal_indirection.py @@ -82,7 +82,7 @@ def test_simple_indirection(program_processor): ) if validate: - assert np.allclose(ref, out) + assert np.allclose(ref, out.asnumpy()) @fundef @@ -113,4 +113,4 @@ def test_direct_offset_for_indirection(program_processor): ) if validate: - assert np.allclose(ref, out) + assert np.allclose(ref, out.asnumpy()) diff --git a/tests/next_tests/integration_tests/feature_tests/iterator_tests/test_implicit_fencil.py b/tests/next_tests/integration_tests/feature_tests/iterator_tests/test_implicit_fencil.py index 2df7691f9e..6f600414db 100644 --- a/tests/next_tests/integration_tests/feature_tests/iterator_tests/test_implicit_fencil.py +++ b/tests/next_tests/integration_tests/feature_tests/iterator_tests/test_implicit_fencil.py @@ -53,7 +53,7 @@ def test_single_argument(program_processor, dom): run_processor(copy_stencil[dom], program_processor, inp, out=out, offset_provider={}) if validate: - assert np.allclose(inp, out) + assert np.allclose(inp.asnumpy(), out.asnumpy()) def test_2_arguments(program_processor, dom): @@ -70,7 +70,7 @@ def fun(inp0, inp1): run_processor(fun[dom], program_processor, inp0, inp1, out=out, offset_provider={}) if validate: - assert np.allclose(inp0 + inp1, out) + assert np.allclose(inp0.asnumpy() + inp1.asnumpy(), out.asnumpy()) def test_lambda_domain(program_processor): @@ -82,4 +82,4 @@ def test_lambda_domain(program_processor): run_processor(copy_stencil[dom], program_processor, inp, out=out, offset_provider={}) if validate: - assert np.allclose(inp, out) + assert np.allclose(inp.asnumpy(), out.asnumpy()) diff --git a/tests/next_tests/integration_tests/feature_tests/iterator_tests/test_scan.py b/tests/next_tests/integration_tests/feature_tests/iterator_tests/test_scan.py index 3af0440c27..fce1aa3960 100644 --- a/tests/next_tests/integration_tests/feature_tests/iterator_tests/test_scan.py +++ b/tests/next_tests/integration_tests/feature_tests/iterator_tests/test_scan.py @@ -60,4 +60,4 @@ def wrapped(inp): ) if validate: - assert np.allclose(out[:, :-1], reference) + assert np.allclose(out[:, :-1].asnumpy(), reference) diff --git a/tests/next_tests/integration_tests/feature_tests/iterator_tests/test_strided_offset_provider.py b/tests/next_tests/integration_tests/feature_tests/iterator_tests/test_strided_offset_provider.py index abdfffd74e..dd603fa3be 100644 --- a/tests/next_tests/integration_tests/feature_tests/iterator_tests/test_strided_offset_provider.py +++ b/tests/next_tests/integration_tests/feature_tests/iterator_tests/test_strided_offset_provider.py @@ -63,9 +63,9 @@ def test_strided_offset_provider(program_processor): ), ) out = gtx.as_field([LocA], np.zeros((LocA_size,))) - ref = np.sum(np.asarray(inp).reshape(LocA_size, max_neighbors), axis=-1) + ref = np.sum(inp.asnumpy().reshape(LocA_size, max_neighbors), axis=-1) run_processor(fencil, program_processor, LocA_size, out, inp) if validate: - assert np.allclose(out, ref) + assert np.allclose(out.asnumpy(), ref) diff --git a/tests/next_tests/integration_tests/feature_tests/iterator_tests/test_trivial.py b/tests/next_tests/integration_tests/feature_tests/iterator_tests/test_trivial.py index 8c59f994ee..8e12647c1b 100644 --- a/tests/next_tests/integration_tests/feature_tests/iterator_tests/test_trivial.py +++ b/tests/next_tests/integration_tests/feature_tests/iterator_tests/test_trivial.py @@ -65,7 +65,7 @@ def test_trivial(program_processor, lift_mode): ) if validate: - assert np.allclose(out[:, :, 0], out_s) + assert np.allclose(out[:, :, 0], out_s.asnumpy()) @fundef @@ -100,7 +100,7 @@ def test_shifted_arg_to_lift(program_processor, lift_mode): ) if validate: - assert np.allclose(out, out_s) + assert np.allclose(out, out_s.asnumpy()) @fendef @@ -137,7 +137,7 @@ def test_direct_deref(program_processor, lift_mode): ) if validate: - assert np.allclose(out, out_s) + assert np.allclose(out, out_s.asnumpy()) @fundef @@ -167,4 +167,4 @@ def test_vertical_shift_unstructured(program_processor): ) if validate: - assert np.allclose(inp_s[:, 1:], np.asarray(out_s)[:, :-1]) + assert np.allclose(inp_s[:, 1:].asnumpy(), out_s[:, :-1].asnumpy()) diff --git a/tests/next_tests/integration_tests/feature_tests/iterator_tests/test_tuple.py b/tests/next_tests/integration_tests/feature_tests/iterator_tests/test_tuple.py index 97a51508f5..add772e7ef 100644 --- a/tests/next_tests/integration_tests/feature_tests/iterator_tests/test_tuple.py +++ b/tests/next_tests/integration_tests/feature_tests/iterator_tests/test_tuple.py @@ -76,8 +76,8 @@ def test_tuple_output(program_processor, stencil): } run_processor(stencil[dom], program_processor, inp1, inp2, out=out, offset_provider={}) if validate: - assert np.allclose(inp1, out[0]) - assert np.allclose(inp2, out[1]) + assert np.allclose(inp1.asnumpy(), out[0].asnumpy()) + assert np.allclose(inp2.asnumpy(), out[1].asnumpy()) @fundef @@ -144,10 +144,10 @@ def stencil(inp1, inp2, inp3, inp4): offset_provider={}, ) if validate: - assert np.allclose(inp1, out[0][0]) - assert np.allclose(inp2, out[0][1]) - assert np.allclose(inp3, out[1][0]) - assert np.allclose(inp4, out[1][1]) + assert np.allclose(inp1.asnumpy(), out[0][0].asnumpy()) + assert np.allclose(inp2.asnumpy(), out[0][1].asnumpy()) + assert np.allclose(inp3.asnumpy(), out[1][0].asnumpy()) + assert np.allclose(inp4.asnumpy(), out[1][1].asnumpy()) @pytest.mark.parametrize( @@ -197,8 +197,8 @@ def fencil(size0, size1, size2, inp1, inp2, out1, out2): offset_provider={}, ) if validate: - assert np.allclose(inp1, out1) - assert np.allclose(inp2, out2) + assert np.allclose(inp1.asnumpy(), out1.asnumpy()) + assert np.allclose(inp2.asnumpy(), out2.asnumpy()) def test_asymetric_nested_tuple_of_field_output_constructed_inside(program_processor): @@ -255,9 +255,9 @@ def fencil(size0, size1, size2, inp1, inp2, inp3, out1, out2, out3): offset_provider={}, ) if validate: - assert np.allclose(inp1, out1) - assert np.allclose(inp2, out2) - assert np.allclose(inp3, out3) + assert np.allclose(inp1.asnumpy(), out1.asnumpy()) + assert np.allclose(inp2.asnumpy(), out2.asnumpy()) + assert np.allclose(inp3.asnumpy(), out3.asnumpy()) @pytest.mark.xfail(reason="Implement wrapper for extradim as tuple") @@ -323,7 +323,7 @@ def test_tuple_field_input(program_processor): } run_processor(tuple_input[dom], program_processor, (inp1, inp2), out=out, offset_provider={}) if validate: - assert np.allclose(np.asarray(inp1) + np.asarray(inp2), out) + assert np.allclose(inp1.asnumpy() + inp2.asnumpy(), out.asnumpy()) @pytest.mark.xfail(reason="Implement wrapper for extradim as tuple") @@ -389,7 +389,7 @@ def test_tuple_of_tuple_of_field_input(program_processor): ) if validate: assert np.allclose( - (np.asarray(inp1) + np.asarray(inp2) + np.asarray(inp3) + np.asarray(inp4)), out + (inp1.asnumpy() + inp2.asnumpy() + inp3.asnumpy() + inp4.asnumpy()), out.asnumpy() ) diff --git a/tests/next_tests/integration_tests/feature_tests/test_util_cases.py b/tests/next_tests/integration_tests/feature_tests/test_util_cases.py index 3f229ef389..579dec11f8 100644 --- a/tests/next_tests/integration_tests/feature_tests/test_util_cases.py +++ b/tests/next_tests/integration_tests/feature_tests/test_util_cases.py @@ -41,30 +41,30 @@ def mixed_args( def test_allocate_default_unique(cartesian_case): # noqa: F811 # fixtures a = cases.allocate(cartesian_case, mixed_args, "a")() - assert np.min(a) == 0 - assert np.max(a) == np.prod(tuple(cartesian_case.default_sizes.values())) - 1 + assert np.min(a.asnumpy()) == 0 + assert np.max(a.asnumpy()) == np.prod(tuple(cartesian_case.default_sizes.values())) - 1 b = cases.allocate(cartesian_case, mixed_args, "b")() - assert b == np.max(a) + 1 + assert b == np.max(a.asnumpy()) + 1 c = cases.allocate(cartesian_case, mixed_args, "c")() - assert np.min(c) == b + 1 - assert np.max(c) == np.prod(tuple(cartesian_case.default_sizes.values())) * 2 + assert np.min(c.asnumpy()) == b + 1 + assert np.max(c.asnumpy()) == np.prod(tuple(cartesian_case.default_sizes.values())) * 2 def test_allocate_return_default_zeros(cartesian_case): # noqa: F811 # fixtures a, (b, c) = cases.allocate(cartesian_case, mixed_args, cases.RETURN)() - assert np.all(np.asarray(a) == 0) - assert np.all(np.asarray(a) == b) - assert np.all(np.asarray(b) == c) + assert np.all(a.asnumpy() == 0) + assert np.all(b.asnumpy() == 0) + assert np.all(c.asnumpy() == 0) def test_allocate_const(cartesian_case): # noqa: F811 # fixtures a = cases.allocate(cartesian_case, mixed_args, "a").strategy(cases.ConstInitializer(42))() - assert np.all(np.asarray(a) == 42) + assert np.all(a.asnumpy() == 42) b = cases.allocate(cartesian_case, mixed_args, "b").strategy(cases.ConstInitializer(42))() assert b == 42.0 diff --git a/tests/next_tests/integration_tests/multi_feature_tests/ffront_tests/test_icon_like_scan.py b/tests/next_tests/integration_tests/multi_feature_tests/ffront_tests/test_icon_like_scan.py index cd948ffa02..8b4cedd98b 100644 --- a/tests/next_tests/integration_tests/multi_feature_tests/ffront_tests/test_icon_like_scan.py +++ b/tests/next_tests/integration_tests/multi_feature_tests/ffront_tests/test_icon_like_scan.py @@ -244,7 +244,7 @@ def test_solve_nonhydro_stencil_52_like_z_q(test_setup): comparison=lambda ref, a: np.allclose(ref[:, 1:], a[:, 1:]), ) - assert np.allclose(test_setup.z_q_ref[:, 1:], test_setup.z_q_out[:, 1:]) + assert np.allclose(test_setup.z_q_ref[:, 1:], test_setup.z_q_out[:, 1:].asnumpy()) @pytest.mark.uses_tuple_returns @@ -286,8 +286,8 @@ def test_solve_nonhydro_stencil_52_like(test_setup): test_setup.dummy, ) - assert np.allclose(test_setup.z_q_ref, test_setup.z_q) - assert np.allclose(test_setup.w_ref, test_setup.w) + assert np.allclose(test_setup.z_q_ref, test_setup.z_q.asnumpy()) + assert np.allclose(test_setup.w_ref, test_setup.w.asnumpy()) @pytest.mark.uses_tuple_returns @@ -306,5 +306,5 @@ def test_solve_nonhydro_stencil_52_like_with_gtfn_tuple_merge(test_setup): test_setup.w, ) - assert np.allclose(test_setup.z_q_ref, test_setup.z_q) - assert np.allclose(test_setup.w_ref, test_setup.w) + assert np.allclose(test_setup.z_q_ref, test_setup.z_q.asnumpy()) + assert np.allclose(test_setup.w_ref, test_setup.w.asnumpy()) diff --git a/tests/next_tests/integration_tests/multi_feature_tests/iterator_tests/test_anton_toy.py b/tests/next_tests/integration_tests/multi_feature_tests/iterator_tests/test_anton_toy.py index 829bc497cb..806ab7eb9a 100644 --- a/tests/next_tests/integration_tests/multi_feature_tests/iterator_tests/test_anton_toy.py +++ b/tests/next_tests/integration_tests/multi_feature_tests/iterator_tests/test_anton_toy.py @@ -103,4 +103,4 @@ def test_anton_toy(program_processor, lift_mode): ) if validate: - assert np.allclose(out, ref) + assert np.allclose(out.asnumpy(), ref) diff --git a/tests/next_tests/integration_tests/multi_feature_tests/iterator_tests/test_column_stencil.py b/tests/next_tests/integration_tests/multi_feature_tests/iterator_tests/test_column_stencil.py index d05b14d73d..fd571514ac 100644 --- a/tests/next_tests/integration_tests/multi_feature_tests/iterator_tests/test_column_stencil.py +++ b/tests/next_tests/integration_tests/multi_feature_tests/iterator_tests/test_column_stencil.py @@ -16,6 +16,7 @@ import pytest import gt4py.next as gtx +from gt4py.next import utils from gt4py.next.iterator.builtins import * from gt4py.next.iterator.runtime import closure, fendef, fundef, offset @@ -89,7 +90,7 @@ def test_basic_column_stencils(program_processor, lift_mode, basic_stencils): ) out = gtx.as_field([IDim, KDim], np.zeros(shape)) - ref = ref_fun(inp) + ref = ref_fun(inp.asnumpy()) run_processor( stencil[{IDim: range(0, shape[0]), KDim: range(0, shape[1])}], @@ -102,7 +103,7 @@ def test_basic_column_stencils(program_processor, lift_mode, basic_stencils): ) if validate: - assert np.allclose(ref, out) + assert np.allclose(ref, out.asnumpy()) @fundef @@ -157,7 +158,7 @@ def test_k_level_condition(program_processor, lift_mode, fun, k_level, inp_funct k_size = 5 inp = inp_function(k_size) - ref = ref_function(inp) + ref = ref_function(utils.asnumpy(inp)) out = gtx.as_field([KDim], np.zeros((5,), dtype=np.int32)) @@ -173,7 +174,7 @@ def test_k_level_condition(program_processor, lift_mode, fun, k_level, inp_funct ) if validate: - np.allclose(ref, out) + np.allclose(ref, out.asnumpy()) @fundef @@ -222,7 +223,7 @@ def test_ksum_scan(program_processor, lift_mode, kstart, reference): ) if validate: - assert np.allclose(reference, np.asarray(out)) + assert np.allclose(reference, out.asnumpy()) @fundef @@ -260,7 +261,7 @@ def test_ksum_back_scan(program_processor, lift_mode): ) if validate: - assert np.allclose(ref, np.asarray(out)) + assert np.allclose(ref, out.asnumpy()) @fundef @@ -366,7 +367,7 @@ def test_different_vertical_sizes(program_processor): ) if validate: - assert np.allclose(ref[1:], out[1:]) + assert np.allclose(ref[1:], out.asnumpy()[1:]) @fundef @@ -392,7 +393,7 @@ def test_different_vertical_sizes_with_origin(program_processor): inp0 = gtx.as_field([KDim], np.arange(0, k_size)) inp1 = gtx.as_field([KDim], np.arange(0, k_size + 1), origin={KDim: 1}) out = gtx.as_field([KDim], np.zeros(k_size, dtype=np.int64)) - ref = np.asarray(inp0) + np.asarray(inp1)[:-1] + ref = inp0.asnumpy() + inp1.asnumpy()[:-1] run_processor( sum_fencil, @@ -405,7 +406,7 @@ def test_different_vertical_sizes_with_origin(program_processor): ) if validate: - assert np.allclose(ref, out) + assert np.allclose(ref, out.asnumpy()) # TODO(havogt) test tuple_get builtin on a Column diff --git a/tests/next_tests/integration_tests/multi_feature_tests/iterator_tests/test_fvm_nabla.py b/tests/next_tests/integration_tests/multi_feature_tests/iterator_tests/test_fvm_nabla.py index 47867b9a64..e1d959aba9 100644 --- a/tests/next_tests/integration_tests/multi_feature_tests/iterator_tests/test_fvm_nabla.py +++ b/tests/next_tests/integration_tests/multi_feature_tests/iterator_tests/test_fvm_nabla.py @@ -159,8 +159,8 @@ def test_compute_zavgS(program_processor, lift_mode): ) if validate: - assert_close(-199755464.25741270, np.min(zavgS)) - assert_close(388241977.58389181, np.max(zavgS)) + assert_close(-199755464.25741270, np.min(zavgS.asnumpy())) + assert_close(388241977.58389181, np.max(zavgS.asnumpy())) run_processor( compute_zavgS_fencil, @@ -173,8 +173,8 @@ def test_compute_zavgS(program_processor, lift_mode): lift_mode=lift_mode, ) if validate: - assert_close(-1000788897.3202186, np.min(zavgS)) - assert_close(1000788897.3202186, np.max(zavgS)) + assert_close(-1000788897.3202186, np.min(zavgS.asnumpy())) + assert_close(1000788897.3202186, np.max(zavgS.asnumpy())) @fendef @@ -222,11 +222,11 @@ def test_compute_zavgS2(program_processor, lift_mode): ) if validate: - assert_close(-199755464.25741270, np.min(zavgS[0])) - assert_close(388241977.58389181, np.max(zavgS[0])) + assert_close(-199755464.25741270, np.min(zavgS[0].asnumpy())) + assert_close(388241977.58389181, np.max(zavgS[0].asnumpy())) - assert_close(-1000788897.3202186, np.min(zavgS[1])) - assert_close(1000788897.3202186, np.max(zavgS[1])) + assert_close(-1000788897.3202186, np.min(zavgS[1].asnumpy())) + assert_close(1000788897.3202186, np.max(zavgS[1].asnumpy())) @pytest.mark.requires_atlas @@ -266,10 +266,10 @@ def test_nabla(program_processor, lift_mode): ) if validate: - assert_close(-3.5455427772566003e-003, np.min(pnabla_MXX)) - assert_close(3.5455427772565435e-003, np.max(pnabla_MXX)) - assert_close(-3.3540113705465301e-003, np.min(pnabla_MYY)) - assert_close(3.3540113705465301e-003, np.max(pnabla_MYY)) + assert_close(-3.5455427772566003e-003, np.min(pnabla_MXX.asnumpy())) + assert_close(3.5455427772565435e-003, np.max(pnabla_MXX.asnumpy())) + assert_close(-3.3540113705465301e-003, np.min(pnabla_MYY.asnumpy())) + assert_close(3.3540113705465301e-003, np.max(pnabla_MYY.asnumpy())) @fendef @@ -322,10 +322,10 @@ def test_nabla2(program_processor, lift_mode): ) if validate: - assert_close(-3.5455427772566003e-003, np.min(pnabla_MXX)) - assert_close(3.5455427772565435e-003, np.max(pnabla_MXX)) - assert_close(-3.3540113705465301e-003, np.min(pnabla_MYY)) - assert_close(3.3540113705465301e-003, np.max(pnabla_MYY)) + assert_close(-3.5455427772566003e-003, np.min(pnabla_MXX.asnumpy())) + assert_close(3.5455427772565435e-003, np.max(pnabla_MXX.asnumpy())) + assert_close(-3.3540113705465301e-003, np.min(pnabla_MYY.asnumpy())) + assert_close(3.3540113705465301e-003, np.max(pnabla_MYY.asnumpy())) @fundef @@ -407,7 +407,7 @@ def test_nabla_sign(program_processor, lift_mode): ) if validate: - assert_close(-3.5455427772566003e-003, np.min(pnabla_MXX)) - assert_close(3.5455427772565435e-003, np.max(pnabla_MXX)) - assert_close(-3.3540113705465301e-003, np.min(pnabla_MYY)) - assert_close(3.3540113705465301e-003, np.max(pnabla_MYY)) + assert_close(-3.5455427772566003e-003, np.min(pnabla_MXX.asnumpy())) + assert_close(3.5455427772565435e-003, np.max(pnabla_MXX.asnumpy())) + assert_close(-3.3540113705465301e-003, np.min(pnabla_MYY.asnumpy())) + assert_close(3.3540113705465301e-003, np.max(pnabla_MYY.asnumpy())) diff --git a/tests/next_tests/integration_tests/multi_feature_tests/iterator_tests/test_hdiff.py b/tests/next_tests/integration_tests/multi_feature_tests/iterator_tests/test_hdiff.py index 8aabd18267..9bba1ab89c 100644 --- a/tests/next_tests/integration_tests/multi_feature_tests/iterator_tests/test_hdiff.py +++ b/tests/next_tests/integration_tests/multi_feature_tests/iterator_tests/test_hdiff.py @@ -97,4 +97,4 @@ def test_hdiff(hdiff_reference, program_processor, lift_mode): ) if validate: - assert np.allclose(out[:, :, 0], out_s) + assert np.allclose(out[:, :, 0], out_s.asnumpy()) diff --git a/tests/next_tests/integration_tests/multi_feature_tests/iterator_tests/test_vertical_advection.py b/tests/next_tests/integration_tests/multi_feature_tests/iterator_tests/test_vertical_advection.py index 29c82442ea..f2a6505a7e 100644 --- a/tests/next_tests/integration_tests/multi_feature_tests/iterator_tests/test_vertical_advection.py +++ b/tests/next_tests/integration_tests/multi_feature_tests/iterator_tests/test_vertical_advection.py @@ -158,4 +158,4 @@ def test_tridiag(fencil, tridiag_reference, program_processor, lift_mode): ) if validate: - assert np.allclose(x, x_s) + assert np.allclose(x, x_s.asnumpy()) diff --git a/tests/next_tests/integration_tests/multi_feature_tests/iterator_tests/test_with_toy_connectivity.py b/tests/next_tests/integration_tests/multi_feature_tests/iterator_tests/test_with_toy_connectivity.py index 6354e45451..000d3c4822 100644 --- a/tests/next_tests/integration_tests/multi_feature_tests/iterator_tests/test_with_toy_connectivity.py +++ b/tests/next_tests/integration_tests/multi_feature_tests/iterator_tests/test_with_toy_connectivity.py @@ -99,7 +99,7 @@ def test_sum_edges_to_vertices(program_processor, lift_mode, stencil): lift_mode=lift_mode, ) if validate: - assert np.allclose(out, ref) + assert np.allclose(out.asnumpy(), ref) @fundef @@ -122,7 +122,7 @@ def test_map_neighbors(program_processor, lift_mode): lift_mode=lift_mode, ) if validate: - assert np.allclose(out, ref) + assert np.allclose(out.asnumpy(), ref) @fundef @@ -146,7 +146,7 @@ def test_map_make_const_list(program_processor, lift_mode): lift_mode=lift_mode, ) if validate: - assert np.allclose(out, ref) + assert np.allclose(out.asnumpy(), ref) @fundef @@ -172,7 +172,7 @@ def test_first_vertex_neigh_of_first_edge_neigh_of_cells_fencil(program_processo lift_mode=lift_mode, ) if validate: - assert np.allclose(out, ref) + assert np.allclose(out.asnumpy(), ref) @fundef @@ -200,7 +200,7 @@ def test_sparse_input_field(program_processor, lift_mode): ) if validate: - assert np.allclose(out, ref) + assert np.allclose(out.asnumpy(), ref) def test_sparse_input_field_v2v(program_processor, lift_mode): @@ -226,7 +226,7 @@ def test_sparse_input_field_v2v(program_processor, lift_mode): ) if validate: - assert np.allclose(out, ref) + assert np.allclose(out.asnumpy(), ref) @fundef @@ -254,7 +254,7 @@ def test_slice_sparse(program_processor, lift_mode): ) if validate: - assert np.allclose(out, ref) + assert np.allclose(out.asnumpy(), ref) @fundef @@ -309,7 +309,7 @@ def test_shift_sliced_sparse(program_processor, lift_mode): ) if validate: - assert np.allclose(out, ref) + assert np.allclose(out.asnumpy(), ref) @fundef @@ -337,7 +337,7 @@ def test_slice_shifted_sparse(program_processor, lift_mode): ) if validate: - assert np.allclose(out, ref) + assert np.allclose(out.asnumpy(), ref) @fundef @@ -365,7 +365,7 @@ def test_lift(program_processor, lift_mode): lift_mode=lift_mode, ) if validate: - assert np.allclose(out, ref) + assert np.allclose(out.asnumpy(), ref) @fundef @@ -390,7 +390,7 @@ def test_shift_sparse_input_field(program_processor, lift_mode): ) if validate: - assert np.allclose(out, ref) + assert np.allclose(out.asnumpy(), ref) @fundef @@ -443,7 +443,7 @@ def test_shift_sparse_input_field2(program_processor, lift_mode): ) if validate: - assert np.allclose(out1, out2) + assert np.allclose(out1.asnumpy(), out2.asnumpy()) @fundef @@ -484,4 +484,4 @@ def test_sparse_shifted_stencil_reduce(program_processor, lift_mode): ) if validate: - assert np.allclose(np.asarray(out), ref) + assert np.allclose(out.asnumpy(), ref) diff --git a/tests/next_tests/integration_tests/multi_feature_tests/otf_tests/test_gtfn_workflow.py b/tests/next_tests/integration_tests/multi_feature_tests/otf_tests/test_gtfn_workflow.py index d851c5560a..c91be04999 100644 --- a/tests/next_tests/integration_tests/multi_feature_tests/otf_tests/test_gtfn_workflow.py +++ b/tests/next_tests/integration_tests/multi_feature_tests/otf_tests/test_gtfn_workflow.py @@ -41,4 +41,4 @@ def copy(inp: gtx.Field[[IDim, JDim], gtx.int32]) -> gtx.Field[[IDim, JDim], gtx copy(inp, out=out, offset_provider={}) - assert np.allclose(inp[:out_nx, :out_ny], out) + assert np.allclose(inp[:out_nx, :out_ny].asnumpy(), out.asnumpy()) diff --git a/tests/next_tests/unit_tests/embedded_tests/test_nd_array_field.py b/tests/next_tests/unit_tests/embedded_tests/test_nd_array_field.py index 00dbf68274..436e672cc5 100644 --- a/tests/next_tests/unit_tests/embedded_tests/test_nd_array_field.py +++ b/tests/next_tests/unit_tests/embedded_tests/test_nd_array_field.py @@ -555,8 +555,8 @@ def test_setitem(index, value): domain=common.Domain(dims=(IDim, JDim), ranges=(UnitRange(0, 10), UnitRange(0, 10))), ) - expected = np.copy(field.ndarray) - expected[index] = value + expected = np.copy(field.asnumpy()) + expected[index] = value.asnumpy() if common.is_field(value) else value field[index] = value From 67e5270729a5951f3902942052afa423d778b965 Mon Sep 17 00:00:00 2001 From: Hannes Vogt Date: Tue, 21 Nov 2023 22:17:54 +0100 Subject: [PATCH 11/85] feature[next]: remap and connectivity implementations for embedded (#1309) Adds - ConnectivityField protocol - NdArrayFieldConnectivity for unstructured remap, CartesianConnectivity for Cartesian remap - implements neighbor_sum, max_over, min_over for field TODOs for next PR: - support for remap with connectivities with has_skip_values=True --------- Co-authored-by: Enrique Gonzalez Paredes Co-authored-by: Enrique G. Paredes <18477+egparedes@users.noreply.github.com> --- src/gt4py/_core/definitions.py | 27 ++ src/gt4py/next/__init__.py | 9 +- src/gt4py/next/common.py | 297 +++++++++++++++++- src/gt4py/next/constructors.py | 70 ++++- src/gt4py/next/embedded/__init__.py | 10 + src/gt4py/next/embedded/common.py | 3 +- src/gt4py/next/embedded/context.py | 64 ++++ src/gt4py/next/embedded/nd_array_field.py | 283 +++++++++++++++-- src/gt4py/next/ffront/decorator.py | 25 +- src/gt4py/next/ffront/fbuiltins.py | 82 +++-- src/gt4py/next/iterator/embedded.py | 25 +- tests/next_tests/exclusion_matrices.py | 2 - tests/next_tests/integration_tests/cases.py | 2 +- .../ffront_tests/ffront_test_utils.py | 2 +- .../ffront_tests/test_execution.py | 6 +- .../ffront_tests/test_icon_like_scan.py | 2 +- .../embedded_tests/test_basic_program.py | 47 +++ .../embedded_tests/test_nd_array_field.py | 66 +++- tests/next_tests/unit_tests/test_common.py | 129 ++++++++ 19 files changed, 1057 insertions(+), 94 deletions(-) create mode 100644 src/gt4py/next/embedded/context.py create mode 100644 tests/next_tests/unit_tests/embedded_tests/test_basic_program.py diff --git a/src/gt4py/_core/definitions.py b/src/gt4py/_core/definitions.py index 79543a1849..0e6301ae0f 100644 --- a/src/gt4py/_core/definitions.py +++ b/src/gt4py/_core/definitions.py @@ -490,3 +490,30 @@ def __rtruediv__(self, other: Any) -> NDArrayObject: def __pow__(self, other: NDArrayObject | Scalar) -> NDArrayObject: ... + + def __eq__(self, other: NDArrayObject | Scalar) -> NDArrayObject: # type: ignore[override] # mypy want to return `bool` + ... + + def __ne__(self, other: NDArrayObject | Scalar) -> NDArrayObject: # type: ignore[override] # mypy want to return `bool` + ... + + def __gt__(self, other: NDArrayObject | Scalar) -> NDArrayObject: # type: ignore[misc] # Forward operator is not callable + ... + + def __ge__(self, other: NDArrayObject | Scalar) -> NDArrayObject: # type: ignore[misc] # Forward operator is not callable + ... + + def __lt__(self, other: NDArrayObject | Scalar) -> NDArrayObject: # type: ignore[misc] # Forward operator is not callable + ... + + def __le__(self, other: NDArrayObject | Scalar) -> NDArrayObject: # type: ignore[misc] # Forward operator is not callable + ... + + def __and__(self, other: NDArrayObject | Scalar) -> NDArrayObject: + ... + + def __or__(self, other: NDArrayObject | Scalar) -> NDArrayObject: + ... + + def __xor(self, other: NDArrayObject | Scalar) -> NDArrayObject: + ... diff --git a/src/gt4py/next/__init__.py b/src/gt4py/next/__init__.py index 696c4f174c..cbd5735949 100644 --- a/src/gt4py/next/__init__.py +++ b/src/gt4py/next/__init__.py @@ -24,8 +24,8 @@ """ from . import common, ffront, iterator, program_processors, type_inference -from .common import Dimension, DimensionKind, Field, GridType -from .constructors import as_field, empty, full, ones, zeros +from .common import Dimension, DimensionKind, Domain, Field, GridType, UnitRange, domain, unit_range +from .constructors import as_connectivity, as_field, empty, full, ones, zeros from .embedded import ( # Just for registering field implementations nd_array_field as _nd_array_field, ) @@ -53,12 +53,17 @@ "DimensionKind", "Field", "GridType", + "domain", + "Domain", + "unit_range", + "UnitRange", # from constructors "empty", "zeros", "ones", "full", "as_field", + "as_connectivity", # from iterator "NeighborTableOffsetProvider", "StridedNeighborOffsetProvider", diff --git a/src/gt4py/next/common.py b/src/gt4py/next/common.py index 51ad14f22d..7f1ad8c0bb 100644 --- a/src/gt4py/next/common.py +++ b/src/gt4py/next/common.py @@ -19,10 +19,10 @@ import dataclasses import enum import functools +import numbers import sys import types from collections.abc import Mapping, Sequence, Set -from typing import overload import numpy as np import numpy.typing as npt @@ -33,6 +33,7 @@ Any, Callable, ClassVar, + Never, Optional, ParamSpec, Protocol, @@ -41,14 +42,14 @@ TypeVar, cast, extended_runtime_checkable, + overload, runtime_checkable, ) from gt4py.eve.type_definitions import StrEnum -DimsT = TypeVar( - "DimsT", covariant=True -) # bound to `Sequence[Dimension]` if instance of Dimension would be a type +DimT = TypeVar("DimT", bound="Dimension") # , covariant=True) +DimsT = TypeVar("DimsT", bound=Sequence["Dimension"], covariant=True) class Infinity(int): @@ -61,6 +62,9 @@ def negative(cls) -> Infinity: return cls(-sys.maxsize) +Tag: TypeAlias = str + + @enum.unique class DimensionKind(StrEnum): HORIZONTAL = "horizontal" @@ -96,6 +100,7 @@ def __init__(self, start: core_defs.IntegralScalar, stop: core_defs.IntegralScal object.__setattr__(self, "start", 0) object.__setattr__(self, "stop", 0) + # TODO: the whole infinity idea and implementation is broken and should be replaced @classmethod def infinity(cls) -> UnitRange: return cls(Infinity.negative(), Infinity.positive()) @@ -113,10 +118,10 @@ def __getitem__(self, index: int) -> int: ... @overload - def __getitem__(self, index: slice) -> UnitRange: + def __getitem__(self, index: slice) -> UnitRange: # noqa: F811 # redefine unused ... - def __getitem__(self, index: int | slice) -> int | UnitRange: + def __getitem__(self, index: int | slice) -> int | UnitRange: # noqa: F811 # redefine unused if isinstance(index, slice): start, stop, step = index.indices(len(self)) if step != 1: @@ -149,6 +154,32 @@ def __le__(self, other: Set[int]): else: return Set.__le__(self, other) + def __add__(self, other: int | Set[int]) -> UnitRange: + if isinstance(other, int): + if other == Infinity.positive(): + return UnitRange.infinity() + elif other == Infinity.negative(): + return UnitRange(0, 0) + return UnitRange( + *( + s if s in [Infinity.negative(), Infinity.positive()] else s + other + for s in (self.start, self.stop) + ) + ) + else: + raise NotImplementedError("Can only compute union with int instances.") + + def __sub__(self, other: int | Set[int]) -> UnitRange: + if isinstance(other, int): + if other == Infinity.negative(): + return self + Infinity.positive() + elif other == Infinity.positive(): + return self + Infinity.negative() + else: + return self + (-other) + else: + raise NotImplementedError("Can only compute substraction with int instances.") + __ge__ = __lt__ = __gt__ = lambda self, other: NotImplemented def __str__(self) -> str: @@ -184,8 +215,8 @@ def unit_range(r: RangeLike) -> UnitRange: IntIndex: TypeAlias = int | core_defs.IntegralScalar -NamedIndex: TypeAlias = tuple[Dimension, IntIndex] -NamedRange: TypeAlias = tuple[Dimension, UnitRange] +NamedIndex: TypeAlias = tuple[Dimension, IntIndex] # TODO: convert to NamedTuple +NamedRange: TypeAlias = tuple[Dimension, UnitRange] # TODO: convert to NamedTuple RelativeIndexElement: TypeAlias = IntIndex | slice | types.EllipsisType AbsoluteIndexElement: TypeAlias = NamedIndex | NamedRange AnyIndexElement: TypeAlias = RelativeIndexElement | AbsoluteIndexElement @@ -260,8 +291,8 @@ class Domain(Sequence[NamedRange]): def __init__( self, *args: NamedRange, - dims: Optional[tuple[Dimension, ...]] = None, - ranges: Optional[tuple[UnitRange, ...]] = None, + dims: Optional[Sequence[Dimension]] = None, + ranges: Optional[Sequence[UnitRange]] = None, ) -> None: if dims is not None or ranges is not None: if dims is None and ranges is None: @@ -285,8 +316,8 @@ def __init__( f"Number of provided dimensions ({len(dims)}) does not match number of provided ranges ({len(ranges)})." ) - object.__setattr__(self, "dims", dims) - object.__setattr__(self, "ranges", ranges) + object.__setattr__(self, "dims", tuple(dims)) + object.__setattr__(self, "ranges", tuple(ranges)) else: if not all(is_named_range(arg) for arg in args): raise ValueError(f"Elements of `Domain` need to be `NamedRange`s, got `{args}`.") @@ -300,6 +331,10 @@ def __init__( def __len__(self) -> int: return len(self.ranges) + @property + def ndim(self) -> int: + return len(self.dims) + @property def shape(self) -> tuple[int, ...]: return tuple(len(r) for r in self.ranges) @@ -309,14 +344,16 @@ def __getitem__(self, index: int) -> NamedRange: ... @overload - def __getitem__(self, index: slice) -> Domain: + def __getitem__(self, index: slice) -> Domain: # noqa: F811 # redefine unused ... @overload - def __getitem__(self, index: Dimension) -> NamedRange: + def __getitem__(self, index: Dimension) -> NamedRange: # noqa: F811 # redefine unused ... - def __getitem__(self, index: int | slice | Dimension) -> NamedRange | Domain: + def __getitem__( # noqa: F811 # redefine unused + self, index: int | slice | Dimension + ) -> NamedRange | Domain: # noqa: F811 # redefine unused if isinstance(index, int): return self.dims[index], self.ranges[index] elif isinstance(index, slice): @@ -360,6 +397,36 @@ def __and__(self, other: Domain) -> Domain: def __str__(self) -> str: return f"Domain({', '.join(f'{e[0]}={e[1]}' for e in self)})" + def dim_index(self, dim: Dimension) -> Optional[int]: + return self.dims.index(dim) if dim in self.dims else None + + def pop(self, index: int | Dimension = -1) -> Domain: + return self.replace(index) + + def insert(self, index: int | Dimension, *named_ranges: NamedRange) -> Domain: + if isinstance(index, int) and index == len(self.dims): + new_dims, new_ranges = zip(*named_ranges) + return Domain(dims=self.dims + new_dims, ranges=self.ranges + new_ranges) + else: + return self.replace(index, *named_ranges) + + def replace(self, index: int | Dimension, *named_ranges: NamedRange) -> Domain: + assert all(is_named_range(nr) for nr in named_ranges) + if isinstance(index, Dimension): + dim_index = self.dim_index(index) + if dim_index is None: + raise ValueError(f"Dimension {index} not found in Domain.") + index = dim_index + if not (-len(self.dims) <= index < len(self.dims)): + raise IndexError(f"Index {index} out of bounds for Domain of length {len(self.dims)}.") + if index < 0: + index += len(self.dims) + new_dims, new_ranges = zip(*named_ranges) if len(named_ranges) > 0 else ((), ()) + dims = self.dims[:index] + new_dims + self.dims[index + 1 :] + ranges = self.ranges[:index] + new_ranges + self.ranges[index + 1 :] + + return Domain(dims=dims, ranges=ranges) + DomainLike: TypeAlias = ( Sequence[tuple[Dimension, RangeLike]] | Mapping[Dimension, RangeLike] @@ -456,6 +523,10 @@ class Field(NextGTDimsInterface, core_defs.GTOriginInterface, Protocol[DimsT, co def domain(self) -> Domain: ... + @property + def codomain(self) -> type[core_defs.ScalarT] | Dimension: + ... + @property def dtype(self) -> core_defs.DType[core_defs.ScalarT]: ... @@ -472,7 +543,7 @@ def asnumpy(self) -> np.ndarray: ... @abc.abstractmethod - def remap(self, index_field: Field) -> Field: + def remap(self, index_field: ConnectivityField | fbuiltins.FieldOffset) -> Field: ... @abc.abstractmethod @@ -481,7 +552,7 @@ def restrict(self, item: AnyIndexSpec) -> Field | core_defs.ScalarT: # Operators @abc.abstractmethod - def __call__(self, index_field: Field) -> Field: + def __call__(self, index_field: ConnectivityField | fbuiltins.FieldOffset) -> Field: ... @abc.abstractmethod @@ -592,6 +663,100 @@ def is_mutable_field( return isinstance(v, MutableField) # type: ignore[misc] # we use extended_runtime_checkable +class ConnectivityKind(enum.Flag): + MODIFY_DIMS = enum.auto() + MODIFY_RANK = enum.auto() + MODIFY_STRUCTURE = enum.auto() + + +@extended_runtime_checkable +class ConnectivityField(Field[DimsT, core_defs.IntegralScalar], Protocol[DimsT, DimT]): # type: ignore[misc] # DimT should be covariant, but break in another place + @property + @abc.abstractmethod + def codomain(self) -> DimT: + ... + + @property + def kind(self) -> ConnectivityKind: + return ( + ConnectivityKind.MODIFY_DIMS + | ConnectivityKind.MODIFY_RANK + | ConnectivityKind.MODIFY_STRUCTURE + ) + + @abc.abstractmethod + def inverse_image(self, image_range: UnitRange | NamedRange) -> Sequence[NamedRange]: + ... + + # Operators + def __abs__(self) -> Never: + raise TypeError("ConnectivityField does not support this operation") + + def __neg__(self) -> Never: + raise TypeError("ConnectivityField does not support this operation") + + def __invert__(self) -> Never: + raise TypeError("ConnectivityField does not support this operation") + + def __eq__(self, other: Any) -> Never: + raise TypeError("ConnectivityField does not support this operation") + + def __ne__(self, other: Any) -> Never: + raise TypeError("ConnectivityField does not support this operation") + + def __add__(self, other: Field | core_defs.IntegralScalar) -> Never: + raise TypeError("ConnectivityField does not support this operation") + + def __radd__(self, other: Field | core_defs.IntegralScalar) -> Never: # type: ignore[misc] # Forward operator not callalbe + raise TypeError("ConnectivityField does not support this operation") + + def __sub__(self, other: Field | core_defs.IntegralScalar) -> Never: + raise TypeError("ConnectivityField does not support this operation") + + def __rsub__(self, other: Field | core_defs.IntegralScalar) -> Never: # type: ignore[misc] # Forward operator not callalbe + raise TypeError("ConnectivityField does not support this operation") + + def __mul__(self, other: Field | core_defs.IntegralScalar) -> Never: + raise TypeError("ConnectivityField does not support this operation") + + def __rmul__(self, other: Field | core_defs.IntegralScalar) -> Never: # type: ignore[misc] # Forward operator not callalbe + raise TypeError("ConnectivityField does not support this operation") + + def __truediv__(self, other: Field | core_defs.IntegralScalar) -> Never: + raise TypeError("ConnectivityField does not support this operation") + + def __rtruediv__(self, other: Field | core_defs.IntegralScalar) -> Never: # type: ignore[misc] # Forward operator not callalbe + raise TypeError("ConnectivityField does not support this operation") + + def __floordiv__(self, other: Field | core_defs.IntegralScalar) -> Never: + raise TypeError("ConnectivityField does not support this operation") + + def __rfloordiv__(self, other: Field | core_defs.IntegralScalar) -> Never: # type: ignore[misc] # Forward operator not callalbe + raise TypeError("ConnectivityField does not support this operation") + + def __pow__(self, other: Field | core_defs.IntegralScalar) -> Never: + raise TypeError("ConnectivityField does not support this operation") + + def __and__(self, other: Field | core_defs.IntegralScalar) -> Never: + raise TypeError("ConnectivityField does not support this operation") + + def __or__(self, other: Field | core_defs.IntegralScalar) -> Never: + raise TypeError("ConnectivityField does not support this operation") + + def __xor__(self, other: Field | core_defs.IntegralScalar) -> Never: + raise TypeError("ConnectivityField does not support this operation") + + +def is_connectivity_field( + v: Any, +) -> TypeGuard[ConnectivityField]: + # This function is introduced to localize the `type: ignore` because + # extended_runtime_checkable does not make the protocol runtime_checkable + # for mypy. + # TODO(egparedes): remove it when extended_runtime_checkable is fixed + return isinstance(v, ConnectivityField) # type: ignore[misc] # we use extended_runtime_checkable + + @functools.singledispatch def field( definition: Any, @@ -603,6 +768,18 @@ def field( raise NotImplementedError +@functools.singledispatch +def connectivity( + definition: Any, + /, + codomain: Dimension, + *, + domain: Optional[DomainLike] = None, + dtype: Optional[core_defs.DType] = None, +) -> ConnectivityField: + raise NotImplementedError + + @dataclasses.dataclass(frozen=True) class GTInfo: definition: Any @@ -638,6 +815,92 @@ class NeighborTable(Connectivity, Protocol): table: npt.NDArray +OffsetProviderElem: TypeAlias = Dimension | Connectivity +OffsetProvider: TypeAlias = Mapping[Tag, OffsetProviderElem] + + +@dataclasses.dataclass(frozen=True, eq=False) +class CartesianConnectivity(ConnectivityField[DimsT, DimT]): + dimension: DimT + offset: int = 0 + + @classmethod + def __gt_builtin_func__(cls, _: fbuiltins.BuiltInFunction) -> Never: # type: ignore[override] + raise NotImplementedError() + + @property + def ndarray(self) -> Never: + raise NotImplementedError() + + def asnumpy(self) -> Never: + raise NotImplementedError() + + @functools.cached_property + def domain(self) -> Domain: + return Domain(dims=(self.dimension,), ranges=(UnitRange.infinity(),)) + + @property + def __gt_dims__(self) -> tuple[Dimension, ...]: + return self.domain.dims + + @property + def __gt_origin__(self) -> Never: + raise TypeError("CartesianConnectivity does not support this operation") + + @property + def dtype(self) -> core_defs.DType[core_defs.IntegralScalar]: + return core_defs.Int32DType() # type: ignore[return-value] + + @functools.cached_property + def codomain(self) -> DimT: + return self.dimension + + @functools.cached_property + def kind(self) -> ConnectivityKind: + return ConnectivityKind(0) + + @classmethod + def from_offset( + cls, + definition: int, + /, + codomain: DimT, + *, + domain: Optional[DomainLike] = None, + dtype: Optional[core_defs.DTypeLike] = None, + ) -> CartesianConnectivity: + assert domain is None + assert dtype is None + return cls(codomain, definition) + + def inverse_image(self, image_range: UnitRange | NamedRange) -> Sequence[NamedRange]: + if not isinstance(image_range, UnitRange): + if image_range[0] != self.codomain: + raise ValueError( + f"Dimension {image_range[0]} does not match the codomain dimension {self.codomain}" + ) + + image_range = image_range[1] + + assert isinstance(image_range, UnitRange) + return ((self.codomain, image_range - self.offset),) + + def remap(self, index_field: ConnectivityField | fbuiltins.FieldOffset) -> ConnectivityField: + raise NotImplementedError() + + __call__ = remap + + def restrict(self, index: AnyIndexSpec) -> core_defs.IntegralScalar: + if is_int_index(index): + return index + self.offset + raise NotImplementedError() # we could possibly implement with a FunctionField, but we don't have a use-case + + __getitem__ = restrict + + +connectivity.register(numbers.Integral, CartesianConnectivity.from_offset) + + @enum.unique class GridType(StrEnum): CARTESIAN = "cartesian" diff --git a/src/gt4py/next/constructors.py b/src/gt4py/next/constructors.py index 42b0bcda90..63fde1cfde 100644 --- a/src/gt4py/next/constructors.py +++ b/src/gt4py/next/constructors.py @@ -212,7 +212,7 @@ def as_field( This function supports partial binding of arguments, see :class:`eve.utils.partial` for details. See :func:`empty` for further details about the meaning of the extra keyword arguments. - Parameters: + Arguments: domain: Definition of the domain of the field (and consequently of the shape of the allocated field buffer). In addition to the values allowed in `empty`, it can also just be a sequence of dimensions, in which case the sizes of each dimension will then be taken from the shape of `data`. @@ -283,7 +283,7 @@ def as_field( dtype = core_defs.dtype(dtype) assert dtype.tensor_shape == () # TODO - if allocator is device is None and xtyping.supports_dlpack(data): + if (allocator is None) and (device is None) and xtyping.supports_dlpack(data): device = core_defs.Device(*data.__dlpack_device__()) field = empty( @@ -297,3 +297,69 @@ def as_field( field[...] = field.array_ns.asarray(data) return field + + +@eve.utils.with_fluid_partial +def as_connectivity( + domain: common.DomainLike | Sequence[common.Dimension], + codomain: common.Dimension, + data: core_defs.NDArrayObject, + dtype: Optional[core_defs.DType] = None, + *, + allocator: Optional[next_allocators.FieldBufferAllocatorProtocol] = None, + device: Optional[core_defs.Device] = None, + # copy=False, TODO +) -> common.ConnectivityField: + """ + Construct a connectivity field from the given domain, codomain, and data. + + Arguments: + domain: The domain of the connectivity field. It can be either a `common.DomainLike` object or a + sequence of `common.Dimension` objects. + codomain: The codomain dimension of the connectivity field. + data: The data used to construct the connectivity field. + dtype: The data type of the connectivity field. If not provided, it will be inferred from the data. + allocator: The allocator used to allocate the buffer for the connectivity field. If not provided, + a default allocator will be used. + device: The device on which the connectivity field will be allocated. If not provided, the default + device will be used. + + Returns: + The constructed connectivity field. + + Raises: + ValueError: If the domain or codomain is invalid, or if the shape of the data does not match the domain shape. + """ + if isinstance(domain, Sequence) and all(isinstance(dim, common.Dimension) for dim in domain): + domain = cast(Sequence[common.Dimension], domain) + if len(domain) != data.ndim: + raise ValueError( + f"Cannot construct `Field` from array of shape `{data.shape}` and domain `{domain}` " + ) + actual_domain = common.domain([(d, (0, s)) for d, s in zip(domain, data.shape)]) + else: + actual_domain = common.domain(cast(common.DomainLike, domain)) + + if not isinstance(codomain, common.Dimension): + raise ValueError(f"Invalid codomain dimension `{codomain}`") + + # TODO(egparedes): allow zero-copy construction (no reallocation) if buffer has + # already the correct layout and device. + shape = storage_utils.asarray(data).shape + if shape != actual_domain.shape: + raise ValueError(f"Cannot construct `Field` from array of shape `{shape}` ") + if dtype is None: + dtype = storage_utils.asarray(data).dtype + dtype = core_defs.dtype(dtype) + assert dtype.tensor_shape == () # TODO + + if (allocator is None) and (device is None) and xtyping.supports_dlpack(data): + device = core_defs.Device(*data.__dlpack_device__()) + buffer = next_allocators.allocate(actual_domain, dtype, allocator=allocator, device=device) + buffer.ndarray[...] = storage_utils.asarray(data) # type: ignore[index] # TODO(havogt): consider addin MutableNDArrayObject + connectivity_field = common.connectivity( + buffer.ndarray, codomain=codomain, domain=actual_domain + ) + assert isinstance(connectivity_field, nd_array_field.NdArrayConnectivityField) + + return connectivity_field diff --git a/src/gt4py/next/embedded/__init__.py b/src/gt4py/next/embedded/__init__.py index 6c43e2f12a..e0cb114148 100644 --- a/src/gt4py/next/embedded/__init__.py +++ b/src/gt4py/next/embedded/__init__.py @@ -11,3 +11,13 @@ # distribution for a copy of the license or check . # # SPDX-License-Identifier: GPL-3.0-or-later + +from . import common, context, exceptions, nd_array_field + + +__all__ = [ + "common", + "context", + "exceptions", + "nd_array_field", +] diff --git a/src/gt4py/next/embedded/common.py b/src/gt4py/next/embedded/common.py index 37ba4954f3..d796189ab3 100644 --- a/src/gt4py/next/embedded/common.py +++ b/src/gt4py/next/embedded/common.py @@ -12,8 +12,9 @@ # # SPDX-License-Identifier: GPL-3.0-or-later -from typing import Any, Optional, Sequence, cast +from __future__ import annotations +from gt4py.eve.extended_typing import Any, Optional, Sequence, cast from gt4py.next import common from gt4py.next.embedded import exceptions as embedded_exceptions diff --git a/src/gt4py/next/embedded/context.py b/src/gt4py/next/embedded/context.py new file mode 100644 index 0000000000..5fbdbc6f25 --- /dev/null +++ b/src/gt4py/next/embedded/context.py @@ -0,0 +1,64 @@ +# GT4Py - GridTools Framework +# +# Copyright (c) 2014-2023, ETH Zurich +# All rights reserved. +# +# This file is part of the GT4Py project and the GridTools framework. +# GT4Py is free software: you can redistribute it and/or modify it under +# the terms of the GNU General Public License as published by the +# Free Software Foundation, either version 3 of the License, or any later +# version. See the LICENSE.txt file at the top-level directory of this +# distribution for a copy of the license or check . +# +# SPDX-License-Identifier: GPL-3.0-or-later + +from __future__ import annotations + +import contextlib +import contextvars as cvars +from typing import Any + +import gt4py.eve as eve +import gt4py.next.common as common + + +#: Column range used in column mode (`column_axis != None`) in the current embedded iterator +#: closure execution context. +closure_column_range: cvars.ContextVar[range] = cvars.ContextVar("column_range") + +_undefined_offset_provider: common.OffsetProvider = {} + +#: Offset provider dict in the current embedded execution context. +offset_provider: cvars.ContextVar[common.OffsetProvider] = cvars.ContextVar( + "offset_provider", default=_undefined_offset_provider +) + + +@contextlib.contextmanager +def new_context( + *, + closure_column_range: range | eve.NothingType = eve.NOTHING, + offset_provider: common.OffsetProvider | eve.NothingType = eve.NOTHING, +): + import gt4py.next.embedded.context as this_module + + updates: list[tuple[cvars.ContextVar[Any], Any]] = [] + if closure_column_range is not eve.NOTHING: + updates.append((this_module.closure_column_range, closure_column_range)) + if offset_provider is not eve.NOTHING: + updates.append((this_module.offset_provider, offset_provider)) + + # Create new context with provided values + ctx = cvars.copy_context() + + def ctx_updater(*args): + for cvar, value in args: + cvar.set(value) + + ctx.run(ctx_updater, *updates) + + yield ctx + + +def within_context() -> bool: + return offset_provider.get() is not _undefined_offset_provider diff --git a/src/gt4py/next/embedded/nd_array_field.py b/src/gt4py/next/embedded/nd_array_field.py index a843772a20..ff6a2ceac7 100644 --- a/src/gt4py/next/embedded/nd_array_field.py +++ b/src/gt4py/next/embedded/nd_array_field.py @@ -19,12 +19,13 @@ import operator from collections.abc import Callable, Sequence from types import ModuleType -from typing import Any, ClassVar, Optional, ParamSpec, TypeAlias, TypeVar +from typing import ClassVar import numpy as np from numpy import typing as npt from gt4py._core import definitions as core_defs +from gt4py.eve.extended_typing import Any, Never, Optional, ParamSpec, TypeAlias, TypeVar from gt4py.next import common from gt4py.next.embedded import common as embedded_common from gt4py.next.ffront import fbuiltins @@ -126,6 +127,10 @@ def asnumpy(self) -> np.ndarray: else: return np.asarray(self._ndarray) + @property + def codomain(self) -> type[core_defs.ScalarT]: + return self.dtype.scalar_type + @property def dtype(self) -> core_defs.DType[core_defs.ScalarT]: return core_defs.dtype(self._ndarray.dtype.type) @@ -153,12 +158,53 @@ def from_array( assert all(isinstance(d, common.Dimension) for d in domain.dims), domain assert len(domain) == array.ndim - assert all(len(r) == s or s == 1 for r, s in zip(domain.ranges, array.shape)) + assert all(s == 1 or len(r) == s for r, s in zip(domain.ranges, array.shape)) return cls(domain, array) - def remap(self: NdArrayField, connectivity) -> NdArrayField: - raise NotImplementedError() + def remap( + self: NdArrayField, connectivity: common.ConnectivityField | fbuiltins.FieldOffset + ) -> NdArrayField: + # For neighbor reductions, a FieldOffset is passed instead of an actual ConnectivityField + if not common.is_connectivity_field(connectivity): + assert isinstance(connectivity, fbuiltins.FieldOffset) + connectivity = connectivity.as_connectivity_field() + + assert common.is_connectivity_field(connectivity) + + # Compute the new domain + dim = connectivity.codomain + dim_idx = self.domain.dim_index(dim) + if dim_idx is None: + raise ValueError(f"Incompatible index field, expected a field with dimension {dim}.") + + current_range: common.UnitRange = self.domain[dim_idx][1] + new_ranges = connectivity.inverse_image(current_range) + new_domain = self.domain.replace(dim_idx, *new_ranges) + + # perform contramap + if not (connectivity.kind & common.ConnectivityKind.MODIFY_STRUCTURE): + # shortcut for compact remap: don't change the array, only the domain + new_buffer = self._ndarray + else: + # general case: first restrict the connectivity to the new domain + restricted_connectivity_domain = common.Domain(*new_ranges) + restricted_connectivity = ( + connectivity.restrict(restricted_connectivity_domain) + if restricted_connectivity_domain != connectivity.domain + else connectivity + ) + assert common.is_connectivity_field(restricted_connectivity) + + # then compute the index array + xp = self.array_ns + new_idx_array = xp.asarray(restricted_connectivity.ndarray) - current_range.start + # finally, take the new array + new_buffer = xp.take(self._ndarray, new_idx_array, axis=dim_idx) + + return self.__class__.from_array(new_buffer, domain=new_domain, dtype=self.dtype) + + __call__ = remap # type: ignore[assignment] def restrict(self, index: common.AnyIndexSpec) -> common.Field | core_defs.ScalarT: new_domain, buffer_slice = self._slice(index) @@ -172,7 +218,22 @@ def restrict(self, index: common.AnyIndexSpec) -> common.Field | core_defs.Scala __getitem__ = restrict - __call__ = None # type: ignore[assignment] # TODO: remap + def __setitem__( + self: NdArrayField[common.DimsT, core_defs.ScalarT], + index: common.AnyIndexSpec, + value: common.Field | core_defs.NDArrayObject | core_defs.ScalarT, + ) -> None: + target_domain, target_slice = self._slice(index) + + if common.is_field(value): + if not value.domain == target_domain: + raise ValueError( + f"Incompatible `Domain` in assignment. Source domain = {value.domain}, target domain = {target_domain}." + ) + value = value.ndarray + + assert hasattr(self.ndarray, "__setitem__") + self._ndarray[target_slice] = value # type: ignore[index] # np and cp allow index assignment, jax overrides __abs__ = _make_builtin("abs", "abs") @@ -194,9 +255,17 @@ def restrict(self, index: common.AnyIndexSpec) -> common.Field | core_defs.Scala __mod__ = __rmod__ = _make_builtin("mod", "mod") - __ne__ = _make_builtin("not_equal", "not_equal") # type: ignore[assignment] # mypy wants return `bool` + __ne__ = _make_builtin("not_equal", "not_equal") # type: ignore # mypy wants return `bool` + + __eq__ = _make_builtin("equal", "equal") # type: ignore # mypy wants return `bool` + + __gt__ = _make_builtin("greater", "greater") + + __ge__ = _make_builtin("greater_equal", "greater_equal") + + __lt__ = _make_builtin("less", "less") - __eq__ = _make_builtin("equal", "equal") # type: ignore[assignment] # mypy wants return `bool` + __le__ = _make_builtin("less_equal", "less_equal") def __and__(self, other: common.Field | core_defs.ScalarT) -> NdArrayField: if self.dtype == core_defs.BoolDType(): @@ -239,6 +308,144 @@ def _slice( return new_domain, slice_ +@dataclasses.dataclass(frozen=True) +class NdArrayConnectivityField( # type: ignore[misc] # for __ne__, __eq__ + common.ConnectivityField[common.DimsT, common.DimT], + NdArrayField[common.DimsT, core_defs.IntegralScalar], +): + _codomain: common.DimT + + @functools.cached_property + def _cache(self) -> dict: + return {} + + @classmethod + def __gt_builtin_func__(cls, _: fbuiltins.BuiltInFunction) -> Never: # type: ignore[override] + raise NotImplementedError() + + @property + def codomain(self) -> common.DimT: # type: ignore[override] # TODO(havogt): instead of inheriting from NdArrayField, steal implementation or common base + return self._codomain + + @functools.cached_property + def kind(self) -> common.ConnectivityKind: + kind = common.ConnectivityKind.MODIFY_STRUCTURE + if self.domain.ndim > 1: + kind |= common.ConnectivityKind.MODIFY_RANK + kind |= common.ConnectivityKind.MODIFY_DIMS + if self.domain.dim_index(self.codomain) is None: + kind |= common.ConnectivityKind.MODIFY_DIMS + + return kind + + @classmethod + def from_array( # type: ignore[override] + cls, + data: npt.ArrayLike | core_defs.NDArrayObject, + /, + codomain: common.DimT, + *, + domain: common.DomainLike, + dtype: Optional[core_defs.DTypeLike] = None, + ) -> NdArrayConnectivityField: + domain = common.domain(domain) + xp = cls.array_ns + + xp_dtype = None if dtype is None else xp.dtype(core_defs.dtype(dtype).scalar_type) + array = xp.asarray(data, dtype=xp_dtype) + + if dtype is not None: + assert array.dtype.type == core_defs.dtype(dtype).scalar_type + + assert issubclass(array.dtype.type, core_defs.INTEGRAL_TYPES) + + assert all(isinstance(d, common.Dimension) for d in domain.dims), domain + assert len(domain) == array.ndim + assert all(len(r) == s or s == 1 for r, s in zip(domain.ranges, array.shape)) + + assert isinstance(codomain, common.Dimension) + + return cls(domain, array, codomain) + + def inverse_image( + self, image_range: common.UnitRange | common.NamedRange + ) -> Sequence[common.NamedRange]: + cache_key = hash((id(self.ndarray), self.domain, image_range)) + + if (new_dims := self._cache.get(cache_key, None)) is None: + xp = self.array_ns + + if not isinstance( + image_range, common.UnitRange + ): # TODO(havogt): cleanup duplication with CartesianConnectivity + if image_range[0] != self.codomain: + raise ValueError( + f"Dimension {image_range[0]} does not match the codomain dimension {self.codomain}" + ) + + image_range = image_range[1] + + assert isinstance(image_range, common.UnitRange) + + restricted_mask = (self._ndarray >= image_range.start) & ( + self._ndarray < image_range.stop + ) + # indices of non-zero elements in each dimension + nnz: tuple[core_defs.NDArrayObject, ...] = xp.nonzero(restricted_mask) + + new_dims = [] + non_contiguous_dims = [] + + for i, dim_nnz_indices in enumerate(nnz): + # Check if the indices are contiguous + first_data_index = dim_nnz_indices[0] + assert isinstance(first_data_index, core_defs.INTEGRAL_TYPES) + last_data_index = dim_nnz_indices[-1] + assert isinstance(last_data_index, core_defs.INTEGRAL_TYPES) + indices, counts = xp.unique(dim_nnz_indices, return_counts=True) + if len(xp.unique(counts)) == 1 and ( + len(indices) == last_data_index - first_data_index + 1 + ): + dim_range = self._domain[i] + idx_offset = dim_range[1].start + start = idx_offset + first_data_index + assert common.is_int_index(start) + stop = idx_offset + last_data_index + 1 + assert common.is_int_index(stop) + new_dims.append( + common.named_range( + ( + dim_range[0], + (start, stop), + ) + ) + ) + else: + non_contiguous_dims.append(dim_range[0]) + + if non_contiguous_dims: + raise ValueError( + f"Restriction generates non-contiguous dimensions {non_contiguous_dims}" + ) + + return new_dims + + def restrict(self, index: common.AnyIndexSpec) -> common.Field | core_defs.IntegralScalar: + cache_key = (id(self.ndarray), self.domain, index) + + if (restricted_connectivity := self._cache.get(cache_key, None)) is None: + cls = self.__class__ + xp = cls.array_ns + new_domain, buffer_slice = self._slice(index) + new_buffer = xp.asarray(self.ndarray[buffer_slice]) + restricted_connectivity = cls(new_domain, new_buffer, self.codomain) + self._cache[cache_key] = restricted_connectivity + + return restricted_connectivity + + __getitem__ = restrict + + # -- Specialized implementations for builtin operations on array fields -- NdArrayField.register_builtin_func(fbuiltins.abs, NdArrayField.__abs__) # type: ignore[attr-defined] @@ -266,22 +473,30 @@ def _slice( NdArrayField.register_builtin_func(fbuiltins.where, _make_builtin("where", "where")) -def _np_cp_setitem( - self: NdArrayField[common.DimsT, core_defs.ScalarT], - index: common.AnyIndexSpec, - value: common.Field | core_defs.NDArrayObject | core_defs.ScalarT, -) -> None: - target_domain, target_slice = self._slice(index) +def _make_reduction( + builtin_name: str, array_builtin_name: str +) -> Callable[..., NdArrayField[common.DimsT, core_defs.ScalarT],]: + def _builtin_op( + field: NdArrayField[common.DimsT, core_defs.ScalarT], axis: common.Dimension + ) -> NdArrayField[common.DimsT, core_defs.ScalarT]: + if not axis.kind == common.DimensionKind.LOCAL: + raise ValueError("Can only reduce local dimensions.") + if axis not in field.domain.dims: + raise ValueError(f"Field doesn't have dimension {axis}. Cannot reduce.") + reduce_dim_index = field.domain.dims.index(axis) + new_domain = common.Domain(*[nr for nr in field.domain if nr[0] != axis]) + return field.__class__.from_array( + getattr(field.array_ns, array_builtin_name)(field.ndarray, axis=reduce_dim_index), + domain=new_domain, + ) + + _builtin_op.__name__ = builtin_name + return _builtin_op - if common.is_field(value): - if not value.domain == target_domain: - raise ValueError( - f"Incompatible `Domain` in assignment. Source domain = {value.domain}, target domain = {target_domain}." - ) - value = value.ndarray - assert hasattr(self.ndarray, "__setitem__") - self.ndarray[target_slice] = value +NdArrayField.register_builtin_func(fbuiltins.neighbor_sum, _make_reduction("neighbor_sum", "sum")) +NdArrayField.register_builtin_func(fbuiltins.max_over, _make_reduction("max_over", "max")) +NdArrayField.register_builtin_func(fbuiltins.min_over, _make_reduction("min_over", "min")) # -- Concrete array implementations -- @@ -293,11 +508,17 @@ def _np_cp_setitem( class NumPyArrayField(NdArrayField): array_ns: ClassVar[ModuleType] = np - __setitem__ = _np_cp_setitem - common.field.register(np.ndarray, NumPyArrayField.from_array) + +@dataclasses.dataclass(frozen=True, eq=False) +class NumPyArrayConnectivityField(NdArrayConnectivityField): + array_ns: ClassVar[ModuleType] = np + + +common.connectivity.register(np.ndarray, NumPyArrayConnectivityField.from_array) + # CuPy if cp: _nd_array_implementations.append(cp) @@ -306,10 +527,14 @@ class NumPyArrayField(NdArrayField): class CuPyArrayField(NdArrayField): array_ns: ClassVar[ModuleType] = cp - __setitem__ = _np_cp_setitem - common.field.register(cp.ndarray, CuPyArrayField.from_array) + @dataclasses.dataclass(frozen=True, eq=False) + class CuPyArrayConnectivityField(NdArrayConnectivityField): + array_ns: ClassVar[ModuleType] = cp + + common.connectivity.register(cp.ndarray, CuPyArrayConnectivityField.from_array) + # JAX if jnp: _nd_array_implementations.append(jnp) @@ -355,11 +580,13 @@ def _builtins_broadcast( NdArrayField.register_builtin_func(fbuiltins.broadcast, _builtins_broadcast) -def _astype(field: NdArrayField, type_: type) -> NdArrayField: - return field.__class__.from_array(field.ndarray.astype(type_), domain=field.domain) +def _astype(field: common.Field | core_defs.ScalarT | tuple, type_: type) -> NdArrayField: + if isinstance(field, NdArrayField): + return field.__class__.from_array(field.ndarray.astype(type_), domain=field.domain) + raise AssertionError("This is the NdArrayField implementation of `fbuiltins.astype`.") -NdArrayField.register_builtin_func(fbuiltins.astype, _astype) # type: ignore[arg-type] # TODO(havogt) the registry should not be for any Field +NdArrayField.register_builtin_func(fbuiltins.astype, _astype) def _get_slices_from_domain_slice( diff --git a/src/gt4py/next/ffront/decorator.py b/src/gt4py/next/ffront/decorator.py index 107415eb06..7572040e13 100644 --- a/src/gt4py/next/ffront/decorator.py +++ b/src/gt4py/next/ffront/decorator.py @@ -32,7 +32,7 @@ from gt4py._core import definitions as core_defs from gt4py.eve import utils as eve_utils from gt4py.eve.extended_typing import Any, Optional -from gt4py.next import allocators as next_allocators, common +from gt4py.next import allocators as next_allocators, common, embedded as next_embedded from gt4py.next.common import Dimension, DimensionKind, GridType from gt4py.next.ffront import ( dialect_ast_enums, @@ -290,8 +290,8 @@ def __call__(self, *args, offset_provider: dict[str, Dimension], **kwargs) -> No f"Field View Program '{self.itir.id}': Using Python execution, consider selecting a perfomance backend." ) ) - - self.definition(*rewritten_args, **kwargs) + with next_embedded.context.new_context(offset_provider=offset_provider) as ctx: + ctx.run(self.definition, *rewritten_args, **kwargs) return ppi.ensure_processor_kind(self.backend, ppi.ProgramExecutor) @@ -686,6 +686,9 @@ def __call__( offset_provider = kwargs.pop("offset_provider", None) if self.backend is not None: # "out" and "offset_provider" -> field_operator as program + # When backend is None, we are in embedded execution and for now + # we disable the program generation since it would involve generating + # Python source code from a PAST node. args, kwargs = type_info.canonicalize_arguments(self.foast_node.type, args, kwargs) # TODO(tehrengruber): check all offset providers are given # deduce argument types @@ -700,10 +703,20 @@ def __call__( *args, out, offset_provider=offset_provider, **kwargs ) else: - # "out" -> field_operator called from program in embedded execution - # TODO(egparedes): put offset_provider in ctxt var here when implementing remap + # "out" -> field_operator called from program in embedded execution or + # field_operator called directly from Python in embedded execution domain = kwargs.pop("domain", None) - res = self.definition(*args, **kwargs) + if not next_embedded.context.within_context(): + # field_operator from Python in embedded execution + with next_embedded.context.new_context(offset_provider=offset_provider) as ctx: + res = ctx.run(self.definition, *args, **kwargs) + else: + # field_operator from program in embedded execution (offset_provicer is already set) + assert ( + offset_provider is None + or next_embedded.context.offset_provider.get() is offset_provider + ) + res = self.definition(*args, **kwargs) _tuple_assign_field( out, res, domain=None if domain is None else common.domain(domain) ) diff --git a/src/gt4py/next/ffront/fbuiltins.py b/src/gt4py/next/ffront/fbuiltins.py index 706b6a4606..8230e35a35 100644 --- a/src/gt4py/next/ffront/fbuiltins.py +++ b/src/gt4py/next/ffront/fbuiltins.py @@ -13,28 +13,19 @@ # SPDX-License-Identifier: GPL-3.0-or-later import dataclasses +import functools import inspect from builtins import bool, float, int, tuple -from typing import ( - Any, - Callable, - Generic, - Optional, - ParamSpec, - Tuple, - TypeAlias, - TypeVar, - Union, - cast, -) +from typing import Any, Callable, Generic, ParamSpec, Tuple, TypeAlias, TypeVar, Union, cast import numpy as np from numpy import float32, float64, int32, int64 +import gt4py.next as gtx from gt4py._core import definitions as core_defs -from gt4py.next import common -from gt4py.next.common import Dimension, Field # direct import for TYPE_BUILTINS -from gt4py.next.ffront.experimental import as_offset # noqa F401 +from gt4py.next import common, embedded +from gt4py.next.common import Dimension, Field # noqa: F401 # direct import for TYPE_BUILTINS +from gt4py.next.ffront.experimental import as_offset # noqa: F401 from gt4py.next.iterator import runtime from gt4py.next.type_system import type_specifications as ts @@ -43,8 +34,8 @@ PYTHON_TYPE_BUILTIN_NAMES = [t.__name__ for t in PYTHON_TYPE_BUILTINS] TYPE_BUILTINS = [ - Field, - Dimension, + common.Field, + common.Dimension, int32, int64, float32, @@ -214,10 +205,10 @@ def where( @BuiltInFunction def astype( - value: Field | core_defs.ScalarT | Tuple, + value: common.Field | core_defs.ScalarT | Tuple, type_: type, /, -) -> Field | core_defs.ScalarT | Tuple: +) -> common.Field | core_defs.ScalarT | Tuple: if isinstance(value, tuple): return tuple(astype(v, type_) for v in value) # default implementation for scalars, Fields are handled via dispatch @@ -324,7 +315,10 @@ def impl( class FieldOffset(runtime.Offset): source: common.Dimension target: tuple[common.Dimension] | tuple[common.Dimension, common.Dimension] - connectivity: Optional[Any] = None # TODO + + @functools.cached_property + def _cache(self) -> dict: + return {} def __post_init__(self): if len(self.target) == 2 and self.target[1].kind != common.DimensionKind.LOCAL: @@ -332,3 +326,51 @@ def __post_init__(self): def __gt_type__(self): return ts.OffsetType(source=self.source, target=self.target) + + def __getitem__(self, offset: int) -> common.ConnectivityField: + """Serve as a connectivity factory.""" + assert isinstance(self.value, str) + current_offset_provider = embedded.context.offset_provider.get(None) + assert current_offset_provider is not None + offset_definition = current_offset_provider[self.value] + + connectivity: common.ConnectivityField + if isinstance(offset_definition, common.Dimension): + connectivity = common.CartesianConnectivity(offset_definition, offset) + elif isinstance( + offset_definition, gtx.NeighborTableOffsetProvider + ) or common.is_connectivity_field(offset_definition): + unrestricted_connectivity = self.as_connectivity_field() + assert unrestricted_connectivity.domain.ndim > 1 + named_index = (self.target[-1], offset) + connectivity = unrestricted_connectivity[named_index] + else: + raise NotImplementedError() + + return connectivity + + def as_connectivity_field(self): + """Convert to connectivity field using the offset providers in current embedded execution context.""" + assert isinstance(self.value, str) + current_offset_provider = embedded.context.offset_provider.get(None) + assert current_offset_provider is not None + offset_definition = current_offset_provider[self.value] + + cache_key = id(offset_definition) + if (connectivity := self._cache.get(cache_key, None)) is None: + if common.is_connectivity_field(offset_definition): + connectivity = offset_definition + elif isinstance(offset_definition, gtx.NeighborTableOffsetProvider): + assert not offset_definition.has_skip_values + connectivity = gtx.as_connectivity( + domain=self.target, + codomain=self.source, + data=offset_definition.table, + dtype=offset_definition.index_type, + ) + else: + raise NotImplementedError() + + self._cache[cache_key] = connectivity + + return connectivity diff --git a/src/gt4py/next/iterator/embedded.py b/src/gt4py/next/iterator/embedded.py index 9000b00d8f..b02d6c8d72 100644 --- a/src/gt4py/next/iterator/embedded.py +++ b/src/gt4py/next/iterator/embedded.py @@ -51,8 +51,9 @@ from gt4py._core import definitions as core_defs from gt4py.eve import extended_typing as xtyping -from gt4py.next import common +from gt4py.next import common, embedded as next_embedded from gt4py.next.embedded import exceptions as embedded_exceptions +from gt4py.next.ffront import fbuiltins from gt4py.next.iterator import builtins, runtime @@ -60,7 +61,7 @@ # Atoms -Tag: TypeAlias = str +Tag: TypeAlias = common.Tag ArrayIndex: TypeAlias = slice | common.IntIndex ArrayIndexOrIndices: TypeAlias = ArrayIndex | tuple[ArrayIndex, ...] @@ -129,8 +130,8 @@ def mapped_index( # Offsets OffsetPart: TypeAlias = Tag | common.IntIndex CompleteOffset: TypeAlias = tuple[Tag, common.IntIndex] -OffsetProviderElem: TypeAlias = common.Dimension | common.Connectivity -OffsetProvider: TypeAlias = dict[Tag, OffsetProviderElem] +OffsetProviderElem: TypeAlias = common.OffsetProviderElem +OffsetProvider: TypeAlias = common.OffsetProvider # Positions SparsePositionEntry = list[int] @@ -195,9 +196,9 @@ def field_setitem(self, indices: NamedFieldIndices, value: Any) -> None: #: Column range used in column mode (`column_axis != None`) in the current closure execution context. -column_range_cvar: cvars.ContextVar[range] = cvars.ContextVar("column_range") +column_range_cvar: cvars.ContextVar[range] = next_embedded.context.closure_column_range #: Offset provider dict in the current closure execution context. -offset_provider_cvar: cvars.ContextVar[OffsetProvider] = cvars.ContextVar("offset_provider") +offset_provider_cvar: cvars.ContextVar[OffsetProvider] = next_embedded.context.offset_provider class Column(np.lib.mixins.NDArrayOperatorsMixin): @@ -1060,6 +1061,10 @@ def __gt_builtin_func__(func: Callable, /) -> NoReturn: # type: ignore[override def domain(self) -> common.Domain: return common.Domain((self._dimension, common.UnitRange.infinity())) + @property + def codomain(self) -> type[core_defs.int32]: + return core_defs.int32 + @property def dtype(self) -> core_defs.Int32DType: return core_defs.Int32DType() @@ -1071,7 +1076,7 @@ def ndarray(self) -> core_defs.NDArrayObject: def asnumpy(self) -> np.ndarray: raise NotImplementedError() - def remap(self, index_field: common.Field) -> common.Field: + def remap(self, index_field: common.ConnectivityField | fbuiltins.FieldOffset) -> common.Field: # TODO can be implemented by constructing and ndarray (but do we know of which kind?) raise NotImplementedError() @@ -1179,6 +1184,10 @@ def domain(self) -> common.Domain: def dtype(self) -> core_defs.DType[core_defs.ScalarT]: return core_defs.dtype(type(self._value)) + @property + def codomain(self) -> type[core_defs.ScalarT]: + return self.dtype.scalar_type + @property def ndarray(self) -> core_defs.NDArrayObject: raise AttributeError("Cannot get `ndarray` of an infinite Field.") @@ -1186,7 +1195,7 @@ def ndarray(self) -> core_defs.NDArrayObject: def asnumpy(self) -> np.ndarray: raise NotImplementedError() - def remap(self, index_field: common.Field) -> common.Field: + def remap(self, index_field: common.ConnectivityField | fbuiltins.FieldOffset) -> common.Field: # TODO can be implemented by constructing and ndarray (but do we know of which kind?) raise NotImplementedError() diff --git a/tests/next_tests/exclusion_matrices.py b/tests/next_tests/exclusion_matrices.py index ef30a61687..a8a508b2fb 100644 --- a/tests/next_tests/exclusion_matrices.py +++ b/tests/next_tests/exclusion_matrices.py @@ -120,8 +120,6 @@ class ProgramFormatterId(_PythonObjectIdMixin, str, enum.Enum): (USES_SCAN_IN_FIELD_OPERATOR, XFAIL, UNSUPPORTED_MESSAGE), ] EMBEDDED_SKIP_LIST = [ - (USES_CARTESIAN_SHIFT, XFAIL, UNSUPPORTED_MESSAGE), - (USES_UNSTRUCTURED_SHIFT, XFAIL, UNSUPPORTED_MESSAGE), (USES_SCAN, XFAIL, UNSUPPORTED_MESSAGE), (USES_DYNAMIC_OFFSETS, XFAIL, UNSUPPORTED_MESSAGE), (CHECKS_SPECIFIC_ERROR, XFAIL, UNSUPPORTED_MESSAGE), diff --git a/tests/next_tests/integration_tests/cases.py b/tests/next_tests/integration_tests/cases.py index 7ef724ee2f..81f216397b 100644 --- a/tests/next_tests/integration_tests/cases.py +++ b/tests/next_tests/integration_tests/cases.py @@ -72,7 +72,7 @@ C2EDim = gtx.Dimension("C2E", kind=common.DimensionKind.LOCAL) V2E = gtx.FieldOffset("V2E", source=Edge, target=(Vertex, V2EDim)) E2V = gtx.FieldOffset("E2V", source=Vertex, target=(Edge, E2VDim)) -C2E = gtx.FieldOffset("E2V", source=Edge, target=(Cell, C2EDim)) +C2E = gtx.FieldOffset("C2E", source=Edge, target=(Cell, C2EDim)) ScalarValue: TypeAlias = core_defs.Scalar FieldValue: TypeAlias = gtx.Field diff --git a/tests/next_tests/integration_tests/feature_tests/ffront_tests/ffront_test_utils.py b/tests/next_tests/integration_tests/feature_tests/ffront_tests/ffront_test_utils.py index 01c78cf950..1537c01642 100644 --- a/tests/next_tests/integration_tests/feature_tests/ffront_tests/ffront_test_utils.py +++ b/tests/next_tests/integration_tests/feature_tests/ffront_tests/ffront_test_utils.py @@ -209,7 +209,7 @@ def reduction_setup(): inp=gtx.as_field([Edge], np.arange(num_edges, dtype=np.int32)), out=gtx.as_field([Vertex], np.zeros([num_vertices], dtype=np.int32)), offset_provider={ - "V2E": gtx.NeighborTableOffsetProvider(v2e_arr, Vertex, Edge, 4), + "V2E": gtx.NeighborTableOffsetProvider(v2e_arr, Vertex, Edge, 4, has_skip_values=False), "E2V": gtx.NeighborTableOffsetProvider(e2v_arr, Edge, Vertex, 2, has_skip_values=False), "C2V": gtx.NeighborTableOffsetProvider(c2v_arr, Cell, Vertex, 4, has_skip_values=False), "C2E": gtx.NeighborTableOffsetProvider(c2e_arr, Cell, Edge, 4, has_skip_values=False), diff --git a/tests/next_tests/integration_tests/feature_tests/ffront_tests/test_execution.py b/tests/next_tests/integration_tests/feature_tests/ffront_tests/test_execution.py index 1f3b54d6f0..cf273a4524 100644 --- a/tests/next_tests/integration_tests/feature_tests/ffront_tests/test_execution.py +++ b/tests/next_tests/integration_tests/feature_tests/ffront_tests/test_execution.py @@ -992,6 +992,10 @@ def fieldop_where_k_offset( ) -> cases.IKField: return where(k_index > 0, inp(Koff[-1]), 2) + @gtx.program + def prog(inp: cases.IKField, k_index: gtx.Field[[KDim], gtx.IndexType], out: cases.IKField): + fieldop_where_k_offset(inp, k_index, out=out, domain={IDim: (0, 10), KDim: (1, 10)}) + inp = cases.allocate(cartesian_case, fieldop_where_k_offset, "inp")() k_index = cases.allocate( cartesian_case, fieldop_where_k_offset, "k_index", strategy=cases.IndexInitializer() @@ -1000,7 +1004,7 @@ def fieldop_where_k_offset( ref = np.where(k_index.asnumpy() > 0, np.roll(inp.asnumpy(), 1, axis=1), 2) - cases.verify(cartesian_case, fieldop_where_k_offset, inp, k_index, out=out, ref=ref) + cases.verify(cartesian_case, prog, inp, k_index, out=out[:, 1:], ref=ref[:, 1:]) def test_undefined_symbols(cartesian_case): diff --git a/tests/next_tests/integration_tests/multi_feature_tests/ffront_tests/test_icon_like_scan.py b/tests/next_tests/integration_tests/multi_feature_tests/ffront_tests/test_icon_like_scan.py index 8b4cedd98b..130f6bd29c 100644 --- a/tests/next_tests/integration_tests/multi_feature_tests/ffront_tests/test_icon_like_scan.py +++ b/tests/next_tests/integration_tests/multi_feature_tests/ffront_tests/test_icon_like_scan.py @@ -28,7 +28,7 @@ ) -pytestmark = pytest.mark.uses_unstructured_shift +pytestmark = [pytest.mark.uses_unstructured_shift, pytest.mark.uses_scan] Cell = gtx.Dimension("Cell") diff --git a/tests/next_tests/unit_tests/embedded_tests/test_basic_program.py b/tests/next_tests/unit_tests/embedded_tests/test_basic_program.py new file mode 100644 index 0000000000..335a08571f --- /dev/null +++ b/tests/next_tests/unit_tests/embedded_tests/test_basic_program.py @@ -0,0 +1,47 @@ +# GT4Py - GridTools Framework +# +# Copyright (c) 2014-2023, ETH Zurich +# All rights reserved. +# +# This file is part of the GT4Py project and the GridTools framework. +# GT4Py is free software: you can redistribute it and/or modify it under +# the terms of the GNU General Public License as published by the +# Free Software Foundation, either version 3 of the License, or any later +# version. See the LICENSE.txt file at the top-level directory of this +# distribution for a copy of the license or check . +# +# SPDX-License-Identifier: GPL-3.0-or-later + +import numpy as np + +import gt4py.next as gtx + + +IDim = gtx.Dimension("IDim") +IOff = gtx.FieldOffset("IOff", source=IDim, target=(IDim,)) + + +@gtx.field_operator +def fop( + a: gtx.Field[[IDim], gtx.float64], b: gtx.Field[[IDim], gtx.float64] +) -> gtx.Field[[IDim], gtx.float64]: + return a(IOff[1]) + b + + +@gtx.program +def prog( + a: gtx.Field[[IDim], gtx.float64], + b: gtx.Field[[IDim], gtx.float64], + out: gtx.Field[[IDim], gtx.float64], +): + fop(a, b, out=out) + + +def test_basic(): + a = gtx.as_field([(IDim, gtx.common.UnitRange(1, 5))], np.asarray([0.0, 1.0, 2.0, 3.0])) + b = gtx.as_field([(IDim, gtx.common.UnitRange(0, 4))], np.asarray([0.0, 1.0, 2.0, 3.0])) + out = gtx.as_field([(IDim, gtx.common.UnitRange(0, 4))], np.asarray([0.0, 0.0, 0.0, 0.0])) + + prog(a, b, out, offset_provider={"IOff": IDim}) + assert out.domain == b.domain + assert np.allclose(out.ndarray, a.ndarray + b.ndarray) diff --git a/tests/next_tests/unit_tests/embedded_tests/test_nd_array_field.py b/tests/next_tests/unit_tests/embedded_tests/test_nd_array_field.py index 436e672cc5..2b78eb9114 100644 --- a/tests/next_tests/unit_tests/embedded_tests/test_nd_array_field.py +++ b/tests/next_tests/unit_tests/embedded_tests/test_nd_array_field.py @@ -20,7 +20,7 @@ import numpy as np import pytest -from gt4py.next import common, constructors +from gt4py.next import common, embedded from gt4py.next.common import Dimension, Domain, UnitRange from gt4py.next.embedded import exceptions as embedded_exceptions, nd_array_field from gt4py.next.embedded.nd_array_field import _get_slices_from_domain_slice @@ -70,12 +70,17 @@ def unary_logical_op(request): yield request.param -def _make_field(lst: Iterable, nd_array_implementation, *, dtype=None): +def _make_field(lst: Iterable, nd_array_implementation, *, domain=None, dtype=None): if not dtype: dtype = nd_array_implementation.float32 + buffer = nd_array_implementation.asarray(lst, dtype=dtype) + if domain is None: + domain = tuple( + (common.Dimension(f"D{i}"), common.UnitRange(0, s)) for i, s in enumerate(buffer.shape) + ) return common.field( - nd_array_implementation.asarray(lst, dtype=dtype), - domain={common.Dimension("foo"): (0, len(lst))}, + buffer, + domain=domain, ) @@ -277,6 +282,59 @@ def fma(a: common.Field, b: common.Field, c: common.Field, /) -> common.Field: assert np.allclose(result.ndarray, expected) +def test_remap_implementation(): + V = Dimension("V") + E = Dimension("E") + + V_START, V_STOP = 2, 7 + E_START, E_STOP = 0, 10 + v_field = common.field( + -0.1 * np.arange(V_START, V_STOP), + domain=common.Domain(dims=(V,), ranges=(UnitRange(V_START, V_STOP),)), + ) + e2v_conn = common.connectivity( + np.arange(E_START, E_STOP), + domain=common.Domain( + dims=(E,), + ranges=[ + UnitRange(E_START, E_STOP), + ], + ), + codomain=V, + ) + + result = v_field.remap(e2v_conn) + expected = common.field( + -0.1 * np.arange(V_START, V_STOP), + domain=common.Domain(dims=(E,), ranges=(UnitRange(V_START, V_STOP),)), + ) + + assert result.domain == expected.domain + assert np.all(result.ndarray == expected.ndarray) + + +def test_cartesian_remap_implementation(): + V = Dimension("V") + E = Dimension("E") + + V_START, V_STOP = 2, 7 + OFFSET = 2 + v_field = common.field( + -0.1 * np.arange(V_START, V_STOP), + domain=common.Domain(dims=(V,), ranges=(UnitRange(V_START, V_STOP),)), + ) + v2_conn = common.connectivity(OFFSET, V) + + result = v_field.remap(v2_conn) + expected = common.field( + v_field.ndarray, + domain=common.Domain(dims=(V,), ranges=(UnitRange(V_START - OFFSET, V_STOP - OFFSET),)), + ) + + assert result.domain == expected.domain + assert np.all(result.ndarray == expected.ndarray) + + @pytest.mark.parametrize( "new_dims,field,expected_domain", [ diff --git a/tests/next_tests/unit_tests/test_common.py b/tests/next_tests/unit_tests/test_common.py index 84008eb99c..da63536953 100644 --- a/tests/next_tests/unit_tests/test_common.py +++ b/tests/next_tests/unit_tests/test_common.py @@ -21,6 +21,7 @@ DimensionKind, Domain, Infinity, + NamedRange, UnitRange, domain, named_range, @@ -319,6 +320,134 @@ def test_domain_dims_ranges_length_mismatch(): Domain(dims=dims, ranges=ranges) +def test_domain_dim_index(): + dims = [Dimension("X"), Dimension("Y"), Dimension("Z")] + ranges = [UnitRange(0, 1), UnitRange(0, 1), UnitRange(0, 1)] + domain = Domain(dims=dims, ranges=ranges) + + domain.dim_index(Dimension("Y")) == 1 + + domain.dim_index(Dimension("Foo")) == None + + +def test_domain_pop(): + dims = [Dimension("X"), Dimension("Y"), Dimension("Z")] + ranges = [UnitRange(0, 1), UnitRange(0, 1), UnitRange(0, 1)] + domain = Domain(dims=dims, ranges=ranges) + + domain.pop(Dimension("X")) == Domain(dims=dims[1:], ranges=ranges[1:]) + + domain.pop(0) == Domain(dims=dims[1:], ranges=ranges[1:]) + + domain.pop(-1) == Domain(dims=dims[:-1], ranges=ranges[:-1]) + + +@pytest.mark.parametrize( + "index, named_ranges, domain, expected", + [ + # Valid index and named ranges + ( + 0, + [(Dimension("X"), UnitRange(100, 110))], + Domain( + (Dimension("I"), UnitRange(0, 10)), + (Dimension("J"), UnitRange(0, 10)), + (Dimension("K"), UnitRange(0, 10)), + ), + Domain( + (Dimension("X"), UnitRange(100, 110)), + (Dimension("J"), UnitRange(0, 10)), + (Dimension("K"), UnitRange(0, 10)), + ), + ), + ( + 1, + [(Dimension("X"), UnitRange(100, 110))], + Domain( + (Dimension("I"), UnitRange(0, 10)), + (Dimension("J"), UnitRange(0, 10)), + (Dimension("K"), UnitRange(0, 10)), + ), + Domain( + (Dimension("I"), UnitRange(0, 10)), + (Dimension("X"), UnitRange(100, 110)), + (Dimension("K"), UnitRange(0, 10)), + ), + ), + ( + -1, + [(Dimension("X"), UnitRange(100, 110))], + Domain( + (Dimension("I"), UnitRange(0, 10)), + (Dimension("J"), UnitRange(0, 10)), + (Dimension("K"), UnitRange(0, 10)), + ), + Domain( + (Dimension("I"), UnitRange(0, 10)), + (Dimension("J"), UnitRange(0, 10)), + (Dimension("X"), UnitRange(100, 110)), + ), + ), + ( + Dimension("J"), + [(Dimension("X"), UnitRange(100, 110)), (Dimension("Z"), UnitRange(100, 110))], + Domain( + (Dimension("I"), UnitRange(0, 10)), + (Dimension("J"), UnitRange(0, 10)), + (Dimension("K"), UnitRange(0, 10)), + ), + Domain( + (Dimension("I"), UnitRange(0, 10)), + (Dimension("X"), UnitRange(100, 110)), + (Dimension("Z"), UnitRange(100, 110)), + (Dimension("K"), UnitRange(0, 10)), + ), + ), + # Invalid indices + ( + 3, + [(Dimension("X"), UnitRange(100, 110))], + Domain( + (Dimension("I"), UnitRange(0, 10)), + (Dimension("J"), UnitRange(0, 10)), + (Dimension("K"), UnitRange(0, 10)), + ), + IndexError, + ), + ( + -4, + [(Dimension("X"), UnitRange(100, 110))], + Domain( + (Dimension("I"), UnitRange(0, 10)), + (Dimension("J"), UnitRange(0, 10)), + (Dimension("K"), UnitRange(0, 10)), + ), + IndexError, + ), + ( + Dimension("Foo"), + [(Dimension("X"), UnitRange(100, 110))], + Domain( + (Dimension("I"), UnitRange(0, 10)), + (Dimension("J"), UnitRange(0, 10)), + (Dimension("K"), UnitRange(0, 10)), + ), + ValueError, + ), + ], +) +def test_domain_replace(index, named_ranges, domain, expected): + if expected is ValueError: + with pytest.raises(ValueError): + domain.replace(index, *named_ranges) + elif expected is IndexError: + with pytest.raises(IndexError): + domain.replace(index, *named_ranges) + else: + new_domain = domain.replace(index, *named_ranges) + assert new_domain == expected + + def dimension_promotion_cases() -> ( list[tuple[list[list[Dimension]], list[Dimension] | None, None | Pattern]] ): From 5a409c560444b73b8807e25552837b10d33cf8f7 Mon Sep 17 00:00:00 2001 From: edopao Date: Thu, 23 Nov 2023 13:43:17 +0100 Subject: [PATCH 12/85] feat[next]: DaCe backend - enable GPU test execution (#1360) This commit enables test execution on the DaCe GPU backend: * Small fix in DaCe SDFG generation for GPU execution. The fix is about handling of in/out fields, for which the program argument is copied to a transient array. We need to use for the transient array the same storage as the program argument (i.e. gpu storage), otherwise code generation will throw an error because of mixed storage for inputs to the closure map. * Minor code refactoring (test exclusion matrix, dace backend processor interface) * Cleanup test exclusion matrix (some left-overs after rebase of previous dace PR) Note that 3 testcases are disabled because the fix needs to be delivered to DaCe repo and a new DaCe release should be provided, in order to update the GT4Py dependency list. --- .../ADRs/0015-Test_Exclusion_Matrices.md | 4 +- pyproject.toml | 2 + .../runners/dace_iterator/__init__.py | 48 +++++++++-------- .../runners/dace_iterator/itir_to_sdfg.py | 14 ++++- tests/next_tests/exclusion_matrices.py | 54 +++++++++---------- .../ffront_tests/ffront_test_utils.py | 3 ++ .../ffront_tests/test_external_local_field.py | 10 ++++ .../ffront_tests/test_gt4py_builtins.py | 10 ++++ .../ffront_tests/test_math_unary_builtins.py | 12 +---- tests/next_tests/unit_tests/conftest.py | 6 +++ 10 files changed, 101 insertions(+), 62 deletions(-) diff --git a/docs/development/ADRs/0015-Test_Exclusion_Matrices.md b/docs/development/ADRs/0015-Test_Exclusion_Matrices.md index 6c6a043560..b338169d61 100644 --- a/docs/development/ADRs/0015-Test_Exclusion_Matrices.md +++ b/docs/development/ADRs/0015-Test_Exclusion_Matrices.md @@ -47,10 +47,12 @@ by calling `next_tests.get_processor_id()`, which returns the so-called processo The following backend processors are defined: ```python -DACE = "dace_iterator.run_dace_iterator" +DACE_CPU = "dace_iterator.run_dace_cpu" +DACE_GPU = "dace_iterator.run_dace_gpu" GTFN_CPU = "otf_compile_executor.run_gtfn" GTFN_CPU_IMPERATIVE = "otf_compile_executor.run_gtfn_imperative" GTFN_CPU_WITH_TEMPORARIES = "otf_compile_executor.run_gtfn_with_temporaries" +GTFN_GPU = "gt4py.next.program_processors.runners.gtfn.run_gtfn_gpu" ``` Following the previous example, the GTFN backend with temporaries does not support yet dynamic offsets in ITIR: diff --git a/pyproject.toml b/pyproject.toml index 041448e17d..2cf4fb12e2 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -332,12 +332,14 @@ markers = [ 'uses_applied_shifts: tests that require backend support for applied-shifts', 'uses_constant_fields: tests that require backend support for constant fields', 'uses_dynamic_offsets: tests that require backend support for dynamic offsets', + 'uses_floordiv: tests that require backend support for floor division', 'uses_if_stmts: tests that require backend support for if-statements', 'uses_index_fields: tests that require backend support for index fields', 'uses_lift_expressions: tests that require backend support for lift expressions', 'uses_negative_modulo: tests that require backend support for modulo on negative numbers', 'uses_origin: tests that require backend support for domain origin', 'uses_reduction_over_lift_expressions: tests that require backend support for reduction over lift expressions', + 'uses_reduction_with_only_sparse_fields: tests that require backend support for with sparse fields', 'uses_scan_in_field_operator: tests that require backend support for scan in field operator', 'uses_sparse_fields: tests that require backend support for sparse fields', 'uses_strided_neighbor_offset: tests that require backend support for strided neighbor offset', diff --git a/src/gt4py/next/program_processors/runners/dace_iterator/__init__.py b/src/gt4py/next/program_processors/runners/dace_iterator/__init__.py index e3fba87571..40b6d24b0e 100644 --- a/src/gt4py/next/program_processors/runners/dace_iterator/__init__.py +++ b/src/gt4py/next/program_processors/runners/dace_iterator/__init__.py @@ -12,6 +12,7 @@ # # SPDX-License-Identifier: GPL-3.0-or-later import hashlib +import warnings from typing import Any, Mapping, Optional, Sequence import dace @@ -22,11 +23,11 @@ import gt4py.next.allocators as next_allocators import gt4py.next.iterator.ir as itir import gt4py.next.program_processors.otf_compile_executor as otf_exec +import gt4py.next.program_processors.processor_interface as ppi from gt4py.next.common import Dimension, Domain, UnitRange, is_field from gt4py.next.iterator.embedded import NeighborTableOffsetProvider, StridedNeighborOffsetProvider from gt4py.next.iterator.transforms import LiftMode, apply_common_transforms from gt4py.next.otf.compilation import cache -from gt4py.next.program_processors.processor_interface import program_executor from gt4py.next.type_system import type_specifications as ts, type_translation from .itir_to_sdfg import ItirToSDFG @@ -94,10 +95,26 @@ def get_args(params: Sequence[itir.Sym], args: Sequence[Any]) -> dict[str, Any]: return {name.id: convert_arg(arg) for name, arg in zip(params, args)} +def _ensure_is_on_device( + connectivity_arg: np.typing.NDArray, device: dace.dtypes.DeviceType +) -> np.typing.NDArray: + if device == dace.dtypes.DeviceType.GPU: + if not isinstance(connectivity_arg, cp.ndarray): + warnings.warn( + "Copying connectivity to device. For performance make sure connectivity is provided on device." + ) + return cp.asarray(connectivity_arg) + return connectivity_arg + + def get_connectivity_args( - neighbor_tables: Sequence[tuple[str, NeighborTableOffsetProvider]] + neighbor_tables: Sequence[tuple[str, NeighborTableOffsetProvider]], + device: dace.dtypes.DeviceType, ) -> dict[str, Any]: - return {connectivity_identifier(offset): table.table for offset, table in neighbor_tables} + return { + connectivity_identifier(offset): _ensure_is_on_device(table.table, device) + for offset, table in neighbor_tables + } def get_shape_args( @@ -167,7 +184,6 @@ def get_cache_id( return m.hexdigest() -@program_executor def run_dace_iterator(program: itir.FencilDefinition, *args, **kwargs) -> None: # build parameters auto_optimize = kwargs.get("auto_optimize", False) @@ -182,6 +198,7 @@ def run_dace_iterator(program: itir.FencilDefinition, *args, **kwargs) -> None: offset_provider = kwargs["offset_provider"] arg_types = [type_translation.from_value(arg) for arg in args] + device = dace.DeviceType.GPU if run_on_gpu else dace.DeviceType.CPU neighbor_tables = filter_neighbor_tables(offset_provider) cache_id = get_cache_id(program, arg_types, column_axis, offset_provider) @@ -192,26 +209,16 @@ def run_dace_iterator(program: itir.FencilDefinition, *args, **kwargs) -> None: else: # visit ITIR and generate SDFG program = preprocess_program(program, offset_provider, lift_mode) - sdfg_genenerator = ItirToSDFG(arg_types, offset_provider, column_axis) + sdfg_genenerator = ItirToSDFG(arg_types, offset_provider, column_axis, run_on_gpu) sdfg = sdfg_genenerator.visit(program) sdfg.simplify() - # set array storage for GPU execution - if run_on_gpu: - device = dace.DeviceType.GPU - sdfg._name = f"{sdfg.name}_gpu" - for _, _, array in sdfg.arrays_recursive(): - if not array.transient: - array.storage = dace.dtypes.StorageType.GPU_Global - else: - device = dace.DeviceType.CPU - # run DaCe auto-optimization heuristics if auto_optimize: # TODO Investigate how symbol definitions improve autoopt transformations, # in which case the cache table should take the symbols map into account. symbols: dict[str, int] = {} - sdfg = autoopt.auto_optimize(sdfg, device, symbols=symbols) + sdfg = autoopt.auto_optimize(sdfg, device, symbols=symbols, use_gpu_storage=run_on_gpu) # compile SDFG and retrieve SDFG program sdfg.build_folder = cache._session_cache_dir_path / ".dacecache" @@ -226,7 +233,7 @@ def run_dace_iterator(program: itir.FencilDefinition, *args, **kwargs) -> None: dace_args = get_args(program.params, args) dace_field_args = {n: v for n, v in dace_args.items() if not np.isscalar(v)} - dace_conn_args = get_connectivity_args(neighbor_tables) + dace_conn_args = get_connectivity_args(neighbor_tables, device) dace_shapes = get_shape_args(sdfg.arrays, dace_field_args) dace_conn_shapes = get_shape_args(sdfg.arrays, dace_conn_args) dace_strides = get_stride_args(sdfg.arrays, dace_field_args) @@ -254,7 +261,6 @@ def run_dace_iterator(program: itir.FencilDefinition, *args, **kwargs) -> None: sdfg_program(**expected_args) -@program_executor def _run_dace_cpu(program: itir.FencilDefinition, *args, **kwargs) -> None: run_dace_iterator( program, @@ -267,13 +273,12 @@ def _run_dace_cpu(program: itir.FencilDefinition, *args, **kwargs) -> None: run_dace_cpu = otf_exec.OTFBackend( - executor=_run_dace_cpu, + executor=ppi.program_executor(_run_dace_cpu, name="run_dace_cpu"), allocator=next_allocators.StandardCPUFieldBufferAllocator(), ) if cp: - @program_executor def _run_dace_gpu(program: itir.FencilDefinition, *args, **kwargs) -> None: run_dace_iterator( program, @@ -286,12 +291,11 @@ def _run_dace_gpu(program: itir.FencilDefinition, *args, **kwargs) -> None: else: - @program_executor def _run_dace_gpu(program: itir.FencilDefinition, *args, **kwargs) -> None: raise RuntimeError("Missing `cupy` dependency for GPU execution.") run_dace_gpu = otf_exec.OTFBackend( - executor=_run_dace_gpu, + executor=ppi.program_executor(_run_dace_gpu, name="run_dace_gpu"), allocator=next_allocators.StandardGPUFieldBufferAllocator(), ) diff --git a/src/gt4py/next/program_processors/runners/dace_iterator/itir_to_sdfg.py b/src/gt4py/next/program_processors/runners/dace_iterator/itir_to_sdfg.py index 9e9cc4bf29..a7cecf5fad 100644 --- a/src/gt4py/next/program_processors/runners/dace_iterator/itir_to_sdfg.py +++ b/src/gt4py/next/program_processors/runners/dace_iterator/itir_to_sdfg.py @@ -96,17 +96,20 @@ class ItirToSDFG(eve.NodeVisitor): offset_provider: dict[str, Any] node_types: dict[int, next_typing.Type] unique_id: int + use_gpu_storage: bool def __init__( self, param_types: list[ts.TypeSpec], offset_provider: dict[str, NeighborTableOffsetProvider], column_axis: Optional[Dimension] = None, + use_gpu_storage: bool = False, ): self.param_types = param_types self.column_axis = column_axis self.offset_provider = offset_provider self.storage_types = {} + self.use_gpu_storage = use_gpu_storage def add_storage(self, sdfg: dace.SDFG, name: str, type_: ts.TypeSpec, has_offset: bool = True): if isinstance(type_, ts.FieldType): @@ -118,7 +121,14 @@ def add_storage(self, sdfg: dace.SDFG, name: str, type_: ts.TypeSpec, has_offset else None ) dtype = as_dace_type(type_.dtype) - sdfg.add_array(name, shape=shape, strides=strides, offset=offset, dtype=dtype) + storage = ( + dace.dtypes.StorageType.GPU_Global + if self.use_gpu_storage + else dace.dtypes.StorageType.Default + ) + sdfg.add_array( + name, shape=shape, strides=strides, offset=offset, dtype=dtype, storage=storage + ) elif isinstance(type_, ts.ScalarType): sdfg.add_symbol(name, as_dace_type(type_)) else: @@ -225,6 +235,7 @@ def visit_StencilClosure( shape=array_table[name].shape, strides=array_table[name].strides, dtype=array_table[name].dtype, + storage=array_table[name].storage, transient=True, ) closure_init_state.add_nedge( @@ -239,6 +250,7 @@ def visit_StencilClosure( shape=array_table[name].shape, strides=array_table[name].strides, dtype=array_table[name].dtype, + storage=array_table[name].storage, ) else: assert isinstance(self.storage_types[name], ts.ScalarType) diff --git a/tests/next_tests/exclusion_matrices.py b/tests/next_tests/exclusion_matrices.py index a8a508b2fb..a6a302e143 100644 --- a/tests/next_tests/exclusion_matrices.py +++ b/tests/next_tests/exclusion_matrices.py @@ -57,6 +57,7 @@ class ProgramBackendId(_PythonObjectIdMixin, str, enum.Enum): class OptionalProgramBackendId(_PythonObjectIdMixin, str, enum.Enum): DACE_CPU = "gt4py.next.program_processors.runners.dace_iterator.run_dace_cpu" + DACE_GPU = "gt4py.next.program_processors.runners.dace_iterator.run_dace_gpu" class ProgramExecutorId(_PythonObjectIdMixin, str, enum.Enum): @@ -83,9 +84,9 @@ class ProgramFormatterId(_PythonObjectIdMixin, str, enum.Enum): # Test markers REQUIRES_ATLAS = "requires_atlas" USES_APPLIED_SHIFTS = "uses_applied_shifts" -USES_CAN_DEREF = "uses_can_deref" USES_CONSTANT_FIELDS = "uses_constant_fields" USES_DYNAMIC_OFFSETS = "uses_dynamic_offsets" +USES_FLOORDIV = "uses_floordiv" USES_IF_STMTS = "uses_if_stmts" USES_INDEX_FIELDS = "uses_index_fields" USES_LIFT_EXPRESSIONS = "uses_lift_expressions" @@ -111,7 +112,7 @@ class ProgramFormatterId(_PythonObjectIdMixin, str, enum.Enum): "We cannot unroll a reduction on a sparse field only (not clear if it is legal ITIR)" ) # Common list of feature markers to skip -GTFN_SKIP_TEST_LIST = [ +COMMON_SKIP_TEST_LIST = [ (REQUIRES_ATLAS, XFAIL, BINDINGS_UNSUPPORTED_MESSAGE), (USES_APPLIED_SHIFTS, XFAIL, UNSUPPORTED_MESSAGE), (USES_IF_STMTS, XFAIL, UNSUPPORTED_MESSAGE), @@ -119,46 +120,45 @@ class ProgramFormatterId(_PythonObjectIdMixin, str, enum.Enum): (USES_REDUCTION_WITH_ONLY_SPARSE_FIELDS, XFAIL, REDUCTION_WITH_ONLY_SPARSE_FIELDS_MESSAGE), (USES_SCAN_IN_FIELD_OPERATOR, XFAIL, UNSUPPORTED_MESSAGE), ] +DACE_SKIP_TEST_LIST = COMMON_SKIP_TEST_LIST + [ + (USES_CONSTANT_FIELDS, XFAIL, UNSUPPORTED_MESSAGE), + (USES_DYNAMIC_OFFSETS, XFAIL, UNSUPPORTED_MESSAGE), + (USES_INDEX_FIELDS, XFAIL, UNSUPPORTED_MESSAGE), + (USES_LIFT_EXPRESSIONS, XFAIL, UNSUPPORTED_MESSAGE), + (USES_ORIGIN, XFAIL, UNSUPPORTED_MESSAGE), + (USES_REDUCTION_OVER_LIFT_EXPRESSIONS, XFAIL, UNSUPPORTED_MESSAGE), + (USES_SPARSE_FIELDS, XFAIL, UNSUPPORTED_MESSAGE), + (USES_TUPLE_ARGS, XFAIL, UNSUPPORTED_MESSAGE), + (USES_TUPLE_RETURNS, XFAIL, UNSUPPORTED_MESSAGE), + (USES_ZERO_DIMENSIONAL_FIELDS, XFAIL, UNSUPPORTED_MESSAGE), +] EMBEDDED_SKIP_LIST = [ (USES_SCAN, XFAIL, UNSUPPORTED_MESSAGE), (USES_DYNAMIC_OFFSETS, XFAIL, UNSUPPORTED_MESSAGE), (CHECKS_SPECIFIC_ERROR, XFAIL, UNSUPPORTED_MESSAGE), ] +GTFN_SKIP_TEST_LIST = COMMON_SKIP_TEST_LIST + [ + # floordiv not yet supported, see https://github.com/GridTools/gt4py/issues/1136 + (USES_FLOORDIV, XFAIL, BINDINGS_UNSUPPORTED_MESSAGE), + (USES_STRIDED_NEIGHBOR_OFFSET, XFAIL, BINDINGS_UNSUPPORTED_MESSAGE), +] #: Skip matrix, contains for each backend processor a list of tuples with following fields: #: (, ) BACKEND_SKIP_TEST_MATRIX = { None: EMBEDDED_SKIP_LIST, - OptionalProgramBackendId.DACE_CPU: GTFN_SKIP_TEST_LIST - + [ - (USES_CAN_DEREF, XFAIL, UNSUPPORTED_MESSAGE), - (USES_CONSTANT_FIELDS, XFAIL, UNSUPPORTED_MESSAGE), - (USES_DYNAMIC_OFFSETS, XFAIL, UNSUPPORTED_MESSAGE), - (USES_INDEX_FIELDS, XFAIL, UNSUPPORTED_MESSAGE), - (USES_LIFT_EXPRESSIONS, XFAIL, UNSUPPORTED_MESSAGE), - (USES_ORIGIN, XFAIL, UNSUPPORTED_MESSAGE), - (USES_REDUCTION_OVER_LIFT_EXPRESSIONS, XFAIL, UNSUPPORTED_MESSAGE), - (USES_SPARSE_FIELDS, XFAIL, UNSUPPORTED_MESSAGE), - (USES_TUPLE_ARGS, XFAIL, UNSUPPORTED_MESSAGE), - (USES_TUPLE_RETURNS, XFAIL, UNSUPPORTED_MESSAGE), - (USES_ZERO_DIMENSIONAL_FIELDS, XFAIL, UNSUPPORTED_MESSAGE), - ], - ProgramBackendId.GTFN_CPU: GTFN_SKIP_TEST_LIST - + [ - (USES_STRIDED_NEIGHBOR_OFFSET, XFAIL, BINDINGS_UNSUPPORTED_MESSAGE), - ], - ProgramBackendId.GTFN_GPU: GTFN_SKIP_TEST_LIST - + [ - (USES_STRIDED_NEIGHBOR_OFFSET, XFAIL, BINDINGS_UNSUPPORTED_MESSAGE), - ], - ProgramBackendId.GTFN_CPU_IMPERATIVE: GTFN_SKIP_TEST_LIST + OptionalProgramBackendId.DACE_CPU: DACE_SKIP_TEST_LIST, + OptionalProgramBackendId.DACE_GPU: DACE_SKIP_TEST_LIST + [ - (USES_STRIDED_NEIGHBOR_OFFSET, XFAIL, BINDINGS_UNSUPPORTED_MESSAGE), + # awaiting dace fix, see https://github.com/spcl/dace/pull/1442 + (USES_FLOORDIV, XFAIL, BINDINGS_UNSUPPORTED_MESSAGE), ], + ProgramBackendId.GTFN_CPU: GTFN_SKIP_TEST_LIST, + ProgramBackendId.GTFN_CPU_IMPERATIVE: GTFN_SKIP_TEST_LIST, + ProgramBackendId.GTFN_GPU: GTFN_SKIP_TEST_LIST, ProgramBackendId.GTFN_CPU_WITH_TEMPORARIES: GTFN_SKIP_TEST_LIST + [ (USES_DYNAMIC_OFFSETS, XFAIL, UNSUPPORTED_MESSAGE), - (USES_STRIDED_NEIGHBOR_OFFSET, XFAIL, BINDINGS_UNSUPPORTED_MESSAGE), ], ProgramFormatterId.GTFN_CPP_FORMATTER: [ (USES_REDUCTION_WITH_ONLY_SPARSE_FIELDS, XFAIL, REDUCTION_WITH_ONLY_SPARSE_FIELDS_MESSAGE), diff --git a/tests/next_tests/integration_tests/feature_tests/ffront_tests/ffront_test_utils.py b/tests/next_tests/integration_tests/feature_tests/ffront_tests/ffront_test_utils.py index 1537c01642..f8a3f6a975 100644 --- a/tests/next_tests/integration_tests/feature_tests/ffront_tests/ffront_test_utils.py +++ b/tests/next_tests/integration_tests/feature_tests/ffront_tests/ffront_test_utils.py @@ -45,6 +45,9 @@ def no_backend(program: itir.FencilDefinition, *args: Any, **kwargs: Any) -> Non OPTIONAL_PROCESSORS = [] if dace_iterator: OPTIONAL_PROCESSORS.append(definitions.OptionalProgramBackendId.DACE_CPU) + OPTIONAL_PROCESSORS.append( + pytest.param(definitions.OptionalProgramBackendId.DACE_GPU, marks=pytest.mark.requires_gpu) + ), @pytest.fixture( diff --git a/tests/next_tests/integration_tests/feature_tests/ffront_tests/test_external_local_field.py b/tests/next_tests/integration_tests/feature_tests/ffront_tests/test_external_local_field.py index 05adc63a45..42938e2f4b 100644 --- a/tests/next_tests/integration_tests/feature_tests/ffront_tests/test_external_local_field.py +++ b/tests/next_tests/integration_tests/feature_tests/ffront_tests/test_external_local_field.py @@ -30,6 +30,16 @@ def test_external_local_field(unstructured_case): + # TODO(edopao): remove try/catch after uplift of dace module to version > 0.15 + try: + from gt4py.next.program_processors.runners.dace_iterator import run_dace_gpu + + if unstructured_case.backend == run_dace_gpu: + # see https://github.com/spcl/dace/pull/1442 + pytest.xfail("requires fix in dace module for cuda codegen") + except ImportError: + pass + @gtx.field_operator def testee( inp: gtx.Field[[Vertex, V2EDim], int32], ones: gtx.Field[[Edge], int32] diff --git a/tests/next_tests/integration_tests/feature_tests/ffront_tests/test_gt4py_builtins.py b/tests/next_tests/integration_tests/feature_tests/ffront_tests/test_gt4py_builtins.py index e2434d860a..bbbac6c139 100644 --- a/tests/next_tests/integration_tests/feature_tests/ffront_tests/test_gt4py_builtins.py +++ b/tests/next_tests/integration_tests/feature_tests/ffront_tests/test_gt4py_builtins.py @@ -120,6 +120,16 @@ def fencil(edge_f: cases.EField, out: cases.VField): @pytest.mark.uses_unstructured_shift def test_reduction_with_common_expression(unstructured_case): + # TODO(edopao): remove try/catch after uplift of dace module to version > 0.15 + try: + from gt4py.next.program_processors.runners.dace_iterator import run_dace_gpu + + if unstructured_case.backend == run_dace_gpu: + # see https://github.com/spcl/dace/pull/1442 + pytest.xfail("requires fix in dace module for cuda codegen") + except ImportError: + pass + @gtx.field_operator def testee(flux: cases.EField) -> cases.VField: return neighbor_sum(flux(V2E) + flux(V2E), axis=V2EDim) diff --git a/tests/next_tests/integration_tests/feature_tests/ffront_tests/test_math_unary_builtins.py b/tests/next_tests/integration_tests/feature_tests/ffront_tests/test_math_unary_builtins.py index 8660ecfdbd..c2ab43773f 100644 --- a/tests/next_tests/integration_tests/feature_tests/ffront_tests/test_math_unary_builtins.py +++ b/tests/next_tests/integration_tests/feature_tests/ffront_tests/test_math_unary_builtins.py @@ -37,7 +37,6 @@ tanh, trunc, ) -from gt4py.next.program_processors.runners import gtfn from next_tests.integration_tests import cases from next_tests.integration_tests.cases import IDim, cartesian_case, unstructured_case @@ -67,17 +66,8 @@ def pow(inp1: cases.IField) -> cases.IField: cases.verify_with_default_data(cartesian_case, pow, ref=lambda inp1: inp1**2) +@pytest.mark.uses_floordiv def test_floordiv(cartesian_case): - if cartesian_case.backend in [ - gtfn.run_gtfn, - gtfn.run_gtfn_imperative, - gtfn.run_gtfn_with_temporaries, - gtfn.run_gtfn_gpu, - ]: - pytest.xfail( - "FloorDiv not yet supported." - ) # see https://github.com/GridTools/gt4py/issues/1136 - @gtx.field_operator def floorDiv(inp1: cases.IField) -> cases.IField: return inp1 // 2 diff --git a/tests/next_tests/unit_tests/conftest.py b/tests/next_tests/unit_tests/conftest.py index 372062d08a..6f91557e46 100644 --- a/tests/next_tests/unit_tests/conftest.py +++ b/tests/next_tests/unit_tests/conftest.py @@ -50,6 +50,12 @@ def lift_mode(request): OPTIONAL_PROCESSORS = [] if dace_iterator: OPTIONAL_PROCESSORS.append((definitions.OptionalProgramBackendId.DACE_CPU, True)) + # TODO(havogt): update tests to use proper allocation + # OPTIONAL_PROCESSORS.append( + # pytest.param( + # (definitions.OptionalProgramBackendId.DACE_GPU, True), marks=pytest.mark.requires_gpu + # ) + # ), @pytest.fixture( From 6bea0074305d2261844746ee4f801a72a8e1c435 Mon Sep 17 00:00:00 2001 From: edopao Date: Fri, 24 Nov 2023 12:58:20 +0100 Subject: [PATCH 13/85] fix[next]: Fix for GPU test execution (#1368) This commit removes test_gpu_backend.py Spack build of icon4py was broken because dace is an optional module, not installed in the default environment, and the dace backend is not available in test execution. This caused an ImportError exception in test_gpu_backend.py, because this test is bypassing the test exclusion matrix. The initial proposed fix was to use try/except to handle this case. However, all tests in baseline are already executed on the GPU backends (both GTFN and DaCe), therefore this simple test is no longer needed. --- .../ffront_tests/test_gpu_backend.py | 45 ------------------- 1 file changed, 45 deletions(-) delete mode 100644 tests/next_tests/integration_tests/feature_tests/ffront_tests/test_gpu_backend.py diff --git a/tests/next_tests/integration_tests/feature_tests/ffront_tests/test_gpu_backend.py b/tests/next_tests/integration_tests/feature_tests/ffront_tests/test_gpu_backend.py deleted file mode 100644 index 7054597831..0000000000 --- a/tests/next_tests/integration_tests/feature_tests/ffront_tests/test_gpu_backend.py +++ /dev/null @@ -1,45 +0,0 @@ -# GT4Py - GridTools Framework -# -# Copyright (c) 2014-2023, ETH Zurich -# All rights reserved. -# -# This file is part of the GT4Py project and the GridTools framework. -# GT4Py is free software: you can redistribute it and/or modify it under -# the terms of the GNU General Public License as published by the -# Free Software Foundation, either version 3 of the License, or any later -# version. See the LICENSE.txt file at the top-level directory of this -# distribution for a copy of the license or check . -# -# SPDX-License-Identifier: GPL-3.0-or-later - -import pytest - -import gt4py.next as gtx -from gt4py.next import common -from gt4py.next.program_processors.runners import dace_iterator, gtfn - -from next_tests.integration_tests import cases -from next_tests.integration_tests.cases import cartesian_case # noqa: F401 -from next_tests.integration_tests.feature_tests.ffront_tests.ffront_test_utils import ( # noqa: F401 - fieldview_backend, -) - - -@pytest.mark.requires_gpu -@pytest.mark.parametrize("fieldview_backend", [dace_iterator.run_dace_gpu, gtfn.run_gtfn_gpu]) -def test_copy(fieldview_backend): # noqa: F811 # fixtures - import cupy as cp - - @gtx.field_operator(backend=fieldview_backend) - def testee(a: cases.IJKField) -> cases.IJKField: - return a - - domain = { - cases.IDim: common.unit_range(3), - cases.JDim: common.unit_range(4), - cases.KDim: common.unit_range(5), - } - inp_field = gtx.full(domain, fill_value=3, allocator=fieldview_backend, dtype=cp.int32) - out_field = gtx.zeros(domain, allocator=fieldview_backend, dtype=cp.int32) - testee(inp_field, out=out_field, offset_provider={}) - assert cp.allclose(inp_field.ndarray, out_field.ndarray) From 1e486f2cd8caebcea5f3cc342a9727fc8a9f1d03 Mon Sep 17 00:00:00 2001 From: edopao Date: Mon, 27 Nov 2023 16:01:27 +0100 Subject: [PATCH 14/85] test[next]: Fix warnings that cause Spack to crash (#1369) Solves warning about invalid escape sequence in some regex strings. These warnings cause spack build to crash. --- .../feature_tests/ffront_tests/test_program.py | 4 ++-- .../feature_tests/ffront_tests/test_type_deduction.py | 4 ++-- tests/next_tests/unit_tests/ffront_tests/test_func_to_past.py | 2 +- tests/next_tests/unit_tests/ffront_tests/test_past_to_itir.py | 2 +- 4 files changed, 6 insertions(+), 6 deletions(-) diff --git a/tests/next_tests/integration_tests/feature_tests/ffront_tests/test_program.py b/tests/next_tests/integration_tests/feature_tests/ffront_tests/test_program.py index a0f69f332c..4c0613a33c 100644 --- a/tests/next_tests/integration_tests/feature_tests/ffront_tests/test_program.py +++ b/tests/next_tests/integration_tests/feature_tests/ffront_tests/test_program.py @@ -228,8 +228,8 @@ def test_wrong_argument_type(cartesian_case, copy_program_def): copy_program(inp, out, offset_provider={}) msgs = [ - "- Expected argument `in_field` to be of type `Field\[\[IDim], float64\]`," - " but got `Field\[\[JDim\], float64\]`.", + r"- Expected argument `in_field` to be of type `Field\[\[IDim], float64\]`," + r" but got `Field\[\[JDim\], float64\]`.", ] for msg in msgs: assert re.search(msg, exc_info.value.__cause__.args[0]) is not None diff --git a/tests/next_tests/integration_tests/feature_tests/ffront_tests/test_type_deduction.py b/tests/next_tests/integration_tests/feature_tests/ffront_tests/test_type_deduction.py index dfa710e038..d1a5f24f79 100644 --- a/tests/next_tests/integration_tests/feature_tests/ffront_tests/test_type_deduction.py +++ b/tests/next_tests/integration_tests/feature_tests/ffront_tests/test_type_deduction.py @@ -299,7 +299,7 @@ def callable_type_info_cases(): [ts.TupleType(types=[float_type, field_type])], {}, [ - "Expected 1st argument to be of type `tuple\[bool, Field\[\[I\], float64\]\]`, but got `tuple\[float64, Field\[\[I\], float64\]\]`" + r"Expected 1st argument to be of type `tuple\[bool, Field\[\[I\], float64\]\]`, but got `tuple\[float64, Field\[\[I\], float64\]\]`" ], ts.VoidType(), ), @@ -308,7 +308,7 @@ def callable_type_info_cases(): [int_type], {}, [ - "Expected 1st argument to be of type `tuple\[bool, Field\[\[I\], float64\]\]`, but got `int64`" + r"Expected 1st argument to be of type `tuple\[bool, Field\[\[I\], float64\]\]`, but got `int64`" ], ts.VoidType(), ), diff --git a/tests/next_tests/unit_tests/ffront_tests/test_func_to_past.py b/tests/next_tests/unit_tests/ffront_tests/test_func_to_past.py index 6e617f77a2..1d1a1efad4 100644 --- a/tests/next_tests/unit_tests/ffront_tests/test_func_to_past.py +++ b/tests/next_tests/unit_tests/ffront_tests/test_func_to_past.py @@ -206,7 +206,7 @@ def domain_format_3_program(in_field: gtx.Field[[IDim], float64]): assert exc_info.match("Invalid call to `domain_format_3`") assert ( - re.search("Missing required keyword argument\(s\) `out`.", exc_info.value.__cause__.args[0]) + re.search(r"Missing required keyword argument\(s\) `out`", exc_info.value.__cause__.args[0]) is not None ) diff --git a/tests/next_tests/unit_tests/ffront_tests/test_past_to_itir.py b/tests/next_tests/unit_tests/ffront_tests/test_past_to_itir.py index e56dc85322..c4fe30c596 100644 --- a/tests/next_tests/unit_tests/ffront_tests/test_past_to_itir.py +++ b/tests/next_tests/unit_tests/ffront_tests/test_past_to_itir.py @@ -187,6 +187,6 @@ def test_invalid_call_sig_program(invalid_call_sig_program_def): # is not None # ) assert ( - re.search("Missing required keyword argument\(s\) `out`", exc_info.value.__cause__.args[0]) + re.search(r"Missing required keyword argument\(s\) `out`", exc_info.value.__cause__.args[0]) is not None ) From 5a912cf1d97e3c5b3f555f1c104b8650df282263 Mon Sep 17 00:00:00 2001 From: edopao Date: Tue, 28 Nov 2023 12:52:16 +0100 Subject: [PATCH 15/85] fix[next]: DaCe backend - symbol propagation in lambda scope (#1367) Full-inlining of ITIR lift operator can result in nested lambda calls, which are translated to nested SDFGs in DaCe backend. The problem was that inner lambda SDFGs were not inheriting the symbols in scope from the parent SDFG, instead the DaCe backend was only mapping the lambda arguments. The solution implemented in this PR is to run a pre-pass to discover all symbols used in the nested lambdas, and propagate the required data containers from outer to inner SDFG. This PR also contains some cleanup of the scan visitor. The overall goal is to rely as much as possible on the visitor for itir.FunCall to generate the scan body. --- .../runners/dace_iterator/itir_to_sdfg.py | 188 +++++------ .../runners/dace_iterator/itir_to_tasklet.py | 295 +++++++++++------- 2 files changed, 263 insertions(+), 220 deletions(-) diff --git a/src/gt4py/next/program_processors/runners/dace_iterator/itir_to_sdfg.py b/src/gt4py/next/program_processors/runners/dace_iterator/itir_to_sdfg.py index a7cecf5fad..94878fd46d 100644 --- a/src/gt4py/next/program_processors/runners/dace_iterator/itir_to_sdfg.py +++ b/src/gt4py/next/program_processors/runners/dace_iterator/itir_to_sdfg.py @@ -24,9 +24,10 @@ from .itir_to_tasklet import ( Context, - IteratorExpr, + GatherOutputSymbolsPass, PythonTaskletCodegen, SymbolExpr, + TaskletExpr, ValueExpr, closure_to_tasklet_sdfg, is_scan, @@ -136,8 +137,13 @@ def add_storage(self, sdfg: dace.SDFG, name: str, type_: ts.TypeSpec, has_offset self.storage_types[name] = type_ def get_output_nodes( - self, closure: itir.StencilClosure, context: Context + self, closure: itir.StencilClosure, sdfg: dace.SDFG, state: dace.SDFGState ) -> dict[str, dace.nodes.AccessNode]: + # Visit output node, which could be a `make_tuple` expression, to collect the required access nodes + output_symbols_pass = GatherOutputSymbolsPass(sdfg, state, self.node_types) + output_symbols_pass.visit(closure.output) + # Visit output node again to generate the corresponding tasklet + context = Context(sdfg, state, output_symbols_pass.symbol_refs) translator = PythonTaskletCodegen(self.offset_provider, context, self.node_types) output_nodes = flatten_list(translator.visit(closure.output)) return {node.value.data: node.value for node in output_nodes} @@ -212,19 +218,16 @@ def visit_StencilClosure( closure_state = closure_sdfg.add_state("closure_entry") closure_init_state = closure_sdfg.add_state_before(closure_state, "closure_init") - program_arg_syms: dict[str, ValueExpr | IteratorExpr | SymbolExpr] = {} - closure_ctx = Context(closure_sdfg, closure_state, program_arg_syms) - neighbor_tables = filter_neighbor_tables(self.offset_provider) - input_names = [str(inp.id) for inp in node.inputs] - conn_names = [connectivity_identifier(offset) for offset, _ in neighbor_tables] + neighbor_tables = filter_neighbor_tables(self.offset_provider) + connectivity_names = [connectivity_identifier(offset) for offset, _ in neighbor_tables] - output_nodes = self.get_output_nodes(node, closure_ctx) + output_nodes = self.get_output_nodes(node, closure_sdfg, closure_state) output_names = [k for k, _ in output_nodes.items()] # Add DaCe arrays for inputs, outputs and connectivities to closure SDFG. input_transients_mapping = {} - for name in [*input_names, *conn_names, *output_names]: + for name in [*input_names, *connectivity_names, *output_names]: if name in closure_sdfg.arrays: assert name in input_names and name in output_names # In case of closures with in/out fields, there is risk of race condition @@ -268,6 +271,7 @@ def visit_StencilClosure( ) # Update symbol table and get output domain of the closure + program_arg_syms: dict[str, TaskletExpr] = {} for name, type_ in self.storage_types.items(): if isinstance(type_, ts.ScalarType): dtype = as_dace_type(type_) @@ -285,10 +289,11 @@ def visit_StencilClosure( program_arg_syms[name] = value else: program_arg_syms[name] = SymbolExpr(name, dtype) + closure_ctx = Context(closure_sdfg, closure_state, program_arg_syms) closure_domain = self._visit_domain(node.domain, closure_ctx) # Map SDFG tasklet arguments to parameters - input_access_names = [ + input_local_names = [ input_transients_mapping[input_name] if input_name in input_transients_mapping else input_name @@ -297,9 +302,9 @@ def visit_StencilClosure( for input_name in input_names ] input_memlets = [ - create_memlet_full(name, closure_sdfg.arrays[name]) for name in input_access_names + create_memlet_full(name, closure_sdfg.arrays[name]) + for name in [*input_local_names, *connectivity_names] ] - conn_memlets = [create_memlet_full(name, closure_sdfg.arrays[name]) for name in conn_names] # create and write to transient that is then copied back to actual output array to avoid aliasing of # same memory in nested SDFG with different names @@ -340,18 +345,18 @@ def visit_StencilClosure( for output_name in output_connectors_mapping.values() ] - input_mapping = {param: arg for param, arg in zip(input_names, input_memlets)} - output_mapping = {param: arg_memlet for param, arg_memlet in zip(results, output_memlets)} - conn_mapping = {param: arg for param, arg in zip(conn_names, conn_memlets)} + input_mapping = { + param: arg for param, arg in zip([*input_names, *connectivity_names], input_memlets) + } + output_mapping = {param: memlet for param, memlet in zip(results, output_memlets)} - array_mapping = {**input_mapping, **conn_mapping} - symbol_mapping = map_nested_sdfg_symbols(closure_sdfg, nsdfg, array_mapping) + symbol_mapping = map_nested_sdfg_symbols(closure_sdfg, nsdfg, input_mapping) nsdfg_node, map_entry, map_exit = add_mapped_nested_sdfg( closure_state, sdfg=nsdfg, map_ranges=map_ranges or {"__dummy": "0"}, - inputs=array_mapping, + inputs=input_mapping, outputs=output_mapping, symbol_mapping=symbol_mapping, output_nodes=output_nodes, @@ -376,7 +381,7 @@ def visit_StencilClosure( closure_state.remove_edge(edge) access_nodes[memlet.data].data = output_connectors_mapping[memlet.data] - return closure_sdfg, input_field_names + conn_names, output_names + return closure_sdfg, input_field_names + connectivity_names, output_names def _visit_scan_stencil_closure( self, @@ -422,6 +427,23 @@ def _visit_scan_stencil_closure( lambda_state = scan_sdfg.add_state("lambda_compute") end_state = scan_sdfg.add_state("end") + # the carry value of the scan operator exists only in the scope of the scan sdfg + scan_carry_name = unique_var_name() + scan_sdfg.add_scalar(scan_carry_name, dtype=as_dace_type(scan_dtype), transient=True) + + # tasklet for initialization of carry + carry_init_tasklet = start_state.add_tasklet( + "get_carry_init_value", {}, {"__result"}, f"__result = {init_carry_value}" + ) + start_state.add_edge( + carry_init_tasklet, + "__result", + start_state.add_access(scan_carry_name), + None, + dace.Memlet.simple(scan_carry_name, "0"), + ) + + # TODO(edopao): replace state machine with dace loop construct scan_sdfg.add_loop( start_state, lambda_state, @@ -434,7 +456,7 @@ def _visit_scan_stencil_closure( increment_expr=f"i_{scan_dim} + 1" if is_forward else f"i_{scan_dim} - 1", ) - # add access nodes to SDFG for inputs + # add storage to scan SDFG for inputs for name in [*input_names, *connectivity_names]: assert name not in scan_sdfg.arrays if isinstance(self.storage_types[name], ts.FieldType): @@ -448,116 +470,76 @@ def _visit_scan_stencil_closure( scan_sdfg.add_scalar( name, dtype=as_dace_type(cast(ts.ScalarType, self.storage_types[name])) ) + # add storage to scan SDFG for output + scan_sdfg.add_array( + output_name, + shape=(array_table[node.output.id].shape[scan_dim_index],), + strides=(array_table[node.output.id].strides[scan_dim_index],), + offset=(array_table[node.output.id].offset[scan_dim_index],), + dtype=array_table[node.output.id].dtype, + ) + # implement the lambda function as a nested SDFG that computes a single item in the scan dimension + lambda_domain = {dim: f"i_{dim}" for dim, _ in closure_domain} + input_arrays = [(scan_carry_name, scan_dtype)] + [ + (name, self.storage_types[name]) for name in input_names + ] connectivity_arrays = [(scan_sdfg.arrays[name], name) for name in connectivity_names] - - # implement the lambda closure as a nested SDFG that computes a single item of the map domain - lambda_context, lambda_inputs, lambda_outputs = closure_to_tasklet_sdfg( + lambda_context, lambda_outputs = closure_to_tasklet_sdfg( node, self.offset_provider, - {}, - [], + lambda_domain, + input_arrays, connectivity_arrays, self.node_types, ) + lambda_input_names = [name for name, _ in input_arrays] + lambda_output_names = [connector.value.data for connector in lambda_outputs] + + input_memlets = [ + create_memlet_full(name, scan_sdfg.arrays[name]) for name in lambda_input_names + ] connectivity_memlets = [ create_memlet_full(name, scan_sdfg.arrays[name]) for name in connectivity_names ] + input_mapping = {param: arg for param, arg in zip(lambda_input_names, input_memlets)} connectivity_mapping = { param: arg for param, arg in zip(connectivity_names, connectivity_memlets) } - - lambda_input_names = [inner_name for inner_name, _ in lambda_inputs] - symbol_mapping = map_nested_sdfg_symbols( - scan_sdfg, lambda_context.body, connectivity_mapping - ) + array_mapping = {**input_mapping, **connectivity_mapping} + symbol_mapping = map_nested_sdfg_symbols(scan_sdfg, lambda_context.body, array_mapping) scan_inner_node = lambda_state.add_nested_sdfg( lambda_context.body, parent=scan_sdfg, inputs=set(lambda_input_names) | set(connectivity_names), - outputs={connector.value.label for connector in lambda_outputs}, + outputs=set(lambda_output_names), symbol_mapping=symbol_mapping, ) - # the carry value of the scan operator exists in the scope of the scan sdfg - scan_carry_name = unique_var_name() - lambda_carry_name, _ = lambda_inputs[0] - scan_sdfg.add_scalar(scan_carry_name, dtype=as_dace_type(scan_dtype), transient=True) - - carry_init_tasklet = start_state.add_tasklet( - "get_carry_init_value", {}, {"__result"}, f"__result = {init_carry_value}" - ) - carry_node1 = start_state.add_access(scan_carry_name) - start_state.add_edge( - carry_init_tasklet, - "__result", - carry_node1, - None, - dace.Memlet.simple(scan_carry_name, "0"), - ) - - carry_node2 = lambda_state.add_access(scan_carry_name) - lambda_state.add_memlet_path( - carry_node2, - scan_inner_node, - memlet=dace.Memlet.simple(scan_carry_name, "0"), - src_conn=None, - dst_conn=lambda_carry_name, - ) - - # connect access nodes to lambda inputs - for (inner_name, _), data_name in zip(lambda_inputs[1:], input_names): - if isinstance(self.storage_types[data_name], ts.FieldType): - memlet = create_memlet_at(data_name, tuple(f"i_{dim}" for dim, _ in closure_domain)) - else: - memlet = dace.Memlet.simple(data_name, "0") - lambda_state.add_memlet_path( - lambda_state.add_access(data_name), - scan_inner_node, - memlet=memlet, - src_conn=None, - dst_conn=inner_name, - ) - - for inner_name, memlet in connectivity_mapping.items(): - access_node = lambda_state.add_access(inner_name) - lambda_state.add_memlet_path( - access_node, - scan_inner_node, - memlet=memlet, - src_conn=None, - dst_conn=inner_name, - propagate=True, - ) + # connect scan SDFG to lambda inputs + for name, memlet in array_mapping.items(): + access_node = lambda_state.add_access(name) + lambda_state.add_edge(access_node, None, scan_inner_node, name, memlet) output_names = [output_name] - assert len(lambda_outputs) == 1 - # connect lambda output to access node - for lambda_connector, data_name in zip(lambda_outputs, output_names): - scan_sdfg.add_array( - data_name, - shape=(array_table[node.output.id].shape[scan_dim_index],), - strides=(array_table[node.output.id].strides[scan_dim_index],), - offset=(array_table[node.output.id].offset[scan_dim_index],), - dtype=array_table[node.output.id].dtype, - ) - lambda_state.add_memlet_path( + assert len(lambda_output_names) == 1 + # connect lambda output to scan SDFG + for name, connector in zip(output_names, lambda_output_names): + lambda_state.add_edge( scan_inner_node, - lambda_state.add_access(data_name), - memlet=dace.Memlet.simple(data_name, f"i_{scan_dim}"), - src_conn=lambda_connector.value.label, - dst_conn=None, + connector, + lambda_state.add_access(name), + None, + dace.Memlet.simple(name, f"i_{scan_dim}"), ) # add state to scan SDFG to update the carry value at each loop iteration lambda_update_state = scan_sdfg.add_state_after(lambda_state, "lambda_update") - result_node = lambda_update_state.add_access(output_names[0]) - carry_node3 = lambda_update_state.add_access(scan_carry_name) lambda_update_state.add_memlet_path( - result_node, - carry_node3, + lambda_update_state.add_access(output_name), + lambda_update_state.add_access(scan_carry_name), memlet=dace.Memlet.simple(output_names[0], f"i_{scan_dim}", other_subset_str="0"), ) @@ -586,14 +568,14 @@ def _visit_parallel_stencil_closure( index_domain = {dim: f"i_{dim}" for dim, _ in closure_domain} input_arrays = [(name, self.storage_types[name]) for name in input_names] - conn_arrays = [(array_table[name], name) for name in conn_names] + connectivity_arrays = [(array_table[name], name) for name in conn_names] - context, _, results = closure_to_tasklet_sdfg( + context, results = closure_to_tasklet_sdfg( node, self.offset_provider, index_domain, input_arrays, - conn_arrays, + connectivity_arrays, self.node_types, ) diff --git a/src/gt4py/next/program_processors/runners/dace_iterator/itir_to_tasklet.py b/src/gt4py/next/program_processors/runners/dace_iterator/itir_to_tasklet.py index 5b240ea2b7..da54f9be14 100644 --- a/src/gt4py/next/program_processors/runners/dace_iterator/itir_to_tasklet.py +++ b/src/gt4py/next/program_processors/runners/dace_iterator/itir_to_tasklet.py @@ -11,11 +11,10 @@ # distribution for a copy of the license or check . # # SPDX-License-Identifier: GPL-3.0-or-later - import dataclasses import itertools from collections.abc import Sequence -from typing import Any, Callable, Optional, cast +from typing import Any, Callable, Optional, TypeAlias, cast import dace import numpy as np @@ -23,6 +22,7 @@ from dace.transformation.passes.prune_symbols import RemoveUnusedSymbols import gt4py.eve.codegen +from gt4py import eve from gt4py.next import Dimension, StridedNeighborOffsetProvider, type_inference as next_typing from gt4py.next.iterator import ir as itir, type_inference as itir_typing from gt4py.next.iterator.embedded import NeighborTableOffsetProvider @@ -151,11 +151,15 @@ class IteratorExpr: dimensions: list[str] +# Union of possible expression types +TaskletExpr: TypeAlias = IteratorExpr | SymbolExpr | ValueExpr + + @dataclasses.dataclass class Context: body: dace.SDFG state: dace.SDFGState - symbol_map: dict[str, IteratorExpr | ValueExpr | SymbolExpr] + symbol_map: dict[str, TaskletExpr] # if we encounter a reduction node, the reduction state needs to be pushed to child nodes reduce_limit: int reduce_wcr: Optional[str] @@ -164,13 +168,15 @@ def __init__( self, body: dace.SDFG, state: dace.SDFGState, - symbol_map: dict[str, IteratorExpr | ValueExpr | SymbolExpr], + symbol_map: dict[str, TaskletExpr], + reduce_limit: int = 0, + reduce_wcr: Optional[str] = None, ): self.body = body self.state = state self.symbol_map = symbol_map - self.reduce_limit = 0 - self.reduce_wcr = None + self.reduce_limit = reduce_limit + self.reduce_wcr = reduce_wcr def builtin_neighbors( @@ -350,6 +356,104 @@ def builtin_undefined(*args: Any) -> Any: } +class GatherLambdaSymbolsPass(eve.NodeVisitor): + _sdfg: dace.SDFG + _state: dace.SDFGState + _symbol_map: dict[str, TaskletExpr] + _parent_symbol_map: dict[str, TaskletExpr] + + def __init__( + self, + sdfg, + state, + parent_symbol_map, + ): + self._sdfg = sdfg + self._state = state + self._symbol_map = {} + self._parent_symbol_map = parent_symbol_map + + @property + def symbol_refs(self): + """Dictionary of symbols referenced from the lambda expression.""" + return self._symbol_map + + def _add_symbol(self, param, arg): + if isinstance(arg, ValueExpr): + # create storage in lambda sdfg + self._sdfg.add_scalar(param, dtype=arg.dtype) + # update table of lambda symbol + self._symbol_map[param] = ValueExpr(self._state.add_access(param), arg.dtype) + elif isinstance(arg, IteratorExpr): + # create storage in lambda sdfg + ndims = len(arg.dimensions) + shape = tuple( + dace.symbol(unique_var_name() + "__shp", dace.int64) for _ in range(ndims) + ) + strides = tuple( + dace.symbol(unique_var_name() + "__strd", dace.int64) for _ in range(ndims) + ) + self._sdfg.add_array(param, shape=shape, strides=strides, dtype=arg.dtype) + index_names = {dim: f"__{param}_i_{dim}" for dim in arg.indices.keys()} + for _, index_name in index_names.items(): + self._sdfg.add_scalar(index_name, dtype=dace.int64) + # update table of lambda symbol + field = self._state.add_access(param) + indices = { + dim: self._state.add_access(index_arg) for dim, index_arg in index_names.items() + } + self._symbol_map[param] = IteratorExpr(field, indices, arg.dtype, arg.dimensions) + else: + assert isinstance(arg, SymbolExpr) + self._symbol_map[param] = arg + + def visit_SymRef(self, node: itir.SymRef): + name = str(node.id) + if name in self._parent_symbol_map and name not in self._symbol_map: + arg = self._parent_symbol_map[name] + self._add_symbol(name, arg) + + def visit_Lambda(self, node: itir.Lambda, args: Optional[Sequence[TaskletExpr]] = None): + if args is not None: + assert len(node.params) == len(args) + for param, arg in zip(node.params, args): + self._add_symbol(str(param.id), arg) + self.visit(node.expr) + + +class GatherOutputSymbolsPass(eve.NodeVisitor): + _sdfg: dace.SDFG + _state: dace.SDFGState + _node_types: dict[int, next_typing.Type] + _symbol_map: dict[str, TaskletExpr] + + @property + def symbol_refs(self): + """Dictionary of symbols referenced from the output expression.""" + return self._symbol_map + + def __init__( + self, + sdfg, + state, + node_types, + ): + self._sdfg = sdfg + self._state = state + self._node_types = node_types + self._symbol_map = {} + + def visit_SymRef(self, node: itir.SymRef): + param = str(node.id) + if param not in _GENERAL_BUILTIN_MAPPING and param not in self._symbol_map: + node_type = self._node_types[id(node)] + assert isinstance(node_type, Val) + access_node = self._state.add_access(param) + self._symbol_map[param] = ValueExpr( + access_node, dtype=itir_type_as_dace_type(node_type.dtype) + ) + + class PythonTaskletCodegen(gt4py.eve.codegen.TemplatedGenerator): offset_provider: dict[str, Any] context: Context @@ -369,7 +473,7 @@ def visit_FunctionDefinition(self, node: itir.FunctionDefinition, **kwargs): raise NotImplementedError() def visit_Lambda( - self, node: itir.Lambda, args: Sequence[ValueExpr | SymbolExpr] + self, node: itir.Lambda, args: Sequence[TaskletExpr] ) -> tuple[ Context, list[tuple[str, ValueExpr] | tuple[tuple[str, dict], IteratorExpr]], @@ -377,62 +481,38 @@ def visit_Lambda( ]: func_name = f"lambda_{abs(hash(node)):x}" neighbor_tables = filter_neighbor_tables(self.offset_provider) - param_names = [str(p.id) for p in node.params] - conn_names = [connectivity_identifier(offset) for offset, _ in neighbor_tables] - - assert len(param_names) == len(args) - symbols = { - **{param: arg for param, arg in zip(param_names, args)}, - } - - # Create the SDFG for the function's body - prev_context = self.context - context_sdfg = dace.SDFG(func_name) - context_state = context_sdfg.add_state(f"{func_name}_entry", True) - symbol_map: dict[str, ValueExpr | IteratorExpr | SymbolExpr] = {} - value: ValueExpr | IteratorExpr - for param, arg in symbols.items(): - if isinstance(arg, ValueExpr): - value = ValueExpr(context_state.add_access(param), arg.dtype) - else: - assert isinstance(arg, IteratorExpr) - field = context_state.add_access(param) - indices = { - dim: context_state.add_access(f"__{param}_i_{dim}") - for dim in arg.indices.keys() - } - value = IteratorExpr(field, indices, arg.dtype, arg.dimensions) - symbol_map[param] = value - context = Context(context_sdfg, context_state, symbol_map) - context.reduce_limit = prev_context.reduce_limit - context.reduce_wcr = prev_context.reduce_wcr - self.context = context + connectivity_names = [connectivity_identifier(offset) for offset, _ in neighbor_tables] + + # Create the SDFG for the lambda's body + lambda_sdfg = dace.SDFG(func_name) + lambda_state = lambda_sdfg.add_state(f"{func_name}_entry", True) - # Add input parameters as arrays + lambda_symbols_pass = GatherLambdaSymbolsPass( + lambda_sdfg, lambda_state, self.context.symbol_map + ) + lambda_symbols_pass.visit(node, args=args) + + # Add for input nodes for lambda symbols inputs: list[tuple[str, ValueExpr] | tuple[tuple[str, dict], IteratorExpr]] = [] - for name, arg in symbols.items(): - if isinstance(arg, ValueExpr): - dtype = arg.dtype - context.body.add_scalar(name, dtype=dtype) - inputs.append((name, arg)) + for sym, input_node in lambda_symbols_pass.symbol_refs.items(): + arg = next((arg for param, arg in zip(node.params, args) if param.id == sym), None) + if arg: + outer_node = arg else: - assert isinstance(arg, IteratorExpr) - ndims = len(arg.dimensions) - shape = tuple( - dace.symbol(unique_var_name() + "__shp", dace.int64) for _ in range(ndims) - ) - strides = tuple( - dace.symbol(unique_var_name() + "__strd", dace.int64) for _ in range(ndims) - ) - dtype = arg.dtype - context.body.add_array(name, shape=shape, strides=strides, dtype=dtype) - index_names = {dim: f"__{name}_i_{dim}" for dim in arg.indices.keys()} - for _, index_name in index_names.items(): - context.body.add_scalar(index_name, dtype=dace.int64) - inputs.append(((name, index_names), arg)) + # the symbol is not found among lambda arguments, then it is inherited from parent scope + outer_node = self.context.symbol_map[sym] + if isinstance(input_node, IteratorExpr): + assert isinstance(outer_node, IteratorExpr) + index_params = { + dim: index_node.data for dim, index_node in input_node.indices.items() + } + inputs.append(((sym, index_params), outer_node)) + elif isinstance(input_node, ValueExpr): + assert isinstance(outer_node, ValueExpr) + inputs.append((sym, outer_node)) # Add connectivities as arrays - for name in conn_names: + for name in connectivity_names: shape = ( dace.symbol(unique_var_name() + "__shp", dace.int64), dace.symbol(unique_var_name() + "__shp", dace.int64), @@ -441,50 +521,53 @@ def visit_Lambda( dace.symbol(unique_var_name() + "__strd", dace.int64), dace.symbol(unique_var_name() + "__strd", dace.int64), ) - dtype = prev_context.body.arrays[name].dtype - context.body.add_array(name, shape=shape, strides=strides, dtype=dtype) + dtype = self.context.body.arrays[name].dtype + lambda_sdfg.add_array(name, shape=shape, strides=strides, dtype=dtype) + + # Translate the lambda's body in its own context + lambda_context = Context( + lambda_sdfg, + lambda_state, + lambda_symbols_pass.symbol_refs, + reduce_limit=self.context.reduce_limit, + reduce_wcr=self.context.reduce_wcr, + ) + lambda_taskgen = PythonTaskletCodegen(self.offset_provider, lambda_context, self.node_types) - # Translate the function's body results: list[ValueExpr] = [] # We are flattening the returned list of value expressions because the multiple outputs of a lamda # should be a list of nodes without tuple structure. Ideally, an ITIR transformation could do this. - for expr in flatten_list(self.visit(node.expr)): + for expr in flatten_list(lambda_taskgen.visit(node.expr)): if isinstance(expr, ValueExpr): result_name = unique_var_name() - self.context.body.add_scalar(result_name, expr.dtype, transient=True) - result_access = self.context.state.add_access(result_name) - self.context.state.add_edge( + lambda_sdfg.add_scalar(result_name, expr.dtype, transient=True) + result_access = lambda_state.add_access(result_name) + lambda_state.add_edge( expr.value, None, result_access, None, # in case of reduction lambda, the output edge from lambda tasklet performs write-conflict resolution - dace.Memlet.simple(result_access.data, "0", wcr_str=context.reduce_wcr), + dace.Memlet.simple(result_access.data, "0", wcr_str=self.context.reduce_wcr), ) result = ValueExpr(value=result_access, dtype=expr.dtype) else: # Forwarding result through a tasklet needed because empty SDFG states don't properly forward connectors - result = self.add_expr_tasklet([], expr.value, expr.dtype, "forward")[0] - self.context.body.arrays[result.value.data].transient = False + result = lambda_taskgen.add_expr_tasklet([], expr.value, expr.dtype, "forward")[0] + lambda_sdfg.arrays[result.value.data].transient = False results.append(result) - self.context = prev_context - for node in context.state.nodes(): - if isinstance(node, dace.nodes.AccessNode): - if context.state.out_degree(node) == 0 and context.state.in_degree(node) == 0: - context.state.remove_node(node) + # remove isolated access nodes for connectivity arrays not consumed by lambda + for sub_node in lambda_state.nodes(): + if isinstance(sub_node, dace.nodes.AccessNode): + if lambda_state.out_degree(sub_node) == 0 and lambda_state.in_degree(sub_node) == 0: + lambda_state.remove_node(sub_node) - return context, inputs, results + return lambda_context, inputs, results def visit_SymRef(self, node: itir.SymRef) -> list[ValueExpr | SymbolExpr] | IteratorExpr: - if node.id not in self.context.symbol_map: - acc = self.context.state.add_access(node.id) - node_type = self.node_types[id(node)] - assert isinstance(node_type, Val) - self.context.symbol_map[node.id] = ValueExpr( - value=acc, dtype=itir_type_as_dace_type(node_type.dtype) - ) - value = self.context.symbol_map[node.id] + param = str(node.id) + value = self.context.symbol_map[param] if isinstance(value, (ValueExpr, SymbolExpr)): return [value] return value @@ -952,29 +1035,6 @@ def is_scan(node: itir.Node) -> bool: return isinstance(node, itir.FunCall) and node.fun == itir.SymRef(id="scan") -def _visit_scan_closure_callable( - node: itir.StencilClosure, - tlet_codegen: PythonTaskletCodegen, -) -> tuple[Context, Sequence[tuple[str, ValueExpr]], Sequence[ValueExpr]]: - stencil = cast(FunCall, node.stencil) - assert isinstance(stencil.args[0], Lambda) - fun_node = itir.Lambda(expr=stencil.args[0].expr, params=stencil.args[0].params) - - args = list(itertools.chain(tlet_codegen.visit(node.output), *tlet_codegen.visit(node.inputs))) - return tlet_codegen.visit(fun_node, args=args) - - -def _visit_closure_callable( - node: itir.StencilClosure, - tlet_codegen: PythonTaskletCodegen, - input_names: Sequence[str], -) -> Sequence[ValueExpr]: - args = [itir.SymRef(id=name) for name in input_names] - fun_node = itir.FunCall(fun=node.stencil, args=args) - - return tlet_codegen.visit(fun_node) - - def closure_to_tasklet_sdfg( node: itir.StencilClosure, offset_provider: dict[str, Any], @@ -982,10 +1042,10 @@ def closure_to_tasklet_sdfg( inputs: Sequence[tuple[str, ts.TypeSpec]], connectivities: Sequence[tuple[dace.ndarray, str]], node_types: dict[int, next_typing.Type], -) -> tuple[Context, Sequence[tuple[str, ValueExpr]], Sequence[ValueExpr]]: +) -> tuple[Context, Sequence[ValueExpr]]: body = dace.SDFG("tasklet_toplevel") state = body.add_state("tasklet_toplevel_entry") - symbol_map: dict[str, ValueExpr | IteratorExpr | SymbolExpr] = {} + symbol_map: dict[str, TaskletExpr] = {} idx_accesses = {} for dim, idx in domain.items(): @@ -1023,16 +1083,17 @@ def closure_to_tasklet_sdfg( context = Context(body, state, symbol_map) translator = PythonTaskletCodegen(offset_provider, context, node_types) + args = [itir.SymRef(id=name) for name, _ in inputs] if is_scan(node.stencil): - context, inner_inputs, inner_outputs = _visit_scan_closure_callable(node, translator) + stencil = cast(FunCall, node.stencil) + assert isinstance(stencil.args[0], Lambda) + lambda_node = itir.Lambda(expr=stencil.args[0].expr, params=stencil.args[0].params) + fun_node = itir.FunCall(fun=lambda_node, args=args) else: - inner_inputs = [] - inner_outputs = _visit_closure_callable( - node, - translator, - [name for name, _ in inputs], - ) - for output in inner_outputs: - context.body.arrays[output.value.data].transient = False + fun_node = itir.FunCall(fun=node.stencil, args=args) + + results = translator.visit(fun_node) + for r in results: + context.body.arrays[r.value.data].transient = False - return context, inner_inputs, inner_outputs + return context, results From 91307b10e2ca1edb76a72cd8a3bebdd66898da60 Mon Sep 17 00:00:00 2001 From: Till Ehrengruber Date: Tue, 28 Nov 2023 14:49:45 +0100 Subject: [PATCH 16/85] feature[next]: Cache direct field operator call (`as_program`) (#1254) Cache direct calls to field operators by storing the autogenerated programs in a cache. --- src/gt4py/next/ffront/decorator.py | 11 ++++++++++- 1 file changed, 10 insertions(+), 1 deletion(-) diff --git a/src/gt4py/next/ffront/decorator.py b/src/gt4py/next/ffront/decorator.py index 7572040e13..67272f88b8 100644 --- a/src/gt4py/next/ffront/decorator.py +++ b/src/gt4py/next/ffront/decorator.py @@ -545,6 +545,7 @@ class FieldOperator(GTCallable, Generic[OperatorNodeT]): definition: Optional[types.FunctionType] = None backend: Optional[ppi.ProgramExecutor] = DEFAULT_BACKEND grid_type: Optional[GridType] = None + _program_cache: dict = dataclasses.field(default_factory=dict) @classmethod def from_function( @@ -613,6 +614,13 @@ def as_program( # of arg and kwarg types # TODO(tehrengruber): check foast operator has no out argument that clashes # with the out argument of the program we generate here. + hash_ = eve_utils.content_hash( + (tuple(arg_types), tuple((name, arg) for name, arg in kwarg_types.items())) + ) + try: + return self._program_cache[hash_] + except KeyError: + pass loc = self.foast_node.location param_sym_uids = eve_utils.UIDGenerator() # use a new UID generator to allow caching @@ -666,12 +674,13 @@ def as_program( untyped_past_node = ProgramClosureVarTypeDeduction.apply(untyped_past_node, closure_vars) past_node = ProgramTypeDeduction.apply(untyped_past_node) - return Program( + self._program_cache[hash_] = Program( past_node=past_node, closure_vars=closure_vars, backend=self.backend, grid_type=self.grid_type, ) + return self._program_cache[hash_] def __call__( self, From 4c022866d75c4bdbbff2d24775dcbc40e3a9a0db Mon Sep 17 00:00:00 2001 From: Till Ehrengruber Date: Wed, 29 Nov 2023 17:59:54 +0100 Subject: [PATCH 17/85] refactor[next]: Move iterator utils to dedicated module (#1371) Move `gt4py.next.iterator.ir_makers` and `gt4py.next.iterator.transforms.common_pattern_matcher` into a new module named `gt4py.next.iterator.ir_utils`. Just a small refactoring in preparation of #1350. --- src/gt4py/next/ffront/decorator.py | 7 ++++++- src/gt4py/next/ffront/foast_to_itir.py | 3 ++- src/gt4py/next/iterator/ir_utils/__init__.py | 13 +++++++++++++ .../common_pattern_matcher.py | 0 src/gt4py/next/iterator/{ => ir_utils}/ir_makers.py | 0 src/gt4py/next/iterator/pretty_parser.py | 3 ++- src/gt4py/next/iterator/tracing.py | 3 ++- .../next/iterator/transforms/constant_folding.py | 3 ++- src/gt4py/next/iterator/transforms/cse.py | 2 +- src/gt4py/next/iterator/transforms/global_tmps.py | 5 +++-- .../next/iterator/transforms/inline_lambdas.py | 2 +- src/gt4py/next/iterator/transforms/inline_lifts.py | 3 ++- .../next/iterator/transforms/symbol_ref_utils.py | 2 +- src/gt4py/next/iterator/transforms/unroll_reduce.py | 2 +- .../unit_tests/ffront_tests/test_foast_to_itir.py | 3 ++- .../iterator_tests/test_type_inference.py | 3 ++- .../transforms_tests/test_collapse_tuple.py | 4 +--- .../transforms_tests/test_constant_folding.py | 2 +- .../iterator_tests/transforms_tests/test_cse.py | 3 ++- .../transforms_tests/test_global_tmps.py | 3 ++- .../transforms_tests/test_inline_lambdas.py | 2 +- .../transforms_tests/test_inline_lifts.py | 2 +- .../transforms_tests/test_propagate_deref.py | 2 +- .../transforms_tests/test_trace_shifts.py | 3 ++- 24 files changed, 51 insertions(+), 24 deletions(-) create mode 100644 src/gt4py/next/iterator/ir_utils/__init__.py rename src/gt4py/next/iterator/{transforms => ir_utils}/common_pattern_matcher.py (100%) rename src/gt4py/next/iterator/{ => ir_utils}/ir_makers.py (100%) diff --git a/src/gt4py/next/ffront/decorator.py b/src/gt4py/next/ffront/decorator.py index 67272f88b8..e06c651b13 100644 --- a/src/gt4py/next/ffront/decorator.py +++ b/src/gt4py/next/ffront/decorator.py @@ -53,7 +53,12 @@ from gt4py.next.ffront.past_to_itir import ProgramLowering from gt4py.next.ffront.source_utils import SourceDefinition, get_closure_vars_from_function from gt4py.next.iterator import ir as itir -from gt4py.next.iterator.ir_makers import literal_from_value, promote_to_const_iterator, ref, sym +from gt4py.next.iterator.ir_utils.ir_makers import ( + literal_from_value, + promote_to_const_iterator, + ref, + sym, +) from gt4py.next.program_processors import processor_interface as ppi from gt4py.next.program_processors.runners import roundtrip from gt4py.next.type_system import type_info, type_specifications as ts, type_translation diff --git a/src/gt4py/next/ffront/foast_to_itir.py b/src/gt4py/next/ffront/foast_to_itir.py index 816b8581f1..3030c03fd1 100644 --- a/src/gt4py/next/ffront/foast_to_itir.py +++ b/src/gt4py/next/ffront/foast_to_itir.py @@ -25,7 +25,8 @@ ) from gt4py.next.ffront.fbuiltins import FUN_BUILTIN_NAMES, MATH_BUILTIN_NAMES, TYPE_BUILTIN_NAMES from gt4py.next.ffront.foast_introspection import StmtReturnKind, deduce_stmt_return_kind -from gt4py.next.iterator import ir as itir, ir_makers as im +from gt4py.next.iterator import ir as itir +from gt4py.next.iterator.ir_utils import ir_makers as im from gt4py.next.type_system import type_info, type_specifications as ts diff --git a/src/gt4py/next/iterator/ir_utils/__init__.py b/src/gt4py/next/iterator/ir_utils/__init__.py new file mode 100644 index 0000000000..6c43e2f12a --- /dev/null +++ b/src/gt4py/next/iterator/ir_utils/__init__.py @@ -0,0 +1,13 @@ +# GT4Py - GridTools Framework +# +# Copyright (c) 2014-2023, ETH Zurich +# All rights reserved. +# +# This file is part of the GT4Py project and the GridTools framework. +# GT4Py is free software: you can redistribute it and/or modify it under +# the terms of the GNU General Public License as published by the +# Free Software Foundation, either version 3 of the License, or any later +# version. See the LICENSE.txt file at the top-level directory of this +# distribution for a copy of the license or check . +# +# SPDX-License-Identifier: GPL-3.0-or-later diff --git a/src/gt4py/next/iterator/transforms/common_pattern_matcher.py b/src/gt4py/next/iterator/ir_utils/common_pattern_matcher.py similarity index 100% rename from src/gt4py/next/iterator/transforms/common_pattern_matcher.py rename to src/gt4py/next/iterator/ir_utils/common_pattern_matcher.py diff --git a/src/gt4py/next/iterator/ir_makers.py b/src/gt4py/next/iterator/ir_utils/ir_makers.py similarity index 100% rename from src/gt4py/next/iterator/ir_makers.py rename to src/gt4py/next/iterator/ir_utils/ir_makers.py diff --git a/src/gt4py/next/iterator/pretty_parser.py b/src/gt4py/next/iterator/pretty_parser.py index a541e985ad..2b1c8169fb 100644 --- a/src/gt4py/next/iterator/pretty_parser.py +++ b/src/gt4py/next/iterator/pretty_parser.py @@ -16,7 +16,8 @@ from lark import lark, lexer as lark_lexer, visitors as lark_visitors -from gt4py.next.iterator import ir, ir_makers as im +from gt4py.next.iterator import ir +from gt4py.next.iterator.ir_utils import ir_makers as im GRAMMAR = """ diff --git a/src/gt4py/next/iterator/tracing.py b/src/gt4py/next/iterator/tracing.py index fbe6a2ae82..d1f6bba8d6 100644 --- a/src/gt4py/next/iterator/tracing.py +++ b/src/gt4py/next/iterator/tracing.py @@ -20,7 +20,7 @@ from gt4py._core import definitions as core_defs from gt4py.eve import Node from gt4py.next import common, iterator -from gt4py.next.iterator import builtins, ir_makers as im +from gt4py.next.iterator import builtins from gt4py.next.iterator.ir import ( AxisLiteral, Expr, @@ -34,6 +34,7 @@ Sym, SymRef, ) +from gt4py.next.iterator.ir_utils import ir_makers as im from gt4py.next.type_system import type_info, type_specifications, type_translation diff --git a/src/gt4py/next/iterator/transforms/constant_folding.py b/src/gt4py/next/iterator/transforms/constant_folding.py index cda422f30d..fa326760b0 100644 --- a/src/gt4py/next/iterator/transforms/constant_folding.py +++ b/src/gt4py/next/iterator/transforms/constant_folding.py @@ -13,7 +13,8 @@ # SPDX-License-Identifier: GPL-3.0-or-later from gt4py.eve import NodeTranslator -from gt4py.next.iterator import embedded, ir, ir_makers as im +from gt4py.next.iterator import embedded, ir +from gt4py.next.iterator.ir_utils import ir_makers as im class ConstantFolding(NodeTranslator): diff --git a/src/gt4py/next/iterator/transforms/cse.py b/src/gt4py/next/iterator/transforms/cse.py index 672e23c5e7..cc70e11413 100644 --- a/src/gt4py/next/iterator/transforms/cse.py +++ b/src/gt4py/next/iterator/transforms/cse.py @@ -233,7 +233,7 @@ def extract_subexpression( Examples: Default case for `(x+y) + ((x+y)+z)`: - >>> import gt4py.next.iterator.ir_makers as im + >>> import gt4py.next.iterator.ir_utils.ir_makers as im >>> from gt4py.eve.utils import UIDGenerator >>> expr = im.plus(im.plus("x", "y"), im.plus(im.plus("x", "y"), "z")) >>> predicate = lambda subexpr, num_occurences: num_occurences > 1 diff --git a/src/gt4py/next/iterator/transforms/global_tmps.py b/src/gt4py/next/iterator/transforms/global_tmps.py index e1b697e0bc..d9d3d18213 100644 --- a/src/gt4py/next/iterator/transforms/global_tmps.py +++ b/src/gt4py/next/iterator/transforms/global_tmps.py @@ -22,10 +22,11 @@ from gt4py.eve import Coerced, NodeTranslator from gt4py.eve.traits import SymbolTableTrait from gt4py.eve.utils import UIDGenerator -from gt4py.next.iterator import ir, ir_makers as im, type_inference +from gt4py.next.iterator import ir, type_inference +from gt4py.next.iterator.ir_utils import ir_makers as im +from gt4py.next.iterator.ir_utils.common_pattern_matcher import is_applied_lift from gt4py.next.iterator.pretty_printer import PrettyPrinter from gt4py.next.iterator.transforms import trace_shifts -from gt4py.next.iterator.transforms.common_pattern_matcher import is_applied_lift from gt4py.next.iterator.transforms.cse import extract_subexpression from gt4py.next.iterator.transforms.eta_reduction import EtaReduction from gt4py.next.iterator.transforms.inline_lambdas import InlineLambdas diff --git a/src/gt4py/next/iterator/transforms/inline_lambdas.py b/src/gt4py/next/iterator/transforms/inline_lambdas.py index fc268f85e3..eac4338345 100644 --- a/src/gt4py/next/iterator/transforms/inline_lambdas.py +++ b/src/gt4py/next/iterator/transforms/inline_lambdas.py @@ -17,7 +17,7 @@ from gt4py.eve import NodeTranslator from gt4py.next.iterator import ir -from gt4py.next.iterator.transforms.common_pattern_matcher import is_applied_lift +from gt4py.next.iterator.ir_utils.common_pattern_matcher import is_applied_lift from gt4py.next.iterator.transforms.remap_symbols import RemapSymbolRefs, RenameSymbols from gt4py.next.iterator.transforms.symbol_ref_utils import CountSymbolRefs diff --git a/src/gt4py/next/iterator/transforms/inline_lifts.py b/src/gt4py/next/iterator/transforms/inline_lifts.py index 8d62450e67..d7d8e5e612 100644 --- a/src/gt4py/next/iterator/transforms/inline_lifts.py +++ b/src/gt4py/next/iterator/transforms/inline_lifts.py @@ -19,7 +19,8 @@ import gt4py.eve as eve from gt4py.eve import NodeTranslator, traits -from gt4py.next.iterator import ir, ir_makers as im +from gt4py.next.iterator import ir +from gt4py.next.iterator.ir_utils import ir_makers as im from gt4py.next.iterator.transforms.inline_lambdas import inline_lambda diff --git a/src/gt4py/next/iterator/transforms/symbol_ref_utils.py b/src/gt4py/next/iterator/transforms/symbol_ref_utils.py index 1f604d62b9..1c587fb9d6 100644 --- a/src/gt4py/next/iterator/transforms/symbol_ref_utils.py +++ b/src/gt4py/next/iterator/transforms/symbol_ref_utils.py @@ -36,7 +36,7 @@ def apply( Count references to given or all symbols in scope. Examples: - >>> import gt4py.next.iterator.ir_makers as im + >>> import gt4py.next.iterator.ir_utils.ir_makers as im >>> expr = im.plus(im.plus("x", "y"), im.plus(im.plus("x", "y"), "z")) >>> CountSymbolRefs.apply(expr) {'x': 2, 'y': 2, 'z': 1} diff --git a/src/gt4py/next/iterator/transforms/unroll_reduce.py b/src/gt4py/next/iterator/transforms/unroll_reduce.py index e3084eaba5..60a5db7e96 100644 --- a/src/gt4py/next/iterator/transforms/unroll_reduce.py +++ b/src/gt4py/next/iterator/transforms/unroll_reduce.py @@ -20,7 +20,7 @@ from gt4py.eve.utils import UIDGenerator from gt4py.next import common from gt4py.next.iterator import ir as itir -from gt4py.next.iterator.transforms.common_pattern_matcher import is_applied_lift +from gt4py.next.iterator.ir_utils.common_pattern_matcher import is_applied_lift def _is_shifted(arg: itir.Expr) -> TypeGuard[itir.FunCall]: diff --git a/tests/next_tests/unit_tests/ffront_tests/test_foast_to_itir.py b/tests/next_tests/unit_tests/ffront_tests/test_foast_to_itir.py index dd66beb522..2dd4b91c48 100644 --- a/tests/next_tests/unit_tests/ffront_tests/test_foast_to_itir.py +++ b/tests/next_tests/unit_tests/ffront_tests/test_foast_to_itir.py @@ -29,7 +29,8 @@ from gt4py.next.ffront.ast_passes import single_static_assign as ssa from gt4py.next.ffront.foast_to_itir import FieldOperatorLowering from gt4py.next.ffront.func_to_foast import FieldOperatorParser -from gt4py.next.iterator import ir as itir, ir_makers as im +from gt4py.next.iterator import ir as itir +from gt4py.next.iterator.ir_utils import ir_makers as im from gt4py.next.type_system import type_specifications as ts, type_translation diff --git a/tests/next_tests/unit_tests/iterator_tests/test_type_inference.py b/tests/next_tests/unit_tests/iterator_tests/test_type_inference.py index 1526e97d74..cacdb7b070 100644 --- a/tests/next_tests/unit_tests/iterator_tests/test_type_inference.py +++ b/tests/next_tests/unit_tests/iterator_tests/test_type_inference.py @@ -15,7 +15,8 @@ import numpy as np import gt4py.next as gtx -from gt4py.next.iterator import ir, ir_makers as im, type_inference as ti +from gt4py.next.iterator import ir, type_inference as ti +from gt4py.next.iterator.ir_utils import ir_makers as im def test_unsatisfiable_constraints(): diff --git a/tests/next_tests/unit_tests/iterator_tests/transforms_tests/test_collapse_tuple.py b/tests/next_tests/unit_tests/iterator_tests/transforms_tests/test_collapse_tuple.py index 736bf04d64..1444b0a64f 100644 --- a/tests/next_tests/unit_tests/iterator_tests/transforms_tests/test_collapse_tuple.py +++ b/tests/next_tests/unit_tests/iterator_tests/transforms_tests/test_collapse_tuple.py @@ -12,9 +12,7 @@ # # SPDX-License-Identifier: GPL-3.0-or-later -import pytest - -from gt4py.next.iterator import ir, ir_makers as im +from gt4py.next.iterator.ir_utils import ir_makers as im from gt4py.next.iterator.transforms.collapse_tuple import CollapseTuple diff --git a/tests/next_tests/unit_tests/iterator_tests/transforms_tests/test_constant_folding.py b/tests/next_tests/unit_tests/iterator_tests/transforms_tests/test_constant_folding.py index 5d052b1989..275412a537 100644 --- a/tests/next_tests/unit_tests/iterator_tests/transforms_tests/test_constant_folding.py +++ b/tests/next_tests/unit_tests/iterator_tests/transforms_tests/test_constant_folding.py @@ -12,7 +12,7 @@ # # SPDX-License-Identifier: GPL-3.0-or-later -from gt4py.next.iterator import ir_makers as im +from gt4py.next.iterator.ir_utils import ir_makers as im from gt4py.next.iterator.transforms.constant_folding import ConstantFolding diff --git a/tests/next_tests/unit_tests/iterator_tests/transforms_tests/test_cse.py b/tests/next_tests/unit_tests/iterator_tests/transforms_tests/test_cse.py index 5d9e0933a7..065095e1c2 100644 --- a/tests/next_tests/unit_tests/iterator_tests/transforms_tests/test_cse.py +++ b/tests/next_tests/unit_tests/iterator_tests/transforms_tests/test_cse.py @@ -14,7 +14,8 @@ import textwrap from gt4py.eve.utils import UIDGenerator -from gt4py.next.iterator import ir, ir_makers as im +from gt4py.next.iterator import ir +from gt4py.next.iterator.ir_utils import ir_makers as im from gt4py.next.iterator.transforms.cse import ( CommonSubexpressionElimination as CSE, extract_subexpression, diff --git a/tests/next_tests/unit_tests/iterator_tests/transforms_tests/test_global_tmps.py b/tests/next_tests/unit_tests/iterator_tests/transforms_tests/test_global_tmps.py index 88f6ed517b..86c3c98c62 100644 --- a/tests/next_tests/unit_tests/iterator_tests/transforms_tests/test_global_tmps.py +++ b/tests/next_tests/unit_tests/iterator_tests/transforms_tests/test_global_tmps.py @@ -15,7 +15,8 @@ import gt4py.next as gtx from gt4py.eve.utils import UIDs -from gt4py.next.iterator import ir, ir_makers as im +from gt4py.next.iterator import ir +from gt4py.next.iterator.ir_utils import ir_makers as im from gt4py.next.iterator.transforms.global_tmps import ( AUTO_DOMAIN, FencilWithTemporaries, diff --git a/tests/next_tests/unit_tests/iterator_tests/transforms_tests/test_inline_lambdas.py b/tests/next_tests/unit_tests/iterator_tests/transforms_tests/test_inline_lambdas.py index b9f2ca16a1..88e554f349 100644 --- a/tests/next_tests/unit_tests/iterator_tests/transforms_tests/test_inline_lambdas.py +++ b/tests/next_tests/unit_tests/iterator_tests/transforms_tests/test_inline_lambdas.py @@ -14,7 +14,7 @@ import pytest -from gt4py.next.iterator import ir_makers as im +from gt4py.next.iterator.ir_utils import ir_makers as im from gt4py.next.iterator.transforms.inline_lambdas import InlineLambdas diff --git a/tests/next_tests/unit_tests/iterator_tests/transforms_tests/test_inline_lifts.py b/tests/next_tests/unit_tests/iterator_tests/transforms_tests/test_inline_lifts.py index 1da2b8a044..e1d440044d 100644 --- a/tests/next_tests/unit_tests/iterator_tests/transforms_tests/test_inline_lifts.py +++ b/tests/next_tests/unit_tests/iterator_tests/transforms_tests/test_inline_lifts.py @@ -14,7 +14,7 @@ import pytest -from gt4py.next.iterator import ir_makers as im +from gt4py.next.iterator.ir_utils import ir_makers as im from gt4py.next.iterator.transforms.inline_lifts import InlineLifts diff --git a/tests/next_tests/unit_tests/iterator_tests/transforms_tests/test_propagate_deref.py b/tests/next_tests/unit_tests/iterator_tests/transforms_tests/test_propagate_deref.py index ffbf2c2c8e..e2e29cd4db 100644 --- a/tests/next_tests/unit_tests/iterator_tests/transforms_tests/test_propagate_deref.py +++ b/tests/next_tests/unit_tests/iterator_tests/transforms_tests/test_propagate_deref.py @@ -12,7 +12,7 @@ # # SPDX-License-Identifier: GPL-3.0-or-later -from gt4py.next.iterator import ir_makers as im +from gt4py.next.iterator.ir_utils import ir_makers as im from gt4py.next.iterator.transforms.propagate_deref import PropagateDeref diff --git a/tests/next_tests/unit_tests/iterator_tests/transforms_tests/test_trace_shifts.py b/tests/next_tests/unit_tests/iterator_tests/transforms_tests/test_trace_shifts.py index 2624a17ebd..47db632a5e 100644 --- a/tests/next_tests/unit_tests/iterator_tests/transforms_tests/test_trace_shifts.py +++ b/tests/next_tests/unit_tests/iterator_tests/transforms_tests/test_trace_shifts.py @@ -12,7 +12,8 @@ # # SPDX-License-Identifier: GPL-3.0-or-later -from gt4py.next.iterator import ir, ir_makers as im +from gt4py.next.iterator import ir +from gt4py.next.iterator.ir_utils import ir_makers as im from gt4py.next.iterator.transforms.trace_shifts import Sentinel, TraceShifts From e564cdc14277155cf820c94f3077b9194986a01d Mon Sep 17 00:00:00 2001 From: edopao Date: Thu, 30 Nov 2023 11:14:34 +0100 Subject: [PATCH 18/85] feat[next]: Add option to ITIR transformation to inline lambda args (#1370) Full-inlining of unrolled reduce in DaCe backend requires lambda arguments to be inlined, in order to generate the corresponding taskgraph. This PR adds an option to the existing ITIR transformation InlineLambdas to enable this additional transformation pass, disabled by default. --- .../iterator/transforms/inline_lambdas.py | 14 ++++++++++++ .../next/iterator/transforms/pass_manager.py | 7 +++++- .../runners/dace_iterator/__init__.py | 1 + .../transforms_tests/test_inline_lambdas.py | 22 +++++++++++++++++++ 4 files changed, 43 insertions(+), 1 deletion(-) diff --git a/src/gt4py/next/iterator/transforms/inline_lambdas.py b/src/gt4py/next/iterator/transforms/inline_lambdas.py index eac4338345..a56ad5cb10 100644 --- a/src/gt4py/next/iterator/transforms/inline_lambdas.py +++ b/src/gt4py/next/iterator/transforms/inline_lambdas.py @@ -29,6 +29,7 @@ def inline_lambda( # noqa: C901 # see todo above opcount_preserving=False, force_inline_lift_args=False, force_inline_trivial_lift_args=False, + force_inline_lambda_args=False, eligible_params: Optional[list[bool]] = None, ): assert isinstance(node.fun, ir.Lambda) @@ -59,6 +60,12 @@ def inline_lambda( # noqa: C901 # see todo above if is_applied_lift(arg) and len(arg.args) == 0: eligible_params[i] = True + # inline lambdas passed as arguments + if force_inline_lambda_args: + for i, arg in enumerate(node.args): + if isinstance(arg, ir.Lambda): + eligible_params[i] = True + if node.fun.params and not any(eligible_params): return node @@ -120,6 +127,8 @@ class InlineLambdas(NodeTranslator): opcount_preserving: bool + force_inline_lambda_args: bool + force_inline_lift_args: bool force_inline_trivial_lift_args: bool @@ -129,6 +138,7 @@ def apply( cls, node: ir.Node, opcount_preserving=False, + force_inline_lambda_args=False, force_inline_lift_args=False, force_inline_trivial_lift_args=False, ): @@ -146,6 +156,8 @@ def apply( opcount_preserving: Preserve the number of operations, i.e. only inline lambda call if the resulting call has the same number of operations. + force_inline_lambda_args: Inline all arguments that are lambda calls, i.e. + `(λ(p) → p(a, a))(λ(x, y) → x+y)` force_inline_lift_args: Inline all arguments that are applied lifts, i.e. `lift(λ(...) → ...)(...)`. force_inline_trivial_lift_args: Inline all arguments that are trivial @@ -154,6 +166,7 @@ def apply( """ return cls( opcount_preserving=opcount_preserving, + force_inline_lambda_args=force_inline_lambda_args, force_inline_lift_args=force_inline_lift_args, force_inline_trivial_lift_args=force_inline_trivial_lift_args, ).visit(node) @@ -164,6 +177,7 @@ def visit_FunCall(self, node: ir.FunCall): return inline_lambda( node, opcount_preserving=self.opcount_preserving, + force_inline_lambda_args=self.force_inline_lambda_args, force_inline_lift_args=self.force_inline_lift_args, force_inline_trivial_lift_args=self.force_inline_trivial_lift_args, ) diff --git a/src/gt4py/next/iterator/transforms/pass_manager.py b/src/gt4py/next/iterator/transforms/pass_manager.py index b0db04eb5f..e2feb79c44 100644 --- a/src/gt4py/next/iterator/transforms/pass_manager.py +++ b/src/gt4py/next/iterator/transforms/pass_manager.py @@ -79,6 +79,7 @@ def apply_common_transforms( offset_provider=None, unroll_reduce=False, common_subexpression_elimination=True, + force_inline_lambda_args=False, unconditionally_collapse_tuples=False, ): if lift_mode is None: @@ -160,6 +161,10 @@ def apply_common_transforms( ir = CommonSubexpressionElimination().visit(ir) ir = MergeLet().visit(ir) - ir = InlineLambdas.apply(ir, opcount_preserving=True) + ir = InlineLambdas.apply( + ir, + opcount_preserving=True, + force_inline_lambda_args=force_inline_lambda_args, + ) return ir diff --git a/src/gt4py/next/program_processors/runners/dace_iterator/__init__.py b/src/gt4py/next/program_processors/runners/dace_iterator/__init__.py index 40b6d24b0e..d77792664e 100644 --- a/src/gt4py/next/program_processors/runners/dace_iterator/__init__.py +++ b/src/gt4py/next/program_processors/runners/dace_iterator/__init__.py @@ -84,6 +84,7 @@ def preprocess_program( fencil_definition = apply_common_transforms( program, common_subexpression_elimination=False, + force_inline_lambda_args=True, lift_mode=lift_mode, offset_provider=offset_provider, unroll_reduce=True, diff --git a/tests/next_tests/unit_tests/iterator_tests/transforms_tests/test_inline_lambdas.py b/tests/next_tests/unit_tests/iterator_tests/transforms_tests/test_inline_lambdas.py index 88e554f349..bf26889882 100644 --- a/tests/next_tests/unit_tests/iterator_tests/transforms_tests/test_inline_lambdas.py +++ b/tests/next_tests/unit_tests/iterator_tests/transforms_tests/test_inline_lambdas.py @@ -56,3 +56,25 @@ def test(name, opcount_preserving, testee, expected): inlined = InlineLambdas.apply(testee, opcount_preserving=opcount_preserving) assert inlined == expected + + +def test_inline_lambda_args(): + testee = im.let("reduce_step", im.lambda_("x", "y")(im.plus("x", "y")))( + im.lambda_("a")( + im.call("reduce_step")(im.call("reduce_step")(im.call("reduce_step")("a", 1), 2), 3) + ) + ) + expected = im.lambda_("a")( + im.call(im.lambda_("x", "y")(im.plus("x", "y")))( + im.call(im.lambda_("x", "y")(im.plus("x", "y")))( + im.call(im.lambda_("x", "y")(im.plus("x", "y")))("a", 1), 2 + ), + 3, + ) + ) + inlined = InlineLambdas.apply( + testee, + opcount_preserving=True, + force_inline_lambda_args=True, + ) + assert inlined == expected From 6e133543480f46202f5eb94e9906cb4d2356301a Mon Sep 17 00:00:00 2001 From: Florian Deconinck Date: Thu, 30 Nov 2023 09:00:06 -0500 Subject: [PATCH 19/85] fix[cartesian]: DaceIR bridge for DaCe v0.15 (#1373) * Adapt state struct codegen for indexing of dace:* stencil backend * Add new "Default" schedule type for dace <> gt4py schedule mapping * Missing key for make template render * Fix typo * Typo of the typo (tm) --- src/gt4py/cartesian/backend/dace_backend.py | 10 +++++++++- src/gt4py/cartesian/gtc/daceir.py | 1 + 2 files changed, 10 insertions(+), 1 deletion(-) diff --git a/src/gt4py/cartesian/backend/dace_backend.py b/src/gt4py/cartesian/backend/dace_backend.py index 60da2c36ff..11cd1fa895 100644 --- a/src/gt4py/cartesian/backend/dace_backend.py +++ b/src/gt4py/cartesian/backend/dace_backend.py @@ -451,7 +451,7 @@ class DaCeComputationCodegen: const int __I = domain[0]; const int __J = domain[1]; const int __K = domain[2]; - ${name}_t dace_handle; + ${name}_${state_suffix} dace_handle; ${backend_specifics} auto allocator = gt::sid::cached_allocator(&${allocator}); ${"\\n".join(tmp_allocs)} @@ -561,6 +561,13 @@ def apply(cls, stencil_ir: gtir.Stencil, builder: "StencilBuilder", sdfg: dace.S else: omp_threads = "" omp_header = "" + + # Backward compatible state struct name change in DaCe >=0.15.x + try: + dace_state_suffix = dace.Config.get("compiler.codegen_state_struct_suffix") + except (KeyError, TypeError): + dace_state_suffix = "t" # old structure name + interface = cls.template.definition.render( name=sdfg.name, backend_specifics=omp_threads, @@ -568,6 +575,7 @@ def apply(cls, stencil_ir: gtir.Stencil, builder: "StencilBuilder", sdfg: dace.S functor_args=self.generate_functor_args(sdfg), tmp_allocs=self.generate_tmp_allocs(sdfg), allocator="gt::cuda_util::cuda_malloc" if is_gpu else "std::make_unique", + state_suffix=dace_state_suffix, ) generated_code = textwrap.dedent( f"""#include diff --git a/src/gt4py/cartesian/gtc/daceir.py b/src/gt4py/cartesian/gtc/daceir.py index dc749a984b..28ebc8cd8e 100644 --- a/src/gt4py/cartesian/gtc/daceir.py +++ b/src/gt4py/cartesian/gtc/daceir.py @@ -101,6 +101,7 @@ def from_dace_schedule(cls, schedule): dace.ScheduleType.Default: MapSchedule.Default, dace.ScheduleType.Sequential: MapSchedule.Sequential, dace.ScheduleType.CPU_Multicore: MapSchedule.CPU_Multicore, + dace.ScheduleType.GPU_Default: MapSchedule.GPU_Device, dace.ScheduleType.GPU_Device: MapSchedule.GPU_Device, dace.ScheduleType.GPU_ThreadBlock: MapSchedule.GPU_ThreadBlock, }[schedule] From b1f9c9a567e01c14e7236b41fa95f09ae1bb3e2a Mon Sep 17 00:00:00 2001 From: edopao Date: Mon, 4 Dec 2023 09:15:16 +0100 Subject: [PATCH 20/85] feat[next][dace]: Support for sparse fields and reductions over lift expressions (#1377) This PR adds support to DaCe backend for sparse fields and reductions over lift expressions. --- .../runners/dace_iterator/itir_to_sdfg.py | 6 +- .../runners/dace_iterator/itir_to_tasklet.py | 181 +++++++++++++----- .../runners/dace_iterator/utility.py | 7 + tests/next_tests/exclusion_matrices.py | 3 - .../ffront_tests/test_execution.py | 1 + 5 files changed, 142 insertions(+), 56 deletions(-) diff --git a/src/gt4py/next/program_processors/runners/dace_iterator/itir_to_sdfg.py b/src/gt4py/next/program_processors/runners/dace_iterator/itir_to_sdfg.py index 94878fd46d..271a79c04b 100644 --- a/src/gt4py/next/program_processors/runners/dace_iterator/itir_to_sdfg.py +++ b/src/gt4py/next/program_processors/runners/dace_iterator/itir_to_sdfg.py @@ -150,7 +150,7 @@ def get_output_nodes( def visit_FencilDefinition(self, node: itir.FencilDefinition): program_sdfg = dace.SDFG(name=node.id) - last_state = program_sdfg.add_state("program_entry") + last_state = program_sdfg.add_state("program_entry", True) self.node_types = itir_typing.infer_all(node) # Filter neighbor tables from offset providers. @@ -216,7 +216,7 @@ def visit_StencilClosure( # Create the closure's nested SDFG and single state. closure_sdfg = dace.SDFG(name="closure") closure_state = closure_sdfg.add_state("closure_entry") - closure_init_state = closure_sdfg.add_state_before(closure_state, "closure_init") + closure_init_state = closure_sdfg.add_state_before(closure_state, "closure_init", True) input_names = [str(inp.id) for inp in node.inputs] neighbor_tables = filter_neighbor_tables(self.offset_provider) @@ -423,7 +423,7 @@ def _visit_scan_stencil_closure( scan_sdfg = dace.SDFG(name="scan") # create a state machine for lambda call over the scan dimension - start_state = scan_sdfg.add_state("start") + start_state = scan_sdfg.add_state("start", True) lambda_state = scan_sdfg.add_state("lambda_compute") end_state = scan_sdfg.add_state("end") diff --git a/src/gt4py/next/program_processors/runners/dace_iterator/itir_to_tasklet.py b/src/gt4py/next/program_processors/runners/dace_iterator/itir_to_tasklet.py index da54f9be14..de18446bbe 100644 --- a/src/gt4py/next/program_processors/runners/dace_iterator/itir_to_tasklet.py +++ b/src/gt4py/next/program_processors/runners/dace_iterator/itir_to_tasklet.py @@ -18,6 +18,7 @@ import dace import numpy as np +from dace import subsets from dace.transformation.dataflow import MapFusion from dace.transformation.passes.prune_symbols import RemoveUnusedSymbols @@ -39,6 +40,7 @@ filter_neighbor_tables, flatten_list, map_nested_sdfg_symbols, + new_array_symbols, unique_name, unique_var_name, ) @@ -131,9 +133,13 @@ def get_reduce_identity_value(op_name_: str, type_: Any): } +# Define type of variables used for field indexing +_INDEX_DTYPE = _TYPE_MAPPING["int64"] + + @dataclasses.dataclass class SymbolExpr: - value: str | dace.symbolic.sympy.Basic + value: dace.symbolic.SymbolicType dtype: dace.typeclass @@ -226,7 +232,7 @@ def builtin_neighbors( outputs={"__result"}, ) idx_name = unique_var_name() - sdfg.add_scalar(idx_name, dace.int64, transient=True) + sdfg.add_scalar(idx_name, _INDEX_DTYPE, transient=True) state.add_memlet_path( state.add_access(table_name), me, @@ -283,10 +289,12 @@ def builtin_can_deref( assert shift_callable.fun.id == "shift" iterator = transformer._visit_shift(can_deref_callable) + # this iterator is accessing a neighbor table, so it should return an index + assert iterator.dtype in dace.dtypes.INTEGER_TYPES # create tasklet to check that field indices are non-negative (-1 is invalid) - args = [ValueExpr(iterator.indices[dim], iterator.dtype) for dim in iterator.dimensions] + args = [ValueExpr(access_node, iterator.dtype) for access_node in iterator.indices.values()] internals = [f"{arg.value.data}_v" for arg in args] - expr_code = " && ".join([f"{v} >= 0" for v in internals]) + expr_code = " and ".join([f"{v} >= 0" for v in internals]) # TODO(edopao): select-memlet could maybe allow to efficiently translate can_deref to predicative execution return transformer.add_expr_tasklet( @@ -309,6 +317,26 @@ def builtin_if( return transformer.add_expr_tasklet(expr_args, expr, type_, "if") +def builtin_list_get( + transformer: "PythonTaskletCodegen", node: itir.Expr, node_args: list[itir.Expr] +) -> list[ValueExpr]: + args = list(itertools.chain(*transformer.visit(node_args))) + assert len(args) == 2 + # index node + assert isinstance(args[0], (SymbolExpr, ValueExpr)) + # 1D-array node + assert isinstance(args[1], ValueExpr) + # source node should be a 1D array + assert len(transformer.context.body.arrays[args[1].value.data].shape) == 1 + + expr_args = [(arg, f"{arg.value.data}_v") for arg in args if not isinstance(arg, SymbolExpr)] + internals = [ + arg.value if isinstance(arg, SymbolExpr) else f"{arg.value.data}_v" for arg in args + ] + expr = f"{internals[1]}[{internals[0]}]" + return transformer.add_expr_tasklet(expr_args, expr, args[1].dtype, "list_get") + + def builtin_cast( transformer: "PythonTaskletCodegen", node: itir.Expr, node_args: list[itir.Expr] ) -> list[ValueExpr]: @@ -340,16 +368,13 @@ def builtin_tuple_get( raise ValueError("Tuple can only be subscripted with compile-time constants") -def builtin_undefined(*args: Any) -> Any: - raise NotImplementedError() - - _GENERAL_BUILTIN_MAPPING: dict[ str, Callable[["PythonTaskletCodegen", itir.Expr, list[itir.Expr]], list[ValueExpr]] ] = { "can_deref": builtin_can_deref, "cast_": builtin_cast, "if_": builtin_if, + "list_get": builtin_list_get, "make_tuple": builtin_make_tuple, "neighbors": builtin_neighbors, "tuple_get": builtin_tuple_get, @@ -387,16 +412,11 @@ def _add_symbol(self, param, arg): elif isinstance(arg, IteratorExpr): # create storage in lambda sdfg ndims = len(arg.dimensions) - shape = tuple( - dace.symbol(unique_var_name() + "__shp", dace.int64) for _ in range(ndims) - ) - strides = tuple( - dace.symbol(unique_var_name() + "__strd", dace.int64) for _ in range(ndims) - ) + shape, strides = new_array_symbols(param, ndims) self._sdfg.add_array(param, shape=shape, strides=strides, dtype=arg.dtype) index_names = {dim: f"__{param}_i_{dim}" for dim in arg.indices.keys()} for _, index_name in index_names.items(): - self._sdfg.add_scalar(index_name, dtype=dace.int64) + self._sdfg.add_scalar(index_name, dtype=_INDEX_DTYPE) # update table of lambda symbol field = self._state.add_access(param) indices = { @@ -513,14 +533,7 @@ def visit_Lambda( # Add connectivities as arrays for name in connectivity_names: - shape = ( - dace.symbol(unique_var_name() + "__shp", dace.int64), - dace.symbol(unique_var_name() + "__shp", dace.int64), - ) - strides = ( - dace.symbol(unique_var_name() + "__strd", dace.int64), - dace.symbol(unique_var_name() + "__strd", dace.int64), - ) + shape, strides = new_array_symbols(name, ndim=2) dtype = self.context.body.arrays[name].dtype lambda_sdfg.add_array(name, shape=shape, strides=strides, dtype=dtype) @@ -542,11 +555,9 @@ def visit_Lambda( result_name = unique_var_name() lambda_sdfg.add_scalar(result_name, expr.dtype, transient=True) result_access = lambda_state.add_access(result_name) - lambda_state.add_edge( + lambda_state.add_nedge( expr.value, - None, result_access, - None, # in case of reduction lambda, the output edge from lambda tasklet performs write-conflict resolution dace.Memlet.simple(result_access.data, "0", wcr_str=self.context.reduce_wcr), ) @@ -587,12 +598,13 @@ def visit_FunCall(self, node: itir.FunCall) -> list[ValueExpr] | IteratorExpr: return self._visit_reduce(node) if isinstance(node.fun, itir.SymRef): - if str(node.fun.id) in _MATH_BUILTINS_MAPPING: + builtin_name = str(node.fun.id) + if builtin_name in _MATH_BUILTINS_MAPPING: return self._visit_numeric_builtin(node) - elif str(node.fun.id) in _GENERAL_BUILTIN_MAPPING: + elif builtin_name in _GENERAL_BUILTIN_MAPPING: return self._visit_general_builtin(node) else: - raise NotImplementedError() + raise NotImplementedError(f"{builtin_name} not implemented") return self._visit_call(node) def _visit_call(self, node: itir.FunCall): @@ -697,7 +709,7 @@ def _visit_deref(self, node: itir.FunCall) -> list[ValueExpr]: for dim in sorted_dims ] args = [ValueExpr(iterator.field, iterator.dtype)] + [ - ValueExpr(iterator.indices[dim], iterator.dtype) for dim in iterator.indices + ValueExpr(iterator.indices[dim], _INDEX_DTYPE) for dim in iterator.indices ] internals = [f"{arg.value.data}_v" for arg in args] @@ -726,14 +738,88 @@ def _visit_deref(self, node: itir.FunCall) -> list[ValueExpr]: return [ValueExpr(value=result_access, dtype=iterator.dtype)] - else: + elif all([dim in iterator.indices for dim in iterator.dimensions]): + # The deref iterator has index values on all dimensions: the result will be a scalar args = [ValueExpr(iterator.field, iterator.dtype)] + [ - ValueExpr(iterator.indices[dim], iterator.dtype) for dim in sorted_dims + ValueExpr(iterator.indices[dim], _INDEX_DTYPE) for dim in sorted_dims ] internals = [f"{arg.value.data}_v" for arg in args] expr = f"{internals[0]}[{', '.join(internals[1:])}]" return self.add_expr_tasklet(list(zip(args, internals)), expr, iterator.dtype, "deref") + else: + # Not all dimensions are included in the deref index list: + # this means the ND-field will be sliced along one or more dimensions and the result will be an array + field_array = self.context.body.arrays[iterator.field.data] + result_shape = tuple( + dim_size + for dim, dim_size in zip(sorted_dims, field_array.shape) + if dim not in iterator.indices + ) + result_name = unique_var_name() + self.context.body.add_array(result_name, result_shape, iterator.dtype, transient=True) + result_array = self.context.body.arrays[result_name] + result_node = self.context.state.add_access(result_name) + + deref_connectors = ["_inp"] + [ + f"_i_{dim}" for dim in sorted_dims if dim in iterator.indices + ] + deref_nodes = [iterator.field] + [ + iterator.indices[dim] for dim in sorted_dims if dim in iterator.indices + ] + deref_memlets = [dace.Memlet.from_array(iterator.field.data, field_array)] + [ + dace.Memlet.simple(node.data, "0") for node in deref_nodes[1:] + ] + + # we create a nested sdfg in order to access the index scalar values as symbols in a memlet subset + deref_sdfg = dace.SDFG("deref") + deref_sdfg.add_array( + "_inp", field_array.shape, iterator.dtype, strides=field_array.strides + ) + for connector in deref_connectors[1:]: + deref_sdfg.add_scalar(connector, _INDEX_DTYPE) + deref_sdfg.add_array("_out", result_shape, iterator.dtype) + deref_init_state = deref_sdfg.add_state("init", True) + deref_access_state = deref_sdfg.add_state("access") + deref_sdfg.add_edge( + deref_init_state, + deref_access_state, + dace.InterstateEdge( + assignments={f"_sym{inp}": inp for inp in deref_connectors[1:]} + ), + ) + # we access the size in source field shape as symbols set on the nested sdfg + source_subset = tuple( + f"_sym_i_{dim}" if dim in iterator.indices else f"0:{size}" + for dim, size in zip(sorted_dims, field_array.shape) + ) + deref_access_state.add_nedge( + deref_access_state.add_access("_inp"), + deref_access_state.add_access("_out"), + dace.Memlet( + data="_out", + subset=subsets.Range.from_array(result_array), + other_subset=",".join(source_subset), + ), + ) + + deref_node = self.context.state.add_nested_sdfg( + deref_sdfg, + self.context.body, + inputs=set(deref_connectors), + outputs={"_out"}, + ) + for connector, node, memlet in zip(deref_connectors, deref_nodes, deref_memlets): + self.context.state.add_edge(node, None, deref_node, connector, memlet) + self.context.state.add_edge( + deref_node, + "_out", + result_node, + None, + dace.Memlet.from_array(result_name, result_array), + ) + return [ValueExpr(result_node, iterator.dtype)] + def _split_shift_args( self, args: list[itir.Expr] ) -> tuple[list[itir.Expr], Optional[list[itir.Expr]]]: @@ -760,6 +846,7 @@ def _visit_shift(self, node: itir.FunCall) -> IteratorExpr: offset_dim = tail[0].value assert isinstance(offset_dim, str) offset_node = self.visit(tail[1])[0] + assert offset_node.dtype in dace.dtypes.INTEGER_TYPES if isinstance(self.offset_provider[offset_dim], NeighborTableOffsetProvider): offset_provider = self.offset_provider[offset_dim] @@ -769,7 +856,7 @@ def _visit_shift(self, node: itir.FunCall) -> IteratorExpr: target_dim = offset_provider.neighbor_axis.value args = [ ValueExpr(connectivity, offset_provider.table.dtype), - ValueExpr(iterator.indices[shifted_dim], dace.int64), + ValueExpr(iterator.indices[shifted_dim], offset_node.dtype), offset_node, ] internals = [f"{arg.value.data}_v" for arg in args] @@ -780,7 +867,7 @@ def _visit_shift(self, node: itir.FunCall) -> IteratorExpr: shifted_dim = offset_provider.origin_axis.value target_dim = offset_provider.neighbor_axis.value args = [ - ValueExpr(iterator.indices[shifted_dim], dace.int64), + ValueExpr(iterator.indices[shifted_dim], offset_node.dtype), offset_node, ] internals = [f"{arg.value.data}_v" for arg in args] @@ -791,14 +878,14 @@ def _visit_shift(self, node: itir.FunCall) -> IteratorExpr: shifted_dim = self.offset_provider[offset_dim].value target_dim = shifted_dim args = [ - ValueExpr(iterator.indices[shifted_dim], dace.int64), + ValueExpr(iterator.indices[shifted_dim], offset_node.dtype), offset_node, ] internals = [f"{arg.value.data}_v" for arg in args] expr = f"{internals[0]} + {internals[1]}" shifted_value = self.add_expr_tasklet( - list(zip(args, internals)), expr, dace.dtypes.int64, "shift" + list(zip(args, internals)), expr, offset_node.dtype, "shift" )[0].value shifted_index = {dim: value for dim, value in iterator.indices.items()} @@ -811,7 +898,7 @@ def visit_OffsetLiteral(self, node: itir.OffsetLiteral) -> list[ValueExpr]: offset = node.value assert isinstance(offset, int) offset_var = unique_var_name() - self.context.body.add_scalar(offset_var, dace.dtypes.int64, transient=True) + self.context.body.add_scalar(offset_var, _INDEX_DTYPE, transient=True) offset_node = self.context.state.add_access(offset_var) tasklet_node = self.context.state.add_tasklet( "get_offset", {}, {"__out"}, f"__out = {offset}" @@ -906,7 +993,7 @@ def _visit_reduce(self, node: itir.FunCall): # initialize the reduction result based on type of operation init_value = get_reduce_identity_value(op_name.id, result_dtype) - init_state = self.context.body.add_state_before(self.context.state, "init") + init_state = self.context.body.add_state_before(self.context.state, "init", True) init_tasklet = init_state.add_tasklet( "init_reduce", {}, {"__out"}, f"__out = {init_value}" ) @@ -1044,13 +1131,13 @@ def closure_to_tasklet_sdfg( node_types: dict[int, next_typing.Type], ) -> tuple[Context, Sequence[ValueExpr]]: body = dace.SDFG("tasklet_toplevel") - state = body.add_state("tasklet_toplevel_entry") + state = body.add_state("tasklet_toplevel_entry", True) symbol_map: dict[str, TaskletExpr] = {} idx_accesses = {} for dim, idx in domain.items(): name = f"{idx}_value" - body.add_scalar(name, dtype=dace.int64, transient=True) + body.add_scalar(name, dtype=_INDEX_DTYPE, transient=True) tasklet = state.add_tasklet(f"get_{dim}", set(), {"value"}, f"value = {idx}") access = state.add_access(name) idx_accesses[dim] = access @@ -1058,15 +1145,10 @@ def closure_to_tasklet_sdfg( for name, ty in inputs: if isinstance(ty, ts.FieldType): ndim = len(ty.dims) - shape = [ - dace.symbol(f"{unique_var_name()}_shp{i}", dtype=dace.int64) for i in range(ndim) - ] - stride = [ - dace.symbol(f"{unique_var_name()}_strd{i}", dtype=dace.int64) for i in range(ndim) - ] + shape, strides = new_array_symbols(name, ndim) dims = [dim.value for dim in ty.dims] dtype = as_dace_type(ty.dtype) - body.add_array(name, shape=shape, strides=stride, dtype=dtype) + body.add_array(name, shape=shape, strides=strides, dtype=dtype) field = state.add_access(name) indices = {dim: idx_accesses[dim] for dim in domain.keys()} symbol_map[name] = IteratorExpr(field, indices, dtype, dims) @@ -1076,9 +1158,8 @@ def closure_to_tasklet_sdfg( body.add_scalar(name, dtype=dtype) symbol_map[name] = ValueExpr(state.add_access(name), dtype) for arr, name in connectivities: - shape = [dace.symbol(f"{unique_var_name()}_shp{i}", dtype=dace.int64) for i in range(2)] - stride = [dace.symbol(f"{unique_var_name()}_strd{i}", dtype=dace.int64) for i in range(2)] - body.add_array(name, shape=shape, strides=stride, dtype=arr.dtype) + shape, strides = new_array_symbols(name, ndim=2) + body.add_array(name, shape=shape, strides=strides, dtype=arr.dtype) context = Context(body, state, symbol_map) translator = PythonTaskletCodegen(offset_provider, context, node_types) diff --git a/src/gt4py/next/program_processors/runners/dace_iterator/utility.py b/src/gt4py/next/program_processors/runners/dace_iterator/utility.py index c17a39ef2d..5ae4676cd7 100644 --- a/src/gt4py/next/program_processors/runners/dace_iterator/utility.py +++ b/src/gt4py/next/program_processors/runners/dace_iterator/utility.py @@ -166,6 +166,13 @@ def unique_var_name(): return unique_name("__var") +def new_array_symbols(name: str, ndim: int) -> tuple[list[dace.symbol], list[dace.symbol]]: + dtype = dace.int64 + shape = [dace.symbol(unique_name(f"{name}_shp{i}"), dtype) for i in range(ndim)] + strides = [dace.symbol(unique_name(f"{name}_strd{i}"), dtype) for i in range(ndim)] + return shape, strides + + def flatten_list(node_list: list[Any]) -> list[Any]: return list( itertools.chain.from_iterable( diff --git a/tests/next_tests/exclusion_matrices.py b/tests/next_tests/exclusion_matrices.py index a6a302e143..84287e209f 100644 --- a/tests/next_tests/exclusion_matrices.py +++ b/tests/next_tests/exclusion_matrices.py @@ -122,12 +122,9 @@ class ProgramFormatterId(_PythonObjectIdMixin, str, enum.Enum): ] DACE_SKIP_TEST_LIST = COMMON_SKIP_TEST_LIST + [ (USES_CONSTANT_FIELDS, XFAIL, UNSUPPORTED_MESSAGE), - (USES_DYNAMIC_OFFSETS, XFAIL, UNSUPPORTED_MESSAGE), (USES_INDEX_FIELDS, XFAIL, UNSUPPORTED_MESSAGE), (USES_LIFT_EXPRESSIONS, XFAIL, UNSUPPORTED_MESSAGE), (USES_ORIGIN, XFAIL, UNSUPPORTED_MESSAGE), - (USES_REDUCTION_OVER_LIFT_EXPRESSIONS, XFAIL, UNSUPPORTED_MESSAGE), - (USES_SPARSE_FIELDS, XFAIL, UNSUPPORTED_MESSAGE), (USES_TUPLE_ARGS, XFAIL, UNSUPPORTED_MESSAGE), (USES_TUPLE_RETURNS, XFAIL, UNSUPPORTED_MESSAGE), (USES_ZERO_DIMENSIONAL_FIELDS, XFAIL, UNSUPPORTED_MESSAGE), diff --git a/tests/next_tests/integration_tests/feature_tests/ffront_tests/test_execution.py b/tests/next_tests/integration_tests/feature_tests/ffront_tests/test_execution.py index cf273a4524..7f37b41383 100644 --- a/tests/next_tests/integration_tests/feature_tests/ffront_tests/test_execution.py +++ b/tests/next_tests/integration_tests/feature_tests/ffront_tests/test_execution.py @@ -698,6 +698,7 @@ def testee( ) +@pytest.mark.uses_constant_fields @pytest.mark.uses_unstructured_shift @pytest.mark.uses_reduction_over_lift_expressions def test_ternary_builtin_neighbor_sum(unstructured_case): From 8a22ba7f1f6f49c2a0065b9b14fb6c417d3bbb78 Mon Sep 17 00:00:00 2001 From: edopao Date: Mon, 4 Dec 2023 12:27:30 +0100 Subject: [PATCH 21/85] feat[next][dace]: Support for reduce-unroll special case (#1381) During integration of icon4py stencils with the DaCe backend, it was found that reduce-unroll can generate an ITIR containing can_deref on a scalar value. Such expression should always evaluate to true, so it can be evaluated at compile-time. Note that in theory such case could be detected by the ITIR pass, once ITIR type inference is replaced by a new solution. At that time, the solution proposed here should be removed. --- .../runners/dace_iterator/itir_to_tasklet.py | 29 ++++++++++++++++--- 1 file changed, 25 insertions(+), 4 deletions(-) diff --git a/src/gt4py/next/program_processors/runners/dace_iterator/itir_to_tasklet.py b/src/gt4py/next/program_processors/runners/dace_iterator/itir_to_tasklet.py index de18446bbe..4fa5ae239c 100644 --- a/src/gt4py/next/program_processors/runners/dace_iterator/itir_to_tasklet.py +++ b/src/gt4py/next/program_processors/runners/dace_iterator/itir_to_tasklet.py @@ -289,10 +289,25 @@ def builtin_can_deref( assert shift_callable.fun.id == "shift" iterator = transformer._visit_shift(can_deref_callable) - # this iterator is accessing a neighbor table, so it should return an index - assert iterator.dtype in dace.dtypes.INTEGER_TYPES + # TODO: remove this special case when ITIR reduce-unroll pass is able to catch it + if not isinstance(iterator, IteratorExpr): + assert len(iterator) == 1 and isinstance(iterator[0], ValueExpr) + # We can always deref a value expression, therefore hard-code `can_deref` to True. + # Returning a SymbolExpr would be preferable, but it requires update to type-checking. + result_name = unique_var_name() + transformer.context.body.add_scalar(result_name, dace.dtypes.bool, transient=True) + result_node = transformer.context.state.add_access(result_name) + transformer.context.state.add_edge( + transformer.context.state.add_tasklet("can_always_deref", {}, {"_out"}, "_out = True"), + "_out", + result_node, + None, + dace.Memlet.simple(result_name, "0"), + ) + return [ValueExpr(result_node, dace.dtypes.bool)] + # create tasklet to check that field indices are non-negative (-1 is invalid) - args = [ValueExpr(access_node, iterator.dtype) for access_node in iterator.indices.values()] + args = [ValueExpr(access_node, _INDEX_DTYPE) for access_node in iterator.indices.values()] internals = [f"{arg.value.data}_v" for arg in args] expr_code = " and ".join([f"{v} >= 0" for v in internals]) @@ -833,7 +848,7 @@ def _make_shift_for_rest(self, rest, iterator): fun=itir.FunCall(fun=itir.SymRef(id="shift"), args=rest), args=[iterator] ) - def _visit_shift(self, node: itir.FunCall) -> IteratorExpr: + def _visit_shift(self, node: itir.FunCall) -> IteratorExpr | list[ValueExpr]: shift = node.fun assert isinstance(shift, itir.FunCall) tail, rest = self._split_shift_args(shift.args) @@ -841,6 +856,12 @@ def _visit_shift(self, node: itir.FunCall) -> IteratorExpr: iterator = self.visit(self._make_shift_for_rest(rest, node.args[0])) else: iterator = self.visit(node.args[0]) + if not isinstance(iterator, IteratorExpr): + # shift cannot be applied because the argument is not iterable + # TODO: remove this special case when ITIR reduce-unroll pass is able to catch it + assert isinstance(iterator, list) and len(iterator) == 1 + assert isinstance(iterator[0], ValueExpr) + return iterator assert isinstance(tail[0], itir.OffsetLiteral) offset_dim = tail[0].value From ebb70b65487d29ef53c1f1fe95d74509c33146aa Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Philip=20M=C3=BCller?= <147368808+philip-paul-mueller@users.noreply.github.com> Date: Mon, 4 Dec 2023 12:56:12 +0100 Subject: [PATCH 22/85] fix[next]: Proper calling signature in DaCe (#1374) This commit adds positional arguments to the generated SDFG. It also improves the naming of some automatically generated symbols, such as the shape and stride. --- .../runners/dace_iterator/__init__.py | 31 ++++++++++++++++--- .../runners/dace_iterator/itir_to_sdfg.py | 9 ++++-- .../runners/dace_iterator/utility.py | 15 ++++----- 3 files changed, 38 insertions(+), 17 deletions(-) diff --git a/src/gt4py/next/program_processors/runners/dace_iterator/__init__.py b/src/gt4py/next/program_processors/runners/dace_iterator/__init__.py index d77792664e..735c6b6284 100644 --- a/src/gt4py/next/program_processors/runners/dace_iterator/__init__.py +++ b/src/gt4py/next/program_processors/runners/dace_iterator/__init__.py @@ -185,12 +185,15 @@ def get_cache_id( return m.hexdigest() -def run_dace_iterator(program: itir.FencilDefinition, *args, **kwargs) -> None: +def run_dace_iterator(program: itir.FencilDefinition, *args, **kwargs) -> Optional[dace.SDFG]: # build parameters auto_optimize = kwargs.get("auto_optimize", False) build_cache = kwargs.get("build_cache", None) build_type = kwargs.get("build_type", "RelWithDebInfo") run_on_gpu = kwargs.get("run_on_gpu", False) + # Return parameter + return_sdfg = kwargs.get("return_sdfg", False) + run_sdfg = kwargs.get("run_sdfg", True) # ITIR parameters column_axis = kwargs.get("column_axis", None) lift_mode = ( @@ -212,6 +215,18 @@ def run_dace_iterator(program: itir.FencilDefinition, *args, **kwargs) -> None: program = preprocess_program(program, offset_provider, lift_mode) sdfg_genenerator = ItirToSDFG(arg_types, offset_provider, column_axis, run_on_gpu) sdfg = sdfg_genenerator.visit(program) + + # All arguments required by the SDFG, regardless if explicit and implicit, are added + # as positional arguments. In the front are all arguments to the Fencil, in that + # order, they are followed by the arguments created by the translation process, + # their order is determined by DaCe and unspecific. + assert len(sdfg.arg_names) == 0 + arg_list = [str(a) for a in program.params] + sig_list = sdfg.signature_arglist(with_types=False) + implicit_args = set(sig_list) - set(arg_list) + call_params = arg_list + [ia for ia in sig_list if ia in implicit_args] + sdfg.arg_names = call_params + sdfg.simplify() # run DaCe auto-optimization heuristics @@ -256,10 +271,16 @@ def run_dace_iterator(program: itir.FencilDefinition, *args, **kwargs) -> None: if key in sdfg.signature_arglist(with_types=False) } - with dace.config.temporary_config(): - dace.config.Config.set("compiler", "allow_view_arguments", value=True) - dace.config.Config.set("frontend", "check_args", value=True) - sdfg_program(**expected_args) + if run_sdfg: + with dace.config.temporary_config(): + dace.config.Config.set("compiler", "allow_view_arguments", value=True) + dace.config.Config.set("frontend", "check_args", value=True) + sdfg_program(**expected_args) + # + + if return_sdfg: + return sdfg + return None def _run_dace_cpu(program: itir.FencilDefinition, *args, **kwargs) -> None: diff --git a/src/gt4py/next/program_processors/runners/dace_iterator/itir_to_sdfg.py b/src/gt4py/next/program_processors/runners/dace_iterator/itir_to_sdfg.py index 271a79c04b..7a6f359771 100644 --- a/src/gt4py/next/program_processors/runners/dace_iterator/itir_to_sdfg.py +++ b/src/gt4py/next/program_processors/runners/dace_iterator/itir_to_sdfg.py @@ -42,6 +42,8 @@ flatten_list, get_sorted_dims, map_nested_sdfg_symbols, + new_array_symbols, + unique_name, unique_var_name, ) @@ -114,10 +116,9 @@ def __init__( def add_storage(self, sdfg: dace.SDFG, name: str, type_: ts.TypeSpec, has_offset: bool = True): if isinstance(type_, ts.FieldType): - shape = [dace.symbol(unique_var_name()) for _ in range(len(type_.dims))] - strides = [dace.symbol(unique_var_name()) for _ in range(len(type_.dims))] + shape, strides = new_array_symbols(name, len(type_.dims)) offset = ( - [dace.symbol(unique_var_name()) for _ in range(len(type_.dims))] + [dace.symbol(unique_name(f"{name}_offset{i}_")) for i in range(len(type_.dims))] if has_offset else None ) @@ -130,8 +131,10 @@ def add_storage(self, sdfg: dace.SDFG, name: str, type_: ts.TypeSpec, has_offset sdfg.add_array( name, shape=shape, strides=strides, offset=offset, dtype=dtype, storage=storage ) + elif isinstance(type_, ts.ScalarType): sdfg.add_symbol(name, as_dace_type(type_)) + else: raise NotImplementedError() self.storage_types[name] = type_ diff --git a/src/gt4py/next/program_processors/runners/dace_iterator/utility.py b/src/gt4py/next/program_processors/runners/dace_iterator/utility.py index 5ae4676cd7..cb14b89e8a 100644 --- a/src/gt4py/next/program_processors/runners/dace_iterator/utility.py +++ b/src/gt4py/next/program_processors/runners/dace_iterator/utility.py @@ -153,23 +153,20 @@ def add_mapped_nested_sdfg( return nsdfg_node, map_entry, map_exit -_unique_id = 0 - - def unique_name(prefix): - global _unique_id - _unique_id += 1 - return f"{prefix}_{_unique_id}" + unique_id = getattr(unique_name, "_unique_id", 0) # noqa: B010 # static variable + setattr(unique_name, "_unique_id", unique_id + 1) # noqa: B010 # static variable + return f"{prefix}_{unique_id}" def unique_var_name(): - return unique_name("__var") + return unique_name("_var") def new_array_symbols(name: str, ndim: int) -> tuple[list[dace.symbol], list[dace.symbol]]: dtype = dace.int64 - shape = [dace.symbol(unique_name(f"{name}_shp{i}"), dtype) for i in range(ndim)] - strides = [dace.symbol(unique_name(f"{name}_strd{i}"), dtype) for i in range(ndim)] + shape = [dace.symbol(unique_name(f"{name}_shape{i}"), dtype) for i in range(ndim)] + strides = [dace.symbol(unique_name(f"{name}_stride{i}"), dtype) for i in range(ndim)] return shape, strides From d7cf10fb31de4e60b33c25a6807e07605d5ecde0 Mon Sep 17 00:00:00 2001 From: Florian Deconinck Date: Mon, 4 Dec 2023 09:19:33 -0500 Subject: [PATCH 23/85] DaCe 0.15 suffix state struct hotfix (#1382) --- src/gt4py/cartesian/backend/dace_backend.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/gt4py/cartesian/backend/dace_backend.py b/src/gt4py/cartesian/backend/dace_backend.py index 11cd1fa895..b1e559a41e 100644 --- a/src/gt4py/cartesian/backend/dace_backend.py +++ b/src/gt4py/cartesian/backend/dace_backend.py @@ -451,7 +451,7 @@ class DaCeComputationCodegen: const int __I = domain[0]; const int __J = domain[1]; const int __K = domain[2]; - ${name}_${state_suffix} dace_handle; + ${name}${state_suffix} dace_handle; ${backend_specifics} auto allocator = gt::sid::cached_allocator(&${allocator}); ${"\\n".join(tmp_allocs)} @@ -566,7 +566,7 @@ def apply(cls, stencil_ir: gtir.Stencil, builder: "StencilBuilder", sdfg: dace.S try: dace_state_suffix = dace.Config.get("compiler.codegen_state_struct_suffix") except (KeyError, TypeError): - dace_state_suffix = "t" # old structure name + dace_state_suffix = "_t" # old structure name interface = cls.template.definition.render( name=sdfg.name, From 9f2ed1e41b50bd1d01a2a861999b5b44d6c9114b Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Philip=20M=C3=BCller?= <147368808+philip-paul-mueller@users.noreply.github.com> Date: Tue, 5 Dec 2023 08:07:57 +0100 Subject: [PATCH 24/85] feat[next]: Separates ITIR -> SDFG translation from running (#1379) Before it was only possible to translate ITIR to SDFG and execute it and it was not possible to extract the SDFG. This commits splits this task into two parts and thus allows to perform the ITIR to SDFG translation without executing it. --- .../runners/dace_iterator/__init__.py | 117 +++++++++++------- .../runners/dace_iterator/itir_to_sdfg.py | 10 ++ 2 files changed, 79 insertions(+), 48 deletions(-) diff --git a/src/gt4py/next/program_processors/runners/dace_iterator/__init__.py b/src/gt4py/next/program_processors/runners/dace_iterator/__init__.py index 735c6b6284..34ba2d2d95 100644 --- a/src/gt4py/next/program_processors/runners/dace_iterator/__init__.py +++ b/src/gt4py/next/program_processors/runners/dace_iterator/__init__.py @@ -185,58 +185,85 @@ def get_cache_id( return m.hexdigest() -def run_dace_iterator(program: itir.FencilDefinition, *args, **kwargs) -> Optional[dace.SDFG]: +def build_sdfg_from_itir( + program: itir.FencilDefinition, + *args, + offset_provider: dict[str, Any], + auto_optimize: bool = False, + on_gpu: bool = False, + column_axis: Optional[Dimension] = None, + lift_mode: LiftMode = LiftMode.FORCE_INLINE, +) -> dace.SDFG: + """Translate a Fencil into an SDFG. + + Args: + program: The Fencil that should be translated. + *args: Arguments for which the fencil should be called. + offset_provider: The set of offset providers that should be used. + auto_optimize: Apply DaCe's `auto_optimize` heuristic. + on_gpu: Performs the translation for GPU, defaults to `False`. + column_axis: The column axis to be used, defaults to `None`. + lift_mode: Which lift mode should be used, defaults `FORCE_INLINE`. + + Notes: + Currently only the `FORCE_INLINE` liftmode is supported and the value of `lift_mode` is ignored. + """ + # TODO(edopao): As temporary fix until temporaries are supported in the DaCe Backend force + # `lift_more` to `FORCE_INLINE` mode. + lift_mode = LiftMode.FORCE_INLINE + + arg_types = [type_translation.from_value(arg) for arg in args] + device = dace.DeviceType.GPU if on_gpu else dace.DeviceType.CPU + + # visit ITIR and generate SDFG + program = preprocess_program(program, offset_provider, lift_mode) + sdfg_genenerator = ItirToSDFG(arg_types, offset_provider, column_axis, on_gpu) + sdfg = sdfg_genenerator.visit(program) + sdfg.simplify() + + # run DaCe auto-optimization heuristics + if auto_optimize: + # TODO Investigate how symbol definitions improve autoopt transformations, + # in which case the cache table should take the symbols map into account. + symbols: dict[str, int] = {} + sdfg = autoopt.auto_optimize(sdfg, device, symbols=symbols, use_gpu_storage=on_gpu) + + return sdfg + + +def run_dace_iterator(program: itir.FencilDefinition, *args, **kwargs): # build parameters - auto_optimize = kwargs.get("auto_optimize", False) build_cache = kwargs.get("build_cache", None) build_type = kwargs.get("build_type", "RelWithDebInfo") - run_on_gpu = kwargs.get("run_on_gpu", False) - # Return parameter - return_sdfg = kwargs.get("return_sdfg", False) - run_sdfg = kwargs.get("run_sdfg", True) + on_gpu = kwargs.get("on_gpu", False) + auto_optimize = kwargs.get("auto_optimize", False) + lift_mode = kwargs.get("lift_mode", LiftMode.FORCE_INLINE) # ITIR parameters column_axis = kwargs.get("column_axis", None) - lift_mode = ( - LiftMode.FORCE_INLINE - ) # TODO(edopao): make it configurable once temporaries are supported in DaCe backend offset_provider = kwargs["offset_provider"] arg_types = [type_translation.from_value(arg) for arg in args] - device = dace.DeviceType.GPU if run_on_gpu else dace.DeviceType.CPU + device = dace.DeviceType.GPU if on_gpu else dace.DeviceType.CPU neighbor_tables = filter_neighbor_tables(offset_provider) cache_id = get_cache_id(program, arg_types, column_axis, offset_provider) + sdfg: Optional[dace.SDFG] = None if build_cache is not None and cache_id in build_cache: # retrieve SDFG program from build cache sdfg_program = build_cache[cache_id] sdfg = sdfg_program.sdfg + else: - # visit ITIR and generate SDFG - program = preprocess_program(program, offset_provider, lift_mode) - sdfg_genenerator = ItirToSDFG(arg_types, offset_provider, column_axis, run_on_gpu) - sdfg = sdfg_genenerator.visit(program) - - # All arguments required by the SDFG, regardless if explicit and implicit, are added - # as positional arguments. In the front are all arguments to the Fencil, in that - # order, they are followed by the arguments created by the translation process, - # their order is determined by DaCe and unspecific. - assert len(sdfg.arg_names) == 0 - arg_list = [str(a) for a in program.params] - sig_list = sdfg.signature_arglist(with_types=False) - implicit_args = set(sig_list) - set(arg_list) - call_params = arg_list + [ia for ia in sig_list if ia in implicit_args] - sdfg.arg_names = call_params - - sdfg.simplify() - - # run DaCe auto-optimization heuristics - if auto_optimize: - # TODO Investigate how symbol definitions improve autoopt transformations, - # in which case the cache table should take the symbols map into account. - symbols: dict[str, int] = {} - sdfg = autoopt.auto_optimize(sdfg, device, symbols=symbols, use_gpu_storage=run_on_gpu) - - # compile SDFG and retrieve SDFG program + sdfg = build_sdfg_from_itir( + program, + *args, + offset_provider=offset_provider, + auto_optimize=auto_optimize, + on_gpu=on_gpu, + column_axis=column_axis, + lift_mode=lift_mode, + ) + sdfg.build_folder = cache._session_cache_dir_path / ".dacecache" with dace.config.temporary_config(): dace.config.Config.set("compiler", "build_type", value=build_type) @@ -271,16 +298,10 @@ def run_dace_iterator(program: itir.FencilDefinition, *args, **kwargs) -> Option if key in sdfg.signature_arglist(with_types=False) } - if run_sdfg: - with dace.config.temporary_config(): - dace.config.Config.set("compiler", "allow_view_arguments", value=True) - dace.config.Config.set("frontend", "check_args", value=True) - sdfg_program(**expected_args) - # - - if return_sdfg: - return sdfg - return None + with dace.config.temporary_config(): + dace.config.Config.set("compiler", "allow_view_arguments", value=True) + dace.config.Config.set("frontend", "check_args", value=True) + sdfg_program(**expected_args) def _run_dace_cpu(program: itir.FencilDefinition, *args, **kwargs) -> None: @@ -290,7 +311,7 @@ def _run_dace_cpu(program: itir.FencilDefinition, *args, **kwargs) -> None: **kwargs, build_cache=_build_cache_cpu, build_type=_build_type, - run_on_gpu=False, + on_gpu=False, ) @@ -308,7 +329,7 @@ def _run_dace_gpu(program: itir.FencilDefinition, *args, **kwargs) -> None: **kwargs, build_cache=_build_cache_gpu, build_type=_build_type, - run_on_gpu=True, + on_gpu=True, ) else: diff --git a/src/gt4py/next/program_processors/runners/dace_iterator/itir_to_sdfg.py b/src/gt4py/next/program_processors/runners/dace_iterator/itir_to_sdfg.py index 7a6f359771..b3e6662623 100644 --- a/src/gt4py/next/program_processors/runners/dace_iterator/itir_to_sdfg.py +++ b/src/gt4py/next/program_processors/runners/dace_iterator/itir_to_sdfg.py @@ -208,6 +208,16 @@ def visit_FencilDefinition(self, node: itir.FencilDefinition): access_node = last_state.add_access(inner_name) last_state.add_edge(nsdfg_node, inner_name, access_node, None, memlet) + # Create the call signature for the SDFG. + # All arguments required by the SDFG, regardless if explicit and implicit, are added + # as positional arguments. In the front are all arguments to the Fencil, in that + # order, they are followed by the arguments created by the translation process, + arg_list = [str(a) for a in node.params] + sig_list = program_sdfg.signature_arglist(with_types=False) + implicit_args = set(sig_list) - set(arg_list) + call_params = arg_list + [ia for ia in sig_list if ia in implicit_args] + program_sdfg.arg_names = call_params + program_sdfg.validate() return program_sdfg From c547f536930c039f03a734ccb85f464fe4ba062d Mon Sep 17 00:00:00 2001 From: edopao Date: Tue, 5 Dec 2023 09:37:34 +0100 Subject: [PATCH 25/85] fix[next][dace]: Add check on neighbor index (#1383) Most part of this PR is about cleaning up and refactoring the code in DaCe backend for translation of neighbor reduction. As part of the refactoring, the backend is now using the reduce library node from DaCe library. However, this PR also contains one functional change, which is a fix. Neighbor reduction should check for validity of neighbor index. This means for neighbor-tables to check that the neighbor index stored in the table is not -1. For neighbor strided offsets, we should check that the neighbor index does not access the origin field out of boundary. --- .../runners/dace_iterator/itir_to_tasklet.py | 308 ++++++------------ .../ffront_tests/test_gt4py_builtins.py | 30 ++ 2 files changed, 137 insertions(+), 201 deletions(-) diff --git a/src/gt4py/next/program_processors/runners/dace_iterator/itir_to_tasklet.py b/src/gt4py/next/program_processors/runners/dace_iterator/itir_to_tasklet.py index 4fa5ae239c..f6f197859b 100644 --- a/src/gt4py/next/program_processors/runners/dace_iterator/itir_to_tasklet.py +++ b/src/gt4py/next/program_processors/runners/dace_iterator/itir_to_tasklet.py @@ -20,7 +20,6 @@ import numpy as np from dace import subsets from dace.transformation.dataflow import MapFusion -from dace.transformation.passes.prune_symbols import RemoveUnusedSymbols import gt4py.eve.codegen from gt4py import eve @@ -167,22 +166,19 @@ class Context: state: dace.SDFGState symbol_map: dict[str, TaskletExpr] # if we encounter a reduction node, the reduction state needs to be pushed to child nodes - reduce_limit: int - reduce_wcr: Optional[str] + reduce_identity: Optional[SymbolExpr] def __init__( self, body: dace.SDFG, state: dace.SDFGState, symbol_map: dict[str, TaskletExpr], - reduce_limit: int = 0, - reduce_wcr: Optional[str] = None, + reduce_identity: Optional[SymbolExpr] = None, ): self.body = body self.state = state self.symbol_map = symbol_map - self.reduce_limit = reduce_limit - self.reduce_wcr = reduce_wcr + self.reduce_identity = reduce_identity def builtin_neighbors( @@ -193,42 +189,53 @@ def builtin_neighbors( offset_dim = offset_literal.value assert isinstance(offset_dim, str) iterator = transformer.visit(data) - table: NeighborTableOffsetProvider = transformer.offset_provider[offset_dim] - assert isinstance(table, NeighborTableOffsetProvider) - - offset = transformer.offset_provider[offset_dim] - if isinstance(offset, Dimension): + assert isinstance(iterator, IteratorExpr) + field_desc = iterator.field.desc(transformer.context.body) + + field_index = "__field_idx" + offset_provider = transformer.offset_provider[offset_dim] + if isinstance(offset_provider, NeighborTableOffsetProvider): + neighbor_check = f"{field_index} >= 0" + elif isinstance(offset_provider, StridedNeighborOffsetProvider): + neighbor_check = f"{field_index} < {field_desc.shape[offset_provider.neighbor_axis.value]}" + else: + assert isinstance(offset_provider, Dimension) raise NotImplementedError( "Neighbor reductions for cartesian grids not implemented in DaCe backend." ) + assert transformer.context.reduce_identity is not None + sdfg: dace.SDFG = transformer.context.body state: dace.SDFGState = transformer.context.state - shifted_dim = table.origin_axis.value + shifted_dim = offset_provider.origin_axis.value result_name = unique_var_name() - sdfg.add_array(result_name, dtype=iterator.dtype, shape=(table.max_neighbors,), transient=True) + sdfg.add_array( + result_name, dtype=iterator.dtype, shape=(offset_provider.max_neighbors,), transient=True + ) result_access = state.add_access(result_name) - table_name = connectivity_identifier(offset_dim) - # generate unique map index name to avoid conflict with other maps inside same state - index_name = unique_name("__neigh_idx") + neighbor_index = unique_name("neighbor_idx") me, mx = state.add_map( f"{offset_dim}_neighbors_map", - ndrange={index_name: f"0:{table.max_neighbors}"}, + ndrange={neighbor_index: f"0:{offset_provider.max_neighbors}"}, ) + table_name = connectivity_identifier(offset_dim) + table_subset = (f"0:{sdfg.arrays[table_name].shape[0]}", neighbor_index) + shift_tasklet = state.add_tasklet( "shift", - code=f"__result = __table[__idx, {index_name}]", + code="__result = __table[__idx]", inputs={"__table", "__idx"}, outputs={"__result"}, ) data_access_tasklet = state.add_tasklet( "data_access", - code="__result = __field[__idx]", - inputs={"__field", "__idx"}, + code=f"__result = __field[{field_index}] if {neighbor_check} else {transformer.context.reduce_identity.value}", + inputs={"__field", field_index}, outputs={"__result"}, ) idx_name = unique_var_name() @@ -237,7 +244,7 @@ def builtin_neighbors( state.add_access(table_name), me, shift_tasklet, - memlet=create_memlet_full(table_name, sdfg.arrays[table_name]), + memlet=create_memlet_at(table_name, table_subset), dst_conn="__table", ) state.add_memlet_path( @@ -247,17 +254,11 @@ def builtin_neighbors( memlet=dace.Memlet.simple(iterator.indices[shifted_dim].data, "0"), dst_conn="__idx", ) - state.add_edge( - shift_tasklet, - "__result", - data_access_tasklet, - "__idx", - dace.Memlet.simple(idx_name, "0"), - ) + state.add_edge(shift_tasklet, "__result", data_access_tasklet, field_index, dace.Memlet()) # select full shape only in the neighbor-axis dimension field_subset = tuple( - f"0:{shape}" if dim == table.neighbor_axis.value else f"i_{dim}" - for dim, shape in zip(sorted(iterator.dimensions), sdfg.arrays[iterator.field.data].shape) + f"0:{shape}" if dim == offset_provider.neighbor_axis.value else f"i_{dim}" + for dim, shape in zip(sorted(iterator.dimensions), field_desc.shape) ) state.add_memlet_path( iterator.field, @@ -270,7 +271,7 @@ def builtin_neighbors( data_access_tasklet, mx, result_access, - memlet=dace.Memlet.simple(result_name, index_name), + memlet=dace.Memlet.simple(result_name, neighbor_index), src_conn="__result", ) @@ -508,14 +509,16 @@ def visit_FunctionDefinition(self, node: itir.FunctionDefinition, **kwargs): raise NotImplementedError() def visit_Lambda( - self, node: itir.Lambda, args: Sequence[TaskletExpr] + self, node: itir.Lambda, args: Sequence[TaskletExpr], use_neighbor_tables: bool = True ) -> tuple[ Context, list[tuple[str, ValueExpr] | tuple[tuple[str, dict], IteratorExpr]], list[ValueExpr], ]: func_name = f"lambda_{abs(hash(node)):x}" - neighbor_tables = filter_neighbor_tables(self.offset_provider) + neighbor_tables = ( + filter_neighbor_tables(self.offset_provider) if use_neighbor_tables else [] + ) connectivity_names = [connectivity_identifier(offset) for offset, _ in neighbor_tables] # Create the SDFG for the lambda's body @@ -557,13 +560,12 @@ def visit_Lambda( lambda_sdfg, lambda_state, lambda_symbols_pass.symbol_refs, - reduce_limit=self.context.reduce_limit, - reduce_wcr=self.context.reduce_wcr, + reduce_identity=self.context.reduce_identity, ) lambda_taskgen = PythonTaskletCodegen(self.offset_provider, lambda_context, self.node_types) results: list[ValueExpr] = [] - # We are flattening the returned list of value expressions because the multiple outputs of a lamda + # We are flattening the returned list of value expressions because the multiple outputs of a lambda # should be a list of nodes without tuple structure. Ideally, an ITIR transformation could do this. for expr in flatten_list(lambda_taskgen.visit(node.expr)): if isinstance(expr, ValueExpr): @@ -573,8 +575,7 @@ def visit_Lambda( lambda_state.add_nedge( expr.value, result_access, - # in case of reduction lambda, the output edge from lambda tasklet performs write-conflict resolution - dace.Memlet.simple(result_access.data, "0", wcr_str=self.context.reduce_wcr), + dace.Memlet.simple(result_access.data, "0"), ) result = ValueExpr(value=result_access, dtype=expr.dtype) else: @@ -700,60 +701,7 @@ def _visit_deref(self, node: itir.FunCall) -> list[ValueExpr]: args: list[ValueExpr] sorted_dims = sorted(iterator.dimensions) - if self.context.reduce_limit: - # we are visiting a child node of reduction, so the neighbor index can be used for indirect addressing - result_name = unique_var_name() - self.context.body.add_array( - result_name, - dtype=iterator.dtype, - shape=(self.context.reduce_limit,), - transient=True, - ) - result_access = self.context.state.add_access(result_name) - - # generate unique map index name to avoid conflict with other maps inside same state - index_name = unique_name("__deref_idx") - me, mx = self.context.state.add_map( - "deref_map", - ndrange={index_name: f"0:{self.context.reduce_limit}"}, - ) - - # if dim is not found in iterator indices, we take the neighbor index over the reduction domain - flat_index = [ - f"{iterator.indices[dim].data}_v" if dim in iterator.indices else index_name - for dim in sorted_dims - ] - args = [ValueExpr(iterator.field, iterator.dtype)] + [ - ValueExpr(iterator.indices[dim], _INDEX_DTYPE) for dim in iterator.indices - ] - internals = [f"{arg.value.data}_v" for arg in args] - - deref_tasklet = self.context.state.add_tasklet( - name="deref", - inputs=set(internals), - outputs={"__result"}, - code=f"__result = {args[0].value.data}_v[{', '.join(flat_index)}]", - ) - - for arg, internal in zip(args, internals): - input_memlet = create_memlet_full( - arg.value.data, self.context.body.arrays[arg.value.data] - ) - self.context.state.add_memlet_path( - arg.value, me, deref_tasklet, memlet=input_memlet, dst_conn=internal - ) - - self.context.state.add_memlet_path( - deref_tasklet, - mx, - result_access, - memlet=dace.Memlet.simple(result_name, index_name), - src_conn="__result", - ) - - return [ValueExpr(value=result_access, dtype=iterator.dtype)] - - elif all([dim in iterator.indices for dim in iterator.dimensions]): + if all([dim in iterator.indices for dim in iterator.dimensions]): # The deref iterator has index values on all dimensions: the result will be a scalar args = [ValueExpr(iterator.field, iterator.dtype)] + [ ValueExpr(iterator.indices[dim], _INDEX_DTYPE) for dim in sorted_dims @@ -930,8 +878,9 @@ def visit_OffsetLiteral(self, node: itir.OffsetLiteral) -> list[ValueExpr]: return [ValueExpr(offset_node, self.context.body.arrays[offset_var].dtype)] def _visit_reduce(self, node: itir.FunCall): - result_name = unique_var_name() - result_access = self.context.state.add_access(result_name) + node_type = self.node_types[id(node)] + assert isinstance(node_type, itir_typing.Val) + reduce_dtype = itir_type_as_dace_type(node_type.dtype) if len(node.args) == 1: assert ( @@ -939,131 +888,70 @@ def _visit_reduce(self, node: itir.FunCall): and isinstance(node.args[0].fun, itir.SymRef) and node.args[0].fun.id == "neighbors" ) - args = self.visit(node.args) - assert len(args) == 1 - args = args[0] - assert len(args) == 1 - neighbors_expr = args[0] - result_dtype = neighbors_expr.dtype assert isinstance(node.fun, itir.FunCall) op_name = node.fun.args[0] assert isinstance(op_name, itir.SymRef) - init = node.fun.args[1] + reduce_identity = node.fun.args[1] + assert isinstance(reduce_identity, itir.Literal) - reduce_array_desc = neighbors_expr.value.desc(self.context.body) + # set reduction state + self.context.reduce_identity = SymbolExpr(reduce_identity, reduce_dtype) + + args = self.visit(node.args) + + assert len(args) == 1 and len(args[0]) == 1 + reduce_input_node = args[0][0].value - self.context.body.add_scalar(result_name, result_dtype, transient=True) - op_str = _MATH_BUILTINS_MAPPING[str(op_name)].format("__result", "__values[__idx]") - reduce_tasklet = self.context.state.add_tasklet( - "reduce", - code=f"__result = {init}\nfor __idx in range({reduce_array_desc.shape[0]}):\n __result = {op_str}", - inputs={"__values"}, - outputs={"__result"}, - ) - self.context.state.add_edge( - args[0].value, - None, - reduce_tasklet, - "__values", - create_memlet_full(neighbors_expr.value.data, reduce_array_desc), - ) - self.context.state.add_edge( - reduce_tasklet, - "__result", - result_access, - None, - dace.Memlet.simple(result_name, "0"), - ) else: assert isinstance(node.fun, itir.FunCall) assert isinstance(node.fun.args[0], itir.Lambda) fun_node = node.fun.args[0] + assert isinstance(fun_node.expr, itir.FunCall) - args = [] - for node_arg in node.args: - if ( - isinstance(node_arg, itir.FunCall) - and isinstance(node_arg.fun, itir.SymRef) - and node_arg.fun.id == "neighbors" - ): - expr = self.visit(node_arg) - args.append(*expr) - else: - args.append(None) - - # first visit only arguments for neighbor selection, all other arguments are none - neighbor_args = [arg for arg in args if arg] - - # check that all neighbors expression have the same range - assert ( - len( - set([self.context.body.arrays[expr.value.data].shape for expr in neighbor_args]) - ) - == 1 - ) + op_name = fun_node.expr.fun + assert isinstance(op_name, itir.SymRef) + reduce_identity = get_reduce_identity_value(op_name.id, reduce_dtype) - nreduce = self.context.body.arrays[neighbor_args[0].value.data].shape[0] - nreduce_domain = {"__idx": f"0:{nreduce}"} + # set reduction state in visit context + self.context.reduce_identity = SymbolExpr(reduce_identity, reduce_dtype) - result_dtype = neighbor_args[0].dtype - self.context.body.add_scalar(result_name, result_dtype, transient=True) + args = flatten_list(self.visit(node.args)) - assert isinstance(fun_node.expr, itir.FunCall) - op_name = fun_node.expr.fun - assert isinstance(op_name, itir.SymRef) + # clear context + self.context.reduce_identity = None - # initialize the reduction result based on type of operation - init_value = get_reduce_identity_value(op_name.id, result_dtype) - init_state = self.context.body.add_state_before(self.context.state, "init", True) - init_tasklet = init_state.add_tasklet( - "init_reduce", {}, {"__out"}, f"__out = {init_value}" - ) - init_state.add_edge( - init_tasklet, - "__out", - init_state.add_access(result_name), - None, - dace.Memlet.simple(result_name, "0"), + # check that all neighbor expressions have the same shape + nreduce_shape = args[1].value.desc(self.context.body).shape + assert all( + [arg.value.desc(self.context.body).shape == nreduce_shape for arg in args[2:]] ) - # set reduction state to enable dereference of neighbors in input fields and to set WCR on reduce tasklet - self.context.reduce_limit = nreduce - self.context.reduce_wcr = "lambda x, y: " + _MATH_BUILTINS_MAPPING[str(op_name)].format( - "x", "y" - ) + nreduce_index = tuple(f"_i{i}" for i in range(len(nreduce_shape))) + nreduce_domain = {idx: f"0:{size}" for idx, size in zip(nreduce_index, nreduce_shape)} - # visit child nodes for input arguments - for i, node_arg in enumerate(node.args): - if not args[i]: - args[i] = self.visit(node_arg)[0] + reduce_input_name = unique_var_name() + self.context.body.add_array( + reduce_input_name, nreduce_shape, reduce_dtype, transient=True + ) lambda_node = itir.Lambda(expr=fun_node.expr.args[1], params=fun_node.params[1:]) - lambda_context, inner_inputs, inner_outputs = self.visit(lambda_node, args=args) - - # clear context - self.context.reduce_limit = 0 - self.context.reduce_wcr = None - - # the connectivity arrays (neighbor tables) are not needed inside the reduce lambda SDFG - neighbor_tables = filter_neighbor_tables(self.offset_provider) - for conn, _ in neighbor_tables: - var = connectivity_identifier(conn) - lambda_context.body.remove_data(var) - # cleanup symbols previously used for shape and stride of connectivity arrays - p = RemoveUnusedSymbols() - p.apply_pass(lambda_context.body, {}) - - input_memlets = [ - dace.Memlet.simple(expr.value.data, "__idx") for arg, expr in zip(node.args, args) - ] - output_memlet = dace.Memlet.simple(result_name, "0") + lambda_context, inner_inputs, inner_outputs = self.visit( + lambda_node, args=args, use_neighbor_tables=False + ) - input_mapping = {param: arg for (param, _), arg in zip(inner_inputs, input_memlets)} - output_mapping = {inner_outputs[0].value.data: output_memlet} + input_mapping = { + param: create_memlet_at(arg.value.data, nreduce_index) + for (param, _), arg in zip(inner_inputs, args) + } + output_mapping = { + inner_outputs[0].value.data: create_memlet_at(reduce_input_name, nreduce_index) + } symbol_mapping = map_nested_sdfg_symbols( self.context.body, lambda_context.body, input_mapping ) + reduce_input_node = self.context.state.add_access(reduce_input_name) + nsdfg_node, map_entry, _ = add_mapped_nested_sdfg( self.context.state, sdfg=lambda_context.body, @@ -1072,14 +960,32 @@ def _visit_reduce(self, node: itir.FunCall): outputs=output_mapping, symbol_mapping=symbol_mapping, input_nodes={arg.value.data: arg.value for arg in args}, - output_nodes={result_name: result_access}, + output_nodes={reduce_input_name: reduce_input_node}, ) - # we apply map fusion only to the nested-SDFG which is generated for the reduction operator - # the purpose is to keep the ITIR-visitor program simple and to clean up the generated SDFG - self.context.body.apply_transformations_repeated([MapFusion], validate=False) + reduce_input_desc = reduce_input_node.desc(self.context.body) + + result_name = unique_var_name() + # we allocate an array instead of a scalar because the reduce library node is generic and expects an array node + self.context.body.add_array(result_name, (1,), reduce_dtype, transient=True) + result_access = self.context.state.add_access(result_name) + + reduce_wcr = "lambda x, y: " + _MATH_BUILTINS_MAPPING[str(op_name)].format("x", "y") + reduce_node = self.context.state.add_reduce(reduce_wcr, None, reduce_identity) + self.context.state.add_nedge( + reduce_input_node, + reduce_node, + dace.Memlet.from_array(reduce_input_node.data, reduce_input_desc), + ) + self.context.state.add_nedge( + reduce_node, result_access, dace.Memlet.simple(result_name, "0") + ) + + # we apply map fusion only to the nested-SDFG which is generated for the reduction operator + # the purpose is to keep the ITIR-visitor program simple and to clean up the generated SDFG + self.context.body.apply_transformations_repeated([MapFusion], validate=False) - return [ValueExpr(result_access, result_dtype)] + return [ValueExpr(result_access, reduce_dtype)] def _visit_numeric_builtin(self, node: itir.FunCall) -> list[ValueExpr]: assert isinstance(node.fun, itir.SymRef) diff --git a/tests/next_tests/integration_tests/feature_tests/ffront_tests/test_gt4py_builtins.py b/tests/next_tests/integration_tests/feature_tests/ffront_tests/test_gt4py_builtins.py index bbbac6c139..e8d0c8b163 100644 --- a/tests/next_tests/integration_tests/feature_tests/ffront_tests/test_gt4py_builtins.py +++ b/tests/next_tests/integration_tests/feature_tests/ffront_tests/test_gt4py_builtins.py @@ -46,6 +46,16 @@ ids=["positive_values", "negative_values"], ) def test_maxover_execution_(unstructured_case, strategy): + # TODO(edopao): remove try/catch after uplift of dace module to version > 0.15 + try: + from gt4py.next.program_processors.runners.dace_iterator import run_dace_gpu + + if unstructured_case.backend == run_dace_gpu: + # see https://github.com/spcl/dace/pull/1442 + pytest.xfail("requires fix in dace module for cuda codegen") + except ImportError: + pass + if unstructured_case.backend in [ gtfn.run_gtfn, gtfn.run_gtfn_gpu, @@ -69,6 +79,16 @@ def testee(edge_f: cases.EField) -> cases.VField: @pytest.mark.uses_unstructured_shift def test_minover_execution(unstructured_case): + # TODO(edopao): remove try/catch after uplift of dace module to version > 0.15 + try: + from gt4py.next.program_processors.runners.dace_iterator import run_dace_gpu + + if unstructured_case.backend == run_dace_gpu: + # see https://github.com/spcl/dace/pull/1442 + pytest.xfail("requires fix in dace module for cuda codegen") + except ImportError: + pass + @gtx.field_operator def minover(edge_f: cases.EField) -> cases.VField: out = min_over(edge_f(V2E), axis=V2EDim) @@ -82,6 +102,16 @@ def minover(edge_f: cases.EField) -> cases.VField: @pytest.mark.uses_unstructured_shift def test_reduction_execution(unstructured_case): + # TODO(edopao): remove try/catch after uplift of dace module to version > 0.15 + try: + from gt4py.next.program_processors.runners.dace_iterator import run_dace_gpu + + if unstructured_case.backend == run_dace_gpu: + # see https://github.com/spcl/dace/pull/1442 + pytest.xfail("requires fix in dace module for cuda codegen") + except ImportError: + pass + @gtx.field_operator def reduction(edge_f: cases.EField) -> cases.VField: return neighbor_sum(edge_f(V2E), axis=V2EDim) From 8e644585361aac30bf97f753e652117c0884bde5 Mon Sep 17 00:00:00 2001 From: edopao Date: Tue, 12 Dec 2023 13:04:03 +0100 Subject: [PATCH 26/85] feat[next][dace]: Add support for if expressions with tuple argument (#1393) Some icon4py stencils require support for if expressions with tuple arguments. This PR adds support to the DaCe backend in the visitor of builtin_if function. Additionally, this PR contains one fix in the result of builtin_tuple_get, which should return a list. --- .../runners/dace_iterator/__init__.py | 1 - .../runners/dace_iterator/itir_to_tasklet.py | 42 ++++++++++++++----- 2 files changed, 31 insertions(+), 12 deletions(-) diff --git a/src/gt4py/next/program_processors/runners/dace_iterator/__init__.py b/src/gt4py/next/program_processors/runners/dace_iterator/__init__.py index 34ba2d2d95..acfa06b456 100644 --- a/src/gt4py/next/program_processors/runners/dace_iterator/__init__.py +++ b/src/gt4py/next/program_processors/runners/dace_iterator/__init__.py @@ -247,7 +247,6 @@ def run_dace_iterator(program: itir.FencilDefinition, *args, **kwargs): neighbor_tables = filter_neighbor_tables(offset_provider) cache_id = get_cache_id(program, arg_types, column_axis, offset_provider) - sdfg: Optional[dace.SDFG] = None if build_cache is not None and cache_id in build_cache: # retrieve SDFG program from build cache sdfg_program = build_cache[cache_id] diff --git a/src/gt4py/next/program_processors/runners/dace_iterator/itir_to_tasklet.py b/src/gt4py/next/program_processors/runners/dace_iterator/itir_to_tasklet.py index f6f197859b..32b8cbf2b1 100644 --- a/src/gt4py/next/program_processors/runners/dace_iterator/itir_to_tasklet.py +++ b/src/gt4py/next/program_processors/runners/dace_iterator/itir_to_tasklet.py @@ -321,16 +321,36 @@ def builtin_can_deref( def builtin_if( transformer: "PythonTaskletCodegen", node: itir.Expr, node_args: list[itir.Expr] ) -> list[ValueExpr]: - args = [arg for li in transformer.visit(node_args) for arg in li] - expr_args = [(arg, f"{arg.value.data}_v") for arg in args if not isinstance(arg, SymbolExpr)] - internals = [ - arg.value if isinstance(arg, SymbolExpr) else f"{arg.value.data}_v" for arg in args + args = transformer.visit(node_args) + assert len(args) == 3 + if_node = args[0][0] if isinstance(args[0], list) else args[0] + + # the argument could be a list of elements on each branch representing the result of `make_tuple` + # however, the normal case is to find one value expression + assert len(args[1]) == len(args[2]) + if_expr_args = [ + (a[0] if isinstance(a, list) else a, b[0] if isinstance(b, list) else b) + for a, b in zip(args[1], args[2]) ] - expr = "({1} if {0} else {2})".format(*internals) - node_type = transformer.node_types[id(node)] - assert isinstance(node_type, itir_typing.Val) - type_ = itir_type_as_dace_type(node_type.dtype) - return transformer.add_expr_tasklet(expr_args, expr, type_, "if") + + # in case of tuple arguments, generate one if-tasklet for each element of the output tuple + if_expr_values = [] + for a, b in if_expr_args: + assert a.dtype == b.dtype + expr_args = [ + (arg, f"{arg.value.data}_v") + for arg in (if_node, a, b) + if not isinstance(arg, SymbolExpr) + ] + internals = [ + arg.value if isinstance(arg, SymbolExpr) else f"{arg.value.data}_v" + for arg in (if_node, a, b) + ] + expr = "({1} if {0} else {2})".format(*internals) + if_expr = transformer.add_expr_tasklet(expr_args, expr, a.dtype, "if") + if_expr_values.append(if_expr[0]) + + return if_expr_values def builtin_list_get( @@ -356,7 +376,7 @@ def builtin_list_get( def builtin_cast( transformer: "PythonTaskletCodegen", node: itir.Expr, node_args: list[itir.Expr] ) -> list[ValueExpr]: - args = [transformer.visit(node_args[0])[0]] + args = transformer.visit(node_args[0]) internals = [f"{arg.value.data}_v" for arg in args] target_type = node_args[1] assert isinstance(target_type, itir.SymRef) @@ -380,7 +400,7 @@ def builtin_tuple_get( elements = transformer.visit(node_args[1]) index = node_args[0] if isinstance(index, itir.Literal): - return elements[int(index.value)] + return [elements[int(index.value)]] raise ValueError("Tuple can only be subscripted with compile-time constants") From a14ad09f6dd3043114238fc820d68621480cfc4e Mon Sep 17 00:00:00 2001 From: Hannes Vogt Date: Tue, 12 Dec 2023 13:24:51 +0100 Subject: [PATCH 27/85] feat[next]: Embedded field scan (#1365) Adds the scalar scan operator for embedded field view. --- .gitpod.yml | 2 +- src/gt4py/next/embedded/common.py | 17 ++ src/gt4py/next/embedded/context.py | 4 +- src/gt4py/next/embedded/nd_array_field.py | 8 +- src/gt4py/next/embedded/operators.py | 168 ++++++++++++++++++ src/gt4py/next/ffront/decorator.py | 95 ++++------ src/gt4py/next/field_utils.py | 22 +++ src/gt4py/next/iterator/embedded.py | 19 +- src/gt4py/next/utils.py | 22 ++- tests/next_tests/exclusion_matrices.py | 1 - tests/next_tests/integration_tests/cases.py | 6 +- .../ffront_tests/test_execution.py | 80 +++++++++ .../iterator_tests/test_column_stencil.py | 4 +- .../unit_tests/embedded_tests/test_common.py | 14 +- .../iterator_tests/test_embedded_internals.py | 8 +- 15 files changed, 372 insertions(+), 98 deletions(-) create mode 100644 src/gt4py/next/embedded/operators.py create mode 100644 src/gt4py/next/field_utils.py diff --git a/.gitpod.yml b/.gitpod.yml index 1d579d88eb..802d87796a 100644 --- a/.gitpod.yml +++ b/.gitpod.yml @@ -5,7 +5,7 @@ image: tasks: - name: Setup venv and dev tools init: | - ln -s /workspace/gt4py/.gitpod/.vscode /workspace/gt4py/.vscode + ln -sfn /workspace/gt4py/.gitpod/.vscode /workspace/gt4py/.vscode python -m venv .venv source .venv/bin/activate pip install --upgrade pip setuptools wheel diff --git a/src/gt4py/next/embedded/common.py b/src/gt4py/next/embedded/common.py index d796189ab3..558730cb82 100644 --- a/src/gt4py/next/embedded/common.py +++ b/src/gt4py/next/embedded/common.py @@ -14,6 +14,10 @@ from __future__ import annotations +import functools +import itertools +import operator + from gt4py.eve.extended_typing import Any, Optional, Sequence, cast from gt4py.next import common from gt4py.next.embedded import exceptions as embedded_exceptions @@ -90,6 +94,19 @@ def _absolute_sub_domain( return common.Domain(*named_ranges) +def intersect_domains(*domains: common.Domain) -> common.Domain: + return functools.reduce( + operator.and_, + domains, + common.Domain(dims=tuple(), ranges=tuple()), + ) + + +def iterate_domain(domain: common.Domain): + for i in itertools.product(*[list(r) for r in domain.ranges]): + yield tuple(zip(domain.dims, i)) + + def _expand_ellipsis( indices: common.RelativeIndexSequence, target_size: int ) -> tuple[common.IntIndex | slice, ...]: diff --git a/src/gt4py/next/embedded/context.py b/src/gt4py/next/embedded/context.py index 5fbdbc6f25..93942a5959 100644 --- a/src/gt4py/next/embedded/context.py +++ b/src/gt4py/next/embedded/context.py @@ -24,7 +24,7 @@ #: Column range used in column mode (`column_axis != None`) in the current embedded iterator #: closure execution context. -closure_column_range: cvars.ContextVar[range] = cvars.ContextVar("column_range") +closure_column_range: cvars.ContextVar[common.NamedRange] = cvars.ContextVar("column_range") _undefined_offset_provider: common.OffsetProvider = {} @@ -37,7 +37,7 @@ @contextlib.contextmanager def new_context( *, - closure_column_range: range | eve.NothingType = eve.NOTHING, + closure_column_range: common.NamedRange | eve.NothingType = eve.NOTHING, offset_provider: common.OffsetProvider | eve.NothingType = eve.NOTHING, ): import gt4py.next.embedded.context as this_module diff --git a/src/gt4py/next/embedded/nd_array_field.py b/src/gt4py/next/embedded/nd_array_field.py index ff6a2ceac7..6b69e8f8cc 100644 --- a/src/gt4py/next/embedded/nd_array_field.py +++ b/src/gt4py/next/embedded/nd_array_field.py @@ -16,7 +16,6 @@ import dataclasses import functools -import operator from collections.abc import Callable, Sequence from types import ModuleType from typing import ClassVar @@ -49,11 +48,10 @@ def _builtin_op(*fields: common.Field | core_defs.Scalar) -> NdArrayField: xp = first.__class__.array_ns op = getattr(xp, array_builtin_name) - domain_intersection = functools.reduce( - operator.and_, - [f.domain for f in fields if common.is_field(f)], - common.Domain(dims=tuple(), ranges=tuple()), + domain_intersection = embedded_common.intersect_domains( + *[f.domain for f in fields if common.is_field(f)] ) + transformed: list[core_defs.NDArrayObject | core_defs.Scalar] = [] for f in fields: if common.is_field(f): diff --git a/src/gt4py/next/embedded/operators.py b/src/gt4py/next/embedded/operators.py new file mode 100644 index 0000000000..f50ace7687 --- /dev/null +++ b/src/gt4py/next/embedded/operators.py @@ -0,0 +1,168 @@ +# GT4Py - GridTools Framework +# +# Copyright (c) 2014-2023, ETH Zurich +# All rights reserved. +# +# This file is part of the GT4Py project and the GridTools framework. +# GT4Py is free software: you can redistribute it and/or modify it under +# the terms of the GNU General Public License as published by the +# Free Software Foundation, either version 3 of the License, or any later +# version. See the LICENSE.txt file at the top-level directory of this +# distribution for a copy of the license or check . +# +# SPDX-License-Identifier: GPL-3.0-or-later + +import dataclasses +from typing import Any, Callable, Generic, ParamSpec, Sequence, TypeVar + +from gt4py import eve +from gt4py._core import definitions as core_defs +from gt4py.next import common, constructors, utils +from gt4py.next.embedded import common as embedded_common, context as embedded_context + + +_P = ParamSpec("_P") +_R = TypeVar("_R") + + +@dataclasses.dataclass(frozen=True) +class EmbeddedOperator(Generic[_R, _P]): + fun: Callable[_P, _R] + + def __call__(self, *args: _P.args, **kwargs: _P.kwargs) -> _R: + return self.fun(*args, **kwargs) + + +@dataclasses.dataclass(frozen=True) +class ScanOperator(EmbeddedOperator[_R, _P]): + forward: bool + init: core_defs.Scalar | tuple[core_defs.Scalar | tuple, ...] + axis: common.Dimension + + def __call__(self, *args: common.Field | core_defs.Scalar, **kwargs: common.Field | core_defs.Scalar) -> common.Field: # type: ignore[override] # we cannot properly type annotate relative to self.fun + scan_range = embedded_context.closure_column_range.get() + assert self.axis == scan_range[0] + scan_axis = scan_range[0] + domain_intersection = _intersect_scan_args(*args, *kwargs.values()) + non_scan_domain = common.Domain(*[nr for nr in domain_intersection if nr[0] != scan_axis]) + + out_domain = common.Domain( + *[scan_range if nr[0] == scan_axis else nr for nr in domain_intersection] + ) + if scan_axis not in out_domain.dims: + # even if the scan dimension is not in the input, we can scan over it + out_domain = common.Domain(*out_domain, (scan_range)) + + res = _construct_scan_array(out_domain)(self.init) + + def scan_loop(hpos): + acc = self.init + for k in scan_range[1] if self.forward else reversed(scan_range[1]): + pos = (*hpos, (scan_axis, k)) + new_args = [_tuple_at(pos, arg) for arg in args] + new_kwargs = {k: _tuple_at(pos, v) for k, v in kwargs.items()} + acc = self.fun(acc, *new_args, **new_kwargs) + _tuple_assign_value(pos, res, acc) + + if len(non_scan_domain) == 0: + # if we don't have any dimension orthogonal to scan_axis, we need to do one scan_loop + scan_loop(()) + else: + for hpos in embedded_common.iterate_domain(non_scan_domain): + scan_loop(hpos) + + return res + + +def field_operator_call(op: EmbeddedOperator, args: Any, kwargs: Any): + if "out" in kwargs: + # called from program or direct field_operator as program + offset_provider = kwargs.pop("offset_provider", None) + + new_context_kwargs = {} + if embedded_context.within_context(): + # called from program + assert offset_provider is None + else: + # field_operator as program + new_context_kwargs["offset_provider"] = offset_provider + + out = kwargs.pop("out") + domain = kwargs.pop("domain", None) + + flattened_out: tuple[common.Field, ...] = utils.flatten_nested_tuple((out,)) + assert all(f.domain == flattened_out[0].domain for f in flattened_out) + + out_domain = common.domain(domain) if domain is not None else flattened_out[0].domain + + new_context_kwargs["closure_column_range"] = _get_vertical_range(out_domain) + + with embedded_context.new_context(**new_context_kwargs) as ctx: + res = ctx.run(op, *args, **kwargs) + _tuple_assign_field( + out, + res, + domain=out_domain, + ) + else: + # called from other field_operator + return op(*args, **kwargs) + + +def _get_vertical_range(domain: common.Domain) -> common.NamedRange | eve.NothingType: + vertical_dim_filtered = [nr for nr in domain if nr[0].kind == common.DimensionKind.VERTICAL] + assert len(vertical_dim_filtered) <= 1 + return vertical_dim_filtered[0] if vertical_dim_filtered else eve.NOTHING + + +def _tuple_assign_field( + target: tuple[common.MutableField | tuple, ...] | common.MutableField, + source: tuple[common.Field | tuple, ...] | common.Field, + domain: common.Domain, +): + @utils.tree_map + def impl(target: common.MutableField, source: common.Field): + target[domain] = source[domain] + + impl(target, source) + + +def _intersect_scan_args( + *args: core_defs.Scalar | common.Field | tuple[core_defs.Scalar | common.Field | tuple, ...] +) -> common.Domain: + return embedded_common.intersect_domains( + *[arg.domain for arg in utils.flatten_nested_tuple(args) if common.is_field(arg)] + ) + + +def _construct_scan_array(domain: common.Domain): + @utils.tree_map + def impl(init: core_defs.Scalar) -> common.Field: + return constructors.empty(domain, dtype=type(init)) + + return impl + + +def _tuple_assign_value( + pos: Sequence[common.NamedIndex], + target: common.MutableField | tuple[common.MutableField | tuple, ...], + source: core_defs.Scalar | tuple[core_defs.Scalar | tuple, ...], +) -> None: + @utils.tree_map + def impl(target: common.MutableField, source: core_defs.Scalar): + target[pos] = source + + impl(target, source) + + +def _tuple_at( + pos: Sequence[common.NamedIndex], + field: common.Field | core_defs.Scalar | tuple[common.Field | core_defs.Scalar | tuple, ...], +) -> core_defs.Scalar | tuple[core_defs.ScalarT | tuple, ...]: + @utils.tree_map + def impl(field: common.Field | core_defs.Scalar) -> core_defs.Scalar: + res = field[pos] if common.is_field(field) else field + assert core_defs.is_scalar_type(res) + return res + + return impl(field) # type: ignore[return-value] diff --git a/src/gt4py/next/ffront/decorator.py b/src/gt4py/next/ffront/decorator.py index e06c651b13..8202cda6f5 100644 --- a/src/gt4py/next/ffront/decorator.py +++ b/src/gt4py/next/ffront/decorator.py @@ -32,8 +32,9 @@ from gt4py._core import definitions as core_defs from gt4py.eve import utils as eve_utils from gt4py.eve.extended_typing import Any, Optional -from gt4py.next import allocators as next_allocators, common, embedded as next_embedded +from gt4py.next import allocators as next_allocators, embedded as next_embedded from gt4py.next.common import Dimension, DimensionKind, GridType +from gt4py.next.embedded import operators as embedded_operators from gt4py.next.ffront import ( dialect_ast_enums, field_operator_ast as foast, @@ -550,6 +551,7 @@ class FieldOperator(GTCallable, Generic[OperatorNodeT]): definition: Optional[types.FunctionType] = None backend: Optional[ppi.ProgramExecutor] = DEFAULT_BACKEND grid_type: Optional[GridType] = None + operator_attributes: Optional[dict[str, Any]] = None _program_cache: dict = dataclasses.field(default_factory=dict) @classmethod @@ -586,6 +588,7 @@ def from_function( definition=definition, backend=backend, grid_type=grid_type, + operator_attributes=operator_attributes, ) def __gt_type__(self) -> ts.CallableType: @@ -692,68 +695,38 @@ def __call__( *args, **kwargs, ) -> None: - # TODO(havogt): Don't select mode based on existence of kwargs, - # because now we cannot provide nice error messages. E.g. set context var - # if we are reaching this from a program call. - if "out" in kwargs: - out = kwargs.pop("out") + if not next_embedded.context.within_context() and self.backend is not None: + # non embedded execution offset_provider = kwargs.pop("offset_provider", None) - if self.backend is not None: - # "out" and "offset_provider" -> field_operator as program - # When backend is None, we are in embedded execution and for now - # we disable the program generation since it would involve generating - # Python source code from a PAST node. - args, kwargs = type_info.canonicalize_arguments(self.foast_node.type, args, kwargs) - # TODO(tehrengruber): check all offset providers are given - # deduce argument types - arg_types = [] - for arg in args: - arg_types.append(type_translation.from_value(arg)) - kwarg_types = {} - for name, arg in kwargs.items(): - kwarg_types[name] = type_translation.from_value(arg) - - return self.as_program(arg_types, kwarg_types)( - *args, out, offset_provider=offset_provider, **kwargs - ) - else: - # "out" -> field_operator called from program in embedded execution or - # field_operator called directly from Python in embedded execution - domain = kwargs.pop("domain", None) - if not next_embedded.context.within_context(): - # field_operator from Python in embedded execution - with next_embedded.context.new_context(offset_provider=offset_provider) as ctx: - res = ctx.run(self.definition, *args, **kwargs) - else: - # field_operator from program in embedded execution (offset_provicer is already set) - assert ( - offset_provider is None - or next_embedded.context.offset_provider.get() is offset_provider - ) - res = self.definition(*args, **kwargs) - _tuple_assign_field( - out, res, domain=None if domain is None else common.domain(domain) - ) - return + out = kwargs.pop("out") + args, kwargs = type_info.canonicalize_arguments(self.foast_node.type, args, kwargs) + # TODO(tehrengruber): check all offset providers are given + # deduce argument types + arg_types = [] + for arg in args: + arg_types.append(type_translation.from_value(arg)) + kwarg_types = {} + for name, arg in kwargs.items(): + kwarg_types[name] = type_translation.from_value(arg) + + return self.as_program(arg_types, kwarg_types)( + *args, out, offset_provider=offset_provider, **kwargs + ) else: - # field_operator called from other field_operator in embedded execution - assert self.backend is None - return self.definition(*args, **kwargs) - - -def _tuple_assign_field( - target: tuple[common.Field | tuple, ...] | common.Field, - source: tuple[common.Field | tuple, ...] | common.Field, - domain: Optional[common.Domain], -): - if isinstance(target, tuple): - if not isinstance(source, tuple): - raise RuntimeError(f"Cannot assign {source} to {target}.") - for t, s in zip(target, source): - _tuple_assign_field(t, s, domain) - else: - domain = domain or target.domain - target[domain] = source[domain] + if self.operator_attributes is not None and any( + has_scan_op_attribute := [ + attribute in self.operator_attributes + for attribute in ["init", "axis", "forward"] + ] + ): + assert all(has_scan_op_attribute) + forward = self.operator_attributes["forward"] + init = self.operator_attributes["init"] + axis = self.operator_attributes["axis"] + op = embedded_operators.ScanOperator(self.definition, forward, init, axis) + else: + op = embedded_operators.EmbeddedOperator(self.definition) + return embedded_operators.field_operator_call(op, args, kwargs) @typing.overload diff --git a/src/gt4py/next/field_utils.py b/src/gt4py/next/field_utils.py new file mode 100644 index 0000000000..14b7c3838c --- /dev/null +++ b/src/gt4py/next/field_utils.py @@ -0,0 +1,22 @@ +# GT4Py - GridTools Framework +# +# Copyright (c) 2014-2023, ETH Zurich +# All rights reserved. +# +# This file is part of the GT4Py project and the GridTools framework. +# GT4Py is free software: you can redistribute it and/or modify it under +# the terms of the GNU General Public License as published by the +# Free Software Foundation, either version 3 of the License, or any later +# version. See the LICENSE.txt file at the top-level directory of this +# distribution for a copy of the license or check . +# +# SPDX-License-Identifier: GPL-3.0-or-later + +import numpy as np + +from gt4py.next import common, utils + + +@utils.tree_map +def asnumpy(field: common.Field | np.ndarray) -> np.ndarray: + return field.asnumpy() if common.is_field(field) else field # type: ignore[return-value] # mypy doesn't understand the condition diff --git a/src/gt4py/next/iterator/embedded.py b/src/gt4py/next/iterator/embedded.py index b02d6c8d72..b00e53bfd9 100644 --- a/src/gt4py/next/iterator/embedded.py +++ b/src/gt4py/next/iterator/embedded.py @@ -196,7 +196,7 @@ def field_setitem(self, indices: NamedFieldIndices, value: Any) -> None: #: Column range used in column mode (`column_axis != None`) in the current closure execution context. -column_range_cvar: cvars.ContextVar[range] = next_embedded.context.closure_column_range +column_range_cvar: cvars.ContextVar[common.NamedRange] = next_embedded.context.closure_column_range #: Offset provider dict in the current closure execution context. offset_provider_cvar: cvars.ContextVar[OffsetProvider] = next_embedded.context.offset_provider @@ -211,8 +211,8 @@ class Column(np.lib.mixins.NDArrayOperatorsMixin): def __init__(self, kstart: int, data: np.ndarray | Scalar) -> None: self.kstart = kstart assert isinstance(data, (np.ndarray, Scalar)) # type: ignore # mypy bug #11673 - column_range = column_range_cvar.get() - self.data = data if isinstance(data, np.ndarray) else np.full(len(column_range), data) + column_range: common.NamedRange = column_range_cvar.get() + self.data = data if isinstance(data, np.ndarray) else np.full(len(column_range[1]), data) def __getitem__(self, i: int) -> Any: result = self.data[i - self.kstart] @@ -746,7 +746,7 @@ def _make_tuple( except embedded_exceptions.IndexOutOfBounds: return _UNDEFINED else: - column_range = column_range_cvar.get() + column_range = column_range_cvar.get()[1] assert column_range is not None col: list[ @@ -823,7 +823,7 @@ def deref(self) -> Any: assert isinstance(k_pos, int) # the following range describes a range in the field # (negative values are relative to the origin, not relative to the size) - slice_column[self.column_axis] = range(k_pos, k_pos + len(column_range)) + slice_column[self.column_axis] = range(k_pos, k_pos + len(column_range[1])) assert _is_concrete_position(shifted_pos) position = {**shifted_pos, **slice_column} @@ -864,7 +864,7 @@ def make_in_iterator( init = [None] * sparse_dimensions.count(sparse_dim) new_pos[sparse_dim] = init # type: ignore[assignment] # looks like mypy is confused if column_axis is not None: - column_range = column_range_cvar.get() + column_range = column_range_cvar.get()[1] # if we deal with column stencil the column position is just an offset by which the whole column needs to be shifted assert column_range is not None new_pos[column_axis] = column_range.start @@ -1479,7 +1479,7 @@ def _column_dtype(elem: Any) -> np.dtype: @builtins.scan.register(EMBEDDED) def scan(scan_pass, is_forward: bool, init): def impl(*iters: ItIterator): - column_range = column_range_cvar.get() + column_range = column_range_cvar.get()[1] if column_range is None: raise RuntimeError("Column range is not defined, cannot scan.") @@ -1532,7 +1532,10 @@ def closure( column = ColumnDescriptor(column_axis.value, domain[column_axis.value]) del domain[column_axis.value] - column_range = column.col_range + column_range = ( + column_axis, + common.UnitRange(column.col_range.start, column.col_range.stop), + ) out = as_tuple_field(out) if is_tuple_of_field(out) else _wrap_field(out) diff --git a/src/gt4py/next/utils.py b/src/gt4py/next/utils.py index baae8361c5..ec459906e0 100644 --- a/src/gt4py/next/utils.py +++ b/src/gt4py/next/utils.py @@ -15,10 +15,6 @@ import functools from typing import Any, Callable, ClassVar, ParamSpec, TypeGuard, TypeVar, cast -import numpy as np - -from gt4py.next import common - class RecursionGuard: """ @@ -57,7 +53,6 @@ def __exit__(self, *exc): _T = TypeVar("_T") - _P = ParamSpec("_P") _R = TypeVar("_R") @@ -66,8 +61,17 @@ def is_tuple_of(v: Any, t: type[_T]) -> TypeGuard[tuple[_T, ...]]: return isinstance(v, tuple) and all(isinstance(e, t) for e in v) +# TODO(havogt): remove flatten duplications in the whole codebase +def flatten_nested_tuple(value: tuple[_T | tuple, ...]) -> tuple[_T, ...]: + if isinstance(value, tuple): + return sum((flatten_nested_tuple(v) for v in value), start=()) # type: ignore[arg-type] # cannot properly express nesting + else: + return (value,) + + def tree_map(fun: Callable[_P, _R]) -> Callable[..., _R | tuple[_R | tuple, ...]]: - """Apply `fun` to each entry of (possibly nested) tuples. + """ + Apply `fun` to each entry of (possibly nested) tuples. Examples: >>> tree_map(lambda x: x + 1)(((1, 2), 3)) @@ -88,9 +92,3 @@ def impl(*args: Any | tuple[Any | tuple, ...]) -> _R | tuple[_R | tuple, ...]: ) # mypy doesn't understand that `args` at this point is of type `_P.args` return impl - - -# TODO(havogt): consider moving to module like `field_utils` -@tree_map -def asnumpy(field: common.Field | np.ndarray) -> np.ndarray: - return field.asnumpy() if common.is_field(field) else field # type: ignore[return-value] # mypy doesn't understand the condition diff --git a/tests/next_tests/exclusion_matrices.py b/tests/next_tests/exclusion_matrices.py index 84287e209f..3c42a180dd 100644 --- a/tests/next_tests/exclusion_matrices.py +++ b/tests/next_tests/exclusion_matrices.py @@ -130,7 +130,6 @@ class ProgramFormatterId(_PythonObjectIdMixin, str, enum.Enum): (USES_ZERO_DIMENSIONAL_FIELDS, XFAIL, UNSUPPORTED_MESSAGE), ] EMBEDDED_SKIP_LIST = [ - (USES_SCAN, XFAIL, UNSUPPORTED_MESSAGE), (USES_DYNAMIC_OFFSETS, XFAIL, UNSUPPORTED_MESSAGE), (CHECKS_SPECIFIC_ERROR, XFAIL, UNSUPPORTED_MESSAGE), ] diff --git a/tests/next_tests/integration_tests/cases.py b/tests/next_tests/integration_tests/cases.py index 81f216397b..b1e26b40cb 100644 --- a/tests/next_tests/integration_tests/cases.py +++ b/tests/next_tests/integration_tests/cases.py @@ -28,7 +28,7 @@ from gt4py._core import definitions as core_defs from gt4py.eve import extended_typing as xtyping from gt4py.eve.extended_typing import Self -from gt4py.next import common, constructors, utils +from gt4py.next import common, constructors, field_utils from gt4py.next.ffront import decorator from gt4py.next.program_processors import processor_interface as ppi from gt4py.next.type_system import type_specifications as ts, type_translation @@ -436,8 +436,8 @@ def verify( out_comp = out or inout assert out_comp is not None - out_comp_ndarray = utils.asnumpy(out_comp) - ref_ndarray = utils.asnumpy(ref) + out_comp_ndarray = field_utils.asnumpy(out_comp) + ref_ndarray = field_utils.asnumpy(ref) assert comparison(ref_ndarray, out_comp_ndarray), ( f"Verification failed:\n" f"\tcomparison={comparison.__name__}(ref, out)\n" diff --git a/tests/next_tests/integration_tests/feature_tests/ffront_tests/test_execution.py b/tests/next_tests/integration_tests/feature_tests/ffront_tests/test_execution.py index 7f37b41383..51f853d41d 100644 --- a/tests/next_tests/integration_tests/feature_tests/ffront_tests/test_execution.py +++ b/tests/next_tests/integration_tests/feature_tests/ffront_tests/test_execution.py @@ -292,6 +292,7 @@ def testee_op( cases.verify(cartesian_case, testee_op, qc, tuple_scalar, out=qc, ref=expected) +@pytest.mark.uses_cartesian_shift @pytest.mark.uses_scan @pytest.mark.uses_index_fields def test_scalar_scan_vertical_offset(cartesian_case): # noqa: F811 # fixtures @@ -802,6 +803,85 @@ def simple_scan_operator(carry: float, a: tuple[float, float]) -> float: cases.verify(cartesian_case, simple_scan_operator, (inp1, inp2), out=out, ref=expected) +@pytest.mark.uses_scan +def test_scan_different_domain_in_tuple(cartesian_case): + init = 1.0 + i_size = cartesian_case.default_sizes[IDim] + k_size = cartesian_case.default_sizes[KDim] + + inp1_np = np.ones( + ( + i_size + 1, + k_size, + ) + ) # i_size bigger than in the other argument + inp2_np = np.fromfunction(lambda i, k: k, shape=(i_size, k_size), dtype=float) + inp1 = cartesian_case.as_field([IDim, KDim], inp1_np) + inp2 = cartesian_case.as_field([IDim, KDim], inp2_np) + out = cartesian_case.as_field([IDim, KDim], np.zeros((i_size, k_size))) + + def prev_levels_iterator(i): + return range(i + 1) + + expected = np.asarray( + [ + reduce( + lambda prev, k: prev + inp1_np[:-1, k] + inp2_np[:, k], + prev_levels_iterator(k), + init, + ) + for k in range(k_size) + ] + ).transpose() + + @gtx.scan_operator(axis=KDim, forward=True, init=init) + def scan_op(carry: float, a: tuple[float, float]) -> float: + return carry + a[0] + a[1] + + @gtx.field_operator + def foo( + inp1: gtx.Field[[IDim, KDim], float], inp2: gtx.Field[[IDim, KDim], float] + ) -> gtx.Field[[IDim, KDim], float]: + return scan_op((inp1, inp2)) + + cases.verify(cartesian_case, foo, inp1, inp2, out=out, ref=expected) + + +@pytest.mark.uses_scan +def test_scan_tuple_field_scalar_mixed(cartesian_case): + init = 1.0 + i_size = cartesian_case.default_sizes[IDim] + k_size = cartesian_case.default_sizes[KDim] + + inp2_np = np.fromfunction(lambda i, k: k, shape=(i_size, k_size), dtype=float) + inp2 = cartesian_case.as_field([IDim, KDim], inp2_np) + out = cartesian_case.as_field([IDim, KDim], np.zeros((i_size, k_size))) + + def prev_levels_iterator(i): + return range(i + 1) + + expected = np.asarray( + [ + reduce( + lambda prev, k: prev + 1.0 + inp2_np[:, k], + prev_levels_iterator(k), + init, + ) + for k in range(k_size) + ] + ).transpose() + + @gtx.scan_operator(axis=KDim, forward=True, init=init) + def scan_op(carry: float, a: tuple[float, float]) -> float: + return carry + a[0] + a[1] + + @gtx.field_operator + def foo(inp1: float, inp2: gtx.Field[[IDim, KDim], float]) -> gtx.Field[[IDim, KDim], float]: + return scan_op((inp1, inp2)) + + cases.verify(cartesian_case, foo, 1.0, inp2, out=out, ref=expected) + + def test_docstring(cartesian_case): @gtx.field_operator def fieldop_with_docstring(a: cases.IField) -> cases.IField: diff --git a/tests/next_tests/integration_tests/multi_feature_tests/iterator_tests/test_column_stencil.py b/tests/next_tests/integration_tests/multi_feature_tests/iterator_tests/test_column_stencil.py index fd571514ac..9ba8eef3a3 100644 --- a/tests/next_tests/integration_tests/multi_feature_tests/iterator_tests/test_column_stencil.py +++ b/tests/next_tests/integration_tests/multi_feature_tests/iterator_tests/test_column_stencil.py @@ -16,7 +16,7 @@ import pytest import gt4py.next as gtx -from gt4py.next import utils +from gt4py.next import field_utils from gt4py.next.iterator.builtins import * from gt4py.next.iterator.runtime import closure, fendef, fundef, offset @@ -158,7 +158,7 @@ def test_k_level_condition(program_processor, lift_mode, fun, k_level, inp_funct k_size = 5 inp = inp_function(k_size) - ref = ref_function(utils.asnumpy(inp)) + ref = ref_function(field_utils.asnumpy(inp)) out = gtx.as_field([KDim], np.zeros((5,), dtype=np.int32)) diff --git a/tests/next_tests/unit_tests/embedded_tests/test_common.py b/tests/next_tests/unit_tests/embedded_tests/test_common.py index 640ed326bb..de511fdabb 100644 --- a/tests/next_tests/unit_tests/embedded_tests/test_common.py +++ b/tests/next_tests/unit_tests/embedded_tests/test_common.py @@ -19,7 +19,7 @@ from gt4py.next import common from gt4py.next.common import UnitRange from gt4py.next.embedded import exceptions as embedded_exceptions -from gt4py.next.embedded.common import _slice_range, sub_domain +from gt4py.next.embedded.common import _slice_range, iterate_domain, sub_domain @pytest.mark.parametrize( @@ -135,3 +135,15 @@ def test_sub_domain(domain, index, expected): expected = common.domain(expected) result = sub_domain(domain, index) assert result == expected + + +def test_iterate_domain(): + domain = common.domain({I: 2, J: 3}) + ref = [] + for i in domain[I][1]: + for j in domain[J][1]: + ref.append(((I, i), (J, j))) + + testee = list(iterate_domain(domain)) + + assert testee == ref diff --git a/tests/next_tests/unit_tests/iterator_tests/test_embedded_internals.py b/tests/next_tests/unit_tests/iterator_tests/test_embedded_internals.py index 3a35570ca2..9238cd4f7a 100644 --- a/tests/next_tests/unit_tests/iterator_tests/test_embedded_internals.py +++ b/tests/next_tests/unit_tests/iterator_tests/test_embedded_internals.py @@ -19,13 +19,14 @@ import numpy as np import pytest +from gt4py.next import common from gt4py.next.iterator import embedded def _run_within_context( func: Callable[[], Any], *, - column_range: Optional[range] = None, + column_range: Optional[common.NamedRange] = None, offset_provider: Optional[embedded.OffsetProvider] = None, ) -> Any: def wrapped_func(): @@ -59,7 +60,10 @@ def test_func(data_a: int, data_b: int): # Setting an invalid column_range here shouldn't affect other contexts embedded.column_range_cvar.set(range(2, 999)) - _run_within_context(lambda: test_func(2, 3), column_range=range(0, 3)) + _run_within_context( + lambda: test_func(2, 3), + column_range=(common.Dimension("K", kind=common.DimensionKind.VERTICAL), range(0, 3)), + ) def test_column_ufunc_with_scalar(): From 3f595ffd6206b5bf3344b7288f98ac8e82adba52 Mon Sep 17 00:00:00 2001 From: edopao Date: Tue, 12 Dec 2023 17:29:32 +0100 Subject: [PATCH 28/85] feat[next][dace]: Fix for broken DaCe test (#1396) Fix for broken DaCe test in baseline: - use `flatten_list` to get `ValueExpr` arguments to numeric builtin function Additionally, enable test for DaCe backend (left-over from PR #1393). --- .../runners/dace_iterator/itir_to_tasklet.py | 4 +--- .../feature_tests/iterator_tests/test_conditional.py | 1 - 2 files changed, 1 insertion(+), 4 deletions(-) diff --git a/src/gt4py/next/program_processors/runners/dace_iterator/itir_to_tasklet.py b/src/gt4py/next/program_processors/runners/dace_iterator/itir_to_tasklet.py index 32b8cbf2b1..d10a14a1ee 100644 --- a/src/gt4py/next/program_processors/runners/dace_iterator/itir_to_tasklet.py +++ b/src/gt4py/next/program_processors/runners/dace_iterator/itir_to_tasklet.py @@ -1010,9 +1010,7 @@ def _visit_reduce(self, node: itir.FunCall): def _visit_numeric_builtin(self, node: itir.FunCall) -> list[ValueExpr]: assert isinstance(node.fun, itir.SymRef) fmt = _MATH_BUILTINS_MAPPING[str(node.fun.id)] - args: list[SymbolExpr | ValueExpr] = list( - itertools.chain(*[self.visit(arg) for arg in node.args]) - ) + args = flatten_list(self.visit(node.args)) expr_args = [ (arg, f"{arg.value.data}_v") for arg in args if not isinstance(arg, SymbolExpr) ] diff --git a/tests/next_tests/integration_tests/feature_tests/iterator_tests/test_conditional.py b/tests/next_tests/integration_tests/feature_tests/iterator_tests/test_conditional.py index 8536dbea90..db7776b2f4 100644 --- a/tests/next_tests/integration_tests/feature_tests/iterator_tests/test_conditional.py +++ b/tests/next_tests/integration_tests/feature_tests/iterator_tests/test_conditional.py @@ -31,7 +31,6 @@ def stencil_conditional(inp): return tuple_get(0, tmp) + tuple_get(1, tmp) -@pytest.mark.uses_tuple_returns def test_conditional_w_tuple(program_processor): program_processor, validate = program_processor From a5b2450e282add00fe90b8cf98cd68d96d42b1ea Mon Sep 17 00:00:00 2001 From: Rico Haeuselmann Date: Wed, 13 Dec 2023 11:41:16 +0100 Subject: [PATCH 29/85] style[next]: standardize error messages. (#1386) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit - add style guide to the coding guidelines - fix existing error messages in next - deal with ensuing qa errorrs / test updates - unshadow one test and fix the code it wasn't testing Co-authored-by: Rico Häuselmann Co-authored-by: Enrique González Paredes --- CODING_GUIDELINES.md | 38 +++++ src/gt4py/_core/definitions.py | 22 ++- src/gt4py/next/allocators.py | 8 +- src/gt4py/next/common.py | 91 ++++++------ src/gt4py/next/constructors.py | 15 +- src/gt4py/next/embedded/common.py | 6 +- src/gt4py/next/embedded/nd_array_field.py | 37 ++--- src/gt4py/next/errors/exceptions.py | 10 +- .../next/ffront/ast_passes/simple_assign.py | 2 +- .../ffront/ast_passes/single_static_assign.py | 2 +- src/gt4py/next/ffront/decorator.py | 29 ++-- src/gt4py/next/ffront/fbuiltins.py | 7 +- src/gt4py/next/ffront/foast_introspection.py | 2 +- .../foast_passes/closure_var_folding.py | 2 +- .../ffront/foast_passes/type_deduction.py | 133 +++++++++--------- src/gt4py/next/ffront/foast_pretty_printer.py | 2 +- src/gt4py/next/ffront/foast_to_itir.py | 10 +- src/gt4py/next/ffront/func_to_foast.py | 23 +-- src/gt4py/next/ffront/func_to_past.py | 4 +- .../next/ffront/past_passes/type_deduction.py | 42 +++--- src/gt4py/next/ffront/past_to_itir.py | 36 ++--- src/gt4py/next/ffront/source_utils.py | 8 +- src/gt4py/next/ffront/type_info.py | 2 +- src/gt4py/next/iterator/dispatcher.py | 2 +- src/gt4py/next/iterator/embedded.py | 26 ++-- src/gt4py/next/iterator/ir.py | 8 +- src/gt4py/next/iterator/ir_utils/ir_makers.py | 2 +- src/gt4py/next/iterator/runtime.py | 2 +- src/gt4py/next/iterator/tracing.py | 6 +- src/gt4py/next/iterator/transforms/cse.py | 4 +- .../next/iterator/transforms/pass_manager.py | 4 +- .../next/iterator/transforms/unroll_reduce.py | 6 +- src/gt4py/next/iterator/type_inference.py | 29 ++-- src/gt4py/next/otf/binding/nanobind.py | 2 +- .../compilation/build_systems/cmake_lists.py | 4 +- src/gt4py/next/otf/compilation/compiler.py | 2 +- src/gt4py/next/otf/stages.py | 2 +- src/gt4py/next/otf/workflow.py | 6 +- .../codegens/gtfn/gtfn_ir_to_gtfn_im_ir.py | 6 +- .../codegens/gtfn/gtfn_module.py | 6 +- .../codegens/gtfn/itir_to_gtfn_ir.py | 20 +-- .../program_processors/processor_interface.py | 88 ++++++++---- .../runners/dace_iterator/__init__.py | 4 +- .../runners/dace_iterator/itir_to_tasklet.py | 4 +- .../runners/dace_iterator/utility.py | 2 +- .../next/program_processors/runners/gtfn.py | 6 +- src/gt4py/next/type_system/type_info.py | 48 ++++--- .../next/type_system/type_translation.py | 34 ++--- tests/next_tests/integration_tests/cases.py | 16 +-- .../ffront_tests/ffront_test_utils.py | 3 +- .../ffront_tests/test_arg_call_interface.py | 8 +- .../ffront_tests/test_execution.py | 8 +- .../test_math_builtin_execution.py | 2 +- .../ffront_tests/test_math_unary_builtins.py | 4 +- .../ffront_tests/test_program.py | 4 +- .../ffront_tests/test_scalar_if.py | 4 +- .../ffront_tests/test_type_deduction.py | 68 ++++----- .../iterator_tests/test_builtins.py | 2 +- tests/next_tests/unit_tests/conftest.py | 2 +- .../embedded_tests/test_nd_array_field.py | 2 +- .../ffront_tests/test_func_to_foast.py | 14 +- .../ffront_tests/test_func_to_past.py | 18 +-- .../ffront_tests/test_past_to_itir.py | 4 +- .../iterator_tests/test_runtime_domain.py | 2 +- .../test_processor_interface.py | 4 +- .../next_tests/unit_tests/test_allocators.py | 2 +- tests/next_tests/unit_tests/test_common.py | 2 +- .../unit_tests/test_constructors.py | 4 +- .../test_type_translation.py | 2 +- 69 files changed, 571 insertions(+), 458 deletions(-) diff --git a/CODING_GUIDELINES.md b/CODING_GUIDELINES.md index 957df0fb04..9376644064 100644 --- a/CODING_GUIDELINES.md +++ b/CODING_GUIDELINES.md @@ -51,6 +51,44 @@ We deviate from the [Google Python Style Guide][google-style-guide] only in the - Client code (like tests, doctests and examples) should use the above style for public FieldView API - Library code should always import the defining module and use qualified names. +### Error messages + +Error messages should be written as sentences, starting with a capital letter and ending with a period (avoid exclamation marks). Try to be informative without being verbose. Code objects such as 'ClassNames' and 'function_names' should be enclosed in single quotes, and so should string values used for message interpolation. + +Examples: + +```python +raise ValueError(f"Invalid argument 'dimension': should be of type 'Dimension', got '{dimension.type}'.") +``` + +Interpolated integer values do not need double quotes, if they are indicating an amount. Example: + +```python +raise ValueError(f"Invalid number of arguments: expected 3 arguments, got {len(args)}.") +``` + +The double quotes can also be dropped when presenting a sequence of values. In this case the message should be rephrased so the sequence is separated from the text by a colon ':'. + +```python +raise ValueError(f"unexpected keyword arguments: {', '.join(set(kwarg_names} - set(expected_kwarg_names)))}.") +``` + +The message should be kept to one sentence if reasonably possible. Ideally the sentence should be kept short and avoid unneccessary words. Examples: + +```python +# too many sentences +raise ValueError(f"Received an unexpeted number of arguments. Should receive 5 arguments, but got {len(args)}. Please provide the correct number of arguments.") +# better +raise ValueError(f"Wrong number of arguments: expected 5, got {len(args)}.") + +# less extreme +raise TypeError(f"Wrong argument type. Can only accept 'int's, got '{type(arg)}' instead.") +# but can still be improved +raise TypeError(f"Wrong argument type: 'int' expected, got '{type(arg)}'") +``` + +The terseness vs. helpfulness tradeoff should be more in favor of terseness for internal error messages and more in favor of helpfulness for `DSLError` and it's subclassses, where additional sentences are encouraged if they point out likely hidden sources of the problem or common fixes. + ### Docstrings We generate the API documentation automatically from the docstrings using [Sphinx][sphinx] and some extensions such as [Sphinx-autodoc][sphinx-autodoc] and [Sphinx-napoleon][sphinx-napoleon]. These follow the Google Python Style Guide docstring conventions to automatically format the generated documentation. A complete overview can be found here: [Example Google Style Python Docstrings](https://sphinxcontrib-napoleon.readthedocs.io/en/latest/example_google.html#example-google). diff --git a/src/gt4py/_core/definitions.py b/src/gt4py/_core/definitions.py index 0e6301ae0f..091fa77e3f 100644 --- a/src/gt4py/_core/definitions.py +++ b/src/gt4py/_core/definitions.py @@ -73,17 +73,23 @@ BoolScalar: TypeAlias = Union[bool_, bool] BoolT = TypeVar("BoolT", bound=BoolScalar) -BOOL_TYPES: Final[Tuple[type, ...]] = cast(Tuple[type, ...], BoolScalar.__args__) # type: ignore[attr-defined] +BOOL_TYPES: Final[Tuple[type, ...]] = cast( + Tuple[type, ...], BoolScalar.__args__ # type: ignore[attr-defined] +) IntScalar: TypeAlias = Union[int8, int16, int32, int64, int] IntT = TypeVar("IntT", bound=IntScalar) -INT_TYPES: Final[Tuple[type, ...]] = cast(Tuple[type, ...], IntScalar.__args__) # type: ignore[attr-defined] +INT_TYPES: Final[Tuple[type, ...]] = cast( + Tuple[type, ...], IntScalar.__args__ # type: ignore[attr-defined] +) UnsignedIntScalar: TypeAlias = Union[uint8, uint16, uint32, uint64] UnsignedIntT = TypeVar("UnsignedIntT", bound=UnsignedIntScalar) -UINT_TYPES: Final[Tuple[type, ...]] = cast(Tuple[type, ...], UnsignedIntScalar.__args__) # type: ignore[attr-defined] +UINT_TYPES: Final[Tuple[type, ...]] = cast( + Tuple[type, ...], UnsignedIntScalar.__args__ # type: ignore[attr-defined] +) IntegralScalar: TypeAlias = Union[IntScalar, UnsignedIntScalar] @@ -93,7 +99,9 @@ FloatingScalar: TypeAlias = Union[float32, float64, float] FloatingT = TypeVar("FloatingT", bound=FloatingScalar) -FLOAT_TYPES: Final[Tuple[type, ...]] = cast(Tuple[type, ...], FloatingScalar.__args__) # type: ignore[attr-defined] +FLOAT_TYPES: Final[Tuple[type, ...]] = cast( + Tuple[type, ...], FloatingScalar.__args__ # type: ignore[attr-defined] +) #: Type alias for all scalar types supported by GT4Py @@ -195,7 +203,7 @@ def dtype_kind(sc_type: Type[ScalarT]) -> DTypeKind: if issubclass(sc_type, numbers.Complex): return DTypeKind.COMPLEX - raise TypeError("Unknown scalar type kind") + raise TypeError("Unknown scalar type kind.") @dataclasses.dataclass(frozen=True) @@ -491,10 +499,10 @@ def __rtruediv__(self, other: Any) -> NDArrayObject: def __pow__(self, other: NDArrayObject | Scalar) -> NDArrayObject: ... - def __eq__(self, other: NDArrayObject | Scalar) -> NDArrayObject: # type: ignore[override] # mypy want to return `bool` + def __eq__(self, other: NDArrayObject | Scalar) -> NDArrayObject: # type: ignore[override] # mypy wants to return `bool` ... - def __ne__(self, other: NDArrayObject | Scalar) -> NDArrayObject: # type: ignore[override] # mypy want to return `bool` + def __ne__(self, other: NDArrayObject | Scalar) -> NDArrayObject: # type: ignore[override] # mypy wants to return `bool` ... def __gt__(self, other: NDArrayObject | Scalar) -> NDArrayObject: # type: ignore[misc] # Forward operator is not callable diff --git a/src/gt4py/next/allocators.py b/src/gt4py/next/allocators.py index 58600d8cda..97e83276fe 100644 --- a/src/gt4py/next/allocators.py +++ b/src/gt4py/next/allocators.py @@ -142,7 +142,9 @@ def get_allocator( elif not strict or is_field_allocator(default): return default else: - raise TypeError(f"Object {obj} is neither a field allocator nor a field allocator factory") + raise TypeError( + f"Object '{obj}' is neither a field allocator nor a field allocator factory." + ) @dataclasses.dataclass(frozen=True) @@ -331,7 +333,7 @@ def allocate( """ if device is None and allocator is None: - raise ValueError("No 'device' or 'allocator' specified") + raise ValueError("No 'device' or 'allocator' specified.") actual_allocator = get_allocator(allocator) if actual_allocator is None: assert device is not None # for mypy @@ -339,7 +341,7 @@ def allocate( elif device is None: device = core_defs.Device(actual_allocator.__gt_device_type__, 0) elif device.device_type != actual_allocator.__gt_device_type__: - raise ValueError(f"Device {device} and allocator {actual_allocator} are incompatible") + raise ValueError(f"Device '{device}' and allocator '{actual_allocator}' are incompatible.") return actual_allocator.__gt_allocate__( domain=common.domain(domain), diff --git a/src/gt4py/next/common.py b/src/gt4py/next/common.py index 7f1ad8c0bb..3e1fe52f31 100644 --- a/src/gt4py/next/common.py +++ b/src/gt4py/next/common.py @@ -125,7 +125,7 @@ def __getitem__(self, index: int | slice) -> int | UnitRange: # noqa: F811 # re if isinstance(index, slice): start, stop, step = index.indices(len(self)) if step != 1: - raise ValueError("UnitRange: step required to be `1`.") + raise ValueError("'UnitRange': step required to be '1'.") new_start = self.start + (start or 0) new_stop = (self.start if stop > 0 else self.stop) + stop return UnitRange(new_start, new_stop) @@ -136,7 +136,7 @@ def __getitem__(self, index: int | slice) -> int | UnitRange: # noqa: F811 # re if 0 <= index < len(self): return self.start + index else: - raise IndexError("UnitRange index out of range") + raise IndexError("'UnitRange' index out of range") def __and__(self, other: Set[int]) -> UnitRange: if isinstance(other, UnitRange): @@ -144,7 +144,9 @@ def __and__(self, other: Set[int]) -> UnitRange: stop = min(self.stop, other.stop) return UnitRange(start, stop) else: - raise NotImplementedError("Can only find the intersection between UnitRange instances.") + raise NotImplementedError( + "Can only find the intersection between 'UnitRange' instances." + ) def __le__(self, other: Set[int]): if isinstance(other, UnitRange): @@ -167,7 +169,7 @@ def __add__(self, other: int | Set[int]) -> UnitRange: ) ) else: - raise NotImplementedError("Can only compute union with int instances.") + raise NotImplementedError("Can only compute union with 'int' instances.") def __sub__(self, other: int | Set[int]) -> UnitRange: if isinstance(other, int): @@ -178,7 +180,7 @@ def __sub__(self, other: int | Set[int]) -> UnitRange: else: return self + (-other) else: - raise NotImplementedError("Can only compute substraction with int instances.") + raise NotImplementedError("Can only compute substraction with 'int' instances.") __ge__ = __lt__ = __gt__ = lambda self, other: NotImplemented @@ -199,7 +201,7 @@ def unit_range(r: RangeLike) -> UnitRange: return r if isinstance(r, range): if r.step != 1: - raise ValueError(f"`UnitRange` requires step size 1, got `{r.step}`.") + raise ValueError(f"'UnitRange' requires step size 1, got '{r.step}'.") return UnitRange(r.start, r.stop) # TODO(egparedes): use core_defs.IntegralScalar for `isinstance()` checks (see PEP 604) # once the related mypy bug (#16358) gets fixed @@ -211,7 +213,7 @@ def unit_range(r: RangeLike) -> UnitRange: return UnitRange(r[0], r[1]) if isinstance(r, core_defs.INTEGRAL_TYPES): return UnitRange(0, cast(core_defs.IntegralScalar, r)) - raise ValueError(f"`{r!r}` cannot be interpreted as `UnitRange`.") + raise ValueError(f"'{r!r}' cannot be interpreted as 'UnitRange'.") IntIndex: TypeAlias = int | core_defs.IntegralScalar @@ -296,20 +298,20 @@ def __init__( ) -> None: if dims is not None or ranges is not None: if dims is None and ranges is None: - raise ValueError("Either both none of `dims` and `ranges` must be specified.") + raise ValueError("Either both none of 'dims' and 'ranges' must be specified.") if len(args) > 0: raise ValueError( - "No extra `args` allowed when constructing fomr `dims` and `ranges`." + "No extra 'args' allowed when constructing fomr 'dims' and 'ranges'." ) assert dims is not None and ranges is not None # for mypy if not all(isinstance(dim, Dimension) for dim in dims): raise ValueError( - f"`dims` argument needs to be a `tuple[Dimension, ...], got `{dims}`." + f"'dims' argument needs to be a 'tuple[Dimension, ...]', got '{dims}'." ) if not all(isinstance(rng, UnitRange) for rng in ranges): raise ValueError( - f"`ranges` argument needs to be a `tuple[UnitRange, ...], got `{ranges}`." + f"'ranges' argument needs to be a 'tuple[UnitRange, ...]', got '{ranges}'." ) if len(dims) != len(ranges): raise ValueError( @@ -320,13 +322,15 @@ def __init__( object.__setattr__(self, "ranges", tuple(ranges)) else: if not all(is_named_range(arg) for arg in args): - raise ValueError(f"Elements of `Domain` need to be `NamedRange`s, got `{args}`.") + raise ValueError( + f"Elements of 'Domain' need to be instances of 'NamedRange', got '{args}'." + ) dims, ranges = zip(*args) if args else ((), ()) object.__setattr__(self, "dims", tuple(dims)) object.__setattr__(self, "ranges", tuple(ranges)) if len(set(self.dims)) != len(self.dims): - raise NotImplementedError(f"Domain dimensions must be unique, not {self.dims}.") + raise NotImplementedError(f"Domain dimensions must be unique, not '{self.dims}'.") def __len__(self) -> int: return len(self.ranges) @@ -365,7 +369,7 @@ def __getitem__( # noqa: F811 # redefine unused index_pos = self.dims.index(index) return self.dims[index_pos], self.ranges[index_pos] except ValueError: - raise KeyError(f"No Dimension of type {index} is present in the Domain.") + raise KeyError(f"No Dimension of type '{index}' is present in the Domain.") else: raise KeyError("Invalid index type, must be either int, slice, or Dimension.") @@ -415,10 +419,12 @@ def replace(self, index: int | Dimension, *named_ranges: NamedRange) -> Domain: if isinstance(index, Dimension): dim_index = self.dim_index(index) if dim_index is None: - raise ValueError(f"Dimension {index} not found in Domain.") + raise ValueError(f"Dimension '{index}' not found in Domain.") index = dim_index if not (-len(self.dims) <= index < len(self.dims)): - raise IndexError(f"Index {index} out of bounds for Domain of length {len(self.dims)}.") + raise IndexError( + f"Index '{index}' out of bounds for Domain of length {len(self.dims)}." + ) if index < 0: index += len(self.dims) new_dims, new_ranges = zip(*named_ranges) if len(named_ranges) > 0 else ((), ()) @@ -462,13 +468,16 @@ def domain(domain_like: DomainLike) -> Domain: if all(isinstance(elem, core_defs.INTEGRAL_TYPES) for elem in domain_like.values()): return Domain( dims=tuple(domain_like.keys()), - ranges=tuple(UnitRange(0, s) for s in domain_like.values()), # type: ignore[arg-type] # type of `s` is checked in condition + ranges=tuple( + UnitRange(0, s) # type: ignore[arg-type] # type of `s` is checked in condition + for s in domain_like.values() + ), ) return Domain( dims=tuple(domain_like.keys()), ranges=tuple(unit_range(r) for r in domain_like.values()), ) - raise ValueError(f"`{domain_like}` is not `DomainLike`.") + raise ValueError(f"'{domain_like}' is not 'DomainLike'.") def _broadcast_ranges( @@ -670,7 +679,8 @@ class ConnectivityKind(enum.Flag): @extended_runtime_checkable -class ConnectivityField(Field[DimsT, core_defs.IntegralScalar], Protocol[DimsT, DimT]): # type: ignore[misc] # DimT should be covariant, but break in another place +# type: ignore[misc] # DimT should be covariant, but break in another place +class ConnectivityField(Field[DimsT, core_defs.IntegralScalar], Protocol[DimsT, DimT]): @property @abc.abstractmethod def codomain(self) -> DimT: @@ -690,61 +700,61 @@ def inverse_image(self, image_range: UnitRange | NamedRange) -> Sequence[NamedRa # Operators def __abs__(self) -> Never: - raise TypeError("ConnectivityField does not support this operation") + raise TypeError("'ConnectivityField' does not support this operation.") def __neg__(self) -> Never: - raise TypeError("ConnectivityField does not support this operation") + raise TypeError("'ConnectivityField' does not support this operation.") def __invert__(self) -> Never: - raise TypeError("ConnectivityField does not support this operation") + raise TypeError("'ConnectivityField' does not support this operation.") def __eq__(self, other: Any) -> Never: - raise TypeError("ConnectivityField does not support this operation") + raise TypeError("'ConnectivityField' does not support this operation.") def __ne__(self, other: Any) -> Never: - raise TypeError("ConnectivityField does not support this operation") + raise TypeError("'ConnectivityField' does not support this operation.") def __add__(self, other: Field | core_defs.IntegralScalar) -> Never: - raise TypeError("ConnectivityField does not support this operation") + raise TypeError("'ConnectivityField' does not support this operation.") def __radd__(self, other: Field | core_defs.IntegralScalar) -> Never: # type: ignore[misc] # Forward operator not callalbe - raise TypeError("ConnectivityField does not support this operation") + raise TypeError("'ConnectivityField' does not support this operation.") def __sub__(self, other: Field | core_defs.IntegralScalar) -> Never: - raise TypeError("ConnectivityField does not support this operation") + raise TypeError("'ConnectivityField' does not support this operation.") def __rsub__(self, other: Field | core_defs.IntegralScalar) -> Never: # type: ignore[misc] # Forward operator not callalbe - raise TypeError("ConnectivityField does not support this operation") + raise TypeError("'ConnectivityField' does not support this operation.") def __mul__(self, other: Field | core_defs.IntegralScalar) -> Never: - raise TypeError("ConnectivityField does not support this operation") + raise TypeError("'ConnectivityField' does not support this operation.") def __rmul__(self, other: Field | core_defs.IntegralScalar) -> Never: # type: ignore[misc] # Forward operator not callalbe - raise TypeError("ConnectivityField does not support this operation") + raise TypeError("'ConnectivityField' does not support this operation.") def __truediv__(self, other: Field | core_defs.IntegralScalar) -> Never: - raise TypeError("ConnectivityField does not support this operation") + raise TypeError("'ConnectivityField' does not support this operation.") def __rtruediv__(self, other: Field | core_defs.IntegralScalar) -> Never: # type: ignore[misc] # Forward operator not callalbe - raise TypeError("ConnectivityField does not support this operation") + raise TypeError("'ConnectivityField' does not support this operation.") def __floordiv__(self, other: Field | core_defs.IntegralScalar) -> Never: - raise TypeError("ConnectivityField does not support this operation") + raise TypeError("'ConnectivityField' does not support this operation.") def __rfloordiv__(self, other: Field | core_defs.IntegralScalar) -> Never: # type: ignore[misc] # Forward operator not callalbe - raise TypeError("ConnectivityField does not support this operation") + raise TypeError("'ConnectivityField' does not support this operation.") def __pow__(self, other: Field | core_defs.IntegralScalar) -> Never: - raise TypeError("ConnectivityField does not support this operation") + raise TypeError("'ConnectivityField' does not support this operation.") def __and__(self, other: Field | core_defs.IntegralScalar) -> Never: - raise TypeError("ConnectivityField does not support this operation") + raise TypeError("'ConnectivityField' does not support this operation.") def __or__(self, other: Field | core_defs.IntegralScalar) -> Never: - raise TypeError("ConnectivityField does not support this operation") + raise TypeError("'ConnectivityField' does not support this operation.") def __xor__(self, other: Field | core_defs.IntegralScalar) -> Never: - raise TypeError("ConnectivityField does not support this operation") + raise TypeError("'ConnectivityField' does not support this operation.") def is_connectivity_field( @@ -845,7 +855,7 @@ def __gt_dims__(self) -> tuple[Dimension, ...]: @property def __gt_origin__(self) -> Never: - raise TypeError("CartesianConnectivity does not support this operation") + raise TypeError("'CartesianConnectivity' does not support this operation.") @property def dtype(self) -> core_defs.DType[core_defs.IntegralScalar]: @@ -877,7 +887,7 @@ def inverse_image(self, image_range: UnitRange | NamedRange) -> Sequence[NamedRa if not isinstance(image_range, UnitRange): if image_range[0] != self.codomain: raise ValueError( - f"Dimension {image_range[0]} does not match the codomain dimension {self.codomain}" + f"Dimension '{image_range[0]}' does not match the codomain dimension '{self.codomain}'." ) image_range = image_range[1] @@ -1017,3 +1027,4 @@ def register_builtin_func( @classmethod def __gt_builtin_func__(cls, /, func: fbuiltins.BuiltInFunction[_R, _P]) -> Callable[_P, _R]: return cls._builtin_func_map.get(func, NotImplemented) + return cls._builtin_func_map.get(func, NotImplemented) diff --git a/src/gt4py/next/constructors.py b/src/gt4py/next/constructors.py index 63fde1cfde..9bb4cf17e5 100644 --- a/src/gt4py/next/constructors.py +++ b/src/gt4py/next/constructors.py @@ -254,12 +254,12 @@ def as_field( domain = cast(Sequence[common.Dimension], domain) if len(domain) != data.ndim: raise ValueError( - f"Cannot construct `Field` from array of shape `{data.shape}` and domain `{domain}` " + f"Cannot construct 'Field' from array of shape '{data.shape}' and domain '{domain}'." ) if origin: domain_dims = set(domain) if unknown_dims := set(origin.keys()) - domain_dims: - raise ValueError(f"Origin keys {unknown_dims} not in domain {domain}") + raise ValueError(f"Origin keys {unknown_dims} not in domain {domain}.") else: origin = {} actual_domain = common.domain( @@ -277,7 +277,7 @@ def as_field( # already the correct layout and device. shape = storage_utils.asarray(data).shape if shape != actual_domain.shape: - raise ValueError(f"Cannot construct `Field` from array of shape `{shape}` ") + raise ValueError(f"Cannot construct 'Field' from array of shape '{shape}'.") if dtype is None: dtype = storage_utils.asarray(data).dtype dtype = core_defs.dtype(dtype) @@ -334,20 +334,20 @@ def as_connectivity( domain = cast(Sequence[common.Dimension], domain) if len(domain) != data.ndim: raise ValueError( - f"Cannot construct `Field` from array of shape `{data.shape}` and domain `{domain}` " + f"Cannot construct 'Field' from array of shape '{data.shape}' and domain '{domain}'." ) actual_domain = common.domain([(d, (0, s)) for d, s in zip(domain, data.shape)]) else: actual_domain = common.domain(cast(common.DomainLike, domain)) if not isinstance(codomain, common.Dimension): - raise ValueError(f"Invalid codomain dimension `{codomain}`") + raise ValueError(f"Invalid codomain dimension '{codomain}'.") # TODO(egparedes): allow zero-copy construction (no reallocation) if buffer has # already the correct layout and device. shape = storage_utils.asarray(data).shape if shape != actual_domain.shape: - raise ValueError(f"Cannot construct `Field` from array of shape `{shape}` ") + raise ValueError(f"Cannot construct 'Field' from array of shape '{shape}'.") if dtype is None: dtype = storage_utils.asarray(data).dtype dtype = core_defs.dtype(dtype) @@ -356,7 +356,8 @@ def as_connectivity( if (allocator is None) and (device is None) and xtyping.supports_dlpack(data): device = core_defs.Device(*data.__dlpack_device__()) buffer = next_allocators.allocate(actual_domain, dtype, allocator=allocator, device=device) - buffer.ndarray[...] = storage_utils.asarray(data) # type: ignore[index] # TODO(havogt): consider addin MutableNDArrayObject + # TODO(havogt): consider addin MutableNDArrayObject + buffer.ndarray[...] = storage_utils.asarray(data) # type: ignore[index] connectivity_field = common.connectivity( buffer.ndarray, codomain=codomain, domain=actual_domain ) diff --git a/src/gt4py/next/embedded/common.py b/src/gt4py/next/embedded/common.py index 558730cb82..87e0800a10 100644 --- a/src/gt4py/next/embedded/common.py +++ b/src/gt4py/next/embedded/common.py @@ -32,7 +32,7 @@ def sub_domain(domain: common.Domain, index: common.AnyIndexSpec) -> common.Doma if common.is_relative_index_sequence(index_sequence): return _relative_sub_domain(domain, index_sequence) - raise IndexError(f"Unsupported index type: {index}") + raise IndexError(f"Unsupported index type: '{index}'.") def _relative_sub_domain( @@ -42,7 +42,9 @@ def _relative_sub_domain( expanded = _expand_ellipsis(index, len(domain)) if len(domain) < len(expanded): - raise IndexError(f"Trying to index a `Field` with {len(domain)} dimensions with {index}.") + raise IndexError( + f"Can not access dimension with index {index} of 'Field' with {len(domain)} dimensions." + ) expanded += (slice(None),) * (len(domain) - len(expanded)) for (dim, rng), idx in zip(domain, expanded, strict=True): if isinstance(idx, slice): diff --git a/src/gt4py/next/embedded/nd_array_field.py b/src/gt4py/next/embedded/nd_array_field.py index 6b69e8f8cc..fbfe64ac42 100644 --- a/src/gt4py/next/embedded/nd_array_field.py +++ b/src/gt4py/next/embedded/nd_array_field.py @@ -174,7 +174,7 @@ def remap( dim = connectivity.codomain dim_idx = self.domain.dim_index(dim) if dim_idx is None: - raise ValueError(f"Incompatible index field, expected a field with dimension {dim}.") + raise ValueError(f"Incompatible index field, expected a field with dimension '{dim}'.") current_range: common.UnitRange = self.domain[dim_idx][1] new_ranges = connectivity.inverse_image(current_range) @@ -226,7 +226,7 @@ def __setitem__( if common.is_field(value): if not value.domain == target_domain: raise ValueError( - f"Incompatible `Domain` in assignment. Source domain = {value.domain}, target domain = {target_domain}." + f"Incompatible 'Domain' in assignment. Source domain = '{value.domain}', target domain = '{target_domain}'." ) value = value.ndarray @@ -268,28 +268,28 @@ def __setitem__( def __and__(self, other: common.Field | core_defs.ScalarT) -> NdArrayField: if self.dtype == core_defs.BoolDType(): return _make_builtin("logical_and", "logical_and")(self, other) - raise NotImplementedError("`__and__` not implemented for non-`bool` fields.") + raise NotImplementedError("'__and__' not implemented for non-'bool' fields.") __rand__ = __and__ def __or__(self, other: common.Field | core_defs.ScalarT) -> NdArrayField: if self.dtype == core_defs.BoolDType(): return _make_builtin("logical_or", "logical_or")(self, other) - raise NotImplementedError("`__or__` not implemented for non-`bool` fields.") + raise NotImplementedError("'__or__' not implemented for non-'bool' fields.") __ror__ = __or__ def __xor__(self, other: common.Field | core_defs.ScalarT) -> NdArrayField: if self.dtype == core_defs.BoolDType(): return _make_builtin("logical_xor", "logical_xor")(self, other) - raise NotImplementedError("`__xor__` not implemented for non-`bool` fields.") + raise NotImplementedError("'__xor__' not implemented for non-'bool' fields.") __rxor__ = __xor__ def __invert__(self) -> NdArrayField: if self.dtype == core_defs.BoolDType(): return _make_builtin("invert", "invert")(self) - raise NotImplementedError("`__invert__` not implemented for non-`bool` fields.") + raise NotImplementedError("'__invert__' not implemented for non-'bool' fields.") def _slice( self, index: common.AnyIndexSpec @@ -322,7 +322,8 @@ def __gt_builtin_func__(cls, _: fbuiltins.BuiltInFunction) -> Never: # type: ig raise NotImplementedError() @property - def codomain(self) -> common.DimT: # type: ignore[override] # TODO(havogt): instead of inheriting from NdArrayField, steal implementation or common base + # type: ignore[override] # TODO(havogt): instead of inheriting from NdArrayField, steal implementation or common base + def codomain(self) -> common.DimT: return self._codomain @functools.cached_property @@ -378,7 +379,7 @@ def inverse_image( ): # TODO(havogt): cleanup duplication with CartesianConnectivity if image_range[0] != self.codomain: raise ValueError( - f"Dimension {image_range[0]} does not match the codomain dimension {self.codomain}" + f"Dimension '{image_range[0]}' does not match the codomain dimension '{self.codomain}'." ) image_range = image_range[1] @@ -423,7 +424,7 @@ def inverse_image( if non_contiguous_dims: raise ValueError( - f"Restriction generates non-contiguous dimensions {non_contiguous_dims}" + f"Restriction generates non-contiguous dimensions '{non_contiguous_dims}'." ) return new_dims @@ -446,8 +447,12 @@ def restrict(self, index: common.AnyIndexSpec) -> common.Field | core_defs.Integ # -- Specialized implementations for builtin operations on array fields -- -NdArrayField.register_builtin_func(fbuiltins.abs, NdArrayField.__abs__) # type: ignore[attr-defined] -NdArrayField.register_builtin_func(fbuiltins.power, NdArrayField.__pow__) # type: ignore[attr-defined] +NdArrayField.register_builtin_func( + fbuiltins.abs, NdArrayField.__abs__ # type: ignore[attr-defined] +) +NdArrayField.register_builtin_func( + fbuiltins.power, NdArrayField.__pow__ # type: ignore[attr-defined] +) # TODO gamma for name in ( @@ -480,7 +485,7 @@ def _builtin_op( if not axis.kind == common.DimensionKind.LOCAL: raise ValueError("Can only reduce local dimensions.") if axis not in field.domain.dims: - raise ValueError(f"Field doesn't have dimension {axis}. Cannot reduce.") + raise ValueError(f"Field can not be reduced as it doesn't have dimension '{axis}'.") reduce_dim_index = field.domain.dims.index(axis) new_domain = common.Domain(*[nr for nr in field.domain if nr[0] != axis]) return field.__class__.from_array( @@ -547,7 +552,7 @@ def __setitem__( value: common.Field | core_defs.NDArrayObject | core_defs.ScalarT, ) -> None: # TODO(havogt): use something like `self.ndarray = self.ndarray.at(index).set(value)` - raise NotImplementedError("`__setitem__` for JaxArrayField not yet implemented.") + raise NotImplementedError("'__setitem__' for JaxArrayField not yet implemented.") common.field.register(jnp.ndarray, JaxArrayField.from_array) @@ -572,7 +577,7 @@ def _builtins_broadcast( ) -> common.Field: # separated for typing reasons if common.is_field(field): return _broadcast(field, new_dimensions) - raise AssertionError("Scalar case not reachable from `fbuiltins.broadcast`.") + raise AssertionError("Scalar case not reachable from 'fbuiltins.broadcast'.") NdArrayField.register_builtin_func(fbuiltins.broadcast, _builtins_broadcast) @@ -581,7 +586,7 @@ def _builtins_broadcast( def _astype(field: common.Field | core_defs.ScalarT | tuple, type_: type) -> NdArrayField: if isinstance(field, NdArrayField): return field.__class__.from_array(field.ndarray.astype(type_), domain=field.domain) - raise AssertionError("This is the NdArrayField implementation of `fbuiltins.astype`.") + raise AssertionError("This is the NdArrayField implementation of 'fbuiltins.astype'.") NdArrayField.register_builtin_func(fbuiltins.astype, _astype) @@ -643,4 +648,4 @@ def _compute_slice( elif common.is_int_index(rng): return rng - domain.ranges[pos].start else: - raise ValueError(f"Can only use integer or UnitRange ranges, provided type: {type(rng)}") + raise ValueError(f"Can only use integer or UnitRange ranges, provided type: '{type(rng)}'.") diff --git a/src/gt4py/next/errors/exceptions.py b/src/gt4py/next/errors/exceptions.py index e956858549..081453c023 100644 --- a/src/gt4py/next/errors/exceptions.py +++ b/src/gt4py/next/errors/exceptions.py @@ -61,7 +61,7 @@ class UnsupportedPythonFeatureError(DSLError): feature: str def __init__(self, location: Optional[SourceLocation], feature: str) -> None: - super().__init__(location, f"unsupported Python syntax: '{feature}'") + super().__init__(location, f"Unsupported Python syntax: '{feature}'.") self.feature = feature @@ -69,7 +69,7 @@ class UndefinedSymbolError(DSLError): sym_name: str def __init__(self, location: Optional[SourceLocation], name: str) -> None: - super().__init__(location, f"name '{name}' is not defined") + super().__init__(location, f"Name '{name}' is not defined.") self.sym_name = name @@ -77,7 +77,7 @@ class MissingAttributeError(DSLError): attr_name: str def __init__(self, location: Optional[SourceLocation], attr_name: str) -> None: - super().__init__(location, f"object does not have attribute '{attr_name}'") + super().__init__(location, f"Object does not have attribute '{attr_name}'.") self.attr_name = attr_name @@ -90,7 +90,7 @@ class MissingParameterAnnotationError(TypeError_): param_name: str def __init__(self, location: Optional[SourceLocation], param_name: str) -> None: - super().__init__(location, f"parameter '{param_name}' is missing type annotations") + super().__init__(location, f"Parameter '{param_name}' is missing type annotations.") self.param_name = param_name @@ -100,7 +100,7 @@ class InvalidParameterAnnotationError(TypeError_): def __init__(self, location: Optional[SourceLocation], param_name: str, type_: Any) -> None: super().__init__( - location, f"parameter '{param_name}' has invalid type annotation '{type_}'" + location, f"Parameter '{param_name}' has invalid type annotation '{type_}'." ) self.param_name = param_name self.annotated_type = type_ diff --git a/src/gt4py/next/ffront/ast_passes/simple_assign.py b/src/gt4py/next/ffront/ast_passes/simple_assign.py index e2e6439e37..8b079bb8c1 100644 --- a/src/gt4py/next/ffront/ast_passes/simple_assign.py +++ b/src/gt4py/next/ffront/ast_passes/simple_assign.py @@ -22,7 +22,7 @@ class NodeYielder(ast.NodeTransformer): def apply(cls, node: ast.AST) -> ast.AST: result = list(cls().visit(node)) if len(result) != 1: - raise ValueError("AST was split or lost during the pass. Use `.visit()` instead.") + raise ValueError("AST was split or lost during the pass, use '.visit()' instead.") return result[0] def visit(self, node: ast.AST) -> Iterator[ast.AST]: diff --git a/src/gt4py/next/ffront/ast_passes/single_static_assign.py b/src/gt4py/next/ffront/ast_passes/single_static_assign.py index 4181d7f449..ee1e29a8e8 100644 --- a/src/gt4py/next/ffront/ast_passes/single_static_assign.py +++ b/src/gt4py/next/ffront/ast_passes/single_static_assign.py @@ -65,7 +65,7 @@ class _AssignmentTracker: def define(self, name: str) -> None: if name in self.names(): - raise ValueError(f"Variable {name} is already defined.") + raise ValueError(f"Variable '{name}' is already defined.") # -1 signifies a self._counts[name] = -1 diff --git a/src/gt4py/next/ffront/decorator.py b/src/gt4py/next/ffront/decorator.py index 8202cda6f5..4abd8f156a 100644 --- a/src/gt4py/next/ffront/decorator.py +++ b/src/gt4py/next/ffront/decorator.py @@ -88,7 +88,7 @@ def _get_closure_vars_recursively(closure_vars: dict[str, Any]) -> dict[str, Any raise NotImplementedError( f"Using closure vars with same name but different value " f"across functions is not implemented yet. \n" - f"Collisions: {'`, `'.join(collisions)}" + f"Collisions: '{', '.join(collisions)}'." ) all_closure_vars = collections.ChainMap(all_closure_vars, all_child_closure_vars) @@ -125,7 +125,7 @@ def is_cartesian_offset(o: FieldOffset): if requested_grid_type == GridType.CARTESIAN and deduced_grid_type == GridType.UNSTRUCTURED: raise ValueError( - "grid_type == GridType.CARTESIAN was requested, but unstructured `FieldOffset` or local `Dimension` was found." + "'grid_type == GridType.CARTESIAN' was requested, but unstructured 'FieldOffset' or local 'Dimension' was found." ) return deduced_grid_type if requested_grid_type is None else requested_grid_type @@ -147,7 +147,7 @@ def _field_constituents_shape_and_dims( elif isinstance(arg_type, ts.ScalarType): yield (None, []) else: - raise ValueError("Expected `FieldType` or `TupleType` thereof.") + raise ValueError("Expected 'FieldType' or 'TupleType' thereof.") # TODO(tehrengruber): Decide if and how programs can call other programs. As a @@ -208,7 +208,7 @@ def __post_init__(self): ] if misnamed_functions: raise RuntimeError( - f"The following symbols resolve to a function with a mismatching name: {','.join(misnamed_functions)}" + f"The following symbols resolve to a function with a mismatching name: {','.join(misnamed_functions)}." ) undefined_symbols = [ @@ -218,7 +218,7 @@ def __post_init__(self): ] if undefined_symbols: raise RuntimeError( - f"The following closure variables are undefined: {', '.join(undefined_symbols)}" + f"The following closure variables are undefined: {', '.join(undefined_symbols)}." ) @functools.cached_property @@ -228,7 +228,7 @@ def __gt_allocator__( if self.backend: return self.backend.__gt_allocator__ else: - raise RuntimeError(f"Program {self} does not have a backend set.") + raise RuntimeError(f"Program '{self}' does not have a backend set.") def with_backend(self, backend: ppi.ProgramExecutor) -> Program: return dataclasses.replace(self, backend=backend) @@ -263,7 +263,7 @@ def with_bound_args(self, **kwargs) -> ProgramWithBoundArgs: """ for key in kwargs.keys(): if all(key != param.id for param in self.past_node.params): - raise TypeError(f"Keyword argument `{key}` is not a valid program parameter.") + raise TypeError(f"Keyword argument '{key}' is not a valid program parameter.") return ProgramWithBoundArgs( bound_args=kwargs, @@ -344,7 +344,7 @@ def _validate_args(self, *args, **kwargs) -> None: raise_exception=True, ) except ValueError as err: - raise TypeError(f"Invalid argument types in call to `{self.past_node.id}`!") from err + raise TypeError(f"Invalid argument types in call to '{self.past_node.id}'.") from err def _process_args(self, args: tuple, kwargs: dict) -> tuple[tuple, tuple, dict[str, Any]]: self._validate_args(*args, **kwargs) @@ -397,9 +397,10 @@ def _column_axis(self): ] raise TypeError( - "Only `ScanOperator`s defined on the same axis " - + "can be used in a `Program`, but found:\n" + "Only 'ScanOperator's defined on the same axis " + + "can be used in a 'Program', found:\n" + "\n".join(scanops_per_axis_strs) + + "." ) return iter(scanops_per_axis.keys()).__next__() @@ -436,7 +437,7 @@ def _process_args(self, args: tuple, kwargs: dict): # a better error message. for name in self.bound_args.keys(): if name in kwargs: - raise ValueError(f"Parameter `{name}` already set as a bound argument.") + raise ValueError(f"Parameter '{name}' already set as a bound argument.") type_info.accepts_args( new_type, @@ -445,10 +446,10 @@ def _process_args(self, args: tuple, kwargs: dict): raise_exception=True, ) except ValueError as err: - bound_arg_names = ", ".join([f"`{bound_arg}`" for bound_arg in self.bound_args.keys()]) + bound_arg_names = ", ".join([f"'{bound_arg}'" for bound_arg in self.bound_args.keys()]) raise TypeError( - f"Invalid argument types in call to program `{self.past_node.id}` with " - f"bound arguments {bound_arg_names}!" + f"Invalid argument types in call to program '{self.past_node.id}' with " + f"bound arguments '{bound_arg_names}'." ) from err full_args = [*args] diff --git a/src/gt4py/next/ffront/fbuiltins.py b/src/gt4py/next/ffront/fbuiltins.py index 8230e35a35..93f17b1eb8 100644 --- a/src/gt4py/next/ffront/fbuiltins.py +++ b/src/gt4py/next/ffront/fbuiltins.py @@ -139,13 +139,16 @@ def __call__(self, mask: MaskT, true_field: FieldT, false_field: FieldT) -> _R: if isinstance(true_field, tuple) or isinstance(false_field, tuple): if not (isinstance(true_field, tuple) and isinstance(false_field, tuple)): raise ValueError( - f"Either both or none can be tuple in {true_field=} and {false_field=}." # TODO(havogt) find a strategy to unify parsing and embedded error messages + # TODO(havogt) find a strategy to unify parsing and embedded error messages + f"Either both or none can be tuple in '{true_field=}' and '{false_field=}'." ) if len(true_field) != len(false_field): raise ValueError( "Tuple of different size not allowed." ) # TODO(havogt) find a strategy to unify parsing and embedded error messages - return tuple(where(mask, t, f) for t, f in zip(true_field, false_field)) # type: ignore[return-value] # `tuple` is not `_R` + return tuple( + where(mask, t, f) for t, f in zip(true_field, false_field) + ) # type: ignore[return-value] # `tuple` is not `_R` return super().__call__(mask, true_field, false_field) diff --git a/src/gt4py/next/ffront/foast_introspection.py b/src/gt4py/next/ffront/foast_introspection.py index 805df465b8..404b99d1a0 100644 --- a/src/gt4py/next/ffront/foast_introspection.py +++ b/src/gt4py/next/ffront/foast_introspection.py @@ -73,4 +73,4 @@ def deduce_stmt_return_kind(node: foast.Stmt) -> StmtReturnKind: elif isinstance(node, (foast.Assign, foast.TupleTargetAssign)): return StmtReturnKind.NO_RETURN else: - raise AssertionError(f"Statements of type `{type(node).__name__}` not understood.") + raise AssertionError(f"Statements of type '{type(node).__name__}' not understood.") diff --git a/src/gt4py/next/ffront/foast_passes/closure_var_folding.py b/src/gt4py/next/ffront/foast_passes/closure_var_folding.py index 9afd22de2c..0561a80659 100644 --- a/src/gt4py/next/ffront/foast_passes/closure_var_folding.py +++ b/src/gt4py/next/ffront/foast_passes/closure_var_folding.py @@ -56,7 +56,7 @@ def visit_Attribute(self, node: foast.Attribute, **kwargs) -> foast.Constant: if hasattr(value.value, node.attr): return foast.Constant(value=getattr(value.value, node.attr), location=node.location) raise errors.MissingAttributeError(node.location, node.attr) - raise errors.DSLError(node.location, "attribute access only applicable to constants") + raise errors.DSLError(node.location, "Attribute access only applicable to constants.") def visit_FunctionDefinition( self, node: foast.FunctionDefinition, **kwargs diff --git a/src/gt4py/next/ffront/foast_passes/type_deduction.py b/src/gt4py/next/ffront/foast_passes/type_deduction.py index 95c9128f87..639e5ff009 100644 --- a/src/gt4py/next/ffront/foast_passes/type_deduction.py +++ b/src/gt4py/next/ffront/foast_passes/type_deduction.py @@ -53,7 +53,7 @@ def with_altered_scalar_kind( elif isinstance(type_spec, ts.ScalarType): return ts.ScalarType(kind=new_scalar_kind, shape=type_spec.shape) else: - raise ValueError(f"Expected field or scalar type, but got {type_spec}.") + raise ValueError(f"Expected field or scalar type, got '{type_spec}'.") def construct_tuple_type( @@ -113,7 +113,9 @@ def promote_to_mask_type( item in input_type.dims for item in mask_type.dims ): return_dtype = input_type.dtype if isinstance(input_type, ts.FieldType) else input_type - return type_info.promote(input_type, ts.FieldType(dims=mask_type.dims, dtype=return_dtype)) # type: ignore + return type_info.promote( + input_type, ts.FieldType(dims=mask_type.dims, dtype=return_dtype) + ) # type: ignore else: return input_type @@ -148,7 +150,7 @@ def deduce_stmt_return_type( else: raise errors.DSLError( stmt.location, - f"If statement contains return statements with inconsistent types:" + "If statement contains return statements with inconsistent types:" f"{return_types[0]} != {return_types[1]}", ) return_type = return_types[0] or return_types[1] @@ -160,12 +162,12 @@ def deduce_stmt_return_type( elif isinstance(stmt, (foast.Assign, foast.TupleTargetAssign)): return_type = None else: - raise AssertionError(f"Nodes of type `{type(stmt).__name__}` not supported.") + raise AssertionError(f"Nodes of type '{type(stmt).__name__}' not supported.") if conditional_return_type and return_type and return_type != conditional_return_type: raise errors.DSLError( stmt.location, - f"If statement contains return statements with inconsistent types:" + "If statement contains return statements with inconsistent types:" f"{conditional_return_type} != {conditional_return_type}", ) @@ -179,7 +181,7 @@ def deduce_stmt_return_type( # If the node was constructed by the foast parsing we should never get here, but instead # we should have gotten an error there. raise AssertionError( - "Malformed block statement. Expected a return statement in this context, " + "Malformed block statement: expected a return statement in this context, " "but none was found. Please submit a bug report." ) @@ -195,7 +197,7 @@ def apply(cls, node: foast.LocatedNode) -> None: cls().visit(node, incomplete_nodes=incomplete_nodes) if incomplete_nodes: - raise AssertionError("FOAST expression is not fully typed.") + raise AssertionError("'FOAST' expression is not fully typed.") def visit_LocatedNode( self, node: foast.LocatedNode, *, incomplete_nodes: list[foast.LocatedNode] @@ -251,7 +253,7 @@ def visit_FunctionDefinition(self, node: foast.FunctionDefinition, **kwargs): if not isinstance(return_type, (ts.DataType, ts.DeferredType, ts.VoidType)): raise errors.DSLError( node.location, - f"Function must return `DataType`, `DeferredType`, or `VoidType`, got `{return_type}`.", + f"Function must return 'DataType', 'DeferredType', or 'VoidType', got '{return_type}'.", ) new_type = ts.FunctionType( pos_only_args=[], @@ -283,17 +285,17 @@ def visit_ScanOperator(self, node: foast.ScanOperator, **kwargs) -> foast.ScanOp if not isinstance(new_axis.type, ts.DimensionType): raise errors.DSLError( node.location, - f"Argument `axis` to scan operator `{node.id}` must be a dimension.", + f"Argument 'axis' to scan operator '{node.id}' must be a dimension.", ) if not new_axis.type.dim.kind == DimensionKind.VERTICAL: raise errors.DSLError( node.location, - f"Argument `axis` to scan operator `{node.id}` must be a vertical dimension.", + f"Argument 'axis' to scan operator '{node.id}' must be a vertical dimension.", ) new_forward = self.visit(node.forward, **kwargs) if not new_forward.type.kind == ts.ScalarKind.BOOL: raise errors.DSLError( - node.location, f"Argument `forward` to scan operator `{node.id}` must be a boolean." + node.location, f"Argument 'forward' to scan operator '{node.id}' must be a boolean." ) new_init = self.visit(node.init, **kwargs) if not all( @@ -302,8 +304,8 @@ def visit_ScanOperator(self, node: foast.ScanOperator, **kwargs) -> foast.ScanOp ): raise errors.DSLError( node.location, - f"Argument `init` to scan operator `{node.id}` must " - f"be an arithmetic type or a logical type or a composite of arithmetic and logical types.", + f"Argument 'init' to scan operator '{node.id}' must " + "be an arithmetic type or a logical type or a composite of arithmetic and logical types.", ) new_definition = self.visit(node.definition, **kwargs) new_def_type = new_definition.type @@ -311,15 +313,15 @@ def visit_ScanOperator(self, node: foast.ScanOperator, **kwargs) -> foast.ScanOp if new_init.type != new_def_type.returns: raise errors.DSLError( node.location, - f"Argument `init` to scan operator `{node.id}` must have same type as its return. " - f"Expected `{new_def_type.returns}`, but got `{new_init.type}`", + f"Argument 'init' to scan operator '{node.id}' must have same type as its return: " + f"expected '{new_def_type.returns}', got '{new_init.type}'.", ) elif new_init.type != carry_type: carry_arg_name = list(new_def_type.pos_or_kw_args.keys())[0] raise errors.DSLError( node.location, - f"Argument `init` to scan operator `{node.id}` must have same type as `{carry_arg_name}` argument. " - f"Expected `{carry_type}`, but got `{new_init.type}`", + f"Argument 'init' to scan operator '{node.id}' must have same type as '{carry_arg_name}' argument: " + f"expected '{carry_type}', got '{new_init.type}'.", ) new_type = ts_ffront.ScanOperatorType( @@ -339,7 +341,7 @@ def visit_ScanOperator(self, node: foast.ScanOperator, **kwargs) -> foast.ScanOp def visit_Name(self, node: foast.Name, **kwargs) -> foast.Name: symtable = kwargs["symtable"] if node.id not in symtable or symtable[node.id].type is None: - raise errors.DSLError(node.location, f"Undeclared symbol `{node.id}`.") + raise errors.DSLError(node.location, f"Undeclared symbol '{node.id}'.") symbol = symtable[node.id] return foast.Name(id=node.id, type=symbol.type, location=node.location) @@ -362,9 +364,9 @@ def visit_TupleTargetAssign( targets: TargetType = node.targets indices: list[tuple[int, int] | int] = compute_assign_indices(targets, num_elts) - if not any(isinstance(i, tuple) for i in indices) and len(indices) != num_elts: + if not any(isinstance(i, tuple) for i in indices) and len(targets) != num_elts: raise errors.DSLError( - node.location, f"Too many values to unpack (expected {len(indices)})." + node.location, f"Too many values to unpack (expected {len(targets)})." ) new_targets: TargetType = [] @@ -396,7 +398,7 @@ def visit_TupleTargetAssign( new_targets.append(new_target) else: raise errors.DSLError( - node.location, f"Assignment value must be of type tuple! Got: {values.type}" + node.location, f"Assignment value must be of type tuple, got '{values.type}'." ) return foast.TupleTargetAssign(targets=new_targets, value=values, location=node.location) @@ -416,15 +418,14 @@ def visit_IfStmt(self, node: foast.IfStmt, **kwargs) -> foast.IfStmt: if not isinstance(new_node.condition.type, ts.ScalarType): raise errors.DSLError( node.location, - "Condition for `if` must be scalar. " - f"But got `{new_node.condition.type}` instead.", + "Condition for 'if' must be scalar, " f"got '{new_node.condition.type}' instead.", ) if new_node.condition.type.kind != ts.ScalarKind.BOOL: raise errors.DSLError( node.location, - "Condition for `if` must be of boolean type. " - f"But got `{new_node.condition.type}` instead.", + "Condition for 'if' must be of boolean type, " + f"got '{new_node.condition.type}' instead.", ) for sym in node.annex.propagated_symbols.keys(): @@ -433,8 +434,8 @@ def visit_IfStmt(self, node: foast.IfStmt, **kwargs) -> foast.IfStmt: ): raise errors.DSLError( node.location, - f"Inconsistent types between two branches for variable `{sym}`. " - f"Got types `{true_type}` and `{false_type}.", + f"Inconsistent types between two branches for variable '{sym}': " + f"got types '{true_type}' and '{false_type}.", ) # TODO: properly patch symtable (new node?) symtable[sym].type = new_node.annex.propagated_symbols[ @@ -455,8 +456,8 @@ def visit_Symbol( raise errors.DSLError( node.location, ( - "type inconsistency: expression was deduced to be " - f"of type {refine_type}, instead of the expected type {node.type}" + "Type inconsistency: expression was deduced to be " + f"of type '{refine_type}', instead of the expected type '{node.type}'." ), ) new_node: foast.Symbol = foast.Symbol( @@ -490,7 +491,7 @@ def visit_Subscript(self, node: foast.Subscript, **kwargs) -> foast.Subscript: new_type = new_value.type case _: raise errors.DSLError( - new_value.location, "Could not deduce type of subscript expression!" + new_value.location, "Could not deduce type of subscript expression." ) return foast.Subscript( @@ -531,13 +532,13 @@ def _deduce_ternaryexpr_type( if condition.type != ts.ScalarType(kind=ts.ScalarKind.BOOL): raise errors.DSLError( condition.location, - f"Condition is of type `{condition.type}` " f"but should be of type `bool`.", + f"Condition is of type '{condition.type}', should be of type 'bool'.", ) if true_expr.type != false_expr.type: raise errors.DSLError( node.location, - f"Left and right types are not the same: `{true_expr.type}` and `{false_expr.type}`", + f"Left and right types are not the same: '{true_expr.type}' and '{false_expr.type}'", ) return true_expr.type @@ -556,7 +557,7 @@ def _deduce_compare_type( for arg in (left, right): if not type_info.is_arithmetic(arg.type): raise errors.DSLError( - arg.location, f"Type {arg.type} can not be used in operator '{node.op}'!" + arg.location, f"Type '{arg.type}' can not be used in operator '{node.op}'." ) self._check_operand_dtypes_match(node, left=left, right=right) @@ -571,8 +572,8 @@ def _deduce_compare_type( except ValueError as ex: raise errors.DSLError( node.location, - f"Could not promote `{left.type}` and `{right.type}` to common type" - f" in call to `{node.op}`.", + f"Could not promote '{left.type}' and '{right.type}' to common type" + f" in call to '{node.op}'.", ) from ex def _deduce_binop_type( @@ -594,7 +595,7 @@ def _deduce_binop_type( for arg in (left, right): if not is_compatible(arg.type): raise errors.DSLError( - arg.location, f"Type {arg.type} can not be used in operator `{node.op}`!" + arg.location, f"Type '{arg.type}' can not be used in operator '{node.op}'." ) left_type = cast(ts.FieldType | ts.ScalarType, left.type) @@ -608,7 +609,7 @@ def _deduce_binop_type( ): raise errors.DSLError( arg.location, - f"Type {right_type} can not be used in operator `{node.op}`, it can only accept ints", + f"Type '{right_type}' can not be used in operator '{node.op}', it only accepts 'int'.", ) try: @@ -616,8 +617,8 @@ def _deduce_binop_type( except ValueError as ex: raise errors.DSLError( node.location, - f"Could not promote `{left_type}` and `{right_type}` to common type" - f" in call to `{node.op}`.", + f"Could not promote '{left_type}' and '{right_type}' to common type" + f" in call to '{node.op}'.", ) from ex def _check_operand_dtypes_match( @@ -627,7 +628,7 @@ def _check_operand_dtypes_match( if not type_info.extract_dtype(left.type) == type_info.extract_dtype(right.type): raise errors.DSLError( node.location, - f"Incompatible datatypes in operator `{node.op}`: {left.type} and {right.type}!", + f"Incompatible datatypes in operator '{node.op}': '{left.type}' and '{right.type}'.", ) def visit_UnaryOp(self, node: foast.UnaryOp, **kwargs) -> foast.UnaryOp: @@ -644,7 +645,7 @@ def visit_UnaryOp(self, node: foast.UnaryOp, **kwargs) -> foast.UnaryOp: if not is_compatible(new_operand.type): raise errors.DSLError( node.location, - f"Incompatible type for unary operator `{node.op}`: `{new_operand.type}`!", + f"Incompatible type for unary operator '{node.op}': '{new_operand.type}'.", ) return foast.UnaryOp( op=node.op, operand=new_operand, location=node.location, type=new_operand.type @@ -674,13 +675,13 @@ def visit_Call(self, node: foast.Call, **kwargs) -> foast.Call: new_func, (foast.FunctionDefinition, foast.FieldOperator, foast.ScanOperator, foast.Name), ): - raise errors.DSLError(node.location, "Functions can only be called directly!") + raise errors.DSLError(node.location, "Functions can only be called directly.") elif isinstance(new_func.type, ts.FieldType): pass else: raise errors.DSLError( node.location, - f"Expression of type `{new_func.type}` is not callable, must be a `Function`, `FieldOperator`, `ScanOperator` or `Field`.", + f"Expression of type '{new_func.type}' is not callable, must be a 'Function', 'FieldOperator', 'ScanOperator' or 'Field'.", ) # ensure signature is valid @@ -693,7 +694,7 @@ def visit_Call(self, node: foast.Call, **kwargs) -> foast.Call: ) except ValueError as err: raise errors.DSLError( - node.location, f"Invalid argument types in call to `{new_func}`!" + node.location, f"Invalid argument types in call to '{new_func}'." ) from err return_type = type_info.return_type(func_type, with_args=arg_types, with_kwargs=kwarg_types) @@ -727,7 +728,7 @@ def _visit_math_built_in(self, node: foast.Call, **kwargs) -> foast.Call: func_name = cast(foast.Name, node.func).id # validate arguments - error_msg_preamble = f"Incompatible argument in call to `{func_name}`." + error_msg_preamble = f"Incompatible argument in call to '{func_name}'." error_msg_for_validator = { type_info.is_arithmetic: "an arithmetic", type_info.is_floating_point: "a floating point", @@ -741,13 +742,13 @@ def _visit_math_built_in(self, node: foast.Call, **kwargs) -> foast.Call: elif func_name in fbuiltins.BINARY_MATH_NUMBER_BUILTIN_NAMES: arg_validator = type_info.is_arithmetic else: - raise AssertionError(f"Unknown math builtin `{func_name}`.") + raise AssertionError(f"Unknown math builtin '{func_name}'.") error_msgs = [] for i, arg in enumerate(node.args): if not arg_validator(arg.type): error_msgs.append( - f"Expected {i}-th argument to be {error_msg_for_validator[arg_validator]} type, but got `{arg.type}`." + f"Expected {i}-th argument to be {error_msg_for_validator[arg_validator]} type, got '{arg.type}'." ) if error_msgs: raise errors.DSLError( @@ -756,7 +757,7 @@ def _visit_math_built_in(self, node: foast.Call, **kwargs) -> foast.Call: ) if func_name == "power" and all(type_info.is_integral(arg.type) for arg in node.args): - print(f"Warning: return type of {func_name} might be inconsistent (not implemented).") + print(f"Warning: return type of '{func_name}' might be inconsistent (not implemented).") # deduce return type return_type: Optional[ts.FieldType | ts.ScalarType] = None @@ -777,7 +778,7 @@ def _visit_math_built_in(self, node: foast.Call, **kwargs) -> foast.Call: except ValueError as ex: raise errors.DSLError(node.location, error_msg_preamble) from ex else: - raise AssertionError(f"Unknown math builtin `{func_name}`.") + raise AssertionError(f"Unknown math builtin '{func_name}'.") return foast.Call( func=node.func, @@ -796,9 +797,9 @@ def _visit_reduction(self, node: foast.Call, **kwargs) -> foast.Call: field_dims_str = ", ".join(str(dim) for dim in field_type.dims) raise errors.DSLError( node.location, - f"Incompatible field argument in call to `{str(node.func)}`. " - f"Expected a field with dimension {reduction_dim}, but got " - f"{field_dims_str}.", + f"Incompatible field argument in call to '{str(node.func)}'. " + f"Expected a field with dimension '{reduction_dim}', got " + f"'{field_dims_str}'.", ) return_type = ts.FieldType( dims=[dim for dim in field_type.dims if dim != reduction_dim], @@ -834,7 +835,7 @@ def _visit_astype(self, node: foast.Call, **kwargs) -> foast.Call: ]: raise errors.DSLError( node.location, - f"Invalid call to `astype`. Second argument must be a scalar type, but got {new_type}.", + f"Invalid call to 'astype': second argument must be a scalar type, got '{new_type}'.", ) return_type = type_info.apply_to_primitive_constituents( @@ -860,16 +861,16 @@ def _visit_as_offset(self, node: foast.Call, **kwargs) -> foast.Call: if not type_info.is_integral(arg_1): raise errors.DSLError( node.location, - f"Incompatible argument in call to `{str(node.func)}`. " - f"Excepted integer for offset field dtype, but got {arg_1.dtype}" + f"Incompatible argument in call to '{str(node.func)}': " + f"expected integer for offset field dtype, got '{arg_1.dtype}'. " f"{node.location}", ) if arg_0.source not in arg_1.dims: raise errors.DSLError( node.location, - f"Incompatible argument in call to `{str(node.func)}`. " - f"{arg_0.source} not in list of offset field dimensions {arg_1.dims}. " + f"Incompatible argument in call to '{str(node.func)}': " + f"'{arg_0.source}' not in list of offset field dimensions '{arg_1.dims}'. " f"{node.location}", ) @@ -889,8 +890,8 @@ def _visit_where(self, node: foast.Call, **kwargs) -> foast.Call: if not type_info.is_logical(mask_type): raise errors.DSLError( node.location, - f"Incompatible argument in call to `{str(node.func)}`. Expected " - f"a field with dtype `bool`, but got `{mask_type}`.", + f"Incompatible argument in call to '{str(node.func)}': expected " + f"a field with dtype 'bool', got '{mask_type}'.", ) try: @@ -907,8 +908,8 @@ def _visit_where(self, node: foast.Call, **kwargs) -> foast.Call: ): raise errors.DSLError( node.location, - f"Return arguments need to be of same type in {str(node.func)}, but got: " - f"{node.args[1].type} and {node.args[2].type}", + f"Return arguments need to be of same type in '{str(node.func)}', got " + f"'{node.args[1].type}' and '{node.args[2].type}'.", ) else: true_branch_fieldtype = cast(ts.FieldType, true_branch_type) @@ -919,7 +920,7 @@ def _visit_where(self, node: foast.Call, **kwargs) -> foast.Call: except ValueError as ex: raise errors.DSLError( node.location, - f"Incompatible argument in call to `{str(node.func)}`.", + f"Incompatible argument in call to '{str(node.func)}'.", ) from ex return foast.Call( @@ -937,8 +938,8 @@ def _visit_broadcast(self, node: foast.Call, **kwargs) -> foast.Call: if any([not (isinstance(elt.type, ts.DimensionType)) for elt in broadcast_dims_expr]): raise errors.DSLError( node.location, - f"Incompatible broadcast dimension type in {str(node.func)}. Expected " - f"all broadcast dimensions to be of type Dimension.", + f"Incompatible broadcast dimension type in '{str(node.func)}': expected " + f"all broadcast dimensions to be of type 'Dimension'.", ) broadcast_dims = [cast(ts.DimensionType, elt.type).dim for elt in broadcast_dims_expr] @@ -946,8 +947,8 @@ def _visit_broadcast(self, node: foast.Call, **kwargs) -> foast.Call: if not set((arg_dims := type_info.extract_dims(arg_type))).issubset(set(broadcast_dims)): raise errors.DSLError( node.location, - f"Incompatible broadcast dimensions in {str(node.func)}. Expected " - f"broadcast dimension is missing {set(arg_dims).difference(set(broadcast_dims))}", + f"Incompatible broadcast dimensions in '{str(node.func)}': expected " + f"broadcast dimension(s) '{set(arg_dims).difference(set(broadcast_dims))}' missing", ) return_type = ts.FieldType( diff --git a/src/gt4py/next/ffront/foast_pretty_printer.py b/src/gt4py/next/ffront/foast_pretty_printer.py index 3b81c85265..9275cdda95 100644 --- a/src/gt4py/next/ffront/foast_pretty_printer.py +++ b/src/gt4py/next/ffront/foast_pretty_printer.py @@ -110,7 +110,7 @@ def apply(cls, node: foast.LocatedNode, **kwargs) -> str: # type: ignore[overri node_type_name = type(node).__name__ if not hasattr(cls, node_type_name) and not hasattr(cls, f"visit_{node_type_name}"): raise NotImplementedError( - f"Pretty printer does not support nodes of type " f"`{node_type_name}`." + f"Pretty printer does not support nodes of type '{node_type_name}'." ) return cls().visit(node, **kwargs) diff --git a/src/gt4py/next/ffront/foast_to_itir.py b/src/gt4py/next/ffront/foast_to_itir.py index 3030c03fd1..c4d518d279 100644 --- a/src/gt4py/next/ffront/foast_to_itir.py +++ b/src/gt4py/next/ffront/foast_to_itir.py @@ -230,7 +230,7 @@ def visit_UnaryOp(self, node: foast.UnaryOp, **kwargs) -> itir.Expr: dtype = type_info.extract_dtype(node.type) if node.op in [dialect_ast_enums.UnaryOperator.NOT, dialect_ast_enums.UnaryOperator.INVERT]: if dtype.kind != ts.ScalarKind.BOOL: - raise NotImplementedError(f"{node.op} is only supported on `bool`s.") + raise NotImplementedError(f"'{node.op}' is only supported on 'bool' arguments.") return self._map("not_", node.operand) return self._map( @@ -313,7 +313,7 @@ def visit_Call(self, node: foast.Call, **kwargs) -> itir.Expr: return im.call(self.visit(node.func, **kwargs))(*lowered_args, *lowered_kwargs.values()) raise AssertionError( - f"Call to object of type {type(node.func.type).__name__} not understood." + f"Call to object of type '{type(node.func.type).__name__}' not understood." ) def _visit_astype(self, node: foast.Call, **kwargs) -> itir.FunCall: @@ -371,7 +371,9 @@ def _visit_type_constr(self, node: foast.Call, **kwargs) -> itir.Expr: im.literal(str(bool(source_type(node.args[0].value))), "bool") ) return im.promote_to_const_iterator(im.literal(str(node.args[0].value), node_kind)) - raise FieldOperatorLoweringError(f"Encountered a type cast, which is not supported: {node}") + raise FieldOperatorLoweringError( + f"Encountered a type cast, which is not supported: {node}." + ) def _make_literal(self, val: Any, type_: ts.TypeSpec) -> itir.Expr: # TODO(havogt): lifted nullary lambdas are not supported in iterator.embedded due to an implementation detail; @@ -388,7 +390,7 @@ def _make_literal(self, val: Any, type_: ts.TypeSpec) -> itir.Expr: elif isinstance(type_, ts.ScalarType): typename = type_.kind.name.lower() return im.promote_to_const_iterator(im.literal(str(val), typename)) - raise ValueError(f"Unsupported literal type {type_}.") + raise ValueError(f"Unsupported literal type '{type_}'.") def visit_Constant(self, node: foast.Constant, **kwargs) -> itir.Expr: return self._make_literal(node.value, node.type) diff --git a/src/gt4py/next/ffront/func_to_foast.py b/src/gt4py/next/ffront/func_to_foast.py index c7c4c3a23f..0fd263308e 100644 --- a/src/gt4py/next/ffront/func_to_foast.py +++ b/src/gt4py/next/ffront/func_to_foast.py @@ -107,8 +107,9 @@ def _postprocess_dialect_ast( if annotated_return_type != foast_node.type.returns: # type: ignore[union-attr] # revisit when `type_info.return_type` is implemented raise errors.DSLError( foast_node.location, - f"Annotated return type does not match deduced return type. Expected `{foast_node.type.returns}`" # type: ignore[union-attr] # revisit when `type_info.return_type` is implemented - f", but got `{annotated_return_type}`.", + "Annotated return type does not match deduced return type: expected " + f"'{foast_node.type.returns}'" # type: ignore[union-attr] # revisit when 'type_info.return_type' is implemented + f", got '{annotated_return_type}'.", ) return foast_node @@ -167,7 +168,7 @@ def visit_FunctionDef(self, node: ast.FunctionDef, **kwargs) -> foast.FunctionDe new_body = self._visit_stmts(node.body, self.get_location(node), **kwargs) if deduce_stmt_return_kind(new_body) == StmtReturnKind.NO_RETURN: - raise errors.DSLError(loc, "Function is expected to return a value.") + raise errors.DSLError(loc, "'Function' is expected to return a value.") return foast.FunctionDefinition( id=node.name, @@ -224,7 +225,7 @@ def visit_Assign(self, node: ast.Assign, **kwargs) -> foast.Assign | foast.Tuple ) if not isinstance(target, ast.Name): - raise errors.DSLError(self.get_location(node), "can only assign to names") + raise errors.DSLError(self.get_location(node), "Can only assign to names.") new_value = self.visit(node.value) constraint_type: Type[ts.DataType] = ts.DataType if isinstance(new_value, foast.TupleExpr): @@ -246,7 +247,7 @@ def visit_Assign(self, node: ast.Assign, **kwargs) -> foast.Assign | foast.Tuple def visit_AnnAssign(self, node: ast.AnnAssign, **kwargs) -> foast.Assign: if not isinstance(node.target, ast.Name): - raise errors.DSLError(self.get_location(node), "can only assign to names") + raise errors.DSLError(self.get_location(node), "Can only assign to names.") if node.annotation is not None: assert isinstance( @@ -281,14 +282,14 @@ def _match_index(node: ast.expr) -> int: return -node.operand.value if isinstance(node.op, ast.UAdd): return node.operand.value - raise ValueError(f"Not an index: {node}") + raise ValueError(f"Not an index: '{node}'.") def visit_Subscript(self, node: ast.Subscript, **kwargs) -> foast.Subscript: try: index = self._match_index(node.slice) except ValueError: raise errors.DSLError( - self.get_location(node.slice), "expected an integral index" + self.get_location(node.slice), "eXpected an integral index." ) from None return foast.Subscript( @@ -310,7 +311,7 @@ def visit_Tuple(self, node: ast.Tuple, **kwargs) -> foast.TupleExpr: def visit_Return(self, node: ast.Return, **kwargs) -> foast.Return: loc = self.get_location(node) if not node.value: - raise errors.DSLError(loc, "must return a value, not None") + raise errors.DSLError(loc, "Must return a value, not None") return foast.Return(value=self.visit(node.value), location=loc) def visit_Expr(self, node: ast.Expr) -> foast.Expr: @@ -442,11 +443,11 @@ def _verify_builtin_type_constructor(self, node: ast.Call): if len(node.args) > 0 and not isinstance(node.args[0], ast.Constant): raise errors.DSLError( self.get_location(node), - f"{self._func_name(node)}() only takes literal arguments!", + f"'{self._func_name(node)}()' only takes literal arguments.", ) def _func_name(self, node: ast.Call) -> str: - return node.func.id # type: ignore[attr-defined] # We want this to fail if the attribute does not exist unexpectedly. + return node.func.id # type: ignore[attr-defined] # We want this to fail if the attribute does not exist unexpectedly. def visit_Call(self, node: ast.Call, **kwargs) -> foast.Call: # TODO(tehrengruber): is this still needed or redundant with the checks in type deduction? @@ -468,7 +469,7 @@ def visit_Constant(self, node: ast.Constant, **kwargs) -> foast.Constant: type_ = type_translation.from_value(node.value) except ValueError: raise errors.DSLError( - loc, f"constants of type {type(node.value)} are not permitted" + loc, f"Constants of type {type(node.value)} are not permitted." ) from None return foast.Constant( diff --git a/src/gt4py/next/ffront/func_to_past.py b/src/gt4py/next/ffront/func_to_past.py index 7b04e90902..5b4dd934b9 100644 --- a/src/gt4py/next/ffront/func_to_past.py +++ b/src/gt4py/next/ffront/func_to_past.py @@ -129,7 +129,7 @@ def visit_Call(self, node: ast.Call) -> past.Call: new_func = self.visit(node.func) if not isinstance(new_func, past.Name): raise errors.DSLError( - loc, "functions must be referenced by their name in function calls" + loc, "Functions must be referenced by their name in function calls." ) return past.Call( @@ -166,7 +166,7 @@ def visit_UnaryOp(self, node: ast.UnaryOp) -> past.Constant: if isinstance(node.op, ast.USub) and isinstance(node.operand, ast.Constant): symbol_type = type_translation.from_value(node.operand.value) return past.Constant(value=-node.operand.value, type=symbol_type, location=loc) - raise errors.DSLError(loc, "unary operators are only applicable to literals") + raise errors.DSLError(loc, "Unary operators are only applicable to literals.") def visit_Constant(self, node: ast.Constant) -> past.Constant: symbol_type = type_translation.from_value(node.value) diff --git a/src/gt4py/next/ffront/past_passes/type_deduction.py b/src/gt4py/next/ffront/past_passes/type_deduction.py index ed3bdae3ff..fc353d64e4 100644 --- a/src/gt4py/next/ffront/past_passes/type_deduction.py +++ b/src/gt4py/next/ffront/past_passes/type_deduction.py @@ -33,7 +33,7 @@ def _ensure_no_sliced_field(entry: past.Expr): For example, if argument is of type past.Subscript, this function will throw an error as both slicing and domain are being applied """ if not isinstance(entry, past.Name) and not isinstance(entry, past.TupleExpr): - raise ValueError("Either only domain or slicing allowed") + raise ValueError("Either only domain or slicing allowed.") elif isinstance(entry, past.TupleExpr): for param in entry.elts: _ensure_no_sliced_field(param) @@ -57,20 +57,18 @@ def _validate_operator_call(new_func: past.Name, new_kwargs: dict): (ts_ffront.FieldOperatorType, ts_ffront.ScanOperatorType), ): raise ValueError( - f"Only calls `FieldOperator`s and `ScanOperator`s " - f"allowed in `Program`, but got `{new_func.type}`." + f"Only calls to 'FieldOperators' and 'ScanOperators' " + f"allowed in 'Program', got '{new_func.type}'." ) if "out" not in new_kwargs: - raise ValueError("Missing required keyword argument(s) `out`.") + raise ValueError("Missing required keyword argument 'out'.") if "domain" in new_kwargs: _ensure_no_sliced_field(new_kwargs["out"]) domain_kwarg = new_kwargs["domain"] if not isinstance(domain_kwarg, past.Dict): - raise ValueError( - f"Only Dictionaries allowed in domain, but got `{type(domain_kwarg)}`." - ) + raise ValueError(f"Only Dictionaries allowed in 'domain', got '{type(domain_kwarg)}'.") if len(domain_kwarg.values_) == 0 and len(domain_kwarg.keys_) == 0: raise ValueError("Empty domain not allowed.") @@ -78,18 +76,18 @@ def _validate_operator_call(new_func: past.Name, new_kwargs: dict): for dim in domain_kwarg.keys_: if not isinstance(dim.type, ts.DimensionType): raise ValueError( - f"Only Dimension allowed in domain dictionary keys, but got `{dim}` which is of type `{dim.type}`." + f"Only 'Dimension' allowed in domain dictionary keys, got '{dim}' which is of type '{dim.type}'." ) for domain_values in domain_kwarg.values_: if len(domain_values.elts) != 2: raise ValueError( - f"Only 2 values allowed in domain range, but got `{len(domain_values.elts)}`." + f"Only 2 values allowed in domain range, got {len(domain_values.elts)}." ) if not _is_integral_scalar(domain_values.elts[0]) or not _is_integral_scalar( domain_values.elts[1] ): raise ValueError( - f"Only integer values allowed in domain range, but got {domain_values.elts[0].type} and {domain_values.elts[1].type}." + f"Only integer values allowed in domain range, got '{domain_values.elts[0].type}' and '{domain_values.elts[1].type}'." ) @@ -149,7 +147,7 @@ def _deduce_binop_type( for arg in (left, right): if not isinstance(arg.type, ts.ScalarType) or not is_compatible(arg.type): raise errors.DSLError( - arg.location, f"Type {arg.type} can not be used in operator `{node.op}`!" + arg.location, f"Type '{arg.type}' can not be used in operator '{node.op}'." ) left_type = cast(ts.ScalarType, left.type) @@ -163,7 +161,7 @@ def _deduce_binop_type( ): raise errors.DSLError( arg.location, - f"Type {right_type} can not be used in operator `{node.op}`, it can only accept ints", + f"Type '{right_type}' can not be used in operator '{node.op}', it only accepts 'int'.", ) try: @@ -171,8 +169,8 @@ def _deduce_binop_type( except ValueError as ex: raise errors.DSLError( node.location, - f"Could not promote `{left_type}` and `{right_type}` to common type" - f" in call to `{node.op}`.", + f"Could not promote '{left_type}' and '{right_type}' to common type" + f" in call to '{node.op}'.", ) from ex def visit_BinOp(self, node: past.BinOp, **kwargs) -> past.BinOp: @@ -214,24 +212,24 @@ def visit_Call(self, node: past.Call, **kwargs): ) if operator_return_type != new_kwargs["out"].type: raise ValueError( - f"Expected keyword argument `out` to be of " - f"type {operator_return_type}, but got " - f"{new_kwargs['out'].type}." + "Expected keyword argument 'out' to be of " + f"type '{operator_return_type}', got " + f"'{new_kwargs['out'].type}'." ) elif new_func.id in ["minimum", "maximum"]: if new_args[0].type != new_args[1].type: raise ValueError( - f"First and second argument in {new_func.id} must be the same type." - f"Got `{new_args[0].type}` and `{new_args[1].type}`." + f"First and second argument in '{new_func.id}' must be of the same type." + f"Got '{new_args[0].type}' and '{new_args[1].type}'." ) return_type = new_args[0].type else: raise AssertionError( - "Only calls `FieldOperator`s, `ScanOperator`s or minimum and maximum builtins allowed" + "Only calls to 'FieldOperator', 'ScanOperator' or 'minimum' and 'maximum' builtins allowed." ) except ValueError as ex: - raise errors.DSLError(node.location, f"Invalid call to `{node.func.id}`.") from ex + raise errors.DSLError(node.location, f"Invalid call to '{node.func.id}'.") from ex return past.Call( func=new_func, @@ -244,6 +242,6 @@ def visit_Call(self, node: past.Call, **kwargs): def visit_Name(self, node: past.Name, **kwargs) -> past.Name: symtable = kwargs["symtable"] if node.id not in symtable or symtable[node.id].type is None: - raise errors.DSLError(node.location, f"Undeclared or untyped symbol `{node.id}`.") + raise errors.DSLError(node.location, f"Undeclared or untyped symbol '{node.id}'.") return past.Name(id=node.id, type=symtable[node.id].type, location=node.location) diff --git a/src/gt4py/next/ffront/past_to_itir.py b/src/gt4py/next/ffront/past_to_itir.py index 2c5dfc6e2f..709912077b 100644 --- a/src/gt4py/next/ffront/past_to_itir.py +++ b/src/gt4py/next/ffront/past_to_itir.py @@ -37,7 +37,7 @@ def _flatten_tuple_expr( for e in node.elts: result.extend(_flatten_tuple_expr(e)) return result - raise ValueError("Only `past.Name`, `past.Subscript` or `past.TupleExpr`s thereof are allowed.") + raise ValueError("Only 'past.Name', 'past.Subscript' or 'past.TupleExpr' thereof are allowed.") class ProgramLowering(traits.VisitorWithSymbolTableTrait, NodeTranslator): @@ -174,7 +174,7 @@ def _visit_slice_bound( else: lowered_bound = self.visit(slice_bound, **kwargs) else: - raise AssertionError("Expected `None` or `past.Constant`.") + raise AssertionError("Expected 'None' or 'past.Constant'.") return lowered_bound def _construct_itir_out_arg(self, node: past.Expr) -> itir.Expr: @@ -189,8 +189,8 @@ def _construct_itir_out_arg(self, node: past.Expr) -> itir.Expr: ) else: raise ValueError( - "Unexpected `out` argument. Must be a `past.Name`, `past.Subscript`" - " or a `past.TupleExpr` thereof." + "Unexpected 'out' argument. Must be a 'past.Name', 'past.Subscript'" + " or a 'past.TupleExpr' thereof." ) def _construct_itir_domain_arg( @@ -209,9 +209,9 @@ def _construct_itir_domain_arg( for out_field_type in out_field_types ): raise AssertionError( - f"Expected constituents of `{out_field.id}` argument to be" - f" fields defined on the same dimensions. This error should be " - f" caught in type deduction already." + f"Expected constituents of '{out_field.id}' argument to be" + " fields defined on the same dimensions. This error should be " + " caught in type deduction already." ) for dim_i, dim in enumerate(out_dims): @@ -232,7 +232,7 @@ def _construct_itir_domain_arg( ) if dim.kind == DimensionKind.LOCAL: - raise ValueError(f"Dimension {dim.value} must not be local.") + raise ValueError(f"Dimension '{dim.value}' must not be local.") domain_args.append( itir.FunCall( fun=itir.SymRef(id="named_range"), @@ -259,8 +259,8 @@ def _construct_itir_initialized_domain_arg( keys_dims_types = cast(ts.DimensionType, node_domain.keys_[dim_i].type).dim if keys_dims_types != dim: raise ValueError( - f"Dimensions in out field and field domain are not equivalent" - f"Expected {dim}, but got {keys_dims_types} " + "Dimensions in out field and field domain are not equivalent:" + f"expected '{dim}', got '{keys_dims_types}'." ) return [self.visit(bound) for bound in node_domain.values_[dim_i].elts] @@ -277,13 +277,13 @@ def _compute_field_slice(node: past.Subscript): out_field_slice_ = [node.slice_] else: raise AssertionError( - "Unexpected `out` argument. Must be tuple of slices or slice expression." + "Unexpected 'out' argument, must be tuple of slices or slice expression." ) node_dims_ls = cast(ts.FieldType, node.type).dims assert isinstance(node_dims_ls, list) if isinstance(node.type, ts.FieldType) and len(out_field_slice_) != len(node_dims_ls): raise ValueError( - f"Too many indices for field {out_field_name}: field is {len(node_dims_ls)}" + f"Too many indices for field '{out_field_name}': field is {len(node_dims_ls)}" f"-dimensional, but {len(out_field_slice_)} were indexed." ) return out_field_slice_ @@ -321,7 +321,11 @@ def _visit_stencil_call_out_arg( isinstance(field, past.Subscript) for field in flattened ), "Incompatible field in tuple: either all fields or no field must be sliced." assert all( - concepts.eq_nonlocated(first_field.slice_, field.slice_) for field in flattened # type: ignore[union-attr] # mypy cannot deduce type + concepts.eq_nonlocated( + first_field.slice_, + field.slice_, # type: ignore[union-attr] # mypy cannot deduce type + ) + for field in flattened ), "Incompatible field in tuple: all fields must be sliced in the same way." field_slice = self._compute_field_slice(first_field) first_field = first_field.value @@ -332,7 +336,7 @@ def _visit_stencil_call_out_arg( ) else: raise AssertionError( - "Unexpected `out` argument. Must be a `past.Subscript`, `past.Name` or `past.TupleExpr` node." + "Unexpected 'out' argument. Must be a 'past.Subscript', 'past.Name' or 'past.TupleExpr' node." ) def visit_Constant(self, node: past.Constant, **kwargs) -> itir.Literal: @@ -340,7 +344,7 @@ def visit_Constant(self, node: past.Constant, **kwargs) -> itir.Literal: match node.type.kind: case ts.ScalarKind.STRING: raise NotImplementedError( - f"Scalars of kind {node.type.kind} not supported currently." + f"Scalars of kind '{node.type.kind}' not supported currently." ) typename = node.type.kind.name.lower() return itir.Literal(value=str(node.value), type=typename) @@ -373,5 +377,5 @@ def visit_Call(self, node: past.Call, **kwargs) -> itir.FunCall: ) else: raise AssertionError( - "Only `minimum` and `maximum` builtins supported supported currently." + "Only 'minimum' and 'maximum' builtins supported supported currently." ) diff --git a/src/gt4py/next/ffront/source_utils.py b/src/gt4py/next/ffront/source_utils.py index 17b2050b1b..baf3037d5e 100644 --- a/src/gt4py/next/ffront/source_utils.py +++ b/src/gt4py/next/ffront/source_utils.py @@ -37,7 +37,7 @@ def make_source_definition_from_function(func: Callable) -> SourceDefinition: filename = str(pathlib.Path(inspect.getabsfile(func)).resolve()) if not filename: raise ValueError( - "Can not create field operator from a function that is not in a source file!" + "Can not create field operator from a function that is not in a source file." ) source_lines, line_offset = inspect.getsourcelines(func) source_code = textwrap.dedent(inspect.getsource(func)) @@ -47,7 +47,7 @@ def make_source_definition_from_function(func: Callable) -> SourceDefinition: return SourceDefinition(source_code, filename, line_offset - 1, column_offset) except OSError as err: - raise ValueError(f"Can not get source code of passed function ({func})") from err + raise ValueError(f"Can not get source code of passed function '{func}'.") from err def make_symbol_names_from_source(source: str, filename: str = MISSING_FILENAME) -> SymbolNames: @@ -55,13 +55,13 @@ def make_symbol_names_from_source(source: str, filename: str = MISSING_FILENAME) mod_st = symtable.symtable(source, filename, "exec") except SyntaxError as err: raise ValueError( - f"Unexpected error when parsing provided source code (\n{source}\n)" + f"Unexpected error when parsing provided source code: \n{source}\n" ) from err assert mod_st.get_type() == "module" if len(children := mod_st.get_children()) != 1: raise ValueError( - f"Sources with multiple function definitions are not yet supported (\n{source}\n)" + f"Sources with multiple function definitions are not yet supported: \n{source}\n" ) assert children[0].get_type() == "function" diff --git a/src/gt4py/next/ffront/type_info.py b/src/gt4py/next/ffront/type_info.py index 7f56f5d92b..affae8fbca 100644 --- a/src/gt4py/next/ffront/type_info.py +++ b/src/gt4py/next/ffront/type_info.py @@ -51,7 +51,7 @@ def _as_field(arg_el: ts.TypeSpec, path: tuple[int, ...]) -> ts.TypeSpec: if type_info.extract_dtype(param_el) == type_info.extract_dtype(arg_el): return param_el else: - raise ValueError(f"{arg_el} is not compatible with {param_el}.") + raise ValueError(f"'{arg_el}' is not compatible with '{param_el}'.") return arg_el return type_info.apply_to_primitive_constituents(arg, _as_field, with_path_arg=True) diff --git a/src/gt4py/next/iterator/dispatcher.py b/src/gt4py/next/iterator/dispatcher.py index b2ca39df04..626c51ed1c 100644 --- a/src/gt4py/next/iterator/dispatcher.py +++ b/src/gt4py/next/iterator/dispatcher.py @@ -57,7 +57,7 @@ def register_key(self, key): def push_key(self, key): if key not in self._funs: - raise RuntimeError(f"Key {key} not registered") + raise RuntimeError(f"Key '{key}' not registered.") self.key_stack.append(key) def pop_key(self): diff --git a/src/gt4py/next/iterator/embedded.py b/src/gt4py/next/iterator/embedded.py index b00e53bfd9..a4f32929db 100644 --- a/src/gt4py/next/iterator/embedded.py +++ b/src/gt4py/next/iterator/embedded.py @@ -238,7 +238,7 @@ def _validate_kstart(self, args): set(arg.kstart for arg in args if isinstance(arg, Column)) - {self.kstart} ): raise ValueError( - "Incompatible Column.kstart: it should be '{self.kstart}' but found other values: {wrong_kstarts}" + "Incompatible 'Column.kstart': it should be '{self.kstart}' but found other values: {wrong_kstarts}." ) def __array_ufunc__(self, ufunc, method, *inputs, **kwargs) -> Column: @@ -486,7 +486,7 @@ def promote_scalars(val: CompositeOfScalarOrField): return constant_field(val) else: raise ValueError( - f"Expected a `Field` or a number (`float`, `np.int64`, ...), but got {val_type}." + f"Expected a 'Field' or a number ('float', 'np.int64', ...), got '{val_type}'." ) @@ -566,7 +566,7 @@ def execute_shift( return new_pos - raise AssertionError("Unknown object in `offset_provider`") + raise AssertionError("Unknown object in 'offset_provider'.") def _is_list_of_complete_offsets( @@ -878,7 +878,7 @@ def make_in_iterator( return SparseListIterator(it, sparse_dimensions[0]) else: raise NotImplementedError( - f"More than one local dimension is currently not supported, got {sparse_dimensions}" + f"More than one local dimension is currently not supported, got {sparse_dimensions}." ) else: return it @@ -925,7 +925,7 @@ def field_setitem(self, named_indices: NamedFieldIndices, value: Any): if common.is_mutable_field(self._ndarrayfield): self._ndarrayfield[self._translate_named_indices(named_indices)] = value else: - raise RuntimeError("Assigment into a non-mutable Field.") + raise RuntimeError("Assigment into a non-mutable Field is not allowed.") @property def __gt_origin__(self) -> tuple[int, ...]: @@ -1023,7 +1023,7 @@ def np_as_located_field( def _maker(a) -> common.Field: if a.ndim != len(axes): - raise TypeError("ndarray.ndim incompatible with number of given dimensions") + raise TypeError("'ndarray.ndim' is incompatible with number of given dimensions.") ranges = [] for d, s in zip(axes, a.shape): offset = origin.get(d, 0) @@ -1071,7 +1071,7 @@ def dtype(self) -> core_defs.Int32DType: @property def ndarray(self) -> core_defs.NDArrayObject: - raise AttributeError("Cannot get `ndarray` of an infinite Field.") + raise AttributeError("Cannot get 'ndarray' of an infinite 'Field'.") def asnumpy(self) -> np.ndarray: raise NotImplementedError() @@ -1190,7 +1190,7 @@ def codomain(self) -> type[core_defs.ScalarT]: @property def ndarray(self) -> core_defs.NDArrayObject: - raise AttributeError("Cannot get `ndarray` of an infinite Field.") + raise AttributeError("Cannot get 'ndarray' of an infinite 'Field'.") def asnumpy(self) -> np.ndarray: raise NotImplementedError() @@ -1440,7 +1440,7 @@ def _tuple_assign(field: tuple | MutableLocatedField, value: Any, named_indices: if isinstance(field, tuple): if len(field) != len(value): raise RuntimeError( - f"Tuple of incompatible size, expected tuple of len={len(field)}, got len={len(value)}" + f"Tuple of incompatible size, expected tuple of 'len={len(field)}', got 'len={len(value)}'." ) for f, v in zip(field, value): _tuple_assign(f, v, named_indices) @@ -1459,7 +1459,7 @@ def field_getitem(self, named_indices: NamedFieldIndices) -> Any: def field_setitem(self, named_indices: NamedFieldIndices, value: Any): if not isinstance(value, tuple): - raise RuntimeError(f"Value needs to be tuple, got `{value}`.") + raise RuntimeError(f"Value needs to be tuple, got '{value}'.") _tuple_assign(self.data, value, named_indices) @@ -1503,13 +1503,13 @@ def _validate_domain(domain: Domain, offset_provider: OffsetProvider) -> None: if isinstance(domain, runtime.CartesianDomain): if any(isinstance(o, common.Connectivity) for o in offset_provider.values()): raise RuntimeError( - "Got a `CartesianDomain`, but found a `Connectivity` in `offset_provider`, expected `UnstructuredDomain`." + "Got a 'CartesianDomain', but found a 'Connectivity' in 'offset_provider', expected 'UnstructuredDomain'." ) def fendef_embedded(fun: Callable[..., None], *args: Any, **kwargs: Any): if "offset_provider" not in kwargs: - raise RuntimeError("offset_provider not provided") + raise RuntimeError("'offset_provider' not provided.") offset_provider = kwargs["offset_provider"] @@ -1523,7 +1523,7 @@ def closure( _validate_domain(domain_, kwargs["offset_provider"]) domain: dict[Tag, range] = _dimension_to_tag(domain_) if not (common.is_field(out) or is_tuple_of_field(out)): - raise TypeError("Out needs to be a located field.") + raise TypeError("'Out' needs to be a located field.") column_range = None column: Optional[ColumnDescriptor] = None diff --git a/src/gt4py/next/iterator/ir.py b/src/gt4py/next/iterator/ir.py index 535648cc47..e6ee20e227 100644 --- a/src/gt4py/next/iterator/ir.py +++ b/src/gt4py/next/iterator/ir.py @@ -49,13 +49,13 @@ class Sym(Node): # helper @datamodels.validator("kind") def _kind_validator(self: datamodels.DataModelTP, attribute: datamodels.Attribute, value: str): if value and value not in ["Iterator", "Value"]: - raise ValueError(f"Invalid kind `{value}`, must be one of `Iterator`, `Value`.") + raise ValueError(f"Invalid kind '{value}', must be one of 'Iterator', 'Value'.") @datamodels.validator("dtype") def _dtype_validator(self: datamodels.DataModelTP, attribute: datamodels.Attribute, value: str): if value and value[0] not in TYPEBUILTINS: raise ValueError( - f"Invalid dtype `{value}`, must be one of `{'`, `'.join(TYPEBUILTINS)}`." + f"Invalid dtype '{value}', must be one of '{', '.join(TYPEBUILTINS)}'." ) @@ -71,7 +71,7 @@ class Literal(Expr): @datamodels.validator("type") def _type_validator(self: datamodels.DataModelTP, attribute: datamodels.Attribute, value): if value not in TYPEBUILTINS: - raise ValueError(f"{value} is not a valid builtin type.") + raise ValueError(f"'{value}' is not a valid builtin type.") class NoneLiteral(Expr): @@ -115,7 +115,7 @@ class StencilClosure(Node): @datamodels.validator("output") def _output_validator(self: datamodels.DataModelTP, attribute: datamodels.Attribute, value): if isinstance(value, FunCall) and value.fun != SymRef(id="make_tuple"): - raise ValueError("Only FunCall to `make_tuple` allowed.") + raise ValueError("Only FunCall to 'make_tuple' allowed.") UNARY_MATH_NUMBER_BUILTINS = {"abs"} diff --git a/src/gt4py/next/iterator/ir_utils/ir_makers.py b/src/gt4py/next/iterator/ir_utils/ir_makers.py index f7086ada0c..94a2646422 100644 --- a/src/gt4py/next/iterator/ir_utils/ir_makers.py +++ b/src/gt4py/next/iterator/ir_utils/ir_makers.py @@ -295,7 +295,7 @@ def literal_from_value(val: core_defs.Scalar) -> itir.Literal: Literal(value='True', type='bool') """ if not isinstance(val, core_defs.Scalar): # type: ignore[arg-type] # mypy bug #11673 - raise ValueError(f"Value must be a scalar, but got {type(val).__name__}") + raise ValueError(f"Value must be a scalar, got '{type(val).__name__}'.") # At the time this has been written the iterator module has its own type system that is # uncoupled from the one used in the frontend. However since we decided to eventually replace diff --git a/src/gt4py/next/iterator/runtime.py b/src/gt4py/next/iterator/runtime.py index ffc00e474b..e12ae84dbc 100644 --- a/src/gt4py/next/iterator/runtime.py +++ b/src/gt4py/next/iterator/runtime.py @@ -96,7 +96,7 @@ def __call__(self, *args, backend: Optional[ProgramExecutor] = None, **kwargs): backend(self.itir(*args, **kwargs), *args, **kwargs) else: if fendef_embedded is None: - raise RuntimeError("Embedded execution is not registered") + raise RuntimeError("Embedded execution is not registered.") fendef_embedded(self.function, *args, **kwargs) def format_itir(self, *args, formatter: ProgramFormatter, **kwargs) -> str: diff --git a/src/gt4py/next/iterator/tracing.py b/src/gt4py/next/iterator/tracing.py index d1f6bba8d6..30fec1f9fd 100644 --- a/src/gt4py/next/iterator/tracing.py +++ b/src/gt4py/next/iterator/tracing.py @@ -164,7 +164,7 @@ def make_node(o): return NoneLiteral() if hasattr(o, "fun"): return SymRef(id=o.fun.__name__) - raise NotImplementedError(f"Cannot handle {o}") + raise NotImplementedError(f"Cannot handle '{o}'.") def trace_function_call(fun, *, args=None): @@ -269,7 +269,7 @@ def _make_fencil_params(fun, args, *, use_arg_types: bool) -> list[Sym]: # the last parameter info might also be a keyword or variadic keyword argument, but # they are not supported. raise NotImplementedError( - "Only `POSITIONAL_OR_KEYWORD` or `VAR_POSITIONAL` parameters are supported." + "Only 'POSITIONAL_OR_KEYWORD' or 'VAR_POSITIONAL' parameters are supported." ) param_info = param_infos[-1] @@ -279,7 +279,7 @@ def _make_fencil_params(fun, args, *, use_arg_types: bool) -> list[Sym]: param_name = param_info.name else: raise NotImplementedError( - "Only `POSITIONAL_OR_KEYWORD` or `VAR_POSITIONAL` parameters are supported." + "Only 'POSITIONAL_OR_KEYWORD' or 'VAR_POSITIONAL' parameters are supported." ) kind, dtype = None, None diff --git a/src/gt4py/next/iterator/transforms/cse.py b/src/gt4py/next/iterator/transforms/cse.py index cc70e11413..034a39d68f 100644 --- a/src/gt4py/next/iterator/transforms/cse.py +++ b/src/gt4py/next/iterator/transforms/cse.py @@ -123,7 +123,7 @@ def generic_visit(self, *args, **kwargs): depth = kwargs.pop("depth") return super().generic_visit(*args, depth=depth + 1, **kwargs) - def visit(self, node: ir.Node, **kwargs) -> None: # type: ignore[override] # supertype accepts any node, but we want to be more specific here. + def visit(self, node: ir.Node, **kwargs) -> None: # type: ignore[override] # supertype accepts any node, but we want to be more specific here. if not isinstance(node, SymbolTableTrait) and not _is_collectable_expr(node): return super().visit(node, **kwargs) @@ -289,7 +289,7 @@ def extract_subexpression( # `_subexpr_2`: `x + y + (x + y)` raise NotImplementedError( "Results of the current implementation not meaningful for " - "`deepest_expr_first == True` and `once_only == True`." + "'deepest_expr_first == True' and 'once_only == True'." ) ignored_children = False diff --git a/src/gt4py/next/iterator/transforms/pass_manager.py b/src/gt4py/next/iterator/transforms/pass_manager.py index e2feb79c44..2e05391634 100644 --- a/src/gt4py/next/iterator/transforms/pass_manager.py +++ b/src/gt4py/next/iterator/transforms/pass_manager.py @@ -68,7 +68,7 @@ def _inline_into_scan(ir, *, max_iter=10): break ir = inlined else: - raise RuntimeError(f"Inlining into scan did not converge with {max_iter} iterations.") + raise RuntimeError(f"Inlining into 'scan' did not converge within {max_iter} iterations.") return ir @@ -117,7 +117,7 @@ def apply_common_transforms( break ir = inlined else: - raise RuntimeError("Inlining lift and lambdas did not converge.") + raise RuntimeError("Inlining 'lift' and 'lambdas' did not converge.") # Since `CollapseTuple` relies on the type inference which does not support returning tuples # larger than the number of closure outputs as given by the unconditional collapse, we can diff --git a/src/gt4py/next/iterator/transforms/unroll_reduce.py b/src/gt4py/next/iterator/transforms/unroll_reduce.py index 60a5db7e96..861052bb25 100644 --- a/src/gt4py/next/iterator/transforms/unroll_reduce.py +++ b/src/gt4py/next/iterator/transforms/unroll_reduce.py @@ -81,7 +81,7 @@ def _get_connectivity( ) -> common.Connectivity: """Return single connectivity that is compatible with the arguments of the reduce.""" if not _is_reduce(applied_reduce_node): - raise ValueError("Expected a call to a `reduce` object, i.e. `reduce(...)(...)`.") + raise ValueError("Expected a call to a 'reduce' object, i.e. 'reduce(...)(...)'.") connectivities: list[common.Connectivity] = [] for o in _get_partial_offset_tags(applied_reduce_node.args): @@ -90,11 +90,11 @@ def _get_connectivity( connectivities.append(conn) if not connectivities: - raise RuntimeError("Couldn't detect partial shift in any arguments of reduce.") + raise RuntimeError("Couldn't detect partial shift in any arguments of 'reduce'.") if len({(c.max_neighbors, c.has_skip_values) for c in connectivities}) != 1: # The condition for this check is required but not sufficient: the actual neighbor tables could still be incompatible. - raise RuntimeError("Arguments to reduce have incompatible partial shifts.") + raise RuntimeError("Arguments to 'reduce' have incompatible partial shifts.") return connectivities[0] diff --git a/src/gt4py/next/iterator/type_inference.py b/src/gt4py/next/iterator/type_inference.py index 14f3e95e10..2375118cd1 100644 --- a/src/gt4py/next/iterator/type_inference.py +++ b/src/gt4py/next/iterator/type_inference.py @@ -74,7 +74,7 @@ def from_elems(cls: typing.Type[T], *elems: Type) -> typing.Union[T, EmptyTuple] def __iter__(self) -> abc.Iterator[Type]: yield self.front if not isinstance(self.others, (Tuple, EmptyTuple)): - raise ValueError(f"Can not iterate over partially defined tuple {self}") + raise ValueError(f"Can not iterate over partially defined tuple '{self}'.") yield from self.others def __len__(self) -> int: @@ -286,7 +286,7 @@ def handle_constraint( if self.name != other.name: raise TypeError( - f"Can not satisfy constraint on primitive types: {self.name} ≡ {other.name}" + f"Can not satisfy constraint on primitive types: '{self.name}' ≡ '{other.name}'." ) return True @@ -300,7 +300,7 @@ def handle_constraint( self, other: Type, add_constraint: abc.Callable[[Type, Type], None] ) -> bool: if isinstance(other, UnionPrimitive): - raise AssertionError("`UnionPrimitive` may only appear on one side of a constraint.") + raise AssertionError("'UnionPrimitive' may only appear on one side of a constraint.") if not isinstance(other, Primitive): return False @@ -551,7 +551,8 @@ def _infer_shift_location_types(shift_args, offset_provider, constraints): current_loc_out = current_loc_in for arg in shift_args: if not isinstance(arg, ir.OffsetLiteral): - continue # probably some dynamically computed offset, thus we assume it’s a number not an axis and just ignore it (see comment below) + # probably some dynamically computed offset, thus we assume it’s a number not an axis and just ignore it (see comment below) + continue offset = arg.value if isinstance(offset, int): continue # ignore ‘application’ of (partial) shifts @@ -639,7 +640,7 @@ def visit_SymRef(self, node: ir.SymRef, *, symtable, **kwargs) -> Type: elif node.id in ir.GRAMMAR_BUILTINS: raise TypeError( f"Builtin '{node.id}' is only allowed as applied/called function by the type " - f"inference." + "inference." ) elif node.id in ir.TYPEBUILTINS: # TODO(tehrengruber): Implement propagating types of values referring to types, e.g. @@ -649,10 +650,10 @@ def visit_SymRef(self, node: ir.SymRef, *, symtable, **kwargs) -> Type: # `typing.Type`. raise NotImplementedError( f"Type builtin '{node.id}' is only supported as literal argument by the " - f"type inference." + "type inference." ) else: - raise NotImplementedError(f"Missing type definition for builtin '{node.id}'") + raise NotImplementedError(f"Missing type definition for builtin '{node.id}'.") elif node.id in symtable: sym_decl = symtable[node.id] assert isinstance(sym_decl, TYPED_IR_NODES) @@ -696,13 +697,13 @@ def _visit_make_tuple(self, node: ir.FunCall, **kwargs) -> Type: def _visit_tuple_get(self, node: ir.FunCall, **kwargs) -> Type: # Calls to `tuple_get` are handled as being part of the grammar, not as function calls. if len(node.args) != 2: - raise TypeError("`tuple_get` requires exactly two arguments.") + raise TypeError("'tuple_get' requires exactly two arguments.") if ( not isinstance(node.args[0], ir.Literal) or node.args[0].type != ir.INTEGER_INDEX_BUILTIN ): raise TypeError( - f"The first argument to `tuple_get` must be a literal of type `{ir.INTEGER_INDEX_BUILTIN}`." + f"The first argument to 'tuple_get' must be a literal of type '{ir.INTEGER_INDEX_BUILTIN}'." ) self.visit(node.args[0], **kwargs) # visit index so that its type is collected idx = int(node.args[0].value) @@ -725,9 +726,9 @@ def _visit_tuple_get(self, node: ir.FunCall, **kwargs) -> Type: def _visit_neighbors(self, node: ir.FunCall, **kwargs) -> Type: if len(node.args) != 2: - raise TypeError("`neighbors` requires exactly two arguments.") + raise TypeError("'neighbors' requires exactly two arguments.") if not (isinstance(node.args[0], ir.OffsetLiteral) and isinstance(node.args[0].value, str)): - raise TypeError("The first argument to `neighbors` must be an `OffsetLiteral` tag.") + raise TypeError("The first argument to 'neighbors' must be an 'OffsetLiteral' tag.") # Visit arguments such that their type is also inferred self.visit(node.args, **kwargs) @@ -766,11 +767,11 @@ def _visit_neighbors(self, node: ir.FunCall, **kwargs) -> Type: def _visit_cast_(self, node: ir.FunCall, **kwargs) -> Type: if len(node.args) != 2: - raise TypeError("`cast_` requires exactly two arguments.") + raise TypeError("'cast_' requires exactly two arguments.") val_arg_type = self.visit(node.args[0], **kwargs) type_arg = node.args[1] if not isinstance(type_arg, ir.SymRef) or type_arg.id not in ir.TYPEBUILTINS: - raise TypeError("The second argument to `cast_` must be a type literal.") + raise TypeError("The second argument to 'cast_' must be a type literal.") size = TypeVar.fresh() @@ -964,7 +965,7 @@ def _save_types_to_annex(node: ir.Node, types: dict[int, Type]) -> None: and child_node.id in ir.GRAMMAR_BUILTINS | ir.TYPEBUILTINS ): raise AssertionError( - f"Expected a type to be inferred for node `{child_node}`, but none was found." + f"Expected a type to be inferred for node '{child_node}', but none was found." ) diff --git a/src/gt4py/next/otf/binding/nanobind.py b/src/gt4py/next/otf/binding/nanobind.py index 5d54512bd0..bfb3b0d474 100644 --- a/src/gt4py/next/otf/binding/nanobind.py +++ b/src/gt4py/next/otf/binding/nanobind.py @@ -206,7 +206,7 @@ def create_bindings( """ if program_source.language not in [languages.Cpp, languages.Cuda]: raise ValueError( - f"Can only create bindings for C++ program sources, received {program_source.language}." + f"Can only create bindings for C++ program sources, received '{program_source.language}'." ) wrapper_name = program_source.entry_point.name + "_wrapper" diff --git a/src/gt4py/next/otf/compilation/build_systems/cmake_lists.py b/src/gt4py/next/otf/compilation/build_systems/cmake_lists.py index 5ea4ba0519..2c0511ebf4 100644 --- a/src/gt4py/next/otf/compilation/build_systems/cmake_lists.py +++ b/src/gt4py/next/otf/compilation/build_systems/cmake_lists.py @@ -101,7 +101,7 @@ def visit_FindDependency(self, dep: FindDependency): return f"find_package(GridTools REQUIRED PATHS {gridtools_cpp.get_cmake_dir()} NO_DEFAULT_PATH)" case _: - raise ValueError("Library {name} is not supported".format(name=dep.name)) + raise ValueError(f"Library '{dep.name}' is not supported") def visit_LinkDependency(self, dep: LinkDependency): # TODO(ricoh): do not add more libraries here @@ -115,7 +115,7 @@ def visit_LinkDependency(self, dep: LinkDependency): case "gridtools_gpu": lib_name = "GridTools::fn_gpu" case _: - raise ValueError("Library {name} is not supported".format(name=dep.name)) + raise ValueError(f"Library '{dep.name}' is not supported") cfg = "" if dep.name == "nanobind": diff --git a/src/gt4py/next/otf/compilation/compiler.py b/src/gt4py/next/otf/compilation/compiler.py index dacb444207..9fd20b16e2 100644 --- a/src/gt4py/next/otf/compilation/compiler.py +++ b/src/gt4py/next/otf/compilation/compiler.py @@ -80,7 +80,7 @@ def __call__( if not new_data or not is_compiled(new_data) or not module_exists(new_data, src_dir): raise CompilationError( - "On-the-fly compilation unsuccessful for {inp.source_module.entry_point.name}!" + f"On-the-fly compilation unsuccessful for '{inp.program_source.entry_point.name}'." ) return getattr( diff --git a/src/gt4py/next/otf/stages.py b/src/gt4py/next/otf/stages.py index 0370b5eeb3..a21bc83c0b 100644 --- a/src/gt4py/next/otf/stages.py +++ b/src/gt4py/next/otf/stages.py @@ -59,7 +59,7 @@ class ProgramSource(Generic[SrcL, SettingT]): def __post_init__(self): if not isinstance(self.language_settings, self.language.settings_class): raise TypeError( - f"Wrong language settings type for {self.language}, must be subclass of {self.language.settings_class}" + f"Wrong language settings type for '{self.language}', must be subclass of '{self.language.settings_class}'." ) diff --git a/src/gt4py/next/otf/workflow.py b/src/gt4py/next/otf/workflow.py index 6b6b91a310..ed8b768972 100644 --- a/src/gt4py/next/otf/workflow.py +++ b/src/gt4py/next/otf/workflow.py @@ -80,7 +80,7 @@ def replace(self, **kwargs: Any) -> Self: TypeError: If `self` is not a dataclass. """ if not dataclasses.is_dataclass(self): - raise TypeError(f"{self.__class__} is not a dataclass") + raise TypeError(f"'{self.__class__}' is not a dataclass.") assert not isinstance(self, type) return dataclasses.replace(self, **kwargs) # type: ignore[misc] # `self` is guaranteed to be a dataclass (is_dataclass) should be a `TypeGuard`? @@ -242,7 +242,9 @@ class CachedStep( """ step: Workflow[StartT, EndT] - hash_function: Callable[[StartT], HashT] = dataclasses.field(default=hash) # type: ignore[assignment] + hash_function: Callable[[StartT], HashT] = dataclasses.field( + default=hash + ) # type: ignore[assignment] _cache: dict[HashT, EndT] = dataclasses.field(repr=False, init=False, default_factory=dict) diff --git a/src/gt4py/next/program_processors/codegens/gtfn/gtfn_ir_to_gtfn_im_ir.py b/src/gt4py/next/program_processors/codegens/gtfn/gtfn_ir_to_gtfn_im_ir.py index f412386bb3..74fbbfc93f 100644 --- a/src/gt4py/next/program_processors/codegens/gtfn/gtfn_ir_to_gtfn_im_ir.py +++ b/src/gt4py/next/program_processors/codegens/gtfn/gtfn_ir_to_gtfn_im_ir.py @@ -99,7 +99,7 @@ def _get_connectivity( ) -> common.Connectivity: """Return single connectivity that is compatible with the arguments of the reduce.""" if not _is_reduce(applied_reduce_node): - raise ValueError("Expected a call to a `reduce` object, i.e. `reduce(...)(...)`.") + raise ValueError("Expected a call to a 'reduce' object, i.e. 'reduce(...)(...)'.") connectivities: list[common.Connectivity] = [] for o in _get_partial_offset_tags(applied_reduce_node.args): @@ -108,11 +108,11 @@ def _get_connectivity( connectivities.append(conn) if not connectivities: - raise RuntimeError("Couldn't detect partial shift in any arguments of reduce.") + raise RuntimeError("Couldn't detect partial shift in any arguments of 'reduce'.") if len({(c.max_neighbors, c.has_skip_values) for c in connectivities}) != 1: # The condition for this check is required but not sufficient: the actual neighbor tables could still be incompatible. - raise RuntimeError("Arguments to reduce have incompatible partial shifts.") + raise RuntimeError("Arguments to 'reduce' have incompatible partial shifts.") return connectivities[0] diff --git a/src/gt4py/next/program_processors/codegens/gtfn/gtfn_module.py b/src/gt4py/next/program_processors/codegens/gtfn/gtfn_module.py index 7bf310f4e1..4abdaa6eea 100644 --- a/src/gt4py/next/program_processors/codegens/gtfn/gtfn_module.py +++ b/src/gt4py/next/program_processors/codegens/gtfn/gtfn_module.py @@ -135,7 +135,7 @@ def _process_connectivity_args( if isinstance(connectivity, Connectivity): if connectivity.index_type not in [np.int32, np.int64]: raise ValueError( - "Neighbor table indices must be of type `np.int32` or `np.int64`." + "Neighbor table indices must be of type 'np.int32' or 'np.int64'." ) # parameter @@ -165,8 +165,8 @@ def _process_connectivity_args( pass else: raise AssertionError( - f"Expected offset provider `{name}` to be a `Connectivity` or `Dimension`, " - f"but got {type(connectivity).__name__}." + f"Expected offset provider '{name}' to be a 'Connectivity' or 'Dimension', " + f"got '{type(connectivity).__name__}'." ) return parameters, arg_exprs diff --git a/src/gt4py/next/program_processors/codegens/gtfn/itir_to_gtfn_ir.py b/src/gt4py/next/program_processors/codegens/gtfn/itir_to_gtfn_ir.py index f78a052679..842080f8ae 100644 --- a/src/gt4py/next/program_processors/codegens/gtfn/itir_to_gtfn_ir.py +++ b/src/gt4py/next/program_processors/codegens/gtfn/itir_to_gtfn_ir.py @@ -59,7 +59,7 @@ def pytype_to_cpptype(t: str): "axis_literal": None, # TODO: domain? }[t] except KeyError: - raise TypeError(f"Unsupported type '{t}'") from None + raise TypeError(f"Unsupported type '{t}'.") from None _vertical_dimension = "gtfn::unstructured::dim::vertical" @@ -83,7 +83,7 @@ def _get_gridtype(closures: list[itir.StencilClosure]) -> common.GridType: grid_types = {_extract_grid_type(d) for d in domains} if len(grid_types) != 1: raise ValueError( - f"Found StencilClosures with more than one GridType: {grid_types}. This is currently not supported." + f"Found 'StencilClosures' with more than one 'GridType': '{grid_types}'. This is currently not supported." ) return grid_types.pop() @@ -109,7 +109,7 @@ def _collect_dimensions_from_domain( offset_definitions[dim_name] = TagDefinition(name=Sym(id=dim_name)) elif domain.fun == itir.SymRef(id="unstructured_domain"): if len(domain.args) > 2: - raise ValueError("unstructured_domain must not have more than 2 arguments.") + raise ValueError("Unstructured_domain must not have more than 2 arguments.") if len(domain.args) > 0: horizontal_range = domain.args[0] assert isinstance(horizontal_range, itir.FunCall) @@ -126,7 +126,7 @@ def _collect_dimensions_from_domain( ) else: raise AssertionError( - "Expected either a call to `cartesian_domain` or to `unstructured_domain`." + "Expected either a call to 'cartesian_domain' or to 'unstructured_domain'." ) return offset_definitions @@ -181,7 +181,7 @@ def _collect_offset_definitions( ) else: raise AssertionError( - "Elements of offset provider need to be either `Dimension` or `Connectivity`." + "Elements of offset provider need to be either 'Dimension' or 'Connectivity'." ) return offset_definitions @@ -233,7 +233,7 @@ def apply( fencil_definition = node else: raise TypeError( - f"Expected a `FencilDefinition` or `FencilWithTemporaries`, but got `{type(node).__name__}`." + f"Expected a 'FencilDefinition' or 'FencilWithTemporaries', got '{type(node).__name__}'." ) grid_type = _get_gridtype(fencil_definition.closures) @@ -303,7 +303,7 @@ def _make_domain(self, node: itir.FunCall): isinstance(named_range, itir.FunCall) and named_range.fun == itir.SymRef(id="named_range") ): - raise ValueError("Arguments to `domain` need to be calls to `named_range`.") + raise ValueError("Arguments to 'domain' need to be calls to 'named_range'.") tags.append(self.visit(named_range.args[0])) sizes.append( BinaryExpr( @@ -410,9 +410,9 @@ def visit_FunCall(self, node: itir.FunCall, **kwargs: Any) -> Node: # special handling of applied builtins is handled in `_visit_` return getattr(self, visit_method)(node, **kwargs) elif node.fun.id == "shift": - raise ValueError("unapplied shift call not supported: {node}") + raise ValueError("Unapplied shift call not supported: '{node}'.") elif node.fun.id == "scan": - raise ValueError("scans are only supported at the top level of a stencil closure") + raise ValueError("Scans are only supported at the top level of a stencil closure.") if isinstance(node.fun, itir.FunCall): if node.fun.fun == itir.SymRef(id="shift"): assert len(node.args) == 1 @@ -440,7 +440,7 @@ def _visit_output_argument(self, node: itir.Expr): return self.visit(node) elif isinstance(node, itir.FunCall) and node.fun == itir.SymRef(id="make_tuple"): return SidComposite(values=[self._visit_output_argument(v) for v in node.args]) - raise ValueError("Expected `SymRef` or `make_tuple` in output argument.") + raise ValueError("Expected 'SymRef' or 'make_tuple' in output argument.") @staticmethod def _bool_from_literal(node: itir.Node): diff --git a/src/gt4py/next/program_processors/processor_interface.py b/src/gt4py/next/program_processors/processor_interface.py index d9f8b36301..95d3d2ca35 100644 --- a/src/gt4py/next/program_processors/processor_interface.py +++ b/src/gt4py/next/program_processors/processor_interface.py @@ -56,6 +56,62 @@ def kind(self) -> type[ProgramFormatter]: return ProgramFormatter +def _make_arg_filter( + accept_args: None | int | Literal["all"] = "all", +) -> Callable[[tuple[Any, ...]], tuple[Any, ...]]: + match accept_args: + case None: + + def arg_filter(args: tuple[Any, ...]) -> tuple[Any, ...]: + return () + + case "all": + + def arg_filter(args: tuple[Any, ...]) -> tuple[Any, ...]: + return args + + case int(): + if accept_args < 0: + raise ValueError( + f"Number of accepted arguments cannot be a negative number, got {accept_args}." + ) + + def arg_filter(args: tuple[Any, ...]) -> tuple[Any, ...]: + return args[:accept_args] + + case _: + raise ValueError(f"Invalid 'accept_args' value: {accept_args}.") + return arg_filter + + +def _make_kwarg_filter( + accept_kwargs: None | Sequence[str] | Literal["all"] = "all", +) -> Callable[[dict[str, Any]], dict[str, Any]]: + match accept_kwargs: + case None: + + def kwarg_filter(kwargs: dict[str, Any]) -> dict[str, Any]: + return {} + + case "all": + + def kwarg_filter(kwargs: dict[str, Any]) -> dict[str, Any]: + return kwargs + + case Sequence(): + if not all(isinstance(a, str) for a in accept_kwargs): + raise ValueError( + f"Provided invalid list of keyword argument names: '{accept_kwargs}'." + ) + + def kwarg_filter(kwargs: dict[str, Any]) -> dict[str, Any]: + return {key: value for key, value in kwargs.items() if key in accept_kwargs} + + case _: + raise ValueError(f"Invalid 'accept_kwargs' value: {accept_kwargs}") + return kwarg_filter + + def make_program_processor( func: ProgramProcessorCallable[OutputT], kind: type[ProcessorKindT], @@ -80,33 +136,9 @@ def make_program_processor( Raises: ValueError: If the value of `accept_args` or `accept_kwargs` is invalid. """ - args_filter: Callable[[Sequence], Sequence] - if accept_args is None: - args_filter = lambda args: () # noqa: E731 # use def instead of named lambdas - elif accept_args == "all": - args_filter = lambda args: args # noqa: E731 - elif isinstance(accept_args, int): - if accept_args < 0: - raise ValueError( - f"Number of accepted arguments cannot be a negative number ({accept_args})" - ) - args_filter = lambda args: args[:accept_args] # type: ignore[misc] # noqa: E731 - else: - raise ValueError(f"Invalid ({accept_args}) accept_args value") - - filtered_kwargs: Callable[[dict[str, Any]], dict[str, Any]] - if accept_kwargs is None: - filtered_kwargs = lambda kwargs: {} # noqa: E731 # use def instead of named lambdas - elif accept_kwargs == "all": # don't swap with 'isinstance(..., Sequence)' - filtered_kwargs = lambda kwargs: kwargs # noqa: E731 - elif isinstance(accept_kwargs, Sequence): - if not all(isinstance(a, str) for a in accept_kwargs): - raise ValueError(f"Provided invalid list of keyword argument names ({accept_args})") - filtered_kwargs = lambda kwargs: { # noqa: E731 - key: value for key, value in kwargs.items() if key in accept_kwargs # type: ignore[operator] # key in accept_kwargs - } - else: - raise ValueError(f"Invalid ({accept_kwargs}) 'accept_kwargs' value") + args_filter = _make_arg_filter(accept_args) + + filtered_kwargs = _make_kwarg_filter(accept_kwargs) @functools.wraps(func) def _wrapper(program: itir.FencilDefinition, *args, **kwargs) -> OutputT: @@ -195,7 +227,7 @@ def ensure_processor_kind( obj: ProgramProcessor[OutputT, ProcessorKindT], kind: type[ProcessorKindT] ) -> None: if not is_processor_kind(obj, kind): - raise TypeError(f"{obj} is not a {kind.__name__}!") + raise TypeError(f"'{obj}' is not a '{kind.__name__}'.") class ProgramBackend( diff --git a/src/gt4py/next/program_processors/runners/dace_iterator/__init__.py b/src/gt4py/next/program_processors/runners/dace_iterator/__init__.py index acfa06b456..65f9d9d71a 100644 --- a/src/gt4py/next/program_processors/runners/dace_iterator/__init__.py +++ b/src/gt4py/next/program_processors/runners/dace_iterator/__init__.py @@ -148,7 +148,7 @@ def get_stride_args( stride, remainder = divmod(stride_size, value.itemsize) if remainder != 0: raise ValueError( - f"Stride ({stride_size} bytes) for argument '{sym}' must be a multiple of item size ({value.itemsize} bytes)" + f"Stride ({stride_size} bytes) for argument '{sym}' must be a multiple of item size ({value.itemsize} bytes)." ) stride_args[str(sym)] = stride @@ -334,7 +334,7 @@ def _run_dace_gpu(program: itir.FencilDefinition, *args, **kwargs) -> None: else: def _run_dace_gpu(program: itir.FencilDefinition, *args, **kwargs) -> None: - raise RuntimeError("Missing `cupy` dependency for GPU execution.") + raise RuntimeError("Missing 'cupy' dependency for GPU execution.") run_dace_gpu = otf_exec.OTFBackend( diff --git a/src/gt4py/next/program_processors/runners/dace_iterator/itir_to_tasklet.py b/src/gt4py/next/program_processors/runners/dace_iterator/itir_to_tasklet.py index d10a14a1ee..d08476847f 100644 --- a/src/gt4py/next/program_processors/runners/dace_iterator/itir_to_tasklet.py +++ b/src/gt4py/next/program_processors/runners/dace_iterator/itir_to_tasklet.py @@ -401,7 +401,7 @@ def builtin_tuple_get( index = node_args[0] if isinstance(index, itir.Literal): return [elements[int(index.value)]] - raise ValueError("Tuple can only be subscripted with compile-time constants") + raise ValueError("Tuple can only be subscripted with compile-time constants.") _GENERAL_BUILTIN_MAPPING: dict[ @@ -640,7 +640,7 @@ def visit_FunCall(self, node: itir.FunCall) -> list[ValueExpr] | IteratorExpr: elif builtin_name in _GENERAL_BUILTIN_MAPPING: return self._visit_general_builtin(node) else: - raise NotImplementedError(f"{builtin_name} not implemented") + raise NotImplementedError(f"'{builtin_name}' not implemented.") return self._visit_call(node) def _visit_call(self, node: itir.FunCall): diff --git a/src/gt4py/next/program_processors/runners/dace_iterator/utility.py b/src/gt4py/next/program_processors/runners/dace_iterator/utility.py index cb14b89e8a..55717326a3 100644 --- a/src/gt4py/next/program_processors/runners/dace_iterator/utility.py +++ b/src/gt4py/next/program_processors/runners/dace_iterator/utility.py @@ -32,7 +32,7 @@ def as_dace_type(type_: ts.ScalarType): return dace.float32 elif type_.kind == ts.ScalarKind.FLOAT64: return dace.float64 - raise ValueError(f"scalar type {type_} not supported") + raise ValueError(f"Scalar type '{type_}' not supported.") def filter_neighbor_tables(offset_provider: dict[str, Any]): diff --git a/src/gt4py/next/program_processors/runners/gtfn.py b/src/gt4py/next/program_processors/runners/gtfn.py index 5d4b450d39..baa45ddc0e 100644 --- a/src/gt4py/next/program_processors/runners/gtfn.py +++ b/src/gt4py/next/program_processors/runners/gtfn.py @@ -83,7 +83,7 @@ def extract_connectivity_args( if isinstance(conn, common.Connectivity): if not isinstance(conn, common.NeighborTable): raise NotImplementedError( - "Only `NeighborTable` connectivities implemented at this point." + "Only 'NeighborTable' connectivities implemented at this point." ) # copying to device here is a fallback for easy testing and might be removed later conn_arg = _ensure_is_on_device(conn.table, device) @@ -92,8 +92,8 @@ def extract_connectivity_args( pass else: raise AssertionError( - f"Expected offset provider `{name}` to be a `Connectivity` or `Dimension`, " - f"but got {type(conn).__name__}." + f"Expected offset provider '{name}' to be a 'Connectivity' or 'Dimension', " + f"but got '{type(conn).__name__}'." ) return args diff --git a/src/gt4py/next/type_system/type_info.py b/src/gt4py/next/type_system/type_info.py index 564df7fd1a..20fa8bd791 100644 --- a/src/gt4py/next/type_system/type_info.py +++ b/src/gt4py/next/type_system/type_info.py @@ -75,13 +75,15 @@ def type_class(symbol_type: ts.TypeSpec) -> Type[ts.TypeSpec]: match symbol_type: case ts.DeferredType(constraint): if constraint is None: - raise ValueError(f"No type information available for {symbol_type}!") + raise ValueError(f"No type information available for '{symbol_type}'.") elif isinstance(constraint, tuple): - raise ValueError(f"Not sufficient type information available for {symbol_type}!") + raise ValueError(f"Not sufficient type information available for '{symbol_type}'.") return constraint case ts.TypeSpec() as concrete_type: return concrete_type.__class__ - raise ValueError(f"Invalid type for TypeInfo: requires {ts.TypeSpec}, got {type(symbol_type)}!") + raise ValueError( + f"Invalid type for TypeInfo: requires '{ts.TypeSpec}', got '{type(symbol_type)}'." + ) def primitive_constituents( @@ -163,7 +165,7 @@ def extract_dtype(symbol_type: ts.TypeSpec) -> ts.ScalarType: return dtype case ts.ScalarType() as dtype: return dtype - raise ValueError(f"Can not unambiguosly extract data type from {symbol_type}!") + raise ValueError(f"Can not unambiguosly extract data type from '{symbol_type}'.") def is_floating_point(symbol_type: ts.TypeSpec) -> bool: @@ -320,7 +322,7 @@ def extract_dims(symbol_type: ts.TypeSpec) -> list[common.Dimension]: return [] case ts.FieldType(dims): return dims - raise ValueError(f"Can not extract dimensions from {symbol_type}!") + raise ValueError(f"Can not extract dimensions from '{symbol_type}'.") def is_local_field(type_: ts.FieldType) -> bool: @@ -435,7 +437,7 @@ def promote(*types: ts.FieldType | ts.ScalarType) -> ts.FieldType | ts.ScalarTyp dtype = cast(ts.ScalarType, promote(*(extract_dtype(type_) for type_ in types))) return ts.FieldType(dims=dims, dtype=dtype) - raise TypeError("Expected a FieldType or ScalarType.") + raise TypeError("Expected a 'FieldType' or 'ScalarType'.") @functools.singledispatch @@ -446,7 +448,7 @@ def return_type( with_kwargs: dict[str, ts.TypeSpec], ): raise NotImplementedError( - f"Return type deduction of type " f"{type(callable_type).__name__} not implemented." + f"Return type deduction of type " f"'{type(callable_type).__name__}' not implemented." ) @@ -473,7 +475,7 @@ def return_type_field( raise ValueError("Could not deduce return type of invalid remap operation.") from ex if not isinstance(with_args[0], ts.OffsetType): - raise ValueError(f"First argument must be of type {ts.OffsetType}, got {with_args[0]}.") + raise ValueError(f"First argument must be of type '{ts.OffsetType}', got '{with_args[0]}'.") source_dim = with_args[0].source target_dims = with_args[0].target @@ -500,7 +502,7 @@ def canonicalize_arguments( ignore_errors=False, use_signature_ordering=False, ) -> tuple[list, dict]: - raise NotImplementedError(f"Not implemented for type {type(func_type).__name__}.") + raise NotImplementedError(f"Not implemented for type '{type(func_type).__name__}'.") @canonicalize_arguments.register @@ -526,7 +528,7 @@ def canonicalize_function_arguments( cargs[args_idx] = ckwargs.pop(name) elif not ignore_errors: raise AssertionError( - f"Error canonicalizing function arguments. Got multiple values for argument `{name}`." + f"Error canonicalizing function arguments. Got multiple values for argument '{name}'." ) a, b = set(func_type.kw_only_args.keys()), set(ckwargs.keys()) @@ -534,7 +536,7 @@ def canonicalize_function_arguments( if invalid_kw_args and (not ignore_errors or use_signature_ordering): # this error can not be ignored as otherwise the invariant that no arguments are dropped # is invalidated. - raise AssertionError(f"Invalid keyword arguments {', '.join(invalid_kw_args)}.") + raise AssertionError(f"Invalid keyword arguments '{', '.join(invalid_kw_args)}'.") if use_signature_ordering: ckwargs = {k: ckwargs[k] for k in func_type.kw_only_args.keys() if k in ckwargs} @@ -566,7 +568,7 @@ def structural_function_signature_incompatibilities( if args_idx < len(args): # remove the argument here such that later errors stay comprehensible kwargs.pop(name) - yield f"Got multiple values for argument `{name}`." + yield f"Got multiple values for argument '{name}'." num_pos_params = len(func_type.pos_only_args) + len(func_type.pos_or_kw_args) num_pos_args = len(args) - args.count(UNDEFINED_ARG) @@ -582,17 +584,17 @@ def structural_function_signature_incompatibilities( range(len(func_type.pos_only_args), num_pos_params), func_type.pos_or_kw_args.keys() ): if args[i] is UNDEFINED_ARG: - missing_positional_args.append(f"`{arg_type}`") + missing_positional_args.append(f"'{arg_type}'") if missing_positional_args: yield f"Missing {len(missing_positional_args)} required positional argument{'s' if len(missing_positional_args) != 1 else ''}: {', '.join(missing_positional_args)}" # check for missing or extra keyword arguments kw_a_m_b = set(func_type.kw_only_args.keys()) - set(kwargs.keys()) if len(kw_a_m_b) > 0: - yield f"Missing required keyword argument{'s' if len(kw_a_m_b) != 1 else ''} `{'`, `'.join(kw_a_m_b)}`." + yield f"Missing required keyword argument{'s' if len(kw_a_m_b) != 1 else ''} '{', '.join(kw_a_m_b)}'." kw_b_m_a = set(kwargs.keys()) - set(func_type.kw_only_args.keys()) if len(kw_b_m_a) > 0: - yield f"Got unexpected keyword argument{'s' if len(kw_b_m_a) != 1 else ''} `{'`, `'.join(kw_b_m_a)}`." + yield f"Got unexpected keyword argument{'s' if len(kw_b_m_a) != 1 else ''} '{', '.join(kw_b_m_a)}'." @functools.singledispatch @@ -604,7 +606,7 @@ def function_signature_incompatibilities( Note that all types must be concrete/complete. """ - raise NotImplementedError(f"Not implemented for type {type(func_type).__name__}.") + raise NotImplementedError(f"Not implemented for type '{type(func_type).__name__}'.") @function_signature_incompatibilities.register @@ -639,14 +641,14 @@ def function_signature_incompatibilities_func( # noqa: C901 if i < len(func_type.pos_only_args): arg_repr = f"{_number_to_ordinal_number(i+1)} argument" else: - arg_repr = f"argument `{list(func_type.pos_or_kw_args.keys())[i - len(func_type.pos_only_args)]}`" - yield f"Expected {arg_repr} to be of type `{a_arg}`, but got `{b_arg}`." + arg_repr = f"argument '{list(func_type.pos_or_kw_args.keys())[i - len(func_type.pos_only_args)]}'" + yield f"Expected {arg_repr} to be of type '{a_arg}', got '{b_arg}'." for kwarg in set(func_type.kw_only_args.keys()) & set(kwargs.keys()): if (a_kwarg := func_type.kw_only_args[kwarg]) != ( b_kwarg := kwargs[kwarg] ) and not is_concretizable(a_kwarg, to_type=b_kwarg): - yield f"Expected keyword argument `{kwarg}` to be of type `{func_type.kw_only_args[kwarg]}`, but got `{kwargs[kwarg]}`." + yield f"Expected keyword argument '{kwarg}' to be of type '{func_type.kw_only_args[kwarg]}', got '{kwargs[kwarg]}'." @function_signature_incompatibilities.register @@ -660,11 +662,11 @@ def function_signature_incompatibilities_field( return if not isinstance(args[0], ts.OffsetType): - yield f"Expected first argument to be of type {ts.OffsetType}, but got {args[0]}." + yield f"Expected first argument to be of type '{ts.OffsetType}', got '{args[0]}'." return if kwargs: - yield f"Got unexpected keyword argument(s) `{'`, `'.join(kwargs.keys())}`." + yield f"Got unexpected keyword argument(s) '{', '.join(kwargs.keys())}'." return source_dim = args[0].source @@ -705,7 +707,7 @@ def accepts_args( """ if not isinstance(callable_type, ts.CallableType): if raise_exception: - raise ValueError(f"Expected a callable type, but got `{callable_type}`.") + raise ValueError(f"Expected a callable type, got '{callable_type}'.") return False errors = function_signature_incompatibilities(callable_type, with_args, with_kwargs) @@ -713,7 +715,7 @@ def accepts_args( error_list = list(errors) if len(error_list) > 0: raise ValueError( - f"Invalid call to function of type `{callable_type}`:\n" + f"Invalid call to function of type '{callable_type}':\n" + ("\n".join([f" - {error}" for error in error_list])) ) return True diff --git a/src/gt4py/next/type_system/type_translation.py b/src/gt4py/next/type_system/type_translation.py index 007a83844c..88a8347fe4 100644 --- a/src/gt4py/next/type_system/type_translation.py +++ b/src/gt4py/next/type_system/type_translation.py @@ -37,7 +37,7 @@ def get_scalar_kind(dtype: npt.DTypeLike) -> ts.ScalarKind: try: dt = np.dtype(dtype) except TypeError as err: - raise ValueError(f"Invalid scalar type definition ({dtype})") from err + raise ValueError(f"Invalid scalar type definition ('{dtype}').") from err if dt.shape == () and dt.fields is None: match dt: @@ -54,9 +54,9 @@ def get_scalar_kind(dtype: npt.DTypeLike) -> ts.ScalarKind: case np.str_: return ts.ScalarKind.STRING case _: - raise ValueError(f"Impossible to map '{dtype}' value to a ScalarKind") + raise ValueError(f"Impossible to map '{dtype}' value to a 'ScalarKind'.") else: - raise ValueError(f"Non-trivial dtypes like '{dtype}' are not yet supported") + raise ValueError(f"Non-trivial dtypes like '{dtype}' are not yet supported.") def from_type_hint( @@ -76,7 +76,7 @@ def from_type_hint( type_hint = xtyping.eval_forward_ref(type_hint, globalns=globalns, localns=localns) except Exception as error: raise ValueError( - f"Type annotation ({type_hint}) has undefined forward references!" + f"Type annotation '{type_hint}' has undefined forward references." ) from error # Annotated @@ -98,50 +98,50 @@ def from_type_hint( case builtins.tuple: if not args: - raise ValueError(f"Tuple annotation ({type_hint}) requires at least one argument!") + raise ValueError(f"Tuple annotation '{type_hint}' requires at least one argument.") if Ellipsis in args: - raise ValueError(f"Unbound tuples ({type_hint}) are not allowed!") + raise ValueError(f"Unbound tuples '{type_hint}' are not allowed.") return ts.TupleType(types=[recursive_make_symbol(arg) for arg in args]) case common.Field: if (n_args := len(args)) != 2: - raise ValueError(f"Field type requires two arguments, got {n_args}! ({type_hint})") + raise ValueError(f"Field type requires two arguments, got {n_args}: '{type_hint}'.") dims: Union[Ellipsis, list[common.Dimension]] = [] dim_arg, dtype_arg = args if isinstance(dim_arg, list): for d in dim_arg: if not isinstance(d, common.Dimension): - raise ValueError(f"Invalid field dimension definition '{d}'") + raise ValueError(f"Invalid field dimension definition '{d}'.") dims.append(d) elif dim_arg is Ellipsis: dims = dim_arg else: - raise ValueError(f"Invalid field dimensions '{dim_arg}'") + raise ValueError(f"Invalid field dimensions '{dim_arg}'.") try: dtype = recursive_make_symbol(dtype_arg) except ValueError as error: raise ValueError( - f"Field dtype argument must be a scalar type (got '{dtype_arg}')!" + f"Field dtype argument must be a scalar type (got '{dtype_arg}')." ) from error if not isinstance(dtype, ts.ScalarType) or dtype.kind == ts.ScalarKind.STRING: - raise ValueError("Field dtype argument must be a scalar type (got '{dtype}')!") + raise ValueError("Field dtype argument must be a scalar type (got '{dtype}').") return ts.FieldType(dims=dims, dtype=dtype) case collections.abc.Callable: if not args: - raise ValueError("Not annotated functions are not supported!") + raise ValueError("Unannotated functions are not supported.") try: arg_types, return_type = args args = [recursive_make_symbol(arg) for arg in arg_types] except Exception as error: - raise ValueError(f"Invalid callable annotations in {type_hint}") from error + raise ValueError(f"Invalid callable annotations in '{type_hint}'.") from error kwargs_info = [arg for arg in extra_args if isinstance(arg, xtyping.CallableKwargsInfo)] if len(kwargs_info) != 1: - raise ValueError(f"Invalid callable annotations in {type_hint}") + raise ValueError(f"Invalid callable annotations in '{type_hint}'.") kwargs = { arg: recursive_make_symbol(arg_type) for arg, arg_type in kwargs_info[0].data.items() @@ -155,7 +155,7 @@ def from_type_hint( returns=recursive_make_symbol(return_type), ) - raise ValueError(f"'{type_hint}' type is not supported") + raise ValueError(f"'{type_hint}' type is not supported.") def from_value(value: Any) -> ts.TypeSpec: @@ -178,7 +178,7 @@ def from_value(value: Any) -> ts.TypeSpec: break if not symbol_type: raise ValueError( - f"Value `{value}` is out of range to be representable as `INT32` or `INT64`." + f"Value '{value}' is out of range to be representable as 'INT32' or 'INT64'." ) return candidate_type elif isinstance(value, common.Dimension): @@ -200,4 +200,4 @@ def from_value(value: Any) -> ts.TypeSpec: if isinstance(symbol_type, (ts.DataType, ts.CallableType, ts.OffsetType, ts.DimensionType)): return symbol_type else: - raise ValueError(f"Impossible to map '{value}' value to a Symbol") + raise ValueError(f"Impossible to map '{value}' value to a 'Symbol'.") diff --git a/tests/next_tests/integration_tests/cases.py b/tests/next_tests/integration_tests/cases.py index b1e26b40cb..6217d3c782 100644 --- a/tests/next_tests/integration_tests/cases.py +++ b/tests/next_tests/integration_tests/cases.py @@ -127,7 +127,7 @@ class ConstInitializer(DataInitializer): def __init__(self, value: ScalarValue): if not core_defs.is_scalar_type(value): raise ValueError( - "`ConstInitializer` can not be used with non-scalars. Use `Case.as_field` instead." + "'ConstInitializer' can not be used with non-scalars. Use 'Case.as_field' instead." ) self.value = value @@ -162,7 +162,7 @@ class IndexInitializer(DataInitializer): @property def scalar_value(self) -> ScalarValue: - raise AttributeError("`scalar_value` not supported in `IndexInitializer`.") + raise AttributeError("'scalar_value' not supported in 'IndexInitializer'.") def field( self, @@ -172,7 +172,7 @@ def field( ) -> FieldValue: if len(sizes) > 1: raise ValueError( - f"`IndexInitializer` only supports fields with a single `Dimension`, got {sizes}." + f"'IndexInitializer' only supports fields with a single 'Dimension', got {sizes}." ) n_data = list(sizes.values())[0] return constructors.as_field( @@ -244,7 +244,7 @@ def __call__(self, *args: Any, **kwargs: Any) -> Any: return self.partial(*args, **kwargs) def __getattr__(self, name: str) -> Any: - raise AttributeError(f"No setter for argument {name}.") + raise AttributeError(f"No setter for argument '{name}'.") @typing.overload @@ -323,7 +323,7 @@ class NewBuilder(Builder): if 0 < len(args) <= 1 and args[0] is not None: return make_builder_inner(args[0]) if len(args) > 1: - raise ValueError(f"make_builder takes only one positional argument, {len(args)} received!") + raise ValueError(f"make_builder takes only one positional argument, {len(args)} received.") return make_builder_inner @@ -533,7 +533,7 @@ def _allocate_from_type( ) case _: raise TypeError( - f"Can not allocate for type {arg_type} with initializer {strategy or 'default'}" + f"Can not allocate for type '{arg_type}' with initializer '{strategy or 'default'}'." ) @@ -542,7 +542,7 @@ def get_param_types( ) -> dict[str, ts.TypeSpec]: if fieldview_prog.definition is None: raise ValueError( - f"test cases do not support {type(fieldview_prog)} with empty .definition attribute (as you would get from .as_program())!" + f"test cases do not support '{type(fieldview_prog)}' with empty .definition attribute (as you would get from .as_program())." ) annotations = xtyping.get_type_hints(fieldview_prog.definition) return { @@ -559,7 +559,7 @@ def get_param_size(param_type: ts.TypeSpec, sizes: dict[gtx.Dimension, int]) -> case ts.TupleType(types): return sum([get_param_size(t, sizes=sizes) for t in types]) case _: - raise TypeError(f"Can not get size for parameter of type {param_type}") + raise TypeError(f"Can not get size for parameter of type '{param_type}'.") def extend_sizes( diff --git a/tests/next_tests/integration_tests/feature_tests/ffront_tests/ffront_test_utils.py b/tests/next_tests/integration_tests/feature_tests/ffront_tests/ffront_test_utils.py index f8a3f6a975..e25576ebde 100644 --- a/tests/next_tests/integration_tests/feature_tests/ffront_tests/ffront_test_utils.py +++ b/tests/next_tests/integration_tests/feature_tests/ffront_tests/ffront_test_utils.py @@ -22,7 +22,6 @@ import gt4py.next as gtx from gt4py.next.ffront import decorator from gt4py.next.iterator import ir as itir -from gt4py.next.program_processors.runners import gtfn, roundtrip try: @@ -39,7 +38,7 @@ def no_backend(program: itir.FencilDefinition, *args: Any, **kwargs: Any) -> None: """Temporary default backend to not accidentally test the wrong backend.""" - raise ValueError("No backend selected! Backend selection is mandatory in tests.") + raise ValueError("No backend selected. Backend selection is mandatory in tests.") OPTIONAL_PROCESSORS = [] diff --git a/tests/next_tests/integration_tests/feature_tests/ffront_tests/test_arg_call_interface.py b/tests/next_tests/integration_tests/feature_tests/ffront_tests/test_arg_call_interface.py index 6293ff76bd..b41696a36b 100644 --- a/tests/next_tests/integration_tests/feature_tests/ffront_tests/test_arg_call_interface.py +++ b/tests/next_tests/integration_tests/feature_tests/ffront_tests/test_arg_call_interface.py @@ -226,7 +226,7 @@ def testee( def test_scan_wrong_return_type(cartesian_case): with pytest.raises( errors.DSLError, - match=(r"Argument `init` to scan operator `testee_scan` must have same type as its return"), + match=(r"Argument 'init' to scan operator 'testee_scan' must have same type as its return"), ): @scan_operator(axis=KDim, forward=True, init=0) @@ -245,7 +245,7 @@ def test_scan_wrong_state_type(cartesian_case): with pytest.raises( errors.DSLError, match=( - r"Argument `init` to scan operator `testee_scan` must have same type as `state` argument" + r"Argument 'init' to scan operator 'testee_scan' must have same type as 'state' argument" ), ): @@ -276,7 +276,7 @@ def program_bound_args(arg1: bool, arg2: bool, out: cases.IField): def test_bind_invalid_arg(cartesian_case, bound_args_testee): with pytest.raises( - TypeError, match="Keyword argument `inexistent_arg` is not a valid program parameter." + TypeError, match="Keyword argument 'inexistent_arg' is not a valid program parameter." ): bound_args_testee.with_bound_args(inexistent_arg=1) @@ -306,7 +306,7 @@ def test_call_bound_program_with_already_bound_arg(cartesian_case, bound_args_te assert ( re.search( - "Parameter `arg2` already set as a bound argument.", exc_info.value.__cause__.args[0] + "Parameter 'arg2' already set as a bound argument.", exc_info.value.__cause__.args[0] ) is not None ) diff --git a/tests/next_tests/integration_tests/feature_tests/ffront_tests/test_execution.py b/tests/next_tests/integration_tests/feature_tests/ffront_tests/test_execution.py index 51f853d41d..a08931628b 100644 --- a/tests/next_tests/integration_tests/feature_tests/ffront_tests/test_execution.py +++ b/tests/next_tests/integration_tests/feature_tests/ffront_tests/test_execution.py @@ -1188,7 +1188,7 @@ def unpack( def test_tuple_unpacking_too_many_values(cartesian_case): with pytest.raises( errors.DSLError, - match=(r"Could not deduce type: Too many values to unpack \(expected 3\)"), + match=(r"Too many values to unpack \(expected 3\)."), ): @gtx.field_operator(backend=cartesian_case.backend) @@ -1197,8 +1197,10 @@ def _star_unpack() -> tuple[int32, float64, int32]: return a, b, c -def test_tuple_unpacking_too_many_values(cartesian_case): - with pytest.raises(errors.DSLError, match=(r"Assignment value must be of type tuple!")): +def test_tuple_unpacking_too_few_values(cartesian_case): + with pytest.raises( + errors.DSLError, match=(r"Assignment value must be of type tuple, got 'int32'.") + ): @gtx.field_operator(backend=cartesian_case.backend) def _invalid_unpack() -> tuple[int32, float64, int32]: diff --git a/tests/next_tests/integration_tests/feature_tests/ffront_tests/test_math_builtin_execution.py b/tests/next_tests/integration_tests/feature_tests/ffront_tests/test_math_builtin_execution.py index 8cfcff160c..167ccbb0a5 100644 --- a/tests/next_tests/integration_tests/feature_tests/ffront_tests/test_math_builtin_execution.py +++ b/tests/next_tests/integration_tests/feature_tests/ffront_tests/test_math_builtin_execution.py @@ -57,7 +57,7 @@ def make_builtin_field_operator(builtin_name: str): "return": cases.IFloatField, } else: - raise AssertionError(f"Unknown builtin `{builtin_name}`") + raise AssertionError(f"Unknown builtin '{builtin_name}'.") closure_vars = {"IDim": IDim, builtin_name: getattr(fbuiltins, builtin_name)} diff --git a/tests/next_tests/integration_tests/feature_tests/ffront_tests/test_math_unary_builtins.py b/tests/next_tests/integration_tests/feature_tests/ffront_tests/test_math_unary_builtins.py index c2ab43773f..f5bf453a09 100644 --- a/tests/next_tests/integration_tests/feature_tests/ffront_tests/test_math_unary_builtins.py +++ b/tests/next_tests/integration_tests/feature_tests/ffront_tests/test_math_unary_builtins.py @@ -147,9 +147,9 @@ def tilde_fieldop(inp1: cases.IBoolField) -> cases.IBoolField: def test_unary_not(cartesian_case): pytest.xfail( - "We accidentally supported `not` on fields. This is wrong, we should raise an error." + "We accidentally supported 'not' on fields. This is wrong, we should raise an error." ) - with pytest.raises: # TODO `not` on a field should be illegal + with pytest.raises: # TODO 'not' on a field should be illegal @gtx.field_operator def not_fieldop(inp1: cases.IBoolField) -> cases.IBoolField: diff --git a/tests/next_tests/integration_tests/feature_tests/ffront_tests/test_program.py b/tests/next_tests/integration_tests/feature_tests/ffront_tests/test_program.py index 4c0613a33c..c86881ab7c 100644 --- a/tests/next_tests/integration_tests/feature_tests/ffront_tests/test_program.py +++ b/tests/next_tests/integration_tests/feature_tests/ffront_tests/test_program.py @@ -228,8 +228,8 @@ def test_wrong_argument_type(cartesian_case, copy_program_def): copy_program(inp, out, offset_provider={}) msgs = [ - r"- Expected argument `in_field` to be of type `Field\[\[IDim], float64\]`," - r" but got `Field\[\[JDim\], float64\]`.", + r"- Expected argument 'in_field' to be of type 'Field\[\[IDim], float64\]'," + r" got 'Field\[\[JDim\], float64\]'.", ] for msg in msgs: assert re.search(msg, exc_info.value.__cause__.args[0]) is not None diff --git a/tests/next_tests/integration_tests/feature_tests/ffront_tests/test_scalar_if.py b/tests/next_tests/integration_tests/feature_tests/ffront_tests/test_scalar_if.py index 84b480a23d..af06da3e29 100644 --- a/tests/next_tests/integration_tests/feature_tests/ffront_tests/test_scalar_if.py +++ b/tests/next_tests/integration_tests/feature_tests/ffront_tests/test_scalar_if.py @@ -334,7 +334,7 @@ def if_without_else( def test_if_non_scalar_condition(): - with pytest.raises(errors.DSLError, match="Condition for `if` must be scalar."): + with pytest.raises(errors.DSLError, match="Condition for 'if' must be scalar"): @field_operator def if_non_scalar_condition( @@ -347,7 +347,7 @@ def if_non_scalar_condition( def test_if_non_boolean_condition(): - with pytest.raises(errors.DSLError, match="Condition for `if` must be of boolean type."): + with pytest.raises(errors.DSLError, match="Condition for 'if' must be of boolean type"): @field_operator def if_non_boolean_condition( diff --git a/tests/next_tests/integration_tests/feature_tests/ffront_tests/test_type_deduction.py b/tests/next_tests/integration_tests/feature_tests/ffront_tests/test_type_deduction.py index d1a5f24f79..2174871f89 100644 --- a/tests/next_tests/integration_tests/feature_tests/ffront_tests/test_type_deduction.py +++ b/tests/next_tests/integration_tests/feature_tests/ffront_tests/test_type_deduction.py @@ -88,7 +88,7 @@ def type_info_cases() -> list[tuple[Optional[ts.TypeSpec], dict]]: def callable_type_info_cases(): # reuse all the other test cases not_callable = [ - (symbol_type, [], {}, [r"Expected a callable type, but got "], None) + (symbol_type, [], {}, [r"Expected a callable type, got "], None) for symbol_type, attributes in type_info_cases() if not isinstance(symbol_type, ts.CallableType) ] @@ -165,7 +165,7 @@ def callable_type_info_cases(): nullary_func_type, [], {"foo": bool_type}, - [r"Got unexpected keyword argument `foo`."], + [r"Got unexpected keyword argument 'foo'."], None, ), ( @@ -180,7 +180,7 @@ def callable_type_info_cases(): unary_func_type, [float_type], {}, - [r"Expected 1st argument to be of type `bool`, but got `float64`."], + [r"Expected 1st argument to be of type 'bool', got 'float64'."], None, ), ( @@ -188,7 +188,7 @@ def callable_type_info_cases(): [], {}, [ - r"Missing 1 required positional argument: `foo`", + r"Missing 1 required positional argument: 'foo'", r"Function takes 1 positional argument, but 0 were given.", ], None, @@ -199,31 +199,31 @@ def callable_type_info_cases(): kw_or_pos_arg_func_type, [], {"foo": float_type}, - [r"Expected argument `foo` to be of type `bool`, but got `float64`."], + [r"Expected argument 'foo' to be of type 'bool', got 'float64'."], None, ), ( kw_or_pos_arg_func_type, [], {"bar": bool_type}, - [r"Got unexpected keyword argument `bar`."], + [r"Got unexpected keyword argument 'bar'."], None, ), # function with keyword-only argument - (kw_only_arg_func_type, [], {}, [r"Missing required keyword argument `foo`."], None), + (kw_only_arg_func_type, [], {}, [r"Missing required keyword argument 'foo'."], None), (kw_only_arg_func_type, [], {"foo": bool_type}, [], ts.VoidType()), ( kw_only_arg_func_type, [], {"foo": float_type}, - [r"Expected keyword argument `foo` to be of type `bool`, but got `float64`."], + [r"Expected keyword argument 'foo' to be of type 'bool', got 'float64'."], None, ), ( kw_only_arg_func_type, [], {"bar": bool_type}, - [r"Got unexpected keyword argument `bar`."], + [r"Got unexpected keyword argument 'bar'."], None, ), # function with positional, keyword-or-positional, and keyword-only argument @@ -232,9 +232,9 @@ def callable_type_info_cases(): [], {}, [ - r"Missing 1 required positional argument: `foo`", + r"Missing 1 required positional argument: 'foo'", r"Function takes 2 positional arguments, but 0 were given.", - r"Missing required keyword argument `bar`", + r"Missing required keyword argument 'bar'", ], None, ), @@ -244,7 +244,7 @@ def callable_type_info_cases(): {}, [ r"Function takes 2 positional arguments, but 1 were given.", - r"Missing required keyword argument `bar`", + r"Missing required keyword argument 'bar'", ], None, ), @@ -252,14 +252,14 @@ def callable_type_info_cases(): pos_arg_and_kw_or_pos_arg_and_kw_only_arg_func_type, [bool_type], {"foo": int_type}, - [r"Missing required keyword argument `bar`"], + [r"Missing required keyword argument 'bar'"], None, ), ( pos_arg_and_kw_or_pos_arg_and_kw_only_arg_func_type, [bool_type], {"foo": int_type}, - [r"Missing required keyword argument `bar`"], + [r"Missing required keyword argument 'bar'"], None, ), ( @@ -274,9 +274,9 @@ def callable_type_info_cases(): [int_type], {"bar": bool_type, "foo": bool_type}, [ - r"Expected 1st argument to be of type `bool`, but got `int64`", - r"Expected argument `foo` to be of type `int64`, but got `bool`", - r"Expected keyword argument `bar` to be of type `float64`, but got `bool`", + r"Expected 1st argument to be of type 'bool', got 'int64'", + r"Expected argument 'foo' to be of type 'int64', got 'bool'", + r"Expected keyword argument 'bar' to be of type 'float64', got 'bool'", ], None, ), @@ -299,7 +299,7 @@ def callable_type_info_cases(): [ts.TupleType(types=[float_type, field_type])], {}, [ - r"Expected 1st argument to be of type `tuple\[bool, Field\[\[I\], float64\]\]`, but got `tuple\[float64, Field\[\[I\], float64\]\]`" + r"Expected 1st argument to be of type 'tuple\[bool, Field\[\[I\], float64\]\]', got 'tuple\[float64, Field\[\[I\], float64\]\]'" ], ts.VoidType(), ), @@ -308,7 +308,7 @@ def callable_type_info_cases(): [int_type], {}, [ - r"Expected 1st argument to be of type `tuple\[bool, Field\[\[I\], float64\]\]`, but got `int64`" + r"Expected 1st argument to be of type 'tuple\[bool, Field\[\[I\], float64\]\]', got 'int64'" ], ts.VoidType(), ), @@ -330,8 +330,8 @@ def callable_type_info_cases(): ], {}, [ - r"Expected argument `a` to be of type `Field\[\[K\], int64\]`, but got `Field\[\[K\], float64\]`", - r"Expected argument `b` to be of type `Field\[\[K\], int64\]`, but got `Field\[\[K\], float64\]`", + r"Expected argument 'a' to be of type 'Field\[\[K\], int64\]', got 'Field\[\[K\], float64\]'", + r"Expected argument 'b' to be of type 'Field\[\[K\], int64\]', got 'Field\[\[K\], float64\]'", ], ts.FieldType(dims=[KDim], dtype=float_type), ), @@ -393,8 +393,8 @@ def callable_type_info_cases(): ], {}, [ - r"Expected argument `a` to be of type `tuple\[Field\[\[I, J, K\], int64\], " - r"Field\[\[\.\.\.\], int64\]\]`, but got `tuple\[Field\[\[I, J, K\], int64\]\]`." + r"Expected argument 'a' to be of type 'tuple\[Field\[\[I, J, K\], int64\], " + r"Field\[\[\.\.\.\], int64\]\]', got 'tuple\[Field\[\[I, J, K\], int64\]\]'." ], ts.FieldType(dims=[IDim, JDim, KDim], dtype=float_type), ), @@ -491,7 +491,7 @@ def add_bools(a: Field[[TDim], bool], b: Field[[TDim], bool]): with pytest.raises( errors.DSLError, - match=(r"Type Field\[\[TDim\], bool\] can not be used in operator `\+`!"), + match=(r"Type 'Field\[\[TDim\], bool\]' can not be used in operator '\+'."), ): _ = FieldOperatorParser.apply_to_function(add_bools) @@ -507,7 +507,7 @@ def nonmatching(a: Field[[X], float64], b: Field[[Y], float64]): with pytest.raises( errors.DSLError, match=( - r"Could not promote `Field\[\[X], float64\]` and `Field\[\[Y\], float64\]` to common type in call to +." + r"Could not promote 'Field\[\[X], float64\]' and 'Field\[\[Y\], float64\]' to common type in call to +." ), ): _ = FieldOperatorParser.apply_to_function(nonmatching) @@ -519,7 +519,7 @@ def float_bitop(a: Field[[TDim], float], b: Field[[TDim], float]): with pytest.raises( errors.DSLError, - match=(r"Type Field\[\[TDim\], float64\] can not be used in operator `\&`!"), + match=(r"Type 'Field\[\[TDim\], float64\]' can not be used in operator '\&'."), ): _ = FieldOperatorParser.apply_to_function(float_bitop) @@ -530,7 +530,7 @@ def sign_bool(a: Field[[TDim], bool]): with pytest.raises( errors.DSLError, - match=r"Incompatible type for unary operator `\-`: `Field\[\[TDim\], bool\]`!", + match=r"Incompatible type for unary operator '\-': 'Field\[\[TDim\], bool\]'.", ): _ = FieldOperatorParser.apply_to_function(sign_bool) @@ -541,7 +541,7 @@ def not_int(a: Field[[TDim], int64]): with pytest.raises( errors.DSLError, - match=r"Incompatible type for unary operator `not`: `Field\[\[TDim\], int64\]`!", + match=r"Incompatible type for unary operator 'not': 'Field\[\[TDim\], int64\]'.", ): _ = FieldOperatorParser.apply_to_function(not_int) @@ -613,7 +613,7 @@ def mismatched_lit() -> Field[[TDim], "float32"]: with pytest.raises( errors.DSLError, - match=(r"Could not promote `float32` and `float64` to common type in call to +."), + match=(r"Could not promote 'float32' and 'float64' to common type in call to +."), ): _ = FieldOperatorParser.apply_to_function(mismatched_lit) @@ -643,7 +643,7 @@ def disjoint_broadcast(a: Field[[ADim], float64]): with pytest.raises( errors.DSLError, - match=r"Expected broadcast dimension is missing", + match=r"expected broadcast dimension\(s\) \'.*\' missing", ): _ = FieldOperatorParser.apply_to_function(disjoint_broadcast) @@ -658,7 +658,7 @@ def badtype_broadcast(a: Field[[ADim], float64]): with pytest.raises( errors.DSLError, - match=r"Expected all broadcast dimensions to be of type Dimension.", + match=r"expected all broadcast dimensions to be of type 'Dimension'.", ): _ = FieldOperatorParser.apply_to_function(badtype_broadcast) @@ -778,7 +778,7 @@ def simple_astype(a: Field[[TDim], float64]): with pytest.raises( errors.DSLError, - match=r"Invalid call to `astype`. Second argument must be a scalar type, but got.", + match=r"Invalid call to 'astype': second argument must be a scalar type, got.", ): _ = FieldOperatorParser.apply_to_function(simple_astype) @@ -806,7 +806,7 @@ def modulo_floats(inp: Field[[TDim], float]): with pytest.raises( errors.DSLError, - match=r"Type float64 can not be used in operator `%`", + match=r"Type 'float64' can not be used in operator '%'", ): _ = FieldOperatorParser.apply_to_function(modulo_floats) @@ -844,6 +844,6 @@ def as_offset_dtype(a: Field[[ADim, BDim], float], b: Field[[BDim], float]): with pytest.raises( errors.DSLError, - match=f"Excepted integer for offset field dtype", + match=f"expected integer for offset field dtype", ): _ = FieldOperatorParser.apply_to_function(as_offset_dtype) diff --git a/tests/next_tests/integration_tests/feature_tests/iterator_tests/test_builtins.py b/tests/next_tests/integration_tests/feature_tests/iterator_tests/test_builtins.py index c0d565bbf4..d3f3f35699 100644 --- a/tests/next_tests/integration_tests/feature_tests/iterator_tests/test_builtins.py +++ b/tests/next_tests/integration_tests/feature_tests/iterator_tests/test_builtins.py @@ -126,7 +126,7 @@ def fenimpl(size, arg0, arg1, arg2, out): closure(cartesian_domain(named_range(IDim, 0, size)), dispatch, out, [arg0, arg1, arg2]) else: - raise AssertionError("Add overload") + raise AssertionError("Add overload.") return run_processor(fenimpl, processor, out.shape[0], *inps, out) diff --git a/tests/next_tests/unit_tests/conftest.py b/tests/next_tests/unit_tests/conftest.py index 6f91557e46..4177a5aeee 100644 --- a/tests/next_tests/unit_tests/conftest.py +++ b/tests/next_tests/unit_tests/conftest.py @@ -109,7 +109,7 @@ def run_processor( elif ppi.is_processor_kind(processor, ppi.ProgramFormatter): print(program.format_itir(*args, formatter=processor, **kwargs)) else: - raise TypeError(f"program processor kind not recognized: {processor}!") + raise TypeError(f"program processor kind not recognized: '{processor}'.") @dataclasses.dataclass diff --git a/tests/next_tests/unit_tests/embedded_tests/test_nd_array_field.py b/tests/next_tests/unit_tests/embedded_tests/test_nd_array_field.py index 2b78eb9114..1a38e5245e 100644 --- a/tests/next_tests/unit_tests/embedded_tests/test_nd_array_field.py +++ b/tests/next_tests/unit_tests/embedded_tests/test_nd_array_field.py @@ -631,5 +631,5 @@ def test_setitem_wrong_domain(): np.ones((10,)) * 42.0, domain=common.Domain((JDim, UnitRange(-5, 5))) ) - with pytest.raises(ValueError, match=r"Incompatible `Domain`.*"): + with pytest.raises(ValueError, match=r"Incompatible 'Domain'.*"): field[(1, slice(None))] = value_incompatible diff --git a/tests/next_tests/unit_tests/ffront_tests/test_func_to_foast.py b/tests/next_tests/unit_tests/ffront_tests/test_func_to_foast.py index e5bbed19fd..96ecc19c0b 100644 --- a/tests/next_tests/unit_tests/ffront_tests/test_func_to_foast.py +++ b/tests/next_tests/unit_tests/ffront_tests/test_func_to_foast.py @@ -88,7 +88,7 @@ def mistyped(inp: gtx.Field): with pytest.raises( ValueError, - match="Field type requires two arguments, got 0!", + match="Field type requires two arguments, got 0.", ): _ = FieldOperatorParser.apply_to_function(mistyped) @@ -245,7 +245,7 @@ def conditional_wrong_mask_type( ) -> gtx.Field[[TDim], float64]: return where(a, a, a) - msg = r"Expected a field with dtype `bool`." + msg = r"expected a field with dtype 'bool'" with pytest.raises(errors.DSLError, match=msg): _ = FieldOperatorParser.apply_to_function(conditional_wrong_mask_type) @@ -269,7 +269,7 @@ def test_ternary_with_field_condition(): def ternary_with_field_condition(cond: gtx.Field[[], bool]): return 1 if cond else 2 - with pytest.raises(errors.DSLError, match=r"should be .* `bool`"): + with pytest.raises(errors.DSLError, match=r"should be .* 'bool'"): _ = FieldOperatorParser.apply_to_function(ternary_with_field_condition) @@ -288,7 +288,7 @@ def test_adr13_wrong_return_type_annotation(): def wrong_return_type_annotation() -> gtx.Field[[], float]: return 1.0 - with pytest.raises(errors.DSLError, match=r"Expected `float.*`"): + with pytest.raises(errors.DSLError, match=r"expected 'float.*'"): _ = FieldOperatorParser.apply_to_function(wrong_return_type_annotation) @@ -395,8 +395,6 @@ def zero_dims_ternary( ): return a if cond == 1 else b - msg = r"Incompatible datatypes in operator `==`" - with pytest.raises(errors.DSLError) as exc_info: + msg = r"Incompatible datatypes in operator '=='" + with pytest.raises(errors.DSLError, match=msg): _ = FieldOperatorParser.apply_to_function(zero_dims_ternary) - - assert re.search(msg, exc_info.value.args[0]) is not None diff --git a/tests/next_tests/unit_tests/ffront_tests/test_func_to_past.py b/tests/next_tests/unit_tests/ffront_tests/test_func_to_past.py index 1d1a1efad4..cca05f9917 100644 --- a/tests/next_tests/unit_tests/ffront_tests/test_func_to_past.py +++ b/tests/next_tests/unit_tests/ffront_tests/test_func_to_past.py @@ -113,7 +113,7 @@ def undefined_field_program(in_field: gtx.Field[[IDim], "float64"]): with pytest.raises( errors.DSLError, - match=(r"Undeclared or untyped symbol `out_field`."), + match=(r"Undeclared or untyped symbol 'out_field'."), ): ProgramParser.apply_to_function(undefined_field_program) @@ -165,10 +165,10 @@ def domain_format_1_program(in_field: gtx.Field[[IDim], float64]): ) as exc_info: ProgramParser.apply_to_function(domain_format_1_program) - assert exc_info.match("Invalid call to `domain_format_1`") + assert exc_info.match("Invalid call to 'domain_format_1'") assert ( - re.search("Only Dictionaries allowed in domain", exc_info.value.__cause__.args[0]) + re.search("Only Dictionaries allowed in 'domain'", exc_info.value.__cause__.args[0]) is not None ) @@ -184,7 +184,7 @@ def domain_format_2_program(in_field: gtx.Field[[IDim], float64]): ) as exc_info: ProgramParser.apply_to_function(domain_format_2_program) - assert exc_info.match("Invalid call to `domain_format_2`") + assert exc_info.match("Invalid call to 'domain_format_2'") assert ( re.search("Only 2 values allowed in domain range", exc_info.value.__cause__.args[0]) @@ -203,10 +203,10 @@ def domain_format_3_program(in_field: gtx.Field[[IDim], float64]): ) as exc_info: ProgramParser.apply_to_function(domain_format_3_program) - assert exc_info.match("Invalid call to `domain_format_3`") + assert exc_info.match("Invalid call to 'domain_format_3'") assert ( - re.search(r"Missing required keyword argument\(s\) `out`", exc_info.value.__cause__.args[0]) + re.search(r"Missing required keyword argument\ 'out'", exc_info.value.__cause__.args[0]) is not None ) @@ -224,7 +224,7 @@ def domain_format_4_program(in_field: gtx.Field[[IDim], float64]): ) as exc_info: ProgramParser.apply_to_function(domain_format_4_program) - assert exc_info.match("Invalid call to `domain_format_4`") + assert exc_info.match("Invalid call to 'domain_format_4'") assert ( re.search("Either only domain or slicing allowed", exc_info.value.__cause__.args[0]) @@ -243,7 +243,7 @@ def domain_format_5_program(in_field: gtx.Field[[IDim], float64]): ) as exc_info: ProgramParser.apply_to_function(domain_format_5_program) - assert exc_info.match("Invalid call to `domain_format_5`") + assert exc_info.match("Invalid call to 'domain_format_5'") assert ( re.search("Only integer values allowed in domain range", exc_info.value.__cause__.args[0]) @@ -262,6 +262,6 @@ def domain_format_6_program(in_field: gtx.Field[[IDim], float64]): ) as exc_info: ProgramParser.apply_to_function(domain_format_6_program) - assert exc_info.match("Invalid call to `domain_format_6`") + assert exc_info.match("Invalid call to 'domain_format_6'") assert re.search("Empty domain not allowed.", exc_info.value.__cause__.args[0]) is not None diff --git a/tests/next_tests/unit_tests/ffront_tests/test_past_to_itir.py b/tests/next_tests/unit_tests/ffront_tests/test_past_to_itir.py index c4fe30c596..a1a7b79cec 100644 --- a/tests/next_tests/unit_tests/ffront_tests/test_past_to_itir.py +++ b/tests/next_tests/unit_tests/ffront_tests/test_past_to_itir.py @@ -177,7 +177,7 @@ def test_invalid_call_sig_program(invalid_call_sig_program_def): grid_type=gtx.GridType.CARTESIAN, ) - assert exc_info.match("Invalid call to `identity`") + assert exc_info.match("Invalid call to 'identity'") # TODO(tehrengruber): re-enable again when call signature check doesn't return # immediately after missing `out` argument # assert ( @@ -187,6 +187,6 @@ def test_invalid_call_sig_program(invalid_call_sig_program_def): # is not None # ) assert ( - re.search(r"Missing required keyword argument\(s\) `out`", exc_info.value.__cause__.args[0]) + re.search(r"Missing required keyword argument 'out'", exc_info.value.__cause__.args[0]) is not None ) diff --git a/tests/next_tests/unit_tests/iterator_tests/test_runtime_domain.py b/tests/next_tests/unit_tests/iterator_tests/test_runtime_domain.py index 232995be58..73ad24f42b 100644 --- a/tests/next_tests/unit_tests/iterator_tests/test_runtime_domain.py +++ b/tests/next_tests/unit_tests/iterator_tests/test_runtime_domain.py @@ -56,7 +56,7 @@ def test_embedded_error_on_wrong_domain(): 1, ), ) - with pytest.raises(RuntimeError, match="expected `UnstructuredDomain`"): + with pytest.raises(RuntimeError, match="expected 'UnstructuredDomain'"): foo[dom]( gtx.as_field([I], np.zeros((1,))), out=out, diff --git a/tests/next_tests/unit_tests/program_processor_tests/test_processor_interface.py b/tests/next_tests/unit_tests/program_processor_tests/test_processor_interface.py index 05e982cf0c..1ba35da7c6 100644 --- a/tests/next_tests/unit_tests/program_processor_tests/test_processor_interface.py +++ b/tests/next_tests/unit_tests/program_processor_tests/test_processor_interface.py @@ -74,12 +74,12 @@ def test_undecorated_formatter_function_is_not_recognized(): def undecorated_formatter(fencil: itir.FencilDefinition, *args, **kwargs) -> str: return "" - with pytest.raises(TypeError, match="is not a ProgramFormatter"): + with pytest.raises(TypeError, match="is not a 'ProgramFormatter'"): ensure_processor_kind(undecorated_formatter, ProgramFormatter) def test_wrong_processor_type_is_caught_at_runtime(dummy_formatter): - with pytest.raises(TypeError, match="is not a ProgramExecutor"): + with pytest.raises(TypeError, match="is not a 'ProgramExecutor'"): ensure_processor_kind(dummy_formatter, ProgramExecutor) diff --git a/tests/next_tests/unit_tests/test_allocators.py b/tests/next_tests/unit_tests/test_allocators.py index 456654c1d0..599bea75e7 100644 --- a/tests/next_tests/unit_tests/test_allocators.py +++ b/tests/next_tests/unit_tests/test_allocators.py @@ -108,7 +108,7 @@ def test_get_allocator(): with pytest.raises( TypeError, - match=f"Object {invalid_obj} is neither a field allocator nor a field allocator factory", + match=f"Object '{invalid_obj}' is neither a field allocator nor a field allocator factory", ): next_allocators.get_allocator(invalid_obj, strict=True) diff --git a/tests/next_tests/unit_tests/test_common.py b/tests/next_tests/unit_tests/test_common.py index da63536953..bafabfb56e 100644 --- a/tests/next_tests/unit_tests/test_common.py +++ b/tests/next_tests/unit_tests/test_common.py @@ -96,7 +96,7 @@ def test_unit_range_slice_error(rng): def test_unit_range_set_intersection(rng): with pytest.raises( - NotImplementedError, match="Can only find the intersection between UnitRange instances." + NotImplementedError, match="Can only find the intersection between 'UnitRange' instances." ): rng & {1, 5} diff --git a/tests/next_tests/unit_tests/test_constructors.py b/tests/next_tests/unit_tests/test_constructors.py index e8b070f0c0..8d95c9951f 100644 --- a/tests/next_tests/unit_tests/test_constructors.py +++ b/tests/next_tests/unit_tests/test_constructors.py @@ -139,7 +139,7 @@ def test_as_field_origin(): def test_field_wrong_dims(): with pytest.raises( ValueError, - match=(r"Cannot construct `Field` from array of shape"), + match=(r"Cannot construct 'Field' from array of shape"), ): gtx.as_field([I, J], np.random.rand(sizes[I]).astype(gtx.float32)) @@ -147,7 +147,7 @@ def test_field_wrong_dims(): def test_field_wrong_domain(): with pytest.raises( ValueError, - match=(r"Cannot construct `Field` from array of shape"), + match=(r"Cannot construct 'Field' from array of shape"), ): domain = common.Domain( dims=(I, J), diff --git a/tests/next_tests/unit_tests/type_system_tests/test_type_translation.py b/tests/next_tests/unit_tests/type_system_tests/test_type_translation.py index d281f5cd90..0a0b747a28 100644 --- a/tests/next_tests/unit_tests/type_system_tests/test_type_translation.py +++ b/tests/next_tests/unit_tests/type_system_tests/test_type_translation.py @@ -158,7 +158,7 @@ def test_invalid_symbol_types(): type_translation.from_type_hint(common.Field[[IDim], None]) # Functions - with pytest.raises(ValueError, match="Not annotated functions are not supported"): + with pytest.raises(ValueError, match="Unannotated functions are not supported"): type_translation.from_type_hint(typing.Callable) with pytest.raises(ValueError, match="Invalid callable annotations"): From 0d66829d8c68b89a620c87fa3fbc8f5b64287d27 Mon Sep 17 00:00:00 2001 From: Till Ehrengruber Date: Thu, 14 Dec 2023 11:45:16 +0100 Subject: [PATCH 30/85] docs[next]: Partially fix Quickstart Guide (#1390) Changes to the quickstart guide to use `field.asnumpy()` (introduced in #1366) instead of `np.asarray(field)`. The quickstart guide is still broken though since the embedded backend (used by default) does not support skip neighbors connectivities. --- docs/user/next/QuickstartGuide.md | 16 ++++++++-------- 1 file changed, 8 insertions(+), 8 deletions(-) diff --git a/docs/user/next/QuickstartGuide.md b/docs/user/next/QuickstartGuide.md index 1ae1db4d92..dc70f804fd 100644 --- a/docs/user/next/QuickstartGuide.md +++ b/docs/user/next/QuickstartGuide.md @@ -102,7 +102,7 @@ You can call field operators from [programs](#Programs), other field operators, result = gtx.as_field([CellDim, KDim], np.zeros(shape=grid_shape)) add(a, b, out=result, offset_provider={}) -print("{} + {} = {} ± {}".format(a_value, b_value, np.average(np.asarray(result)), np.std(np.asarray(result)))) +print("{} + {} = {} ± {}".format(a_value, b_value, np.average(result.asnumpy()), np.std(result.asnumpy()))) ``` #### Programs @@ -128,7 +128,7 @@ You can execute the program by simply calling it: result = gtx.as_field([CellDim, KDim], np.zeros(shape=grid_shape)) run_add(a, b, result, offset_provider={}) -print("{} + {} = {} ± {}".format(b_value, (a_value + b_value), np.average(np.asarray(result)), np.std(np.asarray(result)))) +print("{} + {} = {} ± {}".format(b_value, (a_value + b_value), np.average(result.asnumpy()), np.std(result.asnumpy()))) ``` #### Composing field operators and programs @@ -256,7 +256,7 @@ def run_nearest_cell_to_edge(cell_values: gtx.Field[[CellDim], float64], out : g run_nearest_cell_to_edge(cell_values, edge_values, offset_provider={"E2C": E2C_offset_provider}) -print("0th adjacent cell's value: {}".format(np.asarray(edge_values))) +print("0th adjacent cell's value: {}".format(edge_values.asnumpy())) ``` Running the above snippet results in the following edge field: @@ -283,7 +283,7 @@ def run_sum_adjacent_cells(cells : gtx.Field[[CellDim], float64], out : gtx.Fiel run_sum_adjacent_cells(cell_values, edge_values, offset_provider={"E2C": E2C_offset_provider}) -print("sum of adjacent cells: {}".format(np.asarray(edge_values))) +print("sum of adjacent cells: {}".format(edge_values.asnumpy())) ``` For the border edges, the results are unchanged compared to the previous example, but the inner edges now contain the sum of the two adjacent cells: @@ -317,7 +317,7 @@ def conditional(mask: gtx.Field[[CellDim, KDim], bool], a: gtx.Field[[CellDim, K return where(mask, a, b) conditional(mask, a, b, out=result_where, offset_provider={}) -print("where return: {}".format(np.asarray(result_where))) +print("where return: {}".format(result_where.asnumpy())) ``` **Tuple implementation:** @@ -340,7 +340,7 @@ result_1: gtx.Field[[CellDim, KDim], float64], result_2: gtx.Field[[CellDim, KDi _conditional_tuple(mask, a, b, out=(result_1, result_2)) conditional_tuple(mask, a, b, result_1, result_2, offset_provider={}) -print("where tuple return: {}".format((np.asarray(result_1), np.asarray(result_2)))) +print("where tuple return: {}".format((result_1.asnumpy(), result_2.asnumpy()))) ``` The `where` builtin also allows for nesting of tuples. In this scenario, it will first perform an unrolling: @@ -375,7 +375,7 @@ def conditional_tuple_nested( _conditional_tuple_nested(mask, a, b, c, d, out=((result_1, result_2), (result_2, result_1))) conditional_tuple_nested(mask, a, b, c, d, result_1, result_2, offset_provider={}) -print("where nested tuple return: {}".format(((np.asarray(result_1), np.asarray(result_2)), (np.asarray(result_2), np.asarray(result_1))))) +print("where nested tuple return: {}".format(((result_1.asnumpy(), result_2.asnumpy()), (result_2.asnumpy(), result_1.asnumpy())))) ``` #### Implementing the pseudo-laplacian @@ -447,7 +447,7 @@ run_pseudo_laplacian(cell_values, result_pseudo_lap, offset_provider={"E2C": E2C_offset_provider, "C2E": C2E_offset_provider}) -print("pseudo-laplacian: {}".format(np.asarray(result_pseudo_lap))) +print("pseudo-laplacian: {}".format(result_pseudo_lap.asnumpy())) ``` As a closure, here is an example of chaining field operators, which is very simple to do when working with fields. The field operator below executes the pseudo-laplacian, and then calls the pseudo-laplacian on the result of the first, in effect, calculating the laplacian of a laplacian. From cdcd6537bbc05b050a25ae6abea5b69490ed87db Mon Sep 17 00:00:00 2001 From: Hannes Vogt Date: Mon, 18 Dec 2023 08:52:29 +0100 Subject: [PATCH 31/85] feat[next]: Add missing UnitRange comparison functions (#1363) - Introduce a better Infinity - Make UnitRange Generic to express finite, infinite, left-finite, right-finite properly. - Remove `Set` from UnitRange --- src/gt4py/next/common.py | 228 ++++++++++++------ src/gt4py/next/embedded/common.py | 1 + src/gt4py/next/embedded/nd_array_field.py | 25 +- src/gt4py/next/ffront/fbuiltins.py | 2 +- src/gt4py/next/iterator/embedded.py | 2 +- .../runners/dace_iterator/__init__.py | 44 ++-- .../embedded_tests/test_nd_array_field.py | 9 +- tests/next_tests/unit_tests/test_common.py | 138 ++++++++--- 8 files changed, 305 insertions(+), 144 deletions(-) diff --git a/src/gt4py/next/common.py b/src/gt4py/next/common.py index 3e1fe52f31..29d606ccc0 100644 --- a/src/gt4py/next/common.py +++ b/src/gt4py/next/common.py @@ -20,9 +20,8 @@ import enum import functools import numbers -import sys import types -from collections.abc import Mapping, Sequence, Set +from collections.abc import Mapping, Sequence import numpy as np import numpy.typing as npt @@ -33,10 +32,12 @@ Any, Callable, ClassVar, + Generic, Never, Optional, ParamSpec, Protocol, + Self, TypeAlias, TypeGuard, TypeVar, @@ -52,16 +53,6 @@ DimsT = TypeVar("DimsT", bound=Sequence["Dimension"], covariant=True) -class Infinity(int): - @classmethod - def positive(cls) -> Infinity: - return cls(sys.maxsize) - - @classmethod - def negative(cls) -> Infinity: - return cls(-sys.maxsize) - - Tag: TypeAlias = str @@ -84,31 +75,86 @@ def __str__(self): return f"{self.value}[{self.kind}]" +class Infinity(enum.Enum): + """Describes an unbounded `UnitRange`.""" + + NEGATIVE = enum.auto() + POSITIVE = enum.auto() + + def __add__(self, _: int) -> Self: + return self + + __radd__ = __add__ + + def __sub__(self, _: int) -> Self: + return self + + __rsub__ = __sub__ + + def __le__(self, other: int | Infinity) -> bool: + return self is self.NEGATIVE or other is self.POSITIVE + + def __lt__(self, other: int | Infinity) -> bool: + return self is self.NEGATIVE and other is not self + + def __ge__(self, other: int | Infinity) -> bool: + return self is self.POSITIVE or other is self.NEGATIVE + + def __gt__(self, other: int | Infinity) -> bool: + return self is self.POSITIVE and other is not self + + +def _as_int(v: core_defs.IntegralScalar | Infinity) -> int | Infinity: + return v if isinstance(v, Infinity) else int(v) + + +_Left = TypeVar("_Left", int, Infinity) +_Right = TypeVar("_Right", int, Infinity) + + @dataclasses.dataclass(frozen=True, init=False) -class UnitRange(Sequence[int], Set[int]): +class UnitRange(Sequence[int], Generic[_Left, _Right]): """Range from `start` to `stop` with step size one.""" - start: int - stop: int + start: _Left + stop: _Right - def __init__(self, start: core_defs.IntegralScalar, stop: core_defs.IntegralScalar) -> None: + def __init__( + self, start: core_defs.IntegralScalar | Infinity, stop: core_defs.IntegralScalar | Infinity + ) -> None: if start < stop: - object.__setattr__(self, "start", int(start)) - object.__setattr__(self, "stop", int(stop)) + object.__setattr__(self, "start", _as_int(start)) + object.__setattr__(self, "stop", _as_int(stop)) else: # make UnitRange(0,0) the single empty UnitRange object.__setattr__(self, "start", 0) object.__setattr__(self, "stop", 0) - # TODO: the whole infinity idea and implementation is broken and should be replaced @classmethod - def infinity(cls) -> UnitRange: - return cls(Infinity.negative(), Infinity.positive()) + def infinite( + cls, + ) -> UnitRange: + return cls(Infinity.NEGATIVE, Infinity.POSITIVE) def __len__(self) -> int: - if Infinity.positive() in (abs(self.start), abs(self.stop)): - return Infinity.positive() - return max(0, self.stop - self.start) + if UnitRange.is_finite(self): + return max(0, self.stop - self.start) + raise ValueError("Cannot compute length of open 'UnitRange'.") + + @classmethod + def is_finite(cls, obj: UnitRange) -> TypeGuard[FiniteUnitRange]: + # classmethod since TypeGuards requires the guarded obj as separate argument + return obj.start is not Infinity.NEGATIVE and obj.stop is not Infinity.POSITIVE + + @classmethod + def is_right_finite(cls, obj: UnitRange) -> TypeGuard[UnitRange[_Left, int]]: + # classmethod since TypeGuards requires the guarded obj as separate argument + return obj.stop is not Infinity.POSITIVE + + @classmethod + def is_left_finite(cls, obj: UnitRange) -> TypeGuard[UnitRange[int, _Right]]: + # classmethod since TypeGuards requires the guarded obj as separate argument + return obj.start is not Infinity.NEGATIVE def __repr__(self) -> str: return f"UnitRange({self.start}, {self.stop})" @@ -122,6 +168,7 @@ def __getitem__(self, index: slice) -> UnitRange: # noqa: F811 # redefine unuse ... def __getitem__(self, index: int | slice) -> int | UnitRange: # noqa: F811 # redefine unused + assert UnitRange.is_finite(self) if isinstance(index, slice): start, stop, step = index.indices(len(self)) if step != 1: @@ -138,61 +185,60 @@ def __getitem__(self, index: int | slice) -> int | UnitRange: # noqa: F811 # re else: raise IndexError("'UnitRange' index out of range") - def __and__(self, other: Set[int]) -> UnitRange: - if isinstance(other, UnitRange): - start = max(self.start, other.start) - stop = min(self.stop, other.stop) - return UnitRange(start, stop) - else: - raise NotImplementedError( - "Can only find the intersection between 'UnitRange' instances." - ) + def __and__(self, other: UnitRange) -> UnitRange: + return UnitRange(max(self.start, other.start), min(self.stop, other.stop)) + + def __contains__(self, value: Any) -> bool: + return ( + isinstance(value, core_defs.INTEGRAL_TYPES) + and value >= self.start + and value < self.stop + ) + + def __le__(self, other: UnitRange) -> bool: + return self.start >= other.start and self.stop <= other.stop + + def __lt__(self, other: UnitRange) -> bool: + return (self.start > other.start and self.stop <= other.stop) or ( + self.start >= other.start and self.stop < other.stop + ) + + def __ge__(self, other: UnitRange) -> bool: + return self.start <= other.start and self.stop >= other.stop - def __le__(self, other: Set[int]): + def __gt__(self, other: UnitRange) -> bool: + return (self.start < other.start and self.stop >= other.stop) or ( + self.start <= other.start and self.stop > other.stop + ) + + def __eq__(self, other: Any) -> bool: if isinstance(other, UnitRange): - return self.start >= other.start and self.stop <= other.stop - elif len(self) == Infinity.positive(): - return False - else: - return Set.__le__(self, other) - - def __add__(self, other: int | Set[int]) -> UnitRange: - if isinstance(other, int): - if other == Infinity.positive(): - return UnitRange.infinity() - elif other == Infinity.negative(): - return UnitRange(0, 0) - return UnitRange( - *( - s if s in [Infinity.negative(), Infinity.positive()] else s + other - for s in (self.start, self.stop) - ) - ) - else: - raise NotImplementedError("Can only compute union with 'int' instances.") - - def __sub__(self, other: int | Set[int]) -> UnitRange: - if isinstance(other, int): - if other == Infinity.negative(): - return self + Infinity.positive() - elif other == Infinity.positive(): - return self + Infinity.negative() - else: - return self + (-other) + return self.start == other.start and self.stop == other.stop else: - raise NotImplementedError("Can only compute substraction with 'int' instances.") + return False + + def __ne__(self, other: Any) -> bool: + return not self.__eq__(other) - __ge__ = __lt__ = __gt__ = lambda self, other: NotImplemented + def __add__(self, other: int) -> UnitRange: + return UnitRange(self.start + other, self.stop + other) + + def __sub__(self, other: int) -> UnitRange: + return UnitRange(self.start - other, self.stop - other) def __str__(self) -> str: return f"({self.start}:{self.stop})" +FiniteUnitRange: TypeAlias = UnitRange[int, int] + + RangeLike: TypeAlias = ( UnitRange | range | tuple[core_defs.IntegralScalar, core_defs.IntegralScalar] | core_defs.IntegralScalar + | None ) @@ -207,18 +253,23 @@ def unit_range(r: RangeLike) -> UnitRange: # once the related mypy bug (#16358) gets fixed if ( isinstance(r, tuple) - and isinstance(r[0], core_defs.INTEGRAL_TYPES) - and isinstance(r[1], core_defs.INTEGRAL_TYPES) + and (isinstance(r[0], core_defs.INTEGRAL_TYPES) or r[0] in (None, Infinity.NEGATIVE)) + and (isinstance(r[1], core_defs.INTEGRAL_TYPES) or r[1] in (None, Infinity.POSITIVE)) ): - return UnitRange(r[0], r[1]) + start = r[0] if r[0] is not None else Infinity.NEGATIVE + stop = r[1] if r[1] is not None else Infinity.POSITIVE + return UnitRange(start, stop) if isinstance(r, core_defs.INTEGRAL_TYPES): return UnitRange(0, cast(core_defs.IntegralScalar, r)) + if r is None: + return UnitRange.infinite() raise ValueError(f"'{r!r}' cannot be interpreted as 'UnitRange'.") IntIndex: TypeAlias = int | core_defs.IntegralScalar NamedIndex: TypeAlias = tuple[Dimension, IntIndex] # TODO: convert to NamedTuple NamedRange: TypeAlias = tuple[Dimension, UnitRange] # TODO: convert to NamedTuple +FiniteNamedRange: TypeAlias = tuple[Dimension, FiniteUnitRange] # TODO: convert to NamedTuple RelativeIndexElement: TypeAlias = IntIndex | slice | types.EllipsisType AbsoluteIndexElement: TypeAlias = NamedIndex | NamedRange AnyIndexElement: TypeAlias = RelativeIndexElement | AbsoluteIndexElement @@ -245,6 +296,10 @@ def is_named_range(v: AnyIndexSpec) -> TypeGuard[NamedRange]: ) +def is_finite_named_range(v: NamedRange) -> TypeGuard[FiniteNamedRange]: + return UnitRange.is_finite(v[1]) + + def is_named_index(v: AnyIndexSpec) -> TypeGuard[NamedRange]: return ( isinstance(v, tuple) and len(v) == 2 and isinstance(v[0], Dimension) and is_int_index(v[1]) @@ -283,18 +338,27 @@ def named_range(v: tuple[Dimension, RangeLike]) -> NamedRange: return (v[0], unit_range(v[1])) +_Rng = TypeVar( + "_Rng", + UnitRange[int, int], + UnitRange[Infinity, int], + UnitRange[int, Infinity], + UnitRange[Infinity, Infinity], +) + + @dataclasses.dataclass(frozen=True, init=False) -class Domain(Sequence[NamedRange]): +class Domain(Sequence[tuple[Dimension, _Rng]], Generic[_Rng]): """Describes the `Domain` of a `Field` as a `Sequence` of `NamedRange` s.""" dims: tuple[Dimension, ...] - ranges: tuple[UnitRange, ...] + ranges: tuple[_Rng, ...] def __init__( self, - *args: NamedRange, + *args: tuple[Dimension, _Rng], dims: Optional[Sequence[Dimension]] = None, - ranges: Optional[Sequence[UnitRange]] = None, + ranges: Optional[Sequence[_Rng]] = None, ) -> None: if dims is not None or ranges is not None: if dims is None and ranges is None: @@ -343,16 +407,23 @@ def ndim(self) -> int: def shape(self) -> tuple[int, ...]: return tuple(len(r) for r in self.ranges) + @classmethod + def is_finite(cls, obj: Domain) -> TypeGuard[FiniteDomain]: + # classmethod since TypeGuards requires the guarded obj as separate argument + return all(UnitRange.is_finite(rng) for rng in obj.ranges) + @overload - def __getitem__(self, index: int) -> NamedRange: + def __getitem__(self, index: int) -> tuple[Dimension, _Rng]: ... @overload - def __getitem__(self, index: slice) -> Domain: # noqa: F811 # redefine unused + def __getitem__(self, index: slice) -> Self: # noqa: F811 # redefine unused ... @overload - def __getitem__(self, index: Dimension) -> NamedRange: # noqa: F811 # redefine unused + def __getitem__( # noqa: F811 # redefine unused + self, index: Dimension + ) -> tuple[Dimension, _Rng]: ... def __getitem__( # noqa: F811 # redefine unused @@ -434,6 +505,9 @@ def replace(self, index: int | Dimension, *named_ranges: NamedRange) -> Domain: return Domain(dims=dims, ranges=ranges) +FiniteDomain: TypeAlias = Domain[FiniteUnitRange] + + DomainLike: TypeAlias = ( Sequence[tuple[Dimension, RangeLike]] | Mapping[Dimension, RangeLike] ) # `Domain` is `Sequence[NamedRange]` and therefore a subset @@ -484,7 +558,7 @@ def _broadcast_ranges( broadcast_dims: Sequence[Dimension], dims: Sequence[Dimension], ranges: Sequence[UnitRange] ) -> tuple[UnitRange, ...]: return tuple( - ranges[dims.index(d)] if d in dims else UnitRange.infinity() for d in broadcast_dims + ranges[dims.index(d)] if d in dims else UnitRange.infinite() for d in broadcast_dims ) @@ -847,7 +921,7 @@ def asnumpy(self) -> Never: @functools.cached_property def domain(self) -> Domain: - return Domain(dims=(self.dimension,), ranges=(UnitRange.infinity(),)) + return Domain(dims=(self.dimension,), ranges=(UnitRange.infinite(),)) @property def __gt_dims__(self) -> tuple[Dimension, ...]: diff --git a/src/gt4py/next/embedded/common.py b/src/gt4py/next/embedded/common.py index 87e0800a10..94efe4d61d 100644 --- a/src/gt4py/next/embedded/common.py +++ b/src/gt4py/next/embedded/common.py @@ -58,6 +58,7 @@ def _relative_sub_domain( else: # not in new domain assert common.is_int_index(idx) + assert common.UnitRange.is_finite(rng) new_index = (rng.start if idx >= 0 else rng.stop) + idx if new_index < rng.start or new_index >= rng.stop: raise embedded_exceptions.IndexOutOfBounds( diff --git a/src/gt4py/next/embedded/nd_array_field.py b/src/gt4py/next/embedded/nd_array_field.py index fbfe64ac42..8bd2673db9 100644 --- a/src/gt4py/next/embedded/nd_array_field.py +++ b/src/gt4py/next/embedded/nd_array_field.py @@ -113,6 +113,7 @@ def __gt_dims__(self) -> tuple[common.Dimension, ...]: @property def __gt_origin__(self) -> tuple[int, ...]: + assert common.Domain.is_finite(self._domain) return tuple(-r.start for _, r in self._domain) @property @@ -386,6 +387,7 @@ def inverse_image( assert isinstance(image_range, common.UnitRange) + assert common.UnitRange.is_finite(image_range) restricted_mask = (self._ndarray >= image_range.start) & ( self._ndarray < image_range.stop ) @@ -566,9 +568,7 @@ def _broadcast(field: common.Field, new_dimensions: tuple[common.Dimension, ...] named_ranges.append((dim, field.domain[pos][1])) else: domain_slice.append(np.newaxis) - named_ranges.append( - (dim, common.UnitRange(common.Infinity.negative(), common.Infinity.positive())) - ) + named_ranges.append((dim, common.UnitRange.infinite())) return common.field(field.ndarray[tuple(domain_slice)], domain=common.Domain(*named_ranges)) @@ -638,14 +638,19 @@ def _compute_slice( ValueError: If `new_rng` is not an integer or a UnitRange. """ if isinstance(rng, common.UnitRange): - if domain.ranges[pos] == common.UnitRange.infinity(): - return slice(None) - else: - return slice( - rng.start - domain.ranges[pos].start, - rng.stop - domain.ranges[pos].start, - ) + start = ( + rng.start - domain.ranges[pos].start + if common.UnitRange.is_left_finite(domain.ranges[pos]) + else None + ) + stop = ( + rng.stop - domain.ranges[pos].start + if common.UnitRange.is_right_finite(domain.ranges[pos]) + else None + ) + return slice(start, stop) elif common.is_int_index(rng): + assert common.Domain.is_finite(domain) return rng - domain.ranges[pos].start else: raise ValueError(f"Can only use integer or UnitRange ranges, provided type: '{type(rng)}'.") diff --git a/src/gt4py/next/ffront/fbuiltins.py b/src/gt4py/next/ffront/fbuiltins.py index 93f17b1eb8..278dde9180 100644 --- a/src/gt4py/next/ffront/fbuiltins.py +++ b/src/gt4py/next/ffront/fbuiltins.py @@ -192,7 +192,7 @@ def broadcast( np.asarray(field)[ tuple([np.newaxis] * len(dims)) ], # TODO(havogt) use FunctionField once available - domain=common.Domain(dims=dims, ranges=tuple([common.UnitRange.infinity()] * len(dims))), + domain=common.Domain(dims=dims, ranges=tuple([common.UnitRange.infinite()] * len(dims))), ) diff --git a/src/gt4py/next/iterator/embedded.py b/src/gt4py/next/iterator/embedded.py index a4f32929db..ef70a2e645 100644 --- a/src/gt4py/next/iterator/embedded.py +++ b/src/gt4py/next/iterator/embedded.py @@ -1059,7 +1059,7 @@ def __gt_builtin_func__(func: Callable, /) -> NoReturn: # type: ignore[override @property def domain(self) -> common.Domain: - return common.Domain((self._dimension, common.UnitRange.infinity())) + return common.Domain((self._dimension, common.UnitRange.infinite())) @property def codomain(self) -> type[core_defs.int32]: diff --git a/src/gt4py/next/program_processors/runners/dace_iterator/__init__.py b/src/gt4py/next/program_processors/runners/dace_iterator/__init__.py index 65f9d9d71a..037c4f3e4d 100644 --- a/src/gt4py/next/program_processors/runners/dace_iterator/__init__.py +++ b/src/gt4py/next/program_processors/runners/dace_iterator/__init__.py @@ -24,10 +24,9 @@ import gt4py.next.iterator.ir as itir import gt4py.next.program_processors.otf_compile_executor as otf_exec import gt4py.next.program_processors.processor_interface as ppi -from gt4py.next.common import Dimension, Domain, UnitRange, is_field -from gt4py.next.iterator.embedded import NeighborTableOffsetProvider, StridedNeighborOffsetProvider -from gt4py.next.iterator.transforms import LiftMode, apply_common_transforms -from gt4py.next.otf.compilation import cache +from gt4py.next import common +from gt4py.next.iterator import embedded as itir_embedded, transforms as itir_transforms +from gt4py.next.otf.compilation import cache as compilation_cache from gt4py.next.type_system import type_specifications as ts, type_translation from .itir_to_sdfg import ItirToSDFG @@ -40,7 +39,8 @@ cp = None -def get_sorted_dim_ranges(domain: Domain) -> Sequence[UnitRange]: +def get_sorted_dim_ranges(domain: common.Domain) -> Sequence[common.FiniteUnitRange]: + assert common.Domain.is_finite(domain) sorted_dims = get_sorted_dims(domain.dims) return [domain.ranges[dim_index] for dim_index, _ in sorted_dims] @@ -54,7 +54,7 @@ def get_sorted_dim_ranges(domain: Domain) -> Sequence[UnitRange]: def convert_arg(arg: Any): - if is_field(arg): + if common.is_field(arg): sorted_dims = get_sorted_dims(arg.domain.dims) ndim = len(sorted_dims) dim_indices = [dim_index for dim_index, _ in sorted_dims] @@ -67,9 +67,11 @@ def convert_arg(arg: Any): def preprocess_program( - program: itir.FencilDefinition, offset_provider: Mapping[str, Any], lift_mode: LiftMode + program: itir.FencilDefinition, + offset_provider: Mapping[str, Any], + lift_mode: itir_transforms.LiftMode, ): - node = apply_common_transforms( + node = itir_transforms.apply_common_transforms( program, common_subexpression_elimination=False, lift_mode=lift_mode, @@ -81,7 +83,7 @@ def preprocess_program( if all([ItirToSDFG._check_no_lifts(closure) for closure in node.closures]): fencil_definition = node else: - fencil_definition = apply_common_transforms( + fencil_definition = itir_transforms.apply_common_transforms( program, common_subexpression_elimination=False, force_inline_lambda_args=True, @@ -109,7 +111,7 @@ def _ensure_is_on_device( def get_connectivity_args( - neighbor_tables: Sequence[tuple[str, NeighborTableOffsetProvider]], + neighbor_tables: Sequence[tuple[str, itir_embedded.NeighborTableOffsetProvider]], device: dace.dtypes.DeviceType, ) -> dict[str, Any]: return { @@ -134,7 +136,7 @@ def get_offset_args( return { str(sym): -drange.start for param, arg in zip(params, args) - if is_field(arg) + if common.is_field(arg) for sym, drange in zip(arrays[param.id].offset, get_sorted_dim_ranges(arg.domain)) } @@ -162,13 +164,19 @@ def get_stride_args( def get_cache_id( program: itir.FencilDefinition, arg_types: Sequence[ts.TypeSpec], - column_axis: Optional[Dimension], + column_axis: Optional[common.Dimension], offset_provider: Mapping[str, Any], ) -> str: max_neighbors = [ (k, v.max_neighbors) for k, v in offset_provider.items() - if isinstance(v, (NeighborTableOffsetProvider, StridedNeighborOffsetProvider)) + if isinstance( + v, + ( + itir_embedded.NeighborTableOffsetProvider, + itir_embedded.StridedNeighborOffsetProvider, + ), + ) ] cache_id_args = [ str(arg) @@ -191,8 +199,8 @@ def build_sdfg_from_itir( offset_provider: dict[str, Any], auto_optimize: bool = False, on_gpu: bool = False, - column_axis: Optional[Dimension] = None, - lift_mode: LiftMode = LiftMode.FORCE_INLINE, + column_axis: Optional[common.Dimension] = None, + lift_mode: itir_transforms.LiftMode = itir_transforms.LiftMode.FORCE_INLINE, ) -> dace.SDFG: """Translate a Fencil into an SDFG. @@ -210,7 +218,7 @@ def build_sdfg_from_itir( """ # TODO(edopao): As temporary fix until temporaries are supported in the DaCe Backend force # `lift_more` to `FORCE_INLINE` mode. - lift_mode = LiftMode.FORCE_INLINE + lift_mode = itir_transforms.LiftMode.FORCE_INLINE arg_types = [type_translation.from_value(arg) for arg in args] device = dace.DeviceType.GPU if on_gpu else dace.DeviceType.CPU @@ -237,7 +245,7 @@ def run_dace_iterator(program: itir.FencilDefinition, *args, **kwargs): build_type = kwargs.get("build_type", "RelWithDebInfo") on_gpu = kwargs.get("on_gpu", False) auto_optimize = kwargs.get("auto_optimize", False) - lift_mode = kwargs.get("lift_mode", LiftMode.FORCE_INLINE) + lift_mode = kwargs.get("lift_mode", itir_transforms.LiftMode.FORCE_INLINE) # ITIR parameters column_axis = kwargs.get("column_axis", None) offset_provider = kwargs["offset_provider"] @@ -263,7 +271,7 @@ def run_dace_iterator(program: itir.FencilDefinition, *args, **kwargs): lift_mode=lift_mode, ) - sdfg.build_folder = cache._session_cache_dir_path / ".dacecache" + sdfg.build_folder = compilation_cache._session_cache_dir_path / ".dacecache" with dace.config.temporary_config(): dace.config.Config.set("compiler", "build_type", value=build_type) dace.config.Config.set("compiler", "cpu", "args", value=_cpu_args) diff --git a/tests/next_tests/unit_tests/embedded_tests/test_nd_array_field.py b/tests/next_tests/unit_tests/embedded_tests/test_nd_array_field.py index 1a38e5245e..6863b09c12 100644 --- a/tests/next_tests/unit_tests/embedded_tests/test_nd_array_field.py +++ b/tests/next_tests/unit_tests/embedded_tests/test_nd_array_field.py @@ -11,7 +11,6 @@ # distribution for a copy of the license or check . # # SPDX-License-Identifier: GPL-3.0-or-later -import dataclasses import itertools import math import operator @@ -20,7 +19,7 @@ import numpy as np import pytest -from gt4py.next import common, embedded +from gt4py.next import common from gt4py.next.common import Dimension, Domain, UnitRange from gt4py.next.embedded import exceptions as embedded_exceptions, nd_array_field from gt4py.next.embedded.nd_array_field import _get_slices_from_domain_slice @@ -353,7 +352,7 @@ def test_cartesian_remap_implementation(): common.field( np.arange(10), domain=common.Domain(dims=(IDim,), ranges=(UnitRange(0, 10),)) ), - Domain(dims=(IDim, JDim), ranges=(UnitRange(0, 10), UnitRange.infinity())), + Domain(dims=(IDim, JDim), ranges=(UnitRange(0, 10), UnitRange.infinite())), ) ), ( @@ -362,7 +361,7 @@ def test_cartesian_remap_implementation(): common.field( np.arange(10), domain=common.Domain(dims=(JDim,), ranges=(UnitRange(0, 10),)) ), - Domain(dims=(IDim, JDim), ranges=(UnitRange.infinity(), UnitRange(0, 10))), + Domain(dims=(IDim, JDim), ranges=(UnitRange.infinite(), UnitRange(0, 10))), ) ), ( @@ -373,7 +372,7 @@ def test_cartesian_remap_implementation(): ), Domain( dims=(IDim, JDim, KDim), - ranges=(UnitRange.infinity(), UnitRange(0, 10), UnitRange.infinity()), + ranges=(UnitRange.infinite(), UnitRange(0, 10), UnitRange.infinite()), ), ) ), diff --git a/tests/next_tests/unit_tests/test_common.py b/tests/next_tests/unit_tests/test_common.py index bafabfb56e..7650e90c3c 100644 --- a/tests/next_tests/unit_tests/test_common.py +++ b/tests/next_tests/unit_tests/test_common.py @@ -14,6 +14,7 @@ import operator from typing import Optional, Pattern +import numpy as np import pytest from gt4py.next.common import ( @@ -41,6 +42,56 @@ def a_domain(): return Domain((IDim, UnitRange(0, 10)), (JDim, UnitRange(5, 15)), (KDim, UnitRange(20, 30))) +@pytest.fixture(params=[Infinity.POSITIVE, Infinity.NEGATIVE]) +def unbounded(request): + yield request.param + + +def test_unbounded_add_sub(unbounded): + assert unbounded + 1 == unbounded + assert unbounded - 1 == unbounded + + +@pytest.mark.parametrize("value", [-1, 0, 1]) +@pytest.mark.parametrize("op", [operator.le, operator.lt]) +def test_unbounded_comparison_less(value, op): + assert not op(Infinity.POSITIVE, value) + assert op(value, Infinity.POSITIVE) + + assert op(Infinity.NEGATIVE, value) + assert not op(value, Infinity.NEGATIVE) + + assert op(Infinity.NEGATIVE, Infinity.POSITIVE) + + +@pytest.mark.parametrize("value", [-1, 0, 1]) +@pytest.mark.parametrize("op", [operator.ge, operator.gt]) +def test_unbounded_comparison_greater(value, op): + assert op(Infinity.POSITIVE, value) + assert not op(value, Infinity.POSITIVE) + + assert not op(Infinity.NEGATIVE, value) + assert op(value, Infinity.NEGATIVE) + + assert not op(Infinity.NEGATIVE, Infinity.POSITIVE) + + +def test_unbounded_eq(unbounded): + assert unbounded == unbounded + assert unbounded <= unbounded + assert unbounded >= unbounded + assert not unbounded < unbounded + assert not unbounded > unbounded + + +@pytest.mark.parametrize("value", [-1, 0, 1]) +def test_unbounded_max_min(value): + assert max(Infinity.POSITIVE, value) == Infinity.POSITIVE + assert min(Infinity.POSITIVE, value) == value + assert max(Infinity.NEGATIVE, value) == value + assert min(Infinity.NEGATIVE, value) == Infinity.NEGATIVE + + def test_empty_range(): expected = UnitRange(0, 0) assert UnitRange(1, 1) == expected @@ -58,9 +109,20 @@ def test_unit_range_length(rng): assert len(rng) == 10 -@pytest.mark.parametrize("rng_like", [(2, 4), range(2, 4), UnitRange(2, 4)]) -def test_unit_range_like(rng_like): - assert unit_range(rng_like) == UnitRange(2, 4) +@pytest.mark.parametrize( + "rng_like, expected", + [ + ((2, 4), UnitRange(2, 4)), + (range(2, 4), UnitRange(2, 4)), + (UnitRange(2, 4), UnitRange(2, 4)), + ((None, None), UnitRange(Infinity.NEGATIVE, Infinity.POSITIVE)), + ((2, None), UnitRange(2, Infinity.POSITIVE)), + ((None, 4), UnitRange(Infinity.NEGATIVE, 4)), + (None, UnitRange(Infinity.NEGATIVE, Infinity.POSITIVE)), + ], +) +def test_unit_range_like(rng_like, expected): + assert unit_range(rng_like) == expected def test_unit_range_repr(rng): @@ -94,13 +156,6 @@ def test_unit_range_slice_error(rng): rng[1:2:5] -def test_unit_range_set_intersection(rng): - with pytest.raises( - NotImplementedError, match="Can only find the intersection between 'UnitRange' instances." - ): - rng & {1, 5} - - @pytest.mark.parametrize( "rng1, rng2, expected", [ @@ -121,46 +176,65 @@ def test_unit_range_intersection(rng1, rng2, expected): @pytest.mark.parametrize( "rng1, rng2, expected", [ - (UnitRange(20, Infinity.positive()), UnitRange(10, 15), UnitRange(0, 0)), - (UnitRange(Infinity.negative(), 0), UnitRange(5, 10), UnitRange(0, 0)), - (UnitRange(Infinity.negative(), 0), UnitRange(-10, 0), UnitRange(-10, 0)), - (UnitRange(0, Infinity.positive()), UnitRange(Infinity.negative(), 5), UnitRange(0, 5)), + (UnitRange(20, Infinity.POSITIVE), UnitRange(10, 15), UnitRange(0, 0)), + (UnitRange(Infinity.NEGATIVE, 0), UnitRange(5, 10), UnitRange(0, 0)), + (UnitRange(Infinity.NEGATIVE, 0), UnitRange(-10, 0), UnitRange(-10, 0)), + (UnitRange(0, Infinity.POSITIVE), UnitRange(Infinity.NEGATIVE, 5), UnitRange(0, 5)), ( - UnitRange(Infinity.negative(), 0), - UnitRange(Infinity.negative(), 5), - UnitRange(Infinity.negative(), 0), + UnitRange(Infinity.NEGATIVE, 0), + UnitRange(Infinity.NEGATIVE, 5), + UnitRange(Infinity.NEGATIVE, 0), ), ( - UnitRange(Infinity.negative(), Infinity.positive()), - UnitRange(Infinity.negative(), Infinity.positive()), - UnitRange(Infinity.negative(), Infinity.positive()), + UnitRange(Infinity.NEGATIVE, Infinity.POSITIVE), + UnitRange(Infinity.NEGATIVE, Infinity.POSITIVE), + UnitRange(Infinity.NEGATIVE, Infinity.POSITIVE), ), ], ) -def test_unit_range_infinite_intersection(rng1, rng2, expected): +def test_unit_range_unbounded_intersection(rng1, rng2, expected): result = rng1 & rng2 assert result == expected -def test_positive_infinity_range(): - pos_inf_range = UnitRange(Infinity.positive(), Infinity.positive()) - assert len(pos_inf_range) == 0 +@pytest.mark.parametrize( + "rng", + [ + UnitRange(Infinity.NEGATIVE, 0), + UnitRange(0, Infinity.POSITIVE), + UnitRange(Infinity.NEGATIVE, Infinity.POSITIVE), + ], +) +def test_positive_infinite_range_len(rng): + with pytest.raises(ValueError, match=r".*open.*"): + len(rng) -def test_mixed_infinity_range(): - mixed_inf_range = UnitRange(Infinity.negative(), Infinity.positive()) - assert len(mixed_inf_range) == Infinity.positive() +def test_range_contains(): + assert 1 in UnitRange(0, 2) + assert 1 not in UnitRange(0, 1) + assert 1 in UnitRange(0, Infinity.POSITIVE) + assert 1 in UnitRange(Infinity.NEGATIVE, 2) + assert 1 in UnitRange(Infinity.NEGATIVE, Infinity.POSITIVE) + assert "s" not in UnitRange(Infinity.NEGATIVE, Infinity.POSITIVE) @pytest.mark.parametrize( "op, rng1, rng2, expected", [ (operator.le, UnitRange(-1, 2), UnitRange(-2, 3), True), - (operator.le, UnitRange(-1, 2), {-1, 0, 1}, True), - (operator.le, UnitRange(-1, 2), {-1, 0}, False), - (operator.le, UnitRange(-1, 2), {-2, -1, 0, 1, 2}, True), - (operator.le, UnitRange(Infinity.negative(), 2), UnitRange(Infinity.negative(), 3), True), - (operator.le, UnitRange(Infinity.negative(), 2), {1, 2, 3}, False), + (operator.le, UnitRange(Infinity.NEGATIVE, 2), UnitRange(Infinity.NEGATIVE, 3), True), + (operator.ge, UnitRange(-2, 3), UnitRange(-1, 2), True), + (operator.ge, UnitRange(Infinity.NEGATIVE, 3), UnitRange(Infinity.NEGATIVE, 2), True), + (operator.lt, UnitRange(-1, 2), UnitRange(-2, 2), True), + (operator.lt, UnitRange(-2, 1), UnitRange(-2, 2), True), + (operator.lt, UnitRange(Infinity.NEGATIVE, 2), UnitRange(Infinity.NEGATIVE, 3), True), + (operator.gt, UnitRange(-2, 2), UnitRange(-1, 2), True), + (operator.gt, UnitRange(-2, 2), UnitRange(-2, 1), True), + (operator.gt, UnitRange(Infinity.NEGATIVE, 3), UnitRange(Infinity.NEGATIVE, 2), True), + (operator.eq, UnitRange(Infinity.NEGATIVE, 2), UnitRange(Infinity.NEGATIVE, 2), True), + (operator.ne, UnitRange(Infinity.NEGATIVE, 2), UnitRange(Infinity.NEGATIVE, 3), True), + (operator.ne, UnitRange(Infinity.NEGATIVE, 2), UnitRange(Infinity.NEGATIVE, 2), False), ], ) def test_range_comparison(op, rng1, rng2, expected): From 6c7c5d51b440c40175a25fb75fcbde7c919afbd4 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Philip=20M=C3=BCller?= <147368808+philip-paul-mueller@users.noreply.github.com> Date: Mon, 18 Dec 2023 10:58:51 +0100 Subject: [PATCH 32/85] feat[dace]: Buildflags to the `ITIR -> SDFG` translation (#1389) Made it possible to also pass build flags to the `ITIR -> SDFG` translator. --- .../runners/dace_iterator/__init__.py | 19 ++++++++++++++----- 1 file changed, 14 insertions(+), 5 deletions(-) diff --git a/src/gt4py/next/program_processors/runners/dace_iterator/__init__.py b/src/gt4py/next/program_processors/runners/dace_iterator/__init__.py index 037c4f3e4d..59569de30b 100644 --- a/src/gt4py/next/program_processors/runners/dace_iterator/__init__.py +++ b/src/gt4py/next/program_processors/runners/dace_iterator/__init__.py @@ -47,10 +47,6 @@ def get_sorted_dim_ranges(domain: common.Domain) -> Sequence[common.FiniteUnitRa """ Default build configuration in DaCe backend """ _build_type = "Release" -# removing -ffast-math from DaCe default compiler args in order to support isfinite/isinf/isnan built-ins -_cpu_args = ( - "-std=c++14 -fPIC -Wall -Wextra -O3 -march=native -Wno-unused-parameter -Wno-unused-label" -) def convert_arg(arg: Any): @@ -242,6 +238,7 @@ def build_sdfg_from_itir( def run_dace_iterator(program: itir.FencilDefinition, *args, **kwargs): # build parameters build_cache = kwargs.get("build_cache", None) + compiler_args = kwargs.get("compiler_args", None) # `None` will take default. build_type = kwargs.get("build_type", "RelWithDebInfo") on_gpu = kwargs.get("on_gpu", False) auto_optimize = kwargs.get("auto_optimize", False) @@ -274,7 +271,10 @@ def run_dace_iterator(program: itir.FencilDefinition, *args, **kwargs): sdfg.build_folder = compilation_cache._session_cache_dir_path / ".dacecache" with dace.config.temporary_config(): dace.config.Config.set("compiler", "build_type", value=build_type) - dace.config.Config.set("compiler", "cpu", "args", value=_cpu_args) + if compiler_args is not None: + dace.config.Config.set( + "compiler", "cuda" if on_gpu else "cpu", "args", value=compiler_args + ) sdfg_program = sdfg.compile(validate=False) # store SDFG program in build cache @@ -312,12 +312,21 @@ def run_dace_iterator(program: itir.FencilDefinition, *args, **kwargs): def _run_dace_cpu(program: itir.FencilDefinition, *args, **kwargs) -> None: + compiler_args = dace.config.Config.get("compiler", "cpu", "args") + + # disable finite-math-only in order to support isfinite/isinf/isnan builtins + if "-ffast-math" in compiler_args: + compiler_args += " -fno-finite-math-only" + if "-ffinite-math-only" in compiler_args: + compiler_args.replace("-ffinite-math-only", "") + run_dace_iterator( program, *args, **kwargs, build_cache=_build_cache_cpu, build_type=_build_type, + compiler_args=compiler_args, on_gpu=False, ) From 315d9203bb667baa3daaea4b797a0846a2b70887 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Philip=20M=C3=BCller?= <147368808+philip-paul-mueller@users.noreply.github.com> Date: Tue, 19 Dec 2023 07:35:51 +0100 Subject: [PATCH 33/85] feat[dace]: Computing SDFG call arguments (#1398) Added a function to get the arguments to call an SDFG. This commit adds a function that allows to generate the arguments needed to call an SDFG, before this was part of `run_dace_iterator()`. This made it very complex to run an SDFG outside this function. One should consider this as an amend to [PR #1379](https://github.com/GridTools/gt4py/pull/1379). --- .../runners/dace_iterator/__init__.py | 79 ++++++++++++------- 1 file changed, 49 insertions(+), 30 deletions(-) diff --git a/src/gt4py/next/program_processors/runners/dace_iterator/__init__.py b/src/gt4py/next/program_processors/runners/dace_iterator/__init__.py index 59569de30b..97dd90eb54 100644 --- a/src/gt4py/next/program_processors/runners/dace_iterator/__init__.py +++ b/src/gt4py/next/program_processors/runners/dace_iterator/__init__.py @@ -90,8 +90,9 @@ def preprocess_program( return fencil_definition -def get_args(params: Sequence[itir.Sym], args: Sequence[Any]) -> dict[str, Any]: - return {name.id: convert_arg(arg) for name, arg in zip(params, args)} +def get_args(sdfg: dace.SDFG, args: Sequence[Any]) -> dict[str, Any]: + sdfg_params: Sequence[str] = sdfg.arg_names + return {sdfg_param: convert_arg(arg) for sdfg_param, arg in zip(sdfg_params, args)} def _ensure_is_on_device( @@ -127,13 +128,16 @@ def get_shape_args( def get_offset_args( - arrays: Mapping[str, dace.data.Array], params: Sequence[itir.Sym], args: Sequence[Any] + sdfg: dace.SDFG, + args: Sequence[Any], ) -> Mapping[str, int]: + sdfg_arrays: Mapping[str, dace.data.Array] = sdfg.arrays + sdfg_params: Sequence[str] = sdfg.arg_names return { str(sym): -drange.start - for param, arg in zip(params, args) + for sdfg_param, arg in zip(sdfg_params, args) if common.is_field(arg) - for sym, drange in zip(arrays[param.id].offset, get_sorted_dim_ranges(arg.domain)) + for sym, drange in zip(sdfg_arrays[sdfg_param].offset, get_sorted_dim_ranges(arg.domain)) } @@ -189,6 +193,45 @@ def get_cache_id( return m.hexdigest() +def get_sdfg_args(sdfg: dace.SDFG, *args, **kwargs) -> dict[str, Any]: + """Extracts the arguments needed to call the SDFG. + + This function can handle the same arguments that are passed to `run_dace_iterator()`. + + Args: + sdfg: The SDFG for which we want to get the arguments. + """ # noqa: D401 + offset_provider = kwargs["offset_provider"] + on_gpu = kwargs.get("on_gpu", False) + + neighbor_tables = filter_neighbor_tables(offset_provider) + device = dace.DeviceType.GPU if on_gpu else dace.DeviceType.CPU + + dace_args = get_args(sdfg, args) + dace_field_args = {n: v for n, v in dace_args.items() if not np.isscalar(v)} + dace_conn_args = get_connectivity_args(neighbor_tables, device) + dace_shapes = get_shape_args(sdfg.arrays, dace_field_args) + dace_conn_shapes = get_shape_args(sdfg.arrays, dace_conn_args) + dace_strides = get_stride_args(sdfg.arrays, dace_field_args) + dace_conn_strides = get_stride_args(sdfg.arrays, dace_conn_args) + dace_offsets = get_offset_args(sdfg, args) + all_args = { + **dace_args, + **dace_conn_args, + **dace_shapes, + **dace_conn_shapes, + **dace_strides, + **dace_conn_strides, + **dace_offsets, + } + expected_args = { + key: value + for key, value in all_args.items() + if key in sdfg.signature_arglist(with_types=False) + } + return expected_args + + def build_sdfg_from_itir( program: itir.FencilDefinition, *args, @@ -248,8 +291,6 @@ def run_dace_iterator(program: itir.FencilDefinition, *args, **kwargs): offset_provider = kwargs["offset_provider"] arg_types = [type_translation.from_value(arg) for arg in args] - device = dace.DeviceType.GPU if on_gpu else dace.DeviceType.CPU - neighbor_tables = filter_neighbor_tables(offset_provider) cache_id = get_cache_id(program, arg_types, column_axis, offset_provider) if build_cache is not None and cache_id in build_cache: @@ -281,29 +322,7 @@ def run_dace_iterator(program: itir.FencilDefinition, *args, **kwargs): if build_cache is not None: build_cache[cache_id] = sdfg_program - dace_args = get_args(program.params, args) - dace_field_args = {n: v for n, v in dace_args.items() if not np.isscalar(v)} - dace_conn_args = get_connectivity_args(neighbor_tables, device) - dace_shapes = get_shape_args(sdfg.arrays, dace_field_args) - dace_conn_shapes = get_shape_args(sdfg.arrays, dace_conn_args) - dace_strides = get_stride_args(sdfg.arrays, dace_field_args) - dace_conn_strides = get_stride_args(sdfg.arrays, dace_conn_args) - dace_offsets = get_offset_args(sdfg.arrays, program.params, args) - - all_args = { - **dace_args, - **dace_conn_args, - **dace_shapes, - **dace_conn_shapes, - **dace_strides, - **dace_conn_strides, - **dace_offsets, - } - expected_args = { - key: value - for key, value in all_args.items() - if key in sdfg.signature_arglist(with_types=False) - } + expected_args = get_sdfg_args(sdfg, *args, **kwargs) with dace.config.temporary_config(): dace.config.Config.set("compiler", "allow_view_arguments", value=True) From 15a7bd627d9fc818befd5f6ff6e795868563ff37 Mon Sep 17 00:00:00 2001 From: edopao Date: Tue, 19 Dec 2023 08:43:41 +0100 Subject: [PATCH 34/85] fix[next][dace]: Fix memlet for array slicing (#1399) Implementation of array slicing in DaCe backend changed to a mapped tasklet. Tested on GPU. CUDA code generation did not support the previous implementation, based on memlet in nested-SDFG. --- .../runners/dace_iterator/itir_to_tasklet.py | 66 ++++++------------- 1 file changed, 21 insertions(+), 45 deletions(-) diff --git a/src/gt4py/next/program_processors/runners/dace_iterator/itir_to_tasklet.py b/src/gt4py/next/program_processors/runners/dace_iterator/itir_to_tasklet.py index d08476847f..4c202b1fe8 100644 --- a/src/gt4py/next/program_processors/runners/dace_iterator/itir_to_tasklet.py +++ b/src/gt4py/next/program_processors/runners/dace_iterator/itir_to_tasklet.py @@ -18,7 +18,6 @@ import dace import numpy as np -from dace import subsets from dace.transformation.dataflow import MapFusion import gt4py.eve.codegen @@ -754,52 +753,29 @@ def _visit_deref(self, node: itir.FunCall) -> list[ValueExpr]: dace.Memlet.simple(node.data, "0") for node in deref_nodes[1:] ] - # we create a nested sdfg in order to access the index scalar values as symbols in a memlet subset - deref_sdfg = dace.SDFG("deref") - deref_sdfg.add_array( - "_inp", field_array.shape, iterator.dtype, strides=field_array.strides - ) - for connector in deref_connectors[1:]: - deref_sdfg.add_scalar(connector, _INDEX_DTYPE) - deref_sdfg.add_array("_out", result_shape, iterator.dtype) - deref_init_state = deref_sdfg.add_state("init", True) - deref_access_state = deref_sdfg.add_state("access") - deref_sdfg.add_edge( - deref_init_state, - deref_access_state, - dace.InterstateEdge( - assignments={f"_sym{inp}": inp for inp in deref_connectors[1:]} - ), - ) - # we access the size in source field shape as symbols set on the nested sdfg - source_subset = tuple( - f"_sym_i_{dim}" if dim in iterator.indices else f"0:{size}" + # we create a mapped tasklet for array slicing + map_ranges = { + f"_i_{dim}": f"0:{size}" for dim, size in zip(sorted_dims, field_array.shape) + if dim not in iterator.indices + } + src_subset = ",".join([f"_i_{dim}" for dim in sorted_dims]) + dst_subset = ",".join( + [f"_i_{dim}" for dim in sorted_dims if dim not in iterator.indices] ) - deref_access_state.add_nedge( - deref_access_state.add_access("_inp"), - deref_access_state.add_access("_out"), - dace.Memlet( - data="_out", - subset=subsets.Range.from_array(result_array), - other_subset=",".join(source_subset), - ), - ) - - deref_node = self.context.state.add_nested_sdfg( - deref_sdfg, - self.context.body, - inputs=set(deref_connectors), - outputs={"_out"}, - ) - for connector, node, memlet in zip(deref_connectors, deref_nodes, deref_memlets): - self.context.state.add_edge(node, None, deref_node, connector, memlet) - self.context.state.add_edge( - deref_node, - "_out", - result_node, - None, - dace.Memlet.from_array(result_name, result_array), + self.context.state.add_mapped_tasklet( + "deref", + map_ranges, + inputs={k: v for k, v in zip(deref_connectors, deref_memlets)}, + outputs={ + "_out": dace.Memlet.from_array(result_name, result_array), + }, + code=f"_out[{dst_subset}] = _inp[{src_subset}]", + external_edges=True, + input_nodes={node.data: node for node in deref_nodes}, + output_nodes={ + result_name: result_node, + }, ) return [ValueExpr(result_node, iterator.dtype)] From af33e21fab16fb3de13ec5721b050dada63e220c Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Philip=20M=C3=BCller?= <147368808+philip-paul-mueller@users.noreply.github.com> Date: Tue, 19 Dec 2023 11:31:23 +0100 Subject: [PATCH 35/85] fix[dace]: Fixed SDFG args (#1400) Modified how the SDFG arguments are computed. It was noticed that some transformations, especially the `SDFG.apply_gpu_transformation()`, to the SDFG, added new arguments to the SDFG. But, since a lot of functions build on the `SDFG.arg_names` member and this member was populated before the transformation, an error occurred. Thus it was changed such that `SDFG.arg_names` was only populated with the arguments also known to the Fencil. --- .../runners/dace_iterator/__init__.py | 17 ++++++++--------- .../runners/dace_iterator/itir_to_sdfg.py | 11 +++-------- 2 files changed, 11 insertions(+), 17 deletions(-) diff --git a/src/gt4py/next/program_processors/runners/dace_iterator/__init__.py b/src/gt4py/next/program_processors/runners/dace_iterator/__init__.py index 97dd90eb54..7fd4794e57 100644 --- a/src/gt4py/next/program_processors/runners/dace_iterator/__init__.py +++ b/src/gt4py/next/program_processors/runners/dace_iterator/__init__.py @@ -207,6 +207,7 @@ def get_sdfg_args(sdfg: dace.SDFG, *args, **kwargs) -> dict[str, Any]: neighbor_tables = filter_neighbor_tables(offset_provider) device = dace.DeviceType.GPU if on_gpu else dace.DeviceType.CPU + sdfg_sig = sdfg.signature_arglist(with_types=False) dace_args = get_args(sdfg, args) dace_field_args = {n: v for n, v in dace_args.items() if not np.isscalar(v)} dace_conn_args = get_connectivity_args(neighbor_tables, device) @@ -224,11 +225,8 @@ def get_sdfg_args(sdfg: dace.SDFG, *args, **kwargs) -> dict[str, Any]: **dace_conn_strides, **dace_offsets, } - expected_args = { - key: value - for key, value in all_args.items() - if key in sdfg.signature_arglist(with_types=False) - } + expected_args = {key: all_args[key] for key in sdfg_sig} + return expected_args @@ -258,21 +256,22 @@ def build_sdfg_from_itir( # TODO(edopao): As temporary fix until temporaries are supported in the DaCe Backend force # `lift_more` to `FORCE_INLINE` mode. lift_mode = itir_transforms.LiftMode.FORCE_INLINE - arg_types = [type_translation.from_value(arg) for arg in args] - device = dace.DeviceType.GPU if on_gpu else dace.DeviceType.CPU # visit ITIR and generate SDFG program = preprocess_program(program, offset_provider, lift_mode) + # TODO: According to Lex one should build the SDFG first in a general mannor. + # Generalisation to a particular device should happen only at the end. sdfg_genenerator = ItirToSDFG(arg_types, offset_provider, column_axis, on_gpu) sdfg = sdfg_genenerator.visit(program) sdfg.simplify() # run DaCe auto-optimization heuristics if auto_optimize: - # TODO Investigate how symbol definitions improve autoopt transformations, - # in which case the cache table should take the symbols map into account. + # TODO: Investigate how symbol definitions improve autoopt transformations, + # in which case the cache table should take the symbols map into account. symbols: dict[str, int] = {} + device = dace.DeviceType.GPU if on_gpu else dace.DeviceType.CPU sdfg = autoopt.auto_optimize(sdfg, device, symbols=symbols, use_gpu_storage=on_gpu) return sdfg diff --git a/src/gt4py/next/program_processors/runners/dace_iterator/itir_to_sdfg.py b/src/gt4py/next/program_processors/runners/dace_iterator/itir_to_sdfg.py index b3e6662623..e3b5ddf2ac 100644 --- a/src/gt4py/next/program_processors/runners/dace_iterator/itir_to_sdfg.py +++ b/src/gt4py/next/program_processors/runners/dace_iterator/itir_to_sdfg.py @@ -209,14 +209,9 @@ def visit_FencilDefinition(self, node: itir.FencilDefinition): last_state.add_edge(nsdfg_node, inner_name, access_node, None, memlet) # Create the call signature for the SDFG. - # All arguments required by the SDFG, regardless if explicit and implicit, are added - # as positional arguments. In the front are all arguments to the Fencil, in that - # order, they are followed by the arguments created by the translation process, - arg_list = [str(a) for a in node.params] - sig_list = program_sdfg.signature_arglist(with_types=False) - implicit_args = set(sig_list) - set(arg_list) - call_params = arg_list + [ia for ia in sig_list if ia in implicit_args] - program_sdfg.arg_names = call_params + # Only the arguments requiered by the Fencil, i.e. `node.params` are added as poitional arguments. + # The implicit arguments, such as the offset providers or the arguments created by the translation process, must be passed as keywords only arguments. + program_sdfg.arg_names = [str(a) for a in node.params] program_sdfg.validate() return program_sdfg From b21dd566bcbd805279d94f36a20c5ea34a300d97 Mon Sep 17 00:00:00 2001 From: Hannes Vogt Date: Tue, 19 Dec 2023 12:06:59 +0100 Subject: [PATCH 36/85] feat[next]: Test for local dimension in output (#1392) Currently only supported in field view embedded. --- pyproject.toml | 1 + tests/next_tests/exclusion_matrices.py | 3 +++ .../ffront_tests/test_external_local_field.py | 19 +++++++++++++++++++ 3 files changed, 23 insertions(+) diff --git a/pyproject.toml b/pyproject.toml index 2cf4fb12e2..5d7a2f2cb6 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -342,6 +342,7 @@ markers = [ 'uses_reduction_with_only_sparse_fields: tests that require backend support for with sparse fields', 'uses_scan_in_field_operator: tests that require backend support for scan in field operator', 'uses_sparse_fields: tests that require backend support for sparse fields', + 'uses_sparse_fields_as_output: tests that require backend support for writing sparse fields', 'uses_strided_neighbor_offset: tests that require backend support for strided neighbor offset', 'uses_tuple_args: tests that require backend support for tuple arguments', 'uses_tuple_returns: tests that require backend support for tuple results', diff --git a/tests/next_tests/exclusion_matrices.py b/tests/next_tests/exclusion_matrices.py index 3c42a180dd..f6d2b10a14 100644 --- a/tests/next_tests/exclusion_matrices.py +++ b/tests/next_tests/exclusion_matrices.py @@ -95,6 +95,7 @@ class ProgramFormatterId(_PythonObjectIdMixin, str, enum.Enum): USES_REDUCTION_OVER_LIFT_EXPRESSIONS = "uses_reduction_over_lift_expressions" USES_SCAN_IN_FIELD_OPERATOR = "uses_scan_in_field_operator" USES_SPARSE_FIELDS = "uses_sparse_fields" +USES_SPARSE_FIELDS_AS_OUTPUT = "uses_sparse_fields_as_output" USES_REDUCTION_WITH_ONLY_SPARSE_FIELDS = "uses_reduction_with_only_sparse_fields" USES_STRIDED_NEIGHBOR_OFFSET = "uses_strided_neighbor_offset" USES_TUPLE_ARGS = "uses_tuple_args" @@ -119,6 +120,7 @@ class ProgramFormatterId(_PythonObjectIdMixin, str, enum.Enum): (USES_NEGATIVE_MODULO, XFAIL, UNSUPPORTED_MESSAGE), (USES_REDUCTION_WITH_ONLY_SPARSE_FIELDS, XFAIL, REDUCTION_WITH_ONLY_SPARSE_FIELDS_MESSAGE), (USES_SCAN_IN_FIELD_OPERATOR, XFAIL, UNSUPPORTED_MESSAGE), + (USES_SPARSE_FIELDS_AS_OUTPUT, XFAIL, UNSUPPORTED_MESSAGE), ] DACE_SKIP_TEST_LIST = COMMON_SKIP_TEST_LIST + [ (USES_CONSTANT_FIELDS, XFAIL, UNSUPPORTED_MESSAGE), @@ -159,4 +161,5 @@ class ProgramFormatterId(_PythonObjectIdMixin, str, enum.Enum): ProgramFormatterId.GTFN_CPP_FORMATTER: [ (USES_REDUCTION_WITH_ONLY_SPARSE_FIELDS, XFAIL, REDUCTION_WITH_ONLY_SPARSE_FIELDS_MESSAGE), ], + ProgramBackendId.ROUNDTRIP: [(USES_SPARSE_FIELDS_AS_OUTPUT, XFAIL, UNSUPPORTED_MESSAGE)], } diff --git a/tests/next_tests/integration_tests/feature_tests/ffront_tests/test_external_local_field.py b/tests/next_tests/integration_tests/feature_tests/ffront_tests/test_external_local_field.py index 42938e2f4b..698dce2b5c 100644 --- a/tests/next_tests/integration_tests/feature_tests/ffront_tests/test_external_local_field.py +++ b/tests/next_tests/integration_tests/feature_tests/ffront_tests/test_external_local_field.py @@ -82,3 +82,22 @@ def testee(inp: gtx.Field[[Vertex, V2EDim], int32]) -> gtx.Field[[Vertex], int32 out=cases.allocate(unstructured_case, testee, cases.RETURN)(), ref=np.sum(unstructured_case.offset_provider["V2E"].table, axis=1), ) + + +@pytest.mark.uses_sparse_fields_as_output +def test_write_local_field(unstructured_case): + @gtx.field_operator + def testee(inp: gtx.Field[[Edge], int32]) -> gtx.Field[[Vertex, V2EDim], int32]: + return inp(V2E) + + out = unstructured_case.as_field( + [Vertex, V2EDim], np.zeros_like(unstructured_case.offset_provider["V2E"].table) + ) + inp = cases.allocate(unstructured_case, testee, "inp")() + cases.verify( + unstructured_case, + testee, + inp, + out=out, + ref=inp.asnumpy()[unstructured_case.offset_provider["V2E"].table], + ) From 100bc7fee17e9235da070e1bbf0fedd615de541f Mon Sep 17 00:00:00 2001 From: Till Ehrengruber Date: Wed, 3 Jan 2024 12:09:23 +0100 Subject: [PATCH 37/85] Add missing grid_type argument to scan operator decorator (#1404) --- src/gt4py/next/ffront/decorator.py | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/src/gt4py/next/ffront/decorator.py b/src/gt4py/next/ffront/decorator.py index 4abd8f156a..53159008f0 100644 --- a/src/gt4py/next/ffront/decorator.py +++ b/src/gt4py/next/ffront/decorator.py @@ -775,6 +775,7 @@ def scan_operator( forward: bool, init: core_defs.Scalar, backend: Optional[str], + grid_type: GridType, ) -> FieldOperator[foast.ScanOperator]: ... @@ -786,6 +787,7 @@ def scan_operator( forward: bool, init: core_defs.Scalar, backend: Optional[str], + grid_type: GridType, ) -> Callable[[types.FunctionType], FieldOperator[foast.ScanOperator]]: ... @@ -797,6 +799,7 @@ def scan_operator( forward: bool = True, init: core_defs.Scalar = 0.0, backend=None, + grid_type: GridType = None, ) -> ( FieldOperator[foast.ScanOperator] | Callable[[types.FunctionType], FieldOperator[foast.ScanOperator]] @@ -834,6 +837,7 @@ def scan_operator_inner(definition: types.FunctionType) -> FieldOperator: return FieldOperator.from_function( definition, backend, + grid_type, operator_node_cls=foast.ScanOperator, operator_attributes={"axis": axis, "forward": forward, "init": init}, ) From 7a9489f73ddddd6aff219fc3890bed23e791a9a8 Mon Sep 17 00:00:00 2001 From: Till Ehrengruber Date: Thu, 4 Jan 2024 00:47:33 +0100 Subject: [PATCH 38/85] Fix size check in CollapseTuple pass (#1405) --- src/gt4py/next/iterator/transforms/collapse_tuple.py | 3 +++ src/gt4py/next/iterator/type_inference.py | 6 ++++++ 2 files changed, 9 insertions(+) diff --git a/src/gt4py/next/iterator/transforms/collapse_tuple.py b/src/gt4py/next/iterator/transforms/collapse_tuple.py index 7d710fc919..30457f2246 100644 --- a/src/gt4py/next/iterator/transforms/collapse_tuple.py +++ b/src/gt4py/next/iterator/transforms/collapse_tuple.py @@ -41,6 +41,9 @@ def _get_tuple_size(elem: ir.Node, node_types: Optional[dict] = None) -> int | t ): return UnknownLength + if not type_.dtype.has_known_length: + return UnknownLength + return len(type_.dtype) diff --git a/src/gt4py/next/iterator/type_inference.py b/src/gt4py/next/iterator/type_inference.py index 2375118cd1..68627cfd89 100644 --- a/src/gt4py/next/iterator/type_inference.py +++ b/src/gt4py/next/iterator/type_inference.py @@ -77,6 +77,12 @@ def __iter__(self) -> abc.Iterator[Type]: raise ValueError(f"Can not iterate over partially defined tuple '{self}'.") yield from self.others + @property + def has_known_length(self): + return isinstance(self.others, EmptyTuple) or ( + isinstance(self.others, Tuple) and self.others.has_known_length + ) + def __len__(self) -> int: return sum(1 for _ in self) From 27bf18f570e85233f91d74884df417162df0eef0 Mon Sep 17 00:00:00 2001 From: Hannes Vogt Date: Thu, 4 Jan 2024 16:14:34 +0100 Subject: [PATCH 39/85] bug[next]: respect DEFAULT_BACKEND and no_backend mechanism (#1380) fixes #1376. Thanks @DropD for the testcase. --- src/gt4py/next/embedded/operators.py | 16 +- src/gt4py/next/errors/__init__.py | 2 + src/gt4py/next/errors/exceptions.py | 12 ++ src/gt4py/next/ffront/decorator.py | 47 +++--- .../ffront_tests/ffront_test_utils.py | 5 +- .../test_math_builtin_execution.py | 12 +- .../ffront_tests/test_embedded_regression.py | 137 ++++++++++++++++++ 7 files changed, 200 insertions(+), 31 deletions(-) create mode 100644 tests/next_tests/integration_tests/multi_feature_tests/ffront_tests/test_embedded_regression.py diff --git a/src/gt4py/next/embedded/operators.py b/src/gt4py/next/embedded/operators.py index f50ace7687..0992401ebb 100644 --- a/src/gt4py/next/embedded/operators.py +++ b/src/gt4py/next/embedded/operators.py @@ -17,7 +17,7 @@ from gt4py import eve from gt4py._core import definitions as core_defs -from gt4py.next import common, constructors, utils +from gt4py.next import common, constructors, errors, utils from gt4py.next.embedded import common as embedded_common, context as embedded_context @@ -77,17 +77,20 @@ def scan_loop(hpos): def field_operator_call(op: EmbeddedOperator, args: Any, kwargs: Any): if "out" in kwargs: # called from program or direct field_operator as program - offset_provider = kwargs.pop("offset_provider", None) - new_context_kwargs = {} if embedded_context.within_context(): # called from program - assert offset_provider is None + assert "offset_provider" not in kwargs else: # field_operator as program + if "offset_provider" not in kwargs: + raise errors.MissingArgumentError(None, "offset_provider", True) + offset_provider = kwargs.pop("offset_provider", None) + new_context_kwargs["offset_provider"] = offset_provider out = kwargs.pop("out") + domain = kwargs.pop("domain", None) flattened_out: tuple[common.Field, ...] = utils.flatten_nested_tuple((out,)) @@ -105,7 +108,10 @@ def field_operator_call(op: EmbeddedOperator, args: Any, kwargs: Any): domain=out_domain, ) else: - # called from other field_operator + # called from other field_operator or missing `out` argument + if "offset_provider" in kwargs: + # assuming we wanted to call the field_operator as program, otherwise `offset_provider` would not be there + raise errors.MissingArgumentError(None, "out", True) return op(*args, **kwargs) diff --git a/src/gt4py/next/errors/__init__.py b/src/gt4py/next/errors/__init__.py index 61441e83b9..dd48d6f0f9 100644 --- a/src/gt4py/next/errors/__init__.py +++ b/src/gt4py/next/errors/__init__.py @@ -21,6 +21,7 @@ from .exceptions import ( DSLError, InvalidParameterAnnotationError, + MissingArgumentError, MissingAttributeError, MissingParameterAnnotationError, UndefinedSymbolError, @@ -33,6 +34,7 @@ "InvalidParameterAnnotationError", "MissingAttributeError", "MissingParameterAnnotationError", + "MissingArgumentError", "UndefinedSymbolError", "UnsupportedPythonFeatureError", "set_verbose_exceptions", diff --git a/src/gt4py/next/errors/exceptions.py b/src/gt4py/next/errors/exceptions.py index 081453c023..858f969447 100644 --- a/src/gt4py/next/errors/exceptions.py +++ b/src/gt4py/next/errors/exceptions.py @@ -81,6 +81,18 @@ def __init__(self, location: Optional[SourceLocation], attr_name: str) -> None: self.attr_name = attr_name +class MissingArgumentError(DSLError): + arg_name: str + is_kwarg: bool + + def __init__(self, location: Optional[SourceLocation], arg_name: str, is_kwarg: bool) -> None: + super().__init__( + location, f"Expected {'keyword-' if is_kwarg else ''}argument '{arg_name}'." + ) + self.attr_name = arg_name + self.is_kwarg = is_kwarg + + class TypeError_(DSLError): def __init__(self, location: Optional[SourceLocation], message: str) -> None: super().__init__(location, message) diff --git a/src/gt4py/next/ffront/decorator.py b/src/gt4py/next/ffront/decorator.py index 53159008f0..76a0ddcde0 100644 --- a/src/gt4py/next/ffront/decorator.py +++ b/src/gt4py/next/ffront/decorator.py @@ -29,10 +29,11 @@ from devtools import debug +from gt4py import eve from gt4py._core import definitions as core_defs from gt4py.eve import utils as eve_utils from gt4py.eve.extended_typing import Any, Optional -from gt4py.next import allocators as next_allocators, embedded as next_embedded +from gt4py.next import allocators as next_allocators, embedded as next_embedded, errors from gt4py.next.common import Dimension, DimensionKind, GridType from gt4py.next.embedded import operators as embedded_operators from gt4py.next.ffront import ( @@ -61,11 +62,10 @@ sym, ) from gt4py.next.program_processors import processor_interface as ppi -from gt4py.next.program_processors.runners import roundtrip from gt4py.next.type_system import type_info, type_specifications as ts, type_translation -DEFAULT_BACKEND: Callable = roundtrip.executor +DEFAULT_BACKEND: Callable = None def _get_closure_vars_recursively(closure_vars: dict[str, Any]) -> dict[str, Any]: @@ -176,15 +176,15 @@ class Program: past_node: past.Program closure_vars: dict[str, Any] - definition: Optional[types.FunctionType] = None - backend: Optional[ppi.ProgramExecutor] = DEFAULT_BACKEND - grid_type: Optional[GridType] = None + definition: Optional[types.FunctionType] + backend: Optional[ppi.ProgramExecutor] + grid_type: Optional[GridType] @classmethod def from_function( cls, definition: types.FunctionType, - backend: Optional[ppi.ProgramExecutor] = DEFAULT_BACKEND, + backend: Optional[ppi.ProgramExecutor], grid_type: Optional[GridType] = None, ) -> Program: source_def = SourceDefinition.from_function(definition) @@ -495,7 +495,7 @@ def program(*, backend: Optional[ppi.ProgramExecutor]) -> Callable[[types.Functi def program( definition=None, *, - backend=None, + backend=eve.NOTHING, # `NOTHING` -> default backend, `None` -> no backend (embedded execution) grid_type=None, ) -> Program | Callable[[types.FunctionType], Program]: """ @@ -517,7 +517,9 @@ def program( """ def program_inner(definition: types.FunctionType) -> Program: - return Program.from_function(definition, backend, grid_type) + return Program.from_function( + definition, DEFAULT_BACKEND if backend is eve.NOTHING else backend, grid_type + ) return program_inner if definition is None else program_inner(definition) @@ -549,9 +551,9 @@ class FieldOperator(GTCallable, Generic[OperatorNodeT]): foast_node: OperatorNodeT closure_vars: dict[str, Any] - definition: Optional[types.FunctionType] = None - backend: Optional[ppi.ProgramExecutor] = DEFAULT_BACKEND - grid_type: Optional[GridType] = None + definition: Optional[types.FunctionType] + backend: Optional[ppi.ProgramExecutor] + grid_type: Optional[GridType] operator_attributes: Optional[dict[str, Any]] = None _program_cache: dict = dataclasses.field(default_factory=dict) @@ -559,7 +561,7 @@ class FieldOperator(GTCallable, Generic[OperatorNodeT]): def from_function( cls, definition: types.FunctionType, - backend: Optional[ppi.ProgramExecutor] = DEFAULT_BACKEND, + backend: Optional[ppi.ProgramExecutor], grid_type: Optional[GridType] = None, *, operator_node_cls: type[OperatorNodeT] = foast.FieldOperator, @@ -686,6 +688,7 @@ def as_program( self._program_cache[hash_] = Program( past_node=past_node, closure_vars=closure_vars, + definition=None, backend=self.backend, grid_type=self.grid_type, ) @@ -698,7 +701,12 @@ def __call__( ) -> None: if not next_embedded.context.within_context() and self.backend is not None: # non embedded execution - offset_provider = kwargs.pop("offset_provider", None) + if "offset_provider" not in kwargs: + raise errors.MissingArgumentError(None, "offset_provider", True) + offset_provider = kwargs.pop("offset_provider") + + if "out" not in kwargs: + raise errors.MissingArgumentError(None, "out", True) out = kwargs.pop("out") args, kwargs = type_info.canonicalize_arguments(self.foast_node.type, args, kwargs) # TODO(tehrengruber): check all offset providers are given @@ -744,7 +752,7 @@ def field_operator( ... -def field_operator(definition=None, *, backend=None, grid_type=None): +def field_operator(definition=None, *, backend=eve.NOTHING, grid_type=None): """ Generate an implementation of the field operator from a Python function object. @@ -762,7 +770,9 @@ def field_operator(definition=None, *, backend=None, grid_type=None): """ def field_operator_inner(definition: types.FunctionType) -> FieldOperator[foast.FieldOperator]: - return FieldOperator.from_function(definition, backend, grid_type) + return FieldOperator.from_function( + definition, DEFAULT_BACKEND if backend is eve.NOTHING else backend, grid_type + ) return field_operator_inner if definition is None else field_operator_inner(definition) @@ -798,7 +808,7 @@ def scan_operator( axis: Dimension, forward: bool = True, init: core_defs.Scalar = 0.0, - backend=None, + backend=eve.NOTHING, grid_type: GridType = None, ) -> ( FieldOperator[foast.ScanOperator] @@ -836,8 +846,7 @@ def scan_operator( def scan_operator_inner(definition: types.FunctionType) -> FieldOperator: return FieldOperator.from_function( definition, - backend, - grid_type, + DEFAULT_BACKEND if backend is eve.NOTHING else backend, operator_node_cls=foast.ScanOperator, operator_attributes={"axis": axis, "forward": forward, "init": init}, ) diff --git a/tests/next_tests/integration_tests/feature_tests/ffront_tests/ffront_test_utils.py b/tests/next_tests/integration_tests/feature_tests/ffront_tests/ffront_test_utils.py index e25576ebde..1f5a1f0c48 100644 --- a/tests/next_tests/integration_tests/feature_tests/ffront_tests/ffront_test_utils.py +++ b/tests/next_tests/integration_tests/feature_tests/ffront_tests/ffront_test_utils.py @@ -22,6 +22,8 @@ import gt4py.next as gtx from gt4py.next.ffront import decorator from gt4py.next.iterator import ir as itir +from gt4py.next.program_processors import processor_interface as ppi +from gt4py.next.program_processors.runners import gtfn, roundtrip try: @@ -36,9 +38,10 @@ import next_tests.exclusion_matrices as definitions +@ppi.program_executor def no_backend(program: itir.FencilDefinition, *args: Any, **kwargs: Any) -> None: """Temporary default backend to not accidentally test the wrong backend.""" - raise ValueError("No backend selected. Backend selection is mandatory in tests.") + raise ValueError("No backend selected! Backend selection is mandatory in tests.") OPTIONAL_PROCESSORS = [] diff --git a/tests/next_tests/integration_tests/feature_tests/ffront_tests/test_math_builtin_execution.py b/tests/next_tests/integration_tests/feature_tests/ffront_tests/test_math_builtin_execution.py index 167ccbb0a5..4444742c66 100644 --- a/tests/next_tests/integration_tests/feature_tests/ffront_tests/test_math_builtin_execution.py +++ b/tests/next_tests/integration_tests/feature_tests/ffront_tests/test_math_builtin_execution.py @@ -13,7 +13,7 @@ # SPDX-License-Identifier: GPL-3.0-or-later import math -from typing import Callable +from typing import Callable, Optional import numpy as np import pytest @@ -22,6 +22,7 @@ from gt4py.next.ffront import dialect_ast_enums, fbuiltins, field_operator_ast as foast from gt4py.next.ffront.decorator import FieldOperator from gt4py.next.ffront.foast_passes.type_deduction import FieldOperatorTypeDeduction +from gt4py.next.program_processors import processor_interface as ppi from gt4py.next.type_system import type_translation from next_tests.integration_tests import cases @@ -39,7 +40,7 @@ # becomes easier. -def make_builtin_field_operator(builtin_name: str): +def make_builtin_field_operator(builtin_name: str, backend: Optional[ppi.ProgramExecutor]): # TODO(tehrengruber): creating a field operator programmatically should be # easier than what we need to do here. # construct annotation dictionary containing the input argument and return @@ -109,8 +110,9 @@ def make_builtin_field_operator(builtin_name: str): return FieldOperator( foast_node=typed_foast_node, closure_vars=closure_vars, - backend=None, definition=None, + backend=backend, + grid_type=None, ) @@ -129,9 +131,7 @@ def test_math_function_builtins_execution(cartesian_case, builtin_name: str, inp expected = ref_impl(*inputs) out = cartesian_case.as_field([IDim], np.zeros_like(expected)) - builtin_field_op = make_builtin_field_operator(builtin_name).with_backend( - cartesian_case.backend - ) + builtin_field_op = make_builtin_field_operator(builtin_name, cartesian_case.backend) builtin_field_op(*inps, out=out, offset_provider={}) diff --git a/tests/next_tests/integration_tests/multi_feature_tests/ffront_tests/test_embedded_regression.py b/tests/next_tests/integration_tests/multi_feature_tests/ffront_tests/test_embedded_regression.py new file mode 100644 index 0000000000..ba4b1b0cdb --- /dev/null +++ b/tests/next_tests/integration_tests/multi_feature_tests/ffront_tests/test_embedded_regression.py @@ -0,0 +1,137 @@ +# GT4Py - GridTools Framework +# +# Copyright (c) 2014-2023, ETH Zurich +# All rights reserved. +# +# This file is part of the GT4Py project and the GridTools framework. +# GT4Py is free software: you can redistribute it and/or modify it under +# the terms of the GNU General Public License as published by the +# Free Software Foundation, either version 3 of the License, or any later +# version. See the LICENSE.txt file at the top-level directory of this +# distribution for a copy of the license or check . +# +# SPDX-License-Identifier: GPL-3.0-or-later + +import numpy as np +import pytest + +from gt4py import next as gtx +from gt4py.next import errors + +from next_tests.integration_tests import cases +from next_tests.integration_tests.cases import IField, cartesian_case # noqa: F401 # fixtures +from next_tests.integration_tests.feature_tests.ffront_tests.ffront_test_utils import ( # noqa: F401 # fixtures + KDim, + fieldview_backend, +) + + +def test_default_backend_is_respected_field_operator(cartesian_case): # noqa: F811 # fixtures + """Test that manually calling the field operator without setting the backend raises an error.""" + + # Important not to set the backend here! + @gtx.field_operator + def copy(a: IField) -> IField: + return a + + a = cases.allocate(cartesian_case, copy, "a")() + + with pytest.raises(ValueError, match="No backend selected!"): + # Calling this should fail if the default backend is respected + # due to `fieldview_backend` fixture (dependency of `cartesian_case`) + # setting the default backend to something invalid. + _ = copy(a, out=a, offset_provider={}) + + +def test_default_backend_is_respected_scan_operator(cartesian_case): # noqa: F811 # fixtures + """Test that manually calling the scan operator without setting the backend raises an error.""" + + # Important not to set the backend here! + @gtx.scan_operator(axis=KDim, init=0.0, forward=True) + def sum(state: float, a: float) -> float: + return state + a + + a = gtx.ones({KDim: 10}, allocator=cartesian_case.backend) + + with pytest.raises(ValueError, match="No backend selected!"): + # see comment in field_operator test + _ = sum(a, out=a, offset_provider={}) + + +def test_default_backend_is_respected_program(cartesian_case): # noqa: F811 # fixtures + """Test that manually calling the program without setting the backend raises an error.""" + + @gtx.field_operator + def copy(a: IField) -> IField: + return a + + # Important not to set the backend here! + @gtx.program + def copy_program(a: IField, b: IField) -> IField: + copy(a, out=b) + + a = cases.allocate(cartesian_case, copy_program, "a")() + b = cases.allocate(cartesian_case, copy_program, "b")() + + with pytest.raises(ValueError, match="No backend selected!"): + # see comment in field_operator test + _ = copy_program(a, b, offset_provider={}) + + +def test_missing_arg_field_operator(cartesian_case): # noqa: F811 # fixtures + """Test that calling a field_operator without required args raises an error.""" + + @gtx.field_operator(backend=cartesian_case.backend) + def copy(a: IField) -> IField: + return a + + a = cases.allocate(cartesian_case, copy, "a")() + + with pytest.raises(errors.MissingArgumentError, match="'out'"): + _ = copy(a, offset_provider={}) + + with pytest.raises(errors.MissingArgumentError, match="'offset_provider'"): + _ = copy(a, out=a) + + +def test_missing_arg_scan_operator(cartesian_case): # noqa: F811 # fixtures + """Test that calling a scan_operator without required args raises an error.""" + + @gtx.scan_operator(backend=cartesian_case.backend, axis=KDim, init=0.0, forward=True) + def sum(state: float, a: float) -> float: + return state + a + + a = cases.allocate(cartesian_case, sum, "a")() + + with pytest.raises(errors.MissingArgumentError, match="'out'"): + _ = sum(a, offset_provider={}) + + with pytest.raises(errors.MissingArgumentError, match="'offset_provider'"): + _ = sum(a, out=a) + + +def test_missing_arg_program(cartesian_case): # noqa: F811 # fixtures + """Test that calling a program without required args raises an error.""" + + @gtx.field_operator + def copy(a: IField) -> IField: + return a + + a = cases.allocate(cartesian_case, copy, "a")() + b = cases.allocate(cartesian_case, copy, cases.RETURN)() + + with pytest.raises(errors.DSLError, match="Invalid call"): + + @gtx.program(backend=cartesian_case.backend) + def copy_program(a: IField, b: IField) -> IField: + copy(a) + + _ = copy_program(a, offset_provider={}) + + with pytest.raises(TypeError, match="'offset_provider'"): + + @gtx.program(backend=cartesian_case.backend) + def copy_program(a: IField, b: IField) -> IField: + copy(a, out=b) + + _ = copy_program(a) From 6b269bdf4bdc9e6e6c896b03c43f73c6342eb443 Mon Sep 17 00:00:00 2001 From: Hannes Vogt Date: Thu, 4 Jan 2024 17:04:55 +0100 Subject: [PATCH 40/85] bug[next]: recover grid_type in scan_operator (#1408) lost in merge conflict in #1380 --- src/gt4py/next/ffront/decorator.py | 1 + 1 file changed, 1 insertion(+) diff --git a/src/gt4py/next/ffront/decorator.py b/src/gt4py/next/ffront/decorator.py index 76a0ddcde0..147059b1bd 100644 --- a/src/gt4py/next/ffront/decorator.py +++ b/src/gt4py/next/ffront/decorator.py @@ -847,6 +847,7 @@ def scan_operator_inner(definition: types.FunctionType) -> FieldOperator: return FieldOperator.from_function( definition, DEFAULT_BACKEND if backend is eve.NOTHING else backend, + grid_type, operator_node_cls=foast.ScanOperator, operator_attributes={"axis": axis, "forward": forward, "init": init}, ) From 6e6271c2c5d3ad0e31038d5bd597bd2327534a5e Mon Sep 17 00:00:00 2001 From: SF-N <65219381+SF-N@users.noreply.github.com> Date: Wed, 17 Jan 2024 14:47:23 +0100 Subject: [PATCH 41/85] feature[next]: Add power unrolling functionality and respective unit tests. (#1409) * Add power unrolling functionality and respective unit tests. * Define base and exponent variables for better readability in PowerUnrolling * Remove distinction between SymRef and FunCall in power unrolling * Optimize power unrolling to avoid multiple computations of FunCalls * Further improve power unrolling * Update wrt review and adapt expected results respectively * Add correct annotation --------- Co-authored-by: Sara Faghih-Naini --- .../iterator/transforms/power_unrolling.py | 84 +++++++++ .../transforms_tests/test_power_unrolling.py | 161 ++++++++++++++++++ 2 files changed, 245 insertions(+) create mode 100644 src/gt4py/next/iterator/transforms/power_unrolling.py create mode 100644 tests/next_tests/unit_tests/iterator_tests/transforms_tests/test_power_unrolling.py diff --git a/src/gt4py/next/iterator/transforms/power_unrolling.py b/src/gt4py/next/iterator/transforms/power_unrolling.py new file mode 100644 index 0000000000..ac71f2747d --- /dev/null +++ b/src/gt4py/next/iterator/transforms/power_unrolling.py @@ -0,0 +1,84 @@ +# GT4Py - GridTools Framework +# +# Copyright (c) 2014-2023, ETH Zurich +# All rights reserved. +# +# This file is part of the GT4Py project and the GridTools framework. +# GT4Py is free software: you can redistribute it and/or modify it under +# the terms of the GNU General Public License as published by the +# Free Software Foundation, either version 3 of the License, or any later +# version. See the LICENSE.txt file at the top-level directory of this +# distribution for a copy of the license or check . +# +# SPDX-License-Identifier: GPL-3.0-or-later +import dataclasses +import math + +from gt4py.eve import NodeTranslator +from gt4py.next.iterator import ir +from gt4py.next.iterator.ir_utils import ir_makers as im +from gt4py.next.iterator.transforms.inline_lambdas import InlineLambdas + + +def _is_power_call( + node: ir.FunCall, +) -> bool: + """Match expressions of the form `power(base, integral_literal)`.""" + return ( + isinstance(node.fun, ir.SymRef) + and node.fun.id == "power" + and isinstance(node.args[1], ir.Literal) + and float(node.args[1].value) == int(node.args[1].value) + and node.args[1].value >= im.literal_from_value(0).value + ) + + +def _compute_integer_power_of_two(exp: int) -> int: + return math.floor(math.log2(exp)) + + +@dataclasses.dataclass +class PowerUnrolling(NodeTranslator): + max_unroll: int + + @classmethod + def apply(cls, node: ir.Node, max_unroll: int = 5) -> ir.Node: + return cls(max_unroll=max_unroll).visit(node) + + def visit_FunCall(self, node: ir.FunCall): + new_node = self.generic_visit(node) + + if _is_power_call(new_node): + assert len(new_node.args) == 2 + # Check if unroll should be performed or if exponent is too large + base, exponent = new_node.args[0], int(new_node.args[1].value) + if 1 <= exponent <= self.max_unroll: + # Calculate and store powers of two of the base as long as they are smaller than the exponent. + # Do the same (using the stored values) with the remainder and multiply computed values. + pow_cur = _compute_integer_power_of_two(exponent) + pow_max = pow_cur + remainder = exponent + + # Build target expression + ret = im.ref(f"power_{2 ** pow_max}") + remainder -= 2**pow_cur + while remainder > 0: + pow_cur = _compute_integer_power_of_two(remainder) + remainder -= 2**pow_cur + + ret = im.multiplies_(ret, f"power_{2 ** pow_cur}") + + # Nest target expression to avoid multiple redundant evaluations + for i in range(pow_max, 0, -1): + ret = im.let( + f"power_{2 ** i}", + im.multiplies_(f"power_{2**(i-1)}", f"power_{2**(i-1)}"), + )(ret) + ret = im.let("power_1", base)(ret) + + # Simplify expression in case of SymRef by resolving let statements + if isinstance(base, ir.SymRef): + return InlineLambdas.apply(ret, opcount_preserving=True) + else: + return ret + return new_node diff --git a/tests/next_tests/unit_tests/iterator_tests/transforms_tests/test_power_unrolling.py b/tests/next_tests/unit_tests/iterator_tests/transforms_tests/test_power_unrolling.py new file mode 100644 index 0000000000..ae23becb4c --- /dev/null +++ b/tests/next_tests/unit_tests/iterator_tests/transforms_tests/test_power_unrolling.py @@ -0,0 +1,161 @@ +# GT4Py - GridTools Framework +# +# Copyright (c) 2014-2023, ETH Zurich +# All rights reserved. +# +# This file is part of the GT4Py project and the GridTools framework. +# GT4Py is free software: you can redistribute it and/or modify it under +# the terms of the GNU General Public License as published by the +# Free Software Foundation, either version 3 of the License, or any later +# version. See the LICENSE.txt file at the top-level directory of this +# distribution for a copy of the license or check . +# +# SPDX-License-Identifier: GPL-3.0-or-later + +import pytest + +from gt4py.eve import SymbolRef +from gt4py.next.iterator import ir +from gt4py.next.iterator.ir_utils import ir_makers as im +from gt4py.next.iterator.transforms.power_unrolling import PowerUnrolling + + +def test_power_unrolling_zero(): + pytest.xfail( + "Not implementeds we don't have an easy way to determine the type of the one literal (type inference is to expensive)." + ) + testee = im.call("power")("x", 0) + expected = im.literal_from_value(1) + + actual = PowerUnrolling.apply(testee) + assert actual == expected + + +def test_power_unrolling_one(): + testee = im.call("power")("x", 1) + expected = ir.SymRef(id=SymbolRef("x")) + + actual = PowerUnrolling.apply(testee) + assert actual == expected + + +def test_power_unrolling_two(): + testee = im.call("power")("x", 2) + expected = im.multiplies_("x", "x") + + actual = PowerUnrolling.apply(testee) + assert actual == expected + + +def test_power_unrolling_two_x_plus_two(): + testee = im.call("power")(im.plus("x", 2), 2) + expected = im.let("power_1", im.plus("x", 2))( + im.let("power_2", im.multiplies_("power_1", "power_1"))("power_2") + ) + + actual = PowerUnrolling.apply(testee) + assert actual == expected + + +def test_power_unrolling_two_x_plus_one_times_three(): + testee = im.call("power")(im.multiplies_(im.plus("x", 1), 3), 2) + expected = im.let("power_1", im.multiplies_(im.plus("x", 1), 3))( + im.let("power_2", im.multiplies_("power_1", "power_1"))("power_2") + ) + + actual = PowerUnrolling.apply(testee) + assert actual == expected + + +def test_power_unrolling_three(): + testee = im.call("power")("x", 3) + expected = im.multiplies_(im.multiplies_("x", "x"), "x") + + actual = PowerUnrolling.apply(testee) + assert actual == expected + + +def test_power_unrolling_four(): + testee = im.call("power")("x", 4) + expected = im.let("power_2", im.multiplies_("x", "x"))(im.multiplies_("power_2", "power_2")) + + actual = PowerUnrolling.apply(testee) + assert actual == expected + + +def test_power_unrolling_five(): + testee = im.call("power")("x", 5) + tmp2 = im.multiplies_("x", "x") + expected = im.multiplies_(im.multiplies_(tmp2, tmp2), "x") + expected = im.let("power_2", im.multiplies_("x", "x"))( + im.multiplies_(im.multiplies_("power_2", "power_2"), "x") + ) + + actual = PowerUnrolling.apply(testee) + assert actual == expected + + +def test_power_unrolling_seven(): + testee = im.call("power")("x", 7) + expected = im.call("power")("x", 7) + + actual = PowerUnrolling.apply(testee, max_unroll=5) + assert actual == expected + + +def test_power_unrolling_seven_unrolled(): + testee = im.call("power")("x", 7) + expected = im.let("power_2", im.multiplies_("x", "x"))( + im.multiplies_(im.multiplies_(im.multiplies_("power_2", "power_2"), "power_2"), "x") + ) + + actual = PowerUnrolling.apply(testee, max_unroll=7) + assert actual == expected + + +def test_power_unrolling_seven_x_plus_one_unrolled(): + testee = im.call("power")(im.plus("x", 1), 7) + expected = im.let("power_1", im.plus("x", 1))( + im.let("power_2", im.multiplies_("power_1", "power_1"))( + im.let("power_4", im.multiplies_("power_2", "power_2"))( + im.multiplies_(im.multiplies_("power_4", "power_2"), "power_1") + ) + ) + ) + + actual = PowerUnrolling.apply(testee, max_unroll=7) + assert actual == expected + + +def test_power_unrolling_eight(): + testee = im.call("power")("x", 8) + expected = im.call("power")("x", 8) + + actual = PowerUnrolling.apply(testee, max_unroll=5) + assert actual == expected + + +def test_power_unrolling_eight_unrolled(): + testee = im.call("power")("x", 8) + expected = im.let("power_2", im.multiplies_("x", "x"))( + im.let("power_4", im.multiplies_("power_2", "power_2"))( + im.multiplies_("power_4", "power_4") + ) + ) + + actual = PowerUnrolling.apply(testee, max_unroll=8) + assert actual == expected + + +def test_power_unrolling_eight_x_plus_one_unrolled(): + testee = im.call("power")(im.plus("x", 1), 8) + expected = im.let("power_1", im.plus("x", 1))( + im.let("power_2", im.multiplies_("power_1", "power_1"))( + im.let("power_4", im.multiplies_("power_2", "power_2"))( + im.let("power_8", im.multiplies_("power_4", "power_4"))("power_8") + ) + ) + ) + + actual = PowerUnrolling.apply(testee, max_unroll=8) + assert actual == expected From 6283ac930ee301a02300b2ad9bb440b3aab04b2d Mon Sep 17 00:00:00 2001 From: edopao Date: Wed, 17 Jan 2024 15:50:57 +0100 Subject: [PATCH 42/85] fix[cartesian]: DaCe array access in tasklet (#1410) Found some incompatible tasklet representation while upgrading to dace v0.15.1. Array access inside tasklet with partial index subset worked in v0.14.1, although not valid. The fix consists of modifying the memlets to pass the full array shape to such tasklet, and use all explicit indices inside the tasklet to access the array. This is the right representation in DaCe SDFG, as discussed with the DaCe developers. --- .../gtc/dace/expansion/daceir_builder.py | 35 +++++++++++++++++++ src/gt4py/cartesian/gtc/daceir.py | 4 +-- 2 files changed, 37 insertions(+), 2 deletions(-) diff --git a/src/gt4py/cartesian/gtc/dace/expansion/daceir_builder.py b/src/gt4py/cartesian/gtc/dace/expansion/daceir_builder.py index db276a48b9..48b129fa87 100644 --- a/src/gt4py/cartesian/gtc/dace/expansion/daceir_builder.py +++ b/src/gt4py/cartesian/gtc/dace/expansion/daceir_builder.py @@ -30,6 +30,7 @@ compute_dcir_access_infos, flatten_list, get_tasklet_symbol, + make_dace_subset, union_inout_memlets, union_node_grid_subsets, untile_memlets, @@ -458,6 +459,40 @@ def visit_HorizontalExecution( write_memlets=write_memlets, ) + for memlet in [*read_memlets, *write_memlets]: + """ + This loop handles the special case of a tasklet performing array access. + The memlet should pass the full array shape (no tiling) and + the tasklet expression for array access should use all explicit indexes. + """ + array_ndims = len(global_ctx.arrays[memlet.field].shape) + field_decl = global_ctx.library_node.field_decls[memlet.field] + # calculate array subset on original memlet + memlet_subset = make_dace_subset( + global_ctx.library_node.access_infos[memlet.field], + memlet.access_info, + field_decl.data_dims, + ) + # select index values for single-point grid access + memlet_data_index = [ + dcir.Literal(value=str(dim_range[0]), dtype=common.DataType.INT32) + for dim_range, dim_size in zip(memlet_subset, memlet_subset.size()) + if dim_size == 1 + ] + if len(memlet_data_index) < array_ndims: + reshape_memlet = False + for access_node in dcir_node.walk_values().if_isinstance(dcir.IndexAccess): + if access_node.data_index and access_node.name == memlet.connector: + access_node.data_index = memlet_data_index + access_node.data_index + assert len(access_node.data_index) == array_ndims + reshape_memlet = True + if reshape_memlet: + # ensure that memlet symbols used for array indexing are defined in context + for sym in memlet.access_info.grid_subset.free_symbols: + symbol_collector.add_symbol(sym) + # set full shape on memlet + memlet.access_info = global_ctx.library_node.access_infos[memlet.field] + for item in reversed(expansion_items): iteration_ctx = iteration_ctx.pop() dcir_node = self._process_iteration_item( diff --git a/src/gt4py/cartesian/gtc/daceir.py b/src/gt4py/cartesian/gtc/daceir.py index 28ebc8cd8e..0366317360 100644 --- a/src/gt4py/cartesian/gtc/daceir.py +++ b/src/gt4py/cartesian/gtc/daceir.py @@ -536,7 +536,7 @@ def union(self, other): else: assert ( isinstance(interval2, (TileInterval, DomainInterval)) - and isinstance(interval1, IndexWithExtent) + and isinstance(interval1, (IndexWithExtent, DomainInterval)) ) or ( isinstance(interval1, (TileInterval, DomainInterval)) and isinstance(interval2, IndexWithExtent) @@ -573,7 +573,7 @@ def overapproximated_shape(self): def apply_iteration(self, grid_subset: GridSubset): res_intervals = dict(self.grid_subset.intervals) for axis, field_interval in self.grid_subset.intervals.items(): - if axis in grid_subset.intervals: + if axis in grid_subset.intervals and not isinstance(field_interval, DomainInterval): grid_interval = grid_subset.intervals[axis] assert isinstance(field_interval, IndexWithExtent) extent = field_interval.extent From 3edf21e9c7fce64976068d19d2a15c5b856d94d6 Mon Sep 17 00:00:00 2001 From: Nicoletta Farabullini <41536517+nfarabullini@users.noreply.github.com> Date: Thu, 18 Jan 2024 09:19:00 +0100 Subject: [PATCH 43/85] bug[next]: Bound args kwargs edit (#1411) * edits for BoundArgs with kwargs in correct order --- src/gt4py/next/ffront/decorator.py | 17 +++-- .../ffront_tests/test_bound_args.py | 64 +++++++++++++++++++ .../ffront_tests/test_execution.py | 20 ------ 3 files changed, 75 insertions(+), 26 deletions(-) create mode 100644 tests/next_tests/integration_tests/feature_tests/ffront_tests/test_bound_args.py diff --git a/src/gt4py/next/ffront/decorator.py b/src/gt4py/next/ffront/decorator.py index 147059b1bd..05cbe1c882 100644 --- a/src/gt4py/next/ffront/decorator.py +++ b/src/gt4py/next/ffront/decorator.py @@ -453,27 +453,32 @@ def _process_args(self, args: tuple, kwargs: dict): ) from err full_args = [*args] + full_kwargs = {**kwargs} for index, param in enumerate(self.past_node.params): if param.id in self.bound_args.keys(): - full_args.insert(index, self.bound_args[param.id]) + if index < len(full_args): + full_args.insert(index, self.bound_args[param.id]) + else: + full_kwargs[str(param.id)] = self.bound_args[param.id] - return super()._process_args(tuple(full_args), kwargs) + return super()._process_args(tuple(full_args), full_kwargs) @functools.cached_property def itir(self): new_itir = super().itir for new_clos in new_itir.closures: - for key in self.bound_args.keys(): + new_args = [ref(inp.id) for inp in new_clos.inputs] + for key, value in self.bound_args.items(): index = next( index for index, closure_input in enumerate(new_clos.inputs) if closure_input.id == key ) + new_args[new_args.index(new_clos.inputs[index])] = promote_to_const_iterator( + literal_from_value(value) + ) new_clos.inputs.pop(index) - new_args = [ref(inp.id) for inp in new_clos.inputs] params = [sym(inp.id) for inp in new_clos.inputs] - for value in self.bound_args.values(): - new_args.append(promote_to_const_iterator(literal_from_value(value))) expr = itir.FunCall( fun=new_clos.stencil, args=new_args, diff --git a/tests/next_tests/integration_tests/feature_tests/ffront_tests/test_bound_args.py b/tests/next_tests/integration_tests/feature_tests/ffront_tests/test_bound_args.py new file mode 100644 index 0000000000..0de953d85f --- /dev/null +++ b/tests/next_tests/integration_tests/feature_tests/ffront_tests/test_bound_args.py @@ -0,0 +1,64 @@ +# -*- coding: utf-8 -*- +# GT4Py - GridTools Framework +# +# Copyright (c) 2014-2023, ETH Zurich +# All rights reserved. +# +# This file is part of the GT4Py project and the GridTools framework. +# GT4Py is free software: you can redistribute it and/or modify it under +# the terms of the GNU General Public License as published by the +# Free Software Foundation, either version 3 of the License, or any later +# version. See the LICENSE.txt file at the top-level directory of this +# distribution for a copy of the license or check . +# +# SPDX-License-Identifier: GPL-3.0-or-later + +import numpy as np + +import gt4py.next as gtx +from gt4py.next import int32 + +from next_tests.integration_tests import cases +from next_tests.integration_tests.cases import cartesian_case +from next_tests.integration_tests.feature_tests.ffront_tests.ffront_test_utils import ( + fieldview_backend, + reduction_setup, +) + + +def test_with_bound_args(cartesian_case): + @gtx.field_operator + def fieldop_bound_args(a: cases.IField, scalar: int32, condition: bool) -> cases.IField: + if not condition: + scalar = 0 + return a + scalar + + @gtx.program + def program_bound_args(a: cases.IField, scalar: int32, condition: bool, out: cases.IField): + fieldop_bound_args(a, scalar, condition, out=out) + + a = cases.allocate(cartesian_case, program_bound_args, "a")() + scalar = int32(1) + ref = a + scalar + out = cases.allocate(cartesian_case, program_bound_args, "out")() + + prog_bounds = program_bound_args.with_bound_args(scalar=scalar, condition=True) + cases.verify(cartesian_case, prog_bounds, a, out, inout=out, ref=ref) + + +def test_with_bound_args_order_args(cartesian_case): + @gtx.field_operator + def fieldop_args(a: cases.IField, condition: bool, scalar: int32) -> cases.IField: + scalar = 0 if not condition else scalar + return a + scalar + + @gtx.program(backend=cartesian_case.backend) + def program_args(a: cases.IField, condition: bool, scalar: int32, out: cases.IField): + fieldop_args(a, condition, scalar, out=out) + + a = cases.allocate(cartesian_case, program_args, "a")() + out = cases.allocate(cartesian_case, program_args, "out")() + + prog_bounds = program_args.with_bound_args(condition=True) + prog_bounds(a=a, scalar=int32(1), out=out, offset_provider={}) + np.allclose(out.asnumpy(), a.asnumpy() + int32(1)) diff --git a/tests/next_tests/integration_tests/feature_tests/ffront_tests/test_execution.py b/tests/next_tests/integration_tests/feature_tests/ffront_tests/test_execution.py index a08931628b..70c79d7b6c 100644 --- a/tests/next_tests/integration_tests/feature_tests/ffront_tests/test_execution.py +++ b/tests/next_tests/integration_tests/feature_tests/ffront_tests/test_execution.py @@ -898,26 +898,6 @@ def test_docstring(a: cases.IField): cases.verify(cartesian_case, test_docstring, a, inout=a, ref=a) -def test_with_bound_args(cartesian_case): - @gtx.field_operator - def fieldop_bound_args(a: cases.IField, scalar: int32, condition: bool) -> cases.IField: - if not condition: - scalar = 0 - return a + a + scalar - - @gtx.program - def program_bound_args(a: cases.IField, scalar: int32, condition: bool, out: cases.IField): - fieldop_bound_args(a, scalar, condition, out=out) - - a = cases.allocate(cartesian_case, program_bound_args, "a")() - scalar = int32(1) - ref = a + a + 1 - out = cases.allocate(cartesian_case, program_bound_args, "out")() - - prog_bounds = program_bound_args.with_bound_args(scalar=scalar, condition=True) - cases.verify(cartesian_case, prog_bounds, a, out, inout=out, ref=ref) - - def test_domain(cartesian_case): @gtx.field_operator def fieldop_domain(a: cases.IField) -> cases.IField: From ba368564c27807cbac207bd7d5631501f87b062a Mon Sep 17 00:00:00 2001 From: Hannes Vogt Date: Thu, 18 Jan 2024 11:14:43 +0100 Subject: [PATCH 44/85] example: cartesian with next compatibility (#1202) Add an example illustrating using gt4py.cartesian and gt4py.next computations next to each other using gt4py.next storages. Refactor GTFieldInterface and cleanup GTDimsInterface for next. --- examples/lap_cartesian_vs_next.ipynb | 189 ++++++++++++++++++ src/gt4py/next/__init__.py | 9 + src/gt4py/next/common.py | 35 ++-- src/gt4py/next/embedded/nd_array_field.py | 4 - src/gt4py/next/iterator/embedded.py | 32 +-- src/gt4py/next/iterator/tracing.py | 2 +- .../next/type_system/type_translation.py | 2 +- src/gt4py/storage/cartesian/utils.py | 4 + 8 files changed, 239 insertions(+), 38 deletions(-) create mode 100644 examples/lap_cartesian_vs_next.ipynb diff --git a/examples/lap_cartesian_vs_next.ipynb b/examples/lap_cartesian_vs_next.ipynb new file mode 100644 index 0000000000..cb80122570 --- /dev/null +++ b/examples/lap_cartesian_vs_next.ipynb @@ -0,0 +1,189 @@ +{ + "cells": [ + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "GT4Py - GridTools for Python\n", + "\n", + "Copyright (c) 2014-2023, ETH Zurich\n", + "All rights reserved.\n", + "\n", + "This file is part the GT4Py project and the GridTools framework.\n", + "GT4Py is free software: you can redistribute it and/or modify it under\n", + "the terms of the GNU General Public License as published by the\n", + "Free Software Foundation, either version 3 of the License, or any later\n", + "version. See the LICENSE.txt file at the top-level directory of this\n", + "distribution for a copy of the license or check .\n", + "\n", + "SPDX-License-Identifier: GPL-3.0-or-later" + ] + }, + { + "attachments": {}, + "cell_type": "markdown", + "metadata": {}, + "source": [ + "# Demonstrates gt4py.cartesian with gt4py.next compatibility" + ] + }, + { + "attachments": {}, + "cell_type": "markdown", + "metadata": {}, + "source": [ + "Imports" + ] + }, + { + "cell_type": "code", + "execution_count": 1, + "metadata": {}, + "outputs": [], + "source": [ + "import numpy as np\n", + "\n", + "nx = 32\n", + "ny = 32\n", + "nz = 1\n", + "dtype = np.float64" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "Storages\n", + "--\n", + "\n", + "We create fields using the gt4py.next constructors. These fields are compatible with gt4py.cartesian when we use \"I\", \"J\", \"K\" as the dimension names." + ] + }, + { + "cell_type": "code", + "execution_count": 2, + "metadata": {}, + "outputs": [ + { + "data": { + "text/html": [ + "\n", + "\n" + ], + "text/plain": [ + "" + ] + }, + "metadata": {}, + "output_type": "display_data" + } + ], + "source": [ + "import gt4py.next as gtx\n", + "\n", + "allocator = gtx.itir_embedded # should match the executor\n", + "# allocator = gtx.gtfn_cpu\n", + "# allocator = gtx.gtfn_gpu\n", + "\n", + "# Note: for gt4py.next, names don't matter, for gt4py.cartesian they have to be \"I\", \"J\", \"K\"\n", + "I = gtx.Dimension(\"I\")\n", + "J = gtx.Dimension(\"J\")\n", + "K = gtx.Dimension(\"K\", kind=gtx.DimensionKind.VERTICAL)\n", + "\n", + "domain = gtx.domain({I: nx, J: ny, K: nz})\n", + "\n", + "inp = gtx.as_field(domain, np.fromfunction(lambda x, y, z: x**2+y**2, shape=(nx, ny, nz)), dtype, allocator=allocator)\n", + "out_cartesian = gtx.zeros(domain, dtype, allocator=allocator)\n", + "out_next = gtx.zeros(domain, dtype, allocator=allocator)" + ] + }, + { + "attachments": {}, + "cell_type": "markdown", + "metadata": {}, + "source": [ + "gt4py.cartesian\n", + "--" + ] + }, + { + "cell_type": "code", + "execution_count": 3, + "metadata": {}, + "outputs": [], + "source": [ + "import gt4py.cartesian.gtscript as gtscript\n", + "\n", + "cartesian_backend = \"numpy\"\n", + "# cartesian_backend = \"gt:cpu_ifirst\"\n", + "# cartesian_backend = \"gt:gpu\"\n", + "\n", + "@gtscript.stencil(backend=cartesian_backend)\n", + "def lap_cartesian(\n", + " inp: gtscript.Field[dtype],\n", + " out: gtscript.Field[dtype],\n", + "):\n", + " with computation(PARALLEL), interval(...):\n", + " out = -4.0 * inp[0, 0, 0] + inp[-1, 0, 0] + inp[1, 0, 0] + inp[0, -1, 0] + inp[0, 1, 0]\n", + "\n", + "lap_cartesian(inp=inp, out=out_cartesian, origin=(1, 1, 0), domain=(nx-2, ny-2, nz))" + ] + }, + { + "cell_type": "code", + "execution_count": 4, + "metadata": {}, + "outputs": [], + "source": [ + "from gt4py.next import Field\n", + "\n", + "next_backend = gtx.itir_embedded\n", + "# next_backend = gtx.gtfn_cpu\n", + "# next_backend = gtx.gtfn_gpu\n", + "\n", + "Ioff = gtx.FieldOffset(\"I\", source=I, target=(I,))\n", + "Joff = gtx.FieldOffset(\"J\", source=J, target=(J,))\n", + "\n", + "@gtx.field_operator\n", + "def lap_next(inp: Field[[I, J, K], dtype]) -> Field[[I, J, K], dtype]:\n", + " return -4.0 * inp + inp(Ioff[-1]) + inp(Ioff[1]) + inp(Joff[-1]) + inp(Joff[1])\n", + "\n", + "@gtx.program(backend=next_backend)\n", + "def lap_next_program(inp: Field[[I, J, K], dtype], out: Field[[I, J, K], dtype]):\n", + " lap_next(inp, out=out[1:-1, 1:-1, :])\n", + "\n", + "lap_next_program(inp, out_next, offset_provider={\"Ioff\": I, \"Joff\": J})" + ] + }, + { + "cell_type": "code", + "execution_count": 5, + "metadata": {}, + "outputs": [], + "source": [ + "assert np.allclose(out_cartesian.asnumpy(), out_next.asnumpy())" + ] + } + ], + "metadata": { + "kernelspec": { + "display_name": "Python 3 (ipykernel)", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.10.13" + } + }, + "nbformat": 4, + "nbformat_minor": 4 +} diff --git a/src/gt4py/next/__init__.py b/src/gt4py/next/__init__.py index cbd5735949..1398af5f03 100644 --- a/src/gt4py/next/__init__.py +++ b/src/gt4py/next/__init__.py @@ -39,6 +39,11 @@ index_field, np_as_located_field, ) +from .program_processors.runners.gtfn import ( + run_gtfn_cached as gtfn_cpu, + run_gtfn_gpu_cached as gtfn_gpu, +) +from .program_processors.runners.roundtrip import backend as itir_python __all__ = [ @@ -74,5 +79,9 @@ "field_operator", "program", "scan_operator", + # from program_processor + "gtfn_cpu", + "gtfn_gpu", + "itir_python", *fbuiltins.__all__, ] diff --git a/src/gt4py/next/common.py b/src/gt4py/next/common.py index 29d606ccc0..6bf6858369 100644 --- a/src/gt4py/next/common.py +++ b/src/gt4py/next/common.py @@ -574,38 +574,39 @@ def __call__(self, func: fbuiltins.BuiltInFunction[_R, _P], /) -> Callable[_P, _ ... -# TODO(havogt): replace this protocol with the new `GTFieldInterface` protocol -class NextGTDimsInterface(Protocol): +# TODO(havogt): we need to describe when this interface should be used instead of the `Field` protocol. +class GTFieldInterface(core_defs.GTDimsInterface, core_defs.GTOriginInterface, Protocol): """ - Protocol for objects providing the `__gt_dims__` property, naming :class:`Field` dimensions. + Protocol for object providing the `__gt_domain__` property, specifying the :class:`Domain` of a :class:`Field`. - The dimension names are objects of type :class:`Dimension`, in contrast to - :mod:`gt4py.cartesian`, where the labels are `str` s with implied semantics, - see :class:`~gt4py._core.definitions.GTDimsInterface` . + Note: + - A default implementation of the `__gt_dims__` interface from `gt4py.cartesian` is provided. + - No implementation of `__gt_origin__` is provided because of infinite fields. """ @property - def __gt_dims__(self) -> tuple[Dimension, ...]: + def __gt_domain__(self) -> Domain: + # TODO probably should be changed to `DomainLike` (with a new concept `DimensionLike`) + # to allow implementations without having to import gtx.Domain. ... - -# TODO(egparedes): add support for this new protocol in the cartesian module -class GTFieldInterface(Protocol): - """Protocol for object providing the `__gt_domain__` property, specifying the :class:`Domain` of a :class:`Field`.""" - @property - def __gt_domain__(self) -> Domain: - ... + def __gt_dims__(self) -> tuple[str, ...]: + return tuple(d.value for d in self.__gt_domain__.dims) @extended_runtime_checkable -class Field(NextGTDimsInterface, core_defs.GTOriginInterface, Protocol[DimsT, core_defs.ScalarT]): +class Field(GTFieldInterface, Protocol[DimsT, core_defs.ScalarT]): __gt_builtin_func__: ClassVar[GTBuiltInFuncDispatcher] @property def domain(self) -> Domain: ... + @property + def __gt_domain__(self) -> Domain: + return self.domain + @property def codomain(self) -> type[core_defs.ScalarT] | Dimension: ... @@ -923,10 +924,6 @@ def asnumpy(self) -> Never: def domain(self) -> Domain: return Domain(dims=(self.dimension,), ranges=(UnitRange.infinite(),)) - @property - def __gt_dims__(self) -> tuple[Dimension, ...]: - return self.domain.dims - @property def __gt_origin__(self) -> Never: raise TypeError("'CartesianConnectivity' does not support this operation.") diff --git a/src/gt4py/next/embedded/nd_array_field.py b/src/gt4py/next/embedded/nd_array_field.py index 8bd2673db9..9fc1b42038 100644 --- a/src/gt4py/next/embedded/nd_array_field.py +++ b/src/gt4py/next/embedded/nd_array_field.py @@ -107,10 +107,6 @@ def domain(self) -> common.Domain: def shape(self) -> tuple[int, ...]: return self._ndarray.shape - @property - def __gt_dims__(self) -> tuple[common.Dimension, ...]: - return self._domain.dims - @property def __gt_origin__(self) -> tuple[int, ...]: assert common.Domain.is_finite(self._domain) diff --git a/src/gt4py/next/iterator/embedded.py b/src/gt4py/next/iterator/embedded.py index ef70a2e645..390bec4312 100644 --- a/src/gt4py/next/iterator/embedded.py +++ b/src/gt4py/next/iterator/embedded.py @@ -172,7 +172,7 @@ class LocatedField(Protocol): @property @abc.abstractmethod - def __gt_dims__(self) -> tuple[common.Dimension, ...]: + def __gt_domain__(self) -> common.Domain: ... # TODO(havogt): define generic Protocol to provide a concrete return type @@ -182,7 +182,7 @@ def field_getitem(self, indices: NamedFieldIndices) -> Any: @property def __gt_origin__(self) -> tuple[int, ...]: - return tuple([0] * len(self.__gt_dims__)) + return tuple([0] * len(self.__gt_domain__.dims)) @runtime_checkable @@ -675,12 +675,18 @@ def _is_concrete_position(pos: Position) -> TypeGuard[ConcretePosition]: def _get_axes( field_or_tuple: LocatedField | tuple, ) -> Sequence[common.Dimension]: # arbitrary nesting of tuples of LocatedField + return _get_domain(field_or_tuple).dims + + +def _get_domain( + field_or_tuple: LocatedField | tuple, +) -> common.Domain: # arbitrary nesting of tuples of LocatedField if isinstance(field_or_tuple, tuple): - first = _get_axes(field_or_tuple[0]) - assert all(first == _get_axes(f) for f in field_or_tuple) + first = _get_domain(field_or_tuple[0]) + assert all(first == _get_domain(f) for f in field_or_tuple) return first else: - return field_or_tuple.__gt_dims__ + return field_or_tuple.__gt_domain__ def _single_vertical_idx( @@ -894,14 +900,14 @@ class NDArrayLocatedFieldWrapper(MutableLocatedField): _ndarrayfield: common.Field @property - def __gt_dims__(self) -> tuple[common.Dimension, ...]: - return self._ndarrayfield.__gt_dims__ + def __gt_domain__(self) -> common.Domain: + return self._ndarrayfield.__gt_domain__ def _translate_named_indices( self, _named_indices: NamedFieldIndices ) -> common.AbsoluteIndexSequence: named_indices: Mapping[common.Dimension, FieldIndex | SparsePositionEntry] = { - d: _named_indices[d.value] for d in self._ndarrayfield.__gt_dims__ + d: _named_indices[d.value] for d in self._ndarrayfield.__gt_domain__.dims } domain_slice: list[common.NamedRange | common.NamedIndex] = [] for d, v in named_indices.items(): @@ -1046,8 +1052,8 @@ class IndexField(common.Field): _dimension: common.Dimension @property - def __gt_dims__(self) -> tuple[common.Dimension, ...]: - return (self._dimension,) + def __gt_domain__(self) -> common.Domain: + return self.domain @property def __gt_origin__(self) -> tuple[int, ...]: @@ -1165,8 +1171,8 @@ class ConstantField(common.Field[Any, core_defs.ScalarT]): _value: core_defs.ScalarT @property - def __gt_dims__(self) -> tuple[common.Dimension, ...]: - return tuple() + def __gt_domain__(self) -> common.Domain: + return self.domain @property def __gt_origin__(self) -> tuple[int, ...]: @@ -1452,7 +1458,7 @@ def _tuple_assign(field: tuple | MutableLocatedField, value: Any, named_indices: class TupleOfFields(TupleField): def __init__(self, data): self.data = data - self.__gt_dims__ = _get_axes(data) + self.__gt_domain__ = _get_domain(data) def field_getitem(self, named_indices: NamedFieldIndices) -> Any: return _build_tuple_result(self.data, named_indices) diff --git a/src/gt4py/next/iterator/tracing.py b/src/gt4py/next/iterator/tracing.py index 30fec1f9fd..05ebd02352 100644 --- a/src/gt4py/next/iterator/tracing.py +++ b/src/gt4py/next/iterator/tracing.py @@ -254,7 +254,7 @@ def _contains_tuple_dtype_field(arg): # other `np.int32`). We just ignore the error here and postpone fixing this to when # the new storages land (The implementation here works for LocatedFieldImpl). - return common.is_field(arg) and any(dim is None for dim in arg.__gt_dims__) + return common.is_field(arg) and any(dim is None for dim in arg.domain.dims) def _make_fencil_params(fun, args, *, use_arg_types: bool) -> list[Sym]: diff --git a/src/gt4py/next/type_system/type_translation.py b/src/gt4py/next/type_system/type_translation.py index 88a8347fe4..12649bf620 100644 --- a/src/gt4py/next/type_system/type_translation.py +++ b/src/gt4py/next/type_system/type_translation.py @@ -184,7 +184,7 @@ def from_value(value: Any) -> ts.TypeSpec: elif isinstance(value, common.Dimension): symbol_type = ts.DimensionType(dim=value) elif common.is_field(value): - dims = list(value.__gt_dims__) + dims = list(value.domain.dims) dtype = from_type_hint(value.dtype.scalar_type) symbol_type = ts.FieldType(dims=dims, dtype=dtype) elif isinstance(value, tuple): diff --git a/src/gt4py/storage/cartesian/utils.py b/src/gt4py/storage/cartesian/utils.py index 0f7cf5d0ab..4e7ebb0c21 100644 --- a/src/gt4py/storage/cartesian/utils.py +++ b/src/gt4py/storage/cartesian/utils.py @@ -192,6 +192,10 @@ def cpu_copy(array: Union[np.ndarray, "cp.ndarray"]) -> np.ndarray: def asarray( array: FieldLike, *, device: Literal["cpu", "gpu", None] = None ) -> np.ndarray | cp.ndarray: + if hasattr(array, "ndarray"): + # extract the buffer from a gt4py.next.Field + # TODO(havogt): probably `Field` should provide the array interface methods when applicable + array = array.ndarray if device == "gpu" or (not device and hasattr(array, "__cuda_array_interface__")): return cp.asarray(array) if device == "cpu" or ( From 49db7efadbe6c5329fdadbf3e3f3a0fd1728ee00 Mon Sep 17 00:00:00 2001 From: Samuel Date: Thu, 18 Jan 2024 17:00:33 +0100 Subject: [PATCH 45/85] feat[next]: Pass sizes to temporaries from gt4py program (#1359) --- src/gt4py/eve/trees.py | 4 +- .../next/iterator/transforms/global_tmps.py | 41 +++++-- .../next/iterator/transforms/pass_manager.py | 6 +- src/gt4py/next/iterator/type_inference.py | 6 +- .../codegens/gtfn/gtfn_backend.py | 77 ------------ .../codegens/gtfn/gtfn_module.py | 95 +++++++++++---- .../program_processors/formatters/gtfn.py | 13 +- .../test_temporaries_with_sizes.py | 113 ++++++++++++++++++ .../cpp_backend_tests/anton_lap.py | 6 +- .../cpp_backend_tests/copy_stencil.py | 6 +- .../copy_stencil_field_view.py | 6 +- .../cpp_backend_tests/fvm_nabla.py | 11 +- .../cpp_backend_tests/tridiagonal_solve.py | 6 +- .../transforms_tests/test_global_tmps.py | 2 +- 14 files changed, 264 insertions(+), 128 deletions(-) delete mode 100644 src/gt4py/next/program_processors/codegens/gtfn/gtfn_backend.py create mode 100644 tests/next_tests/integration_tests/feature_tests/ffront_tests/test_temporaries_with_sizes.py diff --git a/src/gt4py/eve/trees.py b/src/gt4py/eve/trees.py index cd7e71588f..74c5bd41bb 100644 --- a/src/gt4py/eve/trees.py +++ b/src/gt4py/eve/trees.py @@ -133,7 +133,7 @@ def _pre_walk_items( yield from _pre_walk_items(child, __key__=key) -def _pre_walk_values(node: TreeLike) -> Iterable[Tuple[Any]]: +def _pre_walk_values(node: TreeLike) -> Iterable: """Create a pre-order tree traversal iterator of values.""" yield node for child in iter_children_values(node): @@ -153,7 +153,7 @@ def _post_walk_items( yield __key__, node -def _post_walk_values(node: TreeLike) -> Iterable[Tuple[Any]]: +def _post_walk_values(node: TreeLike) -> Iterable: """Create a post-order tree traversal iterator of values.""" if (iter_children_values := getattr(node, "iter_children_values", None)) is not None: for child in iter_children_values(): diff --git a/src/gt4py/next/iterator/transforms/global_tmps.py b/src/gt4py/next/iterator/transforms/global_tmps.py index d9d3d18213..0033f36cab 100644 --- a/src/gt4py/next/iterator/transforms/global_tmps.py +++ b/src/gt4py/next/iterator/transforms/global_tmps.py @@ -22,6 +22,7 @@ from gt4py.eve import Coerced, NodeTranslator from gt4py.eve.traits import SymbolTableTrait from gt4py.eve.utils import UIDGenerator +from gt4py.next import common from gt4py.next.iterator import ir, type_inference from gt4py.next.iterator.ir_utils import ir_makers as im from gt4py.next.iterator.ir_utils.common_pattern_matcher import is_applied_lift @@ -437,9 +438,12 @@ def _group_offsets( return zip(tags, offsets, strict=True) # type: ignore[return-value] # mypy doesn't infer literal correctly -def update_domains(node: FencilWithTemporaries, offset_provider: Mapping[str, Any]): +def update_domains( + node: FencilWithTemporaries, + offset_provider: Mapping[str, Any], + symbolic_sizes: Optional[dict[str, str]], +): horizontal_sizes = _max_domain_sizes_by_location_type(offset_provider) - closures: list[ir.StencilClosure] = [] domains = dict[str, ir.FunCall]() for closure in reversed(node.fencil.closures): @@ -479,16 +483,29 @@ def update_domains(node: FencilWithTemporaries, offset_provider: Mapping[str, An # cartesian shift dim = offset_provider[offset_name].value consumed_domain.ranges[dim] = consumed_domain.ranges[dim].translate(offset) - elif isinstance(offset_provider[offset_name], gtx.NeighborTableOffsetProvider): + elif isinstance(offset_provider[offset_name], common.Connectivity): # unstructured shift nbt_provider = offset_provider[offset_name] old_axis = nbt_provider.origin_axis.value new_axis = nbt_provider.neighbor_axis.value - consumed_domain.ranges.pop(old_axis) - assert new_axis not in consumed_domain.ranges - consumed_domain.ranges[new_axis] = SymbolicRange( - im.literal("0", ir.INTEGER_INDEX_BUILTIN), - im.literal(str(horizontal_sizes[new_axis]), ir.INTEGER_INDEX_BUILTIN), + + assert new_axis not in consumed_domain.ranges or old_axis == new_axis + + if symbolic_sizes is None: + new_range = SymbolicRange( + im.literal("0", ir.INTEGER_INDEX_BUILTIN), + im.literal( + str(horizontal_sizes[new_axis]), ir.INTEGER_INDEX_BUILTIN + ), + ) + else: + new_range = SymbolicRange( + im.literal("0", ir.INTEGER_INDEX_BUILTIN), + im.ref(symbolic_sizes[new_axis]), + ) + consumed_domain.ranges = dict( + (axis, range_) if axis != old_axis else (new_axis, new_range) + for axis, range_ in consumed_domain.ranges.items() ) else: raise NotImplementedError @@ -570,7 +587,11 @@ class CreateGlobalTmps(NodeTranslator): """ def visit_FencilDefinition( - self, node: ir.FencilDefinition, *, offset_provider: Mapping[str, Any] + self, + node: ir.FencilDefinition, + *, + offset_provider: Mapping[str, Any], + symbolic_sizes: Optional[dict[str, str]], ) -> FencilWithTemporaries: # Split closures on lifted function calls and introduce temporaries res = split_closures(node, offset_provider=offset_provider) @@ -581,6 +602,6 @@ def visit_FencilDefinition( # Perform an eta-reduction which should put all calls at the highest level of a closure res = EtaReduction().visit(res) # Perform a naive extent analysis to compute domain sizes of closures and temporaries - res = update_domains(res, offset_provider) + res = update_domains(res, offset_provider, symbolic_sizes) # Use type inference to determine the data type of the temporaries return collect_tmps_info(res, offset_provider=offset_provider) diff --git a/src/gt4py/next/iterator/transforms/pass_manager.py b/src/gt4py/next/iterator/transforms/pass_manager.py index 2e05391634..08897861c2 100644 --- a/src/gt4py/next/iterator/transforms/pass_manager.py +++ b/src/gt4py/next/iterator/transforms/pass_manager.py @@ -13,6 +13,7 @@ # SPDX-License-Identifier: GPL-3.0-or-later import enum +from typing import Optional from gt4py.next.iterator import ir from gt4py.next.iterator.transforms import simple_inline_heuristic @@ -81,6 +82,7 @@ def apply_common_transforms( common_subexpression_elimination=True, force_inline_lambda_args=False, unconditionally_collapse_tuples=False, + symbolic_domain_sizes: Optional[dict[str, str]] = None, ): if lift_mode is None: lift_mode = LiftMode.FORCE_INLINE @@ -147,7 +149,9 @@ def apply_common_transforms( if lift_mode != LiftMode.FORCE_INLINE: assert offset_provider is not None - ir = CreateGlobalTmps().visit(ir, offset_provider=offset_provider) + ir = CreateGlobalTmps().visit( + ir, offset_provider=offset_provider, symbolic_sizes=symbolic_domain_sizes + ) ir = InlineLifts().visit(ir) # If after creating temporaries, the scan is not at the top, we inline. # The following example doesn't have a lift around the shift, i.e. temporary pass will not extract it. diff --git a/src/gt4py/next/iterator/type_inference.py b/src/gt4py/next/iterator/type_inference.py index 68627cfd89..d65f67b266 100644 --- a/src/gt4py/next/iterator/type_inference.py +++ b/src/gt4py/next/iterator/type_inference.py @@ -567,9 +567,7 @@ def _infer_shift_location_types(shift_args, offset_provider, constraints): axis = offset_provider[offset] if isinstance(axis, gtx.Dimension): continue # Cartesian shifts don’t change the location type - elif isinstance( - axis, (gtx.NeighborTableOffsetProvider, gtx.StridedNeighborOffsetProvider) - ): + elif isinstance(axis, Connectivity): assert ( axis.origin_axis.kind == axis.neighbor_axis.kind @@ -964,7 +962,7 @@ def visit_FencilDefinition( def _save_types_to_annex(node: ir.Node, types: dict[int, Type]) -> None: for child_node in node.pre_walk_values().if_isinstance(*TYPED_IR_NODES): try: - child_node.annex.type = types[id(child_node)] # type: ignore[attr-defined] + child_node.annex.type = types[id(child_node)] except KeyError: if not ( isinstance(child_node, ir.SymRef) diff --git a/src/gt4py/next/program_processors/codegens/gtfn/gtfn_backend.py b/src/gt4py/next/program_processors/codegens/gtfn/gtfn_backend.py deleted file mode 100644 index 4183f52550..0000000000 --- a/src/gt4py/next/program_processors/codegens/gtfn/gtfn_backend.py +++ /dev/null @@ -1,77 +0,0 @@ -# GT4Py - GridTools Framework -# -# Copyright (c) 2014-2023, ETH Zurich -# All rights reserved. -# -# This file is part of the GT4Py project and the GridTools framework. -# GT4Py is free software: you can redistribute it and/or modify it under -# the terms of the GNU General Public License as published by the -# Free Software Foundation, either version 3 of the License, or any later -# version. See the LICENSE.txt file at the top-level directory of this -# distribution for a copy of the license or check . -# -# SPDX-License-Identifier: GPL-3.0-or-later - -from typing import Any - -import gt4py.next.iterator.ir as itir -from gt4py.eve import codegen -from gt4py.eve.exceptions import EveValueError -from gt4py.next.iterator.transforms.pass_manager import apply_common_transforms -from gt4py.next.program_processors.codegens.gtfn.codegen import GTFNCodegen, GTFNIMCodegen -from gt4py.next.program_processors.codegens.gtfn.gtfn_ir_to_gtfn_im_ir import GTFN_IM_lowering -from gt4py.next.program_processors.codegens.gtfn.itir_to_gtfn_ir import GTFN_lowering - - -def _lower( - program: itir.FencilDefinition, enable_itir_transforms: bool, do_unroll: bool, **kwargs: Any -): - offset_provider = kwargs.get("offset_provider") - assert isinstance(offset_provider, dict) - if enable_itir_transforms: - program = apply_common_transforms( - program, - lift_mode=kwargs.get("lift_mode"), - offset_provider=offset_provider, - unroll_reduce=do_unroll, - unconditionally_collapse_tuples=True, # sid::composite (via hymap) supports assigning from tuple with more elements to tuple with fewer elements - ) - gtfn_ir = GTFN_lowering.apply( - program, - offset_provider=offset_provider, - column_axis=kwargs.get("column_axis"), - ) - return gtfn_ir - - -def generate( - program: itir.FencilDefinition, enable_itir_transforms: bool = True, **kwargs: Any -) -> str: - if kwargs.get("imperative", False): - try: - gtfn_ir = _lower( - program=program, - enable_itir_transforms=enable_itir_transforms, - do_unroll=False, - **kwargs, - ) - except EveValueError: - # if we don't unroll, there may be lifts left in the itir which can't be lowered to - # gtfn. In this case, just retry with unrolled reductions. - gtfn_ir = _lower( - program=program, - enable_itir_transforms=enable_itir_transforms, - do_unroll=True, - **kwargs, - ) - gtfn_im_ir = GTFN_IM_lowering().visit(node=gtfn_ir, **kwargs) - generated_code = GTFNIMCodegen.apply(gtfn_im_ir, **kwargs) - else: - gtfn_ir = _lower( - program=program, - enable_itir_transforms=enable_itir_transforms, - do_unroll=True, - **kwargs, - ) - generated_code = GTFNCodegen.apply(gtfn_ir, **kwargs) - return codegen.format_source("cpp", generated_code, style="LLVM") diff --git a/src/gt4py/next/program_processors/codegens/gtfn/gtfn_module.py b/src/gt4py/next/program_processors/codegens/gtfn/gtfn_module.py index 4abdaa6eea..718fef72af 100644 --- a/src/gt4py/next/program_processors/codegens/gtfn/gtfn_module.py +++ b/src/gt4py/next/program_processors/codegens/gtfn/gtfn_module.py @@ -15,21 +15,24 @@ from __future__ import annotations import dataclasses +import functools import warnings from typing import Any, Final, Optional import numpy as np from gt4py._core import definitions as core_defs -from gt4py.eve import trees, utils +from gt4py.eve import codegen, trees, utils from gt4py.next import common from gt4py.next.common import Connectivity, Dimension from gt4py.next.ffront import fbuiltins from gt4py.next.iterator import ir as itir -from gt4py.next.iterator.transforms import LiftMode +from gt4py.next.iterator.transforms import LiftMode, pass_manager from gt4py.next.otf import languages, stages, step_types, workflow from gt4py.next.otf.binding import cpp_interface, interface -from gt4py.next.program_processors.codegens.gtfn import gtfn_backend +from gt4py.next.program_processors.codegens.gtfn.codegen import GTFNCodegen, GTFNIMCodegen +from gt4py.next.program_processors.codegens.gtfn.gtfn_ir_to_gtfn_im_ir import GTFN_IM_lowering +from gt4py.next.program_processors.codegens.gtfn.itir_to_gtfn_ir import GTFN_lowering from gt4py.next.type_system import type_specifications as ts, type_translation @@ -54,6 +57,7 @@ class GTFNTranslationStep( use_imperative_backend: bool = False lift_mode: Optional[LiftMode] = None device_type: core_defs.DeviceType = core_defs.DeviceType.CPU + symbolic_domain_sizes: Optional[dict[str, str]] = None def _default_language_settings(self) -> languages.LanguageWithHeaderFilesSettings: match self.device_type: @@ -171,6 +175,70 @@ def _process_connectivity_args( return parameters, arg_exprs + def _preprocess_program( + self, + program: itir.FencilDefinition, + offset_provider: dict[str, Connectivity | Dimension], + runtime_lift_mode: Optional[LiftMode] = None, + ) -> itir.FencilDefinition: + # TODO(tehrengruber): Remove `lift_mode` from call interface. It has been implicitly added + # to the interface of all (or at least all of concern) backends, but instead should be + # configured in the backend itself (like it is here), until then we respect the argument + # here and warn the user if it differs from the one configured. + lift_mode = runtime_lift_mode or self.lift_mode + if lift_mode != self.lift_mode: + warnings.warn( + f"GTFN Backend was configured for LiftMode `{str(self.lift_mode)}`, but " + f"overriden to be {str(runtime_lift_mode)} at runtime." + ) + + if not self.enable_itir_transforms: + return program + + apply_common_transforms = functools.partial( + pass_manager.apply_common_transforms, + lift_mode=lift_mode, + offset_provider=offset_provider, + # sid::composite (via hymap) supports assigning from tuple with more elements to tuple with fewer elements + unconditionally_collapse_tuples=True, + symbolic_domain_sizes=self.symbolic_domain_sizes, + ) + + new_program = apply_common_transforms( + program, unroll_reduce=not self.use_imperative_backend + ) + + if self.use_imperative_backend and any( + node.id == "neighbors" + for node in new_program.pre_walk_values().if_isinstance(itir.SymRef) + ): + # if we don't unroll, there may be lifts left in the itir which can't be lowered to + # gtfn. In this case, just retry with unrolled reductions. + new_program = apply_common_transforms(program, unroll_reduce=True) + + return new_program + + def generate_stencil_source( + self, + program: itir.FencilDefinition, + offset_provider: dict[str, Connectivity | Dimension], + column_axis: Optional[common.Dimension], + runtime_lift_mode: Optional[LiftMode] = None, + ) -> str: + new_program = self._preprocess_program(program, offset_provider, runtime_lift_mode) + gtfn_ir = GTFN_lowering.apply( + new_program, + offset_provider=offset_provider, + column_axis=column_axis, + ) + + if self.use_imperative_backend: + gtfn_im_ir = GTFN_IM_lowering().visit(node=gtfn_ir) + generated_code = GTFNIMCodegen.apply(gtfn_im_ir) + else: + generated_code = GTFNCodegen.apply(gtfn_ir) + return codegen.format_source("cpp", generated_code, style="LLVM") + def __call__( self, inp: stages.ProgramCall, @@ -190,18 +258,6 @@ def __call__( inp.kwargs["offset_provider"] ) - # TODO(tehrengruber): Remove `lift_mode` from call interface. It has been implicitly added - # to the interface of all (or at least all of concern) backends, but instead should be - # configured in the backend itself (like it is here), until then we respect the argument - # here and warn the user if it differs from the one configured. - runtime_lift_mode = inp.kwargs.pop("lift_mode", None) - lift_mode = runtime_lift_mode or self.lift_mode - if runtime_lift_mode != self.lift_mode: - warnings.warn( - f"GTFN Backend was configured for LiftMode `{str(self.lift_mode)}`, but " - "overriden to be {str(runtime_lift_mode)} at runtime." - ) - # combine into a format that is aligned with what the backend expects parameters: list[interface.Parameter] = regular_parameters + connectivity_parameters backend_arg = self._backend_type() @@ -213,12 +269,11 @@ def __call__( f"{', '.join(connectivity_args_expr)})({', '.join(args_expr)});" ) decl_src = cpp_interface.render_function_declaration(function, body=decl_body) - stencil_src = gtfn_backend.generate( + stencil_src = self.generate_stencil_source( program, - enable_itir_transforms=self.enable_itir_transforms, - lift_mode=lift_mode, - imperative=self.use_imperative_backend, - **inp.kwargs, + inp.kwargs["offset_provider"], + inp.kwargs.get("column_axis", None), + inp.kwargs.get("lift_mode", None), ) source_code = interface.format_source( self._language_settings(), diff --git a/src/gt4py/next/program_processors/formatters/gtfn.py b/src/gt4py/next/program_processors/formatters/gtfn.py index f9fa154641..27dec77ed1 100644 --- a/src/gt4py/next/program_processors/formatters/gtfn.py +++ b/src/gt4py/next/program_processors/formatters/gtfn.py @@ -15,10 +15,19 @@ from typing import Any from gt4py.next.iterator import ir as itir -from gt4py.next.program_processors.codegens.gtfn.gtfn_backend import generate +from gt4py.next.program_processors.codegens.gtfn.gtfn_module import GTFNTranslationStep from gt4py.next.program_processors.processor_interface import program_formatter +from gt4py.next.program_processors.runners.gtfn import gtfn_executor @program_formatter def format_cpp(program: itir.FencilDefinition, *args: Any, **kwargs: Any) -> str: - return generate(program, **kwargs) + # TODO(tehrengruber): This is a little ugly. Revisit. + gtfn_translation = gtfn_executor.otf_workflow.translation + assert isinstance(gtfn_translation, GTFNTranslationStep) + return gtfn_translation.generate_stencil_source( + program, + offset_provider=kwargs.get("offset_provider", None), + column_axis=kwargs.get("column_axis", None), + runtime_lift_mode=kwargs.get("lift_mode", None), + ) diff --git a/tests/next_tests/integration_tests/feature_tests/ffront_tests/test_temporaries_with_sizes.py b/tests/next_tests/integration_tests/feature_tests/ffront_tests/test_temporaries_with_sizes.py new file mode 100644 index 0000000000..da0945fe96 --- /dev/null +++ b/tests/next_tests/integration_tests/feature_tests/ffront_tests/test_temporaries_with_sizes.py @@ -0,0 +1,113 @@ +# GT4Py - GridTools Framework +# +# Copyright (c) 2014-2023, ETH Zurich +# All rights reserved. +# +# This file is part of the GT4Py project and the GridTools framework. +# GT4Py is free software: you can redistribute it and/or modify it under +# the terms of the GNU General Public License as published by the +# Free Software Foundation, either version 3 of the License, or any later +# version. See the LICENSE.txt file at the top-level directory of this +# distribution for a copy of the license or check . +# +# SPDX-License-Identifier: GPL-3.0-or-later + +import pytest +from numpy import int32, int64 + +from gt4py import next as gtx +from gt4py.next import common +from gt4py.next.iterator.transforms import LiftMode, apply_common_transforms +from gt4py.next.program_processors import otf_compile_executor +from gt4py.next.program_processors.runners.gtfn import run_gtfn_with_temporaries +from tests.next_tests.integration_tests.cases import Case +from tests.next_tests.toy_connectivity import Cell, Edge + +from next_tests.integration_tests import cases +from next_tests.integration_tests.cases import E2V, KDim, Vertex, cartesian_case, unstructured_case +from next_tests.integration_tests.feature_tests.ffront_tests.ffront_test_utils import ( + reduction_setup, +) + + +@pytest.fixture +def run_gtfn_with_temporaries_and_symbolic_sizes(): + return otf_compile_executor.OTFBackend( + executor=otf_compile_executor.OTFCompileExecutor( + name="run_gtfn_with_temporaries_and_sizes", + otf_workflow=run_gtfn_with_temporaries.executor.otf_workflow.replace( + translation=run_gtfn_with_temporaries.executor.otf_workflow.translation.replace( + symbolic_domain_sizes={ + "Cell": "num_cells", + "Edge": "num_edges", + "Vertex": "num_vertices", + }, + ), + ), + ), + allocator=run_gtfn_with_temporaries.allocator, + ) + + +@pytest.fixture +def testee(): + @gtx.field_operator + def testee_op(a: cases.VField) -> cases.EField: + amul = a * 2 + return amul(E2V[0]) + amul(E2V[1]) + + @gtx.program + def prog( + a: cases.VField, + out: cases.EField, + num_vertices: int32, + num_edges: int64, + num_cells: int32, + ): + testee_op(a, out=out) + + return prog + + +def test_verification(testee, run_gtfn_with_temporaries_and_symbolic_sizes, reduction_setup): + unstructured_case = Case( + run_gtfn_with_temporaries_and_symbolic_sizes, + offset_provider=reduction_setup.offset_provider, + default_sizes={ + Vertex: reduction_setup.num_vertices, + Edge: reduction_setup.num_edges, + Cell: reduction_setup.num_cells, + KDim: reduction_setup.k_levels, + }, + grid_type=common.GridType.UNSTRUCTURED, + ) + + a = cases.allocate(unstructured_case, testee, "a")() + out = cases.allocate(unstructured_case, testee, "out")() + + first_nbs, second_nbs = (reduction_setup.offset_provider["E2V"].table[:, i] for i in [0, 1]) + ref = (a.ndarray * 2)[first_nbs] + (a.ndarray * 2)[second_nbs] + + cases.verify( + unstructured_case, + testee, + a, + out, + reduction_setup.num_vertices, + reduction_setup.num_edges, + reduction_setup.num_cells, + inout=out, + ref=ref, + ) + + +def test_temporary_symbols(testee, reduction_setup): + itir_with_tmp = apply_common_transforms( + testee.itir, + lift_mode=LiftMode.FORCE_TEMPORARIES, + offset_provider=reduction_setup.offset_provider, + ) + + params = ["num_vertices", "num_edges", "num_cells"] + for param in params: + assert any([param == str(p) for p in itir_with_tmp.fencil.params]) diff --git a/tests/next_tests/integration_tests/multi_feature_tests/cpp_backend_tests/anton_lap.py b/tests/next_tests/integration_tests/multi_feature_tests/cpp_backend_tests/anton_lap.py index e851e7b130..5af4605988 100644 --- a/tests/next_tests/integration_tests/multi_feature_tests/cpp_backend_tests/anton_lap.py +++ b/tests/next_tests/integration_tests/multi_feature_tests/cpp_backend_tests/anton_lap.py @@ -18,7 +18,7 @@ from gt4py.next.iterator.builtins import * from gt4py.next.iterator.runtime import closure, fundef, offset from gt4py.next.iterator.tracing import trace_fencil_definition -from gt4py.next.program_processors.codegens.gtfn.gtfn_backend import generate +from gt4py.next.program_processors.runners.gtfn import run_gtfn @fundef @@ -69,7 +69,9 @@ def lap_fencil(i_size, j_size, k_size, i_off, j_off, k_off, out, inp): output_file = sys.argv[1] prog = trace_fencil_definition(lap_fencil, [None] * 8, use_arg_types=False) - generated_code = generate(prog, offset_provider={"i": IDim, "j": JDim}) + generated_code = run_gtfn.executor.otf_workflow.translation.generate_stencil_source( + prog, offset_provider={"i": IDim, "j": JDim}, column_axis=None + ) with open(output_file, "w+") as output: output.write(generated_code) diff --git a/tests/next_tests/integration_tests/multi_feature_tests/cpp_backend_tests/copy_stencil.py b/tests/next_tests/integration_tests/multi_feature_tests/cpp_backend_tests/copy_stencil.py index 33c7d5baa7..3e8b88ac66 100644 --- a/tests/next_tests/integration_tests/multi_feature_tests/cpp_backend_tests/copy_stencil.py +++ b/tests/next_tests/integration_tests/multi_feature_tests/cpp_backend_tests/copy_stencil.py @@ -18,7 +18,7 @@ from gt4py.next.iterator.builtins import * from gt4py.next.iterator.runtime import closure, fundef from gt4py.next.iterator.tracing import trace_fencil_definition -from gt4py.next.program_processors.codegens.gtfn.gtfn_backend import generate +from gt4py.next.program_processors.runners.gtfn import run_gtfn IDim = gtx.Dimension("IDim") @@ -48,7 +48,9 @@ def copy_fencil(isize, jsize, ksize, inp, out): output_file = sys.argv[1] prog = trace_fencil_definition(copy_fencil, [None] * 5, use_arg_types=False) - generated_code = generate(prog, offset_provider={}) + generated_code = run_gtfn.executor.otf_workflow.translation.generate_stencil_source( + prog, offset_provider={}, column_axis=None + ) with open(output_file, "w+") as output: output.write(generated_code) diff --git a/tests/next_tests/integration_tests/multi_feature_tests/cpp_backend_tests/copy_stencil_field_view.py b/tests/next_tests/integration_tests/multi_feature_tests/cpp_backend_tests/copy_stencil_field_view.py index f7472d4ac3..fdc57449ee 100644 --- a/tests/next_tests/integration_tests/multi_feature_tests/cpp_backend_tests/copy_stencil_field_view.py +++ b/tests/next_tests/integration_tests/multi_feature_tests/cpp_backend_tests/copy_stencil_field_view.py @@ -18,7 +18,7 @@ import gt4py.next as gtx from gt4py.next import Field, field_operator, program -from gt4py.next.program_processors.codegens.gtfn.gtfn_backend import generate +from gt4py.next.program_processors.runners.gtfn import run_gtfn IDim = gtx.Dimension("IDim") @@ -47,7 +47,9 @@ def copy_program( output_file = sys.argv[1] prog = copy_program.itir - generated_code = generate(prog, offset_provider={}) + generated_code = run_gtfn.executor.otf_workflow.translation.generate_stencil_source( + prog, offset_provider={}, column_axis=None + ) with open(output_file, "w+") as output: output.write(generated_code) diff --git a/tests/next_tests/integration_tests/multi_feature_tests/cpp_backend_tests/fvm_nabla.py b/tests/next_tests/integration_tests/multi_feature_tests/cpp_backend_tests/fvm_nabla.py index 1dfd74baca..abc3755dca 100644 --- a/tests/next_tests/integration_tests/multi_feature_tests/cpp_backend_tests/fvm_nabla.py +++ b/tests/next_tests/integration_tests/multi_feature_tests/cpp_backend_tests/fvm_nabla.py @@ -19,7 +19,7 @@ from gt4py.next.iterator.builtins import * from gt4py.next.iterator.runtime import closure, fundef, offset from gt4py.next.iterator.tracing import trace_fencil_definition -from gt4py.next.program_processors.codegens.gtfn.gtfn_backend import generate +from gt4py.next.program_processors.runners.gtfn import run_gtfn, run_gtfn_imperative E2V = offset("E2V") @@ -92,13 +92,20 @@ def mapped_index(_, __) -> int: output_file = sys.argv[1] imperative = sys.argv[2].lower() == "true" + if imperative: + backend = run_gtfn_imperative + else: + backend = run_gtfn + # prog = trace(zavgS_fencil, [None] * 4) # TODO allow generating of 2 fencils prog = trace_fencil_definition(nabla_fencil, [None] * 7, use_arg_types=False) offset_provider = { "V2E": DummyConnectivity(max_neighbors=6, has_skip_values=True), "E2V": DummyConnectivity(max_neighbors=2, has_skip_values=False), } - generated_code = generate(prog, offset_provider=offset_provider, imperative=imperative) + generated_code = backend.executor.otf_workflow.translation.generate_stencil_source( + prog, offset_provider=offset_provider, column_axis=None + ) with open(output_file, "w+") as output: output.write(generated_code) diff --git a/tests/next_tests/integration_tests/multi_feature_tests/cpp_backend_tests/tridiagonal_solve.py b/tests/next_tests/integration_tests/multi_feature_tests/cpp_backend_tests/tridiagonal_solve.py index 578a19faab..9755774fd0 100644 --- a/tests/next_tests/integration_tests/multi_feature_tests/cpp_backend_tests/tridiagonal_solve.py +++ b/tests/next_tests/integration_tests/multi_feature_tests/cpp_backend_tests/tridiagonal_solve.py @@ -19,7 +19,7 @@ from gt4py.next.iterator.runtime import closure, fundef from gt4py.next.iterator.tracing import trace_fencil_definition from gt4py.next.iterator.transforms import LiftMode -from gt4py.next.program_processors.codegens.gtfn.gtfn_backend import generate +from gt4py.next.program_processors.runners.gtfn import run_gtfn IDim = gtx.Dimension("IDim") @@ -67,10 +67,10 @@ def tridiagonal_solve_fencil(isize, jsize, ksize, a, b, c, d, x): prog = trace_fencil_definition(tridiagonal_solve_fencil, [None] * 8, use_arg_types=False) offset_provider = {"I": gtx.Dimension("IDim"), "J": gtx.Dimension("JDim")} - generated_code = generate( + generated_code = run_gtfn.executor.otf_workflow.translation.generate_stencil_source( prog, offset_provider=offset_provider, - lift_mode=LiftMode.SIMPLE_HEURISTIC, + runtime_lift_mode=LiftMode.SIMPLE_HEURISTIC, column_axis=KDim, ) diff --git a/tests/next_tests/unit_tests/iterator_tests/transforms_tests/test_global_tmps.py b/tests/next_tests/unit_tests/iterator_tests/transforms_tests/test_global_tmps.py index 86c3c98c62..5c2802f90c 100644 --- a/tests/next_tests/unit_tests/iterator_tests/transforms_tests/test_global_tmps.py +++ b/tests/next_tests/unit_tests/iterator_tests/transforms_tests/test_global_tmps.py @@ -323,7 +323,7 @@ def test_update_cartesian_domains(): for a, s in (("JDim", "j"), ("KDim", "k")) ], ) - actual = update_domains(testee, {"I": gtx.Dimension("IDim")}) + actual = update_domains(testee, {"I": gtx.Dimension("IDim")}, symbolic_sizes=None) assert actual == expected From b900b474566f21339d5c99aa2365f9bed86bf1ec Mon Sep 17 00:00:00 2001 From: edopao Date: Fri, 19 Jan 2024 11:46:04 +0100 Subject: [PATCH 46/85] build[cartesian][next]: Bump dace version from 0.14.4 to 0.15.1 (#1391) Bumping dace version to 0.15.1 affects both cartesian and next gt4py: * cartesian: removed try/except for dace backward compatibility * next: re-enabled some tests that were broken on dace 0.14.4 * all: fixed and/or suppressed flake8 and mypy errors --- .pre-commit-config.yaml | 38 ++-- constraints.txt | 191 ++++++++++-------- min-extra-requirements-test.txt | 4 +- pyproject.toml | 6 +- requirements-dev.txt | 191 ++++++++++-------- src/gt4py/__init__.py | 2 +- src/gt4py/cartesian/backend/dace_backend.py | 8 +- src/gt4py/cartesian/gtc/dace/nodes.py | 2 +- src/gt4py/eve/datamodels/core.py | 2 +- src/gt4py/eve/utils.py | 4 +- src/gt4py/next/common.py | 11 +- src/gt4py/next/ffront/fbuiltins.py | 2 +- src/gt4py/next/otf/workflow.py | 2 +- .../runners/dace_iterator/__init__.py | 13 +- .../runners/dace_iterator/itir_to_sdfg.py | 14 +- .../unit_tests/test_gtc/test_common.py | 2 +- .../ffront_tests/test_external_local_field.py | 10 - .../ffront_tests/test_gt4py_builtins.py | 40 ---- .../test_temporaries_with_sizes.py | 12 +- 19 files changed, 262 insertions(+), 292 deletions(-) diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index b1092fafd0..d9cfa0ff48 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -62,7 +62,7 @@ repos: ## version = re.search('black==([0-9\.]*)', open("constraints.txt").read())[1] ## print(f"rev: '{version}' # version from constraints.txt") ##]]] - rev: '23.9.1' # version from constraints.txt + rev: '23.11.0' # version from constraints.txt ##[[[end]]] hooks: - id: black @@ -73,7 +73,7 @@ repos: ## version = re.search('isort==([0-9\.]*)', open("constraints.txt").read())[1] ## print(f"rev: '{version}' # version from constraints.txt") ##]]] - rev: '5.12.0' # version from constraints.txt + rev: '5.13.0' # version from constraints.txt ##[[[end]]] hooks: - id: isort @@ -97,14 +97,14 @@ repos: ## print(f"- {pkg}==" + str(re.search(f'\n{pkg}==([0-9\.]*)', constraints)[1])) ##]]] - darglint==1.8.1 - - flake8-bugbear==23.9.16 - - flake8-builtins==2.1.0 + - flake8-bugbear==23.12.2 + - flake8-builtins==2.2.0 - flake8-debugger==4.1.2 - flake8-docstrings==1.7.0 - flake8-eradicate==1.5.0 - flake8-mutable==1.2.0 - flake8-pyproject==1.2.3 - - pygments==2.16.1 + - pygments==2.17.2 ##[[[end]]] # - flake8-rst-docstrings # Disabled for now due to random false positives exclude: | @@ -146,9 +146,9 @@ repos: ## version = re.search('mypy==([0-9\.]*)', open("constraints.txt").read())[1] ## print(f"#========= FROM constraints.txt: v{version} =========") ##]]] - #========= FROM constraints.txt: v1.5.1 ========= + #========= FROM constraints.txt: v1.7.1 ========= ##[[[end]]] - rev: v1.5.1 # MUST match version ^^^^ in constraints.txt (if the mirror is up-to-date) + rev: v1.7.1 # MUST match version ^^^^ in constraints.txt (if the mirror is up-to-date) hooks: - id: mypy additional_dependencies: # versions from constraints.txt @@ -162,26 +162,26 @@ repos: ##]]] - astunparse==1.6.3 - attrs==23.1.0 - - black==23.9.1 - - boltons==23.0.0 + - black==23.11.0 + - boltons==23.1.1 - cached-property==1.5.2 - click==8.1.7 - - cmake==3.27.5 + - cmake==3.27.9 - cytoolz==0.12.2 - - deepdiff==6.5.0 + - deepdiff==6.7.1 - devtools==0.12.2 - - frozendict==2.3.8 + - frozendict==2.3.10 - gridtools-cpp==2.3.1 - - importlib-resources==6.0.1 + - importlib-resources==6.1.1 - jinja2==3.1.2 - - lark==1.1.7 - - mako==1.2.4 - - nanobind==1.5.2 - - ninja==1.11.1 + - lark==1.1.8 + - mako==1.3.0 + - nanobind==1.8.0 + - ninja==1.11.1.1 - numpy==1.24.4 - - packaging==23.1 + - packaging==23.2 - pybind11==2.11.1 - - setuptools==68.2.2 + - setuptools==69.0.2 - tabulate==0.9.0 - typing-extensions==4.5.0 - xxhash==3.0.0 diff --git a/constraints.txt b/constraints.txt index b334851af1..81abd64c6e 100644 --- a/constraints.txt +++ b/constraints.txt @@ -6,124 +6,136 @@ # aenum==3.1.15 # via dace alabaster==0.7.13 # via sphinx -asttokens==2.4.0 # via devtools +asttokens==2.4.1 # via devtools astunparse==1.6.3 ; python_version < "3.9" # via dace, gt4py (pyproject.toml) attrs==23.1.0 # via flake8-bugbear, flake8-eradicate, gt4py (pyproject.toml), hypothesis, jsonschema, referencing -babel==2.12.1 # via sphinx -black==23.9.1 # via gt4py (pyproject.toml) -blinker==1.6.2 # via flask -boltons==23.0.0 # via gt4py (pyproject.toml) +babel==2.13.1 # via sphinx +black==23.11.0 # via gt4py (pyproject.toml) +blinker==1.7.0 # via flask +boltons==23.1.1 # via gt4py (pyproject.toml) build==1.0.3 # via pip-tools cached-property==1.5.2 # via gt4py (pyproject.toml) -cachetools==5.3.1 # via tox -certifi==2023.7.22 # via requests -cffi==1.15.1 # via cryptography +cachetools==5.3.2 # via tox +cerberus==1.3.5 # via plette +certifi==2023.11.17 # via requests +cffi==1.16.0 # via cryptography cfgv==3.4.0 # via pre-commit chardet==5.2.0 # via tox -charset-normalizer==3.2.0 # via requests -clang-format==16.0.6 # via -r requirements-dev.in, gt4py (pyproject.toml) +charset-normalizer==3.3.2 # via requests +clang-format==17.0.6 # via -r requirements-dev.in, gt4py (pyproject.toml) click==8.1.7 # via black, flask, gt4py (pyproject.toml), pip-tools -cmake==3.27.5 # via gt4py (pyproject.toml) +cmake==3.27.9 # via dace, gt4py (pyproject.toml) cogapp==3.3.0 # via -r requirements-dev.in colorama==0.4.6 # via tox -coverage==7.3.1 # via -r requirements-dev.in, pytest-cov -cryptography==41.0.3 # via types-paramiko, types-pyopenssl, types-redis +coverage==7.3.2 # via -r requirements-dev.in, pytest-cov +cryptography==41.0.7 # via types-paramiko, types-pyopenssl, types-redis cytoolz==0.12.2 # via gt4py (pyproject.toml) -dace==0.14.4 # via gt4py (pyproject.toml) +dace==0.15.1 # via gt4py (pyproject.toml) darglint==1.8.1 # via -r requirements-dev.in -deepdiff==6.5.0 # via gt4py (pyproject.toml) +deepdiff==6.7.1 # via gt4py (pyproject.toml) devtools==0.12.2 # via gt4py (pyproject.toml) dill==0.3.7 # via dace -distlib==0.3.7 # via virtualenv -docutils==0.18.1 # via restructuredtext-lint, sphinx, sphinx-rtd-theme +distlib==0.3.7 # via requirementslib, virtualenv +distro==1.8.0 # via scikit-build +docopt==0.6.2 # via pipreqs +docutils==0.20.1 # via restructuredtext-lint, sphinx, sphinx-rtd-theme eradicate==2.3.0 # via flake8-eradicate -exceptiongroup==1.1.3 # via hypothesis, pytest +exceptiongroup==1.2.0 # via hypothesis, pytest execnet==2.0.2 # via pytest-cache, pytest-xdist -executing==1.2.0 # via devtools +executing==2.0.1 # via devtools factory-boy==3.3.0 # via -r requirements-dev.in, pytest-factoryboy -faker==19.6.1 # via factory-boy -fastjsonschema==2.18.0 # via nbformat -filelock==3.12.4 # via tox, virtualenv +faker==20.1.0 # via factory-boy +fastjsonschema==2.19.0 # via nbformat +filelock==3.13.1 # via tox, virtualenv flake8==6.1.0 # via -r requirements-dev.in, flake8-bugbear, flake8-builtins, flake8-debugger, flake8-docstrings, flake8-eradicate, flake8-mutable, flake8-pyproject, flake8-rst-docstrings -flake8-bugbear==23.9.16 # via -r requirements-dev.in -flake8-builtins==2.1.0 # via -r requirements-dev.in +flake8-bugbear==23.12.2 # via -r requirements-dev.in +flake8-builtins==2.2.0 # via -r requirements-dev.in flake8-debugger==4.1.2 # via -r requirements-dev.in flake8-docstrings==1.7.0 # via -r requirements-dev.in flake8-eradicate==1.5.0 # via -r requirements-dev.in flake8-mutable==1.2.0 # via -r requirements-dev.in flake8-pyproject==1.2.3 # via -r requirements-dev.in flake8-rst-docstrings==0.3.0 # via -r requirements-dev.in -flask==2.3.3 # via dace -frozendict==2.3.8 # via gt4py (pyproject.toml) +flask==3.0.0 # via dace +fparser==0.1.3 # via dace +frozendict==2.3.10 # via gt4py (pyproject.toml) gridtools-cpp==2.3.1 # via gt4py (pyproject.toml) -hypothesis==6.86.1 # via -r requirements-dev.in, gt4py (pyproject.toml) -identify==2.5.29 # via pre-commit -idna==3.4 # via requests +hypothesis==6.92.0 # via -r requirements-dev.in, gt4py (pyproject.toml) +identify==2.5.33 # via pre-commit +idna==3.6 # via requests imagesize==1.4.1 # via sphinx -importlib-metadata==6.8.0 # via build, flask, sphinx -importlib-resources==6.0.1 ; python_version < "3.9" # via gt4py (pyproject.toml), jsonschema, jsonschema-specifications +importlib-metadata==7.0.0 # via build, flask, fparser, sphinx +importlib-resources==6.1.1 ; python_version < "3.9" # via gt4py (pyproject.toml), jsonschema, jsonschema-specifications inflection==0.5.1 # via pytest-factoryboy iniconfig==2.0.0 # via pytest -isort==5.12.0 # via -r requirements-dev.in +isort==5.13.0 # via -r requirements-dev.in itsdangerous==2.1.2 # via flask jinja2==3.1.2 # via flask, gt4py (pyproject.toml), sphinx -jsonschema==4.19.0 # via nbformat -jsonschema-specifications==2023.7.1 # via jsonschema -jupyter-core==5.3.1 # via nbformat -jupytext==1.15.2 # via -r requirements-dev.in -lark==1.1.7 # via gt4py (pyproject.toml) -mako==1.2.4 # via gt4py (pyproject.toml) +jsonschema==4.20.0 # via nbformat +jsonschema-specifications==2023.11.2 # via jsonschema +jupyter-core==5.5.0 # via nbformat +jupytext==1.16.0 # via -r requirements-dev.in +lark==1.1.8 # via gt4py (pyproject.toml) +mako==1.3.0 # via gt4py (pyproject.toml) markdown-it-py==3.0.0 # via jupytext, mdit-py-plugins markupsafe==2.1.3 # via jinja2, mako, werkzeug mccabe==0.7.0 # via flake8 mdit-py-plugins==0.4.0 # via jupytext mdurl==0.1.2 # via markdown-it-py mpmath==1.3.0 # via sympy -mypy==1.5.1 # via -r requirements-dev.in +mypy==1.7.1 # via -r requirements-dev.in mypy-extensions==1.0.0 # via black, mypy -nanobind==1.5.2 # via gt4py (pyproject.toml) +nanobind==1.8.0 # via gt4py (pyproject.toml) nbformat==5.9.2 # via jupytext networkx==3.1 # via dace -ninja==1.11.1 # via gt4py (pyproject.toml) +ninja==1.11.1.1 # via gt4py (pyproject.toml) nodeenv==1.8.0 # via pre-commit numpy==1.24.4 # via dace, gt4py (pyproject.toml), types-jack-client ordered-set==4.1.0 # via deepdiff -packaging==23.1 # via black, build, gt4py (pyproject.toml), pyproject-api, pytest, sphinx, tox -pathspec==0.11.2 # via black +packaging==23.2 # via black, build, gt4py (pyproject.toml), jupytext, pyproject-api, pytest, scikit-build, setuptools-scm, sphinx, tox +pathspec==0.12.1 # via black +pep517==0.13.1 # via requirementslib +pip-api==0.0.30 # via isort pip-tools==7.3.0 # via -r requirements-dev.in -pipdeptree==2.13.0 # via -r requirements-dev.in +pipdeptree==2.13.1 # via -r requirements-dev.in +pipreqs==0.4.13 # via isort pkgutil-resolve-name==1.3.10 # via jsonschema -platformdirs==3.10.0 # via black, jupyter-core, tox, virtualenv +platformdirs==4.1.0 # via black, jupyter-core, requirementslib, tox, virtualenv +plette==0.4.4 # via requirementslib pluggy==1.3.0 # via pytest, tox ply==3.11 # via dace -pre-commit==3.4.0 # via -r requirements-dev.in -psutil==5.9.5 # via -r requirements-dev.in, pytest-xdist +pre-commit==3.5.0 # via -r requirements-dev.in +psutil==5.9.6 # via -r requirements-dev.in, pytest-xdist pybind11==2.11.1 # via gt4py (pyproject.toml) -pycodestyle==2.11.0 # via flake8, flake8-debugger +pycodestyle==2.11.1 # via flake8, flake8-debugger pycparser==2.21 # via cffi +pydantic==1.10.13 # via requirementslib pydocstyle==6.3.0 # via flake8-docstrings pyflakes==3.1.0 # via flake8 -pygments==2.16.1 # via -r requirements-dev.in, devtools, flake8-rst-docstrings, sphinx +pygments==2.17.2 # via -r requirements-dev.in, devtools, flake8-rst-docstrings, sphinx pyproject-api==1.6.1 # via tox pyproject-hooks==1.0.0 # via build -pytest==7.4.2 # via -r requirements-dev.in, gt4py (pyproject.toml), pytest-cache, pytest-cov, pytest-factoryboy, pytest-xdist +pytest==7.4.3 # via -r requirements-dev.in, gt4py (pyproject.toml), pytest-cache, pytest-cov, pytest-factoryboy, pytest-xdist pytest-cache==1.0 # via -r requirements-dev.in pytest-cov==4.1.0 # via -r requirements-dev.in -pytest-factoryboy==2.5.1 # via -r requirements-dev.in -pytest-xdist==3.3.1 # via -r requirements-dev.in +pytest-factoryboy==2.6.0 # via -r requirements-dev.in +pytest-xdist==3.5.0 # via -r requirements-dev.in python-dateutil==2.8.2 # via faker pytz==2023.3.post1 # via babel pyyaml==6.0.1 # via dace, jupytext, pre-commit -referencing==0.30.2 # via jsonschema, jsonschema-specifications -requests==2.31.0 # via dace, sphinx +referencing==0.32.0 # via jsonschema, jsonschema-specifications +requests==2.31.0 # via dace, requirementslib, sphinx, yarg +requirementslib==3.0.0 # via isort restructuredtext-lint==1.4.0 # via flake8-rst-docstrings -rpds-py==0.10.3 # via jsonschema, referencing -ruff==0.0.290 # via -r requirements-dev.in +rpds-py==0.13.2 # via jsonschema, referencing +ruff==0.1.7 # via -r requirements-dev.in +scikit-build==0.17.6 # via dace +setuptools-scm==8.0.4 # via fparser six==1.16.0 # via asttokens, astunparse, python-dateutil snowballstemmer==2.2.0 # via pydocstyle, sphinx sortedcontainers==2.4.0 # via hypothesis sphinx==7.1.2 # via -r requirements-dev.in, sphinx-rtd-theme, sphinxcontrib-jquery -sphinx-rtd-theme==1.3.0 # via -r requirements-dev.in +sphinx-rtd-theme==2.0.0 # via -r requirements-dev.in sphinxcontrib-applehelp==1.0.4 # via sphinx sphinxcontrib-devhelp==1.0.2 # via sphinx sphinxcontrib-htmlhelp==2.0.1 # via sphinx @@ -131,31 +143,32 @@ sphinxcontrib-jquery==4.1 # via sphinx-rtd-theme sphinxcontrib-jsmath==1.0.1 # via sphinx sphinxcontrib-qthelp==1.0.3 # via sphinx sphinxcontrib-serializinghtml==1.1.5 # via sphinx -sympy==1.12 # via dace, gt4py (pyproject.toml) +sympy==1.9 # via dace, gt4py (pyproject.toml) tabulate==0.9.0 # via gt4py (pyproject.toml) toml==0.10.2 # via jupytext -tomli==2.0.1 # via -r requirements-dev.in, black, build, coverage, flake8-pyproject, mypy, pip-tools, pyproject-api, pyproject-hooks, pytest, tox +tomli==2.0.1 # via -r requirements-dev.in, black, build, coverage, flake8-pyproject, mypy, pep517, pip-tools, pyproject-api, pyproject-hooks, pytest, scikit-build, setuptools-scm, tox +tomlkit==0.12.3 # via plette, requirementslib toolz==0.12.0 # via cytoolz -tox==4.11.3 # via -r requirements-dev.in -traitlets==5.10.0 # via jupyter-core, nbformat +tox==4.11.4 # via -r requirements-dev.in +traitlets==5.14.0 # via jupyter-core, nbformat types-aiofiles==23.2.0.0 # via types-all types-all==1.0.0 # via -r requirements-dev.in types-annoy==1.17.8.4 # via types-all types-atomicwrites==1.4.5.1 # via types-all types-backports==0.1.3 # via types-all types-backports-abc==0.5.2 # via types-all -types-bleach==6.0.0.4 # via types-all +types-bleach==6.1.0.1 # via types-all types-boto==2.49.18.9 # via types-all -types-cachetools==5.3.0.6 # via types-all +types-cachetools==5.3.0.7 # via types-all types-certifi==2021.10.8.3 # via types-all -types-cffi==1.15.1.15 # via types-jack-client +types-cffi==1.16.0.0 # via types-jack-client types-characteristic==14.3.7 # via types-all types-chardet==5.0.4.6 # via types-all types-click==7.1.8 # via types-all, types-flask -types-click-spinner==0.1.13.5 # via types-all +types-click-spinner==0.1.13.6 # via types-all types-colorama==0.4.15.12 # via types-all types-contextvars==2.4.7.3 # via types-all -types-croniter==1.4.0.1 # via types-all +types-croniter==2.0.0.0 # via types-all types-cryptography==3.3.23.2 # via types-all, types-openssl-python, types-pyjwt types-dataclasses==0.6.6 # via types-all types-dateparser==1.1.4.10 # via types-all @@ -176,44 +189,44 @@ types-futures==3.3.8 # via types-all types-geoip2==3.0.0 # via types-all types-ipaddress==1.0.8 # via types-all, types-maxminddb types-itsdangerous==1.1.6 # via types-all -types-jack-client==0.5.10.9 # via types-all +types-jack-client==0.5.10.10 # via types-all types-jinja2==2.11.9 # via types-all, types-flask types-kazoo==0.1.3 # via types-all -types-markdown==3.4.2.10 # via types-all +types-markdown==3.5.0.3 # via types-all types-markupsafe==1.1.10 # via types-all, types-jinja2 types-maxminddb==1.5.0 # via types-all, types-geoip2 -types-mock==5.1.0.2 # via types-all +types-mock==5.1.0.3 # via types-all types-mypy-extensions==1.0.0.5 # via types-all types-nmap==0.1.6 # via types-all types-openssl-python==0.1.3 # via types-all types-orjson==3.6.2 # via types-all -types-paramiko==3.3.0.0 # via types-all, types-pysftp +types-paramiko==3.3.0.2 # via types-all, types-pysftp types-pathlib2==2.3.0 # via types-all -types-pillow==10.0.0.3 # via types-all +types-pillow==10.1.0.2 # via types-all types-pkg-resources==0.1.3 # via types-all types-polib==1.2.0.1 # via types-all -types-protobuf==4.24.0.1 # via types-all +types-protobuf==4.24.0.4 # via types-all types-pyaudio==0.2.16.7 # via types-all types-pycurl==7.45.2.5 # via types-all types-pyfarmhash==0.3.1.2 # via types-all types-pyjwt==1.7.1 # via types-all types-pymssql==2.1.0 # via types-all types-pymysql==1.1.0.1 # via types-all -types-pyopenssl==23.2.0.2 # via types-redis +types-pyopenssl==23.3.0.0 # via types-redis types-pyrfc3339==1.1.1.5 # via types-all types-pysftp==0.2.17.6 # via types-all types-python-dateutil==2.8.19.14 # via types-all, types-datetimerange types-python-gflags==3.1.7.3 # via types-all types-python-slugify==8.0.0.3 # via types-all -types-pytz==2023.3.1.0 # via types-all, types-tzlocal +types-pytz==2023.3.1.1 # via types-all, types-tzlocal types-pyvmomi==8.0.0.6 # via types-all -types-pyyaml==6.0.12.11 # via types-all -types-redis==4.6.0.6 # via types-all -types-requests==2.31.0.2 # via types-all +types-pyyaml==6.0.12.12 # via types-all +types-redis==4.6.0.11 # via types-all +types-requests==2.31.0.10 # via types-all types-retry==0.9.9.4 # via types-all types-routes==2.5.0 # via types-all types-scribe==2.0.0 # via types-all -types-setuptools==68.2.0.0 # via types-cffi +types-setuptools==69.0.0.0 # via types-cffi types-simplejson==3.19.0.2 # via types-all types-singledispatch==4.1.0.0 # via types-all types-six==1.16.21.9 # via types-all @@ -222,21 +235,21 @@ types-termcolor==1.1.6.2 # via types-all types-toml==0.10.8.7 # via types-all types-tornado==5.1.1 # via types-all types-typed-ast==1.5.8.7 # via types-all -types-tzlocal==5.0.1.1 # via types-all +types-tzlocal==5.1.0.1 # via types-all types-ujson==5.8.0.1 # via types-all -types-urllib3==1.26.25.14 # via types-requests types-waitress==2.1.4.9 # via types-all types-werkzeug==1.0.9 # via types-all, types-flask types-xxhash==3.0.5.2 # via types-all -typing-extensions==4.5.0 # via black, faker, gt4py (pyproject.toml), mypy, pytest-factoryboy -urllib3==2.0.4 # via requests -virtualenv==20.24.5 # via pre-commit, tox -websockets==11.0.3 # via dace -werkzeug==2.3.7 # via flask -wheel==0.41.2 # via astunparse, pip-tools +typing-extensions==4.5.0 # via black, faker, gt4py (pyproject.toml), mypy, pydantic, pytest-factoryboy, setuptools-scm +urllib3==2.1.0 # via requests, types-requests +virtualenv==20.25.0 # via pre-commit, tox +websockets==12.0 # via dace +werkzeug==3.0.1 # via flask +wheel==0.42.0 # via astunparse, pip-tools, scikit-build xxhash==3.0.0 # via gt4py (pyproject.toml) -zipp==3.16.2 # via importlib-metadata, importlib-resources +yarg==0.1.9 # via pipreqs +zipp==3.17.0 # via importlib-metadata, importlib-resources # The following packages are considered to be unsafe in a requirements file: -pip==23.2.1 # via pip-tools -setuptools==68.2.2 # via gt4py (pyproject.toml), nodeenv, pip-tools +pip==23.3.1 # via pip-api, pip-tools, requirementslib +setuptools==69.0.2 # via gt4py (pyproject.toml), nodeenv, pip-tools, requirementslib, scikit-build, setuptools-scm diff --git a/min-extra-requirements-test.txt b/min-extra-requirements-test.txt index 17709206a0..fd7724bac9 100644 --- a/min-extra-requirements-test.txt +++ b/min-extra-requirements-test.txt @@ -25,7 +25,7 @@ cmake==3.22 cogapp==3.3 coverage[toml]==5.0 cytoolz==0.12.0 -dace==0.14.2 +dace==0.15.1 darglint==1.6 deepdiff==5.6.0 devtools==0.6 @@ -70,7 +70,7 @@ scipy==1.7.2 setuptools==65.5.0 sphinx==4.4 sphinx_rtd_theme==1.0 -sympy==1.7 +sympy==1.9 tabulate==0.8.10 tomli==2.0.1 tox==3.2.0 diff --git a/pyproject.toml b/pyproject.toml index 5d7a2f2cb6..675bdae9d0 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -69,15 +69,15 @@ requires-python = '>=3.8' cuda = ['cupy>=12.0'] cuda11x = ['cupy-cuda11x>=12.0'] cuda12x = ['cupy-cuda12x>=12.0'] -dace = ['dace>=0.14.2,<0.15', 'sympy>=1.7'] +dace = ['dace>=0.15.1,<0.16', 'sympy>=1.9'] formatting = ['clang-format>=9.0'] # Always add all extra packages to 'full' for a simple full gt4py installation full = [ 'clang-format>=9.0', - 'dace>=0.14.2,<0.15', + 'dace>=0.15.1,<0.16', 'hypothesis>=6.0.0', 'pytest>=7.0', - 'sympy>=1.7', + 'sympy>=1.9', 'scipy>=1.7.2', 'jax[cpu]>=0.4.13' ] diff --git a/requirements-dev.txt b/requirements-dev.txt index d6dcc12d21..0fa523866f 100644 --- a/requirements-dev.txt +++ b/requirements-dev.txt @@ -6,124 +6,136 @@ # aenum==3.1.15 # via dace alabaster==0.7.13 # via sphinx -asttokens==2.4.0 # via devtools +asttokens==2.4.1 # via devtools astunparse==1.6.3 ; python_version < "3.9" # via dace, gt4py (pyproject.toml) attrs==23.1.0 # via flake8-bugbear, flake8-eradicate, gt4py (pyproject.toml), hypothesis, jsonschema, referencing -babel==2.12.1 # via sphinx -black==23.9.1 # via gt4py (pyproject.toml) -blinker==1.6.2 # via flask -boltons==23.0.0 # via gt4py (pyproject.toml) +babel==2.13.1 # via sphinx +black==23.11.0 # via gt4py (pyproject.toml) +blinker==1.7.0 # via flask +boltons==23.1.1 # via gt4py (pyproject.toml) build==1.0.3 # via pip-tools cached-property==1.5.2 # via gt4py (pyproject.toml) -cachetools==5.3.1 # via tox -certifi==2023.7.22 # via requests -cffi==1.15.1 # via cryptography +cachetools==5.3.2 # via tox +cerberus==1.3.5 # via plette +certifi==2023.11.17 # via requests +cffi==1.16.0 # via cryptography cfgv==3.4.0 # via pre-commit chardet==5.2.0 # via tox -charset-normalizer==3.2.0 # via requests -clang-format==16.0.6 # via -r requirements-dev.in, gt4py (pyproject.toml) +charset-normalizer==3.3.2 # via requests +clang-format==17.0.6 # via -r requirements-dev.in, gt4py (pyproject.toml) click==8.1.7 # via black, flask, gt4py (pyproject.toml), pip-tools -cmake==3.27.5 # via gt4py (pyproject.toml) +cmake==3.27.9 # via dace, gt4py (pyproject.toml) cogapp==3.3.0 # via -r requirements-dev.in colorama==0.4.6 # via tox -coverage[toml]==7.3.1 # via -r requirements-dev.in, pytest-cov -cryptography==41.0.3 # via types-paramiko, types-pyopenssl, types-redis +coverage[toml]==7.3.2 # via -r requirements-dev.in, pytest-cov +cryptography==41.0.7 # via types-paramiko, types-pyopenssl, types-redis cytoolz==0.12.2 # via gt4py (pyproject.toml) -dace==0.14.4 # via gt4py (pyproject.toml) +dace==0.15.1 # via gt4py (pyproject.toml) darglint==1.8.1 # via -r requirements-dev.in -deepdiff==6.5.0 # via gt4py (pyproject.toml) +deepdiff==6.7.1 # via gt4py (pyproject.toml) devtools==0.12.2 # via gt4py (pyproject.toml) dill==0.3.7 # via dace -distlib==0.3.7 # via virtualenv -docutils==0.18.1 # via restructuredtext-lint, sphinx, sphinx-rtd-theme +distlib==0.3.7 # via requirementslib, virtualenv +distro==1.8.0 # via scikit-build +docopt==0.6.2 # via pipreqs +docutils==0.20.1 # via restructuredtext-lint, sphinx, sphinx-rtd-theme eradicate==2.3.0 # via flake8-eradicate -exceptiongroup==1.1.3 # via hypothesis, pytest +exceptiongroup==1.2.0 # via hypothesis, pytest execnet==2.0.2 # via pytest-cache, pytest-xdist -executing==1.2.0 # via devtools +executing==2.0.1 # via devtools factory-boy==3.3.0 # via -r requirements-dev.in, pytest-factoryboy -faker==19.6.1 # via factory-boy -fastjsonschema==2.18.0 # via nbformat -filelock==3.12.4 # via tox, virtualenv +faker==20.1.0 # via factory-boy +fastjsonschema==2.19.0 # via nbformat +filelock==3.13.1 # via tox, virtualenv flake8==6.1.0 # via -r requirements-dev.in, flake8-bugbear, flake8-builtins, flake8-debugger, flake8-docstrings, flake8-eradicate, flake8-mutable, flake8-pyproject, flake8-rst-docstrings -flake8-bugbear==23.9.16 # via -r requirements-dev.in -flake8-builtins==2.1.0 # via -r requirements-dev.in +flake8-bugbear==23.12.2 # via -r requirements-dev.in +flake8-builtins==2.2.0 # via -r requirements-dev.in flake8-debugger==4.1.2 # via -r requirements-dev.in flake8-docstrings==1.7.0 # via -r requirements-dev.in flake8-eradicate==1.5.0 # via -r requirements-dev.in flake8-mutable==1.2.0 # via -r requirements-dev.in flake8-pyproject==1.2.3 # via -r requirements-dev.in flake8-rst-docstrings==0.3.0 # via -r requirements-dev.in -flask==2.3.3 # via dace -frozendict==2.3.8 # via gt4py (pyproject.toml) +flask==3.0.0 # via dace +fparser==0.1.3 # via dace +frozendict==2.3.10 # via gt4py (pyproject.toml) gridtools-cpp==2.3.1 # via gt4py (pyproject.toml) -hypothesis==6.86.1 # via -r requirements-dev.in, gt4py (pyproject.toml) -identify==2.5.29 # via pre-commit -idna==3.4 # via requests +hypothesis==6.92.0 # via -r requirements-dev.in, gt4py (pyproject.toml) +identify==2.5.33 # via pre-commit +idna==3.6 # via requests imagesize==1.4.1 # via sphinx -importlib-metadata==6.8.0 # via build, flask, sphinx -importlib-resources==6.0.1 ; python_version < "3.9" # via gt4py (pyproject.toml), jsonschema, jsonschema-specifications +importlib-metadata==7.0.0 # via build, flask, fparser, sphinx +importlib-resources==6.1.1 ; python_version < "3.9" # via gt4py (pyproject.toml), jsonschema, jsonschema-specifications inflection==0.5.1 # via pytest-factoryboy iniconfig==2.0.0 # via pytest -isort==5.12.0 # via -r requirements-dev.in +isort==5.13.0 # via -r requirements-dev.in itsdangerous==2.1.2 # via flask jinja2==3.1.2 # via flask, gt4py (pyproject.toml), sphinx -jsonschema==4.19.0 # via nbformat -jsonschema-specifications==2023.7.1 # via jsonschema -jupyter-core==5.3.1 # via nbformat -jupytext==1.15.2 # via -r requirements-dev.in -lark==1.1.7 # via gt4py (pyproject.toml) -mako==1.2.4 # via gt4py (pyproject.toml) +jsonschema==4.20.0 # via nbformat +jsonschema-specifications==2023.11.2 # via jsonschema +jupyter-core==5.5.0 # via nbformat +jupytext==1.16.0 # via -r requirements-dev.in +lark==1.1.8 # via gt4py (pyproject.toml) +mako==1.3.0 # via gt4py (pyproject.toml) markdown-it-py==3.0.0 # via jupytext, mdit-py-plugins markupsafe==2.1.3 # via jinja2, mako, werkzeug mccabe==0.7.0 # via flake8 mdit-py-plugins==0.4.0 # via jupytext mdurl==0.1.2 # via markdown-it-py mpmath==1.3.0 # via sympy -mypy==1.5.1 # via -r requirements-dev.in +mypy==1.7.1 # via -r requirements-dev.in mypy-extensions==1.0.0 # via black, mypy -nanobind==1.5.2 # via gt4py (pyproject.toml) +nanobind==1.8.0 # via gt4py (pyproject.toml) nbformat==5.9.2 # via jupytext networkx==3.1 # via dace -ninja==1.11.1 # via gt4py (pyproject.toml) +ninja==1.11.1.1 # via gt4py (pyproject.toml) nodeenv==1.8.0 # via pre-commit numpy==1.24.4 # via dace, gt4py (pyproject.toml), types-jack-client ordered-set==4.1.0 # via deepdiff -packaging==23.1 # via black, build, gt4py (pyproject.toml), pyproject-api, pytest, sphinx, tox -pathspec==0.11.2 # via black +packaging==23.2 # via black, build, gt4py (pyproject.toml), jupytext, pyproject-api, pytest, scikit-build, setuptools-scm, sphinx, tox +pathspec==0.12.1 # via black +pep517==0.13.1 # via requirementslib +pip-api==0.0.30 # via isort pip-tools==7.3.0 # via -r requirements-dev.in -pipdeptree==2.13.0 # via -r requirements-dev.in +pipdeptree==2.13.1 # via -r requirements-dev.in +pipreqs==0.4.13 # via isort pkgutil-resolve-name==1.3.10 # via jsonschema -platformdirs==3.10.0 # via black, jupyter-core, tox, virtualenv +platformdirs==4.1.0 # via black, jupyter-core, requirementslib, tox, virtualenv +plette[validation]==0.4.4 # via requirementslib pluggy==1.3.0 # via pytest, tox ply==3.11 # via dace -pre-commit==3.4.0 # via -r requirements-dev.in -psutil==5.9.5 # via -r requirements-dev.in, pytest-xdist +pre-commit==3.5.0 # via -r requirements-dev.in +psutil==5.9.6 # via -r requirements-dev.in, pytest-xdist pybind11==2.11.1 # via gt4py (pyproject.toml) -pycodestyle==2.11.0 # via flake8, flake8-debugger +pycodestyle==2.11.1 # via flake8, flake8-debugger pycparser==2.21 # via cffi +pydantic==1.10.13 # via requirementslib pydocstyle==6.3.0 # via flake8-docstrings pyflakes==3.1.0 # via flake8 -pygments==2.16.1 # via -r requirements-dev.in, devtools, flake8-rst-docstrings, sphinx +pygments==2.17.2 # via -r requirements-dev.in, devtools, flake8-rst-docstrings, sphinx pyproject-api==1.6.1 # via tox pyproject-hooks==1.0.0 # via build -pytest==7.4.2 # via -r requirements-dev.in, gt4py (pyproject.toml), pytest-cache, pytest-cov, pytest-factoryboy, pytest-xdist +pytest==7.4.3 # via -r requirements-dev.in, gt4py (pyproject.toml), pytest-cache, pytest-cov, pytest-factoryboy, pytest-xdist pytest-cache==1.0 # via -r requirements-dev.in pytest-cov==4.1.0 # via -r requirements-dev.in -pytest-factoryboy==2.5.1 # via -r requirements-dev.in -pytest-xdist[psutil]==3.3.1 # via -r requirements-dev.in +pytest-factoryboy==2.6.0 # via -r requirements-dev.in +pytest-xdist[psutil]==3.5.0 # via -r requirements-dev.in python-dateutil==2.8.2 # via faker pytz==2023.3.post1 # via babel pyyaml==6.0.1 # via dace, jupytext, pre-commit -referencing==0.30.2 # via jsonschema, jsonschema-specifications -requests==2.31.0 # via dace, sphinx +referencing==0.32.0 # via jsonschema, jsonschema-specifications +requests==2.31.0 # via dace, requirementslib, sphinx, yarg +requirementslib==3.0.0 # via isort restructuredtext-lint==1.4.0 # via flake8-rst-docstrings -rpds-py==0.10.3 # via jsonschema, referencing -ruff==0.0.290 # via -r requirements-dev.in +rpds-py==0.13.2 # via jsonschema, referencing +ruff==0.1.7 # via -r requirements-dev.in +scikit-build==0.17.6 # via dace +setuptools-scm==8.0.4 # via fparser six==1.16.0 # via asttokens, astunparse, python-dateutil snowballstemmer==2.2.0 # via pydocstyle, sphinx sortedcontainers==2.4.0 # via hypothesis sphinx==7.1.2 # via -r requirements-dev.in, sphinx-rtd-theme, sphinxcontrib-jquery -sphinx-rtd-theme==1.3.0 # via -r requirements-dev.in +sphinx-rtd-theme==2.0.0 # via -r requirements-dev.in sphinxcontrib-applehelp==1.0.4 # via sphinx sphinxcontrib-devhelp==1.0.2 # via sphinx sphinxcontrib-htmlhelp==2.0.1 # via sphinx @@ -131,31 +143,32 @@ sphinxcontrib-jquery==4.1 # via sphinx-rtd-theme sphinxcontrib-jsmath==1.0.1 # via sphinx sphinxcontrib-qthelp==1.0.3 # via sphinx sphinxcontrib-serializinghtml==1.1.5 # via sphinx -sympy==1.12 # via dace, gt4py (pyproject.toml) +sympy==1.9 # via dace, gt4py (pyproject.toml) tabulate==0.9.0 # via gt4py (pyproject.toml) toml==0.10.2 # via jupytext -tomli==2.0.1 # via -r requirements-dev.in, black, build, coverage, flake8-pyproject, mypy, pip-tools, pyproject-api, pyproject-hooks, pytest, tox +tomli==2.0.1 # via -r requirements-dev.in, black, build, coverage, flake8-pyproject, mypy, pep517, pip-tools, pyproject-api, pyproject-hooks, pytest, scikit-build, setuptools-scm, tox +tomlkit==0.12.3 # via plette, requirementslib toolz==0.12.0 # via cytoolz -tox==4.11.3 # via -r requirements-dev.in -traitlets==5.10.0 # via jupyter-core, nbformat +tox==4.11.4 # via -r requirements-dev.in +traitlets==5.14.0 # via jupyter-core, nbformat types-aiofiles==23.2.0.0 # via types-all types-all==1.0.0 # via -r requirements-dev.in types-annoy==1.17.8.4 # via types-all types-atomicwrites==1.4.5.1 # via types-all types-backports==0.1.3 # via types-all types-backports-abc==0.5.2 # via types-all -types-bleach==6.0.0.4 # via types-all +types-bleach==6.1.0.1 # via types-all types-boto==2.49.18.9 # via types-all -types-cachetools==5.3.0.6 # via types-all +types-cachetools==5.3.0.7 # via types-all types-certifi==2021.10.8.3 # via types-all -types-cffi==1.15.1.15 # via types-jack-client +types-cffi==1.16.0.0 # via types-jack-client types-characteristic==14.3.7 # via types-all types-chardet==5.0.4.6 # via types-all types-click==7.1.8 # via types-all, types-flask -types-click-spinner==0.1.13.5 # via types-all +types-click-spinner==0.1.13.6 # via types-all types-colorama==0.4.15.12 # via types-all types-contextvars==2.4.7.3 # via types-all -types-croniter==1.4.0.1 # via types-all +types-croniter==2.0.0.0 # via types-all types-cryptography==3.3.23.2 # via types-all, types-openssl-python, types-pyjwt types-dataclasses==0.6.6 # via types-all types-dateparser==1.1.4.10 # via types-all @@ -176,44 +189,44 @@ types-futures==3.3.8 # via types-all types-geoip2==3.0.0 # via types-all types-ipaddress==1.0.8 # via types-all, types-maxminddb types-itsdangerous==1.1.6 # via types-all -types-jack-client==0.5.10.9 # via types-all +types-jack-client==0.5.10.10 # via types-all types-jinja2==2.11.9 # via types-all, types-flask types-kazoo==0.1.3 # via types-all -types-markdown==3.4.2.10 # via types-all +types-markdown==3.5.0.3 # via types-all types-markupsafe==1.1.10 # via types-all, types-jinja2 types-maxminddb==1.5.0 # via types-all, types-geoip2 -types-mock==5.1.0.2 # via types-all +types-mock==5.1.0.3 # via types-all types-mypy-extensions==1.0.0.5 # via types-all types-nmap==0.1.6 # via types-all types-openssl-python==0.1.3 # via types-all types-orjson==3.6.2 # via types-all -types-paramiko==3.3.0.0 # via types-all, types-pysftp +types-paramiko==3.3.0.2 # via types-all, types-pysftp types-pathlib2==2.3.0 # via types-all -types-pillow==10.0.0.3 # via types-all +types-pillow==10.1.0.2 # via types-all types-pkg-resources==0.1.3 # via types-all types-polib==1.2.0.1 # via types-all -types-protobuf==4.24.0.1 # via types-all +types-protobuf==4.24.0.4 # via types-all types-pyaudio==0.2.16.7 # via types-all types-pycurl==7.45.2.5 # via types-all types-pyfarmhash==0.3.1.2 # via types-all types-pyjwt==1.7.1 # via types-all types-pymssql==2.1.0 # via types-all types-pymysql==1.1.0.1 # via types-all -types-pyopenssl==23.2.0.2 # via types-redis +types-pyopenssl==23.3.0.0 # via types-redis types-pyrfc3339==1.1.1.5 # via types-all types-pysftp==0.2.17.6 # via types-all types-python-dateutil==2.8.19.14 # via types-all, types-datetimerange types-python-gflags==3.1.7.3 # via types-all types-python-slugify==8.0.0.3 # via types-all -types-pytz==2023.3.1.0 # via types-all, types-tzlocal +types-pytz==2023.3.1.1 # via types-all, types-tzlocal types-pyvmomi==8.0.0.6 # via types-all -types-pyyaml==6.0.12.11 # via types-all -types-redis==4.6.0.6 # via types-all -types-requests==2.31.0.2 # via types-all +types-pyyaml==6.0.12.12 # via types-all +types-redis==4.6.0.11 # via types-all +types-requests==2.31.0.10 # via types-all types-retry==0.9.9.4 # via types-all types-routes==2.5.0 # via types-all types-scribe==2.0.0 # via types-all -types-setuptools==68.2.0.0 # via types-cffi +types-setuptools==69.0.0.0 # via types-cffi types-simplejson==3.19.0.2 # via types-all types-singledispatch==4.1.0.0 # via types-all types-six==1.16.21.9 # via types-all @@ -222,21 +235,21 @@ types-termcolor==1.1.6.2 # via types-all types-toml==0.10.8.7 # via types-all types-tornado==5.1.1 # via types-all types-typed-ast==1.5.8.7 # via types-all -types-tzlocal==5.0.1.1 # via types-all +types-tzlocal==5.1.0.1 # via types-all types-ujson==5.8.0.1 # via types-all -types-urllib3==1.26.25.14 # via types-requests types-waitress==2.1.4.9 # via types-all types-werkzeug==1.0.9 # via types-all, types-flask types-xxhash==3.0.5.2 # via types-all -typing-extensions==4.5.0 # via black, faker, gt4py (pyproject.toml), mypy, pytest-factoryboy -urllib3==2.0.4 # via requests -virtualenv==20.24.5 # via pre-commit, tox -websockets==11.0.3 # via dace -werkzeug==2.3.7 # via flask -wheel==0.41.2 # via astunparse, pip-tools +typing-extensions==4.5.0 # via black, faker, gt4py (pyproject.toml), mypy, pydantic, pytest-factoryboy, setuptools-scm +urllib3==2.1.0 # via requests, types-requests +virtualenv==20.25.0 # via pre-commit, tox +websockets==12.0 # via dace +werkzeug==3.0.1 # via flask +wheel==0.42.0 # via astunparse, pip-tools, scikit-build xxhash==3.0.0 # via gt4py (pyproject.toml) -zipp==3.16.2 # via importlib-metadata, importlib-resources +yarg==0.1.9 # via pipreqs +zipp==3.17.0 # via importlib-metadata, importlib-resources # The following packages are considered to be unsafe in a requirements file: -pip==23.2.1 # via pip-tools -setuptools==68.2.2 # via gt4py (pyproject.toml), nodeenv, pip-tools +pip==23.3.1 # via pip-api, pip-tools, requirementslib +setuptools==69.0.2 # via gt4py (pyproject.toml), nodeenv, pip-tools, requirementslib, scikit-build, setuptools-scm diff --git a/src/gt4py/__init__.py b/src/gt4py/__init__.py index 7d255de142..c28c5cf2d6 100644 --- a/src/gt4py/__init__.py +++ b/src/gt4py/__init__.py @@ -33,6 +33,6 @@ if _sys.version_info >= (3, 10): - from . import next + from . import next # noqa: A004 __all__ += ["next"] diff --git a/src/gt4py/cartesian/backend/dace_backend.py b/src/gt4py/cartesian/backend/dace_backend.py index b1e559a41e..5dae025acb 100644 --- a/src/gt4py/cartesian/backend/dace_backend.py +++ b/src/gt4py/cartesian/backend/dace_backend.py @@ -562,12 +562,6 @@ def apply(cls, stencil_ir: gtir.Stencil, builder: "StencilBuilder", sdfg: dace.S omp_threads = "" omp_header = "" - # Backward compatible state struct name change in DaCe >=0.15.x - try: - dace_state_suffix = dace.Config.get("compiler.codegen_state_struct_suffix") - except (KeyError, TypeError): - dace_state_suffix = "_t" # old structure name - interface = cls.template.definition.render( name=sdfg.name, backend_specifics=omp_threads, @@ -575,7 +569,7 @@ def apply(cls, stencil_ir: gtir.Stencil, builder: "StencilBuilder", sdfg: dace.S functor_args=self.generate_functor_args(sdfg), tmp_allocs=self.generate_tmp_allocs(sdfg), allocator="gt::cuda_util::cuda_malloc" if is_gpu else "std::make_unique", - state_suffix=dace_state_suffix, + state_suffix=dace.Config.get("compiler.codegen_state_struct_suffix"), ) generated_code = textwrap.dedent( f"""#include diff --git a/src/gt4py/cartesian/gtc/dace/nodes.py b/src/gt4py/cartesian/gtc/dace/nodes.py index ddcb719b5f..bd8c08034c 100644 --- a/src/gt4py/cartesian/gtc/dace/nodes.py +++ b/src/gt4py/cartesian/gtc/dace/nodes.py @@ -121,7 +121,7 @@ def __init__( *args, **kwargs, ): - super().__init__(name=name, *args, **kwargs) + super().__init__(*args, name=name, **kwargs) from gt4py.cartesian.gtc.dace.utils import compute_dcir_access_infos diff --git a/src/gt4py/eve/datamodels/core.py b/src/gt4py/eve/datamodels/core.py index fcd53d1312..5660fdbf76 100644 --- a/src/gt4py/eve/datamodels/core.py +++ b/src/gt4py/eve/datamodels/core.py @@ -814,7 +814,7 @@ def concretize( """ # noqa: RST301 # doctest conventions confuse RST validator concrete_cls: Type[DataModelT] = _make_concrete_with_cache( - datamodel_cls, *type_args, class_name=class_name, module=module + datamodel_cls, *type_args, class_name=class_name, module=module # type: ignore[arg-type] ) assert isinstance(concrete_cls, type) and is_datamodel(concrete_cls) diff --git a/src/gt4py/eve/utils.py b/src/gt4py/eve/utils.py index 7104f7658f..624407f319 100644 --- a/src/gt4py/eve/utils.py +++ b/src/gt4py/eve/utils.py @@ -1225,7 +1225,7 @@ def unzip(self) -> XIterable[Tuple[Any, ...]]: [('a', 'b', 'c'), (1, 2, 3)] """ - return XIterable(zip(*self.iterator)) # type: ignore # mypy gets confused with *args + return XIterable(zip(*self.iterator)) @typing.overload def islice(self, __stop: int) -> XIterable[T]: @@ -1536,7 +1536,7 @@ def reduceby( ) -> Dict[K, S]: ... - def reduceby( # type: ignore[misc] # signatures 2 and 4 are not satified due to inconsistencies with type variables + def reduceby( self, bin_op_func: Callable[[S, T], S], key: Union[str, List[K], Callable[[T], K]], diff --git a/src/gt4py/next/common.py b/src/gt4py/next/common.py index 6bf6858369..949f4b461a 100644 --- a/src/gt4py/next/common.py +++ b/src/gt4py/next/common.py @@ -189,11 +189,12 @@ def __and__(self, other: UnitRange) -> UnitRange: return UnitRange(max(self.start, other.start), min(self.stop, other.stop)) def __contains__(self, value: Any) -> bool: - return ( - isinstance(value, core_defs.INTEGRAL_TYPES) - and value >= self.start - and value < self.stop - ) + # TODO(egparedes): use core_defs.IntegralScalar for `isinstance()` checks (see PEP 604) + # and remove int cast, once the related mypy bug (#16358) gets fixed + if isinstance(value, core_defs.INTEGRAL_TYPES): + return self.start <= cast(int, value) < self.stop + else: + return False def __le__(self, other: UnitRange) -> bool: return self.start >= other.start and self.stop <= other.stop diff --git a/src/gt4py/next/ffront/fbuiltins.py b/src/gt4py/next/ffront/fbuiltins.py index 278dde9180..cd75538da7 100644 --- a/src/gt4py/next/ffront/fbuiltins.py +++ b/src/gt4py/next/ffront/fbuiltins.py @@ -15,7 +15,7 @@ import dataclasses import functools import inspect -from builtins import bool, float, int, tuple +from builtins import bool, float, int, tuple # noqa: A004 from typing import Any, Callable, Generic, ParamSpec, Tuple, TypeAlias, TypeVar, Union, cast import numpy as np diff --git a/src/gt4py/next/otf/workflow.py b/src/gt4py/next/otf/workflow.py index ed8b768972..3a82f9c738 100644 --- a/src/gt4py/next/otf/workflow.py +++ b/src/gt4py/next/otf/workflow.py @@ -82,7 +82,7 @@ def replace(self, **kwargs: Any) -> Self: if not dataclasses.is_dataclass(self): raise TypeError(f"'{self.__class__}' is not a dataclass.") assert not isinstance(self, type) - return dataclasses.replace(self, **kwargs) # type: ignore[misc] # `self` is guaranteed to be a dataclass (is_dataclass) should be a `TypeGuard`? + return dataclasses.replace(self, **kwargs) class ChainableWorkflowMixin(Workflow[StartT, EndT]): diff --git a/src/gt4py/next/program_processors/runners/dace_iterator/__init__.py b/src/gt4py/next/program_processors/runners/dace_iterator/__init__.py index 7fd4794e57..fdd8a61054 100644 --- a/src/gt4py/next/program_processors/runners/dace_iterator/__init__.py +++ b/src/gt4py/next/program_processors/runners/dace_iterator/__init__.py @@ -260,10 +260,12 @@ def build_sdfg_from_itir( # visit ITIR and generate SDFG program = preprocess_program(program, offset_provider, lift_mode) - # TODO: According to Lex one should build the SDFG first in a general mannor. - # Generalisation to a particular device should happen only at the end. - sdfg_genenerator = ItirToSDFG(arg_types, offset_provider, column_axis, on_gpu) + sdfg_genenerator = ItirToSDFG(arg_types, offset_provider, column_axis) sdfg = sdfg_genenerator.visit(program) + if sdfg is None: + raise RuntimeError(f"Visit failed for program {program.id}.") + + # run DaCe transformations to simplify the SDFG sdfg.simplify() # run DaCe auto-optimization heuristics @@ -274,6 +276,9 @@ def build_sdfg_from_itir( device = dace.DeviceType.GPU if on_gpu else dace.DeviceType.CPU sdfg = autoopt.auto_optimize(sdfg, device, symbols=symbols, use_gpu_storage=on_gpu) + if on_gpu: + sdfg.apply_gpu_transformations() + return sdfg @@ -283,7 +288,7 @@ def run_dace_iterator(program: itir.FencilDefinition, *args, **kwargs): compiler_args = kwargs.get("compiler_args", None) # `None` will take default. build_type = kwargs.get("build_type", "RelWithDebInfo") on_gpu = kwargs.get("on_gpu", False) - auto_optimize = kwargs.get("auto_optimize", False) + auto_optimize = kwargs.get("auto_optimize", True) lift_mode = kwargs.get("lift_mode", itir_transforms.LiftMode.FORCE_INLINE) # ITIR parameters column_axis = kwargs.get("column_axis", None) diff --git a/src/gt4py/next/program_processors/runners/dace_iterator/itir_to_sdfg.py b/src/gt4py/next/program_processors/runners/dace_iterator/itir_to_sdfg.py index e3b5ddf2ac..fb2f82fed0 100644 --- a/src/gt4py/next/program_processors/runners/dace_iterator/itir_to_sdfg.py +++ b/src/gt4py/next/program_processors/runners/dace_iterator/itir_to_sdfg.py @@ -99,20 +99,17 @@ class ItirToSDFG(eve.NodeVisitor): offset_provider: dict[str, Any] node_types: dict[int, next_typing.Type] unique_id: int - use_gpu_storage: bool def __init__( self, param_types: list[ts.TypeSpec], offset_provider: dict[str, NeighborTableOffsetProvider], column_axis: Optional[Dimension] = None, - use_gpu_storage: bool = False, ): self.param_types = param_types self.column_axis = column_axis self.offset_provider = offset_provider self.storage_types = {} - self.use_gpu_storage = use_gpu_storage def add_storage(self, sdfg: dace.SDFG, name: str, type_: ts.TypeSpec, has_offset: bool = True): if isinstance(type_, ts.FieldType): @@ -123,14 +120,7 @@ def add_storage(self, sdfg: dace.SDFG, name: str, type_: ts.TypeSpec, has_offset else None ) dtype = as_dace_type(type_.dtype) - storage = ( - dace.dtypes.StorageType.GPU_Global - if self.use_gpu_storage - else dace.dtypes.StorageType.Default - ) - sdfg.add_array( - name, shape=shape, strides=strides, offset=offset, dtype=dtype, storage=storage - ) + sdfg.add_array(name, shape=shape, strides=strides, offset=offset, dtype=dtype) elif isinstance(type_, ts.ScalarType): sdfg.add_symbol(name, as_dace_type(type_)) @@ -246,7 +236,6 @@ def visit_StencilClosure( shape=array_table[name].shape, strides=array_table[name].strides, dtype=array_table[name].dtype, - storage=array_table[name].storage, transient=True, ) closure_init_state.add_nedge( @@ -261,7 +250,6 @@ def visit_StencilClosure( shape=array_table[name].shape, strides=array_table[name].strides, dtype=array_table[name].dtype, - storage=array_table[name].storage, ) else: assert isinstance(self.storage_types[name], ts.ScalarType) diff --git a/tests/cartesian_tests/unit_tests/test_gtc/test_common.py b/tests/cartesian_tests/unit_tests/test_gtc/test_common.py index e580333bc8..8cfff12df4 100644 --- a/tests/cartesian_tests/unit_tests/test_gtc/test_common.py +++ b/tests/cartesian_tests/unit_tests/test_gtc/test_common.py @@ -312,7 +312,7 @@ def test_symbolref_validation_for_valid_tree(): SymbolTableRootNode( nodes=[SymbolChildNode(name="foo"), SymbolRefChildNode(name="foo")], ) - SymbolTableRootNode( + SymbolTableRootNode( # noqa: B018 nodes=[ SymbolChildNode(name="foo"), SymbolRefChildNode(name="foo"), diff --git a/tests/next_tests/integration_tests/feature_tests/ffront_tests/test_external_local_field.py b/tests/next_tests/integration_tests/feature_tests/ffront_tests/test_external_local_field.py index 698dce2b5c..d100cd380c 100644 --- a/tests/next_tests/integration_tests/feature_tests/ffront_tests/test_external_local_field.py +++ b/tests/next_tests/integration_tests/feature_tests/ffront_tests/test_external_local_field.py @@ -30,16 +30,6 @@ def test_external_local_field(unstructured_case): - # TODO(edopao): remove try/catch after uplift of dace module to version > 0.15 - try: - from gt4py.next.program_processors.runners.dace_iterator import run_dace_gpu - - if unstructured_case.backend == run_dace_gpu: - # see https://github.com/spcl/dace/pull/1442 - pytest.xfail("requires fix in dace module for cuda codegen") - except ImportError: - pass - @gtx.field_operator def testee( inp: gtx.Field[[Vertex, V2EDim], int32], ones: gtx.Field[[Edge], int32] diff --git a/tests/next_tests/integration_tests/feature_tests/ffront_tests/test_gt4py_builtins.py b/tests/next_tests/integration_tests/feature_tests/ffront_tests/test_gt4py_builtins.py index e8d0c8b163..e2434d860a 100644 --- a/tests/next_tests/integration_tests/feature_tests/ffront_tests/test_gt4py_builtins.py +++ b/tests/next_tests/integration_tests/feature_tests/ffront_tests/test_gt4py_builtins.py @@ -46,16 +46,6 @@ ids=["positive_values", "negative_values"], ) def test_maxover_execution_(unstructured_case, strategy): - # TODO(edopao): remove try/catch after uplift of dace module to version > 0.15 - try: - from gt4py.next.program_processors.runners.dace_iterator import run_dace_gpu - - if unstructured_case.backend == run_dace_gpu: - # see https://github.com/spcl/dace/pull/1442 - pytest.xfail("requires fix in dace module for cuda codegen") - except ImportError: - pass - if unstructured_case.backend in [ gtfn.run_gtfn, gtfn.run_gtfn_gpu, @@ -79,16 +69,6 @@ def testee(edge_f: cases.EField) -> cases.VField: @pytest.mark.uses_unstructured_shift def test_minover_execution(unstructured_case): - # TODO(edopao): remove try/catch after uplift of dace module to version > 0.15 - try: - from gt4py.next.program_processors.runners.dace_iterator import run_dace_gpu - - if unstructured_case.backend == run_dace_gpu: - # see https://github.com/spcl/dace/pull/1442 - pytest.xfail("requires fix in dace module for cuda codegen") - except ImportError: - pass - @gtx.field_operator def minover(edge_f: cases.EField) -> cases.VField: out = min_over(edge_f(V2E), axis=V2EDim) @@ -102,16 +82,6 @@ def minover(edge_f: cases.EField) -> cases.VField: @pytest.mark.uses_unstructured_shift def test_reduction_execution(unstructured_case): - # TODO(edopao): remove try/catch after uplift of dace module to version > 0.15 - try: - from gt4py.next.program_processors.runners.dace_iterator import run_dace_gpu - - if unstructured_case.backend == run_dace_gpu: - # see https://github.com/spcl/dace/pull/1442 - pytest.xfail("requires fix in dace module for cuda codegen") - except ImportError: - pass - @gtx.field_operator def reduction(edge_f: cases.EField) -> cases.VField: return neighbor_sum(edge_f(V2E), axis=V2EDim) @@ -150,16 +120,6 @@ def fencil(edge_f: cases.EField, out: cases.VField): @pytest.mark.uses_unstructured_shift def test_reduction_with_common_expression(unstructured_case): - # TODO(edopao): remove try/catch after uplift of dace module to version > 0.15 - try: - from gt4py.next.program_processors.runners.dace_iterator import run_dace_gpu - - if unstructured_case.backend == run_dace_gpu: - # see https://github.com/spcl/dace/pull/1442 - pytest.xfail("requires fix in dace module for cuda codegen") - except ImportError: - pass - @gtx.field_operator def testee(flux: cases.EField) -> cases.VField: return neighbor_sum(flux(V2E) + flux(V2E), axis=V2EDim) diff --git a/tests/next_tests/integration_tests/feature_tests/ffront_tests/test_temporaries_with_sizes.py b/tests/next_tests/integration_tests/feature_tests/ffront_tests/test_temporaries_with_sizes.py index da0945fe96..788081b81e 100644 --- a/tests/next_tests/integration_tests/feature_tests/ffront_tests/test_temporaries_with_sizes.py +++ b/tests/next_tests/integration_tests/feature_tests/ffront_tests/test_temporaries_with_sizes.py @@ -20,14 +20,20 @@ from gt4py.next.iterator.transforms import LiftMode, apply_common_transforms from gt4py.next.program_processors import otf_compile_executor from gt4py.next.program_processors.runners.gtfn import run_gtfn_with_temporaries -from tests.next_tests.integration_tests.cases import Case -from tests.next_tests.toy_connectivity import Cell, Edge from next_tests.integration_tests import cases -from next_tests.integration_tests.cases import E2V, KDim, Vertex, cartesian_case, unstructured_case +from next_tests.integration_tests.cases import ( + E2V, + Case, + KDim, + Vertex, + cartesian_case, + unstructured_case, +) from next_tests.integration_tests.feature_tests.ffront_tests.ffront_test_utils import ( reduction_setup, ) +from next_tests.toy_connectivity import Cell, Edge @pytest.fixture From 90e5d5a281fb02c1c4558e9097fc5fb980584321 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Enrique=20Gonz=C3=A1lez=20Paredes?= Date: Fri, 19 Jan 2024 13:44:39 +0100 Subject: [PATCH 47/85] feature[next]: add support for Python3.11 by fixing typing-related bugs (#1418) Fixes hidden bugs in `eve.datamodels` and `eve.extended_typing` to support Python 3.11. Actual bug fixes: - Previous fix to support `typing.Any` implementation as a class (https://github.com/python/cpython/commit/5a4973e29f2f5c4ee8c086f40325786c62381540) didn't work in 3.11. - Partially concretization of generic datamodels replacing typevars was broken. - Partially concretization of generic datamodels leaving some parameters as typevars was broken. Other changes: - Add python 3.11 as supported version. - Remove dead code in comments. - Fix some imports style to comply with our coding guidelines. --- .github/workflows/daily-ci.yml | 2 +- .github/workflows/test-cartesian-fallback.yml | 2 +- .github/workflows/test-cartesian.yml | 2 +- .github/workflows/test-eve-fallback.yml | 2 +- .github/workflows/test-eve.yml | 3 +- .github/workflows/test-next-fallback.yml | 2 +- .github/workflows/test-next.yml | 2 +- .github/workflows/test-storage-fallback.yml | 2 +- .github/workflows/test-storage.yml | 3 +- src/gt4py/eve/datamodels/core.py | 40 ++---- src/gt4py/eve/extended_typing.py | 15 ++- tests/eve_tests/unit_tests/test_datamodels.py | 117 ++++++++++-------- .../unit_tests/test_type_validation.py | 5 +- .../ffront_tests/test_icon_like_scan.py | 26 ++-- .../test_single_static_assign.py | 5 +- tox.ini | 35 +++--- 16 files changed, 140 insertions(+), 123 deletions(-) diff --git a/.github/workflows/daily-ci.yml b/.github/workflows/daily-ci.yml index 77ba39a361..8631390dbb 100644 --- a/.github/workflows/daily-ci.yml +++ b/.github/workflows/daily-ci.yml @@ -14,7 +14,7 @@ jobs: daily-ci: strategy: matrix: - python-version: ["3.8", "3.9", "3.10"] + python-version: ["3.8", "3.9", "3.10", "3.11"] tox-module-factor: ["cartesian", "eve", "next", "storage"] os: ["ubuntu-latest"] requirements-file: ["requirements-dev.txt", "min-requirements-test.txt", "min-extra-requirements-test.txt"] diff --git a/.github/workflows/test-cartesian-fallback.yml b/.github/workflows/test-cartesian-fallback.yml index b2eaead47a..7e9a948e9c 100644 --- a/.github/workflows/test-cartesian-fallback.yml +++ b/.github/workflows/test-cartesian-fallback.yml @@ -16,7 +16,7 @@ jobs: runs-on: ubuntu-latest strategy: matrix: - python-version: ["3.8", "3.9", "3.10"] + python-version: ["3.8", "3.9", "3.10", "3.11"] backends: [internal-cpu, dace-cpu] steps: diff --git a/.github/workflows/test-cartesian.yml b/.github/workflows/test-cartesian.yml index 2c2b97aaa6..ebdc4ce749 100644 --- a/.github/workflows/test-cartesian.yml +++ b/.github/workflows/test-cartesian.yml @@ -23,7 +23,7 @@ jobs: runs-on: ubuntu-latest strategy: matrix: - python-version: ["3.8", "3.9", "3.10"] + python-version: ["3.8", "3.9", "3.10", "3.11"] backends: [internal-cpu, dace-cpu] steps: - uses: actions/checkout@v2 diff --git a/.github/workflows/test-eve-fallback.yml b/.github/workflows/test-eve-fallback.yml index 93dc308a53..fd7ab5452c 100644 --- a/.github/workflows/test-eve-fallback.yml +++ b/.github/workflows/test-eve-fallback.yml @@ -17,7 +17,7 @@ jobs: test-eve: strategy: matrix: - python-version: ["3.8", "3.9", "3.10"] + python-version: ["3.8", "3.9", "3.10", "3.11"] os: ["ubuntu-latest"] runs-on: ${{ matrix.os }} diff --git a/.github/workflows/test-eve.yml b/.github/workflows/test-eve.yml index 1322c573db..222b825f38 100644 --- a/.github/workflows/test-eve.yml +++ b/.github/workflows/test-eve.yml @@ -20,7 +20,7 @@ jobs: test-eve: strategy: matrix: - python-version: ["3.8", "3.9", "3.10"] + python-version: ["3.8", "3.9", "3.10", "3.11"] os: ["ubuntu-latest"] fail-fast: false @@ -68,4 +68,3 @@ jobs: # with: # name: info-py${{ matrix.python-version }}-${{ matrix.os }} # path: info.txt - diff --git a/.github/workflows/test-next-fallback.yml b/.github/workflows/test-next-fallback.yml index 8490a3e393..bdcc061db0 100644 --- a/.github/workflows/test-next-fallback.yml +++ b/.github/workflows/test-next-fallback.yml @@ -15,7 +15,7 @@ jobs: test-next: strategy: matrix: - python-version: ["3.10"] + python-version: ["3.10", "3.11"] tox-env-factor: ["nomesh", "atlas"] os: ["ubuntu-latest"] diff --git a/.github/workflows/test-next.yml b/.github/workflows/test-next.yml index 52f8c25386..4282a22da6 100644 --- a/.github/workflows/test-next.yml +++ b/.github/workflows/test-next.yml @@ -18,7 +18,7 @@ jobs: test-next: strategy: matrix: - python-version: ["3.10"] + python-version: ["3.10", "3.11"] tox-env-factor: ["nomesh", "atlas"] os: ["ubuntu-latest"] fail-fast: false diff --git a/.github/workflows/test-storage-fallback.yml b/.github/workflows/test-storage-fallback.yml index 0cbc735564..99e4923de8 100644 --- a/.github/workflows/test-storage-fallback.yml +++ b/.github/workflows/test-storage-fallback.yml @@ -18,7 +18,7 @@ jobs: test-storage: strategy: matrix: - python-version: ["3.8", "3.9", "3.10"] + python-version: ["3.8", "3.9", "3.10", "3.11"] backends: [internal-cpu, dace-cpu] os: ["ubuntu-latest"] diff --git a/.github/workflows/test-storage.yml b/.github/workflows/test-storage.yml index 1133353f30..34841ed71c 100644 --- a/.github/workflows/test-storage.yml +++ b/.github/workflows/test-storage.yml @@ -21,7 +21,7 @@ jobs: test-storage: strategy: matrix: - python-version: ["3.8", "3.9", "3.10"] + python-version: ["3.8", "3.9", "3.10", "3.11"] backends: [internal-cpu, dace-cpu] os: ["ubuntu-latest"] fail-fast: false @@ -70,4 +70,3 @@ jobs: # with: # name: info-py${{ matrix.python-version }}-${{ matrix.os }} # path: info.txt - diff --git a/src/gt4py/eve/datamodels/core.py b/src/gt4py/eve/datamodels/core.py index 5660fdbf76..bc744b3ccc 100644 --- a/src/gt4py/eve/datamodels/core.py +++ b/src/gt4py/eve/datamodels/core.py @@ -883,17 +883,6 @@ def _substitute_typevars( return type_params_map[type_hint], True elif getattr(type_hint, "__parameters__", []): return type_hint[tuple(type_params_map[tp] for tp in type_hint.__parameters__)], True - # TODO(egparedes): WIP fix for partial specialization - # # Type hint is a generic model: replace all the concretized type vars - # noqa: e800 replaced = False - # noqa: e800 new_args = [] - # noqa: e800 for tp in type_hint.__parameters__: - # noqa: e800 if tp in type_params_map: - # noqa: e800 new_args.append(type_params_map[tp]) - # noqa: e800 replaced = True - # noqa: e800 else: - # noqa: e800 new_args.append(type_params_map[tp]) - # noqa: e800 return type_hint[tuple(new_args)], replaced else: return type_hint, False @@ -981,21 +970,14 @@ def __class_getitem__( """ type_args: Tuple[Type] = args if isinstance(args, tuple) else (args,) concrete_cls: Type[DataModelT] = concretize(cls, *type_args) - res = xtyping.StdGenericAliasType(concrete_cls, type_args) - if sys.version_info < (3, 9): - # in Python 3.8, xtyping.StdGenericAliasType (aka typing._GenericAlias) - # does not copy all required `__dict__` entries, so do it manually - for k, v in concrete_cls.__dict__.items(): - if k not in res.__dict__: - res.__dict__[k] = v - return res + return concrete_cls return classmethod(__class_getitem__) def _make_type_converter(type_annotation: TypeAnnotation, name: str) -> TypeConverter[_T]: - # TODO(egparedes): if a "typing tree" structure is implemented, refactor this code as a tree traversal. - # + # TODO(egparedes): if a "typing tree" structure is implemented, refactor this code + # as a tree traversal. if xtyping.is_actual_type(type_annotation) and not isinstance(None, type_annotation): assert not xtyping.get_args(type_annotation) assert isinstance(type_annotation, type) @@ -1316,11 +1298,7 @@ def _make_concrete_with_cache( # Replace field definitions with the new actual types for generic fields type_params_map = dict(zip(datamodel_cls.__parameters__, type_args)) model_fields = getattr(datamodel_cls, MODEL_FIELD_DEFINITIONS_ATTR) - new_annotations = { - # TODO(egparedes): ? - # noqa: e800 "__args__": "ClassVar[Tuple[Union[Type, TypeVar], ...]]", - # noqa: e800 "__parameters__": "ClassVar[Tuple[TypeVar, ...]]", - } + new_annotations = {} new_field_c_attrs = {} for field_name, field_type in xtyping.get_type_hints(datamodel_cls).items(): @@ -1353,8 +1331,16 @@ def _make_concrete_with_cache( "__module__": module if module else datamodel_cls.__module__, **new_field_c_attrs, } - concrete_cls = type(class_name, (datamodel_cls,), namespace) + + # Update the tuple of generic parameters in the new class, in case + # this is a partial concretization + assert hasattr(concrete_cls, "__parameters__") + concrete_cls.__parameters__ = tuple( + type_params_map[tp_var] + for tp_var in datamodel_cls.__parameters__ + if isinstance(type_params_map[tp_var], typing.TypeVar) + ) assert concrete_cls.__module__ == module or not module if MODEL_FIELD_DEFINITIONS_ATTR not in concrete_cls.__dict__: diff --git a/src/gt4py/eve/extended_typing.py b/src/gt4py/eve/extended_typing.py index 17462a37ff..3ee447ca6c 100644 --- a/src/gt4py/eve/extended_typing.py +++ b/src/gt4py/eve/extended_typing.py @@ -493,7 +493,7 @@ def _patched_proto_hook(other): # type: ignore[no-untyped-def] if isinstance(_typing.Any, type): # Python >= 3.11 _ArtefactTypes = (*_ArtefactTypes, _typing.Any) -# `Any` is a class since typing_extensions >= 4.4 +# `Any` is a class since typing_extensions >= 4.4 and Python 3.11 if (typing_exts_any := getattr(_typing_extensions, "Any", None)) is not _typing.Any and isinstance( typing_exts_any, type ): @@ -504,11 +504,13 @@ def is_actual_type(obj: Any) -> TypeGuard[Type]: """Check if an object has an actual type and instead of a typing artefact like ``GenericAlias`` or ``Any``. This is needed because since Python 3.9: - ``isinstance(types.GenericAlias(), type) is True`` + ``isinstance(types.GenericAlias(), type) is True`` and since Python 3.11: - ``isinstance(typing.Any, type) is True`` + ``isinstance(typing.Any, type) is True`` """ - return isinstance(obj, type) and type(obj) not in _ArtefactTypes + return ( + isinstance(obj, type) and (obj not in _ArtefactTypes) and (type(obj) not in _ArtefactTypes) + ) if hasattr(_typing_extensions, "Any") and _typing.Any is not _typing_extensions.Any: # type: ignore[attr-defined] # _typing_extensions.Any only from >= 4.4 @@ -641,9 +643,12 @@ def get_partial_type_hints( resolved_hints = get_type_hints( # type: ignore[call-arg] # Python 3.8 does not define `include-extras` obj, globalns=globalns, localns=localns, include_extras=include_extras ) - hints.update(resolved_hints) + hints[name] = resolved_hints[name] except NameError as error: if isinstance(hint, str): + # This conversion could be probably skipped in Python versions containing + # the fix applied in bpo-41370. Check: + # https://github.com/python/cpython/commit/b465b606049f6f7dd0711cb031fdaa251818741a#diff-ddb987fca5f5df0c9a2f5521ed687919d70bb3d64eaeb8021f98833a2a716887R344 hints[name] = ForwardRef(hint) elif isinstance(hint, (ForwardRef, _typing.ForwardRef)): hints[name] = hint diff --git a/tests/eve_tests/unit_tests/test_datamodels.py b/tests/eve_tests/unit_tests/test_datamodels.py index 8fa9e02cb6..0abb893dd4 100644 --- a/tests/eve_tests/unit_tests/test_datamodels.py +++ b/tests/eve_tests/unit_tests/test_datamodels.py @@ -15,6 +15,7 @@ from __future__ import annotations import enum +import numbers import types import typing from typing import Set # noqa: F401 # imported but unused (used in exec() context) @@ -1150,66 +1151,80 @@ class PartialGenericModel(datamodels.GenericDataModel, Generic[T]): with pytest.raises(TypeError, match="'PartialGenericModel__int.value'"): PartialGenericModel__int(value=["1"]) - def test_partial_specialization(self): - class PartialGenericModel(datamodels.GenericDataModel, Generic[T, U]): + def test_partial_concretization(self): + class BaseGenericModel(datamodels.GenericDataModel, Generic[T, U]): value: List[Tuple[T, U]] - PartialGenericModel(value=[]) - PartialGenericModel(value=[("value", 3)]) - PartialGenericModel(value=[(1, "value")]) - PartialGenericModel(value=[(-1.0, "value")]) - with pytest.raises(TypeError, match="'PartialGenericModel.value'"): - PartialGenericModel(value=1) - with pytest.raises(TypeError, match="'PartialGenericModel.value'"): - PartialGenericModel(value=(1, 2)) - with pytest.raises(TypeError, match="'PartialGenericModel.value'"): - PartialGenericModel(value=[()]) - with pytest.raises(TypeError, match="'PartialGenericModel.value'"): - PartialGenericModel(value=[(1,)]) + assert len(BaseGenericModel.__parameters__) == 2 + + BaseGenericModel(value=[]) + BaseGenericModel(value=[("value", 3)]) + BaseGenericModel(value=[(1, "value")]) + BaseGenericModel(value=[(-1.0, "value")]) + with pytest.raises(TypeError, match="'BaseGenericModel.value'"): + BaseGenericModel(value=1) + with pytest.raises(TypeError, match="'BaseGenericModel.value'"): + BaseGenericModel(value=(1, 2)) + with pytest.raises(TypeError, match="'BaseGenericModel.value'"): + BaseGenericModel(value=[()]) + with pytest.raises(TypeError, match="'BaseGenericModel.value'"): + BaseGenericModel(value=[(1,)]) + + PartiallyConcretizedGenericModel = BaseGenericModel[int, U] + + assert len(PartiallyConcretizedGenericModel.__parameters__) == 1 + + PartiallyConcretizedGenericModel(value=[]) + PartiallyConcretizedGenericModel(value=[(1, 2)]) + PartiallyConcretizedGenericModel(value=[(1, "value")]) + PartiallyConcretizedGenericModel(value=[(1, (11, 12))]) + with pytest.raises(TypeError, match=".value'"): + PartiallyConcretizedGenericModel(value=1) + with pytest.raises(TypeError, match=".value'"): + PartiallyConcretizedGenericModel(value=(1, 2)) + with pytest.raises(TypeError, match=".value'"): + PartiallyConcretizedGenericModel(value=[1.0]) + with pytest.raises(TypeError, match=".value'"): + PartiallyConcretizedGenericModel(value=["1"]) - print(f"{PartialGenericModel.__parameters__=}") - print(f"{hasattr(PartialGenericModel ,'__args__')=}") + FullyConcretizedGenericModel = PartiallyConcretizedGenericModel[str] - PartiallySpecializedGenericModel = PartialGenericModel[int, U] - print(f"{PartiallySpecializedGenericModel.__datamodel_fields__=}") - print(f"{PartiallySpecializedGenericModel.__parameters__=}") - print(f"{PartiallySpecializedGenericModel.__args__=}") + assert len(FullyConcretizedGenericModel.__parameters__) == 0 - PartiallySpecializedGenericModel(value=[]) - PartiallySpecializedGenericModel(value=[(1, 2)]) - PartiallySpecializedGenericModel(value=[(1, "value")]) - PartiallySpecializedGenericModel(value=[(1, (11, 12))]) + FullyConcretizedGenericModel(value=[]) + FullyConcretizedGenericModel(value=[(1, "value")]) + with pytest.raises(TypeError, match=".value'"): + FullyConcretizedGenericModel(value=1) + with pytest.raises(TypeError, match=".value'"): + FullyConcretizedGenericModel(value=(1, 2)) with pytest.raises(TypeError, match=".value'"): - PartiallySpecializedGenericModel(value=1) + FullyConcretizedGenericModel(value=[1.0]) with pytest.raises(TypeError, match=".value'"): - PartiallySpecializedGenericModel(value=(1, 2)) + FullyConcretizedGenericModel(value=["1"]) with pytest.raises(TypeError, match=".value'"): - PartiallySpecializedGenericModel(value=[1.0]) + FullyConcretizedGenericModel(value=1) with pytest.raises(TypeError, match=".value'"): - PartiallySpecializedGenericModel(value=["1"]) - - # TODO(egparedes): after fixing partial nested datamodel specialization - # noqa: e800 FullySpecializedGenericModel = PartiallySpecializedGenericModel[str] - # noqa: e800 print(f"{FullySpecializedGenericModel.__datamodel_fields__=}") - # noqa: e800 print(f"{FullySpecializedGenericModel.__parameters__=}") - # noqa: e800 print(f"{FullySpecializedGenericModel.__args__=}") - - # noqa: e800 FullySpecializedGenericModel(value=[]) - # noqa: e800 FullySpecializedGenericModel(value=[(1, "value")]) - # noqa: e800 with pytest.raises(TypeError, match=".value'"): - # noqa: e800 FullySpecializedGenericModel(value=1) - # noqa: e800 with pytest.raises(TypeError, match=".value'"): - # noqa: e800 FullySpecializedGenericModel(value=(1, 2)) - # noqa: e800 with pytest.raises(TypeError, match=".value'"): - # noqa: e800 FullySpecializedGenericModel(value=[1.0]) - # noqa: e800 with pytest.raises(TypeError, match=".value'"): - # noqa: e800 FullySpecializedGenericModel(value=["1"]) - # noqa: e800 with pytest.raises(TypeError, match=".value'"): - # noqa: e800 FullySpecializedGenericModel(value=1) - # noqa: e800 with pytest.raises(TypeError, match=".value'"): - # noqa: e800 FullySpecializedGenericModel(value=[(1, 2)]) - # noqa: e800 with pytest.raises(TypeError, match=".value'"): - # noqa: e800 FullySpecializedGenericModel(value=[(1, (11, 12))]) + FullyConcretizedGenericModel(value=[(1, 2)]) + with pytest.raises(TypeError, match=".value'"): + FullyConcretizedGenericModel(value=[(1, (11, 12))]) + + def test_partial_concretization_with_typevar(self): + class PartialGenericModel(datamodels.GenericDataModel, Generic[T]): + a: T + values: List[T] + + B = TypeVar("B", bound=numbers.Number) + PartiallyConcretizedGenericModel = PartialGenericModel[B] + + PartiallyConcretizedGenericModel(a=1, values=[2, 3]) + PartiallyConcretizedGenericModel(a=-1.32, values=[2.2, 3j]) + + with pytest.raises(TypeError, match=".a'"): + PartiallyConcretizedGenericModel(a="1", values=[2, 3]) + with pytest.raises(TypeError, match=".values'"): + PartiallyConcretizedGenericModel(a=1, values=[1, "2"]) + with pytest.raises(TypeError, match=".values'"): + PartiallyConcretizedGenericModel(a=1, values=(1, 2)) # Reuse sample_type_data from test_field_type_hint @pytest.mark.parametrize(["type_hint", "valid_values", "wrong_values"], SAMPLE_TYPE_DATA) diff --git a/tests/eve_tests/unit_tests/test_type_validation.py b/tests/eve_tests/unit_tests/test_type_validation.py index 70ef033ff0..d9977f0d3a 100644 --- a/tests/eve_tests/unit_tests/test_type_validation.py +++ b/tests/eve_tests/unit_tests/test_type_validation.py @@ -28,6 +28,7 @@ ) from gt4py.eve.extended_typing import ( Any, + Callable, Dict, Final, ForwardRef, @@ -41,8 +42,8 @@ ) -VALIDATORS: Final = [type_val.simple_type_validator] -FACTORIES: Final = [type_val.simple_type_validator_factory] +VALIDATORS: Final[list[Callable]] = [type_val.simple_type_validator] +FACTORIES: Final[list[Callable]] = [type_val.simple_type_validator_factory] class SampleEnum(enum.Enum): diff --git a/tests/next_tests/integration_tests/multi_feature_tests/ffront_tests/test_icon_like_scan.py b/tests/next_tests/integration_tests/multi_feature_tests/ffront_tests/test_icon_like_scan.py index 130f6bd29c..5bd255f80f 100644 --- a/tests/next_tests/integration_tests/multi_feature_tests/ffront_tests/test_icon_like_scan.py +++ b/tests/next_tests/integration_tests/multi_feature_tests/ffront_tests/test_icon_like_scan.py @@ -12,7 +12,7 @@ # # SPDX-License-Identifier: GPL-3.0-or-later -from dataclasses import dataclass +import dataclasses import numpy as np import pytest @@ -201,22 +201,26 @@ def test_setup(fieldview_backend): grid_type=common.GridType.UNSTRUCTURED, ) - @dataclass(frozen=True) + @dataclasses.dataclass(frozen=True) class setup: - case: cases.Case = test_case - cell_size = case.default_sizes[Cell] - k_size = case.default_sizes[KDim] - z_alpha = case.as_field( + case: cases.Case = dataclasses.field(default_factory=lambda: test_case) + cell_size = test_case.default_sizes[Cell] + k_size = test_case.default_sizes[KDim] + z_alpha = test_case.as_field( [Cell, KDim], np.random.default_rng().uniform(size=(cell_size, k_size + 1)) ) - z_beta = case.as_field( + z_beta = test_case.as_field( + [Cell, KDim], np.random.default_rng().uniform(size=(cell_size, k_size)) + ) + z_q = test_case.as_field( + [Cell, KDim], np.random.default_rng().uniform(size=(cell_size, k_size)) + ) + w = test_case.as_field( [Cell, KDim], np.random.default_rng().uniform(size=(cell_size, k_size)) ) - z_q = case.as_field([Cell, KDim], np.random.default_rng().uniform(size=(cell_size, k_size))) - w = case.as_field([Cell, KDim], np.random.default_rng().uniform(size=(cell_size, k_size))) z_q_ref, w_ref = reference(z_alpha.ndarray, z_beta.ndarray, z_q.ndarray, w.ndarray) - dummy = case.as_field([Cell, KDim], np.zeros((cell_size, k_size), dtype=bool)) - z_q_out = case.as_field([Cell, KDim], np.zeros((cell_size, k_size))) + dummy = test_case.as_field([Cell, KDim], np.zeros((cell_size, k_size), dtype=bool)) + z_q_out = test_case.as_field([Cell, KDim], np.zeros((cell_size, k_size))) return setup() diff --git a/tests/next_tests/unit_tests/ffront_tests/ast_passes_tests/test_single_static_assign.py b/tests/next_tests/unit_tests/ffront_tests/ast_passes_tests/test_single_static_assign.py index 052f272d22..ea1cdb82a6 100644 --- a/tests/next_tests/unit_tests/ffront_tests/ast_passes_tests/test_single_static_assign.py +++ b/tests/next_tests/unit_tests/ffront_tests/ast_passes_tests/test_single_static_assign.py @@ -108,7 +108,10 @@ def test_unpacking_swap(): lines = ast.unparse(ssa_ast).split("\n") assert lines[0] == f"a{SEP}0 = 5" assert lines[1] == f"b{SEP}0 = 1" - assert lines[2] == f"(b{SEP}1, a{SEP}1) = (a{SEP}0, b{SEP}0)" + assert lines[2] in [ + f"(b{SEP}1, a{SEP}1) = (a{SEP}0, b{SEP}0)", + f"b{SEP}1, a{SEP}1 = (a{SEP}0, b{SEP}0)", + ] # unparse produces different parentheses in different Python versions def test_annotated_assign(): diff --git a/tox.ini b/tox.ini index 44dc912c8a..817f721f71 100644 --- a/tox.ini +++ b/tox.ini @@ -11,21 +11,24 @@ envlist = # docs labels = test-cartesian-cpu = cartesian-py38-internal-cpu, cartesian-py39-internal-cpu, cartesian-py310-internal-cpu, \ - cartesian-py38-dace-cpu, cartesian-py39-dace-cpu, cartesian-py310-dace-cpu + cartesian-py311-internal-cpu, cartesian-py38-dace-cpu, cartesian-py39-dace-cpu, cartesian-py310-dace-cpu, \ + cartesian-py311-dace-cpu - test-eve-cpu = eve-py38, eve-py39, eve-py310 + test-eve-cpu = eve-py38, eve-py39, eve-py310, eve-py311 - test-next-cpu = next-py310-nomesh, next-py310-atlas + test-next-cpu = next-py310-nomesh, next-py311-nomesh, next-py310-atlas, next-py311-atlas test-storage-cpu = storage-py38-internal-cpu, storage-py39-internal-cpu, storage-py310-internal-cpu, \ - storage-py38-dace-cpu, storage-py39-dace-cpu, storage-py310-dace-cpu + storage-py311-internal-cpu, storage-py38-dace-cpu, storage-py39-dace-cpu, storage-py310-dace-cpu, \ + storage-py311-dace-cpu test-cpu = cartesian-py38-internal-cpu, cartesian-py39-internal-cpu, cartesian-py310-internal-cpu, \ - cartesian-py38-dace-cpu, cartesian-py39-dace-cpu, cartesian-py310-dace-cpu, \ - eve-py38, eve-py39, eve-py310, \ - next-py310-nomesh, next-py310-atlas, \ - storage-py38-internal-cpu, storage-py39-internal-cpu, storage-py310-internal-cpu, \ - storage-py38-dace-cpu, storage-py39-dace-cpu, storage-py310-dace-cpu + cartesian-py311-internal-cpu, cartesian-py38-dace-cpu, cartesian-py39-dace-cpu, cartesian-py310-dace-cpu, \ + cartesian-py311-dace-cpu, \ + eve-py38, eve-py39, eve-py310, eve-py311, \ + next-py310-nomesh, next-py311-nomesh, next-py310-atlas, next-py311-atlas, \ + storage-py38-internal-cpu, storage-py39-internal-cpu, storage-py310-internal-cpu, storage-py311-internal-cpu, \ + storage-py38-dace-cpu, storage-py39-dace-cpu, storage-py310-dace-cpu, storage-py311-dace-cpu [testenv] deps = -r {tox_root}{/}{env:ENV_REQUIREMENTS_FILE:requirements-dev.txt} @@ -44,7 +47,7 @@ pass_env = NUM_PROCESSES set_env = PYTHONWARNINGS = {env:PYTHONWARNINGS:ignore:Support for `[tool.setuptools]` in `pyproject.toml` is still *beta*:UserWarning} -[testenv:cartesian-py{38,39,310}-{internal,dace}-{cpu,cuda,cuda11x,cuda12x}] +[testenv:cartesian-py{38,39,310,311}-{internal,dace}-{cpu,cuda,cuda11x,cuda12x}] description = Run 'gt4py.cartesian' tests pass_env = {[testenv]pass_env}, BOOST_ROOT, BOOST_HOME, CUDA_HOME, CUDA_PATH, CXX, CC, OPENMP_CPPFLAGS, OPENMP_LDFLAGS, PIP_USER, PYTHONUSERBASE allowlist_externals = @@ -65,13 +68,13 @@ commands = ; coverage json --rcfile=setup.cfg ; coverage html --rcfile=setup.cfg --show-contexts -[testenv:eve-py{38,39,310}] +[testenv:eve-py{38,39,310,311}] description = Run 'gt4py.eve' tests commands = python -m pytest --cache-clear -v -n {env:NUM_PROCESSES:1} {posargs} tests{/}eve_tests python -m pytest --doctest-modules src{/}gt4py{/}eve -[testenv:next-py{310}-{nomesh,atlas}-{cpu,cuda,cuda11x,cuda12x}] +[testenv:next-py{310,311}-{nomesh,atlas}-{cpu,cuda,cuda11x,cuda12x}] description = Run 'gt4py.next' tests pass_env = {[testenv]pass_env}, BOOST_ROOT, BOOST_HOME, CUDA_HOME, CUDA_PATH deps = @@ -87,14 +90,14 @@ commands = # atlas-{cuda,cuda11x,cuda12x}: python -m pytest --cache-clear -v -n {env:NUM_PROCESSES:1} -m "requires_atlas and requires_gpu" {posargs} tests{/}next_tests # TODO(ricoh): activate when such tests exist pytest --doctest-modules src{/}gt4py{/}next -[testenv:storage-py{38,39,310}-{internal,dace}-{cpu,cuda,cuda11x,cuda12x}] +[testenv:storage-py{38,39,310,311}-{internal,dace}-{cpu,cuda,cuda11x,cuda12x}] description = Run 'gt4py.storage' tests commands = cpu: python -m pytest --cache-clear -v -n {env:NUM_PROCESSES:1} -m "not requires_gpu" {posargs} tests{/}storage_tests {cuda,cuda11x,cuda12x}: python -m pytest --cache-clear -v -n {env:NUM_PROCESSES:1} -m "requires_gpu" {posargs} tests{/}storage_tests #pytest doctest-modules {posargs} src{/}gt4py{/}storage -[testenv:linters-py{38,39,310}] +[testenv:linters-py{38,39,310,311}] description = Run linters commands = flake8 .{/}src @@ -134,11 +137,13 @@ description = py38: Update requirements for testing a specific python version py39: Update requirements for testing a specific python version py310: Update requirements for testing a specific python version + py311: Update requirements for testing a specific python version base_python = common: py38 py38: py38 py39: py39 py310: py310 + py311: py311 deps = cogapp>=3.3 pip-tools>=6.10 @@ -178,7 +183,7 @@ commands = # Run cog to update .pre-commit-config.yaml with new versions common: cog -r -P .pre-commit-config.yaml -[testenv:dev-py{38,39,310}{-atlas,}] +[testenv:dev-py{38,39,310,311}{-atlas,}] description = Initialize development environment for gt4py deps = -r {tox_root}{/}requirements-dev.txt From e20294efe31f911567713b62ac70e161cc49255f Mon Sep 17 00:00:00 2001 From: Till Ehrengruber Date: Mon, 22 Jan 2024 10:21:46 +0100 Subject: [PATCH 48/85] bug[next]: Improve error message on invalid call to field operator and program (#1323) After #1275 most of the error message given to the user when calling a field operator or program with invalid arguments was only available in verbose mode. This PR shows this information again. ```python @field_operator def foo(x: IField): return x @field_operator def testee(a: IField, b: IField, c: IField) -> IField: return foo(1) ``` ``` gt4py.next.errors.exceptions.DSLError: Invalid argument types in call to `foo`. E Invalid call to function of type `FieldOperatorType(definition=FunctionType(pos_only_args=[], pos_or_kw_args={'x': FieldType(dims=[Dimension(value='IDim', kind=)], dtype=ScalarType(kind=, shape=None))}, kw_only_args={}, returns=FieldType(dims=[Dimension(value='IDim', kind=)], dtype=ScalarType(kind=, shape=None))))`: E - Expected argument `x` to be of type `Field[[IDim], int32]`, but got `int32`. E File ".../gt4py_functional/tests/next_tests/integration_tests/feature_tests/ffront_tests/test_arg_call_interface.py", line 113 E return foo(1) ``` --- src/gt4py/next/ffront/decorator.py | 4 +++- src/gt4py/next/ffront/foast_passes/type_deduction.py | 2 +- src/gt4py/next/ffront/past_passes/type_deduction.py | 2 +- .../feature_tests/ffront_tests/test_program.py | 3 ++- 4 files changed, 7 insertions(+), 4 deletions(-) diff --git a/src/gt4py/next/ffront/decorator.py b/src/gt4py/next/ffront/decorator.py index 05cbe1c882..9f8537f59b 100644 --- a/src/gt4py/next/ffront/decorator.py +++ b/src/gt4py/next/ffront/decorator.py @@ -344,7 +344,9 @@ def _validate_args(self, *args, **kwargs) -> None: raise_exception=True, ) except ValueError as err: - raise TypeError(f"Invalid argument types in call to '{self.past_node.id}'.") from err + raise errors.DSLError( + None, f"Invalid argument types in call to '{self.past_node.id}'.\n{err}" + ) from err def _process_args(self, args: tuple, kwargs: dict) -> tuple[tuple, tuple, dict[str, Any]]: self._validate_args(*args, **kwargs) diff --git a/src/gt4py/next/ffront/foast_passes/type_deduction.py b/src/gt4py/next/ffront/foast_passes/type_deduction.py index 639e5ff009..5e289af664 100644 --- a/src/gt4py/next/ffront/foast_passes/type_deduction.py +++ b/src/gt4py/next/ffront/foast_passes/type_deduction.py @@ -694,7 +694,7 @@ def visit_Call(self, node: foast.Call, **kwargs) -> foast.Call: ) except ValueError as err: raise errors.DSLError( - node.location, f"Invalid argument types in call to '{new_func}'." + node.location, f"Invalid argument types in call to '{new_func}'.\n{err}" ) from err return_type = type_info.return_type(func_type, with_args=arg_types, with_kwargs=kwarg_types) diff --git a/src/gt4py/next/ffront/past_passes/type_deduction.py b/src/gt4py/next/ffront/past_passes/type_deduction.py index fc353d64e4..af8f5e8368 100644 --- a/src/gt4py/next/ffront/past_passes/type_deduction.py +++ b/src/gt4py/next/ffront/past_passes/type_deduction.py @@ -229,7 +229,7 @@ def visit_Call(self, node: past.Call, **kwargs): ) except ValueError as ex: - raise errors.DSLError(node.location, f"Invalid call to '{node.func.id}'.") from ex + raise errors.DSLError(node.location, f"Invalid call to '{node.func.id}'.\n{ex}") from ex return past.Call( func=new_func, diff --git a/tests/next_tests/integration_tests/feature_tests/ffront_tests/test_program.py b/tests/next_tests/integration_tests/feature_tests/ffront_tests/test_program.py index c86881ab7c..938c69fb52 100644 --- a/tests/next_tests/integration_tests/feature_tests/ffront_tests/test_program.py +++ b/tests/next_tests/integration_tests/feature_tests/ffront_tests/test_program.py @@ -20,6 +20,7 @@ import pytest import gt4py.next as gtx +from gt4py.next import errors from next_tests.integration_tests import cases from next_tests.integration_tests.cases import IDim, Ioff, JDim, cartesian_case, fieldview_backend @@ -222,7 +223,7 @@ def test_wrong_argument_type(cartesian_case, copy_program_def): inp = cartesian_case.as_field([JDim], np.ones((cartesian_case.default_sizes[JDim],))) out = cases.allocate(cartesian_case, copy_program, "out").strategy(cases.ConstInitializer(1))() - with pytest.raises(TypeError) as exc_info: + with pytest.raises(errors.DSLError) as exc_info: # program is defined on Field[[IDim], ...], but we call with # Field[[JDim], ...] copy_program(inp, out, offset_provider={}) From e6f41605ee7ae5a791d8a5d3557b4f45e0f511a8 Mon Sep 17 00:00:00 2001 From: Till Ehrengruber Date: Mon, 22 Jan 2024 18:35:57 +0100 Subject: [PATCH 49/85] Update AUTHORS.md --- AUTHORS.md | 1 + 1 file changed, 1 insertion(+) diff --git a/AUTHORS.md b/AUTHORS.md index 89aafb9971..0fd0098fc4 100644 --- a/AUTHORS.md +++ b/AUTHORS.md @@ -9,6 +9,7 @@ - Deconinck, Florian. SSAI/NASA-GSFC - Ehrengruber, Till. ETH Zurich - CSCS - Elbert, Oliver D. NOAA-GFDL +- Faghih-Naini, Sara - ECMWF - Farabullini, Nicoletta. ETH Zurich - C2SM - George, Rhea. Allen Institute for AI - González Paredes, Enrique. ETH Zurich - CSCS From 8bd5a41e9d27409472442c6ad8a8a7908953b265 Mon Sep 17 00:00:00 2001 From: Till Ehrengruber Date: Mon, 22 Jan 2024 18:37:04 +0100 Subject: [PATCH 50/85] Update AUTHORS.md --- AUTHORS.md | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/AUTHORS.md b/AUTHORS.md index 0fd0098fc4..6c76e5759e 100644 --- a/AUTHORS.md +++ b/AUTHORS.md @@ -9,7 +9,7 @@ - Deconinck, Florian. SSAI/NASA-GSFC - Ehrengruber, Till. ETH Zurich - CSCS - Elbert, Oliver D. NOAA-GFDL -- Faghih-Naini, Sara - ECMWF +- Faghih-Naini, Sara. ECMWF - Farabullini, Nicoletta. ETH Zurich - C2SM - George, Rhea. Allen Institute for AI - González Paredes, Enrique. ETH Zurich - CSCS From d5cfa7d7b1c74056059d0e42822b4cf01a2a2a22 Mon Sep 17 00:00:00 2001 From: Christos Kotsalos Date: Tue, 23 Jan 2024 18:13:11 +0100 Subject: [PATCH 51/85] feat[next][dace]: Add more debug info to DaCe (#1384) * Add more debug info to DaCe (pass SourceLocation from past/foast to itir, and from itir to the SDFG): Preserve Location through Visitors --- src/gt4py/eve/__init__.py | 8 +- src/gt4py/eve/traits.py | 8 + src/gt4py/next/ffront/foast_to_itir.py | 8 +- src/gt4py/next/ffront/past_to_itir.py | 20 ++- src/gt4py/next/iterator/ir.py | 3 + .../iterator/transforms/collapse_list_get.py | 2 +- .../iterator/transforms/collapse_tuple.py | 9 +- .../iterator/transforms/constant_folding.py | 4 +- src/gt4py/next/iterator/transforms/cse.py | 14 +- .../next/iterator/transforms/eta_reduction.py | 4 +- .../next/iterator/transforms/fuse_maps.py | 4 +- .../next/iterator/transforms/global_tmps.py | 10 +- .../iterator/transforms/inline_fundefs.py | 6 +- .../iterator/transforms/inline_into_scan.py | 7 +- .../iterator/transforms/inline_lambdas.py | 6 +- .../next/iterator/transforms/inline_lifts.py | 8 +- .../next/iterator/transforms/merge_let.py | 2 +- .../iterator/transforms/normalize_shifts.py | 4 +- .../iterator/transforms/propagate_deref.py | 4 +- .../transforms/prune_closure_inputs.py | 4 +- .../next/iterator/transforms/remap_symbols.py | 6 +- .../iterator/transforms/scan_eta_reduction.py | 7 +- .../iterator/transforms/symbol_ref_utils.py | 2 +- .../next/iterator/transforms/trace_shifts.py | 4 +- .../next/iterator/transforms/unroll_reduce.py | 17 +- .../runners/dace_iterator/__init__.py | 14 ++ .../runners/dace_iterator/itir_to_sdfg.py | 47 ++++-- .../runners/dace_iterator/itir_to_tasklet.py | 152 +++++++++++++----- .../runners/dace_iterator/utility.py | 24 ++- 29 files changed, 288 insertions(+), 120 deletions(-) diff --git a/src/gt4py/eve/__init__.py b/src/gt4py/eve/__init__.py index 617a889e28..e726db1f1a 100644 --- a/src/gt4py/eve/__init__.py +++ b/src/gt4py/eve/__init__.py @@ -58,7 +58,12 @@ field, frozenmodel, ) -from .traits import SymbolTableTrait, ValidatedSymbolTableTrait, VisitorWithSymbolTableTrait +from .traits import ( + PreserveLocationVisitor, + SymbolTableTrait, + ValidatedSymbolTableTrait, + VisitorWithSymbolTableTrait, +) from .trees import ( bfs_walk_items, bfs_walk_values, @@ -113,6 +118,7 @@ "SymbolTableTrait", "ValidatedSymbolTableTrait", "VisitorWithSymbolTableTrait", + "PreserveLocationVisitor", # trees "bfs_walk_items", "bfs_walk_values", diff --git a/src/gt4py/eve/traits.py b/src/gt4py/eve/traits.py index df556c9d7f..aacae804d8 100644 --- a/src/gt4py/eve/traits.py +++ b/src/gt4py/eve/traits.py @@ -172,3 +172,11 @@ def visit(self, node: concepts.RootNode, **kwargs: Any) -> Any: kwargs["symtable"] = kwargs["symtable"].parents return result + + +class PreserveLocationVisitor(visitors.NodeVisitor): + def visit(self, node: concepts.RootNode, **kwargs: Any) -> Any: + result = super().visit(node, **kwargs) + if hasattr(node, "location") and hasattr(result, "location"): + result.location = node.location + return result diff --git a/src/gt4py/next/ffront/foast_to_itir.py b/src/gt4py/next/ffront/foast_to_itir.py index c4d518d279..0c9ab4ab27 100644 --- a/src/gt4py/next/ffront/foast_to_itir.py +++ b/src/gt4py/next/ffront/foast_to_itir.py @@ -15,7 +15,7 @@ import dataclasses from typing import Any, Callable, Optional -from gt4py.eve import NodeTranslator +from gt4py.eve import NodeTranslator, PreserveLocationVisitor from gt4py.eve.utils import UIDGenerator from gt4py.next.ffront import ( dialect_ast_enums, @@ -39,7 +39,7 @@ def promote_to_list( @dataclasses.dataclass -class FieldOperatorLowering(NodeTranslator): +class FieldOperatorLowering(PreserveLocationVisitor, NodeTranslator): """ Lower FieldOperator AST (FOAST) to Iterator IR (ITIR). @@ -61,7 +61,7 @@ class FieldOperatorLowering(NodeTranslator): >>> lowered.id SymbolName('fieldop') - >>> lowered.params + >>> lowered.params # doctest: +ELLIPSIS [Sym(id=SymbolName('inp'), kind='Iterator', dtype=('float64', False))] """ @@ -142,7 +142,7 @@ def visit_IfStmt( self, node: foast.IfStmt, *, inner_expr: Optional[itir.Expr], **kwargs ) -> itir.Expr: # the lowered if call doesn't need to be lifted as the condition can only originate - # from a scalar value (and not a field) + # from a scalar value (and not a field) assert ( isinstance(node.condition.type, ts.ScalarType) and node.condition.type.kind == ts.ScalarKind.BOOL diff --git a/src/gt4py/next/ffront/past_to_itir.py b/src/gt4py/next/ffront/past_to_itir.py index 709912077b..ed239e0436 100644 --- a/src/gt4py/next/ffront/past_to_itir.py +++ b/src/gt4py/next/ffront/past_to_itir.py @@ -40,7 +40,9 @@ def _flatten_tuple_expr( raise ValueError("Only 'past.Name', 'past.Subscript' or 'past.TupleExpr' thereof are allowed.") -class ProgramLowering(traits.VisitorWithSymbolTableTrait, NodeTranslator): +class ProgramLowering( + traits.PreserveLocationVisitor, traits.VisitorWithSymbolTableTrait, NodeTranslator +): """ Lower Program AST (PAST) to Iterator IR (ITIR). @@ -151,6 +153,7 @@ def _visit_stencil_call(self, node: past.Call, **kwargs) -> itir.StencilClosure: stencil=itir.SymRef(id=node.func.id), inputs=[*lowered_args, *lowered_kwargs.values()], output=output, + location=node.location, ) def _visit_slice_bound( @@ -175,17 +178,22 @@ def _visit_slice_bound( lowered_bound = self.visit(slice_bound, **kwargs) else: raise AssertionError("Expected 'None' or 'past.Constant'.") + if slice_bound: + lowered_bound.location = slice_bound.location return lowered_bound def _construct_itir_out_arg(self, node: past.Expr) -> itir.Expr: if isinstance(node, past.Name): - return itir.SymRef(id=node.id) + return itir.SymRef(id=node.id, location=node.location) elif isinstance(node, past.Subscript): - return self._construct_itir_out_arg(node.value) + itir_node = self._construct_itir_out_arg(node.value) + itir_node.location = node.location + return itir_node elif isinstance(node, past.TupleExpr): return itir.FunCall( fun=itir.SymRef(id="make_tuple"), args=[self._construct_itir_out_arg(el) for el in node.elts], + location=node.location, ) else: raise ValueError( @@ -247,7 +255,11 @@ def _construct_itir_domain_arg( else: raise AssertionError() - return itir.FunCall(fun=itir.SymRef(id=domain_builtin), args=domain_args) + return itir.FunCall( + fun=itir.SymRef(id=domain_builtin), + args=domain_args, + location=(node_domain or out_field).location, + ) def _construct_itir_initialized_domain_arg( self, diff --git a/src/gt4py/next/iterator/ir.py b/src/gt4py/next/iterator/ir.py index e6ee20e227..37abbec9e7 100644 --- a/src/gt4py/next/iterator/ir.py +++ b/src/gt4py/next/iterator/ir.py @@ -17,12 +17,15 @@ import gt4py.eve as eve from gt4py.eve import Coerced, SymbolName, SymbolRef, datamodels +from gt4py.eve.concepts import SourceLocation from gt4py.eve.traits import SymbolTableTrait, ValidatedSymbolTableTrait from gt4py.eve.utils import noninstantiable @noninstantiable class Node(eve.Node): + location: Optional[SourceLocation] = eve.field(default=None, repr=False, compare=False) + def __str__(self) -> str: from gt4py.next.iterator.pretty_printer import pformat diff --git a/src/gt4py/next/iterator/transforms/collapse_list_get.py b/src/gt4py/next/iterator/transforms/collapse_list_get.py index 08cbd7313e..6acb8a79c4 100644 --- a/src/gt4py/next/iterator/transforms/collapse_list_get.py +++ b/src/gt4py/next/iterator/transforms/collapse_list_get.py @@ -16,7 +16,7 @@ from gt4py.next.iterator import ir -class CollapseListGet(eve.NodeTranslator): +class CollapseListGet(eve.PreserveLocationVisitor, eve.NodeTranslator): """Simplifies expressions containing `list_get`. Examples diff --git a/src/gt4py/next/iterator/transforms/collapse_tuple.py b/src/gt4py/next/iterator/transforms/collapse_tuple.py index 30457f2246..42bbf28909 100644 --- a/src/gt4py/next/iterator/transforms/collapse_tuple.py +++ b/src/gt4py/next/iterator/transforms/collapse_tuple.py @@ -48,7 +48,7 @@ def _get_tuple_size(elem: ir.Node, node_types: Optional[dict] = None) -> int | t @dataclass(frozen=True) -class CollapseTuple(eve.NodeTranslator): +class CollapseTuple(eve.PreserveLocationVisitor, eve.NodeTranslator): """ Simplifies `make_tuple`, `tuple_get` calls. @@ -88,13 +88,6 @@ def apply( node_types, ).visit(node) - return cls( - ignore_tuple_size, - collapse_make_tuple_tuple_get, - collapse_tuple_get_make_tuple, - use_global_type_inference, - ).visit(node) - def visit_FunCall(self, node: ir.FunCall, **kwargs) -> ir.Node: if ( self.collapse_make_tuple_tuple_get diff --git a/src/gt4py/next/iterator/transforms/constant_folding.py b/src/gt4py/next/iterator/transforms/constant_folding.py index fa326760b0..696a87a197 100644 --- a/src/gt4py/next/iterator/transforms/constant_folding.py +++ b/src/gt4py/next/iterator/transforms/constant_folding.py @@ -12,12 +12,12 @@ # # SPDX-License-Identifier: GPL-3.0-or-later -from gt4py.eve import NodeTranslator +from gt4py.eve import NodeTranslator, PreserveLocationVisitor from gt4py.next.iterator import embedded, ir from gt4py.next.iterator.ir_utils import ir_makers as im -class ConstantFolding(NodeTranslator): +class ConstantFolding(PreserveLocationVisitor, NodeTranslator): @classmethod def apply(cls, node: ir.Node) -> ir.Node: return cls().visit(node) diff --git a/src/gt4py/next/iterator/transforms/cse.py b/src/gt4py/next/iterator/transforms/cse.py index 034a39d68f..f9cf272c45 100644 --- a/src/gt4py/next/iterator/transforms/cse.py +++ b/src/gt4py/next/iterator/transforms/cse.py @@ -17,14 +17,20 @@ import operator import typing -from gt4py.eve import NodeTranslator, NodeVisitor, SymbolTableTrait, VisitorWithSymbolTableTrait +from gt4py.eve import ( + NodeTranslator, + NodeVisitor, + PreserveLocationVisitor, + SymbolTableTrait, + VisitorWithSymbolTableTrait, +) from gt4py.eve.utils import UIDGenerator from gt4py.next.iterator import ir from gt4py.next.iterator.transforms.inline_lambdas import inline_lambda @dataclasses.dataclass -class _NodeReplacer(NodeTranslator): +class _NodeReplacer(PreserveLocationVisitor, NodeTranslator): PRESERVED_ANNEX_ATTRS = ("type",) expr_map: dict[int, ir.SymRef] @@ -72,7 +78,7 @@ def _is_collectable_expr(node: ir.Node) -> bool: @dataclasses.dataclass -class CollectSubexpressions(VisitorWithSymbolTableTrait, NodeVisitor): +class CollectSubexpressions(PreserveLocationVisitor, VisitorWithSymbolTableTrait, NodeVisitor): @dataclasses.dataclass class SubexpressionData: #: A list of node ids with equal hash and a set of collected child subexpression ids @@ -341,7 +347,7 @@ def extract_subexpression( @dataclasses.dataclass(frozen=True) -class CommonSubexpressionElimination(NodeTranslator): +class CommonSubexpressionElimination(PreserveLocationVisitor, NodeTranslator): """ Perform common subexpression elimination. diff --git a/src/gt4py/next/iterator/transforms/eta_reduction.py b/src/gt4py/next/iterator/transforms/eta_reduction.py index 55b2141499..93702a6c96 100644 --- a/src/gt4py/next/iterator/transforms/eta_reduction.py +++ b/src/gt4py/next/iterator/transforms/eta_reduction.py @@ -12,11 +12,11 @@ # # SPDX-License-Identifier: GPL-3.0-or-later -from gt4py.eve import NodeTranslator +from gt4py.eve import NodeTranslator, PreserveLocationVisitor from gt4py.next.iterator import ir -class EtaReduction(NodeTranslator): +class EtaReduction(PreserveLocationVisitor, NodeTranslator): """Eta reduction: simplifies `λ(args...) → f(args...)` to `f`.""" def visit_Lambda(self, node: ir.Lambda) -> ir.Node: diff --git a/src/gt4py/next/iterator/transforms/fuse_maps.py b/src/gt4py/next/iterator/transforms/fuse_maps.py index e9fbb0f81d..694dcd6a61 100644 --- a/src/gt4py/next/iterator/transforms/fuse_maps.py +++ b/src/gt4py/next/iterator/transforms/fuse_maps.py @@ -38,7 +38,7 @@ def _is_reduce(node: ir.Node) -> TypeGuard[ir.FunCall]: @dataclasses.dataclass(frozen=True) -class FuseMaps(traits.VisitorWithSymbolTableTrait, NodeTranslator): +class FuseMaps(traits.PreserveLocationVisitor, traits.VisitorWithSymbolTableTrait, NodeTranslator): """ Fuses nested `map_`s. @@ -66,6 +66,7 @@ def _as_lambda(self, fun: ir.SymRef | ir.Lambda, param_count: int) -> ir.Lambda: return ir.Lambda( params=params, expr=ir.FunCall(fun=fun, args=[ir.SymRef(id=p.id) for p in params]), + location=fun.location, ) def visit_FunCall(self, node: ir.FunCall, **kwargs): @@ -99,6 +100,7 @@ def visit_FunCall(self, node: ir.FunCall, **kwargs): ir.FunCall( fun=inner_op, args=[ir.SymRef(id=param.id) for param in inner_op.params], + location=node.location, ) ) ) diff --git a/src/gt4py/next/iterator/transforms/global_tmps.py b/src/gt4py/next/iterator/transforms/global_tmps.py index 0033f36cab..c423a3c277 100644 --- a/src/gt4py/next/iterator/transforms/global_tmps.py +++ b/src/gt4py/next/iterator/transforms/global_tmps.py @@ -19,7 +19,7 @@ import gt4py.eve as eve import gt4py.next as gtx -from gt4py.eve import Coerced, NodeTranslator +from gt4py.eve import Coerced, NodeTranslator, PreserveLocationVisitor from gt4py.eve.traits import SymbolTableTrait from gt4py.eve.utils import UIDGenerator from gt4py.next import common @@ -267,6 +267,7 @@ def split_closures(node: ir.FencilDefinition, offset_provider) -> FencilWithTemp stencil=stencil, output=im.ref(tmp_sym.id), inputs=[closure_param_arg_mapping[param.id] for param in lift_expr.args], # type: ignore[attr-defined] + location=current_closure.location, ) ) @@ -294,6 +295,7 @@ def split_closures(node: ir.FencilDefinition, offset_provider) -> FencilWithTemp output=current_closure.output, inputs=current_closure.inputs + [ir.SymRef(id=sym.id) for sym in extracted_lifts.keys()], + location=current_closure.location, ) ) else: @@ -307,6 +309,7 @@ def split_closures(node: ir.FencilDefinition, offset_provider) -> FencilWithTemp + [ir.Sym(id=tmp.id) for tmp in tmps] + [ir.Sym(id=AUTO_DOMAIN.fun.id)], # type: ignore[attr-defined] # value is a global constant closures=list(reversed(closures)), + location=node.location, ), params=node.params, tmps=[Temporary(id=tmp.id) for tmp in tmps], @@ -333,6 +336,7 @@ def prune_unused_temporaries(node: FencilWithTemporaries) -> FencilWithTemporari function_definitions=node.fencil.function_definitions, params=[p for p in node.fencil.params if p.id not in unused_tmps], closures=closures, + location=node.fencil.location, ), params=node.params, tmps=[tmp for tmp in node.tmps if tmp.id not in unused_tmps], @@ -456,6 +460,7 @@ def update_domains( stencil=closure.stencil, output=closure.output, inputs=closure.inputs, + location=closure.location, ) else: domain = closure.domain @@ -521,6 +526,7 @@ def update_domains( function_definitions=node.fencil.function_definitions, params=node.fencil.params[:-1], # remove `_gtmp_auto_domain` param again closures=list(reversed(closures)), + location=node.fencil.location, ), params=node.params, tmps=node.tmps, @@ -580,7 +586,7 @@ def convert_type(dtype): # TODO(tehrengruber): Add support for dynamic shifts (e.g. the distance is a symbol). This can be # tricky: For every lift statement that is dynamically shifted we can not compute bounds anymore # and hence also not extract as a temporary. -class CreateGlobalTmps(NodeTranslator): +class CreateGlobalTmps(PreserveLocationVisitor, NodeTranslator): """Main entry point for introducing global temporaries. Transforms an existing iterator IR fencil into a fencil with global temporaries. diff --git a/src/gt4py/next/iterator/transforms/inline_fundefs.py b/src/gt4py/next/iterator/transforms/inline_fundefs.py index 6bf2b60592..a53232745f 100644 --- a/src/gt4py/next/iterator/transforms/inline_fundefs.py +++ b/src/gt4py/next/iterator/transforms/inline_fundefs.py @@ -14,11 +14,11 @@ from typing import Any, Dict, Set -from gt4py.eve import NOTHING, NodeTranslator +from gt4py.eve import NOTHING, NodeTranslator, PreserveLocationVisitor from gt4py.next.iterator import ir -class InlineFundefs(NodeTranslator): +class InlineFundefs(PreserveLocationVisitor, NodeTranslator): def visit_SymRef(self, node: ir.SymRef, *, symtable: Dict[str, Any]): if node.id in symtable and isinstance((symbol := symtable[node.id]), ir.FunctionDefinition): return ir.Lambda( @@ -31,7 +31,7 @@ def visit_FencilDefinition(self, node: ir.FencilDefinition): return self.generic_visit(node, symtable=node.annex.symtable) -class PruneUnreferencedFundefs(NodeTranslator): +class PruneUnreferencedFundefs(PreserveLocationVisitor, NodeTranslator): def visit_FunctionDefinition( self, node: ir.FunctionDefinition, *, referenced: Set[str], second_pass: bool ): diff --git a/src/gt4py/next/iterator/transforms/inline_into_scan.py b/src/gt4py/next/iterator/transforms/inline_into_scan.py index fe1eae6e07..a1c9a2eb5b 100644 --- a/src/gt4py/next/iterator/transforms/inline_into_scan.py +++ b/src/gt4py/next/iterator/transforms/inline_into_scan.py @@ -53,7 +53,9 @@ def _lambda_and_lift_inliner(node: ir.FunCall) -> ir.FunCall: return inlined -class InlineIntoScan(traits.VisitorWithSymbolTableTrait, NodeTranslator): +class InlineIntoScan( + traits.PreserveLocationVisitor, traits.VisitorWithSymbolTableTrait, NodeTranslator +): """ Inline non-SymRef arguments into the scan. @@ -100,6 +102,5 @@ def visit_FunCall(self, node: ir.FunCall, **kwargs): new_scan = ir.FunCall( fun=ir.SymRef(id="scan"), args=[new_scanpass, *original_scan_call.args[1:]] ) - result = ir.FunCall(fun=new_scan, args=[ir.SymRef(id=ref) for ref in refs_in_args]) - return result + return ir.FunCall(fun=new_scan, args=[ir.SymRef(id=ref) for ref in refs_in_args]) return self.generic_visit(node, **kwargs) diff --git a/src/gt4py/next/iterator/transforms/inline_lambdas.py b/src/gt4py/next/iterator/transforms/inline_lambdas.py index a56ad5cb10..0b89fe6d98 100644 --- a/src/gt4py/next/iterator/transforms/inline_lambdas.py +++ b/src/gt4py/next/iterator/transforms/inline_lambdas.py @@ -15,7 +15,7 @@ import dataclasses from typing import Optional -from gt4py.eve import NodeTranslator +from gt4py.eve import NodeTranslator, PreserveLocationVisitor from gt4py.next.iterator import ir from gt4py.next.iterator.ir_utils.common_pattern_matcher import is_applied_lift from gt4py.next.iterator.transforms.remap_symbols import RemapSymbolRefs, RenameSymbols @@ -104,6 +104,7 @@ def new_name(name): new_expr = RemapSymbolRefs().visit(expr, symbol_map=symbol_map) if all(eligible_params): + new_expr.location = node.location return new_expr else: return ir.FunCall( @@ -116,11 +117,12 @@ def new_name(name): expr=new_expr, ), args=[arg for arg, eligible in zip(node.args, eligible_params) if not eligible], + location=node.location, ) @dataclasses.dataclass -class InlineLambdas(NodeTranslator): +class InlineLambdas(PreserveLocationVisitor, NodeTranslator): """Inline lambda calls by substituting every argument by its value.""" PRESERVED_ANNEX_ATTRS = ("type",) diff --git a/src/gt4py/next/iterator/transforms/inline_lifts.py b/src/gt4py/next/iterator/transforms/inline_lifts.py index d7d8e5e612..d6146d9fc8 100644 --- a/src/gt4py/next/iterator/transforms/inline_lifts.py +++ b/src/gt4py/next/iterator/transforms/inline_lifts.py @@ -103,14 +103,18 @@ def _transform_and_extract_lift_args( extracted_args[new_symbol] = arg new_args.append(ir.SymRef(id=new_symbol.id)) - return (im.lift(inner_stencil)(*new_args), extracted_args) + itir_node = im.lift(inner_stencil)(*new_args) + itir_node.location = node.location + return (itir_node, extracted_args) # TODO(tehrengruber): This pass has many different options that should be written as dedicated # passes. Due to a lack of infrastructure (e.g. no pass manager) to combine passes without # performance degradation we leave everything as one pass for now. @dataclasses.dataclass -class InlineLifts(traits.VisitorWithSymbolTableTrait, NodeTranslator): +class InlineLifts( + traits.PreserveLocationVisitor, traits.VisitorWithSymbolTableTrait, NodeTranslator +): """Inline lifted function calls. Optionally a predicate function can be passed which can enable or disable inlining of specific diff --git a/src/gt4py/next/iterator/transforms/merge_let.py b/src/gt4py/next/iterator/transforms/merge_let.py index 7426617ac8..bcfc6b2a17 100644 --- a/src/gt4py/next/iterator/transforms/merge_let.py +++ b/src/gt4py/next/iterator/transforms/merge_let.py @@ -17,7 +17,7 @@ from gt4py.next.iterator.transforms.symbol_ref_utils import CountSymbolRefs -class MergeLet(eve.NodeTranslator): +class MergeLet(eve.PreserveLocationVisitor, eve.NodeTranslator): """ Merge let-like statements. diff --git a/src/gt4py/next/iterator/transforms/normalize_shifts.py b/src/gt4py/next/iterator/transforms/normalize_shifts.py index efc9064612..c70dc1ccd1 100644 --- a/src/gt4py/next/iterator/transforms/normalize_shifts.py +++ b/src/gt4py/next/iterator/transforms/normalize_shifts.py @@ -12,11 +12,11 @@ # # SPDX-License-Identifier: GPL-3.0-or-later -from gt4py.eve import NodeTranslator +from gt4py.eve import NodeTranslator, PreserveLocationVisitor from gt4py.next.iterator import ir -class NormalizeShifts(NodeTranslator): +class NormalizeShifts(PreserveLocationVisitor, NodeTranslator): def visit_FunCall(self, node: ir.FunCall): node = self.generic_visit(node) if ( diff --git a/src/gt4py/next/iterator/transforms/propagate_deref.py b/src/gt4py/next/iterator/transforms/propagate_deref.py index 54bdafcda8..783e54ede0 100644 --- a/src/gt4py/next/iterator/transforms/propagate_deref.py +++ b/src/gt4py/next/iterator/transforms/propagate_deref.py @@ -12,7 +12,7 @@ # # SPDX-License-Identifier: GPL-3.0-or-later -from gt4py.eve import NodeTranslator +from gt4py.eve import NodeTranslator, PreserveLocationVisitor from gt4py.eve.pattern_matching import ObjectPattern as P from gt4py.next.iterator import ir @@ -22,7 +22,7 @@ # `(λ(...) → plus(multiplies(...), ...))(...)`. -class PropagateDeref(NodeTranslator): +class PropagateDeref(PreserveLocationVisitor, NodeTranslator): @classmethod def apply(cls, node: ir.Node): """ diff --git a/src/gt4py/next/iterator/transforms/prune_closure_inputs.py b/src/gt4py/next/iterator/transforms/prune_closure_inputs.py index 7fd3c50c6e..1e637a0bfb 100644 --- a/src/gt4py/next/iterator/transforms/prune_closure_inputs.py +++ b/src/gt4py/next/iterator/transforms/prune_closure_inputs.py @@ -12,11 +12,11 @@ # # SPDX-License-Identifier: GPL-3.0-or-later -from gt4py.eve import NodeTranslator +from gt4py.eve import NodeTranslator, PreserveLocationVisitor from gt4py.next.iterator import ir -class PruneClosureInputs(NodeTranslator): +class PruneClosureInputs(PreserveLocationVisitor, NodeTranslator): """Removes all unused input arguments from a stencil closure.""" def visit_StencilClosure(self, node: ir.StencilClosure) -> ir.StencilClosure: diff --git a/src/gt4py/next/iterator/transforms/remap_symbols.py b/src/gt4py/next/iterator/transforms/remap_symbols.py index cdf3d76173..431dd6cd7a 100644 --- a/src/gt4py/next/iterator/transforms/remap_symbols.py +++ b/src/gt4py/next/iterator/transforms/remap_symbols.py @@ -14,11 +14,11 @@ from typing import Any, Dict, Optional, Set -from gt4py.eve import NodeTranslator, SymbolTableTrait +from gt4py.eve import NodeTranslator, PreserveLocationVisitor, SymbolTableTrait from gt4py.next.iterator import ir -class RemapSymbolRefs(NodeTranslator): +class RemapSymbolRefs(PreserveLocationVisitor, NodeTranslator): PRESERVED_ANNEX_ATTRS = ("type",) def visit_SymRef(self, node: ir.SymRef, *, symbol_map: Dict[str, ir.Node]): @@ -39,7 +39,7 @@ def generic_visit(self, node: ir.Node, **kwargs: Any): # type: ignore[override] return super().generic_visit(node, **kwargs) -class RenameSymbols(NodeTranslator): +class RenameSymbols(PreserveLocationVisitor, NodeTranslator): PRESERVED_ANNEX_ATTRS = ("type",) def visit_Sym( diff --git a/src/gt4py/next/iterator/transforms/scan_eta_reduction.py b/src/gt4py/next/iterator/transforms/scan_eta_reduction.py index 3266c25c4b..d93b4242ab 100644 --- a/src/gt4py/next/iterator/transforms/scan_eta_reduction.py +++ b/src/gt4py/next/iterator/transforms/scan_eta_reduction.py @@ -12,7 +12,7 @@ # # SPDX-License-Identifier: GPL-3.0-or-later -from gt4py.eve import NodeTranslator +from gt4py.eve import NodeTranslator, PreserveLocationVisitor from gt4py.next.iterator import ir @@ -24,7 +24,7 @@ def _is_scan(node: ir.Node): ) -class ScanEtaReduction(NodeTranslator): +class ScanEtaReduction(PreserveLocationVisitor, NodeTranslator): """Applies eta-reduction-like transformation involving scans. Simplifies `λ(x, y) → scan(λ(state, param_y, param_x) → ..., ...)(y, x)` to `scan(λ(state, param_x, param_y) → ..., ...)`. @@ -55,9 +55,8 @@ def visit_Lambda(self, node: ir.Lambda) -> ir.Node: original_scanpass.params[i + 1] for i in new_scanpass_params_idx ] new_scanpass = ir.Lambda(params=new_scanpass_params, expr=original_scanpass.expr) - result = ir.FunCall( + return ir.FunCall( fun=ir.SymRef(id="scan"), args=[new_scanpass, *node.expr.fun.args[1:]] ) - return result return self.generic_visit(node) diff --git a/src/gt4py/next/iterator/transforms/symbol_ref_utils.py b/src/gt4py/next/iterator/transforms/symbol_ref_utils.py index 1c587fb9d6..05d137e8c4 100644 --- a/src/gt4py/next/iterator/transforms/symbol_ref_utils.py +++ b/src/gt4py/next/iterator/transforms/symbol_ref_utils.py @@ -21,7 +21,7 @@ @dataclasses.dataclass -class CountSymbolRefs(eve.NodeVisitor): +class CountSymbolRefs(eve.PreserveLocationVisitor, eve.NodeVisitor): ref_counts: dict[str, int] = dataclasses.field(default_factory=lambda: defaultdict(int)) @classmethod diff --git a/src/gt4py/next/iterator/transforms/trace_shifts.py b/src/gt4py/next/iterator/transforms/trace_shifts.py index 5c607e7df1..082987ac96 100644 --- a/src/gt4py/next/iterator/transforms/trace_shifts.py +++ b/src/gt4py/next/iterator/transforms/trace_shifts.py @@ -16,7 +16,7 @@ from collections.abc import Callable from typing import Any, Final, Iterable, Literal -from gt4py.eve import NodeTranslator +from gt4py.eve import NodeTranslator, PreserveLocationVisitor from gt4py.next.iterator import ir @@ -235,7 +235,7 @@ def _tuple_get(index, tuple_val): @dataclasses.dataclass(frozen=True) -class TraceShifts(NodeTranslator): +class TraceShifts(PreserveLocationVisitor, NodeTranslator): shift_recorder: ShiftRecorder = dataclasses.field(default_factory=ShiftRecorder) def visit_Literal(self, node: ir.SymRef, *, ctx: dict[str, Any]) -> Any: diff --git a/src/gt4py/next/iterator/transforms/unroll_reduce.py b/src/gt4py/next/iterator/transforms/unroll_reduce.py index 861052bb25..3c878b2b00 100644 --- a/src/gt4py/next/iterator/transforms/unroll_reduce.py +++ b/src/gt4py/next/iterator/transforms/unroll_reduce.py @@ -16,7 +16,7 @@ from collections.abc import Iterable, Iterator from typing import TypeGuard -from gt4py.eve import NodeTranslator +from gt4py.eve import NodeTranslator, PreserveLocationVisitor from gt4py.eve.utils import UIDGenerator from gt4py.next import common from gt4py.next.iterator import ir as itir @@ -100,31 +100,36 @@ def _get_connectivity( def _make_shift(offsets: list[itir.Expr], iterator: itir.Expr) -> itir.FunCall: return itir.FunCall( - fun=itir.FunCall(fun=itir.SymRef(id="shift"), args=offsets), args=[iterator] + fun=itir.FunCall(fun=itir.SymRef(id="shift"), args=offsets), + args=[iterator], + location=iterator.location, ) def _make_deref(iterator: itir.Expr) -> itir.FunCall: - return itir.FunCall(fun=itir.SymRef(id="deref"), args=[iterator]) + return itir.FunCall(fun=itir.SymRef(id="deref"), args=[iterator], location=iterator.location) def _make_can_deref(iterator: itir.Expr) -> itir.FunCall: - return itir.FunCall(fun=itir.SymRef(id="can_deref"), args=[iterator]) + return itir.FunCall( + fun=itir.SymRef(id="can_deref"), args=[iterator], location=iterator.location + ) def _make_if(cond: itir.Expr, true_expr: itir.Expr, false_expr: itir.Expr) -> itir.FunCall: return itir.FunCall( fun=itir.SymRef(id="if_"), args=[cond, true_expr, false_expr], + location=cond.location, ) def _make_list_get(offset: itir.Expr, expr: itir.Expr) -> itir.FunCall: - return itir.FunCall(fun=itir.SymRef(id="list_get"), args=[offset, expr]) + return itir.FunCall(fun=itir.SymRef(id="list_get"), args=[offset, expr], location=expr.location) @dataclasses.dataclass(frozen=True) -class UnrollReduce(NodeTranslator): +class UnrollReduce(PreserveLocationVisitor, NodeTranslator): # we use one UID generator per instance such that the generated ids are # stable across multiple runs (required for caching to properly work) uids: UIDGenerator = dataclasses.field(init=False, repr=False, default_factory=UIDGenerator) diff --git a/src/gt4py/next/program_processors/runners/dace_iterator/__init__.py b/src/gt4py/next/program_processors/runners/dace_iterator/__init__.py index fdd8a61054..54ca08fe6e 100644 --- a/src/gt4py/next/program_processors/runners/dace_iterator/__init__.py +++ b/src/gt4py/next/program_processors/runners/dace_iterator/__init__.py @@ -13,6 +13,7 @@ # SPDX-License-Identifier: GPL-3.0-or-later import hashlib import warnings +from inspect import currentframe, getframeinfo from typing import Any, Mapping, Optional, Sequence import dace @@ -265,6 +266,19 @@ def build_sdfg_from_itir( if sdfg is None: raise RuntimeError(f"Visit failed for program {program.id}.") + for nested_sdfg in sdfg.all_sdfgs_recursive(): + if not nested_sdfg.debuginfo: + _, frameinfo = warnings.warn( + f"{nested_sdfg} does not have debuginfo. Consider adding them in the corresponding nested sdfg." + ), getframeinfo( + currentframe() # type: ignore + ) + nested_sdfg.debuginfo = dace.dtypes.DebugInfo( + start_line=frameinfo.lineno, + end_line=frameinfo.lineno, + filename=frameinfo.filename, + ) + # run DaCe transformations to simplify the SDFG sdfg.simplify() diff --git a/src/gt4py/next/program_processors/runners/dace_iterator/itir_to_sdfg.py b/src/gt4py/next/program_processors/runners/dace_iterator/itir_to_sdfg.py index fb2f82fed0..dc194c0436 100644 --- a/src/gt4py/next/program_processors/runners/dace_iterator/itir_to_sdfg.py +++ b/src/gt4py/next/program_processors/runners/dace_iterator/itir_to_sdfg.py @@ -38,6 +38,7 @@ connectivity_identifier, create_memlet_at, create_memlet_full, + dace_debuginfo, filter_neighbor_tables, flatten_list, get_sorted_dims, @@ -143,6 +144,7 @@ def get_output_nodes( def visit_FencilDefinition(self, node: itir.FencilDefinition): program_sdfg = dace.SDFG(name=node.id) + program_sdfg.debuginfo = dace_debuginfo(node) last_state = program_sdfg.add_state("program_entry", True) self.node_types = itir_typing.infer_all(node) @@ -187,15 +189,16 @@ def visit_FencilDefinition(self, node: itir.FencilDefinition): inputs=set(input_names), outputs=set(output_names), symbol_mapping=symbol_mapping, + debuginfo=closure_sdfg.debuginfo, ) # Add access nodes for the program parameters and connect them to the nested SDFG's inputs via edges. for inner_name, memlet in input_mapping.items(): - access_node = last_state.add_access(inner_name) + access_node = last_state.add_access(inner_name, debuginfo=nsdfg_node.debuginfo) last_state.add_edge(access_node, None, nsdfg_node, inner_name, memlet) for inner_name, memlet in output_mapping.items(): - access_node = last_state.add_access(inner_name) + access_node = last_state.add_access(inner_name, debuginfo=nsdfg_node.debuginfo) last_state.add_edge(nsdfg_node, inner_name, access_node, None, memlet) # Create the call signature for the SDFG. @@ -213,6 +216,7 @@ def visit_StencilClosure( # Create the closure's nested SDFG and single state. closure_sdfg = dace.SDFG(name="closure") + closure_sdfg.debuginfo = dace_debuginfo(node) closure_state = closure_sdfg.add_state("closure_entry") closure_init_state = closure_sdfg.add_state_before(closure_state, "closure_init", True) @@ -239,8 +243,8 @@ def visit_StencilClosure( transient=True, ) closure_init_state.add_nedge( - closure_init_state.add_access(name), - closure_init_state.add_access(transient_name), + closure_init_state.add_access(name, debuginfo=closure_sdfg.debuginfo), + closure_init_state.add_access(transient_name, debuginfo=closure_sdfg.debuginfo), create_memlet_full(name, closure_sdfg.arrays[name]), ) input_transients_mapping[name] = transient_name @@ -276,9 +280,15 @@ def visit_StencilClosure( out_name = unique_var_name() closure_sdfg.add_scalar(out_name, dtype, transient=True) out_tasklet = closure_init_state.add_tasklet( - f"get_{name}", {}, {"__result"}, f"__result = {name}" + f"get_{name}", + {}, + {"__result"}, + f"__result = {name}", + debuginfo=closure_sdfg.debuginfo, + ) + access = closure_init_state.add_access( + out_name, debuginfo=closure_sdfg.debuginfo ) - access = closure_init_state.add_access(out_name) value = ValueExpr(access, dtype) memlet = dace.Memlet.simple(out_name, "0") closure_init_state.add_edge(out_tasklet, "__result", access, None, memlet) @@ -356,19 +366,20 @@ def visit_StencilClosure( outputs=output_mapping, symbol_mapping=symbol_mapping, output_nodes=output_nodes, + debuginfo=nsdfg.debuginfo, ) access_nodes = {edge.data.data: edge.dst for edge in closure_state.out_edges(map_exit)} for edge in closure_state.in_edges(map_exit): memlet = edge.data if memlet.data not in output_connectors_mapping: continue - transient_access = closure_state.add_access(memlet.data) + transient_access = closure_state.add_access(memlet.data, debuginfo=nsdfg.debuginfo) closure_state.add_edge( nsdfg_node, edge.src_conn, transient_access, None, - dace.Memlet.simple(memlet.data, output_subset), + dace.Memlet.simple(memlet.data, output_subset, debuginfo=nsdfg.debuginfo), ) inner_memlet = dace.Memlet.simple( memlet.data, output_subset, other_subset_str=memlet.subset @@ -417,6 +428,7 @@ def _visit_scan_stencil_closure( # the scan operator is implemented as an SDFG to be nested in the closure SDFG scan_sdfg = dace.SDFG(name="scan") + scan_sdfg.debuginfo = dace_debuginfo(node) # create a state machine for lambda call over the scan dimension start_state = scan_sdfg.add_state("start", True) @@ -429,12 +441,16 @@ def _visit_scan_stencil_closure( # tasklet for initialization of carry carry_init_tasklet = start_state.add_tasklet( - "get_carry_init_value", {}, {"__result"}, f"__result = {init_carry_value}" + "get_carry_init_value", + {}, + {"__result"}, + f"__result = {init_carry_value}", + debuginfo=scan_sdfg.debuginfo, ) start_state.add_edge( carry_init_tasklet, "__result", - start_state.add_access(scan_carry_name), + start_state.add_access(scan_carry_name, debuginfo=scan_sdfg.debuginfo), None, dace.Memlet.simple(scan_carry_name, "0"), ) @@ -512,11 +528,12 @@ def _visit_scan_stencil_closure( inputs=set(lambda_input_names) | set(connectivity_names), outputs=set(lambda_output_names), symbol_mapping=symbol_mapping, + debuginfo=lambda_context.body.debuginfo, ) # connect scan SDFG to lambda inputs for name, memlet in array_mapping.items(): - access_node = lambda_state.add_access(name) + access_node = lambda_state.add_access(name, debuginfo=lambda_context.body.debuginfo) lambda_state.add_edge(access_node, None, scan_inner_node, name, memlet) output_names = [output_name] @@ -526,7 +543,7 @@ def _visit_scan_stencil_closure( lambda_state.add_edge( scan_inner_node, connector, - lambda_state.add_access(name), + lambda_state.add_access(name, debuginfo=lambda_context.body.debuginfo), None, dace.Memlet.simple(name, f"i_{scan_dim}"), ) @@ -534,8 +551,10 @@ def _visit_scan_stencil_closure( # add state to scan SDFG to update the carry value at each loop iteration lambda_update_state = scan_sdfg.add_state_after(lambda_state, "lambda_update") lambda_update_state.add_memlet_path( - lambda_update_state.add_access(output_name), - lambda_update_state.add_access(scan_carry_name), + lambda_update_state.add_access(output_name, debuginfo=lambda_context.body.debuginfo), + lambda_update_state.add_access( + scan_carry_name, debuginfo=lambda_context.body.debuginfo + ), memlet=dace.Memlet.simple(output_names[0], f"i_{scan_dim}", other_subset_str="0"), ) diff --git a/src/gt4py/next/program_processors/runners/dace_iterator/itir_to_tasklet.py b/src/gt4py/next/program_processors/runners/dace_iterator/itir_to_tasklet.py index 4c202b1fe8..0ace6948b0 100644 --- a/src/gt4py/next/program_processors/runners/dace_iterator/itir_to_tasklet.py +++ b/src/gt4py/next/program_processors/runners/dace_iterator/itir_to_tasklet.py @@ -35,6 +35,7 @@ connectivity_identifier, create_memlet_at, create_memlet_full, + dace_debuginfo, filter_neighbor_tables, flatten_list, map_nested_sdfg_symbols, @@ -183,6 +184,7 @@ def __init__( def builtin_neighbors( transformer: "PythonTaskletCodegen", node: itir.Expr, node_args: list[itir.Expr] ) -> list[ValueExpr]: + di = dace_debuginfo(node, transformer.context.body.debuginfo) offset_literal, data = node_args assert isinstance(offset_literal, itir.OffsetLiteral) offset_dim = offset_literal.value @@ -214,13 +216,14 @@ def builtin_neighbors( sdfg.add_array( result_name, dtype=iterator.dtype, shape=(offset_provider.max_neighbors,), transient=True ) - result_access = state.add_access(result_name) + result_access = state.add_access(result_name, debuginfo=di) # generate unique map index name to avoid conflict with other maps inside same state neighbor_index = unique_name("neighbor_idx") me, mx = state.add_map( f"{offset_dim}_neighbors_map", ndrange={neighbor_index: f"0:{offset_provider.max_neighbors}"}, + debuginfo=di, ) table_name = connectivity_identifier(offset_dim) table_subset = (f"0:{sdfg.arrays[table_name].shape[0]}", neighbor_index) @@ -230,17 +233,19 @@ def builtin_neighbors( code="__result = __table[__idx]", inputs={"__table", "__idx"}, outputs={"__result"}, + debuginfo=di, ) data_access_tasklet = state.add_tasklet( "data_access", code=f"__result = __field[{field_index}] if {neighbor_check} else {transformer.context.reduce_identity.value}", inputs={"__field", field_index}, outputs={"__result"}, + debuginfo=di, ) idx_name = unique_var_name() sdfg.add_scalar(idx_name, _INDEX_DTYPE, transient=True) state.add_memlet_path( - state.add_access(table_name), + state.add_access(table_name, debuginfo=di), me, shift_tasklet, memlet=create_memlet_at(table_name, table_subset), @@ -250,7 +255,7 @@ def builtin_neighbors( iterator.indices[shifted_dim], me, shift_tasklet, - memlet=dace.Memlet.simple(iterator.indices[shifted_dim].data, "0"), + memlet=dace.Memlet.simple(iterator.indices[shifted_dim].data, "0", debuginfo=di), dst_conn="__idx", ) state.add_edge(shift_tasklet, "__result", data_access_tasklet, field_index, dace.Memlet()) @@ -270,7 +275,7 @@ def builtin_neighbors( data_access_tasklet, mx, result_access, - memlet=dace.Memlet.simple(result_name, neighbor_index), + memlet=dace.Memlet.simple(result_name, neighbor_index, debuginfo=di), src_conn="__result", ) @@ -280,6 +285,7 @@ def builtin_neighbors( def builtin_can_deref( transformer: "PythonTaskletCodegen", node: itir.Expr, node_args: list[itir.Expr] ) -> list[ValueExpr]: + di = dace_debuginfo(node, transformer.context.body.debuginfo) # first visit shift, to get set of indices for deref can_deref_callable = node_args[0] assert isinstance(can_deref_callable, itir.FunCall) @@ -296,13 +302,15 @@ def builtin_can_deref( # Returning a SymbolExpr would be preferable, but it requires update to type-checking. result_name = unique_var_name() transformer.context.body.add_scalar(result_name, dace.dtypes.bool, transient=True) - result_node = transformer.context.state.add_access(result_name) + result_node = transformer.context.state.add_access(result_name, debuginfo=di) transformer.context.state.add_edge( - transformer.context.state.add_tasklet("can_always_deref", {}, {"_out"}, "_out = True"), + transformer.context.state.add_tasklet( + "can_always_deref", {}, {"_out"}, "_out = True", debuginfo=di + ), "_out", result_node, None, - dace.Memlet.simple(result_name, "0"), + dace.Memlet.simple(result_name, "0", debuginfo=di), ) return [ValueExpr(result_node, dace.dtypes.bool)] @@ -313,13 +321,18 @@ def builtin_can_deref( # TODO(edopao): select-memlet could maybe allow to efficiently translate can_deref to predicative execution return transformer.add_expr_tasklet( - list(zip(args, internals)), expr_code, dace.dtypes.bool, "can_deref" + list(zip(args, internals)), + expr_code, + dace.dtypes.bool, + "can_deref", + dace_debuginfo=di, ) def builtin_if( transformer: "PythonTaskletCodegen", node: itir.Expr, node_args: list[itir.Expr] ) -> list[ValueExpr]: + di = dace_debuginfo(node, transformer.context.body.debuginfo) args = transformer.visit(node_args) assert len(args) == 3 if_node = args[0][0] if isinstance(args[0], list) else args[0] @@ -346,7 +359,7 @@ def builtin_if( for arg in (if_node, a, b) ] expr = "({1} if {0} else {2})".format(*internals) - if_expr = transformer.add_expr_tasklet(expr_args, expr, a.dtype, "if") + if_expr = transformer.add_expr_tasklet(expr_args, expr, a.dtype, "if", dace_debuginfo=di) if_expr_values.append(if_expr[0]) return if_expr_values @@ -355,6 +368,7 @@ def builtin_if( def builtin_list_get( transformer: "PythonTaskletCodegen", node: itir.Expr, node_args: list[itir.Expr] ) -> list[ValueExpr]: + di = dace_debuginfo(node, transformer.context.body.debuginfo) args = list(itertools.chain(*transformer.visit(node_args))) assert len(args) == 2 # index node @@ -369,12 +383,15 @@ def builtin_list_get( arg.value if isinstance(arg, SymbolExpr) else f"{arg.value.data}_v" for arg in args ] expr = f"{internals[1]}[{internals[0]}]" - return transformer.add_expr_tasklet(expr_args, expr, args[1].dtype, "list_get") + return transformer.add_expr_tasklet( + expr_args, expr, args[1].dtype, "list_get", dace_debuginfo=di + ) def builtin_cast( transformer: "PythonTaskletCodegen", node: itir.Expr, node_args: list[itir.Expr] ) -> list[ValueExpr]: + di = dace_debuginfo(node, transformer.context.body.debuginfo) args = transformer.visit(node_args[0]) internals = [f"{arg.value.data}_v" for arg in args] target_type = node_args[1] @@ -383,7 +400,13 @@ def builtin_cast( node_type = transformer.node_types[id(node)] assert isinstance(node_type, itir_typing.Val) type_ = itir_type_as_dace_type(node_type.dtype) - return transformer.add_expr_tasklet(list(zip(args, internals)), expr, type_, "cast") + return transformer.add_expr_tasklet( + list(zip(args, internals)), + expr, + type_, + "cast", + dace_debuginfo=di, + ) def builtin_make_tuple( @@ -443,7 +466,9 @@ def _add_symbol(self, param, arg): # create storage in lambda sdfg self._sdfg.add_scalar(param, dtype=arg.dtype) # update table of lambda symbol - self._symbol_map[param] = ValueExpr(self._state.add_access(param), arg.dtype) + self._symbol_map[param] = ValueExpr( + self._state.add_access(param, debuginfo=self._sdfg.debuginfo), arg.dtype + ) elif isinstance(arg, IteratorExpr): # create storage in lambda sdfg ndims = len(arg.dimensions) @@ -453,9 +478,10 @@ def _add_symbol(self, param, arg): for _, index_name in index_names.items(): self._sdfg.add_scalar(index_name, dtype=_INDEX_DTYPE) # update table of lambda symbol - field = self._state.add_access(param) + field = self._state.add_access(param, debuginfo=self._sdfg.debuginfo) indices = { - dim: self._state.add_access(index_arg) for dim, index_arg in index_names.items() + dim: self._state.add_access(index_arg, debuginfo=self._sdfg.debuginfo) + for dim, index_arg in index_names.items() } self._symbol_map[param] = IteratorExpr(field, indices, arg.dtype, arg.dimensions) else: @@ -503,7 +529,7 @@ def visit_SymRef(self, node: itir.SymRef): if param not in _GENERAL_BUILTIN_MAPPING and param not in self._symbol_map: node_type = self._node_types[id(node)] assert isinstance(node_type, Val) - access_node = self._state.add_access(param) + access_node = self._state.add_access(param, debuginfo=self._sdfg.debuginfo) self._symbol_map[param] = ValueExpr( access_node, dtype=itir_type_as_dace_type(node_type.dtype) ) @@ -542,6 +568,7 @@ def visit_Lambda( # Create the SDFG for the lambda's body lambda_sdfg = dace.SDFG(func_name) + lambda_sdfg.debuginfo = dace_debuginfo(node) lambda_state = lambda_sdfg.add_state(f"{func_name}_entry", True) lambda_symbols_pass = GatherLambdaSymbolsPass( @@ -586,11 +613,14 @@ def visit_Lambda( results: list[ValueExpr] = [] # We are flattening the returned list of value expressions because the multiple outputs of a lambda # should be a list of nodes without tuple structure. Ideally, an ITIR transformation could do this. + node.expr.location = node.location for expr in flatten_list(lambda_taskgen.visit(node.expr)): if isinstance(expr, ValueExpr): result_name = unique_var_name() lambda_sdfg.add_scalar(result_name, expr.dtype, transient=True) - result_access = lambda_state.add_access(result_name) + result_access = lambda_state.add_access( + result_name, debuginfo=lambda_sdfg.debuginfo + ) lambda_state.add_nedge( expr.value, result_access, @@ -599,7 +629,9 @@ def visit_Lambda( result = ValueExpr(value=result_access, dtype=expr.dtype) else: # Forwarding result through a tasklet needed because empty SDFG states don't properly forward connectors - result = lambda_taskgen.add_expr_tasklet([], expr.value, expr.dtype, "forward")[0] + result = lambda_taskgen.add_expr_tasklet( + [], expr.value, expr.dtype, "forward", dace_debuginfo=lambda_sdfg.debuginfo + )[0] lambda_sdfg.arrays[result.value.data].transient = False results.append(result) @@ -624,6 +656,7 @@ def visit_Literal(self, node: itir.Literal) -> list[SymbolExpr]: return [SymbolExpr(node.value, itir_type_as_dace_type(node_type.dtype))] def visit_FunCall(self, node: itir.FunCall) -> list[ValueExpr] | IteratorExpr: + node.fun.location = node.location if isinstance(node.fun, itir.SymRef) and node.fun.id == "deref": return self._visit_deref(node) if isinstance(node.fun, itir.FunCall) and isinstance(node.fun.fun, itir.SymRef): @@ -646,7 +679,7 @@ def _visit_call(self, node: itir.FunCall): args = self.visit(node.args) args = [arg if isinstance(arg, Sequence) else [arg] for arg in args] args = list(itertools.chain(*args)) - + node.fun.location = node.location func_context, func_inputs, results = self.visit(node.fun, args=args) nsdfg_inputs = {} @@ -679,6 +712,7 @@ def _visit_call(self, node: itir.FunCall): inputs=set(nsdfg_inputs.keys()), outputs=set(r.value.data for r in results), symbol_mapping=symbol_mapping, + debuginfo=dace_debuginfo(node, func_context.body.debuginfo), ) for name, value in func_inputs: @@ -698,14 +732,14 @@ def _visit_call(self, node: itir.FunCall): for conn, _ in neighbor_tables: var = connectivity_identifier(conn) memlet = nsdfg_inputs[var] - access = self.context.state.add_access(var) + access = self.context.state.add_access(var, debuginfo=nsdfg_node.debuginfo) self.context.state.add_edge(access, None, nsdfg_node, var, memlet) result_exprs = [] for result in results: name = unique_var_name() self.context.body.add_scalar(name, result.dtype, transient=True) - result_access = self.context.state.add_access(name) + result_access = self.context.state.add_access(name, debuginfo=nsdfg_node.debuginfo) result_exprs.append(ValueExpr(result_access, result.dtype)) memlet = create_memlet_full(name, self.context.body.arrays[name]) self.context.state.add_edge(nsdfg_node, result.value.data, result_access, None, memlet) @@ -713,6 +747,7 @@ def _visit_call(self, node: itir.FunCall): return result_exprs def _visit_deref(self, node: itir.FunCall) -> list[ValueExpr]: + di = dace_debuginfo(node, self.context.body.debuginfo) iterator = self.visit(node.args[0]) if not isinstance(iterator, IteratorExpr): # already a list of ValueExpr @@ -727,7 +762,13 @@ def _visit_deref(self, node: itir.FunCall) -> list[ValueExpr]: ] internals = [f"{arg.value.data}_v" for arg in args] expr = f"{internals[0]}[{', '.join(internals[1:])}]" - return self.add_expr_tasklet(list(zip(args, internals)), expr, iterator.dtype, "deref") + return self.add_expr_tasklet( + list(zip(args, internals)), + expr, + iterator.dtype, + "deref", + dace_debuginfo=di, + ) else: # Not all dimensions are included in the deref index list: @@ -741,7 +782,7 @@ def _visit_deref(self, node: itir.FunCall) -> list[ValueExpr]: result_name = unique_var_name() self.context.body.add_array(result_name, result_shape, iterator.dtype, transient=True) result_array = self.context.body.arrays[result_name] - result_node = self.context.state.add_access(result_name) + result_node = self.context.state.add_access(result_name, debuginfo=di) deref_connectors = ["_inp"] + [ f"_i_{dim}" for dim in sorted_dims if dim in iterator.indices @@ -776,6 +817,7 @@ def _visit_deref(self, node: itir.FunCall) -> list[ValueExpr]: output_nodes={ result_name: result_node, }, + debuginfo=di, ) return [ValueExpr(result_node, iterator.dtype)] @@ -789,10 +831,13 @@ def _split_shift_args( def _make_shift_for_rest(self, rest, iterator): return itir.FunCall( - fun=itir.FunCall(fun=itir.SymRef(id="shift"), args=rest), args=[iterator] + fun=itir.FunCall(fun=itir.SymRef(id="shift"), args=rest), + args=[iterator], + location=iterator.location, ) def _visit_shift(self, node: itir.FunCall) -> IteratorExpr | list[ValueExpr]: + di = dace_debuginfo(node, self.context.body.debuginfo) shift = node.fun assert isinstance(shift, itir.FunCall) tail, rest = self._split_shift_args(shift.args) @@ -815,7 +860,9 @@ def _visit_shift(self, node: itir.FunCall) -> IteratorExpr | list[ValueExpr]: if isinstance(self.offset_provider[offset_dim], NeighborTableOffsetProvider): offset_provider = self.offset_provider[offset_dim] - connectivity = self.context.state.add_access(connectivity_identifier(offset_dim)) + connectivity = self.context.state.add_access( + connectivity_identifier(offset_dim), debuginfo=di + ) shifted_dim = offset_provider.origin_axis.value target_dim = offset_provider.neighbor_axis.value @@ -850,7 +897,7 @@ def _visit_shift(self, node: itir.FunCall) -> IteratorExpr | list[ValueExpr]: expr = f"{internals[0]} + {internals[1]}" shifted_value = self.add_expr_tasklet( - list(zip(args, internals)), expr, offset_node.dtype, "shift" + list(zip(args, internals)), expr, offset_node.dtype, "shift", dace_debuginfo=di )[0].value shifted_index = {dim: value for dim, value in iterator.indices.items()} @@ -860,13 +907,14 @@ def _visit_shift(self, node: itir.FunCall) -> IteratorExpr | list[ValueExpr]: return IteratorExpr(iterator.field, shifted_index, iterator.dtype, iterator.dimensions) def visit_OffsetLiteral(self, node: itir.OffsetLiteral) -> list[ValueExpr]: + di = dace_debuginfo(node, self.context.body.debuginfo) offset = node.value assert isinstance(offset, int) offset_var = unique_var_name() self.context.body.add_scalar(offset_var, _INDEX_DTYPE, transient=True) - offset_node = self.context.state.add_access(offset_var) + offset_node = self.context.state.add_access(offset_var, debuginfo=di) tasklet_node = self.context.state.add_tasklet( - "get_offset", {}, {"__out"}, f"__out = {offset}" + "get_offset", {}, {"__out"}, f"__out = {offset}", debuginfo=di ) self.context.state.add_edge( tasklet_node, "__out", offset_node, None, dace.Memlet.simple(offset_var, "0") @@ -874,6 +922,7 @@ def visit_OffsetLiteral(self, node: itir.OffsetLiteral) -> list[ValueExpr]: return [ValueExpr(offset_node, self.context.body.arrays[offset_var].dtype)] def _visit_reduce(self, node: itir.FunCall): + di = dace_debuginfo(node, self.context.body.debuginfo) node_type = self.node_types[id(node)] assert isinstance(node_type, itir_typing.Val) reduce_dtype = itir_type_as_dace_type(node_type.dtype) @@ -930,7 +979,9 @@ def _visit_reduce(self, node: itir.FunCall): reduce_input_name, nreduce_shape, reduce_dtype, transient=True ) - lambda_node = itir.Lambda(expr=fun_node.expr.args[1], params=fun_node.params[1:]) + lambda_node = itir.Lambda( + expr=fun_node.expr.args[1], params=fun_node.params[1:], location=node.location + ) lambda_context, inner_inputs, inner_outputs = self.visit( lambda_node, args=args, use_neighbor_tables=False ) @@ -946,7 +997,7 @@ def _visit_reduce(self, node: itir.FunCall): self.context.body, lambda_context.body, input_mapping ) - reduce_input_node = self.context.state.add_access(reduce_input_name) + reduce_input_node = self.context.state.add_access(reduce_input_name, debuginfo=di) nsdfg_node, map_entry, _ = add_mapped_nested_sdfg( self.context.state, @@ -957,6 +1008,7 @@ def _visit_reduce(self, node: itir.FunCall): symbol_mapping=symbol_mapping, input_nodes={arg.value.data: arg.value for arg in args}, output_nodes={reduce_input_name: reduce_input_node}, + debuginfo=di, ) reduce_input_desc = reduce_input_node.desc(self.context.body) @@ -964,7 +1016,7 @@ def _visit_reduce(self, node: itir.FunCall): result_name = unique_var_name() # we allocate an array instead of a scalar because the reduce library node is generic and expects an array node self.context.body.add_array(result_name, (1,), reduce_dtype, transient=True) - result_access = self.context.state.add_access(result_name) + result_access = self.context.state.add_access(result_name, debuginfo=di) reduce_wcr = "lambda x, y: " + _MATH_BUILTINS_MAPPING[str(op_name)].format("x", "y") reduce_node = self.context.state.add_reduce(reduce_wcr, None, reduce_identity) @@ -997,7 +1049,13 @@ def _visit_numeric_builtin(self, node: itir.FunCall) -> list[ValueExpr]: node_type = self.node_types[id(node)] assert isinstance(node_type, itir_typing.Val) type_ = itir_type_as_dace_type(node_type.dtype) - return self.add_expr_tasklet(expr_args, expr, type_, "numeric") + return self.add_expr_tasklet( + expr_args, + expr, + type_, + "numeric", + dace_debuginfo=dace_debuginfo(node, self.context.body.debuginfo), + ) def _visit_general_builtin(self, node: itir.FunCall) -> list[ValueExpr]: assert isinstance(node.fun, itir.SymRef) @@ -1005,17 +1063,24 @@ def _visit_general_builtin(self, node: itir.FunCall) -> list[ValueExpr]: return expr_func(self, node, node.args) def add_expr_tasklet( - self, args: list[tuple[ValueExpr, str]], expr: str, result_type: Any, name: str + self, + args: list[tuple[ValueExpr, str]], + expr: str, + result_type: Any, + name: str, + dace_debuginfo: Optional[dace.dtypes.DebugInfo] = None, ) -> list[ValueExpr]: + di = dace_debuginfo if dace_debuginfo else self.context.body.debuginfo result_name = unique_var_name() self.context.body.add_scalar(result_name, result_type, transient=True) - result_access = self.context.state.add_access(result_name) + result_access = self.context.state.add_access(result_name, debuginfo=di) expr_tasklet = self.context.state.add_tasklet( name=name, inputs={internal for _, internal in args}, outputs={"__result"}, code=f"__result = {expr}", + debuginfo=di, ) for arg, internal in args: @@ -1033,7 +1098,7 @@ def add_expr_tasklet( ) self.context.state.add_edge(arg.value, None, expr_tasklet, internal, memlet) - memlet = dace.Memlet.simple(result_access.data, "0") + memlet = dace.Memlet.simple(result_access.data, "0", debuginfo=di) self.context.state.add_edge(expr_tasklet, "__result", result_access, None, memlet) return [ValueExpr(result_access, result_type)] @@ -1052,6 +1117,7 @@ def closure_to_tasklet_sdfg( node_types: dict[int, next_typing.Type], ) -> tuple[Context, Sequence[ValueExpr]]: body = dace.SDFG("tasklet_toplevel") + body.debuginfo = dace_debuginfo(node) state = body.add_state("tasklet_toplevel_entry", True) symbol_map: dict[str, TaskletExpr] = {} @@ -1059,8 +1125,10 @@ def closure_to_tasklet_sdfg( for dim, idx in domain.items(): name = f"{idx}_value" body.add_scalar(name, dtype=_INDEX_DTYPE, transient=True) - tasklet = state.add_tasklet(f"get_{dim}", set(), {"value"}, f"value = {idx}") - access = state.add_access(name) + tasklet = state.add_tasklet( + f"get_{dim}", set(), {"value"}, f"value = {idx}", debuginfo=body.debuginfo + ) + access = state.add_access(name, debuginfo=body.debuginfo) idx_accesses[dim] = access state.add_edge(tasklet, "value", access, None, dace.Memlet.simple(name, "0")) for name, ty in inputs: @@ -1070,14 +1138,14 @@ def closure_to_tasklet_sdfg( dims = [dim.value for dim in ty.dims] dtype = as_dace_type(ty.dtype) body.add_array(name, shape=shape, strides=strides, dtype=dtype) - field = state.add_access(name) + field = state.add_access(name, debuginfo=body.debuginfo) indices = {dim: idx_accesses[dim] for dim in domain.keys()} symbol_map[name] = IteratorExpr(field, indices, dtype, dims) else: assert isinstance(ty, ts.ScalarType) dtype = as_dace_type(ty) body.add_scalar(name, dtype=dtype) - symbol_map[name] = ValueExpr(state.add_access(name), dtype) + symbol_map[name] = ValueExpr(state.add_access(name, debuginfo=body.debuginfo), dtype) for arr, name in connectivities: shape, strides = new_array_symbols(name, ndim=2) body.add_array(name, shape=shape, strides=strides, dtype=arr.dtype) @@ -1089,10 +1157,12 @@ def closure_to_tasklet_sdfg( if is_scan(node.stencil): stencil = cast(FunCall, node.stencil) assert isinstance(stencil.args[0], Lambda) - lambda_node = itir.Lambda(expr=stencil.args[0].expr, params=stencil.args[0].params) - fun_node = itir.FunCall(fun=lambda_node, args=args) + lambda_node = itir.Lambda( + expr=stencil.args[0].expr, params=stencil.args[0].params, location=node.location + ) + fun_node = itir.FunCall(fun=lambda_node, args=args, location=node.location) else: - fun_node = itir.FunCall(fun=node.stencil, args=args) + fun_node = itir.FunCall(fun=node.stencil, args=args, location=node.location) results = translator.visit(fun_node) for r in results: diff --git a/src/gt4py/next/program_processors/runners/dace_iterator/utility.py b/src/gt4py/next/program_processors/runners/dace_iterator/utility.py index 55717326a3..971c1bbdf2 100644 --- a/src/gt4py/next/program_processors/runners/dace_iterator/utility.py +++ b/src/gt4py/next/program_processors/runners/dace_iterator/utility.py @@ -12,15 +12,31 @@ # # SPDX-License-Identifier: GPL-3.0-or-later import itertools -from typing import Any, Sequence +from typing import Any, Optional, Sequence import dace from gt4py.next import Dimension from gt4py.next.iterator.embedded import NeighborTableOffsetProvider +from gt4py.next.iterator.ir import Node from gt4py.next.type_system import type_specifications as ts +def dace_debuginfo( + node: Node, debuginfo: Optional[dace.dtypes.DebugInfo] = None +) -> Optional[dace.dtypes.DebugInfo]: + location = node.location + if location: + return dace.dtypes.DebugInfo( + start_line=location.line, + start_column=location.column if location.column else 0, + end_line=location.end_line if location.end_line else -1, + end_column=location.end_column if location.end_column else 0, + filename=location.filename, + ) + return debuginfo + + def as_dace_type(type_: ts.ScalarType): if type_.kind == ts.ScalarKind.BOOL: return dace.bool_ @@ -119,11 +135,13 @@ def add_mapped_nested_sdfg( if input_nodes is None: input_nodes = { - memlet.data: state.add_access(memlet.data) for name, memlet in inputs.items() + memlet.data: state.add_access(memlet.data, debuginfo=debuginfo) + for name, memlet in inputs.items() } if output_nodes is None: output_nodes = { - memlet.data: state.add_access(memlet.data) for name, memlet in outputs.items() + memlet.data: state.add_access(memlet.data, debuginfo=debuginfo) + for name, memlet in outputs.items() } if not inputs: state.add_edge(map_entry, None, nsdfg_node, None, dace.Memlet()) From 11f9c1cfe833755cca9bf312b1f2b6862cf41827 Mon Sep 17 00:00:00 2001 From: Hannes Vogt Date: Wed, 24 Jan 2024 14:18:52 +0100 Subject: [PATCH 52/85] Bump version to 1.0.2 (#1421) --- .bumpversion.cfg | 2 +- CHANGELOG.md | 17 +++++++++++++++++ src/gt4py/__about__.py | 2 +- 3 files changed, 19 insertions(+), 2 deletions(-) diff --git a/.bumpversion.cfg b/.bumpversion.cfg index d7a3acaac1..9e65fd9ae0 100644 --- a/.bumpversion.cfg +++ b/.bumpversion.cfg @@ -1,5 +1,5 @@ [bumpversion] -current_version = 1.0.1 +current_version = 1.0.2 parse = (?P\d+)\.(?P\d+)(\.(?P\d+))? serialize = {major}.{minor}.{patch} diff --git a/CHANGELOG.md b/CHANGELOG.md index 519f7ff1db..87f3ee9d2b 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -2,6 +2,23 @@ Notable changes to this project will be documented in this file. The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.0.0/), +## [1.0.2] - 2024-01-24 + +### Cartesian + +- Compatibility of `gt4py.next` Fields with `gt4py.cartesian` computations. +- Fixes for DaCe 0.15.1 compatibility. +- Added `log10` as native function. +- Make `scipy` optional: get `scipy` by installing `gt4py[full]` for best performance with `numpy` backend. + +### Storage + +- Refactored low-level storage allocation. + +### Next + +See commit history. + ## [1.0.1] - 2023-02-20 First version including the experimental `gt4py.next` aka _Declarative GT4Py_. The `gt4py.next` package is excluded from semantic versioning. diff --git a/src/gt4py/__about__.py b/src/gt4py/__about__.py index 57b914f25b..10f4607724 100644 --- a/src/gt4py/__about__.py +++ b/src/gt4py/__about__.py @@ -33,5 +33,5 @@ __license__: Final = "GPL-3.0-or-later" -__version__: Final = "1.0.1" +__version__: Final = "1.0.2" __version_info__: Final = pkg_version.parse(__version__) From ac0478ada60a03474fc607e345ba4f49e7413ade Mon Sep 17 00:00:00 2001 From: edopao Date: Thu, 25 Jan 2024 13:01:19 +0100 Subject: [PATCH 53/85] fix[next][dace]: Use constant shape for neighbor tables in local dimension (#1422) Main purpose of this PR is to avoid the definition of shape symbols for array dimensions known at compile time. The local size of neighbor connectivity tables falls into this category. For each element in the origin dimension, the number of elements in the target dimension is defined by the attribute max_neighbors in the offset provider. --- .../runners/dace_iterator/__init__.py | 96 +++++++++++-------- .../runners/dace_iterator/itir_to_sdfg.py | 85 ++++++++++++---- .../runners/dace_iterator/itir_to_tasklet.py | 52 ++++++---- .../runners/dace_iterator/utility.py | 10 +- 4 files changed, 160 insertions(+), 83 deletions(-) diff --git a/src/gt4py/next/program_processors/runners/dace_iterator/__init__.py b/src/gt4py/next/program_processors/runners/dace_iterator/__init__.py index 54ca08fe6e..a039d311ca 100644 --- a/src/gt4py/next/program_processors/runners/dace_iterator/__init__.py +++ b/src/gt4py/next/program_processors/runners/dace_iterator/__init__.py @@ -14,6 +14,7 @@ import hashlib import warnings from inspect import currentframe, getframeinfo +from pathlib import Path from typing import Any, Mapping, Optional, Sequence import dace @@ -26,7 +27,7 @@ import gt4py.next.program_processors.otf_compile_executor as otf_exec import gt4py.next.program_processors.processor_interface as ppi from gt4py.next import common -from gt4py.next.iterator import embedded as itir_embedded, transforms as itir_transforms +from gt4py.next.iterator import transforms as itir_transforms from gt4py.next.otf.compilation import cache as compilation_cache from gt4py.next.type_system import type_specifications as ts, type_translation @@ -109,23 +110,29 @@ def _ensure_is_on_device( def get_connectivity_args( - neighbor_tables: Sequence[tuple[str, itir_embedded.NeighborTableOffsetProvider]], + neighbor_tables: Mapping[str, common.NeighborTable], device: dace.dtypes.DeviceType, ) -> dict[str, Any]: return { - connectivity_identifier(offset): _ensure_is_on_device(table.table, device) - for offset, table in neighbor_tables + connectivity_identifier(offset): _ensure_is_on_device(offset_provider.table, device) + for offset, offset_provider in neighbor_tables.items() } def get_shape_args( arrays: Mapping[str, dace.data.Array], args: Mapping[str, Any] ) -> Mapping[str, int]: - return { - str(sym): size - for name, value in args.items() - for sym, size in zip(arrays[name].shape, value.shape) - } + shape_args: dict[str, int] = {} + for name, value in args.items(): + for sym, size in zip(arrays[name].shape, value.shape): + if isinstance(sym, dace.symbol): + assert sym.name not in shape_args + shape_args[sym.name] = size + elif sym != size: + raise RuntimeError( + f"Expected shape {arrays[name].shape} for arg {name}, got {value.shape}." + ) + return shape_args def get_offset_args( @@ -158,34 +165,41 @@ def get_stride_args( return stride_args -_build_cache_cpu: dict[str, CompiledSDFG] = {} -_build_cache_gpu: dict[str, CompiledSDFG] = {} +_build_cache: dict[str, CompiledSDFG] = {} def get_cache_id( + build_type: str, + build_for_gpu: bool, program: itir.FencilDefinition, arg_types: Sequence[ts.TypeSpec], column_axis: Optional[common.Dimension], offset_provider: Mapping[str, Any], ) -> str: - max_neighbors = [ - (k, v.max_neighbors) - for k, v in offset_provider.items() - if isinstance( - v, - ( - itir_embedded.NeighborTableOffsetProvider, - itir_embedded.StridedNeighborOffsetProvider, - ), - ) + def offset_invariants(offset): + if isinstance(offset, common.Connectivity): + return ( + offset.origin_axis, + offset.neighbor_axis, + offset.has_skip_values, + offset.max_neighbors, + ) + if isinstance(offset, common.Dimension): + return (offset,) + return tuple() + + offset_cache_keys = [ + (name, *offset_invariants(offset)) for name, offset in offset_provider.items() ] cache_id_args = [ str(arg) for arg in ( + build_type, + build_for_gpu, program, *arg_types, column_axis, - *max_neighbors, + *offset_cache_keys, ) ] m = hashlib.sha256() @@ -262,7 +276,7 @@ def build_sdfg_from_itir( # visit ITIR and generate SDFG program = preprocess_program(program, offset_provider, lift_mode) sdfg_genenerator = ItirToSDFG(arg_types, offset_provider, column_axis) - sdfg = sdfg_genenerator.visit(program) + sdfg: dace.SDFG = sdfg_genenerator.visit(program) if sdfg is None: raise RuntimeError(f"Visit failed for program {program.id}.") @@ -284,8 +298,8 @@ def build_sdfg_from_itir( # run DaCe auto-optimization heuristics if auto_optimize: - # TODO: Investigate how symbol definitions improve autoopt transformations, - # in which case the cache table should take the symbols map into account. + # TODO: Investigate performance improvement from SDFG specialization with constant symbols, + # for array shape and strides, although this would imply JIT compilation. symbols: dict[str, int] = {} device = dace.DeviceType.GPU if on_gpu else dace.DeviceType.CPU sdfg = autoopt.auto_optimize(sdfg, device, symbols=symbols, use_gpu_storage=on_gpu) @@ -307,25 +321,31 @@ def run_dace_iterator(program: itir.FencilDefinition, *args, **kwargs): # ITIR parameters column_axis = kwargs.get("column_axis", None) offset_provider = kwargs["offset_provider"] + # debug option to store SDFGs on filesystem and skip lowering ITIR to SDFG at each run + skip_itir_lowering_to_sdfg = kwargs.get("skip_itir_lowering_to_sdfg", False) arg_types = [type_translation.from_value(arg) for arg in args] - cache_id = get_cache_id(program, arg_types, column_axis, offset_provider) + cache_id = get_cache_id(build_type, on_gpu, program, arg_types, column_axis, offset_provider) if build_cache is not None and cache_id in build_cache: # retrieve SDFG program from build cache sdfg_program = build_cache[cache_id] sdfg = sdfg_program.sdfg - else: - sdfg = build_sdfg_from_itir( - program, - *args, - offset_provider=offset_provider, - auto_optimize=auto_optimize, - on_gpu=on_gpu, - column_axis=column_axis, - lift_mode=lift_mode, - ) + sdfg_filename = f"_dacegraphs/gt4py/{cache_id}/{program.id}.sdfg" + if not (skip_itir_lowering_to_sdfg and Path(sdfg_filename).exists()): + sdfg = build_sdfg_from_itir( + program, + *args, + offset_provider=offset_provider, + auto_optimize=auto_optimize, + on_gpu=on_gpu, + column_axis=column_axis, + lift_mode=lift_mode, + ) + sdfg.save(sdfg_filename) + else: + sdfg = dace.SDFG.from_file(sdfg_filename) sdfg.build_folder = compilation_cache._session_cache_dir_path / ".dacecache" with dace.config.temporary_config(): @@ -361,7 +381,7 @@ def _run_dace_cpu(program: itir.FencilDefinition, *args, **kwargs) -> None: program, *args, **kwargs, - build_cache=_build_cache_cpu, + build_cache=_build_cache, build_type=_build_type, compiler_args=compiler_args, on_gpu=False, @@ -380,7 +400,7 @@ def _run_dace_gpu(program: itir.FencilDefinition, *args, **kwargs) -> None: program, *args, **kwargs, - build_cache=_build_cache_gpu, + build_cache=_build_cache, build_type=_build_type, on_gpu=True, ) diff --git a/src/gt4py/next/program_processors/runners/dace_iterator/itir_to_sdfg.py b/src/gt4py/next/program_processors/runners/dace_iterator/itir_to_sdfg.py index dc194c0436..ce1ac6073a 100644 --- a/src/gt4py/next/program_processors/runners/dace_iterator/itir_to_sdfg.py +++ b/src/gt4py/next/program_processors/runners/dace_iterator/itir_to_sdfg.py @@ -11,14 +11,14 @@ # distribution for a copy of the license or check . # # SPDX-License-Identifier: GPL-3.0-or-later -from typing import Any, Optional, cast +from typing import Any, Mapping, Optional, Sequence, cast import dace import gt4py.eve as eve from gt4py.next import Dimension, DimensionKind, type_inference as next_typing +from gt4py.next.common import NeighborTable from gt4py.next.iterator import ir as itir, type_inference as itir_typing -from gt4py.next.iterator.embedded import NeighborTableOffsetProvider from gt4py.next.iterator.ir import Expr, FunCall, Literal, SymRef from gt4py.next.type_system import type_specifications as ts, type_translation @@ -43,13 +43,12 @@ flatten_list, get_sorted_dims, map_nested_sdfg_symbols, - new_array_symbols, unique_name, unique_var_name, ) -def get_scan_args(stencil: Expr) -> tuple[bool, Literal]: +def _get_scan_args(stencil: Expr) -> tuple[bool, Literal]: """ Parse stencil expression to extract the scan arguments. @@ -68,7 +67,7 @@ def get_scan_args(stencil: Expr) -> tuple[bool, Literal]: return is_forward.value == "True", init_carry -def get_scan_dim( +def _get_scan_dim( column_axis: Dimension, storage_types: dict[str, ts.TypeSpec], output: SymRef, @@ -93,6 +92,35 @@ def get_scan_dim( ) +def _make_array_shape_and_strides( + name: str, + dims: Sequence[Dimension], + neighbor_tables: Mapping[str, NeighborTable], + sort_dims: bool, +) -> tuple[list[dace.symbol], list[dace.symbol]]: + """ + Parse field dimensions and allocate symbols for array shape and strides. + + For local dimensions, the size is known at compile-time and therefore + the corresponding array shape dimension is set to an integer literal value. + + Returns + ------- + tuple(shape, strides) + The output tuple fields are arrays of dace symbolic expressions. + """ + dtype = dace.int64 + sorted_dims = [dim for _, dim in get_sorted_dims(dims)] if sort_dims else dims + shape = [ + neighbor_tables[dim.value].max_neighbors + if dim.kind == DimensionKind.LOCAL + else dace.symbol(unique_name(f"{name}_shape{i}"), dtype) + for i, dim in enumerate(sorted_dims) + ] + strides = [dace.symbol(unique_name(f"{name}_stride{i}"), dtype) for i, _ in enumerate(shape)] + return shape, strides + + class ItirToSDFG(eve.NodeVisitor): param_types: list[ts.TypeSpec] storage_types: dict[str, ts.TypeSpec] @@ -104,7 +132,7 @@ class ItirToSDFG(eve.NodeVisitor): def __init__( self, param_types: list[ts.TypeSpec], - offset_provider: dict[str, NeighborTableOffsetProvider], + offset_provider: dict[str, NeighborTable], column_axis: Optional[Dimension] = None, ): self.param_types = param_types @@ -112,9 +140,19 @@ def __init__( self.offset_provider = offset_provider self.storage_types = {} - def add_storage(self, sdfg: dace.SDFG, name: str, type_: ts.TypeSpec, has_offset: bool = True): + def add_storage( + self, + sdfg: dace.SDFG, + name: str, + type_: ts.TypeSpec, + neighbor_tables: Mapping[str, NeighborTable], + has_offset: bool = True, + sort_dimensions: bool = True, + ): if isinstance(type_, ts.FieldType): - shape, strides = new_array_symbols(name, len(type_.dims)) + shape, strides = _make_array_shape_and_strides( + name, type_.dims, neighbor_tables, sort_dimensions + ) offset = ( [dace.symbol(unique_name(f"{name}_offset{i}_")) for i in range(len(type_.dims))] if has_offset @@ -153,14 +191,23 @@ def visit_FencilDefinition(self, node: itir.FencilDefinition): # Add program parameters as SDFG storages. for param, type_ in zip(node.params, self.param_types): - self.add_storage(program_sdfg, str(param.id), type_) + self.add_storage(program_sdfg, str(param.id), type_, neighbor_tables) # Add connectivities as SDFG storages. - for offset, table in neighbor_tables: - scalar_kind = type_translation.get_scalar_kind(table.table.dtype) - local_dim = Dimension("ElementDim", kind=DimensionKind.LOCAL) - type_ = ts.FieldType([table.origin_axis, local_dim], ts.ScalarType(scalar_kind)) - self.add_storage(program_sdfg, connectivity_identifier(offset), type_, has_offset=False) + for offset, offset_provider in neighbor_tables.items(): + scalar_kind = type_translation.get_scalar_kind(offset_provider.table.dtype) + local_dim = Dimension(offset, kind=DimensionKind.LOCAL) + type_ = ts.FieldType( + [offset_provider.origin_axis, local_dim], ts.ScalarType(scalar_kind) + ) + self.add_storage( + program_sdfg, + connectivity_identifier(offset), + type_, + neighbor_tables, + has_offset=False, + sort_dimensions=False, + ) # Create a nested SDFG for all stencil closures. for closure in node.closures: @@ -222,7 +269,7 @@ def visit_StencilClosure( input_names = [str(inp.id) for inp in node.inputs] neighbor_tables = filter_neighbor_tables(self.offset_provider) - connectivity_names = [connectivity_identifier(offset) for offset, _ in neighbor_tables] + connectivity_names = [connectivity_identifier(offset) for offset in neighbor_tables.keys()] output_nodes = self.get_output_nodes(node, closure_sdfg, closure_state) output_names = [k for k, _ in output_nodes.items()] @@ -400,11 +447,11 @@ def _visit_scan_stencil_closure( output_name: str, ) -> tuple[dace.SDFG, dict[str, str | dace.subsets.Subset], int]: # extract scan arguments - is_forward, init_carry_value = get_scan_args(node.stencil) + is_forward, init_carry_value = _get_scan_args(node.stencil) # select the scan dimension based on program argument for column axis assert self.column_axis assert isinstance(node.output, SymRef) - scan_dim, scan_dim_index, scan_dtype = get_scan_dim( + scan_dim, scan_dim_index, scan_dtype = _get_scan_dim( self.column_axis, self.storage_types, node.output, @@ -570,7 +617,7 @@ def _visit_parallel_stencil_closure( ) -> tuple[dace.SDFG, dict[str, str | dace.subsets.Subset], list[str]]: neighbor_tables = filter_neighbor_tables(self.offset_provider) input_names = [str(inp.id) for inp in node.inputs] - conn_names = [connectivity_identifier(offset) for offset, _ in neighbor_tables] + connectivity_names = [connectivity_identifier(offset) for offset in neighbor_tables.keys()] # find the scan dimension, same as output dimension, and exclude it from the map domain map_ranges = {} @@ -583,7 +630,7 @@ def _visit_parallel_stencil_closure( index_domain = {dim: f"i_{dim}" for dim, _ in closure_domain} input_arrays = [(name, self.storage_types[name]) for name in input_names] - connectivity_arrays = [(array_table[name], name) for name in conn_names] + connectivity_arrays = [(array_table[name], name) for name in connectivity_names] context, results = closure_to_tasklet_sdfg( node, diff --git a/src/gt4py/next/program_processors/runners/dace_iterator/itir_to_tasklet.py b/src/gt4py/next/program_processors/runners/dace_iterator/itir_to_tasklet.py index 0ace6948b0..322a147382 100644 --- a/src/gt4py/next/program_processors/runners/dace_iterator/itir_to_tasklet.py +++ b/src/gt4py/next/program_processors/runners/dace_iterator/itir_to_tasklet.py @@ -237,7 +237,12 @@ def builtin_neighbors( ) data_access_tasklet = state.add_tasklet( "data_access", - code=f"__result = __field[{field_index}] if {neighbor_check} else {transformer.context.reduce_identity.value}", + code=f"__result = __field[{field_index}]" + + ( + f" if {neighbor_check} else {transformer.context.reduce_identity.value}" + if offset_provider.has_skip_values + else "" + ), inputs={"__field", field_index}, outputs={"__result"}, debuginfo=di, @@ -372,20 +377,25 @@ def builtin_list_get( args = list(itertools.chain(*transformer.visit(node_args))) assert len(args) == 2 # index node - assert isinstance(args[0], (SymbolExpr, ValueExpr)) - # 1D-array node - assert isinstance(args[1], ValueExpr) - # source node should be a 1D array - assert len(transformer.context.body.arrays[args[1].value.data].shape) == 1 - - expr_args = [(arg, f"{arg.value.data}_v") for arg in args if not isinstance(arg, SymbolExpr)] - internals = [ - arg.value if isinstance(arg, SymbolExpr) else f"{arg.value.data}_v" for arg in args - ] - expr = f"{internals[1]}[{internals[0]}]" - return transformer.add_expr_tasklet( - expr_args, expr, args[1].dtype, "list_get", dace_debuginfo=di - ) + if isinstance(args[0], SymbolExpr): + index_value = args[0].value + result_name = unique_var_name() + transformer.context.body.add_scalar(result_name, args[1].dtype, transient=True) + result_node = transformer.context.state.add_access(result_name) + transformer.context.state.add_nedge( + args[1].value, + result_node, + dace.Memlet.simple(args[1].value.data, index_value), + ) + return [ValueExpr(result_node, args[1].dtype)] + + else: + expr_args = [(arg, f"{arg.value.data}_v") for arg in args] + internals = [f"{arg.value.data}_v" for arg in args] + expr = f"{internals[1]}[{internals[0]}]" + return transformer.add_expr_tasklet( + expr_args, expr, args[1].dtype, "list_get", dace_debuginfo=di + ) def builtin_cast( @@ -562,9 +572,9 @@ def visit_Lambda( ]: func_name = f"lambda_{abs(hash(node)):x}" neighbor_tables = ( - filter_neighbor_tables(self.offset_provider) if use_neighbor_tables else [] + filter_neighbor_tables(self.offset_provider) if use_neighbor_tables else {} ) - connectivity_names = [connectivity_identifier(offset) for offset, _ in neighbor_tables] + connectivity_names = [connectivity_identifier(offset) for offset in neighbor_tables.keys()] # Create the SDFG for the lambda's body lambda_sdfg = dace.SDFG(func_name) @@ -700,8 +710,8 @@ def _visit_call(self, node: itir.FunCall): nsdfg_inputs[var] = create_memlet_full(store, self.context.body.arrays[store]) neighbor_tables = filter_neighbor_tables(self.offset_provider) - for conn, _ in neighbor_tables: - var = connectivity_identifier(conn) + for offset in neighbor_tables.keys(): + var = connectivity_identifier(offset) nsdfg_inputs[var] = create_memlet_full(var, self.context.body.arrays[var]) symbol_mapping = map_nested_sdfg_symbols(self.context.body, func_context.body, nsdfg_inputs) @@ -729,8 +739,8 @@ def _visit_call(self, node: itir.FunCall): store = value.indices[dim] idx_memlet = nsdfg_inputs[var] self.context.state.add_edge(store, None, nsdfg_node, var, idx_memlet) - for conn, _ in neighbor_tables: - var = connectivity_identifier(conn) + for offset in neighbor_tables.keys(): + var = connectivity_identifier(offset) memlet = nsdfg_inputs[var] access = self.context.state.add_access(var, debuginfo=nsdfg_node.debuginfo) self.context.state.add_edge(access, None, nsdfg_node, var, memlet) diff --git a/src/gt4py/next/program_processors/runners/dace_iterator/utility.py b/src/gt4py/next/program_processors/runners/dace_iterator/utility.py index 971c1bbdf2..a66fc36b1b 100644 --- a/src/gt4py/next/program_processors/runners/dace_iterator/utility.py +++ b/src/gt4py/next/program_processors/runners/dace_iterator/utility.py @@ -17,7 +17,7 @@ import dace from gt4py.next import Dimension -from gt4py.next.iterator.embedded import NeighborTableOffsetProvider +from gt4py.next.common import NeighborTable from gt4py.next.iterator.ir import Node from gt4py.next.type_system import type_specifications as ts @@ -52,11 +52,11 @@ def as_dace_type(type_: ts.ScalarType): def filter_neighbor_tables(offset_provider: dict[str, Any]): - return [ - (offset, table) + return { + offset: table for offset, table in offset_provider.items() - if isinstance(table, NeighborTableOffsetProvider) - ] + if isinstance(table, NeighborTable) + } def connectivity_identifier(name: str): From 70f0f88df76d10f29f28a63bcc8802460da2269c Mon Sep 17 00:00:00 2001 From: edopao Date: Thu, 25 Jan 2024 13:02:13 +0100 Subject: [PATCH 54/85] feat[next][dace]: use new LoopRegion construct for scan operator (#1424) The lowering of scan operator to SDFG uses a state machine to represent a loop. This PR replaces the state machine with a LoopRegion construct introduced in dace v0.15. The LoopRegion construct is not yet supported by dace transformation, but it will in the future and it could open new optimization opportunities (e.g. K-caching). --- .../runners/dace_iterator/__init__.py | 4 ++ .../runners/dace_iterator/itir_to_sdfg.py | 71 ++++++++++--------- 2 files changed, 43 insertions(+), 32 deletions(-) diff --git a/src/gt4py/next/program_processors/runners/dace_iterator/__init__.py b/src/gt4py/next/program_processors/runners/dace_iterator/__init__.py index a039d311ca..6a8b9bc9c6 100644 --- a/src/gt4py/next/program_processors/runners/dace_iterator/__init__.py +++ b/src/gt4py/next/program_processors/runners/dace_iterator/__init__.py @@ -20,6 +20,7 @@ import dace import numpy as np from dace.codegen.compiled_sdfg import CompiledSDFG +from dace.sdfg import utils as sdutils from dace.transformation.auto import auto_optimize as autoopt import gt4py.next.allocators as next_allocators @@ -293,6 +294,9 @@ def build_sdfg_from_itir( filename=frameinfo.filename, ) + # TODO(edopao): remove `inline_loop_blocks` when DaCe transformations support LoopRegion construct + sdutils.inline_loop_blocks(sdfg) + # run DaCe transformations to simplify the SDFG sdfg.simplify() diff --git a/src/gt4py/next/program_processors/runners/dace_iterator/itir_to_sdfg.py b/src/gt4py/next/program_processors/runners/dace_iterator/itir_to_sdfg.py index ce1ac6073a..8a7826dae4 100644 --- a/src/gt4py/next/program_processors/runners/dace_iterator/itir_to_sdfg.py +++ b/src/gt4py/next/program_processors/runners/dace_iterator/itir_to_sdfg.py @@ -14,6 +14,7 @@ from typing import Any, Mapping, Optional, Sequence, cast import dace +from dace.sdfg.state import LoopRegion import gt4py.eve as eve from gt4py.next import Dimension, DimensionKind, type_inference as next_typing @@ -477,15 +478,38 @@ def _visit_scan_stencil_closure( scan_sdfg = dace.SDFG(name="scan") scan_sdfg.debuginfo = dace_debuginfo(node) - # create a state machine for lambda call over the scan dimension - start_state = scan_sdfg.add_state("start", True) - lambda_state = scan_sdfg.add_state("lambda_compute") - end_state = scan_sdfg.add_state("end") - # the carry value of the scan operator exists only in the scope of the scan sdfg scan_carry_name = unique_var_name() scan_sdfg.add_scalar(scan_carry_name, dtype=as_dace_type(scan_dtype), transient=True) + # create a loop region for lambda call over the scan dimension + scan_loop_var = f"i_{scan_dim}" + if is_forward: + scan_loop = LoopRegion( + label="scan", + condition_expr=f"{scan_loop_var} < {scan_ub_str}", + loop_var=scan_loop_var, + initialize_expr=f"{scan_loop_var} = {scan_lb_str}", + update_expr=f"{scan_loop_var} = {scan_loop_var} + 1", + inverted=False, + ) + else: + scan_loop = LoopRegion( + label="scan", + condition_expr=f"{scan_loop_var} >= {scan_lb_str}", + loop_var=scan_loop_var, + initialize_expr=f"{scan_loop_var} = {scan_ub_str} - 1", + update_expr=f"{scan_loop_var} = {scan_loop_var} - 1", + inverted=False, + ) + scan_sdfg.add_node(scan_loop) + compute_state = scan_loop.add_state("lambda_compute", is_start_block=True) + update_state = scan_loop.add_state("lambda_update") + scan_loop.add_edge(compute_state, update_state, dace.InterstateEdge()) + + start_state = scan_sdfg.add_state("start", is_start_block=True) + scan_sdfg.add_edge(start_state, scan_loop, dace.InterstateEdge()) + # tasklet for initialization of carry carry_init_tasklet = start_state.add_tasklet( "get_carry_init_value", @@ -502,19 +526,6 @@ def _visit_scan_stencil_closure( dace.Memlet.simple(scan_carry_name, "0"), ) - # TODO(edopao): replace state machine with dace loop construct - scan_sdfg.add_loop( - start_state, - lambda_state, - end_state, - loop_var=f"i_{scan_dim}", - initialize_expr=f"{scan_lb_str}" if is_forward else f"{scan_ub_str} - 1", - condition_expr=f"i_{scan_dim} < {scan_ub_str}" - if is_forward - else f"i_{scan_dim} >= {scan_lb_str}", - increment_expr=f"i_{scan_dim} + 1" if is_forward else f"i_{scan_dim} - 1", - ) - # add storage to scan SDFG for inputs for name in [*input_names, *connectivity_names]: assert name not in scan_sdfg.arrays @@ -569,7 +580,7 @@ def _visit_scan_stencil_closure( array_mapping = {**input_mapping, **connectivity_mapping} symbol_mapping = map_nested_sdfg_symbols(scan_sdfg, lambda_context.body, array_mapping) - scan_inner_node = lambda_state.add_nested_sdfg( + scan_inner_node = compute_state.add_nested_sdfg( lambda_context.body, parent=scan_sdfg, inputs=set(lambda_input_names) | set(connectivity_names), @@ -580,29 +591,25 @@ def _visit_scan_stencil_closure( # connect scan SDFG to lambda inputs for name, memlet in array_mapping.items(): - access_node = lambda_state.add_access(name, debuginfo=lambda_context.body.debuginfo) - lambda_state.add_edge(access_node, None, scan_inner_node, name, memlet) + access_node = compute_state.add_access(name, debuginfo=lambda_context.body.debuginfo) + compute_state.add_edge(access_node, None, scan_inner_node, name, memlet) output_names = [output_name] assert len(lambda_output_names) == 1 # connect lambda output to scan SDFG for name, connector in zip(output_names, lambda_output_names): - lambda_state.add_edge( + compute_state.add_edge( scan_inner_node, connector, - lambda_state.add_access(name, debuginfo=lambda_context.body.debuginfo), + compute_state.add_access(name, debuginfo=lambda_context.body.debuginfo), None, - dace.Memlet.simple(name, f"i_{scan_dim}"), + dace.Memlet.simple(name, scan_loop_var), ) - # add state to scan SDFG to update the carry value at each loop iteration - lambda_update_state = scan_sdfg.add_state_after(lambda_state, "lambda_update") - lambda_update_state.add_memlet_path( - lambda_update_state.add_access(output_name, debuginfo=lambda_context.body.debuginfo), - lambda_update_state.add_access( - scan_carry_name, debuginfo=lambda_context.body.debuginfo - ), - memlet=dace.Memlet.simple(output_names[0], f"i_{scan_dim}", other_subset_str="0"), + update_state.add_nedge( + update_state.add_access(output_name, debuginfo=lambda_context.body.debuginfo), + update_state.add_access(scan_carry_name, debuginfo=lambda_context.body.debuginfo), + dace.Memlet.simple(output_names[0], scan_loop_var, other_subset_str="0"), ) return scan_sdfg, map_ranges, scan_dim_index From 9cd9879181669c6a8b8f58db329c58e4bef1813d Mon Sep 17 00:00:00 2001 From: edopao Date: Thu, 25 Jan 2024 14:19:11 +0100 Subject: [PATCH 55/85] Remove usage of deprecated API dace.Memlet.simple (#1425) Replace deprecated constructor API dace.Memlet.simple() with dace.Memlet() --- .../runners/dace_iterator/itir_to_sdfg.py | 14 ++++++------- .../runners/dace_iterator/itir_to_tasklet.py | 20 +++++++++---------- .../runners/dace_iterator/utility.py | 2 +- 3 files changed, 18 insertions(+), 18 deletions(-) diff --git a/src/gt4py/next/program_processors/runners/dace_iterator/itir_to_sdfg.py b/src/gt4py/next/program_processors/runners/dace_iterator/itir_to_sdfg.py index 8a7826dae4..63e9fb03dc 100644 --- a/src/gt4py/next/program_processors/runners/dace_iterator/itir_to_sdfg.py +++ b/src/gt4py/next/program_processors/runners/dace_iterator/itir_to_sdfg.py @@ -338,7 +338,7 @@ def visit_StencilClosure( out_name, debuginfo=closure_sdfg.debuginfo ) value = ValueExpr(access, dtype) - memlet = dace.Memlet.simple(out_name, "0") + memlet = dace.Memlet(data=out_name, subset="0") closure_init_state.add_edge(out_tasklet, "__result", access, None, memlet) program_arg_syms[name] = value else: @@ -427,10 +427,10 @@ def visit_StencilClosure( edge.src_conn, transient_access, None, - dace.Memlet.simple(memlet.data, output_subset, debuginfo=nsdfg.debuginfo), + dace.Memlet(data=memlet.data, subset=output_subset, debuginfo=nsdfg.debuginfo), ) - inner_memlet = dace.Memlet.simple( - memlet.data, output_subset, other_subset_str=memlet.subset + inner_memlet = dace.Memlet( + data=memlet.data, subset=output_subset, other_subset=memlet.subset ) closure_state.add_edge(transient_access, None, map_exit, edge.dst_conn, inner_memlet) closure_state.remove_edge(edge) @@ -523,7 +523,7 @@ def _visit_scan_stencil_closure( "__result", start_state.add_access(scan_carry_name, debuginfo=scan_sdfg.debuginfo), None, - dace.Memlet.simple(scan_carry_name, "0"), + dace.Memlet(data=scan_carry_name, subset="0"), ) # add storage to scan SDFG for inputs @@ -603,13 +603,13 @@ def _visit_scan_stencil_closure( connector, compute_state.add_access(name, debuginfo=lambda_context.body.debuginfo), None, - dace.Memlet.simple(name, scan_loop_var), + dace.Memlet(data=name, subset=scan_loop_var), ) update_state.add_nedge( update_state.add_access(output_name, debuginfo=lambda_context.body.debuginfo), update_state.add_access(scan_carry_name, debuginfo=lambda_context.body.debuginfo), - dace.Memlet.simple(output_names[0], scan_loop_var, other_subset_str="0"), + dace.Memlet(data=output_name, subset=scan_loop_var, other_subset="0"), ) return scan_sdfg, map_ranges, scan_dim_index diff --git a/src/gt4py/next/program_processors/runners/dace_iterator/itir_to_tasklet.py b/src/gt4py/next/program_processors/runners/dace_iterator/itir_to_tasklet.py index 322a147382..ab03d29389 100644 --- a/src/gt4py/next/program_processors/runners/dace_iterator/itir_to_tasklet.py +++ b/src/gt4py/next/program_processors/runners/dace_iterator/itir_to_tasklet.py @@ -260,7 +260,7 @@ def builtin_neighbors( iterator.indices[shifted_dim], me, shift_tasklet, - memlet=dace.Memlet.simple(iterator.indices[shifted_dim].data, "0", debuginfo=di), + memlet=dace.Memlet(data=iterator.indices[shifted_dim].data, subset="0", debuginfo=di), dst_conn="__idx", ) state.add_edge(shift_tasklet, "__result", data_access_tasklet, field_index, dace.Memlet()) @@ -280,7 +280,7 @@ def builtin_neighbors( data_access_tasklet, mx, result_access, - memlet=dace.Memlet.simple(result_name, neighbor_index, debuginfo=di), + memlet=dace.Memlet(data=result_name, subset=neighbor_index, debuginfo=di), src_conn="__result", ) @@ -315,7 +315,7 @@ def builtin_can_deref( "_out", result_node, None, - dace.Memlet.simple(result_name, "0", debuginfo=di), + dace.Memlet(data=result_name, subset="0", debuginfo=di), ) return [ValueExpr(result_node, dace.dtypes.bool)] @@ -385,7 +385,7 @@ def builtin_list_get( transformer.context.state.add_nedge( args[1].value, result_node, - dace.Memlet.simple(args[1].value.data, index_value), + dace.Memlet(data=args[1].value.data, subset=index_value), ) return [ValueExpr(result_node, args[1].dtype)] @@ -634,7 +634,7 @@ def visit_Lambda( lambda_state.add_nedge( expr.value, result_access, - dace.Memlet.simple(result_access.data, "0"), + dace.Memlet(data=result_access.data, subset="0"), ) result = ValueExpr(value=result_access, dtype=expr.dtype) else: @@ -801,7 +801,7 @@ def _visit_deref(self, node: itir.FunCall) -> list[ValueExpr]: iterator.indices[dim] for dim in sorted_dims if dim in iterator.indices ] deref_memlets = [dace.Memlet.from_array(iterator.field.data, field_array)] + [ - dace.Memlet.simple(node.data, "0") for node in deref_nodes[1:] + dace.Memlet(data=node.data, subset="0") for node in deref_nodes[1:] ] # we create a mapped tasklet for array slicing @@ -927,7 +927,7 @@ def visit_OffsetLiteral(self, node: itir.OffsetLiteral) -> list[ValueExpr]: "get_offset", {}, {"__out"}, f"__out = {offset}", debuginfo=di ) self.context.state.add_edge( - tasklet_node, "__out", offset_node, None, dace.Memlet.simple(offset_var, "0") + tasklet_node, "__out", offset_node, None, dace.Memlet(data=offset_var, subset="0") ) return [ValueExpr(offset_node, self.context.body.arrays[offset_var].dtype)] @@ -1036,7 +1036,7 @@ def _visit_reduce(self, node: itir.FunCall): dace.Memlet.from_array(reduce_input_node.data, reduce_input_desc), ) self.context.state.add_nedge( - reduce_node, result_access, dace.Memlet.simple(result_name, "0") + reduce_node, result_access, dace.Memlet(data=result_name, subset="0") ) # we apply map fusion only to the nested-SDFG which is generated for the reduction operator @@ -1108,7 +1108,7 @@ def add_expr_tasklet( ) self.context.state.add_edge(arg.value, None, expr_tasklet, internal, memlet) - memlet = dace.Memlet.simple(result_access.data, "0", debuginfo=di) + memlet = dace.Memlet(data=result_access.data, subset="0", debuginfo=di) self.context.state.add_edge(expr_tasklet, "__result", result_access, None, memlet) return [ValueExpr(result_access, result_type)] @@ -1140,7 +1140,7 @@ def closure_to_tasklet_sdfg( ) access = state.add_access(name, debuginfo=body.debuginfo) idx_accesses[dim] = access - state.add_edge(tasklet, "value", access, None, dace.Memlet.simple(name, "0")) + state.add_edge(tasklet, "value", access, None, dace.Memlet(data=name, subset="0")) for name, ty in inputs: if isinstance(ty, ts.FieldType): ndim = len(ty.dims) diff --git a/src/gt4py/next/program_processors/runners/dace_iterator/utility.py b/src/gt4py/next/program_processors/runners/dace_iterator/utility.py index a66fc36b1b..49dd2472c5 100644 --- a/src/gt4py/next/program_processors/runners/dace_iterator/utility.py +++ b/src/gt4py/next/program_processors/runners/dace_iterator/utility.py @@ -69,7 +69,7 @@ def create_memlet_full(source_identifier: str, source_array: dace.data.Array): def create_memlet_at(source_identifier: str, index: tuple[str, ...]): subset = ", ".join(index) - return dace.Memlet.simple(source_identifier, subset) + return dace.Memlet(data=source_identifier, subset=subset) def get_sorted_dims(dims: Sequence[Dimension]) -> Sequence[tuple[int, Dimension]]: From 8c3b3d739dcfd9d989dee66e94c0248e45d46d5c Mon Sep 17 00:00:00 2001 From: Hannes Vogt Date: Fri, 26 Jan 2024 09:32:08 +0100 Subject: [PATCH 56/85] ci: test jupyter notebooks (#1426) --- .github/workflows/test-notebooks.yml | 42 ++++ .pre-commit-config.yaml | 30 +-- constraints.txt | 203 ++++++++++-------- examples/lap_cartesian_vs_next.ipynb | 4 +- min-extra-requirements-test.txt | 2 + min-requirements-test.txt | 2 + requirements-dev.in | 2 + requirements-dev.txt | 203 ++++++++++-------- .../cartesian/frontend/gtscript_frontend.py | 4 +- .../ffront/ast_passes/remove_docstrings.py | 17 +- tox.ini | 6 +- 11 files changed, 302 insertions(+), 213 deletions(-) create mode 100644 .github/workflows/test-notebooks.yml diff --git a/.github/workflows/test-notebooks.yml b/.github/workflows/test-notebooks.yml new file mode 100644 index 0000000000..9de6a409e8 --- /dev/null +++ b/.github/workflows/test-notebooks.yml @@ -0,0 +1,42 @@ +name: "Test Jupyter Notebooks" + +on: + push: + branches: + - main + pull_request: + branches: + - main + +jobs: + test-notebooks: + strategy: + matrix: + python-version: ["3.10", "3.11"] + os: ["ubuntu-latest"] + fail-fast: false + + runs-on: ${{ matrix.os }} + steps: + - uses: actions/checkout@v3 + - name: Set up Python ${{ matrix.python-version }} + uses: actions/setup-python@v4 + with: + python-version: ${{ matrix.python-version }} + cache: 'pip' + cache-dependency-path: | + **/pyproject.toml + **/constraints.txt + **/requirements-dev.txt + - name: Install python dependencies + run: | + python -m pip install -c ./constraints.txt pip setuptools wheel + python -m pip install -r ./requirements-dev.txt + - name: Run tox tests + env: + NUM_PROCESSES: auto + shell: bash + run: | + pyversion=${{ matrix.python-version }} + pyversion_no_dot=${pyversion//./} + tox run -e notebooks-py${pyversion_no_dot} diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index d9cfa0ff48..3259a74f38 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -62,7 +62,7 @@ repos: ## version = re.search('black==([0-9\.]*)', open("constraints.txt").read())[1] ## print(f"rev: '{version}' # version from constraints.txt") ##]]] - rev: '23.11.0' # version from constraints.txt + rev: '23.12.1' # version from constraints.txt ##[[[end]]] hooks: - id: black @@ -73,7 +73,7 @@ repos: ## version = re.search('isort==([0-9\.]*)', open("constraints.txt").read())[1] ## print(f"rev: '{version}' # version from constraints.txt") ##]]] - rev: '5.13.0' # version from constraints.txt + rev: '5.13.2' # version from constraints.txt ##[[[end]]] hooks: - id: isort @@ -84,7 +84,7 @@ repos: ## version = re.search('flake8==([0-9\.]*)', open("constraints.txt").read())[1] ## print(f"rev: '{version}' # version from constraints.txt") ##]]] - rev: '6.1.0' # version from constraints.txt + rev: '7.0.0' # version from constraints.txt ##[[[end]]] hooks: - id: flake8 @@ -97,7 +97,7 @@ repos: ## print(f"- {pkg}==" + str(re.search(f'\n{pkg}==([0-9\.]*)', constraints)[1])) ##]]] - darglint==1.8.1 - - flake8-bugbear==23.12.2 + - flake8-bugbear==24.1.17 - flake8-builtins==2.2.0 - flake8-debugger==4.1.2 - flake8-docstrings==1.7.0 @@ -146,9 +146,9 @@ repos: ## version = re.search('mypy==([0-9\.]*)', open("constraints.txt").read())[1] ## print(f"#========= FROM constraints.txt: v{version} =========") ##]]] - #========= FROM constraints.txt: v1.7.1 ========= + #========= FROM constraints.txt: v1.8.0 ========= ##[[[end]]] - rev: v1.7.1 # MUST match version ^^^^ in constraints.txt (if the mirror is up-to-date) + rev: v1.8.0 # MUST match version ^^^^ in constraints.txt (if the mirror is up-to-date) hooks: - id: mypy additional_dependencies: # versions from constraints.txt @@ -161,27 +161,27 @@ repos: ## print(f"- {pkg}==" + str(re.search(f'\n{pkg}==([0-9\.]*)', constraints)[1])) ##]]] - astunparse==1.6.3 - - attrs==23.1.0 - - black==23.11.0 + - attrs==23.2.0 + - black==23.12.1 - boltons==23.1.1 - cached-property==1.5.2 - click==8.1.7 - - cmake==3.27.9 - - cytoolz==0.12.2 + - cmake==3.28.1 + - cytoolz==0.12.3 - deepdiff==6.7.1 - devtools==0.12.2 - - frozendict==2.3.10 + - frozendict==2.4.0 - gridtools-cpp==2.3.1 - importlib-resources==6.1.1 - - jinja2==3.1.2 - - lark==1.1.8 - - mako==1.3.0 + - jinja2==3.1.3 + - lark==1.1.9 + - mako==1.3.1 - nanobind==1.8.0 - ninja==1.11.1.1 - numpy==1.24.4 - packaging==23.2 - pybind11==2.11.1 - - setuptools==69.0.2 + - setuptools==69.0.3 - tabulate==0.9.0 - typing-extensions==4.5.0 - xxhash==3.0.0 diff --git a/constraints.txt b/constraints.txt index 81abd64c6e..343615b421 100644 --- a/constraints.txt +++ b/constraints.txt @@ -6,17 +6,17 @@ # aenum==3.1.15 # via dace alabaster==0.7.13 # via sphinx -asttokens==2.4.1 # via devtools +asttokens==2.4.1 # via devtools, stack-data astunparse==1.6.3 ; python_version < "3.9" # via dace, gt4py (pyproject.toml) -attrs==23.1.0 # via flake8-bugbear, flake8-eradicate, gt4py (pyproject.toml), hypothesis, jsonschema, referencing -babel==2.13.1 # via sphinx -black==23.11.0 # via gt4py (pyproject.toml) +attrs==23.2.0 # via flake8-bugbear, flake8-eradicate, gt4py (pyproject.toml), hypothesis, jsonschema, referencing +babel==2.14.0 # via sphinx +backcall==0.2.0 # via ipython +black==23.12.1 # via gt4py (pyproject.toml) blinker==1.7.0 # via flask boltons==23.1.1 # via gt4py (pyproject.toml) build==1.0.3 # via pip-tools cached-property==1.5.2 # via gt4py (pyproject.toml) cachetools==5.3.2 # via tox -cerberus==1.3.5 # via plette certifi==2023.11.17 # via requests cffi==1.16.0 # via cryptography cfgv==3.4.0 # via pre-commit @@ -24,31 +24,34 @@ chardet==5.2.0 # via tox charset-normalizer==3.3.2 # via requests clang-format==17.0.6 # via -r requirements-dev.in, gt4py (pyproject.toml) click==8.1.7 # via black, flask, gt4py (pyproject.toml), pip-tools -cmake==3.27.9 # via dace, gt4py (pyproject.toml) +cmake==3.28.1 # via gt4py (pyproject.toml) cogapp==3.3.0 # via -r requirements-dev.in colorama==0.4.6 # via tox -coverage==7.3.2 # via -r requirements-dev.in, pytest-cov -cryptography==41.0.7 # via types-paramiko, types-pyopenssl, types-redis -cytoolz==0.12.2 # via gt4py (pyproject.toml) +comm==0.2.1 # via ipykernel +contourpy==1.1.1 # via matplotlib +coverage==7.4.0 # via -r requirements-dev.in, coverage, pytest-cov +cryptography==42.0.1 # via types-paramiko, types-pyopenssl, types-redis +cycler==0.12.1 # via matplotlib +cytoolz==0.12.3 # via gt4py (pyproject.toml) dace==0.15.1 # via gt4py (pyproject.toml) darglint==1.8.1 # via -r requirements-dev.in +debugpy==1.8.0 # via ipykernel +decorator==5.1.1 # via ipython deepdiff==6.7.1 # via gt4py (pyproject.toml) devtools==0.12.2 # via gt4py (pyproject.toml) dill==0.3.7 # via dace -distlib==0.3.7 # via requirementslib, virtualenv -distro==1.8.0 # via scikit-build -docopt==0.6.2 # via pipreqs +distlib==0.3.8 # via virtualenv docutils==0.20.1 # via restructuredtext-lint, sphinx, sphinx-rtd-theme eradicate==2.3.0 # via flake8-eradicate exceptiongroup==1.2.0 # via hypothesis, pytest execnet==2.0.2 # via pytest-cache, pytest-xdist -executing==2.0.1 # via devtools +executing==2.0.1 # via devtools, stack-data factory-boy==3.3.0 # via -r requirements-dev.in, pytest-factoryboy -faker==20.1.0 # via factory-boy -fastjsonschema==2.19.0 # via nbformat +faker==22.5.1 # via factory-boy +fastjsonschema==2.19.1 # via nbformat filelock==3.13.1 # via tox, virtualenv -flake8==6.1.0 # via -r requirements-dev.in, flake8-bugbear, flake8-builtins, flake8-debugger, flake8-docstrings, flake8-eradicate, flake8-mutable, flake8-pyproject, flake8-rst-docstrings -flake8-bugbear==23.12.2 # via -r requirements-dev.in +flake8==7.0.0 # via -r requirements-dev.in, flake8-bugbear, flake8-builtins, flake8-debugger, flake8-docstrings, flake8-eradicate, flake8-mutable, flake8-pyproject, flake8-rst-docstrings +flake8-bugbear==24.1.17 # via -r requirements-dev.in flake8-builtins==2.2.0 # via -r requirements-dev.in flake8-debugger==4.1.2 # via -r requirements-dev.in flake8-docstrings==1.7.0 # via -r requirements-dev.in @@ -56,80 +59,93 @@ flake8-eradicate==1.5.0 # via -r requirements-dev.in flake8-mutable==1.2.0 # via -r requirements-dev.in flake8-pyproject==1.2.3 # via -r requirements-dev.in flake8-rst-docstrings==0.3.0 # via -r requirements-dev.in -flask==3.0.0 # via dace +flask==3.0.1 # via dace +fonttools==4.47.2 # via matplotlib fparser==0.1.3 # via dace -frozendict==2.3.10 # via gt4py (pyproject.toml) +frozendict==2.4.0 # via gt4py (pyproject.toml) gridtools-cpp==2.3.1 # via gt4py (pyproject.toml) -hypothesis==6.92.0 # via -r requirements-dev.in, gt4py (pyproject.toml) +hypothesis==6.97.0 # via -r requirements-dev.in, gt4py (pyproject.toml) identify==2.5.33 # via pre-commit idna==3.6 # via requests imagesize==1.4.1 # via sphinx -importlib-metadata==7.0.0 # via build, flask, fparser, sphinx -importlib-resources==6.1.1 ; python_version < "3.9" # via gt4py (pyproject.toml), jsonschema, jsonschema-specifications +importlib-metadata==7.0.1 # via build, flask, fparser, jupyter-client, sphinx +importlib-resources==6.1.1 ; python_version < "3.9" # via gt4py (pyproject.toml), jsonschema, jsonschema-specifications, matplotlib inflection==0.5.1 # via pytest-factoryboy iniconfig==2.0.0 # via pytest -isort==5.13.0 # via -r requirements-dev.in +ipykernel==6.29.0 # via nbmake +ipython==8.12.3 # via ipykernel +isort==5.13.2 # via -r requirements-dev.in itsdangerous==2.1.2 # via flask -jinja2==3.1.2 # via flask, gt4py (pyproject.toml), sphinx -jsonschema==4.20.0 # via nbformat -jsonschema-specifications==2023.11.2 # via jsonschema -jupyter-core==5.5.0 # via nbformat -jupytext==1.16.0 # via -r requirements-dev.in -lark==1.1.8 # via gt4py (pyproject.toml) -mako==1.3.0 # via gt4py (pyproject.toml) +jedi==0.19.1 # via ipython +jinja2==3.1.3 # via flask, gt4py (pyproject.toml), sphinx +jsonschema==4.21.1 # via nbformat +jsonschema-specifications==2023.12.1 # via jsonschema +jupyter-client==8.6.0 # via ipykernel, nbclient +jupyter-core==5.7.1 # via ipykernel, jupyter-client, nbformat +jupytext==1.16.1 # via -r requirements-dev.in +kiwisolver==1.4.5 # via matplotlib +lark==1.1.9 # via gt4py (pyproject.toml) +mako==1.3.1 # via gt4py (pyproject.toml) markdown-it-py==3.0.0 # via jupytext, mdit-py-plugins -markupsafe==2.1.3 # via jinja2, mako, werkzeug +markupsafe==2.1.4 # via jinja2, mako, werkzeug +matplotlib==3.7.4 # via -r requirements-dev.in +matplotlib-inline==0.1.6 # via ipykernel, ipython mccabe==0.7.0 # via flake8 mdit-py-plugins==0.4.0 # via jupytext mdurl==0.1.2 # via markdown-it-py mpmath==1.3.0 # via sympy -mypy==1.7.1 # via -r requirements-dev.in +mypy==1.8.0 # via -r requirements-dev.in mypy-extensions==1.0.0 # via black, mypy nanobind==1.8.0 # via gt4py (pyproject.toml) -nbformat==5.9.2 # via jupytext +nbclient==0.6.8 # via nbmake +nbformat==5.9.2 # via jupytext, nbclient, nbmake +nbmake==1.4.6 # via -r requirements-dev.in +nest-asyncio==1.6.0 # via ipykernel, nbclient networkx==3.1 # via dace ninja==1.11.1.1 # via gt4py (pyproject.toml) nodeenv==1.8.0 # via pre-commit -numpy==1.24.4 # via dace, gt4py (pyproject.toml), types-jack-client +numpy==1.24.4 # via contourpy, dace, gt4py (pyproject.toml), matplotlib, types-jack-client ordered-set==4.1.0 # via deepdiff -packaging==23.2 # via black, build, gt4py (pyproject.toml), jupytext, pyproject-api, pytest, scikit-build, setuptools-scm, sphinx, tox +packaging==23.2 # via black, build, gt4py (pyproject.toml), ipykernel, jupytext, matplotlib, pyproject-api, pytest, setuptools-scm, sphinx, tox +parso==0.8.3 # via jedi pathspec==0.12.1 # via black -pep517==0.13.1 # via requirementslib -pip-api==0.0.30 # via isort +pexpect==4.9.0 # via ipython +pickleshare==0.7.5 # via ipython +pillow==10.2.0 # via matplotlib pip-tools==7.3.0 # via -r requirements-dev.in -pipdeptree==2.13.1 # via -r requirements-dev.in -pipreqs==0.4.13 # via isort +pipdeptree==2.13.2 # via -r requirements-dev.in pkgutil-resolve-name==1.3.10 # via jsonschema -platformdirs==4.1.0 # via black, jupyter-core, requirementslib, tox, virtualenv -plette==0.4.4 # via requirementslib -pluggy==1.3.0 # via pytest, tox +platformdirs==4.1.0 # via black, jupyter-core, tox, virtualenv +pluggy==1.4.0 # via pytest, tox ply==3.11 # via dace pre-commit==3.5.0 # via -r requirements-dev.in -psutil==5.9.6 # via -r requirements-dev.in, pytest-xdist +prompt-toolkit==3.0.43 # via ipython +psutil==5.9.8 # via -r requirements-dev.in, ipykernel, pytest-xdist +ptyprocess==0.7.0 # via pexpect +pure-eval==0.2.2 # via stack-data pybind11==2.11.1 # via gt4py (pyproject.toml) pycodestyle==2.11.1 # via flake8, flake8-debugger pycparser==2.21 # via cffi -pydantic==1.10.13 # via requirementslib pydocstyle==6.3.0 # via flake8-docstrings -pyflakes==3.1.0 # via flake8 -pygments==2.17.2 # via -r requirements-dev.in, devtools, flake8-rst-docstrings, sphinx +pyflakes==3.2.0 # via flake8 +pygments==2.17.2 # via -r requirements-dev.in, devtools, flake8-rst-docstrings, ipython, nbmake, sphinx +pyparsing==3.1.1 # via matplotlib pyproject-api==1.6.1 # via tox pyproject-hooks==1.0.0 # via build -pytest==7.4.3 # via -r requirements-dev.in, gt4py (pyproject.toml), pytest-cache, pytest-cov, pytest-factoryboy, pytest-xdist +pytest==7.4.4 # via -r requirements-dev.in, gt4py (pyproject.toml), nbmake, pytest-cache, pytest-cov, pytest-factoryboy, pytest-xdist pytest-cache==1.0 # via -r requirements-dev.in pytest-cov==4.1.0 # via -r requirements-dev.in pytest-factoryboy==2.6.0 # via -r requirements-dev.in -pytest-xdist==3.5.0 # via -r requirements-dev.in -python-dateutil==2.8.2 # via faker +pytest-xdist==3.5.0 # via -r requirements-dev.in, pytest-xdist +python-dateutil==2.8.2 # via faker, jupyter-client, matplotlib pytz==2023.3.post1 # via babel pyyaml==6.0.1 # via dace, jupytext, pre-commit -referencing==0.32.0 # via jsonschema, jsonschema-specifications -requests==2.31.0 # via dace, requirementslib, sphinx, yarg -requirementslib==3.0.0 # via isort +pyzmq==25.1.2 # via ipykernel, jupyter-client +referencing==0.32.1 # via jsonschema, jsonschema-specifications +requests==2.31.0 # via dace, sphinx restructuredtext-lint==1.4.0 # via flake8-rst-docstrings -rpds-py==0.13.2 # via jsonschema, referencing -ruff==0.1.7 # via -r requirements-dev.in -scikit-build==0.17.6 # via dace +rpds-py==0.17.1 # via jsonschema, referencing +ruff==0.1.14 # via -r requirements-dev.in setuptools-scm==8.0.4 # via fparser six==1.16.0 # via asttokens, astunparse, python-dateutil snowballstemmer==2.2.0 # via pydocstyle, sphinx @@ -143,15 +159,16 @@ sphinxcontrib-jquery==4.1 # via sphinx-rtd-theme sphinxcontrib-jsmath==1.0.1 # via sphinx sphinxcontrib-qthelp==1.0.3 # via sphinx sphinxcontrib-serializinghtml==1.1.5 # via sphinx +stack-data==0.6.3 # via ipython sympy==1.9 # via dace, gt4py (pyproject.toml) tabulate==0.9.0 # via gt4py (pyproject.toml) toml==0.10.2 # via jupytext -tomli==2.0.1 # via -r requirements-dev.in, black, build, coverage, flake8-pyproject, mypy, pep517, pip-tools, pyproject-api, pyproject-hooks, pytest, scikit-build, setuptools-scm, tox -tomlkit==0.12.3 # via plette, requirementslib -toolz==0.12.0 # via cytoolz -tox==4.11.4 # via -r requirements-dev.in -traitlets==5.14.0 # via jupyter-core, nbformat -types-aiofiles==23.2.0.0 # via types-all +tomli==2.0.1 # via -r requirements-dev.in, black, build, coverage, flake8-pyproject, mypy, pip-tools, pyproject-api, pyproject-hooks, pytest, setuptools-scm, tox +toolz==0.12.1 # via cytoolz +tornado==6.4 # via ipykernel, jupyter-client +tox==4.12.1 # via -r requirements-dev.in +traitlets==5.14.1 # via comm, ipykernel, ipython, jupyter-client, jupyter-core, matplotlib-inline, nbclient, nbformat +types-aiofiles==23.2.0.20240106 # via types-all types-all==1.0.0 # via -r requirements-dev.in types-annoy==1.17.8.4 # via types-all types-atomicwrites==1.4.5.1 # via types-all @@ -161,22 +178,22 @@ types-bleach==6.1.0.1 # via types-all types-boto==2.49.18.9 # via types-all types-cachetools==5.3.0.7 # via types-all types-certifi==2021.10.8.3 # via types-all -types-cffi==1.16.0.0 # via types-jack-client +types-cffi==1.16.0.20240106 # via types-jack-client types-characteristic==14.3.7 # via types-all types-chardet==5.0.4.6 # via types-all types-click==7.1.8 # via types-all, types-flask -types-click-spinner==0.1.13.6 # via types-all -types-colorama==0.4.15.12 # via types-all +types-click-spinner==0.1.13.20240106 # via types-all +types-colorama==0.4.15.20240106 # via types-all types-contextvars==2.4.7.3 # via types-all -types-croniter==2.0.0.0 # via types-all +types-croniter==2.0.0.20240106 # via types-all types-cryptography==3.3.23.2 # via types-all, types-openssl-python, types-pyjwt types-dataclasses==0.6.6 # via types-all -types-dateparser==1.1.4.10 # via types-all +types-dateparser==1.1.4.20240106 # via types-all types-datetimerange==2.0.0.6 # via types-all -types-decorator==5.1.8.4 # via types-all -types-deprecated==1.2.9.3 # via types-all +types-decorator==5.1.8.20240106 # via types-all +types-deprecated==1.2.9.20240106 # via types-all types-docopt==0.6.11.4 # via types-all -types-docutils==0.20.0.3 # via types-all +types-docutils==0.20.0.20240125 # via types-all types-emoji==2.1.0.3 # via types-all types-enum34==1.1.8 # via types-all types-fb303==1.0.0 # via types-all, types-scribe @@ -189,67 +206,67 @@ types-futures==3.3.8 # via types-all types-geoip2==3.0.0 # via types-all types-ipaddress==1.0.8 # via types-all, types-maxminddb types-itsdangerous==1.1.6 # via types-all -types-jack-client==0.5.10.10 # via types-all +types-jack-client==0.5.10.20240106 # via types-all types-jinja2==2.11.9 # via types-all, types-flask types-kazoo==0.1.3 # via types-all -types-markdown==3.5.0.3 # via types-all +types-markdown==3.5.0.20240106 # via types-all types-markupsafe==1.1.10 # via types-all, types-jinja2 types-maxminddb==1.5.0 # via types-all, types-geoip2 -types-mock==5.1.0.3 # via types-all +types-mock==5.1.0.20240106 # via types-all types-mypy-extensions==1.0.0.5 # via types-all types-nmap==0.1.6 # via types-all types-openssl-python==0.1.3 # via types-all types-orjson==3.6.2 # via types-all -types-paramiko==3.3.0.2 # via types-all, types-pysftp +types-paramiko==3.4.0.20240120 # via types-all, types-pysftp types-pathlib2==2.3.0 # via types-all -types-pillow==10.1.0.2 # via types-all +types-pillow==10.2.0.20240125 # via types-all types-pkg-resources==0.1.3 # via types-all -types-polib==1.2.0.1 # via types-all -types-protobuf==4.24.0.4 # via types-all -types-pyaudio==0.2.16.7 # via types-all -types-pycurl==7.45.2.5 # via types-all +types-polib==1.2.0.20240115 # via types-all +types-protobuf==4.24.0.20240106 # via types-all +types-pyaudio==0.2.16.20240106 # via types-all +types-pycurl==7.45.2.20240106 # via types-all types-pyfarmhash==0.3.1.2 # via types-all types-pyjwt==1.7.1 # via types-all types-pymssql==2.1.0 # via types-all types-pymysql==1.1.0.1 # via types-all -types-pyopenssl==23.3.0.0 # via types-redis +types-pyopenssl==23.3.0.20240106 # via types-redis types-pyrfc3339==1.1.1.5 # via types-all -types-pysftp==0.2.17.6 # via types-all -types-python-dateutil==2.8.19.14 # via types-all, types-datetimerange +types-pysftp==0.2.17.20240106 # via types-all +types-python-dateutil==2.8.19.20240106 # via types-all, types-datetimerange types-python-gflags==3.1.7.3 # via types-all types-python-slugify==8.0.0.3 # via types-all types-pytz==2023.3.1.1 # via types-all, types-tzlocal types-pyvmomi==8.0.0.6 # via types-all types-pyyaml==6.0.12.12 # via types-all -types-redis==4.6.0.11 # via types-all -types-requests==2.31.0.10 # via types-all +types-redis==4.6.0.20240106 # via types-all +types-requests==2.31.0.20240125 # via types-all types-retry==0.9.9.4 # via types-all types-routes==2.5.0 # via types-all types-scribe==2.0.0 # via types-all -types-setuptools==69.0.0.0 # via types-cffi +types-setuptools==69.0.0.20240125 # via types-cffi types-simplejson==3.19.0.2 # via types-all types-singledispatch==4.1.0.0 # via types-all -types-six==1.16.21.9 # via types-all -types-tabulate==0.9.0.3 # via types-all +types-six==1.16.21.20240106 # via types-all +types-tabulate==0.9.0.20240106 # via types-all types-termcolor==1.1.6.2 # via types-all types-toml==0.10.8.7 # via types-all types-tornado==5.1.1 # via types-all types-typed-ast==1.5.8.7 # via types-all types-tzlocal==5.1.0.1 # via types-all -types-ujson==5.8.0.1 # via types-all -types-waitress==2.1.4.9 # via types-all +types-ujson==5.9.0.0 # via types-all +types-waitress==2.1.4.20240106 # via types-all types-werkzeug==1.0.9 # via types-all, types-flask types-xxhash==3.0.5.2 # via types-all -typing-extensions==4.5.0 # via black, faker, gt4py (pyproject.toml), mypy, pydantic, pytest-factoryboy, setuptools-scm +typing-extensions==4.5.0 # via black, faker, gt4py (pyproject.toml), ipython, mypy, pytest-factoryboy, setuptools-scm urllib3==2.1.0 # via requests, types-requests virtualenv==20.25.0 # via pre-commit, tox +wcwidth==0.2.13 # via prompt-toolkit websockets==12.0 # via dace werkzeug==3.0.1 # via flask -wheel==0.42.0 # via astunparse, pip-tools, scikit-build +wheel==0.42.0 # via astunparse, pip-tools xxhash==3.0.0 # via gt4py (pyproject.toml) -yarg==0.1.9 # via pipreqs zipp==3.17.0 # via importlib-metadata, importlib-resources # The following packages are considered to be unsafe in a requirements file: -pip==23.3.1 # via pip-api, pip-tools, requirementslib -setuptools==69.0.2 # via gt4py (pyproject.toml), nodeenv, pip-tools, requirementslib, scikit-build, setuptools-scm +pip==23.3.2 # via pip-tools +setuptools==69.0.3 # via gt4py (pyproject.toml), nodeenv, pip-tools, setuptools-scm diff --git a/examples/lap_cartesian_vs_next.ipynb b/examples/lap_cartesian_vs_next.ipynb index cb80122570..8187a571dd 100644 --- a/examples/lap_cartesian_vs_next.ipynb +++ b/examples/lap_cartesian_vs_next.ipynb @@ -81,7 +81,7 @@ "source": [ "import gt4py.next as gtx\n", "\n", - "allocator = gtx.itir_embedded # should match the executor\n", + "allocator = gtx.itir_python # should match the executor\n", "# allocator = gtx.gtfn_cpu\n", "# allocator = gtx.gtfn_gpu\n", "\n", @@ -137,7 +137,7 @@ "source": [ "from gt4py.next import Field\n", "\n", - "next_backend = gtx.itir_embedded\n", + "next_backend = gtx.itir_python\n", "# next_backend = gtx.gtfn_cpu\n", "# next_backend = gtx.gtfn_gpu\n", "\n", diff --git a/min-extra-requirements-test.txt b/min-extra-requirements-test.txt index fd7724bac9..3c6cd3d9ff 100644 --- a/min-extra-requirements-test.txt +++ b/min-extra-requirements-test.txt @@ -49,8 +49,10 @@ jinja2==3.0.0 jupytext==1.14 lark==1.1.2 mako==1.1 +matplotlib==3.3 mypy==1.0 nanobind==1.4.0 +nbmake==1.4.6 ninja==1.10 numpy==1.21.2 packaging==20.0 diff --git a/min-requirements-test.txt b/min-requirements-test.txt index a6e5d19d1d..d2ebaba331 100644 --- a/min-requirements-test.txt +++ b/min-requirements-test.txt @@ -46,8 +46,10 @@ jinja2==3.0.0 jupytext==1.14 lark==1.1.2 mako==1.1 +matplotlib==3.3 mypy==1.0 nanobind==1.4.0 +nbmake==1.4.6 ninja==1.10 numpy==1.21.2 packaging==20.0 diff --git a/requirements-dev.in b/requirements-dev.in index 46bf3d9a8c..59ddb733d0 100644 --- a/requirements-dev.in +++ b/requirements-dev.in @@ -26,6 +26,8 @@ flake8-rst-docstrings>=0.0.14 isort>=5.10 jupytext>=1.14 mypy>=1.0 +matplotlib>=3.3 +nbmake>=1.4.6 pipdeptree>=2.3 pip-tools>=6.10 pre-commit>=2.17 diff --git a/requirements-dev.txt b/requirements-dev.txt index 0fa523866f..abfa99a2ae 100644 --- a/requirements-dev.txt +++ b/requirements-dev.txt @@ -6,17 +6,17 @@ # aenum==3.1.15 # via dace alabaster==0.7.13 # via sphinx -asttokens==2.4.1 # via devtools +asttokens==2.4.1 # via devtools, stack-data astunparse==1.6.3 ; python_version < "3.9" # via dace, gt4py (pyproject.toml) -attrs==23.1.0 # via flake8-bugbear, flake8-eradicate, gt4py (pyproject.toml), hypothesis, jsonschema, referencing -babel==2.13.1 # via sphinx -black==23.11.0 # via gt4py (pyproject.toml) +attrs==23.2.0 # via flake8-bugbear, flake8-eradicate, gt4py (pyproject.toml), hypothesis, jsonschema, referencing +babel==2.14.0 # via sphinx +backcall==0.2.0 # via ipython +black==23.12.1 # via gt4py (pyproject.toml) blinker==1.7.0 # via flask boltons==23.1.1 # via gt4py (pyproject.toml) build==1.0.3 # via pip-tools cached-property==1.5.2 # via gt4py (pyproject.toml) cachetools==5.3.2 # via tox -cerberus==1.3.5 # via plette certifi==2023.11.17 # via requests cffi==1.16.0 # via cryptography cfgv==3.4.0 # via pre-commit @@ -24,31 +24,34 @@ chardet==5.2.0 # via tox charset-normalizer==3.3.2 # via requests clang-format==17.0.6 # via -r requirements-dev.in, gt4py (pyproject.toml) click==8.1.7 # via black, flask, gt4py (pyproject.toml), pip-tools -cmake==3.27.9 # via dace, gt4py (pyproject.toml) +cmake==3.28.1 # via gt4py (pyproject.toml) cogapp==3.3.0 # via -r requirements-dev.in colorama==0.4.6 # via tox -coverage[toml]==7.3.2 # via -r requirements-dev.in, pytest-cov -cryptography==41.0.7 # via types-paramiko, types-pyopenssl, types-redis -cytoolz==0.12.2 # via gt4py (pyproject.toml) +comm==0.2.1 # via ipykernel +contourpy==1.1.1 # via matplotlib +coverage[toml]==7.4.0 # via -r requirements-dev.in, coverage, pytest-cov +cryptography==42.0.1 # via types-paramiko, types-pyopenssl, types-redis +cycler==0.12.1 # via matplotlib +cytoolz==0.12.3 # via gt4py (pyproject.toml) dace==0.15.1 # via gt4py (pyproject.toml) darglint==1.8.1 # via -r requirements-dev.in +debugpy==1.8.0 # via ipykernel +decorator==5.1.1 # via ipython deepdiff==6.7.1 # via gt4py (pyproject.toml) devtools==0.12.2 # via gt4py (pyproject.toml) dill==0.3.7 # via dace -distlib==0.3.7 # via requirementslib, virtualenv -distro==1.8.0 # via scikit-build -docopt==0.6.2 # via pipreqs +distlib==0.3.8 # via virtualenv docutils==0.20.1 # via restructuredtext-lint, sphinx, sphinx-rtd-theme eradicate==2.3.0 # via flake8-eradicate exceptiongroup==1.2.0 # via hypothesis, pytest execnet==2.0.2 # via pytest-cache, pytest-xdist -executing==2.0.1 # via devtools +executing==2.0.1 # via devtools, stack-data factory-boy==3.3.0 # via -r requirements-dev.in, pytest-factoryboy -faker==20.1.0 # via factory-boy -fastjsonschema==2.19.0 # via nbformat +faker==22.5.1 # via factory-boy +fastjsonschema==2.19.1 # via nbformat filelock==3.13.1 # via tox, virtualenv -flake8==6.1.0 # via -r requirements-dev.in, flake8-bugbear, flake8-builtins, flake8-debugger, flake8-docstrings, flake8-eradicate, flake8-mutable, flake8-pyproject, flake8-rst-docstrings -flake8-bugbear==23.12.2 # via -r requirements-dev.in +flake8==7.0.0 # via -r requirements-dev.in, flake8-bugbear, flake8-builtins, flake8-debugger, flake8-docstrings, flake8-eradicate, flake8-mutable, flake8-pyproject, flake8-rst-docstrings +flake8-bugbear==24.1.17 # via -r requirements-dev.in flake8-builtins==2.2.0 # via -r requirements-dev.in flake8-debugger==4.1.2 # via -r requirements-dev.in flake8-docstrings==1.7.0 # via -r requirements-dev.in @@ -56,80 +59,93 @@ flake8-eradicate==1.5.0 # via -r requirements-dev.in flake8-mutable==1.2.0 # via -r requirements-dev.in flake8-pyproject==1.2.3 # via -r requirements-dev.in flake8-rst-docstrings==0.3.0 # via -r requirements-dev.in -flask==3.0.0 # via dace +flask==3.0.1 # via dace +fonttools==4.47.2 # via matplotlib fparser==0.1.3 # via dace -frozendict==2.3.10 # via gt4py (pyproject.toml) +frozendict==2.4.0 # via gt4py (pyproject.toml) gridtools-cpp==2.3.1 # via gt4py (pyproject.toml) -hypothesis==6.92.0 # via -r requirements-dev.in, gt4py (pyproject.toml) +hypothesis==6.97.0 # via -r requirements-dev.in, gt4py (pyproject.toml) identify==2.5.33 # via pre-commit idna==3.6 # via requests imagesize==1.4.1 # via sphinx -importlib-metadata==7.0.0 # via build, flask, fparser, sphinx -importlib-resources==6.1.1 ; python_version < "3.9" # via gt4py (pyproject.toml), jsonschema, jsonschema-specifications +importlib-metadata==7.0.1 # via build, flask, fparser, jupyter-client, sphinx +importlib-resources==6.1.1 ; python_version < "3.9" # via gt4py (pyproject.toml), jsonschema, jsonschema-specifications, matplotlib inflection==0.5.1 # via pytest-factoryboy iniconfig==2.0.0 # via pytest -isort==5.13.0 # via -r requirements-dev.in +ipykernel==6.29.0 # via nbmake +ipython==8.12.3 # via ipykernel +isort==5.13.2 # via -r requirements-dev.in itsdangerous==2.1.2 # via flask -jinja2==3.1.2 # via flask, gt4py (pyproject.toml), sphinx -jsonschema==4.20.0 # via nbformat -jsonschema-specifications==2023.11.2 # via jsonschema -jupyter-core==5.5.0 # via nbformat -jupytext==1.16.0 # via -r requirements-dev.in -lark==1.1.8 # via gt4py (pyproject.toml) -mako==1.3.0 # via gt4py (pyproject.toml) +jedi==0.19.1 # via ipython +jinja2==3.1.3 # via flask, gt4py (pyproject.toml), sphinx +jsonschema==4.21.1 # via nbformat +jsonschema-specifications==2023.12.1 # via jsonschema +jupyter-client==8.6.0 # via ipykernel, nbclient +jupyter-core==5.7.1 # via ipykernel, jupyter-client, nbformat +jupytext==1.16.1 # via -r requirements-dev.in +kiwisolver==1.4.5 # via matplotlib +lark==1.1.9 # via gt4py (pyproject.toml) +mako==1.3.1 # via gt4py (pyproject.toml) markdown-it-py==3.0.0 # via jupytext, mdit-py-plugins -markupsafe==2.1.3 # via jinja2, mako, werkzeug +markupsafe==2.1.4 # via jinja2, mako, werkzeug +matplotlib==3.7.4 # via -r requirements-dev.in +matplotlib-inline==0.1.6 # via ipykernel, ipython mccabe==0.7.0 # via flake8 mdit-py-plugins==0.4.0 # via jupytext mdurl==0.1.2 # via markdown-it-py mpmath==1.3.0 # via sympy -mypy==1.7.1 # via -r requirements-dev.in +mypy==1.8.0 # via -r requirements-dev.in mypy-extensions==1.0.0 # via black, mypy nanobind==1.8.0 # via gt4py (pyproject.toml) -nbformat==5.9.2 # via jupytext +nbclient==0.6.8 # via nbmake +nbformat==5.9.2 # via jupytext, nbclient, nbmake +nbmake==1.4.6 # via -r requirements-dev.in +nest-asyncio==1.6.0 # via ipykernel, nbclient networkx==3.1 # via dace ninja==1.11.1.1 # via gt4py (pyproject.toml) nodeenv==1.8.0 # via pre-commit -numpy==1.24.4 # via dace, gt4py (pyproject.toml), types-jack-client +numpy==1.24.4 # via contourpy, dace, gt4py (pyproject.toml), matplotlib, types-jack-client ordered-set==4.1.0 # via deepdiff -packaging==23.2 # via black, build, gt4py (pyproject.toml), jupytext, pyproject-api, pytest, scikit-build, setuptools-scm, sphinx, tox +packaging==23.2 # via black, build, gt4py (pyproject.toml), ipykernel, jupytext, matplotlib, pyproject-api, pytest, setuptools-scm, sphinx, tox +parso==0.8.3 # via jedi pathspec==0.12.1 # via black -pep517==0.13.1 # via requirementslib -pip-api==0.0.30 # via isort +pexpect==4.9.0 # via ipython +pickleshare==0.7.5 # via ipython +pillow==10.2.0 # via matplotlib pip-tools==7.3.0 # via -r requirements-dev.in -pipdeptree==2.13.1 # via -r requirements-dev.in -pipreqs==0.4.13 # via isort +pipdeptree==2.13.2 # via -r requirements-dev.in pkgutil-resolve-name==1.3.10 # via jsonschema -platformdirs==4.1.0 # via black, jupyter-core, requirementslib, tox, virtualenv -plette[validation]==0.4.4 # via requirementslib -pluggy==1.3.0 # via pytest, tox +platformdirs==4.1.0 # via black, jupyter-core, tox, virtualenv +pluggy==1.4.0 # via pytest, tox ply==3.11 # via dace pre-commit==3.5.0 # via -r requirements-dev.in -psutil==5.9.6 # via -r requirements-dev.in, pytest-xdist +prompt-toolkit==3.0.43 # via ipython +psutil==5.9.8 # via -r requirements-dev.in, ipykernel, pytest-xdist +ptyprocess==0.7.0 # via pexpect +pure-eval==0.2.2 # via stack-data pybind11==2.11.1 # via gt4py (pyproject.toml) pycodestyle==2.11.1 # via flake8, flake8-debugger pycparser==2.21 # via cffi -pydantic==1.10.13 # via requirementslib pydocstyle==6.3.0 # via flake8-docstrings -pyflakes==3.1.0 # via flake8 -pygments==2.17.2 # via -r requirements-dev.in, devtools, flake8-rst-docstrings, sphinx +pyflakes==3.2.0 # via flake8 +pygments==2.17.2 # via -r requirements-dev.in, devtools, flake8-rst-docstrings, ipython, nbmake, sphinx +pyparsing==3.1.1 # via matplotlib pyproject-api==1.6.1 # via tox pyproject-hooks==1.0.0 # via build -pytest==7.4.3 # via -r requirements-dev.in, gt4py (pyproject.toml), pytest-cache, pytest-cov, pytest-factoryboy, pytest-xdist +pytest==7.4.4 # via -r requirements-dev.in, gt4py (pyproject.toml), nbmake, pytest-cache, pytest-cov, pytest-factoryboy, pytest-xdist pytest-cache==1.0 # via -r requirements-dev.in pytest-cov==4.1.0 # via -r requirements-dev.in pytest-factoryboy==2.6.0 # via -r requirements-dev.in -pytest-xdist[psutil]==3.5.0 # via -r requirements-dev.in -python-dateutil==2.8.2 # via faker +pytest-xdist[psutil]==3.5.0 # via -r requirements-dev.in, pytest-xdist +python-dateutil==2.8.2 # via faker, jupyter-client, matplotlib pytz==2023.3.post1 # via babel pyyaml==6.0.1 # via dace, jupytext, pre-commit -referencing==0.32.0 # via jsonschema, jsonschema-specifications -requests==2.31.0 # via dace, requirementslib, sphinx, yarg -requirementslib==3.0.0 # via isort +pyzmq==25.1.2 # via ipykernel, jupyter-client +referencing==0.32.1 # via jsonschema, jsonschema-specifications +requests==2.31.0 # via dace, sphinx restructuredtext-lint==1.4.0 # via flake8-rst-docstrings -rpds-py==0.13.2 # via jsonschema, referencing -ruff==0.1.7 # via -r requirements-dev.in -scikit-build==0.17.6 # via dace +rpds-py==0.17.1 # via jsonschema, referencing +ruff==0.1.14 # via -r requirements-dev.in setuptools-scm==8.0.4 # via fparser six==1.16.0 # via asttokens, astunparse, python-dateutil snowballstemmer==2.2.0 # via pydocstyle, sphinx @@ -143,15 +159,16 @@ sphinxcontrib-jquery==4.1 # via sphinx-rtd-theme sphinxcontrib-jsmath==1.0.1 # via sphinx sphinxcontrib-qthelp==1.0.3 # via sphinx sphinxcontrib-serializinghtml==1.1.5 # via sphinx +stack-data==0.6.3 # via ipython sympy==1.9 # via dace, gt4py (pyproject.toml) tabulate==0.9.0 # via gt4py (pyproject.toml) toml==0.10.2 # via jupytext -tomli==2.0.1 # via -r requirements-dev.in, black, build, coverage, flake8-pyproject, mypy, pep517, pip-tools, pyproject-api, pyproject-hooks, pytest, scikit-build, setuptools-scm, tox -tomlkit==0.12.3 # via plette, requirementslib -toolz==0.12.0 # via cytoolz -tox==4.11.4 # via -r requirements-dev.in -traitlets==5.14.0 # via jupyter-core, nbformat -types-aiofiles==23.2.0.0 # via types-all +tomli==2.0.1 # via -r requirements-dev.in, black, build, coverage, flake8-pyproject, mypy, pip-tools, pyproject-api, pyproject-hooks, pytest, setuptools-scm, tox +toolz==0.12.1 # via cytoolz +tornado==6.4 # via ipykernel, jupyter-client +tox==4.12.1 # via -r requirements-dev.in +traitlets==5.14.1 # via comm, ipykernel, ipython, jupyter-client, jupyter-core, matplotlib-inline, nbclient, nbformat +types-aiofiles==23.2.0.20240106 # via types-all types-all==1.0.0 # via -r requirements-dev.in types-annoy==1.17.8.4 # via types-all types-atomicwrites==1.4.5.1 # via types-all @@ -161,22 +178,22 @@ types-bleach==6.1.0.1 # via types-all types-boto==2.49.18.9 # via types-all types-cachetools==5.3.0.7 # via types-all types-certifi==2021.10.8.3 # via types-all -types-cffi==1.16.0.0 # via types-jack-client +types-cffi==1.16.0.20240106 # via types-jack-client types-characteristic==14.3.7 # via types-all types-chardet==5.0.4.6 # via types-all types-click==7.1.8 # via types-all, types-flask -types-click-spinner==0.1.13.6 # via types-all -types-colorama==0.4.15.12 # via types-all +types-click-spinner==0.1.13.20240106 # via types-all +types-colorama==0.4.15.20240106 # via types-all types-contextvars==2.4.7.3 # via types-all -types-croniter==2.0.0.0 # via types-all +types-croniter==2.0.0.20240106 # via types-all types-cryptography==3.3.23.2 # via types-all, types-openssl-python, types-pyjwt types-dataclasses==0.6.6 # via types-all -types-dateparser==1.1.4.10 # via types-all +types-dateparser==1.1.4.20240106 # via types-all types-datetimerange==2.0.0.6 # via types-all -types-decorator==5.1.8.4 # via types-all -types-deprecated==1.2.9.3 # via types-all +types-decorator==5.1.8.20240106 # via types-all +types-deprecated==1.2.9.20240106 # via types-all types-docopt==0.6.11.4 # via types-all -types-docutils==0.20.0.3 # via types-all +types-docutils==0.20.0.20240125 # via types-all types-emoji==2.1.0.3 # via types-all types-enum34==1.1.8 # via types-all types-fb303==1.0.0 # via types-all, types-scribe @@ -189,67 +206,67 @@ types-futures==3.3.8 # via types-all types-geoip2==3.0.0 # via types-all types-ipaddress==1.0.8 # via types-all, types-maxminddb types-itsdangerous==1.1.6 # via types-all -types-jack-client==0.5.10.10 # via types-all +types-jack-client==0.5.10.20240106 # via types-all types-jinja2==2.11.9 # via types-all, types-flask types-kazoo==0.1.3 # via types-all -types-markdown==3.5.0.3 # via types-all +types-markdown==3.5.0.20240106 # via types-all types-markupsafe==1.1.10 # via types-all, types-jinja2 types-maxminddb==1.5.0 # via types-all, types-geoip2 -types-mock==5.1.0.3 # via types-all +types-mock==5.1.0.20240106 # via types-all types-mypy-extensions==1.0.0.5 # via types-all types-nmap==0.1.6 # via types-all types-openssl-python==0.1.3 # via types-all types-orjson==3.6.2 # via types-all -types-paramiko==3.3.0.2 # via types-all, types-pysftp +types-paramiko==3.4.0.20240120 # via types-all, types-pysftp types-pathlib2==2.3.0 # via types-all -types-pillow==10.1.0.2 # via types-all +types-pillow==10.2.0.20240125 # via types-all types-pkg-resources==0.1.3 # via types-all -types-polib==1.2.0.1 # via types-all -types-protobuf==4.24.0.4 # via types-all -types-pyaudio==0.2.16.7 # via types-all -types-pycurl==7.45.2.5 # via types-all +types-polib==1.2.0.20240115 # via types-all +types-protobuf==4.24.0.20240106 # via types-all +types-pyaudio==0.2.16.20240106 # via types-all +types-pycurl==7.45.2.20240106 # via types-all types-pyfarmhash==0.3.1.2 # via types-all types-pyjwt==1.7.1 # via types-all types-pymssql==2.1.0 # via types-all types-pymysql==1.1.0.1 # via types-all -types-pyopenssl==23.3.0.0 # via types-redis +types-pyopenssl==23.3.0.20240106 # via types-redis types-pyrfc3339==1.1.1.5 # via types-all -types-pysftp==0.2.17.6 # via types-all -types-python-dateutil==2.8.19.14 # via types-all, types-datetimerange +types-pysftp==0.2.17.20240106 # via types-all +types-python-dateutil==2.8.19.20240106 # via types-all, types-datetimerange types-python-gflags==3.1.7.3 # via types-all types-python-slugify==8.0.0.3 # via types-all types-pytz==2023.3.1.1 # via types-all, types-tzlocal types-pyvmomi==8.0.0.6 # via types-all types-pyyaml==6.0.12.12 # via types-all -types-redis==4.6.0.11 # via types-all -types-requests==2.31.0.10 # via types-all +types-redis==4.6.0.20240106 # via types-all +types-requests==2.31.0.20240125 # via types-all types-retry==0.9.9.4 # via types-all types-routes==2.5.0 # via types-all types-scribe==2.0.0 # via types-all -types-setuptools==69.0.0.0 # via types-cffi +types-setuptools==69.0.0.20240125 # via types-cffi types-simplejson==3.19.0.2 # via types-all types-singledispatch==4.1.0.0 # via types-all -types-six==1.16.21.9 # via types-all -types-tabulate==0.9.0.3 # via types-all +types-six==1.16.21.20240106 # via types-all +types-tabulate==0.9.0.20240106 # via types-all types-termcolor==1.1.6.2 # via types-all types-toml==0.10.8.7 # via types-all types-tornado==5.1.1 # via types-all types-typed-ast==1.5.8.7 # via types-all types-tzlocal==5.1.0.1 # via types-all -types-ujson==5.8.0.1 # via types-all -types-waitress==2.1.4.9 # via types-all +types-ujson==5.9.0.0 # via types-all +types-waitress==2.1.4.20240106 # via types-all types-werkzeug==1.0.9 # via types-all, types-flask types-xxhash==3.0.5.2 # via types-all -typing-extensions==4.5.0 # via black, faker, gt4py (pyproject.toml), mypy, pydantic, pytest-factoryboy, setuptools-scm +typing-extensions==4.5.0 # via black, faker, gt4py (pyproject.toml), ipython, mypy, pytest-factoryboy, setuptools-scm urllib3==2.1.0 # via requests, types-requests virtualenv==20.25.0 # via pre-commit, tox +wcwidth==0.2.13 # via prompt-toolkit websockets==12.0 # via dace werkzeug==3.0.1 # via flask -wheel==0.42.0 # via astunparse, pip-tools, scikit-build +wheel==0.42.0 # via astunparse, pip-tools xxhash==3.0.0 # via gt4py (pyproject.toml) -yarg==0.1.9 # via pipreqs zipp==3.17.0 # via importlib-metadata, importlib-resources # The following packages are considered to be unsafe in a requirements file: -pip==23.3.1 # via pip-api, pip-tools, requirementslib -setuptools==69.0.2 # via gt4py (pyproject.toml), nodeenv, pip-tools, requirementslib, scikit-build, setuptools-scm +pip==23.3.2 # via pip-tools +setuptools==69.0.3 # via gt4py (pyproject.toml), nodeenv, pip-tools, setuptools-scm diff --git a/src/gt4py/cartesian/frontend/gtscript_frontend.py b/src/gt4py/cartesian/frontend/gtscript_frontend.py index 1c96773ff2..f665410b30 100644 --- a/src/gt4py/cartesian/frontend/gtscript_frontend.py +++ b/src/gt4py/cartesian/frontend/gtscript_frontend.py @@ -1889,7 +1889,9 @@ def resolve_external_symbols( "{}.{}".format(value._gtscript_["qualified_name"], item_name): item_value for item_name, item_value in value._gtscript_["nonlocals"].items() } - resolved_values_list.extend(nested_inlined_values.items()) + resolved_values_list.extend( # noqa[B038] #editing a loop's mutable iterable (probably intended here) + nested_inlined_values.items() + ) for imported_name, imported_name_accesses in value._gtscript_[ "imported" diff --git a/src/gt4py/next/ffront/ast_passes/remove_docstrings.py b/src/gt4py/next/ffront/ast_passes/remove_docstrings.py index 2ae12c01f2..653456f6c5 100644 --- a/src/gt4py/next/ffront/ast_passes/remove_docstrings.py +++ b/src/gt4py/next/ffront/ast_passes/remove_docstrings.py @@ -53,12 +53,13 @@ def apply(cls, node: ast.AST) -> ast.AST: return cls().visit(node) def visit_FunctionDef(self, node: ast.FunctionDef) -> ast.AST: - for obj in node.body: - if ( - isinstance(obj, ast.Expr) - and isinstance(obj.value, ast.Constant) - and isinstance(obj.value.value, str) - ): - node.body.remove(obj) - + node.body = [obj for obj in node.body if not _is_const_str_expr(obj)] return self.generic_visit(node) + + +def _is_const_str_expr(obj: ast.stmt) -> bool: + return ( + isinstance(obj, ast.Expr) + and isinstance(obj.value, ast.Constant) + and isinstance(obj.value.value, str) + ) diff --git a/tox.ini b/tox.ini index 817f721f71..6a21a298ba 100644 --- a/tox.ini +++ b/tox.ini @@ -97,6 +97,10 @@ commands = {cuda,cuda11x,cuda12x}: python -m pytest --cache-clear -v -n {env:NUM_PROCESSES:1} -m "requires_gpu" {posargs} tests{/}storage_tests #pytest doctest-modules {posargs} src{/}gt4py{/}storage +[testenv:notebooks-py{310,311}] +description = Run notebooks +commands = python -m pytest --nbmake examples -v -n {env:NUM_PROCESSES:1} + [testenv:linters-py{38,39,310,311}] description = Run linters commands = @@ -131,7 +135,7 @@ commands = # git add _static # commands_post = -[testenv:requirements-{common,py38,py39,py310}] +[testenv:requirements-{common,py38,py39,py310,py311}] description = common: Update pinned development requirements py38: Update requirements for testing a specific python version From f0986bb58e2ceffca39ee7a6cb9bf7a6749414b8 Mon Sep 17 00:00:00 2001 From: Hannes Vogt Date: Mon, 29 Jan 2024 14:10:25 +0100 Subject: [PATCH 57/85] feat[next]: Enable tests for embedded with cupy (#1372) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Introduces mechanism in tests for having different allocators for the same (`None`) backend. Fixes: - The resulting buffer for scan is deduced from the buffer type of the arguments, if there are no arguments we fallback to numpy (maybe break). We need to find a mechanism for this corner case. Currently these tests are excluded with `pytest.mark.uses_scan_without_field_args` for cupy embedded execution. Refactoring: - make common.field and common.connectivity private - rename next_tests.exclusion_matrices to definitions TODOs for later: - `broadcast` of scalar ignores the broadcast --------- Co-authored-by: Enrique González Paredes --- .../ADRs/0015-Test_Exclusion_Matrices.md | 8 +-- pyproject.toml | 6 +- src/gt4py/next/allocators.py | 1 + src/gt4py/next/common.py | 9 ++- src/gt4py/next/constructors.py | 6 +- src/gt4py/next/embedded/nd_array_field.py | 22 +++---- src/gt4py/next/embedded/operators.py | 33 +++++++++-- src/gt4py/next/ffront/fbuiltins.py | 8 +-- src/gt4py/next/iterator/embedded.py | 2 +- tests/next_tests/__init__.py | 4 +- .../{exclusion_matrices.py => definitions.py} | 59 ++++++++++++++++--- tests/next_tests/integration_tests/cases.py | 43 ++++++++------ .../ffront_tests/ffront_test_utils.py | 37 +++++++----- .../ffront_tests/test_arg_call_interface.py | 6 +- .../ffront_tests/test_bound_args.py | 4 +- .../ffront_tests/test_execution.py | 35 ++++------- .../ffront_tests/test_external_local_field.py | 2 +- .../ffront_tests/test_gt4py_builtins.py | 11 +--- .../test_math_builtin_execution.py | 6 +- .../ffront_tests/test_math_unary_builtins.py | 2 +- .../ffront_tests/test_program.py | 18 ++++-- .../ffront_tests/test_scalar_if.py | 2 +- .../test_temporaries_with_sizes.py | 3 +- .../feature_tests/test_util_cases.py | 10 ++-- .../ffront_tests/test_embedded_regression.py | 14 ++--- .../ffront_tests/test_icon_like_scan.py | 37 ++++++------ .../ffront_tests/test_laplacian.py | 2 +- tests/next_tests/unit_tests/conftest.py | 24 ++++---- .../embedded_tests/test_nd_array_field.py | 54 ++++++++--------- 29 files changed, 274 insertions(+), 194 deletions(-) rename tests/next_tests/{exclusion_matrices.py => definitions.py} (77%) diff --git a/docs/development/ADRs/0015-Test_Exclusion_Matrices.md b/docs/development/ADRs/0015-Test_Exclusion_Matrices.md index b338169d61..c868757905 100644 --- a/docs/development/ADRs/0015-Test_Exclusion_Matrices.md +++ b/docs/development/ADRs/0015-Test_Exclusion_Matrices.md @@ -1,5 +1,5 @@ --- -tags: [] +tags: [testing] --- # Test-Exclusion Matrices @@ -7,7 +7,7 @@ tags: [] - **Status**: valid - **Authors**: Edoardo Paone (@edopao), Enrique G. Paredes (@egparedes) - **Created**: 2023-09-21 -- **Updated**: 2023-09-21 +- **Updated**: 2024-01-25 In the context of Field View testing, lacking support for specific ITIR features while a certain backend is being developed, we decided to use `pytest` fixtures to exclude unsupported tests. @@ -22,7 +22,7 @@ the supported backends, while keeping the test code clean. ## Decision It was decided to apply fixtures and markers from `pytest` module. The fixture is the same used to execute the test -on different backends (`fieldview_backend` and `program_processor`), but it is extended with a check on the available feature markers. +on different backends (`exec_alloc_descriptor` and `program_processor`), but it is extended with a check on the available feature markers. If a test is annotated with a feature marker, the fixture will check if this feature is supported on the selected backend. If no marker is specified, the test is supposed to run on all backends. @@ -33,7 +33,7 @@ In the example below, `test_offset_field` requires the backend to support dynami def test_offset_field(cartesian_case): ``` -In order to selectively enable the backends, the dictionary `next_tests.exclusion_matrices.BACKEND_SKIP_TEST_MATRIX` +In order to selectively enable the backends, the dictionary `next_tests.definitions.BACKEND_SKIP_TEST_MATRIX` lists for each backend the features that are not supported. The fixture will check if the annotated feature is present in the exclusion-matrix for the selected backend. If so, the exclusion matrix will also specify the action `pytest` should take (e.g. `SKIP` or `XFAIL`). diff --git a/pyproject.toml b/pyproject.toml index 675bdae9d0..51cfc267d5 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -340,7 +340,11 @@ markers = [ 'uses_origin: tests that require backend support for domain origin', 'uses_reduction_over_lift_expressions: tests that require backend support for reduction over lift expressions', 'uses_reduction_with_only_sparse_fields: tests that require backend support for with sparse fields', + 'uses_scan: tests that uses scan', 'uses_scan_in_field_operator: tests that require backend support for scan in field operator', + 'uses_scan_without_field_args: tests that require calls to scan that do not have any fields as arguments', + 'uses_scan_nested: tests that use nested scans', + 'uses_scan_requiring_projector: tests need a projector implementation in gtfn', 'uses_sparse_fields: tests that require backend support for sparse fields', 'uses_sparse_fields_as_output: tests that require backend support for writing sparse fields', 'uses_strided_neighbor_offset: tests that require backend support for strided neighbor offset', @@ -349,7 +353,7 @@ markers = [ 'uses_zero_dimensional_fields: tests that require backend support for zero-dimensional fields', 'uses_cartesian_shift: tests that use a Cartesian connectivity', 'uses_unstructured_shift: tests that use a unstructured connectivity', - 'uses_scan: tests that uses scan', + 'uses_max_over: tests that use the max_over builtin', 'checks_specific_error: tests that rely on the backend to produce a specific error message' ] norecursedirs = ['dist', 'build', 'cpp_backend_tests/build*', '_local/*', '.*'] diff --git a/src/gt4py/next/allocators.py b/src/gt4py/next/allocators.py index 97e83276fe..44203bf6d8 100644 --- a/src/gt4py/next/allocators.py +++ b/src/gt4py/next/allocators.py @@ -231,6 +231,7 @@ def __init__(self) -> None: device_allocators[core_defs.DeviceType.CPU] = StandardCPUFieldBufferAllocator() + assert is_field_allocator(device_allocators[core_defs.DeviceType.CPU]) diff --git a/src/gt4py/next/common.py b/src/gt4py/next/common.py index 949f4b461a..33a0591813 100644 --- a/src/gt4py/next/common.py +++ b/src/gt4py/next/common.py @@ -843,8 +843,10 @@ def is_connectivity_field( return isinstance(v, ConnectivityField) # type: ignore[misc] # we use extended_runtime_checkable +# Utility function to construct a `Field` from different buffer representations. +# Consider removing this function and using `Field` constructor directly. See also `_connectivity`. @functools.singledispatch -def field( +def _field( definition: Any, /, *, @@ -854,8 +856,9 @@ def field( raise NotImplementedError +# See comment for `_field`. @functools.singledispatch -def connectivity( +def _connectivity( definition: Any, /, codomain: Dimension, @@ -980,7 +983,7 @@ def restrict(self, index: AnyIndexSpec) -> core_defs.IntegralScalar: __getitem__ = restrict -connectivity.register(numbers.Integral, CartesianConnectivity.from_offset) +_connectivity.register(numbers.Integral, CartesianConnectivity.from_offset) @enum.unique diff --git a/src/gt4py/next/constructors.py b/src/gt4py/next/constructors.py index 9bb4cf17e5..8b41bf7cba 100644 --- a/src/gt4py/next/constructors.py +++ b/src/gt4py/next/constructors.py @@ -87,7 +87,7 @@ def empty( buffer = next_allocators.allocate( domain, dtype, aligned_index=aligned_index, allocator=allocator, device=device ) - res = common.field(buffer.ndarray, domain=domain) + res = common._field(buffer.ndarray, domain=domain) assert common.is_mutable_field(res) assert isinstance(res, nd_array_field.NdArrayField) return res @@ -356,9 +356,9 @@ def as_connectivity( if (allocator is None) and (device is None) and xtyping.supports_dlpack(data): device = core_defs.Device(*data.__dlpack_device__()) buffer = next_allocators.allocate(actual_domain, dtype, allocator=allocator, device=device) - # TODO(havogt): consider addin MutableNDArrayObject + # TODO(havogt): consider adding MutableNDArrayObject buffer.ndarray[...] = storage_utils.asarray(data) # type: ignore[index] - connectivity_field = common.connectivity( + connectivity_field = common._connectivity( buffer.ndarray, codomain=codomain, domain=actual_domain ) assert isinstance(connectivity_field, nd_array_field.NdArrayConnectivityField) diff --git a/src/gt4py/next/embedded/nd_array_field.py b/src/gt4py/next/embedded/nd_array_field.py index 9fc1b42038..52a61b40bb 100644 --- a/src/gt4py/next/embedded/nd_array_field.py +++ b/src/gt4py/next/embedded/nd_array_field.py @@ -95,9 +95,7 @@ class NdArrayField( _domain: common.Domain _ndarray: core_defs.NDArrayObject - array_ns: ClassVar[ - ModuleType - ] # TODO(havogt) after storage PR is merged, update to the NDArrayNamespace protocol + array_ns: ClassVar[ModuleType] # TODO(havogt) introduce a NDArrayNamespace protocol @property def domain(self) -> common.Domain: @@ -197,7 +195,11 @@ def remap( # finally, take the new array new_buffer = xp.take(self._ndarray, new_idx_array, axis=dim_idx) - return self.__class__.from_array(new_buffer, domain=new_domain, dtype=self.dtype) + return self.__class__.from_array( + new_buffer, + domain=new_domain, + dtype=self.dtype, + ) __call__ = remap # type: ignore[assignment] @@ -510,7 +512,7 @@ class NumPyArrayField(NdArrayField): array_ns: ClassVar[ModuleType] = np -common.field.register(np.ndarray, NumPyArrayField.from_array) +common._field.register(np.ndarray, NumPyArrayField.from_array) @dataclasses.dataclass(frozen=True, eq=False) @@ -518,7 +520,7 @@ class NumPyArrayConnectivityField(NdArrayConnectivityField): array_ns: ClassVar[ModuleType] = np -common.connectivity.register(np.ndarray, NumPyArrayConnectivityField.from_array) +common._connectivity.register(np.ndarray, NumPyArrayConnectivityField.from_array) # CuPy if cp: @@ -528,13 +530,13 @@ class NumPyArrayConnectivityField(NdArrayConnectivityField): class CuPyArrayField(NdArrayField): array_ns: ClassVar[ModuleType] = cp - common.field.register(cp.ndarray, CuPyArrayField.from_array) + common._field.register(cp.ndarray, CuPyArrayField.from_array) @dataclasses.dataclass(frozen=True, eq=False) class CuPyArrayConnectivityField(NdArrayConnectivityField): array_ns: ClassVar[ModuleType] = cp - common.connectivity.register(cp.ndarray, CuPyArrayConnectivityField.from_array) + common._connectivity.register(cp.ndarray, CuPyArrayConnectivityField.from_array) # JAX if jnp: @@ -552,7 +554,7 @@ def __setitem__( # TODO(havogt): use something like `self.ndarray = self.ndarray.at(index).set(value)` raise NotImplementedError("'__setitem__' for JaxArrayField not yet implemented.") - common.field.register(jnp.ndarray, JaxArrayField.from_array) + common._field.register(jnp.ndarray, JaxArrayField.from_array) def _broadcast(field: common.Field, new_dimensions: tuple[common.Dimension, ...]) -> common.Field: @@ -565,7 +567,7 @@ def _broadcast(field: common.Field, new_dimensions: tuple[common.Dimension, ...] else: domain_slice.append(np.newaxis) named_ranges.append((dim, common.UnitRange.infinite())) - return common.field(field.ndarray[tuple(domain_slice)], domain=common.Domain(*named_ranges)) + return common._field(field.ndarray[tuple(domain_slice)], domain=common.Domain(*named_ranges)) def _builtins_broadcast( diff --git a/src/gt4py/next/embedded/operators.py b/src/gt4py/next/embedded/operators.py index 0992401ebb..cb03373b41 100644 --- a/src/gt4py/next/embedded/operators.py +++ b/src/gt4py/next/embedded/operators.py @@ -13,11 +13,14 @@ # SPDX-License-Identifier: GPL-3.0-or-later import dataclasses +from types import ModuleType from typing import Any, Callable, Generic, ParamSpec, Sequence, TypeVar +import numpy as np + from gt4py import eve from gt4py._core import definitions as core_defs -from gt4py.next import common, constructors, errors, utils +from gt4py.next import common, errors, utils from gt4py.next.embedded import common as embedded_common, context as embedded_context @@ -43,7 +46,8 @@ def __call__(self, *args: common.Field | core_defs.Scalar, **kwargs: common.Fiel scan_range = embedded_context.closure_column_range.get() assert self.axis == scan_range[0] scan_axis = scan_range[0] - domain_intersection = _intersect_scan_args(*args, *kwargs.values()) + all_args = [*args, *kwargs.values()] + domain_intersection = _intersect_scan_args(*all_args) non_scan_domain = common.Domain(*[nr for nr in domain_intersection if nr[0] != scan_axis]) out_domain = common.Domain( @@ -53,7 +57,8 @@ def __call__(self, *args: common.Field | core_defs.Scalar, **kwargs: common.Fiel # even if the scan dimension is not in the input, we can scan over it out_domain = common.Domain(*out_domain, (scan_range)) - res = _construct_scan_array(out_domain)(self.init) + xp = _get_array_ns(*all_args) + res = _construct_scan_array(out_domain, xp)(self.init) def scan_loop(hpos): acc = self.init @@ -128,7 +133,11 @@ def _tuple_assign_field( ): @utils.tree_map def impl(target: common.MutableField, source: common.Field): - target[domain] = source[domain] + if common.is_field(source): + target[domain] = source[domain] + else: + assert core_defs.is_scalar_type(source) + target[domain] = source impl(target, source) @@ -141,10 +150,21 @@ def _intersect_scan_args( ) -def _construct_scan_array(domain: common.Domain): +def _get_array_ns( + *args: core_defs.Scalar | common.Field | tuple[core_defs.Scalar | common.Field | tuple, ...] +) -> ModuleType: + for arg in utils.flatten_nested_tuple(args): + if hasattr(arg, "array_ns"): + return arg.array_ns + return np + + +def _construct_scan_array( + domain: common.Domain, xp: ModuleType +): # TODO(havogt) introduce a NDArrayNamespace protocol @utils.tree_map def impl(init: core_defs.Scalar) -> common.Field: - return constructors.empty(domain, dtype=type(init)) + return common._field(xp.empty(domain.shape, dtype=type(init)), domain=domain) return impl @@ -168,6 +188,7 @@ def _tuple_at( @utils.tree_map def impl(field: common.Field | core_defs.Scalar) -> core_defs.Scalar: res = field[pos] if common.is_field(field) else field + res = res.item() if hasattr(res, "item") else res # extract scalar value from array assert core_defs.is_scalar_type(res) return res diff --git a/src/gt4py/next/ffront/fbuiltins.py b/src/gt4py/next/ffront/fbuiltins.py index cd75538da7..493493f697 100644 --- a/src/gt4py/next/ffront/fbuiltins.py +++ b/src/gt4py/next/ffront/fbuiltins.py @@ -188,12 +188,8 @@ def broadcast( assert core_defs.is_scalar_type( field ) # default implementation for scalars, Fields are handled via dispatch - return common.field( - np.asarray(field)[ - tuple([np.newaxis] * len(dims)) - ], # TODO(havogt) use FunctionField once available - domain=common.Domain(dims=dims, ranges=tuple([common.UnitRange.infinite()] * len(dims))), - ) + # TODO(havogt) implement with FunctionField, the workaround is to ignore broadcasting on scalars as they broadcast automatically, but we lose the check for compatible dimensions + return field # type: ignore[return-value] # see comment above @WhereBuiltinFunction diff --git a/src/gt4py/next/iterator/embedded.py b/src/gt4py/next/iterator/embedded.py index 390bec4312..6d610fd136 100644 --- a/src/gt4py/next/iterator/embedded.py +++ b/src/gt4py/next/iterator/embedded.py @@ -1035,7 +1035,7 @@ def _maker(a) -> common.Field: offset = origin.get(d, 0) ranges.append(common.UnitRange(-offset, s - offset)) - res = common.field(a, domain=common.Domain(dims=tuple(axes), ranges=tuple(ranges))) + res = common._field(a, domain=common.Domain(dims=tuple(axes), ranges=tuple(ranges))) return res return _maker diff --git a/tests/next_tests/__init__.py b/tests/next_tests/__init__.py index e2905ab49a..1745dac6ef 100644 --- a/tests/next_tests/__init__.py +++ b/tests/next_tests/__init__.py @@ -12,10 +12,10 @@ # # SPDX-License-Identifier: GPL-3.0-or-later -from . import exclusion_matrices +from . import definitions -__all__ = ["exclusion_matrices", "get_processor_id"] +__all__ = ["definitions", "get_processor_id"] def get_processor_id(processor): diff --git a/tests/next_tests/exclusion_matrices.py b/tests/next_tests/definitions.py similarity index 77% rename from tests/next_tests/exclusion_matrices.py rename to tests/next_tests/definitions.py index f6d2b10a14..dbb2366f47 100644 --- a/tests/next_tests/exclusion_matrices.py +++ b/tests/next_tests/definitions.py @@ -14,11 +14,16 @@ """Contains definition of test-exclusion matrices, see ADR 15.""" +import dataclasses import enum import importlib +from typing import Final, Optional, Protocol import pytest +from gt4py.next import allocators as next_allocators +from gt4py.next.program_processors import processor_interface as ppi + # Skip definitions XFAIL = pytest.xfail @@ -38,8 +43,6 @@ def load(self) -> object: obj = eval(f"_m.{obj}", globs) return obj - __invert__ = load - def short_id(self, num_components: int = 2) -> str: return ".".join(self.value.split(".")[-num_components:]) @@ -55,6 +58,32 @@ class ProgramBackendId(_PythonObjectIdMixin, str, enum.Enum): DOUBLE_ROUNDTRIP = "gt4py.next.program_processors.runners.double_roundtrip.backend" +class ExecutionAndAllocatorDescriptor(Protocol): + # Used for test infrastructure, consider implementing this in gt4py when refactoring otf + @property + def executor(self) -> Optional[ppi.ProgramExecutor]: + ... + + @property + def allocator(self) -> next_allocators.FieldBufferAllocatorProtocol: + ... + + +@dataclasses.dataclass(frozen=True) +class EmbeddedExecutionDescriptor: + allocator: next_allocators.FieldBufferAllocatorProtocol + executor: Final = None + + +numpy_execution = EmbeddedExecutionDescriptor(next_allocators.StandardCPUFieldBufferAllocator()) +cupy_execution = EmbeddedExecutionDescriptor(next_allocators.StandardGPUFieldBufferAllocator()) + + +class EmbeddedIds(_PythonObjectIdMixin, str, enum.Enum): + NUMPY_EXECUTION = "next_tests.definitions.numpy_execution" + CUPY_EXECUTION = "next_tests.definitions.cupy_execution" + + class OptionalProgramBackendId(_PythonObjectIdMixin, str, enum.Enum): DACE_CPU = "gt4py.next.program_processors.runners.dace_iterator.run_dace_cpu" DACE_GPU = "gt4py.next.program_processors.runners.dace_iterator.run_dace_gpu" @@ -93,7 +122,11 @@ class ProgramFormatterId(_PythonObjectIdMixin, str, enum.Enum): USES_NEGATIVE_MODULO = "uses_negative_modulo" USES_ORIGIN = "uses_origin" USES_REDUCTION_OVER_LIFT_EXPRESSIONS = "uses_reduction_over_lift_expressions" +USES_SCAN = "uses_scan" USES_SCAN_IN_FIELD_OPERATOR = "uses_scan_in_field_operator" +USES_SCAN_WITHOUT_FIELD_ARGS = "uses_scan_without_field_args" +USES_SCAN_NESTED = "uses_scan_nested" +USES_SCAN_REQUIRING_PROJECTOR = "uses_scan_requiring_projector" USES_SPARSE_FIELDS = "uses_sparse_fields" USES_SPARSE_FIELDS_AS_OUTPUT = "uses_sparse_fields_as_output" USES_REDUCTION_WITH_ONLY_SPARSE_FIELDS = "uses_reduction_with_only_sparse_fields" @@ -103,7 +136,7 @@ class ProgramFormatterId(_PythonObjectIdMixin, str, enum.Enum): USES_ZERO_DIMENSIONAL_FIELDS = "uses_zero_dimensional_fields" USES_CARTESIAN_SHIFT = "uses_cartesian_shift" USES_UNSTRUCTURED_SHIFT = "uses_unstructured_shift" -USES_SCAN = "uses_scan" +USES_MAX_OVER = "uses_max_over" CHECKS_SPECIFIC_ERROR = "checks_specific_error" # Skip messages (available format keys: 'marker', 'backend') @@ -134,26 +167,38 @@ class ProgramFormatterId(_PythonObjectIdMixin, str, enum.Enum): EMBEDDED_SKIP_LIST = [ (USES_DYNAMIC_OFFSETS, XFAIL, UNSUPPORTED_MESSAGE), (CHECKS_SPECIFIC_ERROR, XFAIL, UNSUPPORTED_MESSAGE), + ( + USES_SCAN_WITHOUT_FIELD_ARGS, + XFAIL, + UNSUPPORTED_MESSAGE, + ), # we can't extract the field type from scan args ] GTFN_SKIP_TEST_LIST = COMMON_SKIP_TEST_LIST + [ # floordiv not yet supported, see https://github.com/GridTools/gt4py/issues/1136 (USES_FLOORDIV, XFAIL, BINDINGS_UNSUPPORTED_MESSAGE), (USES_STRIDED_NEIGHBOR_OFFSET, XFAIL, BINDINGS_UNSUPPORTED_MESSAGE), + # max_over broken, see https://github.com/GridTools/gt4py/issues/1289 + (USES_MAX_OVER, XFAIL, UNSUPPORTED_MESSAGE), + (USES_SCAN_REQUIRING_PROJECTOR, XFAIL, UNSUPPORTED_MESSAGE), ] #: Skip matrix, contains for each backend processor a list of tuples with following fields: #: (, ) BACKEND_SKIP_TEST_MATRIX = { - None: EMBEDDED_SKIP_LIST, + EmbeddedIds.NUMPY_EXECUTION: EMBEDDED_SKIP_LIST, + EmbeddedIds.CUPY_EXECUTION: EMBEDDED_SKIP_LIST, OptionalProgramBackendId.DACE_CPU: DACE_SKIP_TEST_LIST, OptionalProgramBackendId.DACE_GPU: DACE_SKIP_TEST_LIST + [ # awaiting dace fix, see https://github.com/spcl/dace/pull/1442 (USES_FLOORDIV, XFAIL, BINDINGS_UNSUPPORTED_MESSAGE), ], - ProgramBackendId.GTFN_CPU: GTFN_SKIP_TEST_LIST, - ProgramBackendId.GTFN_CPU_IMPERATIVE: GTFN_SKIP_TEST_LIST, - ProgramBackendId.GTFN_GPU: GTFN_SKIP_TEST_LIST, + ProgramBackendId.GTFN_CPU: GTFN_SKIP_TEST_LIST + + [(USES_SCAN_NESTED, XFAIL, UNSUPPORTED_MESSAGE)], + ProgramBackendId.GTFN_CPU_IMPERATIVE: GTFN_SKIP_TEST_LIST + + [(USES_SCAN_NESTED, XFAIL, UNSUPPORTED_MESSAGE)], + ProgramBackendId.GTFN_GPU: GTFN_SKIP_TEST_LIST + + [(USES_SCAN_NESTED, XFAIL, UNSUPPORTED_MESSAGE)], ProgramBackendId.GTFN_CPU_WITH_TEMPORARIES: GTFN_SKIP_TEST_LIST + [ (USES_DYNAMIC_OFFSETS, XFAIL, UNSUPPORTED_MESSAGE), diff --git a/tests/next_tests/integration_tests/cases.py b/tests/next_tests/integration_tests/cases.py index 6217d3c782..03a0a9f5a7 100644 --- a/tests/next_tests/integration_tests/cases.py +++ b/tests/next_tests/integration_tests/cases.py @@ -28,11 +28,12 @@ from gt4py._core import definitions as core_defs from gt4py.eve import extended_typing as xtyping from gt4py.eve.extended_typing import Self -from gt4py.next import common, constructors, field_utils +from gt4py.next import allocators as next_allocators, common, constructors, field_utils from gt4py.next.ffront import decorator from gt4py.next.program_processors import processor_interface as ppi from gt4py.next.type_system import type_specifications as ts, type_translation +from next_tests import definitions as test_definitions from next_tests.integration_tests.feature_tests.ffront_tests.ffront_test_utils import ( # noqa: F401 # fixture and aliases Cell, Edge, @@ -43,7 +44,7 @@ KDim, Koff, Vertex, - fieldview_backend, + exec_alloc_descriptor, reduction_setup, ) @@ -103,7 +104,7 @@ def scalar(self, dtype: np.typing.DTypeLike) -> ScalarValue: def field( self, - backend: ppi.ProgramProcessor, + allocator: next_allocators.FieldBufferAllocatorProtocol, sizes: dict[gtx.Dimension, int], dtype: np.typing.DTypeLike, ) -> FieldValue: @@ -137,7 +138,7 @@ def scalar_value(self) -> ScalarValue: def field( self, - backend: ppi.ProgramExecutor, + allocator: next_allocators.FieldBufferAllocatorProtocol, sizes: dict[gtx.Dimension, int], dtype: np.typing.DTypeLike, ) -> FieldValue: @@ -145,7 +146,7 @@ def field( domain=common.domain(sizes), fill_value=self.value, dtype=dtype, - allocator=backend, + allocator=allocator, ) @@ -166,7 +167,7 @@ def scalar_value(self) -> ScalarValue: def field( self, - backend: ppi.ProgramExecutor, + allocator: next_allocators.FieldBufferAllocatorProtocol, sizes: dict[gtx.Dimension, int], dtype: np.typing.DTypeLike, ) -> FieldValue: @@ -176,7 +177,7 @@ def field( ) n_data = list(sizes.values())[0] return constructors.as_field( - domain=common.domain(sizes), data=np.arange(0, n_data, dtype=dtype), allocator=backend + domain=common.domain(sizes), data=np.arange(0, n_data, dtype=dtype), allocator=allocator ) def from_case( @@ -207,7 +208,7 @@ def scalar_value(self) -> ScalarValue: def field( self, - backend: ppi.ProgramProcessor, + allocator: next_allocators.FieldBufferAllocatorProtocol, sizes: dict[gtx.Dimension, int], dtype: np.typing.DTypeLike, ) -> FieldValue: @@ -218,7 +219,7 @@ def field( return constructors.as_field( common.domain(sizes), np.arange(start, start + n_data, dtype=dtype).reshape(svals), - allocator=backend, + allocator=allocator, ) def from_case( @@ -382,7 +383,7 @@ def run( """Run fieldview code in the context of a given test case.""" if kwargs.get("offset_provider", None) is None: kwargs["offset_provider"] = case.offset_provider - fieldview_prog.with_grid_type(case.grid_type).with_backend(case.backend)(*args, **kwargs) + fieldview_prog.with_grid_type(case.grid_type).with_backend(case.executor)(*args, **kwargs) def verify( @@ -480,19 +481,25 @@ def verify_with_default_data( @pytest.fixture -def cartesian_case(fieldview_backend): # noqa: F811 # fixtures +def cartesian_case( + exec_alloc_descriptor: test_definitions.ExecutionAndAllocatorDescriptor, # noqa: F811 # fixtures +): yield Case( - fieldview_backend, + exec_alloc_descriptor.executor, offset_provider={"Ioff": IDim, "Joff": JDim, "Koff": KDim}, default_sizes={IDim: 10, JDim: 10, KDim: 10}, grid_type=common.GridType.CARTESIAN, + allocator=exec_alloc_descriptor.allocator, ) @pytest.fixture -def unstructured_case(reduction_setup, fieldview_backend): # noqa: F811 # fixtures +def unstructured_case( + reduction_setup, # noqa: F811 # fixtures + exec_alloc_descriptor: test_definitions.ExecutionAndAllocatorDescriptor, # noqa: F811 # fixtures +): yield Case( - fieldview_backend, + exec_alloc_descriptor.executor, offset_provider=reduction_setup.offset_provider, default_sizes={ Vertex: reduction_setup.num_vertices, @@ -501,6 +508,7 @@ def unstructured_case(reduction_setup, fieldview_backend): # noqa: F811 # fixtu KDim: reduction_setup.k_levels, }, grid_type=common.GridType.UNSTRUCTURED, + allocator=exec_alloc_descriptor.allocator, ) @@ -516,7 +524,7 @@ def _allocate_from_type( match arg_type: case ts.FieldType(dims=dims, dtype=arg_dtype): return strategy.field( - backend=case.backend, + allocator=case.allocator, sizes={dim: sizes[dim] for dim in dims}, dtype=dtype or arg_dtype.kind.name.lower(), ) @@ -601,11 +609,12 @@ def get_default_data( class Case: """Parametrizable components for single feature integration tests.""" - backend: ppi.ProgramProcessor + executor: Optional[ppi.ProgramProcessor] offset_provider: dict[str, common.Connectivity | gtx.Dimension] default_sizes: dict[gtx.Dimension, int] grid_type: common.GridType + allocator: next_allocators.FieldBufferAllocatorFactoryProtocol @property def as_field(self): - return constructors.as_field.partial(allocator=self.backend) + return constructors.as_field.partial(allocator=self.allocator) diff --git a/tests/next_tests/integration_tests/feature_tests/ffront_tests/ffront_test_utils.py b/tests/next_tests/integration_tests/feature_tests/ffront_tests/ffront_test_utils.py index 1f5a1f0c48..e421763699 100644 --- a/tests/next_tests/integration_tests/feature_tests/ffront_tests/ffront_test_utils.py +++ b/tests/next_tests/integration_tests/feature_tests/ffront_tests/ffront_test_utils.py @@ -14,7 +14,7 @@ # SPDX-License-Identifier: GPL-3.0-or-later from collections import namedtuple -from typing import Any, TypeVar +from typing import Any, Optional, TypeVar import numpy as np import pytest @@ -35,7 +35,6 @@ raise e import next_tests -import next_tests.exclusion_matrices as definitions @ppi.program_executor @@ -46,25 +45,33 @@ def no_backend(program: itir.FencilDefinition, *args: Any, **kwargs: Any) -> Non OPTIONAL_PROCESSORS = [] if dace_iterator: - OPTIONAL_PROCESSORS.append(definitions.OptionalProgramBackendId.DACE_CPU) + OPTIONAL_PROCESSORS.append(next_tests.definitions.OptionalProgramBackendId.DACE_CPU) OPTIONAL_PROCESSORS.append( - pytest.param(definitions.OptionalProgramBackendId.DACE_GPU, marks=pytest.mark.requires_gpu) + pytest.param( + next_tests.definitions.OptionalProgramBackendId.DACE_GPU, marks=pytest.mark.requires_gpu + ) ), @pytest.fixture( params=[ - definitions.ProgramBackendId.ROUNDTRIP, - definitions.ProgramBackendId.GTFN_CPU, - definitions.ProgramBackendId.GTFN_CPU_IMPERATIVE, - definitions.ProgramBackendId.GTFN_CPU_WITH_TEMPORARIES, - pytest.param(definitions.ProgramBackendId.GTFN_GPU, marks=pytest.mark.requires_gpu), - None, + next_tests.definitions.ProgramBackendId.ROUNDTRIP, + next_tests.definitions.ProgramBackendId.GTFN_CPU, + next_tests.definitions.ProgramBackendId.GTFN_CPU_IMPERATIVE, + next_tests.definitions.ProgramBackendId.GTFN_CPU_WITH_TEMPORARIES, + pytest.param( + next_tests.definitions.ProgramBackendId.GTFN_GPU, marks=pytest.mark.requires_gpu + ), + # will use the default (embedded) execution, but input/output allocated with the provided allocator + next_tests.definitions.EmbeddedIds.NUMPY_EXECUTION, + pytest.param( + next_tests.definitions.EmbeddedIds.CUPY_EXECUTION, marks=pytest.mark.requires_gpu + ), ] + OPTIONAL_PROCESSORS, - ids=lambda p: p.short_id() if p is not None else "None", + ids=lambda p: p.short_id(), ) -def fieldview_backend(request): +def exec_alloc_descriptor(request): """ Fixture creating field-view operator backend on-demand for tests. @@ -72,9 +79,9 @@ def fieldview_backend(request): Check ADR 15 for details on the test-exclusion matrices. """ backend_id = request.param - backend = None if backend_id is None else backend_id.load() + backend = backend_id.load() - for marker, skip_mark, msg in next_tests.exclusion_matrices.BACKEND_SKIP_TEST_MATRIX.get( + for marker, skip_mark, msg in next_tests.definitions.BACKEND_SKIP_TEST_MATRIX.get( backend_id, [] ): if request.node.get_closest_marker(marker): @@ -225,7 +232,7 @@ def reduction_setup(): __all__ = [ - "fieldview_backend", + "exec_alloc_descriptor", "reduction_setup", "debug_itir", "DimsType", diff --git a/tests/next_tests/integration_tests/feature_tests/ffront_tests/test_arg_call_interface.py b/tests/next_tests/integration_tests/feature_tests/ffront_tests/test_arg_call_interface.py index b41696a36b..354323afeb 100644 --- a/tests/next_tests/integration_tests/feature_tests/ffront_tests/test_arg_call_interface.py +++ b/tests/next_tests/integration_tests/feature_tests/ffront_tests/test_arg_call_interface.py @@ -26,7 +26,7 @@ from next_tests.integration_tests import cases from next_tests.integration_tests.cases import IDim, IField, IJKFloatField, KDim, cartesian_case from next_tests.integration_tests.feature_tests.ffront_tests.ffront_test_utils import ( - fieldview_backend, + exec_alloc_descriptor, ) @@ -59,7 +59,7 @@ def testee(a: IField, b: IField, c: IField) -> IField: pos_args = [args[name] for name in arg_names] kw_args = {name: args[name] for name in kwarg_names} - testee.with_backend(cartesian_case.backend)( + testee.with_backend(cartesian_case.executor)( *pos_args, **kw_args, out=out, offset_provider=cartesian_case.offset_provider ) @@ -85,7 +85,7 @@ def testee(a: IField, b: IField, out: IField): pos_args = [args[name] for name in arg_names] kw_args = {name: args[name] for name in kwarg_names} - testee.with_backend(cartesian_case.backend)( + testee.with_backend(cartesian_case.executor)( *pos_args, **kw_args, offset_provider=cartesian_case.offset_provider ) diff --git a/tests/next_tests/integration_tests/feature_tests/ffront_tests/test_bound_args.py b/tests/next_tests/integration_tests/feature_tests/ffront_tests/test_bound_args.py index 0de953d85f..e4baedc6ee 100644 --- a/tests/next_tests/integration_tests/feature_tests/ffront_tests/test_bound_args.py +++ b/tests/next_tests/integration_tests/feature_tests/ffront_tests/test_bound_args.py @@ -21,7 +21,7 @@ from next_tests.integration_tests import cases from next_tests.integration_tests.cases import cartesian_case from next_tests.integration_tests.feature_tests.ffront_tests.ffront_test_utils import ( - fieldview_backend, + exec_alloc_descriptor, reduction_setup, ) @@ -52,7 +52,7 @@ def fieldop_args(a: cases.IField, condition: bool, scalar: int32) -> cases.IFiel scalar = 0 if not condition else scalar return a + scalar - @gtx.program(backend=cartesian_case.backend) + @gtx.program(backend=cartesian_case.executor) def program_args(a: cases.IField, condition: bool, scalar: int32, out: cases.IField): fieldop_args(a, condition, scalar, out=out) diff --git a/tests/next_tests/integration_tests/feature_tests/ffront_tests/test_execution.py b/tests/next_tests/integration_tests/feature_tests/ffront_tests/test_execution.py index 70c79d7b6c..9482860d13 100644 --- a/tests/next_tests/integration_tests/feature_tests/ffront_tests/test_execution.py +++ b/tests/next_tests/integration_tests/feature_tests/ffront_tests/test_execution.py @@ -51,7 +51,7 @@ unstructured_case, ) from next_tests.integration_tests.feature_tests.ffront_tests.ffront_test_utils import ( - fieldview_backend, + exec_alloc_descriptor, reduction_setup, ) @@ -589,6 +589,7 @@ def testee(a: tuple[tuple[cases.IField, cases.IField], cases.IField]) -> cases.I @pytest.mark.uses_scan +@pytest.mark.uses_scan_without_field_args @pytest.mark.parametrize("forward", [True, False]) def test_fieldop_from_scan(cartesian_case, forward): init = 1.0 @@ -611,15 +612,9 @@ def simple_scan_operator(carry: float) -> float: @pytest.mark.uses_scan @pytest.mark.uses_lift_expressions +@pytest.mark.uses_scan_nested def test_solve_triag(cartesian_case): - if cartesian_case.backend in [ - gtfn.run_gtfn, - gtfn.run_gtfn_gpu, - gtfn.run_gtfn_imperative, - gtfn.run_gtfn_with_temporaries, - ]: - pytest.xfail("Nested `scan`s requires creating temporaries.") - if cartesian_case.backend == gtfn.run_gtfn_with_temporaries: + if cartesian_case.executor == gtfn.run_gtfn_with_temporaries: pytest.xfail("Temporary extraction does not work correctly in combination with scans.") @gtx.scan_operator(axis=KDim, forward=True, init=(0.0, 0.0)) @@ -721,7 +716,7 @@ def testee(a: cases.EField, b: cases.EField) -> cases.VField: @pytest.mark.uses_scan def test_ternary_scan(cartesian_case): - if cartesian_case.backend in [gtfn.run_gtfn_with_temporaries]: + if cartesian_case.executor in [gtfn.run_gtfn_with_temporaries]: pytest.xfail("Temporary extraction does not work correctly in combination with scans.") @gtx.scan_operator(axis=KDim, forward=True, init=0.0) @@ -743,9 +738,10 @@ def simple_scan_operator(carry: float, a: float) -> float: @pytest.mark.parametrize("forward", [True, False]) @pytest.mark.uses_scan +@pytest.mark.uses_scan_without_field_args @pytest.mark.uses_tuple_returns def test_scan_nested_tuple_output(forward, cartesian_case): - if cartesian_case.backend in [gtfn.run_gtfn_with_temporaries]: + if cartesian_case.executor in [gtfn.run_gtfn_with_temporaries]: pytest.xfail("Temporary extraction does not work correctly in combination with scans.") init = (1, (2, 3)) @@ -916,15 +912,8 @@ def program_domain(a: cases.IField, out: cases.IField): cases.verify(cartesian_case, program_domain, a, out, inout=out, ref=ref) +@pytest.mark.uses_floordiv def test_domain_input_bounds(cartesian_case): - if cartesian_case.backend in [ - gtfn.run_gtfn, - gtfn.run_gtfn_gpu, - gtfn.run_gtfn_imperative, - gtfn.run_gtfn_with_temporaries, - ]: - pytest.xfail("FloorDiv not fully supported in gtfn.") - lower_i = 1 upper_i = 10 @@ -970,7 +959,7 @@ def test_domain_input_bounds_1(cartesian_case): def fieldop_domain(a: cases.IJField) -> cases.IJField: return a + a - @gtx.program(backend=cartesian_case.backend) + @gtx.program(backend=cartesian_case.executor) def program_domain( a: cases.IJField, out: cases.IJField, @@ -1071,7 +1060,7 @@ def prog(inp: cases.IKField, k_index: gtx.Field[[KDim], gtx.IndexType], out: cas def test_undefined_symbols(cartesian_case): with pytest.raises(errors.DSLError, match="Undeclared symbol"): - @gtx.field_operator(backend=cartesian_case.backend) + @gtx.field_operator(backend=cartesian_case.executor) def return_undefined(): return undefined_symbol @@ -1171,7 +1160,7 @@ def test_tuple_unpacking_too_many_values(cartesian_case): match=(r"Too many values to unpack \(expected 3\)."), ): - @gtx.field_operator(backend=cartesian_case.backend) + @gtx.field_operator(backend=cartesian_case.executor) def _star_unpack() -> tuple[int32, float64, int32]: a, b, c = (1, 2.0, 3, 4, 5, 6, 7.0) return a, b, c @@ -1182,7 +1171,7 @@ def test_tuple_unpacking_too_few_values(cartesian_case): errors.DSLError, match=(r"Assignment value must be of type tuple, got 'int32'.") ): - @gtx.field_operator(backend=cartesian_case.backend) + @gtx.field_operator(backend=cartesian_case.executor) def _invalid_unpack() -> tuple[int32, float64, int32]: a, b, c = 1 return a diff --git a/tests/next_tests/integration_tests/feature_tests/ffront_tests/test_external_local_field.py b/tests/next_tests/integration_tests/feature_tests/ffront_tests/test_external_local_field.py index d100cd380c..bb1d878a6a 100644 --- a/tests/next_tests/integration_tests/feature_tests/ffront_tests/test_external_local_field.py +++ b/tests/next_tests/integration_tests/feature_tests/ffront_tests/test_external_local_field.py @@ -21,7 +21,7 @@ from next_tests.integration_tests import cases from next_tests.integration_tests.cases import V2E, Edge, V2EDim, Vertex, unstructured_case from next_tests.integration_tests.feature_tests.ffront_tests.ffront_test_utils import ( - fieldview_backend, + exec_alloc_descriptor, reduction_setup, ) diff --git a/tests/next_tests/integration_tests/feature_tests/ffront_tests/test_gt4py_builtins.py b/tests/next_tests/integration_tests/feature_tests/ffront_tests/test_gt4py_builtins.py index e2434d860a..90d07f360d 100644 --- a/tests/next_tests/integration_tests/feature_tests/ffront_tests/test_gt4py_builtins.py +++ b/tests/next_tests/integration_tests/feature_tests/ffront_tests/test_gt4py_builtins.py @@ -34,26 +34,19 @@ unstructured_case, ) from next_tests.integration_tests.feature_tests.ffront_tests.ffront_test_utils import ( - fieldview_backend, + exec_alloc_descriptor, reduction_setup, ) @pytest.mark.uses_unstructured_shift +@pytest.mark.uses_max_over @pytest.mark.parametrize( "strategy", [cases.UniqueInitializer(1), cases.UniqueInitializer(-100)], ids=["positive_values", "negative_values"], ) def test_maxover_execution_(unstructured_case, strategy): - if unstructured_case.backend in [ - gtfn.run_gtfn, - gtfn.run_gtfn_gpu, - gtfn.run_gtfn_imperative, - gtfn.run_gtfn_with_temporaries, - ]: - pytest.xfail("`maxover` broken in gtfn, see #1289.") - @gtx.field_operator def testee(edge_f: cases.EField) -> cases.VField: out = max_over(edge_f(V2E), axis=V2EDim) diff --git a/tests/next_tests/integration_tests/feature_tests/ffront_tests/test_math_builtin_execution.py b/tests/next_tests/integration_tests/feature_tests/ffront_tests/test_math_builtin_execution.py index 4444742c66..e076ec4227 100644 --- a/tests/next_tests/integration_tests/feature_tests/ffront_tests/test_math_builtin_execution.py +++ b/tests/next_tests/integration_tests/feature_tests/ffront_tests/test_math_builtin_execution.py @@ -28,7 +28,7 @@ from next_tests.integration_tests import cases from next_tests.integration_tests.cases import IDim, cartesian_case, unstructured_case from next_tests.integration_tests.feature_tests.ffront_tests.ffront_test_utils import ( - fieldview_backend, + exec_alloc_descriptor, ) from next_tests.integration_tests.feature_tests.math_builtin_test_data import math_builtin_test_data @@ -118,7 +118,7 @@ def make_builtin_field_operator(builtin_name: str, backend: Optional[ppi.Program @pytest.mark.parametrize("builtin_name, inputs", math_builtin_test_data()) def test_math_function_builtins_execution(cartesian_case, builtin_name: str, inputs): - if cartesian_case.backend is None: + if cartesian_case.executor is None: # TODO(havogt) find a way that works for embedded pytest.xfail("Test does not have a field view program.") if builtin_name == "gamma": @@ -131,7 +131,7 @@ def test_math_function_builtins_execution(cartesian_case, builtin_name: str, inp expected = ref_impl(*inputs) out = cartesian_case.as_field([IDim], np.zeros_like(expected)) - builtin_field_op = make_builtin_field_operator(builtin_name, cartesian_case.backend) + builtin_field_op = make_builtin_field_operator(builtin_name, cartesian_case.executor) builtin_field_op(*inps, out=out, offset_provider={}) diff --git a/tests/next_tests/integration_tests/feature_tests/ffront_tests/test_math_unary_builtins.py b/tests/next_tests/integration_tests/feature_tests/ffront_tests/test_math_unary_builtins.py index f5bf453a09..b21f29c9bc 100644 --- a/tests/next_tests/integration_tests/feature_tests/ffront_tests/test_math_unary_builtins.py +++ b/tests/next_tests/integration_tests/feature_tests/ffront_tests/test_math_unary_builtins.py @@ -41,7 +41,7 @@ from next_tests.integration_tests import cases from next_tests.integration_tests.cases import IDim, cartesian_case, unstructured_case from next_tests.integration_tests.feature_tests.ffront_tests.ffront_test_utils import ( - fieldview_backend, + exec_alloc_descriptor, ) diff --git a/tests/next_tests/integration_tests/feature_tests/ffront_tests/test_program.py b/tests/next_tests/integration_tests/feature_tests/ffront_tests/test_program.py index 938c69fb52..df0009d0d4 100644 --- a/tests/next_tests/integration_tests/feature_tests/ffront_tests/test_program.py +++ b/tests/next_tests/integration_tests/feature_tests/ffront_tests/test_program.py @@ -23,7 +23,13 @@ from gt4py.next import errors from next_tests.integration_tests import cases -from next_tests.integration_tests.cases import IDim, Ioff, JDim, cartesian_case, fieldview_backend +from next_tests.integration_tests.cases import ( + IDim, + Ioff, + JDim, + cartesian_case, + exec_alloc_descriptor, +) from next_tests.past_common_fixtures import ( copy_program_def, copy_restrict_program_def, @@ -34,7 +40,7 @@ def test_identity_fo_execution(cartesian_case, identity_def): - identity = gtx.field_operator(identity_def, backend=cartesian_case.backend) + identity = gtx.field_operator(identity_def, backend=cartesian_case.executor) in_field = cases.allocate(cartesian_case, identity, "in_field").strategy( cases.ConstInitializer(1) @@ -82,13 +88,13 @@ def shift_by_one_program(in_field: cases.IFloatField, out_field: cases.IFloatFie def test_copy_execution(cartesian_case, copy_program_def): - copy_program = gtx.program(copy_program_def, backend=cartesian_case.backend) + copy_program = gtx.program(copy_program_def, backend=cartesian_case.executor) cases.verify_with_default_data(cartesian_case, copy_program, ref=lambda in_field: in_field) def test_double_copy_execution(cartesian_case, double_copy_program_def): - double_copy_program = gtx.program(double_copy_program_def, backend=cartesian_case.backend) + double_copy_program = gtx.program(double_copy_program_def, backend=cartesian_case.executor) cases.verify_with_default_data( cartesian_case, double_copy_program, ref=lambda in_field, intermediate_field: in_field @@ -96,7 +102,7 @@ def test_double_copy_execution(cartesian_case, double_copy_program_def): def test_copy_restricted_execution(cartesian_case, copy_restrict_program_def): - copy_restrict_program = gtx.program(copy_restrict_program_def, backend=cartesian_case.backend) + copy_restrict_program = gtx.program(copy_restrict_program_def, backend=cartesian_case.executor) cases.verify_with_default_data( cartesian_case, @@ -218,7 +224,7 @@ def prog( def test_wrong_argument_type(cartesian_case, copy_program_def): - copy_program = gtx.program(copy_program_def, backend=cartesian_case.backend) + copy_program = gtx.program(copy_program_def, backend=cartesian_case.executor) inp = cartesian_case.as_field([JDim], np.ones((cartesian_case.default_sizes[JDim],))) out = cases.allocate(cartesian_case, copy_program, "out").strategy(cases.ConstInitializer(1))() diff --git a/tests/next_tests/integration_tests/feature_tests/ffront_tests/test_scalar_if.py b/tests/next_tests/integration_tests/feature_tests/ffront_tests/test_scalar_if.py index af06da3e29..834966c125 100644 --- a/tests/next_tests/integration_tests/feature_tests/ffront_tests/test_scalar_if.py +++ b/tests/next_tests/integration_tests/feature_tests/ffront_tests/test_scalar_if.py @@ -40,7 +40,7 @@ ) from next_tests.integration_tests.feature_tests.ffront_tests.ffront_test_utils import ( Cell, - fieldview_backend, + exec_alloc_descriptor, size, ) diff --git a/tests/next_tests/integration_tests/feature_tests/ffront_tests/test_temporaries_with_sizes.py b/tests/next_tests/integration_tests/feature_tests/ffront_tests/test_temporaries_with_sizes.py index 788081b81e..c4cdd8a4be 100644 --- a/tests/next_tests/integration_tests/feature_tests/ffront_tests/test_temporaries_with_sizes.py +++ b/tests/next_tests/integration_tests/feature_tests/ffront_tests/test_temporaries_with_sizes.py @@ -77,7 +77,7 @@ def prog( def test_verification(testee, run_gtfn_with_temporaries_and_symbolic_sizes, reduction_setup): unstructured_case = Case( - run_gtfn_with_temporaries_and_symbolic_sizes, + run_gtfn_with_temporaries_and_symbolic_sizes.executor, offset_provider=reduction_setup.offset_provider, default_sizes={ Vertex: reduction_setup.num_vertices, @@ -86,6 +86,7 @@ def test_verification(testee, run_gtfn_with_temporaries_and_symbolic_sizes, redu KDim: reduction_setup.k_levels, }, grid_type=common.GridType.UNSTRUCTURED, + allocator=run_gtfn_with_temporaries_and_symbolic_sizes.allocator, ) a = cases.allocate(unstructured_case, testee, "a")() diff --git a/tests/next_tests/integration_tests/feature_tests/test_util_cases.py b/tests/next_tests/integration_tests/feature_tests/test_util_cases.py index 579dec11f8..59c72bbf3f 100644 --- a/tests/next_tests/integration_tests/feature_tests/test_util_cases.py +++ b/tests/next_tests/integration_tests/feature_tests/test_util_cases.py @@ -18,11 +18,11 @@ import gt4py.next as gtx from gt4py.next import errors -import next_tests.exclusion_matrices as definitions +from next_tests import definitions from next_tests.integration_tests import cases from next_tests.integration_tests.cases import ( # noqa: F401 # fixtures cartesian_case, - fieldview_backend, + exec_alloc_descriptor, ) @@ -70,7 +70,7 @@ def test_allocate_const(cartesian_case): # noqa: F811 # fixtures assert b == 42.0 -@pytest.mark.parametrize("fieldview_backend", [~definitions.ProgramBackendId.ROUNDTRIP]) +@pytest.mark.parametrize("exec_alloc_descriptor", [definitions.ProgramBackendId.ROUNDTRIP.load()]) def test_verify_fails_with_wrong_reference(cartesian_case): # noqa: F811 # fixtures a = cases.allocate(cartesian_case, addition, "a")() b = cases.allocate(cartesian_case, addition, "b")() @@ -81,7 +81,7 @@ def test_verify_fails_with_wrong_reference(cartesian_case): # noqa: F811 # fixt cases.verify(cartesian_case, addition, a, b, out=out, ref=wrong_ref) -@pytest.mark.parametrize("fieldview_backend", [~definitions.ProgramBackendId.ROUNDTRIP]) +@pytest.mark.parametrize("exec_alloc_descriptor", [definitions.ProgramBackendId.ROUNDTRIP.load()]) def test_verify_fails_with_wrong_type(cartesian_case): # noqa: F811 # fixtures a = cases.allocate(cartesian_case, addition, "a").dtype(np.float32)() b = cases.allocate(cartesian_case, addition, "b")() @@ -91,7 +91,7 @@ def test_verify_fails_with_wrong_type(cartesian_case): # noqa: F811 # fixtures cases.verify(cartesian_case, addition, a, b, out=out, ref=a + b) -@pytest.mark.parametrize("fieldview_backend", [~definitions.ProgramBackendId.ROUNDTRIP]) +@pytest.mark.parametrize("exec_alloc_descriptor", [definitions.ProgramBackendId.ROUNDTRIP.load()]) def test_verify_with_default_data_fails_with_wrong_reference( cartesian_case, # noqa: F811 # fixtures ): diff --git a/tests/next_tests/integration_tests/multi_feature_tests/ffront_tests/test_embedded_regression.py b/tests/next_tests/integration_tests/multi_feature_tests/ffront_tests/test_embedded_regression.py index ba4b1b0cdb..65f017a518 100644 --- a/tests/next_tests/integration_tests/multi_feature_tests/ffront_tests/test_embedded_regression.py +++ b/tests/next_tests/integration_tests/multi_feature_tests/ffront_tests/test_embedded_regression.py @@ -22,7 +22,7 @@ from next_tests.integration_tests.cases import IField, cartesian_case # noqa: F401 # fixtures from next_tests.integration_tests.feature_tests.ffront_tests.ffront_test_utils import ( # noqa: F401 # fixtures KDim, - fieldview_backend, + exec_alloc_descriptor, ) @@ -38,7 +38,7 @@ def copy(a: IField) -> IField: with pytest.raises(ValueError, match="No backend selected!"): # Calling this should fail if the default backend is respected - # due to `fieldview_backend` fixture (dependency of `cartesian_case`) + # due to `exec_alloc_descriptor` fixture (dependency of `cartesian_case`) # setting the default backend to something invalid. _ = copy(a, out=a, offset_provider={}) @@ -51,7 +51,7 @@ def test_default_backend_is_respected_scan_operator(cartesian_case): # noqa: F8 def sum(state: float, a: float) -> float: return state + a - a = gtx.ones({KDim: 10}, allocator=cartesian_case.backend) + a = gtx.ones({KDim: 10}, allocator=cartesian_case.allocator) with pytest.raises(ValueError, match="No backend selected!"): # see comment in field_operator test @@ -81,7 +81,7 @@ def copy_program(a: IField, b: IField) -> IField: def test_missing_arg_field_operator(cartesian_case): # noqa: F811 # fixtures """Test that calling a field_operator without required args raises an error.""" - @gtx.field_operator(backend=cartesian_case.backend) + @gtx.field_operator(backend=cartesian_case.executor) def copy(a: IField) -> IField: return a @@ -97,7 +97,7 @@ def copy(a: IField) -> IField: def test_missing_arg_scan_operator(cartesian_case): # noqa: F811 # fixtures """Test that calling a scan_operator without required args raises an error.""" - @gtx.scan_operator(backend=cartesian_case.backend, axis=KDim, init=0.0, forward=True) + @gtx.scan_operator(backend=cartesian_case.executor, axis=KDim, init=0.0, forward=True) def sum(state: float, a: float) -> float: return state + a @@ -122,7 +122,7 @@ def copy(a: IField) -> IField: with pytest.raises(errors.DSLError, match="Invalid call"): - @gtx.program(backend=cartesian_case.backend) + @gtx.program(backend=cartesian_case.executor) def copy_program(a: IField, b: IField) -> IField: copy(a) @@ -130,7 +130,7 @@ def copy_program(a: IField, b: IField) -> IField: with pytest.raises(TypeError, match="'offset_provider'"): - @gtx.program(backend=cartesian_case.backend) + @gtx.program(backend=cartesian_case.executor) def copy_program(a: IField, b: IField) -> IField: copy(a, out=b) diff --git a/tests/next_tests/integration_tests/multi_feature_tests/ffront_tests/test_icon_like_scan.py b/tests/next_tests/integration_tests/multi_feature_tests/ffront_tests/test_icon_like_scan.py index 5bd255f80f..f1a5b41f81 100644 --- a/tests/next_tests/integration_tests/multi_feature_tests/ffront_tests/test_icon_like_scan.py +++ b/tests/next_tests/integration_tests/multi_feature_tests/ffront_tests/test_icon_like_scan.py @@ -19,12 +19,12 @@ import gt4py.next as gtx from gt4py.next import common -from gt4py.next.program_processors.runners import gtfn, roundtrip +from next_tests import definitions as test_definitions from next_tests.integration_tests import cases from next_tests.integration_tests.cases import Cell, KDim, Koff from next_tests.integration_tests.feature_tests.ffront_tests.ffront_test_utils import ( - fieldview_backend, + exec_alloc_descriptor, ) @@ -193,12 +193,13 @@ def reference( @pytest.fixture -def test_setup(fieldview_backend): +def test_setup(exec_alloc_descriptor): test_case = cases.Case( - fieldview_backend, + exec_alloc_descriptor.executor, offset_provider={"Koff": KDim}, default_sizes={Cell: 14, KDim: 10}, grid_type=common.GridType.UNSTRUCTURED, + allocator=exec_alloc_descriptor.allocator, ) @dataclasses.dataclass(frozen=True) @@ -226,15 +227,8 @@ class setup: @pytest.mark.uses_tuple_returns +@pytest.mark.uses_scan_requiring_projector def test_solve_nonhydro_stencil_52_like_z_q(test_setup): - if test_setup.case.backend in [ - gtfn.run_gtfn, - gtfn.run_gtfn_gpu, - gtfn.run_gtfn_imperative, - gtfn.run_gtfn_with_temporaries, - ]: - pytest.xfail("Needs implementation of scan projector.") - cases.verify( test_setup.case, solve_nonhydro_stencil_52_like_z_q, @@ -253,12 +247,15 @@ def test_solve_nonhydro_stencil_52_like_z_q(test_setup): @pytest.mark.uses_tuple_returns def test_solve_nonhydro_stencil_52_like_z_q_tup(test_setup): - if test_setup.case.backend in [gtfn.run_gtfn_with_temporaries]: + if ( + test_setup.case.executor + == test_definitions.ProgramBackendId.GTFN_CPU_WITH_TEMPORARIES.load().executor + ): pytest.xfail( "Needs implementation of scan projector. Breaks in type inference as executed" "again after CollapseTuple." ) - if test_setup.case.backend == roundtrip.backend: + if test_setup.case.executor == test_definitions.ProgramBackendId.ROUNDTRIP.load().executor: pytest.xfail("Needs proper handling of tuple[Column] <-> Column[tuple].") cases.verify( @@ -277,7 +274,10 @@ def test_solve_nonhydro_stencil_52_like_z_q_tup(test_setup): @pytest.mark.uses_tuple_returns def test_solve_nonhydro_stencil_52_like(test_setup): - if test_setup.case.backend in [gtfn.run_gtfn_with_temporaries]: + if ( + test_setup.case.executor + == test_definitions.ProgramBackendId.GTFN_CPU_WITH_TEMPORARIES.load().executor + ): pytest.xfail("Temporary extraction does not work correctly in combination with scans.") cases.run( @@ -296,9 +296,12 @@ def test_solve_nonhydro_stencil_52_like(test_setup): @pytest.mark.uses_tuple_returns def test_solve_nonhydro_stencil_52_like_with_gtfn_tuple_merge(test_setup): - if test_setup.case.backend in [gtfn.run_gtfn_with_temporaries]: + if ( + test_setup.case.executor + == test_definitions.ProgramBackendId.GTFN_CPU_WITH_TEMPORARIES.load().executor + ): pytest.xfail("Temporary extraction does not work correctly in combination with scans.") - if test_setup.case.backend == roundtrip.backend: + if test_setup.case.executor == test_definitions.ProgramBackendId.ROUNDTRIP.load().executor: pytest.xfail("Needs proper handling of tuple[Column] <-> Column[tuple].") cases.run( diff --git a/tests/next_tests/integration_tests/multi_feature_tests/ffront_tests/test_laplacian.py b/tests/next_tests/integration_tests/multi_feature_tests/ffront_tests/test_laplacian.py index 4f4d4969a9..6784857211 100644 --- a/tests/next_tests/integration_tests/multi_feature_tests/ffront_tests/test_laplacian.py +++ b/tests/next_tests/integration_tests/multi_feature_tests/ffront_tests/test_laplacian.py @@ -20,7 +20,7 @@ from next_tests.integration_tests import cases from next_tests.integration_tests.cases import IDim, Ioff, JDim, Joff, cartesian_case from next_tests.integration_tests.feature_tests.ffront_tests.ffront_test_utils import ( - fieldview_backend, + exec_alloc_descriptor, ) diff --git a/tests/next_tests/unit_tests/conftest.py b/tests/next_tests/unit_tests/conftest.py index 4177a5aeee..d3f9bdb761 100644 --- a/tests/next_tests/unit_tests/conftest.py +++ b/tests/next_tests/unit_tests/conftest.py @@ -31,8 +31,8 @@ else: raise e + import next_tests -import next_tests.exclusion_matrices as definitions @pytest.fixture( @@ -49,7 +49,7 @@ def lift_mode(request): OPTIONAL_PROCESSORS = [] if dace_iterator: - OPTIONAL_PROCESSORS.append((definitions.OptionalProgramBackendId.DACE_CPU, True)) + OPTIONAL_PROCESSORS.append((next_tests.definitions.OptionalProgramBackendId.DACE_CPU, True)) # TODO(havogt): update tests to use proper allocation # OPTIONAL_PROCESSORS.append( # pytest.param( @@ -61,16 +61,16 @@ def lift_mode(request): @pytest.fixture( params=[ (None, True), - (definitions.ProgramBackendId.ROUNDTRIP, True), - (definitions.ProgramBackendId.DOUBLE_ROUNDTRIP, True), - (definitions.ProgramBackendId.GTFN_CPU, True), - (definitions.ProgramBackendId.GTFN_CPU_IMPERATIVE, True), - (definitions.ProgramBackendId.GTFN_CPU_WITH_TEMPORARIES, True), + (next_tests.definitions.ProgramBackendId.ROUNDTRIP, True), + (next_tests.definitions.ProgramBackendId.DOUBLE_ROUNDTRIP, True), + (next_tests.definitions.ProgramBackendId.GTFN_CPU, True), + (next_tests.definitions.ProgramBackendId.GTFN_CPU_IMPERATIVE, True), + (next_tests.definitions.ProgramBackendId.GTFN_CPU_WITH_TEMPORARIES, True), # pytest.param((definitions.ProgramBackendId.GTFN_GPU, True), marks=pytest.mark.requires_gpu), # TODO(havogt): update tests to use proper allocation - (definitions.ProgramFormatterId.LISP_FORMATTER, False), - (definitions.ProgramFormatterId.ITIR_PRETTY_PRINTER, False), - (definitions.ProgramFormatterId.ITIR_TYPE_CHECKER, False), - (definitions.ProgramFormatterId.GTFN_CPP_FORMATTER, False), + (next_tests.definitions.ProgramFormatterId.LISP_FORMATTER, False), + (next_tests.definitions.ProgramFormatterId.ITIR_PRETTY_PRINTER, False), + (next_tests.definitions.ProgramFormatterId.ITIR_TYPE_CHECKER, False), + (next_tests.definitions.ProgramFormatterId.GTFN_CPP_FORMATTER, False), ] + OPTIONAL_PROCESSORS, ids=lambda p: p[0].short_id() if p[0] is not None else "None", @@ -89,7 +89,7 @@ def program_processor(request) -> tuple[ppi.ProgramProcessor, bool]: processor = processor_id.load() assert is_backend == ppi.is_program_backend(processor) - for marker, skip_mark, msg in next_tests.exclusion_matrices.BACKEND_SKIP_TEST_MATRIX.get( + for marker, skip_mark, msg in next_tests.definitions.BACKEND_SKIP_TEST_MATRIX.get( processor_id, [] ): if request.node.get_closest_marker(marker): diff --git a/tests/next_tests/unit_tests/embedded_tests/test_nd_array_field.py b/tests/next_tests/unit_tests/embedded_tests/test_nd_array_field.py index 6863b09c12..1972b55852 100644 --- a/tests/next_tests/unit_tests/embedded_tests/test_nd_array_field.py +++ b/tests/next_tests/unit_tests/embedded_tests/test_nd_array_field.py @@ -77,7 +77,7 @@ def _make_field(lst: Iterable, nd_array_implementation, *, domain=None, dtype=No domain = tuple( (common.Dimension(f"D{i}"), common.UnitRange(0, s)) for i, s in enumerate(buffer.shape) ) - return common.field( + return common._field( buffer, domain=domain, ) @@ -119,14 +119,14 @@ def test_where_builtin_different_domain(nd_array_implementation): true_ = np.asarray([[1.0, 2.0, 3.0], [4.0, 5.0, 6.0]], dtype=np.float32) false_ = np.asarray([7.0, 8.0, 9.0, 10.0], dtype=np.float32) - cond_field = common.field( + cond_field = common._field( nd_array_implementation.asarray(cond), domain=common.domain({JDim: 2}) ) - true_field = common.field( + true_field = common._field( nd_array_implementation.asarray(true_), domain=common.domain({IDim: common.UnitRange(0, 2), JDim: common.UnitRange(-1, 2)}), ) - false_field = common.field( + false_field = common._field( nd_array_implementation.asarray(false_), domain=common.domain({JDim: common.UnitRange(-1, 3)}), ) @@ -225,8 +225,8 @@ def test_binary_operations_with_intersection(binary_arithmetic_op, dims, expecte arr2 = np.ones((5, 5)) arr2_domain = common.Domain(dims=(IDim, JDim), ranges=(UnitRange(5, 10), UnitRange(5, 10))) - field1 = common.field(arr1, domain=arr1_domain) - field2 = common.field(arr2, domain=arr2_domain) + field1 = common._field(arr1, domain=arr1_domain) + field2 = common._field(arr2, domain=arr2_domain) op_result = binary_arithmetic_op(field1, field2) expected_result = binary_arithmetic_op(arr1[expected_indices[0], expected_indices[1]], arr2) @@ -287,11 +287,11 @@ def test_remap_implementation(): V_START, V_STOP = 2, 7 E_START, E_STOP = 0, 10 - v_field = common.field( + v_field = common._field( -0.1 * np.arange(V_START, V_STOP), domain=common.Domain(dims=(V,), ranges=(UnitRange(V_START, V_STOP),)), ) - e2v_conn = common.connectivity( + e2v_conn = common._connectivity( np.arange(E_START, E_STOP), domain=common.Domain( dims=(E,), @@ -303,7 +303,7 @@ def test_remap_implementation(): ) result = v_field.remap(e2v_conn) - expected = common.field( + expected = common._field( -0.1 * np.arange(V_START, V_STOP), domain=common.Domain(dims=(E,), ranges=(UnitRange(V_START, V_STOP),)), ) @@ -318,14 +318,14 @@ def test_cartesian_remap_implementation(): V_START, V_STOP = 2, 7 OFFSET = 2 - v_field = common.field( + v_field = common._field( -0.1 * np.arange(V_START, V_STOP), domain=common.Domain(dims=(V,), ranges=(UnitRange(V_START, V_STOP),)), ) - v2_conn = common.connectivity(OFFSET, V) + v2_conn = common._connectivity(OFFSET, V) result = v_field.remap(v2_conn) - expected = common.field( + expected = common._field( v_field.ndarray, domain=common.Domain(dims=(V,), ranges=(UnitRange(V_START - OFFSET, V_STOP - OFFSET),)), ) @@ -340,7 +340,7 @@ def test_cartesian_remap_implementation(): ( ( (IDim,), - common.field( + common._field( np.arange(10), domain=common.Domain(dims=(IDim,), ranges=(UnitRange(0, 10),)) ), Domain(dims=(IDim,), ranges=(UnitRange(0, 10),)), @@ -349,7 +349,7 @@ def test_cartesian_remap_implementation(): ( ( (IDim, JDim), - common.field( + common._field( np.arange(10), domain=common.Domain(dims=(IDim,), ranges=(UnitRange(0, 10),)) ), Domain(dims=(IDim, JDim), ranges=(UnitRange(0, 10), UnitRange.infinite())), @@ -358,7 +358,7 @@ def test_cartesian_remap_implementation(): ( ( (IDim, JDim), - common.field( + common._field( np.arange(10), domain=common.Domain(dims=(JDim,), ranges=(UnitRange(0, 10),)) ), Domain(dims=(IDim, JDim), ranges=(UnitRange.infinite(), UnitRange(0, 10))), @@ -367,7 +367,7 @@ def test_cartesian_remap_implementation(): ( ( (IDim, JDim, KDim), - common.field( + common._field( np.arange(10), domain=common.Domain(dims=(JDim,), ranges=(UnitRange(0, 10),)) ), Domain( @@ -455,7 +455,7 @@ def test_absolute_indexing(domain_slice, expected_dimensions, expected_shape): domain = common.Domain( dims=(IDim, JDim, KDim), ranges=(UnitRange(5, 10), UnitRange(5, 15), UnitRange(10, 25)) ) - field = common.field(np.ones((5, 10, 15)), domain=domain) + field = common._field(np.ones((5, 10, 15)), domain=domain) indexed_field = field[domain_slice] assert common.is_field(indexed_field) @@ -465,7 +465,7 @@ def test_absolute_indexing(domain_slice, expected_dimensions, expected_shape): def test_absolute_indexing_value_return(): domain = common.Domain(dims=(IDim, JDim), ranges=(UnitRange(10, 20), UnitRange(5, 15))) - field = common.field(np.reshape(np.arange(100, dtype=np.int32), (10, 10)), domain=domain) + field = common._field(np.reshape(np.arange(100, dtype=np.int32), (10, 10)), domain=domain) named_index = ((IDim, 12), (JDim, 6)) value = field[named_index] @@ -502,7 +502,7 @@ def test_absolute_indexing_value_return(): ) def test_relative_indexing_slice_2D(index, expected_shape, expected_domain): domain = common.Domain(dims=(IDim, JDim), ranges=(UnitRange(5, 15), UnitRange(2, 12))) - field = common.field(np.ones((10, 10)), domain=domain) + field = common._field(np.ones((10, 10)), domain=domain) indexed_field = field[index] assert common.is_field(indexed_field) @@ -558,7 +558,7 @@ def test_relative_indexing_slice_3D(index, expected_shape, expected_domain): domain = common.Domain( dims=(IDim, JDim, KDim), ranges=(UnitRange(5, 15), UnitRange(10, 25), UnitRange(10, 20)) ) - field = common.field(np.ones((10, 15, 10)), domain=domain) + field = common._field(np.ones((10, 15, 10)), domain=domain) indexed_field = field[index] assert common.is_field(indexed_field) @@ -572,7 +572,7 @@ def test_relative_indexing_slice_3D(index, expected_shape, expected_domain): ) def test_relative_indexing_value_return(index, expected_value): domain = common.Domain(dims=(IDim, JDim), ranges=(UnitRange(5, 15), UnitRange(2, 12))) - field = common.field(np.reshape(np.arange(100, dtype=int), (10, 10)), domain=domain) + field = common._field(np.reshape(np.arange(100, dtype=int), (10, 10)), domain=domain) indexed_field = field[index] assert indexed_field == expected_value @@ -581,7 +581,7 @@ def test_relative_indexing_value_return(index, expected_value): @pytest.mark.parametrize("lazy_slice", [lambda f: f[13], lambda f: f[:5, :3, :2]]) def test_relative_indexing_out_of_bounds(lazy_slice): domain = common.Domain(dims=(IDim, JDim), ranges=(UnitRange(3, 13), UnitRange(-5, 5))) - field = common.field(np.ones((10, 10)), domain=domain) + field = common._field(np.ones((10, 10)), domain=domain) with pytest.raises((embedded_exceptions.IndexOutOfBounds, IndexError)): lazy_slice(field) @@ -590,7 +590,7 @@ def test_relative_indexing_out_of_bounds(lazy_slice): @pytest.mark.parametrize("index", [IDim, "1", (IDim, JDim)]) def test_field_unsupported_index(index): domain = common.Domain(dims=(IDim,), ranges=(UnitRange(0, 10),)) - field = common.field(np.ones((10,)), domain=domain) + field = common._field(np.ones((10,)), domain=domain) with pytest.raises(IndexError, match="Unsupported index type"): field[index] @@ -602,12 +602,12 @@ def test_field_unsupported_index(index): ((1, slice(None)), np.ones((10,)) * 42.0), ( (1, slice(None)), - common.field(np.ones((10,)) * 42.0, domain=common.Domain((JDim, UnitRange(0, 10)))), + common._field(np.ones((10,)) * 42.0, domain=common.Domain((JDim, UnitRange(0, 10)))), ), ], ) def test_setitem(index, value): - field = common.field( + field = common._field( np.arange(100).reshape(10, 10), domain=common.Domain(dims=(IDim, JDim), ranges=(UnitRange(0, 10), UnitRange(0, 10))), ) @@ -621,12 +621,12 @@ def test_setitem(index, value): def test_setitem_wrong_domain(): - field = common.field( + field = common._field( np.arange(100).reshape(10, 10), domain=common.Domain(dims=(IDim, JDim), ranges=(UnitRange(0, 10), UnitRange(0, 10))), ) - value_incompatible = common.field( + value_incompatible = common._field( np.ones((10,)) * 42.0, domain=common.Domain((JDim, UnitRange(-5, 5))) ) From eb430023bb3dabef665ace944408a67b53a7a5e5 Mon Sep 17 00:00:00 2001 From: edopao Date: Tue, 30 Jan 2024 10:55:48 +0100 Subject: [PATCH 58/85] fix[next][dace]: Bugfix in deref (dynamic memory allocation) (#1430) Baseline contained a bug in the lowering of deref in the context of neighbor reduction. The data container should be statically allocated with size equal to the max_neighbors attribute in the offset provider. --- .../runners/dace_iterator/itir_to_sdfg.py | 2 +- .../runners/dace_iterator/itir_to_tasklet.py | 43 +++++++++---------- 2 files changed, 21 insertions(+), 24 deletions(-) diff --git a/src/gt4py/next/program_processors/runners/dace_iterator/itir_to_sdfg.py b/src/gt4py/next/program_processors/runners/dace_iterator/itir_to_sdfg.py index 63e9fb03dc..525a5c694e 100644 --- a/src/gt4py/next/program_processors/runners/dace_iterator/itir_to_sdfg.py +++ b/src/gt4py/next/program_processors/runners/dace_iterator/itir_to_sdfg.py @@ -461,7 +461,7 @@ def _visit_scan_stencil_closure( assert isinstance(node.output, SymRef) neighbor_tables = filter_neighbor_tables(self.offset_provider) input_names = [str(inp.id) for inp in node.inputs] - connectivity_names = [connectivity_identifier(offset) for offset, _ in neighbor_tables] + connectivity_names = [connectivity_identifier(offset) for offset in neighbor_tables.keys()] # find the scan dimension, same as output dimension, and exclude it from the map domain map_ranges = {} diff --git a/src/gt4py/next/program_processors/runners/dace_iterator/itir_to_tasklet.py b/src/gt4py/next/program_processors/runners/dace_iterator/itir_to_tasklet.py index ab03d29389..ba969608a7 100644 --- a/src/gt4py/next/program_processors/runners/dace_iterator/itir_to_tasklet.py +++ b/src/gt4py/next/program_processors/runners/dace_iterator/itir_to_tasklet.py @@ -763,7 +763,6 @@ def _visit_deref(self, node: itir.FunCall) -> list[ValueExpr]: # already a list of ValueExpr return iterator - args: list[ValueExpr] sorted_dims = sorted(iterator.dimensions) if all([dim in iterator.indices for dim in iterator.dimensions]): # The deref iterator has index values on all dimensions: the result will be a scalar @@ -781,16 +780,16 @@ def _visit_deref(self, node: itir.FunCall) -> list[ValueExpr]: ) else: - # Not all dimensions are included in the deref index list: - # this means the ND-field will be sliced along one or more dimensions and the result will be an array - field_array = self.context.body.arrays[iterator.field.data] - result_shape = tuple( - dim_size - for dim, dim_size in zip(sorted_dims, field_array.shape) - if dim not in iterator.indices - ) + dims_not_indexed = [dim for dim in iterator.dimensions if dim not in iterator.indices] + assert len(dims_not_indexed) == 1 + offset = dims_not_indexed[0] + offset_provider = self.offset_provider[offset] + neighbor_dim = offset_provider.neighbor_axis.value + result_name = unique_var_name() - self.context.body.add_array(result_name, result_shape, iterator.dtype, transient=True) + self.context.body.add_array( + result_name, (offset_provider.max_neighbors,), iterator.dtype, transient=True + ) result_array = self.context.body.arrays[result_name] result_node = self.context.state.add_access(result_name, debuginfo=di) @@ -800,19 +799,17 @@ def _visit_deref(self, node: itir.FunCall) -> list[ValueExpr]: deref_nodes = [iterator.field] + [ iterator.indices[dim] for dim in sorted_dims if dim in iterator.indices ] - deref_memlets = [dace.Memlet.from_array(iterator.field.data, field_array)] + [ - dace.Memlet(data=node.data, subset="0") for node in deref_nodes[1:] - ] + deref_memlets = [ + dace.Memlet.from_array(iterator.field.data, iterator.field.desc(self.context.body)) + ] + [dace.Memlet(data=node.data, subset="0") for node in deref_nodes[1:]] # we create a mapped tasklet for array slicing + index_name = unique_name(f"_i_{neighbor_dim}") map_ranges = { - f"_i_{dim}": f"0:{size}" - for dim, size in zip(sorted_dims, field_array.shape) - if dim not in iterator.indices + index_name: f"0:{offset_provider.max_neighbors}", } - src_subset = ",".join([f"_i_{dim}" for dim in sorted_dims]) - dst_subset = ",".join( - [f"_i_{dim}" for dim in sorted_dims if dim not in iterator.indices] + src_subset = ",".join( + [f"_i_{dim}" if dim in iterator.indices else index_name for dim in sorted_dims] ) self.context.state.add_mapped_tasklet( "deref", @@ -821,7 +818,7 @@ def _visit_deref(self, node: itir.FunCall) -> list[ValueExpr]: outputs={ "_out": dace.Memlet.from_array(result_name, result_array), }, - code=f"_out[{dst_subset}] = _inp[{src_subset}]", + code=f"_out[{index_name}] = _inp[{src_subset}]", external_edges=True, input_nodes={node.data: node for node in deref_nodes}, output_nodes={ @@ -952,10 +949,10 @@ def _visit_reduce(self, node: itir.FunCall): # set reduction state self.context.reduce_identity = SymbolExpr(reduce_identity, reduce_dtype) - args = self.visit(node.args) + args = self.visit(node.args[0]) - assert len(args) == 1 and len(args[0]) == 1 - reduce_input_node = args[0][0].value + assert len(args) == 1 + reduce_input_node = args[0].value else: assert isinstance(node.fun, itir.FunCall) From 3fb512df03ff9245beab0b495bcfc94ec505d213 Mon Sep 17 00:00:00 2001 From: Hannes Vogt Date: Wed, 31 Jan 2024 12:08:47 +0100 Subject: [PATCH 59/85] build: update min requirements (#1435) - Update minimal version for pygments due to conflict (failing daily min requirements ci) - Many files touched due to formatting change in black - Fix a bug in cartesian hypothesis setup --- .pre-commit-config.yaml | 6 +- constraints.txt | 40 ++++---- min-extra-requirements-test.txt | 2 +- min-requirements-test.txt | 2 +- pyproject.toml | 4 +- requirements-dev.in | 2 +- requirements-dev.txt | 40 ++++---- src/gt4py/_core/definitions.py | 84 ++++++----------- src/gt4py/cartesian/backend/base.py | 3 +- src/gt4py/cartesian/backend/dace_backend.py | 16 ++-- src/gt4py/cartesian/backend/gtcpp_backend.py | 6 +- src/gt4py/cartesian/backend/pyext_builder.py | 6 +- src/gt4py/cartesian/frontend/defir_to_gtir.py | 12 +-- .../cartesian/frontend/gtscript_frontend.py | 16 ++-- src/gt4py/cartesian/gtc/cuir/oir_to_cuir.py | 3 +- .../gtc/dace/expansion/daceir_builder.py | 32 ++++--- .../gtc/dace/expansion_specification.py | 8 +- src/gt4py/cartesian/gtc/gtcpp/oir_to_gtcpp.py | 3 +- .../cartesian/testing/input_strategies.py | 8 +- src/gt4py/cartesian/type_hints.py | 3 +- src/gt4py/cartesian/utils/attrib.py | 9 +- src/gt4py/eve/codegen.py | 6 +- src/gt4py/eve/datamodels/core.py | 42 ++++----- src/gt4py/eve/extended_typing.py | 76 ++++++--------- src/gt4py/eve/trees.py | 6 +- src/gt4py/eve/type_validation.py | 12 +-- src/gt4py/eve/utils.py | 53 ++++------- src/gt4py/next/allocators.py | 15 ++- src/gt4py/next/common.py | 93 +++++++------------ src/gt4py/next/embedded/nd_array_field.py | 12 ++- src/gt4py/next/ffront/decorator.py | 20 ++-- src/gt4py/next/ffront/dialect_parser.py | 16 ++-- src/gt4py/next/ffront/field_operator_ast.py | 3 +- .../ffront/foast_passes/type_deduction.py | 6 +- src/gt4py/next/ffront/foast_to_itir.py | 3 +- src/gt4py/next/ffront/program_ast.py | 3 +- src/gt4py/next/iterator/embedded.py | 48 ++++------ src/gt4py/next/iterator/ir.py | 9 +- src/gt4py/next/iterator/runtime.py | 6 +- src/gt4py/next/otf/compilation/compiler.py | 6 +- src/gt4py/next/otf/languages.py | 3 +- src/gt4py/next/otf/stages.py | 6 +- src/gt4py/next/otf/step_types.py | 8 +- src/gt4py/next/otf/workflow.py | 3 +- .../codegens/gtfn/gtfn_im_ir.py | 6 +- .../codegens/gtfn/gtfn_ir_common.py | 3 +- .../program_processors/processor_interface.py | 9 +- .../runners/dace_iterator/itir_to_sdfg.py | 30 +++--- src/gt4py/next/type_inference.py | 12 +-- src/gt4py/next/type_system/type_info.py | 5 +- src/gt4py/storage/allocators.py | 12 +-- .../stencil_definitions.py | 6 +- .../unit_tests/test_extended_typing.py | 39 +++----- tests/next_tests/definitions.py | 6 +- tests/next_tests/integration_tests/cases.py | 18 ++-- .../ffront_tests/test_execution.py | 7 +- .../ffront_tests/test_foast_to_itir.py | 36 ++++--- .../unit_tests/test_type_inference.py | 3 +- 58 files changed, 398 insertions(+), 554 deletions(-) diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index 3259a74f38..862aa46d66 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -62,7 +62,7 @@ repos: ## version = re.search('black==([0-9\.]*)', open("constraints.txt").read())[1] ## print(f"rev: '{version}' # version from constraints.txt") ##]]] - rev: '23.12.1' # version from constraints.txt + rev: '24.1.1' # version from constraints.txt ##[[[end]]] hooks: - id: black @@ -162,7 +162,7 @@ repos: ##]]] - astunparse==1.6.3 - attrs==23.2.0 - - black==23.12.1 + - black==24.1.1 - boltons==23.1.1 - cached-property==1.5.2 - click==8.1.7 @@ -175,7 +175,7 @@ repos: - importlib-resources==6.1.1 - jinja2==3.1.3 - lark==1.1.9 - - mako==1.3.1 + - mako==1.3.2 - nanobind==1.8.0 - ninja==1.11.1.1 - numpy==1.24.4 diff --git a/constraints.txt b/constraints.txt index 343615b421..1aa47d8340 100644 --- a/constraints.txt +++ b/constraints.txt @@ -11,7 +11,7 @@ astunparse==1.6.3 ; python_version < "3.9" # via dace, gt4py (pyproject.toml) attrs==23.2.0 # via flake8-bugbear, flake8-eradicate, gt4py (pyproject.toml), hypothesis, jsonschema, referencing babel==2.14.0 # via sphinx backcall==0.2.0 # via ipython -black==23.12.1 # via gt4py (pyproject.toml) +black==24.1.1 # via gt4py (pyproject.toml) blinker==1.7.0 # via flask boltons==23.1.1 # via gt4py (pyproject.toml) build==1.0.3 # via pip-tools @@ -29,8 +29,8 @@ cogapp==3.3.0 # via -r requirements-dev.in colorama==0.4.6 # via tox comm==0.2.1 # via ipykernel contourpy==1.1.1 # via matplotlib -coverage==7.4.0 # via -r requirements-dev.in, coverage, pytest-cov -cryptography==42.0.1 # via types-paramiko, types-pyopenssl, types-redis +coverage==7.4.1 # via -r requirements-dev.in, coverage, pytest-cov +cryptography==42.0.2 # via types-paramiko, types-pyopenssl, types-redis cycler==0.12.1 # via matplotlib cytoolz==0.12.3 # via gt4py (pyproject.toml) dace==0.15.1 # via gt4py (pyproject.toml) @@ -39,7 +39,7 @@ debugpy==1.8.0 # via ipykernel decorator==5.1.1 # via ipython deepdiff==6.7.1 # via gt4py (pyproject.toml) devtools==0.12.2 # via gt4py (pyproject.toml) -dill==0.3.7 # via dace +dill==0.3.8 # via dace distlib==0.3.8 # via virtualenv docutils==0.20.1 # via restructuredtext-lint, sphinx, sphinx-rtd-theme eradicate==2.3.0 # via flake8-eradicate @@ -47,7 +47,7 @@ exceptiongroup==1.2.0 # via hypothesis, pytest execnet==2.0.2 # via pytest-cache, pytest-xdist executing==2.0.1 # via devtools, stack-data factory-boy==3.3.0 # via -r requirements-dev.in, pytest-factoryboy -faker==22.5.1 # via factory-boy +faker==22.6.0 # via factory-boy fastjsonschema==2.19.1 # via nbformat filelock==3.13.1 # via tox, virtualenv flake8==7.0.0 # via -r requirements-dev.in, flake8-bugbear, flake8-builtins, flake8-debugger, flake8-docstrings, flake8-eradicate, flake8-mutable, flake8-pyproject, flake8-rst-docstrings @@ -64,7 +64,7 @@ fonttools==4.47.2 # via matplotlib fparser==0.1.3 # via dace frozendict==2.4.0 # via gt4py (pyproject.toml) gridtools-cpp==2.3.1 # via gt4py (pyproject.toml) -hypothesis==6.97.0 # via -r requirements-dev.in, gt4py (pyproject.toml) +hypothesis==6.97.3 # via -r requirements-dev.in, gt4py (pyproject.toml) identify==2.5.33 # via pre-commit idna==3.6 # via requests imagesize==1.4.1 # via sphinx @@ -85,7 +85,7 @@ jupyter-core==5.7.1 # via ipykernel, jupyter-client, nbformat jupytext==1.16.1 # via -r requirements-dev.in kiwisolver==1.4.5 # via matplotlib lark==1.1.9 # via gt4py (pyproject.toml) -mako==1.3.1 # via gt4py (pyproject.toml) +mako==1.3.2 # via gt4py (pyproject.toml) markdown-it-py==3.0.0 # via jupytext, mdit-py-plugins markupsafe==2.1.4 # via jinja2, mako, werkzeug matplotlib==3.7.4 # via -r requirements-dev.in @@ -99,7 +99,7 @@ mypy-extensions==1.0.0 # via black, mypy nanobind==1.8.0 # via gt4py (pyproject.toml) nbclient==0.6.8 # via nbmake nbformat==5.9.2 # via jupytext, nbclient, nbmake -nbmake==1.4.6 # via -r requirements-dev.in +nbmake==1.5.0 # via -r requirements-dev.in nest-asyncio==1.6.0 # via ipykernel, nbclient networkx==3.1 # via dace ninja==1.11.1.1 # via gt4py (pyproject.toml) @@ -115,7 +115,7 @@ pillow==10.2.0 # via matplotlib pip-tools==7.3.0 # via -r requirements-dev.in pipdeptree==2.13.2 # via -r requirements-dev.in pkgutil-resolve-name==1.3.10 # via jsonschema -platformdirs==4.1.0 # via black, jupyter-core, tox, virtualenv +platformdirs==4.2.0 # via black, jupyter-core, tox, virtualenv pluggy==1.4.0 # via pytest, tox ply==3.11 # via dace pre-commit==3.5.0 # via -r requirements-dev.in @@ -132,20 +132,20 @@ pygments==2.17.2 # via -r requirements-dev.in, devtools, flake8-rst-doc pyparsing==3.1.1 # via matplotlib pyproject-api==1.6.1 # via tox pyproject-hooks==1.0.0 # via build -pytest==7.4.4 # via -r requirements-dev.in, gt4py (pyproject.toml), nbmake, pytest-cache, pytest-cov, pytest-factoryboy, pytest-xdist +pytest==8.0.0 # via -r requirements-dev.in, gt4py (pyproject.toml), nbmake, pytest-cache, pytest-cov, pytest-factoryboy, pytest-xdist pytest-cache==1.0 # via -r requirements-dev.in pytest-cov==4.1.0 # via -r requirements-dev.in pytest-factoryboy==2.6.0 # via -r requirements-dev.in pytest-xdist==3.5.0 # via -r requirements-dev.in, pytest-xdist python-dateutil==2.8.2 # via faker, jupyter-client, matplotlib -pytz==2023.3.post1 # via babel +pytz==2023.4 # via babel pyyaml==6.0.1 # via dace, jupytext, pre-commit pyzmq==25.1.2 # via ipykernel, jupyter-client -referencing==0.32.1 # via jsonschema, jsonschema-specifications +referencing==0.33.0 # via jsonschema, jsonschema-specifications requests==2.31.0 # via dace, sphinx restructuredtext-lint==1.4.0 # via flake8-rst-docstrings rpds-py==0.17.1 # via jsonschema, referencing -ruff==0.1.14 # via -r requirements-dev.in +ruff==0.1.15 # via -r requirements-dev.in setuptools-scm==8.0.4 # via fparser six==1.16.0 # via asttokens, astunparse, python-dateutil snowballstemmer==2.2.0 # via pydocstyle, sphinx @@ -193,7 +193,7 @@ types-datetimerange==2.0.0.6 # via types-all types-decorator==5.1.8.20240106 # via types-all types-deprecated==1.2.9.20240106 # via types-all types-docopt==0.6.11.4 # via types-all -types-docutils==0.20.0.20240125 # via types-all +types-docutils==0.20.0.20240126 # via types-all types-emoji==2.1.0.3 # via types-all types-enum34==1.1.8 # via types-all types-fb303==1.0.0 # via types-all, types-scribe @@ -209,7 +209,7 @@ types-itsdangerous==1.1.6 # via types-all types-jack-client==0.5.10.20240106 # via types-all types-jinja2==2.11.9 # via types-all, types-flask types-kazoo==0.1.3 # via types-all -types-markdown==3.5.0.20240106 # via types-all +types-markdown==3.5.0.20240129 # via types-all types-markupsafe==1.1.10 # via types-all, types-jinja2 types-maxminddb==1.5.0 # via types-all, types-geoip2 types-mock==5.1.0.20240106 # via types-all @@ -222,20 +222,20 @@ types-pathlib2==2.3.0 # via types-all types-pillow==10.2.0.20240125 # via types-all types-pkg-resources==0.1.3 # via types-all types-polib==1.2.0.20240115 # via types-all -types-protobuf==4.24.0.20240106 # via types-all +types-protobuf==4.24.0.20240129 # via types-all types-pyaudio==0.2.16.20240106 # via types-all types-pycurl==7.45.2.20240106 # via types-all types-pyfarmhash==0.3.1.2 # via types-all types-pyjwt==1.7.1 # via types-all types-pymssql==2.1.0 # via types-all types-pymysql==1.1.0.1 # via types-all -types-pyopenssl==23.3.0.20240106 # via types-redis +types-pyopenssl==24.0.0.20240130 # via types-redis types-pyrfc3339==1.1.1.5 # via types-all types-pysftp==0.2.17.20240106 # via types-all types-python-dateutil==2.8.19.20240106 # via types-all, types-datetimerange types-python-gflags==3.1.7.3 # via types-all -types-python-slugify==8.0.0.3 # via types-all -types-pytz==2023.3.1.1 # via types-all, types-tzlocal +types-python-slugify==8.0.2.20240127 # via types-all +types-pytz==2023.4.0.20240130 # via types-all, types-tzlocal types-pyvmomi==8.0.0.6 # via types-all types-pyyaml==6.0.12.12 # via types-all types-redis==4.6.0.20240106 # via types-all @@ -258,7 +258,7 @@ types-waitress==2.1.4.20240106 # via types-all types-werkzeug==1.0.9 # via types-all, types-flask types-xxhash==3.0.5.2 # via types-all typing-extensions==4.5.0 # via black, faker, gt4py (pyproject.toml), ipython, mypy, pytest-factoryboy, setuptools-scm -urllib3==2.1.0 # via requests, types-requests +urllib3==2.2.0 # via requests, types-requests virtualenv==20.25.0 # via pre-commit, tox wcwidth==0.2.13 # via prompt-toolkit websockets==12.0 # via dace diff --git a/min-extra-requirements-test.txt b/min-extra-requirements-test.txt index 3c6cd3d9ff..7200018616 100644 --- a/min-extra-requirements-test.txt +++ b/min-extra-requirements-test.txt @@ -61,7 +61,7 @@ pipdeptree==2.3 pre-commit==2.17 psutil==5.0 pybind11==2.5 -pygments==2.7 +pygments==2.7.3 pytest-cache==1.0 pytest-cov==2.8 pytest-factoryboy==2.0.3 diff --git a/min-requirements-test.txt b/min-requirements-test.txt index d2ebaba331..259663ffc4 100644 --- a/min-requirements-test.txt +++ b/min-requirements-test.txt @@ -58,7 +58,7 @@ pipdeptree==2.3 pre-commit==2.17 psutil==5.0 pybind11==2.5 -pygments==2.7 +pygments==2.7.3 pytest-cache==1.0 pytest-cov==2.8 pytest-factoryboy==2.0.3 diff --git a/pyproject.toml b/pyproject.toml index 51cfc267d5..5a1618fc49 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -162,7 +162,9 @@ ignore = [ 'DAR', # Disable dargling errors by default 'E203', # Whitespace before ':' (black formatter breaks this sometimes) 'E501', # Line too long (using Bugbear's B950 warning) - 'W503' # Line break occurred before a binary operator + 'W503', # Line break occurred before a binary operator + 'E701', # Multiple statements on one line, see https://github.com/psf/black/issues/3887 + 'E704' # Multiple statements on one line, see https://github.com/psf/black/issues/3887 ] max-complexity = 15 max-line-length = 100 # It should be the same as in `tool.black.line-length` above diff --git a/requirements-dev.in b/requirements-dev.in index 59ddb733d0..4bb05ecbc5 100644 --- a/requirements-dev.in +++ b/requirements-dev.in @@ -32,7 +32,7 @@ pipdeptree>=2.3 pip-tools>=6.10 pre-commit>=2.17 psutil>=5.0 -pygments>=2.7 +pygments>=2.7.3 pytest-cache>=1.0 pytest-cov>=2.8 pytest-factoryboy>=2.0.3 diff --git a/requirements-dev.txt b/requirements-dev.txt index abfa99a2ae..e54e56ad62 100644 --- a/requirements-dev.txt +++ b/requirements-dev.txt @@ -11,7 +11,7 @@ astunparse==1.6.3 ; python_version < "3.9" # via dace, gt4py (pyproject.toml) attrs==23.2.0 # via flake8-bugbear, flake8-eradicate, gt4py (pyproject.toml), hypothesis, jsonschema, referencing babel==2.14.0 # via sphinx backcall==0.2.0 # via ipython -black==23.12.1 # via gt4py (pyproject.toml) +black==24.1.1 # via gt4py (pyproject.toml) blinker==1.7.0 # via flask boltons==23.1.1 # via gt4py (pyproject.toml) build==1.0.3 # via pip-tools @@ -29,8 +29,8 @@ cogapp==3.3.0 # via -r requirements-dev.in colorama==0.4.6 # via tox comm==0.2.1 # via ipykernel contourpy==1.1.1 # via matplotlib -coverage[toml]==7.4.0 # via -r requirements-dev.in, coverage, pytest-cov -cryptography==42.0.1 # via types-paramiko, types-pyopenssl, types-redis +coverage[toml]==7.4.1 # via -r requirements-dev.in, coverage, pytest-cov +cryptography==42.0.2 # via types-paramiko, types-pyopenssl, types-redis cycler==0.12.1 # via matplotlib cytoolz==0.12.3 # via gt4py (pyproject.toml) dace==0.15.1 # via gt4py (pyproject.toml) @@ -39,7 +39,7 @@ debugpy==1.8.0 # via ipykernel decorator==5.1.1 # via ipython deepdiff==6.7.1 # via gt4py (pyproject.toml) devtools==0.12.2 # via gt4py (pyproject.toml) -dill==0.3.7 # via dace +dill==0.3.8 # via dace distlib==0.3.8 # via virtualenv docutils==0.20.1 # via restructuredtext-lint, sphinx, sphinx-rtd-theme eradicate==2.3.0 # via flake8-eradicate @@ -47,7 +47,7 @@ exceptiongroup==1.2.0 # via hypothesis, pytest execnet==2.0.2 # via pytest-cache, pytest-xdist executing==2.0.1 # via devtools, stack-data factory-boy==3.3.0 # via -r requirements-dev.in, pytest-factoryboy -faker==22.5.1 # via factory-boy +faker==22.6.0 # via factory-boy fastjsonschema==2.19.1 # via nbformat filelock==3.13.1 # via tox, virtualenv flake8==7.0.0 # via -r requirements-dev.in, flake8-bugbear, flake8-builtins, flake8-debugger, flake8-docstrings, flake8-eradicate, flake8-mutable, flake8-pyproject, flake8-rst-docstrings @@ -64,7 +64,7 @@ fonttools==4.47.2 # via matplotlib fparser==0.1.3 # via dace frozendict==2.4.0 # via gt4py (pyproject.toml) gridtools-cpp==2.3.1 # via gt4py (pyproject.toml) -hypothesis==6.97.0 # via -r requirements-dev.in, gt4py (pyproject.toml) +hypothesis==6.97.3 # via -r requirements-dev.in, gt4py (pyproject.toml) identify==2.5.33 # via pre-commit idna==3.6 # via requests imagesize==1.4.1 # via sphinx @@ -85,7 +85,7 @@ jupyter-core==5.7.1 # via ipykernel, jupyter-client, nbformat jupytext==1.16.1 # via -r requirements-dev.in kiwisolver==1.4.5 # via matplotlib lark==1.1.9 # via gt4py (pyproject.toml) -mako==1.3.1 # via gt4py (pyproject.toml) +mako==1.3.2 # via gt4py (pyproject.toml) markdown-it-py==3.0.0 # via jupytext, mdit-py-plugins markupsafe==2.1.4 # via jinja2, mako, werkzeug matplotlib==3.7.4 # via -r requirements-dev.in @@ -99,7 +99,7 @@ mypy-extensions==1.0.0 # via black, mypy nanobind==1.8.0 # via gt4py (pyproject.toml) nbclient==0.6.8 # via nbmake nbformat==5.9.2 # via jupytext, nbclient, nbmake -nbmake==1.4.6 # via -r requirements-dev.in +nbmake==1.5.0 # via -r requirements-dev.in nest-asyncio==1.6.0 # via ipykernel, nbclient networkx==3.1 # via dace ninja==1.11.1.1 # via gt4py (pyproject.toml) @@ -115,7 +115,7 @@ pillow==10.2.0 # via matplotlib pip-tools==7.3.0 # via -r requirements-dev.in pipdeptree==2.13.2 # via -r requirements-dev.in pkgutil-resolve-name==1.3.10 # via jsonschema -platformdirs==4.1.0 # via black, jupyter-core, tox, virtualenv +platformdirs==4.2.0 # via black, jupyter-core, tox, virtualenv pluggy==1.4.0 # via pytest, tox ply==3.11 # via dace pre-commit==3.5.0 # via -r requirements-dev.in @@ -132,20 +132,20 @@ pygments==2.17.2 # via -r requirements-dev.in, devtools, flake8-rst-doc pyparsing==3.1.1 # via matplotlib pyproject-api==1.6.1 # via tox pyproject-hooks==1.0.0 # via build -pytest==7.4.4 # via -r requirements-dev.in, gt4py (pyproject.toml), nbmake, pytest-cache, pytest-cov, pytest-factoryboy, pytest-xdist +pytest==8.0.0 # via -r requirements-dev.in, gt4py (pyproject.toml), nbmake, pytest-cache, pytest-cov, pytest-factoryboy, pytest-xdist pytest-cache==1.0 # via -r requirements-dev.in pytest-cov==4.1.0 # via -r requirements-dev.in pytest-factoryboy==2.6.0 # via -r requirements-dev.in pytest-xdist[psutil]==3.5.0 # via -r requirements-dev.in, pytest-xdist python-dateutil==2.8.2 # via faker, jupyter-client, matplotlib -pytz==2023.3.post1 # via babel +pytz==2023.4 # via babel pyyaml==6.0.1 # via dace, jupytext, pre-commit pyzmq==25.1.2 # via ipykernel, jupyter-client -referencing==0.32.1 # via jsonschema, jsonschema-specifications +referencing==0.33.0 # via jsonschema, jsonschema-specifications requests==2.31.0 # via dace, sphinx restructuredtext-lint==1.4.0 # via flake8-rst-docstrings rpds-py==0.17.1 # via jsonschema, referencing -ruff==0.1.14 # via -r requirements-dev.in +ruff==0.1.15 # via -r requirements-dev.in setuptools-scm==8.0.4 # via fparser six==1.16.0 # via asttokens, astunparse, python-dateutil snowballstemmer==2.2.0 # via pydocstyle, sphinx @@ -193,7 +193,7 @@ types-datetimerange==2.0.0.6 # via types-all types-decorator==5.1.8.20240106 # via types-all types-deprecated==1.2.9.20240106 # via types-all types-docopt==0.6.11.4 # via types-all -types-docutils==0.20.0.20240125 # via types-all +types-docutils==0.20.0.20240126 # via types-all types-emoji==2.1.0.3 # via types-all types-enum34==1.1.8 # via types-all types-fb303==1.0.0 # via types-all, types-scribe @@ -209,7 +209,7 @@ types-itsdangerous==1.1.6 # via types-all types-jack-client==0.5.10.20240106 # via types-all types-jinja2==2.11.9 # via types-all, types-flask types-kazoo==0.1.3 # via types-all -types-markdown==3.5.0.20240106 # via types-all +types-markdown==3.5.0.20240129 # via types-all types-markupsafe==1.1.10 # via types-all, types-jinja2 types-maxminddb==1.5.0 # via types-all, types-geoip2 types-mock==5.1.0.20240106 # via types-all @@ -222,20 +222,20 @@ types-pathlib2==2.3.0 # via types-all types-pillow==10.2.0.20240125 # via types-all types-pkg-resources==0.1.3 # via types-all types-polib==1.2.0.20240115 # via types-all -types-protobuf==4.24.0.20240106 # via types-all +types-protobuf==4.24.0.20240129 # via types-all types-pyaudio==0.2.16.20240106 # via types-all types-pycurl==7.45.2.20240106 # via types-all types-pyfarmhash==0.3.1.2 # via types-all types-pyjwt==1.7.1 # via types-all types-pymssql==2.1.0 # via types-all types-pymysql==1.1.0.1 # via types-all -types-pyopenssl==23.3.0.20240106 # via types-redis +types-pyopenssl==24.0.0.20240130 # via types-redis types-pyrfc3339==1.1.1.5 # via types-all types-pysftp==0.2.17.20240106 # via types-all types-python-dateutil==2.8.19.20240106 # via types-all, types-datetimerange types-python-gflags==3.1.7.3 # via types-all -types-python-slugify==8.0.0.3 # via types-all -types-pytz==2023.3.1.1 # via types-all, types-tzlocal +types-python-slugify==8.0.2.20240127 # via types-all +types-pytz==2023.4.0.20240130 # via types-all, types-tzlocal types-pyvmomi==8.0.0.6 # via types-all types-pyyaml==6.0.12.12 # via types-all types-redis==4.6.0.20240106 # via types-all @@ -258,7 +258,7 @@ types-waitress==2.1.4.20240106 # via types-all types-werkzeug==1.0.9 # via types-all, types-flask types-xxhash==3.0.5.2 # via types-all typing-extensions==4.5.0 # via black, faker, gt4py (pyproject.toml), ipython, mypy, pytest-factoryboy, setuptools-scm -urllib3==2.1.0 # via requests, types-requests +urllib3==2.2.0 # via requests, types-requests virtualenv==20.25.0 # via pre-commit, tox wcwidth==0.2.13 # via prompt-toolkit websockets==12.0 # via dace diff --git a/src/gt4py/_core/definitions.py b/src/gt4py/_core/definitions.py index 091fa77e3f..6237704f69 100644 --- a/src/gt4py/_core/definitions.py +++ b/src/gt4py/_core/definitions.py @@ -165,28 +165,23 @@ class DTypeKind(eve.StrEnum): @overload -def dtype_kind(sc_type: Type[BoolT]) -> Literal[DTypeKind.BOOL]: - ... +def dtype_kind(sc_type: Type[BoolT]) -> Literal[DTypeKind.BOOL]: ... @overload -def dtype_kind(sc_type: Type[IntT]) -> Literal[DTypeKind.INT]: - ... +def dtype_kind(sc_type: Type[IntT]) -> Literal[DTypeKind.INT]: ... @overload -def dtype_kind(sc_type: Type[UnsignedIntT]) -> Literal[DTypeKind.UINT]: - ... +def dtype_kind(sc_type: Type[UnsignedIntT]) -> Literal[DTypeKind.UINT]: ... @overload -def dtype_kind(sc_type: Type[FloatingT]) -> Literal[DTypeKind.FLOAT]: - ... +def dtype_kind(sc_type: Type[FloatingT]) -> Literal[DTypeKind.FLOAT]: ... @overload -def dtype_kind(sc_type: Type[ScalarT]) -> DTypeKind: - ... +def dtype_kind(sc_type: Type[ScalarT]) -> DTypeKind: ... def dtype_kind(sc_type: Type[ScalarT]) -> DTypeKind: @@ -360,8 +355,7 @@ class GTDimsInterface(Protocol): """ @property - def __gt_dims__(self) -> Tuple[str, ...]: - ... + def __gt_dims__(self) -> Tuple[str, ...]: ... class GTOriginInterface(Protocol): @@ -372,8 +366,7 @@ class GTOriginInterface(Protocol): """ @property - def __gt_origin__(self) -> Tuple[int, ...]: - ... + def __gt_origin__(self) -> Tuple[int, ...]: ... # -- Device representation -- @@ -443,61 +436,43 @@ def __iter__(self) -> Iterator[DeviceTypeT | int]: class NDArrayObject(Protocol): @property - def ndim(self) -> int: - ... + def ndim(self) -> int: ... @property - def shape(self) -> tuple[int, ...]: - ... + def shape(self) -> tuple[int, ...]: ... @property - def dtype(self) -> Any: - ... + def dtype(self) -> Any: ... - def astype(self, dtype: npt.DTypeLike) -> NDArrayObject: - ... + def astype(self, dtype: npt.DTypeLike) -> NDArrayObject: ... - def __getitem__(self, item: Any) -> NDArrayObject: - ... + def __getitem__(self, item: Any) -> NDArrayObject: ... - def __abs__(self) -> NDArrayObject: - ... + def __abs__(self) -> NDArrayObject: ... - def __neg__(self) -> NDArrayObject: - ... + def __neg__(self) -> NDArrayObject: ... - def __add__(self, other: NDArrayObject | Scalar) -> NDArrayObject: - ... + def __add__(self, other: NDArrayObject | Scalar) -> NDArrayObject: ... - def __radd__(self, other: Any) -> NDArrayObject: - ... + def __radd__(self, other: Any) -> NDArrayObject: ... - def __sub__(self, other: NDArrayObject | Scalar) -> NDArrayObject: - ... + def __sub__(self, other: NDArrayObject | Scalar) -> NDArrayObject: ... - def __rsub__(self, other: Any) -> NDArrayObject: - ... + def __rsub__(self, other: Any) -> NDArrayObject: ... - def __mul__(self, other: NDArrayObject | Scalar) -> NDArrayObject: - ... + def __mul__(self, other: NDArrayObject | Scalar) -> NDArrayObject: ... - def __rmul__(self, other: Any) -> NDArrayObject: - ... + def __rmul__(self, other: Any) -> NDArrayObject: ... - def __floordiv__(self, other: NDArrayObject | Scalar) -> NDArrayObject: - ... + def __floordiv__(self, other: NDArrayObject | Scalar) -> NDArrayObject: ... - def __rfloordiv__(self, other: Any) -> NDArrayObject: - ... + def __rfloordiv__(self, other: Any) -> NDArrayObject: ... - def __truediv__(self, other: NDArrayObject | Scalar) -> NDArrayObject: - ... + def __truediv__(self, other: NDArrayObject | Scalar) -> NDArrayObject: ... - def __rtruediv__(self, other: Any) -> NDArrayObject: - ... + def __rtruediv__(self, other: Any) -> NDArrayObject: ... - def __pow__(self, other: NDArrayObject | Scalar) -> NDArrayObject: - ... + def __pow__(self, other: NDArrayObject | Scalar) -> NDArrayObject: ... def __eq__(self, other: NDArrayObject | Scalar) -> NDArrayObject: # type: ignore[override] # mypy wants to return `bool` ... @@ -517,11 +492,8 @@ def __lt__(self, other: NDArrayObject | Scalar) -> NDArrayObject: # type: ignor def __le__(self, other: NDArrayObject | Scalar) -> NDArrayObject: # type: ignore[misc] # Forward operator is not callable ... - def __and__(self, other: NDArrayObject | Scalar) -> NDArrayObject: - ... + def __and__(self, other: NDArrayObject | Scalar) -> NDArrayObject: ... - def __or__(self, other: NDArrayObject | Scalar) -> NDArrayObject: - ... + def __or__(self, other: NDArrayObject | Scalar) -> NDArrayObject: ... - def __xor(self, other: NDArrayObject | Scalar) -> NDArrayObject: - ... + def __xor(self, other: NDArrayObject | Scalar) -> NDArrayObject: ... diff --git a/src/gt4py/cartesian/backend/base.py b/src/gt4py/cartesian/backend/base.py index 62e36de721..669110161e 100644 --- a/src/gt4py/cartesian/backend/base.py +++ b/src/gt4py/cartesian/backend/base.py @@ -305,8 +305,7 @@ def make_module_source(self, *, args_data: Optional[ModuleData] = None, **kwargs class MakeModuleSourceCallable(Protocol): - def __call__(self, *, args_data: Optional[ModuleData] = None, **kwargs: Any) -> str: - ... + def __call__(self, *, args_data: Optional[ModuleData] = None, **kwargs: Any) -> str: ... class PurePythonBackendCLIMixin(CLIBackendMixin): diff --git a/src/gt4py/cartesian/backend/dace_backend.py b/src/gt4py/cartesian/backend/dace_backend.py index 5dae025acb..b02c765ad7 100644 --- a/src/gt4py/cartesian/backend/dace_backend.py +++ b/src/gt4py/cartesian/backend/dace_backend.py @@ -684,14 +684,14 @@ def generate_entry_params(self, stencil_ir: gtir.Stencil, sdfg: dace.SDFG) -> Li if name in sdfg.arrays: data = sdfg.arrays[name] assert isinstance(data, dace.data.Array) - res[ - name - ] = "py::{pybind_type} {name}, std::array {name}_origin".format( - pybind_type="object" - if self.backend.storage_info["device"] == "gpu" - else "buffer", - name=name, - ndim=len(data.shape), + res[name] = ( + "py::{pybind_type} {name}, std::array {name}_origin".format( + pybind_type=( + "object" if self.backend.storage_info["device"] == "gpu" else "buffer" + ), + name=name, + ndim=len(data.shape), + ) ) elif name in sdfg.symbols and not name.startswith("__"): assert name in sdfg.symbols diff --git a/src/gt4py/cartesian/backend/gtcpp_backend.py b/src/gt4py/cartesian/backend/gtcpp_backend.py index e0f982b8be..c69b5b5088 100644 --- a/src/gt4py/cartesian/backend/gtcpp_backend.py +++ b/src/gt4py/cartesian/backend/gtcpp_backend.py @@ -88,9 +88,9 @@ def visit_FieldDecl(self, node: gtcpp.FieldDecl, **kwargs): sid_ndim = domain_ndim + data_ndim if kwargs["external_arg"]: return "py::{pybind_type} {name}, std::array {name}_origin".format( - pybind_type="object" - if self.backend.storage_info["device"] == "gpu" - else "buffer", + pybind_type=( + "object" if self.backend.storage_info["device"] == "gpu" else "buffer" + ), name=node.name, sid_ndim=sid_ndim, ) diff --git a/src/gt4py/cartesian/backend/pyext_builder.py b/src/gt4py/cartesian/backend/pyext_builder.py index e12669ae0f..1ffa5a412d 100644 --- a/src/gt4py/cartesian/backend/pyext_builder.py +++ b/src/gt4py/cartesian/backend/pyext_builder.py @@ -179,8 +179,7 @@ def build_pybind_ext( build_path: str, target_path: str, **kwargs: str, -) -> Tuple[str, str]: - ... +) -> Tuple[str, str]: ... @overload @@ -198,8 +197,7 @@ def build_pybind_ext( build_ext_class: Type = None, verbose: bool = False, clean: bool = False, -) -> Tuple[str, str]: - ... +) -> Tuple[str, str]: ... def build_pybind_ext( diff --git a/src/gt4py/cartesian/frontend/defir_to_gtir.py b/src/gt4py/cartesian/frontend/defir_to_gtir.py index f2ee544900..eb53e49ac5 100644 --- a/src/gt4py/cartesian/frontend/defir_to_gtir.py +++ b/src/gt4py/cartesian/frontend/defir_to_gtir.py @@ -489,18 +489,18 @@ def visit_If(self, node: If) -> Union[gtir.FieldIfStmt, gtir.ScalarIfStmt]: return gtir.FieldIfStmt( cond=cond, true_branch=gtir.BlockStmt(body=self.visit(node.main_body)), - false_branch=gtir.BlockStmt(body=self.visit(node.else_body)) - if node.else_body - else None, + false_branch=( + gtir.BlockStmt(body=self.visit(node.else_body)) if node.else_body else None + ), loc=location_to_source_location(node.loc), ) else: return gtir.ScalarIfStmt( cond=cond, true_branch=gtir.BlockStmt(body=self.visit(node.main_body)), - false_branch=gtir.BlockStmt(body=self.visit(node.else_body)) - if node.else_body - else None, + false_branch=( + gtir.BlockStmt(body=self.visit(node.else_body)) if node.else_body else None + ), loc=location_to_source_location(node.loc), ) diff --git a/src/gt4py/cartesian/frontend/gtscript_frontend.py b/src/gt4py/cartesian/frontend/gtscript_frontend.py index f665410b30..2df8c106ce 100644 --- a/src/gt4py/cartesian/frontend/gtscript_frontend.py +++ b/src/gt4py/cartesian/frontend/gtscript_frontend.py @@ -1154,9 +1154,11 @@ def visit_Subscript(self, node: ast.Subscript): result.offset = {axis: value for axis, value in zip(field_axes, index)} elif isinstance(node.value, ast.Subscript): result.data_index = [ - nodes.ScalarLiteral(value=value, data_type=nodes.DataType.INT32) - if isinstance(value, numbers.Integral) - else value + ( + nodes.ScalarLiteral(value=value, data_type=nodes.DataType.INT32) + if isinstance(value, numbers.Integral) + else value + ) for value in index ] if len(result.data_index) != len(self.fields[result.name].data_dims): @@ -1321,9 +1323,11 @@ def visit_If(self, node: ast.If) -> list: condition=self.visit(node.test), loc=nodes.Location.from_ast_node(node), main_body=nodes.BlockStmt(stmts=main_stmts, loc=nodes.Location.from_ast_node(node)), - else_body=nodes.BlockStmt(stmts=else_stmts, loc=nodes.Location.from_ast_node(node)) - if else_stmts - else None, + else_body=( + nodes.BlockStmt(stmts=else_stmts, loc=nodes.Location.from_ast_node(node)) + if else_stmts + else None + ), ) ) diff --git a/src/gt4py/cartesian/gtc/cuir/oir_to_cuir.py b/src/gt4py/cartesian/gtc/cuir/oir_to_cuir.py index 567d128c29..de1ca93557 100644 --- a/src/gt4py/cartesian/gtc/cuir/oir_to_cuir.py +++ b/src/gt4py/cartesian/gtc/cuir/oir_to_cuir.py @@ -29,8 +29,7 @@ class SymbolNameCreator(Protocol): - def __call__(self, name: str) -> str: - ... + def __call__(self, name: str) -> str: ... def _make_axis_offset_expr(bound: common.AxisBound, axis_index: int) -> cuir.Expr: diff --git a/src/gt4py/cartesian/gtc/dace/expansion/daceir_builder.py b/src/gt4py/cartesian/gtc/dace/expansion/daceir_builder.py index 48b129fa87..9a214441ad 100644 --- a/src/gt4py/cartesian/gtc/dace/expansion/daceir_builder.py +++ b/src/gt4py/cartesian/gtc/dace/expansion/daceir_builder.py @@ -118,18 +118,26 @@ def all_regions_same(scope_nodes): len( set( ( - None - if mask.intervals[axis.to_idx()].start is None - else mask.intervals[axis.to_idx()].start.level, - None - if mask.intervals[axis.to_idx()].start is None - else mask.intervals[axis.to_idx()].start.offset, - None - if mask.intervals[axis.to_idx()].end is None - else mask.intervals[axis.to_idx()].end.level, - None - if mask.intervals[axis.to_idx()].end is None - else mask.intervals[axis.to_idx()].end.offset, + ( + None + if mask.intervals[axis.to_idx()].start is None + else mask.intervals[axis.to_idx()].start.level + ), + ( + None + if mask.intervals[axis.to_idx()].start is None + else mask.intervals[axis.to_idx()].start.offset + ), + ( + None + if mask.intervals[axis.to_idx()].end is None + else mask.intervals[axis.to_idx()].end.level + ), + ( + None + if mask.intervals[axis.to_idx()].end is None + else mask.intervals[axis.to_idx()].end.offset + ), ) for mask in eve.walk_values(scope_nodes).if_isinstance(common.HorizontalMask) ) diff --git a/src/gt4py/cartesian/gtc/dace/expansion_specification.py b/src/gt4py/cartesian/gtc/dace/expansion_specification.py index 57146ef2a8..7c99146426 100644 --- a/src/gt4py/cartesian/gtc/dace/expansion_specification.py +++ b/src/gt4py/cartesian/gtc/dace/expansion_specification.py @@ -173,9 +173,11 @@ def _order_as_spec(computation_node, expansion_order): expansion_specification.append( Loop( axis=axis, - stride=-1 - if computation_node.oir_node.loop_order == common.LoopOrder.BACKWARD - else 1, + stride=( + -1 + if computation_node.oir_node.loop_order == common.LoopOrder.BACKWARD + else 1 + ), ) ) elif item == "Sections": diff --git a/src/gt4py/cartesian/gtc/gtcpp/oir_to_gtcpp.py b/src/gt4py/cartesian/gtc/gtcpp/oir_to_gtcpp.py index 82991af1d4..58cddffd5f 100644 --- a/src/gt4py/cartesian/gtc/gtcpp/oir_to_gtcpp.py +++ b/src/gt4py/cartesian/gtc/gtcpp/oir_to_gtcpp.py @@ -94,8 +94,7 @@ def _make_axis_offset_expr( class SymbolNameCreator(Protocol): - def __call__(self, name: str) -> str: - ... + def __call__(self, name: str) -> str: ... class OIRToGTCpp(eve.NodeTranslator, eve.VisitorWithSymbolTableTrait): diff --git a/src/gt4py/cartesian/testing/input_strategies.py b/src/gt4py/cartesian/testing/input_strategies.py index 5f3ed32572..008b859929 100644 --- a/src/gt4py/cartesian/testing/input_strategies.py +++ b/src/gt4py/cartesian/testing/input_strategies.py @@ -142,7 +142,11 @@ def scalar_value_st(dtype, min_value, max_value, allow_nan=False): """Hypothesis strategy for `dtype` scalar values in range [min_value, max_value].""" allow_infinity = not (np.isfinite(min_value) and np.isfinite(max_value)) - if issubclass(dtype.type, numbers.Real): + if issubclass(dtype.type, numbers.Integral): + value_st = hyp_st.integers(min_value, max_value) + elif issubclass( + dtype.type, numbers.Real + ): # after numbers.Integral because np.int32 is a subclass of numbers.Real value_st = hyp_st.floats( min_value, max_value, @@ -150,8 +154,6 @@ def scalar_value_st(dtype, min_value, max_value, allow_nan=False): allow_nan=allow_nan, width=dtype.itemsize * 8, ) - elif issubclass(dtype.type, numbers.Integral): - value_st = hyp_st.integers(min_value, max_value) return value_st.map(dtype.type) diff --git a/src/gt4py/cartesian/type_hints.py b/src/gt4py/cartesian/type_hints.py index a1af6b93d1..3a776ba847 100644 --- a/src/gt4py/cartesian/type_hints.py +++ b/src/gt4py/cartesian/type_hints.py @@ -21,8 +21,7 @@ class StencilFunc(Protocol): __name__: str __module__: str - def __call__(self, *args: Any, **kwargs: Dict[str, Any]) -> None: - ... + def __call__(self, *args: Any, **kwargs: Dict[str, Any]) -> None: ... class AnnotatedStencilFunc(StencilFunc, Protocol): diff --git a/src/gt4py/cartesian/utils/attrib.py b/src/gt4py/cartesian/utils/attrib.py index f2f77769ec..da53e5c128 100644 --- a/src/gt4py/cartesian/utils/attrib.py +++ b/src/gt4py/cartesian/utils/attrib.py @@ -240,16 +240,13 @@ def attribute(of, optional=False, **kwargs): class AttributeClassLike: - def validate(self): - ... + def validate(self): ... @property - def attributes(self): - ... + def attributes(self): ... @property - def as_dict(self): - ... + def as_dict(self): ... def attribclass(cls_or_none=None, **kwargs): diff --git a/src/gt4py/eve/codegen.py b/src/gt4py/eve/codegen.py index 3a964c92a9..72f0e8858f 100644 --- a/src/gt4py/eve/codegen.py +++ b/src/gt4py/eve/codegen.py @@ -641,15 +641,13 @@ def __init_subclass__(cls, *, inherit_templates: bool = True, **kwargs: Any) -> @overload @classmethod - def apply(cls, root: LeafNode, **kwargs: Any) -> str: - ... + def apply(cls, root: LeafNode, **kwargs: Any) -> str: ... @overload @classmethod def apply( # noqa: F811 # redefinition of symbol cls, root: CollectionNode, **kwargs: Any - ) -> Collection[str]: - ... + ) -> Collection[str]: ... @classmethod def apply( # noqa: F811 # redefinition of symbol diff --git a/src/gt4py/eve/datamodels/core.py b/src/gt4py/eve/datamodels/core.py index bc744b3ccc..11ad824aab 100644 --- a/src/gt4py/eve/datamodels/core.py +++ b/src/gt4py/eve/datamodels/core.py @@ -84,8 +84,7 @@ class _AttrsClassTP(Protocol): class DataModelTP(_AttrsClassTP, xtyping.DevToolsPrettyPrintable, Protocol): - def __init__(self, *args: Any, **kwargs: Any) -> None: - ... + def __init__(self, *args: Any, **kwargs: Any) -> None: ... __datamodel_fields__: ClassVar[utils.FrozenNamespace[Attribute]] = cast( utils.FrozenNamespace[Attribute], None @@ -116,8 +115,7 @@ class GenericDataModelTP(DataModelTP, Protocol): @classmethod def __class_getitem__( cls: Type[GenericDataModelTP], args: Union[Type, Tuple[Type, ...]] - ) -> Union[DataModelTP, GenericDataModelTP]: - ... + ) -> Union[DataModelTP, GenericDataModelTP]: ... _DM = TypeVar("_DM", bound="DataModel") @@ -280,8 +278,7 @@ def datamodel( coerce: bool = _COERCE_DEFAULT, generic: bool = _GENERIC_DEFAULT, type_validation_factory: Optional[FieldTypeValidatorFactory] = DefaultFieldTypeValidatorFactory, -) -> Callable[[Type[_T]], Type[_T]]: - ... +) -> Callable[[Type[_T]], Type[_T]]: ... @overload @@ -300,8 +297,7 @@ def datamodel( # noqa: F811 # redefinion of unused symbol coerce: bool = _COERCE_DEFAULT, generic: bool = _GENERIC_DEFAULT, type_validation_factory: Optional[FieldTypeValidatorFactory] = DefaultFieldTypeValidatorFactory, -) -> Type[_T]: - ... +) -> Type[_T]: ... # TODO(egparedes): Use @dataclass_transform(eq_default=True, field_specifiers=("field",)) @@ -410,8 +406,7 @@ def __call__( type_validation_factory: Optional[ FieldTypeValidatorFactory ] = DefaultFieldTypeValidatorFactory, - ) -> Union[Type[_T], Callable[[Type[_T]], Type[_T]]]: - ... + ) -> Union[Type[_T], Callable[[Type[_T]], Type[_T]]]: ... frozenmodel: _DataModelDecoratorTP = functools.partial(datamodel, frozen=True) @@ -424,13 +419,11 @@ def __call__( if xtyping.TYPE_CHECKING: class DataModel(DataModelTP): - def __init__(self, *args: Any, **kwargs: Any) -> None: - ... + def __init__(self, *args: Any, **kwargs: Any) -> None: ... def __pretty__( self, fmt: Callable[[Any], Any], **kwargs: Any - ) -> Generator[Any, None, None]: - ... + ) -> Generator[Any, None, None]: ... else: # TODO(egparedes): use @dataclass_transform(eq_default=True, field_specifiers=("field",)) @@ -453,9 +446,9 @@ def __init_subclass__( cls, /, *, - repr: bool # noqa: A002 # shadowing 'repr' python builtin - | None - | Literal["inherited"] = "inherited", + repr: ( # noqa: A002 # shadowing 'repr' python builtin + bool | None | Literal["inherited"] + ) = "inherited", eq: bool | None | Literal["inherited"] = "inherited", order: bool | None | Literal["inherited"] = "inherited", unsafe_hash: bool | None | Literal["inherited"] = "inherited", @@ -463,8 +456,9 @@ def __init_subclass__( match_args: bool | Literal["inherited"] = "inherited", kw_only: bool | Literal["inherited"] = "inherited", coerce: bool | Literal["inherited"] = "inherited", - type_validation_factory: Optional[FieldTypeValidatorFactory] - | Literal["inherited"] = "inherited", + type_validation_factory: ( + Optional[FieldTypeValidatorFactory] | Literal["inherited"] + ) = "inherited", **kwargs: Any, ) -> None: dm_opts = kwargs.pop(_DM_OPTS, []) @@ -519,10 +513,9 @@ def field( metadata: Optional[Mapping[Any, Any]] = None, kw_only: bool = _KW_ONLY_DEFAULT, converter: Callable[[Any], Any] | Literal["coerce"] | None = None, - validator: AttrsValidator - | FieldValidator - | Sequence[AttrsValidator | FieldValidator] - | None = None, + validator: ( + AttrsValidator | FieldValidator | Sequence[AttrsValidator | FieldValidator] | None + ) = None, ) -> Any: # attr.s lies in some typings """Define a new attribute on a class with advanced options. @@ -1373,8 +1366,7 @@ class GenericDataModel(GenericDataModelTP): @classmethod def __class_getitem__( cls: Type[GenericDataModelTP], args: Union[Type, Tuple[Type, ...]] - ) -> Union[DataModelTP, GenericDataModelTP]: - ... + ) -> Union[DataModelTP, GenericDataModelTP]: ... else: diff --git a/src/gt4py/eve/extended_typing.py b/src/gt4py/eve/extended_typing.py index 3ee447ca6c..82076d1a9c 100644 --- a/src/gt4py/eve/extended_typing.py +++ b/src/gt4py/eve/extended_typing.py @@ -177,19 +177,16 @@ class NonDataDescriptor(Protocol[_C, _V]): @overload def __get__( self, _instance: Literal[None], _owner_type: Optional[Type[_C]] = None - ) -> NonDataDescriptor[_C, _V]: - ... + ) -> NonDataDescriptor[_C, _V]: ... @overload def __get__( # noqa: F811 # redefinion of unused member self, _instance: _C, _owner_type: Optional[Type[_C]] = None - ) -> _V: - ... + ) -> _V: ... def __get__( # noqa: F811 # redefinion of unused member self, _instance: Optional[_C], _owner_type: Optional[Type[_C]] = None - ) -> _V | NonDataDescriptor[_C, _V]: - ... + ) -> _V | NonDataDescriptor[_C, _V]: ... class DataDescriptor(NonDataDescriptor[_C, _V], Protocol): @@ -198,11 +195,9 @@ class DataDescriptor(NonDataDescriptor[_C, _V], Protocol): See https://docs.python.org/3/howto/descriptor.html for further information. """ - def __set__(self, _instance: _C, _value: _V) -> None: - ... + def __set__(self, _instance: _C, _value: _V) -> None: ... - def __delete__(self, _instance: _C) -> None: - ... + def __delete__(self, _instance: _C) -> None: ... # -- Based on typeshed definitions -- @@ -220,26 +215,20 @@ class HashlibAlgorithm(Protocol): block_size: int name: str - def __init__(self, data: ReadableBuffer = ...) -> None: - ... + def __init__(self, data: ReadableBuffer = ...) -> None: ... - def copy(self) -> HashlibAlgorithm: - ... + def copy(self) -> HashlibAlgorithm: ... - def update(self, data: ReadableBuffer) -> None: - ... + def update(self, data: ReadableBuffer) -> None: ... - def digest(self) -> bytes: - ... + def digest(self) -> bytes: ... - def hexdigest(self) -> str: - ... + def hexdigest(self) -> str: ... # -- Third party protocols -- class SupportsArray(Protocol): - def __array__(self, dtype: Optional[npt.DTypeLike] = None, /) -> npt.NDArray[Any]: - ... + def __array__(self, dtype: Optional[npt.DTypeLike] = None, /) -> npt.NDArray[Any]: ... def supports_array(value: Any) -> TypeGuard[SupportsArray]: @@ -248,8 +237,7 @@ def supports_array(value: Any) -> TypeGuard[SupportsArray]: class ArrayInterface(Protocol): @property - def __array_interface__(self) -> Dict[str, Any]: - ... + def __array_interface__(self) -> Dict[str, Any]: ... class ArrayInterfaceTypedDict(TypedDict): @@ -265,8 +253,7 @@ class ArrayInterfaceTypedDict(TypedDict): class StrictArrayInterface(Protocol): @property - def __array_interface__(self) -> ArrayInterfaceTypedDict: - ... + def __array_interface__(self) -> ArrayInterfaceTypedDict: ... def supports_array_interface(value: Any) -> TypeGuard[ArrayInterface]: @@ -275,8 +262,7 @@ def supports_array_interface(value: Any) -> TypeGuard[ArrayInterface]: class CUDAArrayInterface(Protocol): @property - def __cuda_array_interface__(self) -> Dict[str, Any]: - ... + def __cuda_array_interface__(self) -> Dict[str, Any]: ... class CUDAArrayInterfaceTypedDict(TypedDict): @@ -292,8 +278,7 @@ class CUDAArrayInterfaceTypedDict(TypedDict): class StrictCUDAArrayInterface(Protocol): @property - def __cuda_array_interface__(self) -> CUDAArrayInterfaceTypedDict: - ... + def __cuda_array_interface__(self) -> CUDAArrayInterfaceTypedDict: ... def supports_cuda_array_interface(value: Any) -> TypeGuard[CUDAArrayInterface]: @@ -305,19 +290,15 @@ def supports_cuda_array_interface(value: Any) -> TypeGuard[CUDAArrayInterface]: class MultiStreamDLPackBuffer(Protocol): - def __dlpack__(self, *, stream: Optional[int] = None) -> Any: - ... + def __dlpack__(self, *, stream: Optional[int] = None) -> Any: ... - def __dlpack_device__(self) -> DLPackDevice: - ... + def __dlpack_device__(self) -> DLPackDevice: ... class SingleStreamDLPackBuffer(Protocol): - def __dlpack__(self, *, stream: None = None) -> Any: - ... + def __dlpack__(self, *, stream: None = None) -> Any: ... - def __dlpack_device__(self) -> DLPackDevice: - ... + def __dlpack_device__(self) -> DLPackDevice: ... DLPackBuffer: TypeAlias = Union[MultiStreamDLPackBuffer, SingleStreamDLPackBuffer] @@ -333,8 +314,9 @@ def supports_dlpack(value: Any) -> TypeGuard[DLPackBuffer]: class DevToolsPrettyPrintable(Protocol): """Used by python-devtools (https://python-devtools.helpmanual.io/).""" - def __pretty__(self, fmt: Callable[[Any], Any], **kwargs: Any) -> Generator[Any, None, None]: - ... + def __pretty__( + self, fmt: Callable[[Any], Any], **kwargs: Any + ) -> Generator[Any, None, None]: ... # -- Added functionality -- @@ -357,8 +339,7 @@ def extended_runtime_checkable( *, instance_check_shortcut: bool = True, subclass_check_with_data_members: bool = False, -) -> Callable[[_ProtoT], _ProtoT]: - ... +) -> Callable[[_ProtoT], _ProtoT]: ... @overload @@ -367,8 +348,7 @@ def extended_runtime_checkable( *, instance_check_shortcut: bool = True, subclass_check_with_data_members: bool = False, -) -> _ProtoT: - ... +) -> _ProtoT: ... def extended_runtime_checkable( # noqa: C901 # too complex but unavoidable @@ -414,9 +394,11 @@ def _decorator(cls: _ProtoT) -> _ProtoT: _allow_reckless_class_checks = getattr( _typing, - "_allow_reckless_class_checks" - if hasattr(_typing, "_allow_reckless_class_checks") - else "_allow_reckless_class_cheks", # There is a typo in 3.8 and 3.9 + ( + "_allow_reckless_class_checks" + if hasattr(_typing, "_allow_reckless_class_checks") + else "_allow_reckless_class_cheks" + ), # There is a typo in 3.8 and 3.9 ) _get_protocol_attrs = ( diff --git a/src/gt4py/eve/trees.py b/src/gt4py/eve/trees.py index 74c5bd41bb..7bfd22cdf7 100644 --- a/src/gt4py/eve/trees.py +++ b/src/gt4py/eve/trees.py @@ -62,12 +62,10 @@ class TreeLike(abc.ABC): # noqa: B024 class Tree(Protocol): @abc.abstractmethod - def iter_children_values(self) -> Iterable: - ... + def iter_children_values(self) -> Iterable: ... @abc.abstractmethod - def iter_children_items(self) -> Iterable[Tuple[TreeKey, Any]]: - ... + def iter_children_items(self) -> Iterable[Tuple[TreeKey, Any]]: ... TreeLike.register(Tree) diff --git a/src/gt4py/eve/type_validation.py b/src/gt4py/eve/type_validation.py index 65f492ebfe..124957fa20 100644 --- a/src/gt4py/eve/type_validation.py +++ b/src/gt4py/eve/type_validation.py @@ -110,8 +110,7 @@ def __call__( globalns: Optional[Dict[str, Any]] = None, localns: Optional[Dict[str, Any]] = None, **kwargs: Any, - ) -> FixedTypeValidator: - ... + ) -> FixedTypeValidator: ... @overload def __call__( # noqa: F811 # redefinion of unused member @@ -123,8 +122,7 @@ def __call__( # noqa: F811 # redefinion of unused member globalns: Optional[Dict[str, Any]] = None, localns: Optional[Dict[str, Any]] = None, **kwargs: Any, - ) -> Optional[FixedTypeValidator]: - ... + ) -> Optional[FixedTypeValidator]: ... @abc.abstractmethod def __call__( # noqa: F811 # redefinion of unused member @@ -169,8 +167,7 @@ def __call__( globalns: Optional[Dict[str, Any]] = None, localns: Optional[Dict[str, Any]] = None, **kwargs: Any, - ) -> FixedTypeValidator: - ... + ) -> FixedTypeValidator: ... @overload def __call__( # noqa: F811 # redefinion of unused member @@ -182,8 +179,7 @@ def __call__( # noqa: F811 # redefinion of unused member globalns: Optional[Dict[str, Any]] = None, localns: Optional[Dict[str, Any]] = None, **kwargs: Any, - ) -> Optional[FixedTypeValidator]: - ... + ) -> Optional[FixedTypeValidator]: ... def __call__( # noqa: F811,C901 # redefinion of unused member / complex but well organized in cases self, diff --git a/src/gt4py/eve/utils.py b/src/gt4py/eve/utils.py index 624407f319..8e634c4b11 100644 --- a/src/gt4py/eve/utils.py +++ b/src/gt4py/eve/utils.py @@ -241,15 +241,13 @@ def partial(self, *args: Any, **kwargs: Any) -> fluid_partial: @overload def with_fluid_partial( func: Literal[None] = None, *args: Any, **kwargs: Any -) -> Callable[[Callable[_P, _T]], Callable[_P, _T]]: - ... +) -> Callable[[Callable[_P, _T]], Callable[_P, _T]]: ... @overload def with_fluid_partial( # noqa: F811 # redefinition of unused function func: Callable[_P, _T], *args: Any, **kwargs: Any -) -> Callable[_P, _T]: - ... +) -> Callable[_P, _T]: ... def with_fluid_partial( # noqa: F811 # redefinition of unused function @@ -286,15 +284,13 @@ def _decorator(func: Callable[..., Any]) -> Callable[..., Any]: @overload def optional_lru_cache( func: Literal[None] = None, *, maxsize: Optional[int] = 128, typed: bool = False -) -> Callable[[Callable[_P, _T]], Callable[_P, _T]]: - ... +) -> Callable[[Callable[_P, _T]], Callable[_P, _T]]: ... @overload def optional_lru_cache( # noqa: F811 # redefinition of unused function func: Callable[_P, _T], *, maxsize: Optional[int] = 128, typed: bool = False -) -> Callable[_P, _T]: - ... +) -> Callable[_P, _T]: ... def optional_lru_cache( # noqa: F811 # redefinition of unused function @@ -1228,12 +1224,10 @@ def unzip(self) -> XIterable[Tuple[Any, ...]]: return XIterable(zip(*self.iterator)) @typing.overload - def islice(self, __stop: int) -> XIterable[T]: - ... + def islice(self, __stop: int) -> XIterable[T]: ... @typing.overload - def islice(self, __start: int, __stop: int, __step: int = 1) -> XIterable[T]: - ... + def islice(self, __start: int, __stop: int, __step: int = 1) -> XIterable[T]: ... def islice( self, @@ -1315,18 +1309,17 @@ def unique(self, *, key: Union[NOTHING, Callable] = NOTHING) -> XIterable[T]: @typing.overload def groupby( self, key: str, *other_keys: str, as_dict: bool = False - ) -> XIterable[Tuple[Any, List[T]]]: - ... + ) -> XIterable[Tuple[Any, List[T]]]: ... @typing.overload - def groupby(self, key: List[Any], *, as_dict: bool = False) -> XIterable[Tuple[Any, List[T]]]: - ... + def groupby( + self, key: List[Any], *, as_dict: bool = False + ) -> XIterable[Tuple[Any, List[T]]]: ... @typing.overload def groupby( self, key: Callable[[T], Any], *, as_dict: bool = False - ) -> XIterable[Tuple[Any, List[T]]]: - ... + ) -> XIterable[Tuple[Any, List[T]]]: ... def groupby( self, @@ -1454,8 +1447,7 @@ def reduceby( *, as_dict: Literal[False], init: Union[S, NothingType], - ) -> XIterable[Tuple[str, S]]: - ... + ) -> XIterable[Tuple[str, S]]: ... @typing.overload def reduceby( @@ -1466,8 +1458,7 @@ def reduceby( *attr_keys: str, as_dict: Literal[False], init: Union[S, NothingType], - ) -> XIterable[Tuple[Tuple[str, ...], S]]: - ... + ) -> XIterable[Tuple[Tuple[str, ...], S]]: ... @typing.overload def reduceby( @@ -1477,8 +1468,7 @@ def reduceby( *, as_dict: Literal[True], init: Union[S, NothingType], - ) -> Dict[str, S]: - ... + ) -> Dict[str, S]: ... @typing.overload def reduceby( @@ -1489,8 +1479,7 @@ def reduceby( *attr_keys: str, as_dict: Literal[True], init: Union[S, NothingType], - ) -> Dict[Tuple[str, ...], S]: - ... + ) -> Dict[Tuple[str, ...], S]: ... @typing.overload def reduceby( @@ -1500,8 +1489,7 @@ def reduceby( *, as_dict: Literal[False], init: Union[S, NothingType], - ) -> XIterable[Tuple[K, S]]: - ... + ) -> XIterable[Tuple[K, S]]: ... @typing.overload def reduceby( @@ -1511,8 +1499,7 @@ def reduceby( *, as_dict: Literal[True], init: Union[S, NothingType], - ) -> Dict[K, S]: - ... + ) -> Dict[K, S]: ... @typing.overload def reduceby( @@ -1522,8 +1509,7 @@ def reduceby( *, as_dict: Literal[False], init: Union[S, NothingType], - ) -> XIterable[Tuple[K, S]]: - ... + ) -> XIterable[Tuple[K, S]]: ... @typing.overload def reduceby( @@ -1533,8 +1519,7 @@ def reduceby( *, as_dict: Literal[True], init: Union[S, NothingType], - ) -> Dict[K, S]: - ... + ) -> Dict[K, S]: ... def reduceby( self, diff --git a/src/gt4py/next/allocators.py b/src/gt4py/next/allocators.py index 44203bf6d8..559e78eb3e 100644 --- a/src/gt4py/next/allocators.py +++ b/src/gt4py/next/allocators.py @@ -58,8 +58,7 @@ class FieldBufferAllocatorProtocol(Protocol[core_defs.DeviceTypeT]): @property @abc.abstractmethod - def __gt_device_type__(self) -> core_defs.DeviceTypeT: - ... + def __gt_device_type__(self) -> core_defs.DeviceTypeT: ... @abc.abstractmethod def __gt_allocate__( @@ -68,8 +67,7 @@ def __gt_allocate__( dtype: core_defs.DType[core_defs.ScalarT], device_id: int = 0, aligned_index: Optional[Sequence[common.NamedIndex]] = None, # absolute position - ) -> core_allocators.TensorBuffer[core_defs.DeviceTypeT, core_defs.ScalarT]: - ... + ) -> core_allocators.TensorBuffer[core_defs.DeviceTypeT, core_defs.ScalarT]: ... def is_field_allocator(obj: Any) -> TypeGuard[FieldBufferAllocatorProtocol]: @@ -87,8 +85,7 @@ class FieldBufferAllocatorFactoryProtocol(Protocol[core_defs.DeviceTypeT]): @property @abc.abstractmethod - def __gt_allocator__(self) -> FieldBufferAllocatorProtocol[core_defs.DeviceTypeT]: - ... + def __gt_allocator__(self) -> FieldBufferAllocatorProtocol[core_defs.DeviceTypeT]: ... def is_field_allocator_factory(obj: Any) -> TypeGuard[FieldBufferAllocatorFactoryProtocol]: @@ -178,9 +175,9 @@ def __gt_allocate__( if TYPE_CHECKING: - __TensorFieldAllocatorAsFieldAllocatorInterfaceT: type[ - FieldBufferAllocatorProtocol - ] = BaseFieldBufferAllocator + __TensorFieldAllocatorAsFieldAllocatorInterfaceT: type[FieldBufferAllocatorProtocol] = ( + BaseFieldBufferAllocator + ) def horizontal_first_layout_mapper( diff --git a/src/gt4py/next/common.py b/src/gt4py/next/common.py index 33a0591813..d8ffc2057b 100644 --- a/src/gt4py/next/common.py +++ b/src/gt4py/next/common.py @@ -160,8 +160,7 @@ def __repr__(self) -> str: return f"UnitRange({self.start}, {self.stop})" @overload - def __getitem__(self, index: int) -> int: - ... + def __getitem__(self, index: int) -> int: ... @overload def __getitem__(self, index: slice) -> UnitRange: # noqa: F811 # redefine unused @@ -414,8 +413,7 @@ def is_finite(cls, obj: Domain) -> TypeGuard[FiniteDomain]: return all(UnitRange.is_finite(rng) for rng in obj.ranges) @overload - def __getitem__(self, index: int) -> tuple[Dimension, _Rng]: - ... + def __getitem__(self, index: int) -> tuple[Dimension, _Rng]: ... @overload def __getitem__(self, index: slice) -> Self: # noqa: F811 # redefine unused @@ -424,8 +422,7 @@ def __getitem__(self, index: slice) -> Self: # noqa: F811 # redefine unused @overload def __getitem__( # noqa: F811 # redefine unused self, index: Dimension - ) -> tuple[Dimension, _Rng]: - ... + ) -> tuple[Dimension, _Rng]: ... def __getitem__( # noqa: F811 # redefine unused self, index: int | slice | Dimension @@ -571,8 +568,7 @@ def _broadcast_ranges( _R = TypeVar("_R", _Value, tuple[_Value, ...]) class GTBuiltInFuncDispatcher(Protocol): - def __call__(self, func: fbuiltins.BuiltInFunction[_R, _P], /) -> Callable[_P, _R]: - ... + def __call__(self, func: fbuiltins.BuiltInFunction[_R, _P], /) -> Callable[_P, _R]: ... # TODO(havogt): we need to describe when this interface should be used instead of the `Field` protocol. @@ -601,56 +597,45 @@ class Field(GTFieldInterface, Protocol[DimsT, core_defs.ScalarT]): __gt_builtin_func__: ClassVar[GTBuiltInFuncDispatcher] @property - def domain(self) -> Domain: - ... + def domain(self) -> Domain: ... @property def __gt_domain__(self) -> Domain: return self.domain @property - def codomain(self) -> type[core_defs.ScalarT] | Dimension: - ... + def codomain(self) -> type[core_defs.ScalarT] | Dimension: ... @property - def dtype(self) -> core_defs.DType[core_defs.ScalarT]: - ... + def dtype(self) -> core_defs.DType[core_defs.ScalarT]: ... @property - def ndarray(self) -> core_defs.NDArrayObject: - ... + def ndarray(self) -> core_defs.NDArrayObject: ... def __str__(self) -> str: return f"⟨{self.domain!s} → {self.dtype}⟩" @abc.abstractmethod - def asnumpy(self) -> np.ndarray: - ... + def asnumpy(self) -> np.ndarray: ... @abc.abstractmethod - def remap(self, index_field: ConnectivityField | fbuiltins.FieldOffset) -> Field: - ... + def remap(self, index_field: ConnectivityField | fbuiltins.FieldOffset) -> Field: ... @abc.abstractmethod - def restrict(self, item: AnyIndexSpec) -> Field | core_defs.ScalarT: - ... + def restrict(self, item: AnyIndexSpec) -> Field | core_defs.ScalarT: ... # Operators @abc.abstractmethod - def __call__(self, index_field: ConnectivityField | fbuiltins.FieldOffset) -> Field: - ... + def __call__(self, index_field: ConnectivityField | fbuiltins.FieldOffset) -> Field: ... @abc.abstractmethod - def __getitem__(self, item: AnyIndexSpec) -> Field | core_defs.ScalarT: - ... + def __getitem__(self, item: AnyIndexSpec) -> Field | core_defs.ScalarT: ... @abc.abstractmethod - def __abs__(self) -> Field: - ... + def __abs__(self) -> Field: ... @abc.abstractmethod - def __neg__(self) -> Field: - ... + def __neg__(self) -> Field: ... @abc.abstractmethod def __invert__(self) -> Field: @@ -665,48 +650,37 @@ def __ne__(self, other: Any) -> Field: # type: ignore[override] # mypy wants re ... @abc.abstractmethod - def __add__(self, other: Field | core_defs.ScalarT) -> Field: - ... + def __add__(self, other: Field | core_defs.ScalarT) -> Field: ... @abc.abstractmethod - def __radd__(self, other: Field | core_defs.ScalarT) -> Field: - ... + def __radd__(self, other: Field | core_defs.ScalarT) -> Field: ... @abc.abstractmethod - def __sub__(self, other: Field | core_defs.ScalarT) -> Field: - ... + def __sub__(self, other: Field | core_defs.ScalarT) -> Field: ... @abc.abstractmethod - def __rsub__(self, other: Field | core_defs.ScalarT) -> Field: - ... + def __rsub__(self, other: Field | core_defs.ScalarT) -> Field: ... @abc.abstractmethod - def __mul__(self, other: Field | core_defs.ScalarT) -> Field: - ... + def __mul__(self, other: Field | core_defs.ScalarT) -> Field: ... @abc.abstractmethod - def __rmul__(self, other: Field | core_defs.ScalarT) -> Field: - ... + def __rmul__(self, other: Field | core_defs.ScalarT) -> Field: ... @abc.abstractmethod - def __floordiv__(self, other: Field | core_defs.ScalarT) -> Field: - ... + def __floordiv__(self, other: Field | core_defs.ScalarT) -> Field: ... @abc.abstractmethod - def __rfloordiv__(self, other: Field | core_defs.ScalarT) -> Field: - ... + def __rfloordiv__(self, other: Field | core_defs.ScalarT) -> Field: ... @abc.abstractmethod - def __truediv__(self, other: Field | core_defs.ScalarT) -> Field: - ... + def __truediv__(self, other: Field | core_defs.ScalarT) -> Field: ... @abc.abstractmethod - def __rtruediv__(self, other: Field | core_defs.ScalarT) -> Field: - ... + def __rtruediv__(self, other: Field | core_defs.ScalarT) -> Field: ... @abc.abstractmethod - def __pow__(self, other: Field | core_defs.ScalarT) -> Field: - ... + def __pow__(self, other: Field | core_defs.ScalarT) -> Field: ... @abc.abstractmethod def __and__(self, other: Field | core_defs.ScalarT) -> Field: @@ -734,8 +708,7 @@ def is_field( @extended_runtime_checkable class MutableField(Field[DimsT, core_defs.ScalarT], Protocol[DimsT, core_defs.ScalarT]): @abc.abstractmethod - def __setitem__(self, index: AnyIndexSpec, value: Field | core_defs.ScalarT) -> None: - ... + def __setitem__(self, index: AnyIndexSpec, value: Field | core_defs.ScalarT) -> None: ... def is_mutable_field( @@ -759,8 +732,7 @@ class ConnectivityKind(enum.Flag): class ConnectivityField(Field[DimsT, core_defs.IntegralScalar], Protocol[DimsT, DimT]): @property @abc.abstractmethod - def codomain(self) -> DimT: - ... + def codomain(self) -> DimT: ... @property def kind(self) -> ConnectivityKind: @@ -771,8 +743,7 @@ def kind(self) -> ConnectivityKind: ) @abc.abstractmethod - def inverse_image(self, image_range: UnitRange | NamedRange) -> Sequence[NamedRange]: - ... + def inverse_image(self, image_range: UnitRange | NamedRange) -> Sequence[NamedRange]: ... # Operators def __abs__(self) -> Never: @@ -1076,9 +1047,9 @@ class FieldBuiltinFuncRegistry: dispatching (via ChainMap) to its parent's registries. """ - _builtin_func_map: collections.ChainMap[ - fbuiltins.BuiltInFunction, Callable - ] = collections.ChainMap() + _builtin_func_map: collections.ChainMap[fbuiltins.BuiltInFunction, Callable] = ( + collections.ChainMap() + ) def __init_subclass__(cls, **kwargs): cls._builtin_func_map = collections.ChainMap( diff --git a/src/gt4py/next/embedded/nd_array_field.py b/src/gt4py/next/embedded/nd_array_field.py index 52a61b40bb..3a22df1032 100644 --- a/src/gt4py/next/embedded/nd_array_field.py +++ b/src/gt4py/next/embedded/nd_array_field.py @@ -131,8 +131,9 @@ def dtype(self) -> core_defs.DType[core_defs.ScalarT]: @classmethod def from_array( cls, - data: npt.ArrayLike - | core_defs.NDArrayObject, # TODO: NDArrayObject should be part of ArrayLike + data: ( + npt.ArrayLike | core_defs.NDArrayObject + ), # TODO: NDArrayObject should be part of ArrayLike /, *, domain: common.DomainLike, @@ -476,9 +477,10 @@ def restrict(self, index: common.AnyIndexSpec) -> common.Field | core_defs.Integ NdArrayField.register_builtin_func(fbuiltins.where, _make_builtin("where", "where")) -def _make_reduction( - builtin_name: str, array_builtin_name: str -) -> Callable[..., NdArrayField[common.DimsT, core_defs.ScalarT],]: +def _make_reduction(builtin_name: str, array_builtin_name: str) -> Callable[ + ..., + NdArrayField[common.DimsT, core_defs.ScalarT], +]: def _builtin_op( field: NdArrayField[common.DimsT, core_defs.ScalarT], axis: common.Dimension ) -> NdArrayField[common.DimsT, core_defs.ScalarT]: diff --git a/src/gt4py/next/ffront/decorator.py b/src/gt4py/next/ffront/decorator.py index 9f8537f59b..6510be560e 100644 --- a/src/gt4py/next/ffront/decorator.py +++ b/src/gt4py/next/ffront/decorator.py @@ -490,13 +490,13 @@ def itir(self): @typing.overload -def program(definition: types.FunctionType) -> Program: - ... +def program(definition: types.FunctionType) -> Program: ... @typing.overload -def program(*, backend: Optional[ppi.ProgramExecutor]) -> Callable[[types.FunctionType], Program]: - ... +def program( + *, backend: Optional[ppi.ProgramExecutor] +) -> Callable[[types.FunctionType], Program]: ... def program( @@ -748,15 +748,13 @@ def __call__( @typing.overload def field_operator( definition: types.FunctionType, *, backend: Optional[ppi.ProgramExecutor] -) -> FieldOperator[foast.FieldOperator]: - ... +) -> FieldOperator[foast.FieldOperator]: ... @typing.overload def field_operator( *, backend: Optional[ppi.ProgramExecutor] -) -> Callable[[types.FunctionType], FieldOperator[foast.FieldOperator]]: - ... +) -> Callable[[types.FunctionType], FieldOperator[foast.FieldOperator]]: ... def field_operator(definition=None, *, backend=eve.NOTHING, grid_type=None): @@ -793,8 +791,7 @@ def scan_operator( init: core_defs.Scalar, backend: Optional[str], grid_type: GridType, -) -> FieldOperator[foast.ScanOperator]: - ... +) -> FieldOperator[foast.ScanOperator]: ... @typing.overload @@ -805,8 +802,7 @@ def scan_operator( init: core_defs.Scalar, backend: Optional[str], grid_type: GridType, -) -> Callable[[types.FunctionType], FieldOperator[foast.ScanOperator]]: - ... +) -> Callable[[types.FunctionType], FieldOperator[foast.ScanOperator]]: ... def scan_operator( diff --git a/src/gt4py/next/ffront/dialect_parser.py b/src/gt4py/next/ffront/dialect_parser.py index c04e978e51..07490db27c 100644 --- a/src/gt4py/next/ffront/dialect_parser.py +++ b/src/gt4py/next/ffront/dialect_parser.py @@ -39,12 +39,16 @@ def parse_source_definition(source_definition: SourceDefinition) -> ast.AST: line=err.lineno + source_definition.line_offset, column=err.offset + source_definition.column_offset, filename=source_definition.filename, - end_line=err.end_lineno + source_definition.line_offset - if err.end_lineno is not None - else None, - end_column=err.end_offset + source_definition.column_offset - if err.end_offset is not None - else None, + end_line=( + err.end_lineno + source_definition.line_offset + if err.end_lineno is not None + else None + ), + end_column=( + err.end_offset + source_definition.column_offset + if err.end_offset is not None + else None + ), ) raise errors.DSLError(loc, err.msg).with_traceback(err.__traceback__) diff --git a/src/gt4py/next/ffront/field_operator_ast.py b/src/gt4py/next/ffront/field_operator_ast.py index 6b772227b2..322a6df2e0 100644 --- a/src/gt4py/next/ffront/field_operator_ast.py +++ b/src/gt4py/next/ffront/field_operator_ast.py @@ -153,8 +153,7 @@ class Call(Expr): kwargs: dict[str, Expr] -class Stmt(LocatedNode): - ... +class Stmt(LocatedNode): ... class Starred(Expr): diff --git a/src/gt4py/next/ffront/foast_passes/type_deduction.py b/src/gt4py/next/ffront/foast_passes/type_deduction.py index 5e289af664..64fea7935c 100644 --- a/src/gt4py/next/ffront/foast_passes/type_deduction.py +++ b/src/gt4py/next/ffront/foast_passes/type_deduction.py @@ -438,9 +438,9 @@ def visit_IfStmt(self, node: foast.IfStmt, **kwargs) -> foast.IfStmt: f"got types '{true_type}' and '{false_type}.", ) # TODO: properly patch symtable (new node?) - symtable[sym].type = new_node.annex.propagated_symbols[ - sym - ].type = new_true_branch.annex.symtable[sym].type + symtable[sym].type = new_node.annex.propagated_symbols[sym].type = ( + new_true_branch.annex.symtable[sym].type + ) return new_node diff --git a/src/gt4py/next/ffront/foast_to_itir.py b/src/gt4py/next/ffront/foast_to_itir.py index 0c9ab4ab27..c0e618a42d 100644 --- a/src/gt4py/next/ffront/foast_to_itir.py +++ b/src/gt4py/next/ffront/foast_to_itir.py @@ -430,5 +430,4 @@ def _process_elements( return self._map(im.lambda_("expr")(process_func(current_el_expr)), obj) -class FieldOperatorLoweringError(Exception): - ... +class FieldOperatorLoweringError(Exception): ... diff --git a/src/gt4py/next/ffront/program_ast.py b/src/gt4py/next/ffront/program_ast.py index 14151fc243..4ff8265f70 100644 --- a/src/gt4py/next/ffront/program_ast.py +++ b/src/gt4py/next/ffront/program_ast.py @@ -93,8 +93,7 @@ class Slice(Expr): step: Literal[None] -class Stmt(LocatedNode): - ... +class Stmt(LocatedNode): ... class Program(LocatedNode, SymbolTableTrait): diff --git a/src/gt4py/next/iterator/embedded.py b/src/gt4py/next/iterator/embedded.py index 6d610fd136..6985aea853 100644 --- a/src/gt4py/next/iterator/embedded.py +++ b/src/gt4py/next/iterator/embedded.py @@ -80,8 +80,7 @@ ) -class SparseTag(Tag): - ... +class SparseTag(Tag): ... class NeighborTableOffsetProvider: @@ -156,14 +155,11 @@ class ItIterator(Protocol): `ItIterator` to avoid name clashes with `Iterator` from `typing` and `collections.abc`. """ - def shift(self, *offsets: OffsetPart) -> ItIterator: - ... + def shift(self, *offsets: OffsetPart) -> ItIterator: ... - def can_deref(self) -> bool: - ... + def can_deref(self) -> bool: ... - def deref(self) -> Any: - ... + def deref(self) -> Any: ... @runtime_checkable @@ -172,13 +168,11 @@ class LocatedField(Protocol): @property @abc.abstractmethod - def __gt_domain__(self) -> common.Domain: - ... + def __gt_domain__(self) -> common.Domain: ... # TODO(havogt): define generic Protocol to provide a concrete return type @abc.abstractmethod - def field_getitem(self, indices: NamedFieldIndices) -> Any: - ... + def field_getitem(self, indices: NamedFieldIndices) -> Any: ... @property def __gt_origin__(self) -> tuple[int, ...]: @@ -191,8 +185,7 @@ class MutableLocatedField(LocatedField, Protocol): # TODO(havogt): define generic Protocol to provide a concrete return type @abc.abstractmethod - def field_setitem(self, indices: NamedFieldIndices, value: Any) -> None: - ... + def field_setitem(self, indices: NamedFieldIndices, value: Any) -> None: ... #: Column range used in column mode (`column_axis != None`) in the current closure execution context. @@ -705,8 +698,7 @@ def _make_tuple( named_indices: NamedFieldIndices, *, column_axis: Tag, -) -> tuple[tuple | Column, ...]: - ... +) -> tuple[tuple | Column, ...]: ... @overload @@ -722,8 +714,7 @@ def _make_tuple( @overload def _make_tuple( field_or_tuple: LocatedField, named_indices: NamedFieldIndices, *, column_axis: Tag -) -> Column: - ... +) -> Column: ... @overload @@ -732,8 +723,7 @@ def _make_tuple( named_indices: NamedFieldIndices, *, column_axis: Literal[None] = None, -) -> npt.DTypeLike | Undefined: - ... +) -> npt.DTypeLike | Undefined: ... def _make_tuple( @@ -974,13 +964,11 @@ def get_ordered_indices(axes: Iterable[Axis], pos: NamedFieldIndices) -> tuple[F @overload -def _shift_range(range_or_index: range, offset: int) -> slice: - ... +def _shift_range(range_or_index: range, offset: int) -> slice: ... @overload -def _shift_range(range_or_index: common.IntIndex, offset: int) -> common.IntIndex: - ... +def _shift_range(range_or_index: common.IntIndex, offset: int) -> common.IntIndex: ... def _shift_range(range_or_index: range | common.IntIndex, offset: int) -> ArrayIndex: @@ -994,13 +982,11 @@ def _shift_range(range_or_index: range | common.IntIndex, offset: int) -> ArrayI @overload -def _range2slice(r: range) -> slice: - ... +def _range2slice(r: range) -> slice: ... @overload -def _range2slice(r: common.IntIndex) -> common.IntIndex: - ... +def _range2slice(r: common.IntIndex) -> common.IntIndex: ... def _range2slice(r: range | common.IntIndex) -> slice | common.IntIndex: @@ -1288,8 +1274,7 @@ def impl(it: ItIterator) -> ItIterator: DT = TypeVar("DT") -class _List(tuple, Generic[DT]): - ... +class _List(tuple, Generic[DT]): ... @dataclasses.dataclass(frozen=True) @@ -1424,8 +1409,7 @@ def is_tuple_of_field(field) -> bool: ) -class TupleFieldMeta(type): - ... +class TupleFieldMeta(type): ... class TupleField(metaclass=TupleFieldMeta): diff --git a/src/gt4py/next/iterator/ir.py b/src/gt4py/next/iterator/ir.py index 37abbec9e7..10caecc591 100644 --- a/src/gt4py/next/iterator/ir.py +++ b/src/gt4py/next/iterator/ir.py @@ -45,9 +45,9 @@ class Sym(Node): # helper # TODO(tehrengruber): Revisit. Using strings is a workaround to avoid coupling with the # type inference. kind: typing.Literal["Iterator", "Value", None] = None - dtype: Optional[ - tuple[str, bool] - ] = None # format: name of primitive type, boolean indicating if it is a list + dtype: Optional[tuple[str, bool]] = ( + None # format: name of primitive type, boolean indicating if it is a list + ) @datamodels.validator("kind") def _kind_validator(self: datamodels.DataModelTP, attribute: datamodels.Attribute, value: str): @@ -63,8 +63,7 @@ def _dtype_validator(self: datamodels.DataModelTP, attribute: datamodels.Attribu @noninstantiable -class Expr(Node): - ... +class Expr(Node): ... class Literal(Expr): diff --git a/src/gt4py/next/iterator/runtime.py b/src/gt4py/next/iterator/runtime.py index e12ae84dbc..5de4839b55 100644 --- a/src/gt4py/next/iterator/runtime.py +++ b/src/gt4py/next/iterator/runtime.py @@ -43,12 +43,10 @@ def offset(value): return Offset(value) -class CartesianDomain(dict): - ... +class CartesianDomain(dict): ... -class UnstructuredDomain(dict): - ... +class UnstructuredDomain(dict): ... # dependency inversion, register fendef for embedded execution or for tracing/parsing here diff --git a/src/gt4py/next/otf/compilation/compiler.py b/src/gt4py/next/otf/compilation/compiler.py index 9fd20b16e2..29541a3ae5 100644 --- a/src/gt4py/next/otf/compilation/compiler.py +++ b/src/gt4py/next/otf/compilation/compiler.py @@ -41,8 +41,7 @@ def __call__( self, source: stages.CompilableSource[SrcL, LS, TgtL], cache_strategy: cache.Strategy, - ) -> stages.BuildSystemProject[SrcL, LS, TgtL]: - ... + ) -> stages.BuildSystemProject[SrcL, LS, TgtL]: ... @dataclasses.dataclass(frozen=True) @@ -88,5 +87,4 @@ def __call__( ) -class CompilationError(RuntimeError): - ... +class CompilationError(RuntimeError): ... diff --git a/src/gt4py/next/otf/languages.py b/src/gt4py/next/otf/languages.py index b0d01d91ab..2397878271 100644 --- a/src/gt4py/next/otf/languages.py +++ b/src/gt4py/next/otf/languages.py @@ -57,8 +57,7 @@ class Python(LanguageTag): ... -class NanobindSrcL(LanguageTag): - ... +class NanobindSrcL(LanguageTag): ... class Cpp(NanobindSrcL): diff --git a/src/gt4py/next/otf/stages.py b/src/gt4py/next/otf/stages.py index a21bc83c0b..bd7f59e7aa 100644 --- a/src/gt4py/next/otf/stages.py +++ b/src/gt4py/next/otf/stages.py @@ -107,15 +107,13 @@ class BuildSystemProject(Protocol[SrcL_co, SettingT_co, TgtL_co]): and is not responsible for importing the results into Python. """ - def build(self) -> None: - ... + def build(self) -> None: ... class CompiledProgram(Protocol): """Executable python representation of a program.""" - def __call__(self, *args, **kwargs) -> None: - ... + def __call__(self, *args, **kwargs) -> None: ... def _unique_libs(*args: interface.LibraryDependency) -> tuple[interface.LibraryDependency, ...]: diff --git a/src/gt4py/next/otf/step_types.py b/src/gt4py/next/otf/step_types.py index 5eeb5c495b..43def259ab 100644 --- a/src/gt4py/next/otf/step_types.py +++ b/src/gt4py/next/otf/step_types.py @@ -46,8 +46,7 @@ class BindingStep(Protocol[SrcL, LS, TgtL]): def __call__( self, program_source: stages.ProgramSource[SrcL, LS] - ) -> stages.CompilableSource[SrcL, LS, TgtL]: - ... + ) -> stages.CompilableSource[SrcL, LS, TgtL]: ... class CompilationStep( @@ -56,5 +55,6 @@ class CompilationStep( ): """Compile program source code and bindings into a python callable (CompilableSource -> CompiledProgram).""" - def __call__(self, source: stages.CompilableSource[SrcL, LS, TgtL]) -> stages.CompiledProgram: - ... + def __call__( + self, source: stages.CompilableSource[SrcL, LS, TgtL] + ) -> stages.CompiledProgram: ... diff --git a/src/gt4py/next/otf/workflow.py b/src/gt4py/next/otf/workflow.py index 3a82f9c738..4bdb4bbb41 100644 --- a/src/gt4py/next/otf/workflow.py +++ b/src/gt4py/next/otf/workflow.py @@ -61,8 +61,7 @@ class Workflow(Protocol[StartT_contra, EndT_co]): - take a single input argument """ - def __call__(self, inp: StartT_contra) -> EndT_co: - ... + def __call__(self, inp: StartT_contra) -> EndT_co: ... class ReplaceEnabledWorkflowMixin(Workflow[StartT_contra, EndT_co], Protocol): diff --git a/src/gt4py/next/program_processors/codegens/gtfn/gtfn_im_ir.py b/src/gt4py/next/program_processors/codegens/gtfn/gtfn_im_ir.py index f0843919fe..a62f50fc44 100644 --- a/src/gt4py/next/program_processors/codegens/gtfn/gtfn_im_ir.py +++ b/src/gt4py/next/program_processors/codegens/gtfn/gtfn_im_ir.py @@ -21,8 +21,7 @@ from gt4py.next.program_processors.codegens.gtfn.gtfn_ir_common import Expr, Sym, SymRef -class Stmt(Node): - ... +class Stmt(Node): ... class AssignStmt(Stmt): @@ -35,8 +34,7 @@ class InitStmt(AssignStmt): init_type: str = "auto" -class EmptyListInitializer(Expr): - ... +class EmptyListInitializer(Expr): ... class Conditional(Stmt): diff --git a/src/gt4py/next/program_processors/codegens/gtfn/gtfn_ir_common.py b/src/gt4py/next/program_processors/codegens/gtfn/gtfn_ir_common.py index 79d4c18828..cb9aeffb90 100644 --- a/src/gt4py/next/program_processors/codegens/gtfn/gtfn_ir_common.py +++ b/src/gt4py/next/program_processors/codegens/gtfn/gtfn_ir_common.py @@ -25,8 +25,7 @@ class Sym(Node): # helper id: Coerced[SymbolName] # noqa: A003 -class Expr(Node): - ... +class Expr(Node): ... class SymRef(Expr): diff --git a/src/gt4py/next/program_processors/processor_interface.py b/src/gt4py/next/program_processors/processor_interface.py index 95d3d2ca35..0c280202b8 100644 --- a/src/gt4py/next/program_processors/processor_interface.py +++ b/src/gt4py/next/program_processors/processor_interface.py @@ -40,14 +40,12 @@ class ProgramProcessorCallable(Protocol[OutputT]): - def __call__(self, program: itir.FencilDefinition, *args, **kwargs) -> OutputT: - ... + def __call__(self, program: itir.FencilDefinition, *args, **kwargs) -> OutputT: ... class ProgramProcessor(ProgramProcessorCallable[OutputT], Protocol[OutputT, ProcessorKindT]): @property - def kind(self) -> type[ProcessorKindT]: - ... + def kind(self) -> type[ProcessorKindT]: ... class ProgramFormatter(ProgramProcessor[str, "ProgramFormatter"], Protocol): @@ -234,8 +232,7 @@ class ProgramBackend( ProgramProcessor[None, "ProgramExecutor"], next_allocators.FieldBufferAllocatorFactoryProtocol[core_defs.DeviceTypeT], Protocol[core_defs.DeviceTypeT], -): - ... +): ... def is_program_backend(obj: Callable) -> TypeGuard[ProgramBackend]: diff --git a/src/gt4py/next/program_processors/runners/dace_iterator/itir_to_sdfg.py b/src/gt4py/next/program_processors/runners/dace_iterator/itir_to_sdfg.py index 525a5c694e..073c856d86 100644 --- a/src/gt4py/next/program_processors/runners/dace_iterator/itir_to_sdfg.py +++ b/src/gt4py/next/program_processors/runners/dace_iterator/itir_to_sdfg.py @@ -113,9 +113,11 @@ def _make_array_shape_and_strides( dtype = dace.int64 sorted_dims = [dim for _, dim in get_sorted_dims(dims)] if sort_dims else dims shape = [ - neighbor_tables[dim.value].max_neighbors - if dim.kind == DimensionKind.LOCAL - else dace.symbol(unique_name(f"{name}_shape{i}"), dtype) + ( + neighbor_tables[dim.value].max_neighbors + if dim.kind == DimensionKind.LOCAL + else dace.symbol(unique_name(f"{name}_shape{i}"), dtype) + ) for i, dim in enumerate(sorted_dims) ] strides = [dace.symbol(unique_name(f"{name}_stride{i}"), dtype) for i, _ in enumerate(shape)] @@ -348,11 +350,15 @@ def visit_StencilClosure( # Map SDFG tasklet arguments to parameters input_local_names = [ - input_transients_mapping[input_name] - if input_name in input_transients_mapping - else input_name - if input_name in input_field_names - else cast(ValueExpr, program_arg_syms[input_name]).value.data + ( + input_transients_mapping[input_name] + if input_name in input_transients_mapping + else ( + input_name + if input_name in input_field_names + else cast(ValueExpr, program_arg_syms[input_name]).value.data + ) + ) for input_name in input_names ] input_memlets = [ @@ -380,9 +386,11 @@ def visit_StencilClosure( create_memlet_at( output_name, tuple( - f"i_{dim}" - if f"i_{dim}" in map_ranges - else f"0:{closure_sdfg.arrays[output_name].shape[scan_dim_index]}" + ( + f"i_{dim}" + if f"i_{dim}" in map_ranges + else f"0:{closure_sdfg.arrays[output_name].shape[scan_dim_index]}" + ) for dim, _ in closure_domain ), ) diff --git a/src/gt4py/next/type_inference.py b/src/gt4py/next/type_inference.py index 9b5d9070e3..10ae524451 100644 --- a/src/gt4py/next/type_inference.py +++ b/src/gt4py/next/type_inference.py @@ -94,13 +94,11 @@ def visit_TypeVar(self, node: V, *, index_map: dict[int, int]) -> V: @typing.overload -def freshen(dtypes: list[T]) -> list[T]: - ... +def freshen(dtypes: list[T]) -> list[T]: ... @typing.overload -def freshen(dtypes: T) -> T: - ... +def freshen(dtypes: T) -> T: ... def freshen(dtypes: list[T] | T) -> list[T] | T: @@ -325,15 +323,13 @@ def _handle_constraint(self, constraint: tuple[_Box, _Box]) -> bool: @typing.overload def unify( dtypes: list[Type], constraints: set[tuple[Type, Type]] -) -> tuple[list[Type], list[tuple[Type, Type]]]: - ... +) -> tuple[list[Type], list[tuple[Type, Type]]]: ... @typing.overload def unify( dtypes: Type, constraints: set[tuple[Type, Type]] -) -> tuple[Type, list[tuple[Type, Type]]]: - ... +) -> tuple[Type, list[tuple[Type, Type]]]: ... def unify( diff --git a/src/gt4py/next/type_system/type_info.py b/src/gt4py/next/type_system/type_info.py index 20fa8bd791..7c4c8e6e23 100644 --- a/src/gt4py/next/type_system/type_info.py +++ b/src/gt4py/next/type_system/type_info.py @@ -118,8 +118,9 @@ def constituents_yielder(symbol_type: ts.TypeSpec): def apply_to_primitive_constituents( symbol_type: ts.TypeSpec, - fun: Callable[[ts.TypeSpec], ts.TypeSpec] - | Callable[[ts.TypeSpec, tuple[int, ...]], ts.TypeSpec], + fun: ( + Callable[[ts.TypeSpec], ts.TypeSpec] | Callable[[ts.TypeSpec, tuple[int, ...]], ts.TypeSpec] + ), with_path_arg=False, _path=(), ): diff --git a/src/gt4py/storage/allocators.py b/src/gt4py/storage/allocators.py index 061f79f146..0482ec1e65 100644 --- a/src/gt4py/storage/allocators.py +++ b/src/gt4py/storage/allocators.py @@ -156,8 +156,7 @@ class BufferAllocator(Protocol[core_defs.DeviceTypeT]): """Protocol for buffer allocators.""" @property - def device_type(self) -> core_defs.DeviceTypeT: - ... + def device_type(self) -> core_defs.DeviceTypeT: ... def allocate( self, @@ -321,20 +320,17 @@ class _NumPyLibStridesModule(Protocol): @staticmethod def as_strided( ndarray: core_defs.NDArrayObject, **kwargs: Any - ) -> core_defs.NDArrayObject: - ... + ) -> core_defs.NDArrayObject: ... stride_tricks: _NumPyLibStridesModule lib: _NumPyLibModule @staticmethod - def empty(shape: Tuple[int, ...], dtype: Any) -> _NDBuffer: - ... + def empty(shape: Tuple[int, ...], dtype: Any) -> _NDBuffer: ... @staticmethod - def byte_bounds(ndarray: _NDBuffer) -> Tuple[int, int]: - ... + def byte_bounds(ndarray: _NDBuffer) -> Tuple[int, int]: ... def is_valid_nplike_allocation_ns(obj: Any) -> TypeGuard[ValidNumPyLikeAllocationNS]: diff --git a/tests/cartesian_tests/integration_tests/multi_feature_tests/stencil_definitions.py b/tests/cartesian_tests/integration_tests/multi_feature_tests/stencil_definitions.py index 79056c2914..4ac239fdd2 100644 --- a/tests/cartesian_tests/integration_tests/multi_feature_tests/stencil_definitions.py +++ b/tests/cartesian_tests/integration_tests/multi_feature_tests/stencil_definitions.py @@ -142,11 +142,7 @@ def native_functions(field_a: Field3D, field_b: Field3D): field_b = ( trunc_res if isfinite(trunc_res) - else field_a - if isinf(trunc_res) - else field_b - if isnan(trunc_res) - else 0.0 + else field_a if isinf(trunc_res) else field_b if isnan(trunc_res) else 0.0 ) diff --git a/tests/eve_tests/unit_tests/test_extended_typing.py b/tests/eve_tests/unit_tests/test_extended_typing.py index 733e12577c..d90a577bf9 100644 --- a/tests/eve_tests/unit_tests/test_extended_typing.py +++ b/tests/eve_tests/unit_tests/test_extended_typing.py @@ -413,12 +413,10 @@ class B: def test_is_protocol(): class AProtocol(typing.Protocol): - def do_something(self, value: int) -> int: - ... + def do_something(self, value: int) -> int: ... class NotProtocol(AProtocol): - def do_something_else(self, value: float) -> float: - ... + def do_something_else(self, value: float) -> float: ... class AXProtocol(xtyping.Protocol): A = 1 @@ -427,8 +425,7 @@ class NotXProtocol(AXProtocol): A = 1 class AgainProtocol(AProtocol, xtyping.Protocol): - def do_something_else(self, value: float) -> float: - ... + def do_something_else(self, value: float) -> float: ... assert xtyping.is_protocol(AProtocol) assert xtyping.is_protocol(AXProtocol) @@ -440,16 +437,13 @@ def do_something_else(self, value: float) -> float: def test_get_partial_type_hints(): - def f1(a: int) -> float: - ... + def f1(a: int) -> float: ... assert xtyping.get_partial_type_hints(f1) == {"a": int, "return": float} - class MissingRef: - ... + class MissingRef: ... - def f_partial(a: int) -> MissingRef: - ... + def f_partial(a: int) -> MissingRef: ... # This is expected behavior because this test file uses # 'from __future__ import annotations' and therefore local @@ -467,8 +461,7 @@ def f_partial(a: int) -> MissingRef: "return": int, } - def f_nested_partial(a: int) -> Dict[str, MissingRef]: - ... + def f_nested_partial(a: int) -> Dict[str, MissingRef]: ... assert xtyping.get_partial_type_hints(f_nested_partial) == { "a": int, @@ -500,8 +493,7 @@ def test_eval_forward_ref(): == Dict[str, Tuple[int, float]] ) - class MissingRef: - ... + class MissingRef: ... assert ( xtyping.eval_forward_ref("Callable[[int], MissingRef]", localns={"MissingRef": MissingRef}) @@ -559,19 +551,16 @@ def test_infer_type(): assert xtyping.infer_type(str) == Type[str] - class A: - ... + class A: ... assert xtyping.infer_type(A()) == A assert xtyping.infer_type(A) == Type[A] - def f1(): - ... + def f1(): ... assert xtyping.infer_type(f1) == Callable[[], Any] - def f2(a: int, b: float) -> None: - ... + def f2(a: int, b: float) -> None: ... assert xtyping.infer_type(f2) == Callable[[int, float], type(None)] @@ -579,8 +568,7 @@ def f3( a: Dict[Tuple[str, ...], List[int]], b: List[Callable[[List[int]], Set[Set[int]]]], c: Type[List[int]], - ) -> Any: - ... + ) -> Any: ... assert ( xtyping.infer_type(f3) @@ -594,8 +582,7 @@ def f3( ] ) - def f4(a: int, b: float, *, foo: Tuple[str, ...] = ()) -> None: - ... + def f4(a: int, b: float, *, foo: Tuple[str, ...] = ()) -> None: ... assert xtyping.infer_type(f4) == Callable[[int, float], type(None)] assert ( diff --git a/tests/next_tests/definitions.py b/tests/next_tests/definitions.py index dbb2366f47..70a0e7d090 100644 --- a/tests/next_tests/definitions.py +++ b/tests/next_tests/definitions.py @@ -61,12 +61,10 @@ class ProgramBackendId(_PythonObjectIdMixin, str, enum.Enum): class ExecutionAndAllocatorDescriptor(Protocol): # Used for test infrastructure, consider implementing this in gt4py when refactoring otf @property - def executor(self) -> Optional[ppi.ProgramExecutor]: - ... + def executor(self) -> Optional[ppi.ProgramExecutor]: ... @property - def allocator(self) -> next_allocators.FieldBufferAllocatorProtocol: - ... + def allocator(self) -> next_allocators.FieldBufferAllocatorProtocol: ... @dataclasses.dataclass(frozen=True) diff --git a/tests/next_tests/integration_tests/cases.py b/tests/next_tests/integration_tests/cases.py index 03a0a9f5a7..7d55e26118 100644 --- a/tests/next_tests/integration_tests/cases.py +++ b/tests/next_tests/integration_tests/cases.py @@ -95,8 +95,7 @@ class DataInitializer(Protocol): @property - def scalar_value(self) -> ScalarValue: - ... + def scalar_value(self) -> ScalarValue: ... def scalar(self, dtype: np.typing.DTypeLike) -> ScalarValue: # some unlikely numpy dtypes are picky about arguments @@ -107,8 +106,7 @@ def field( allocator: next_allocators.FieldBufferAllocatorProtocol, sizes: dict[gtx.Dimension, int], dtype: np.typing.DTypeLike, - ) -> FieldValue: - ... + ) -> FieldValue: ... def from_case( self: Self, @@ -249,22 +247,19 @@ def __getattr__(self, name: str) -> Any: @typing.overload -def make_builder(*args: Callable) -> Callable[..., Builder]: - ... +def make_builder(*args: Callable) -> Callable[..., Builder]: ... @typing.overload def make_builder( *args: Literal[None], **kwargs: dict[str, Any] -) -> Callable[[Callable], Callable[..., Builder]]: - ... +) -> Callable[[Callable], Callable[..., Builder]]: ... @typing.overload def make_builder( *args: Optional[Callable], **kwargs: dict[str, Any] -) -> Callable[[Callable], Callable[..., Builder]] | Callable[..., Builder]: - ... +) -> Callable[[Callable], Callable[..., Builder]] | Callable[..., Builder]: ... # TODO(ricoh): Think about improving the type hints using `typing.ParamSpec`. @@ -305,8 +300,7 @@ def setter(self: Builder) -> Builder: argspec = inspect.getfullargspec(func) @dataclasses.dataclass(frozen=True) - class NewBuilder(Builder): - ... + class NewBuilder(Builder): ... for argname in argspec.args + argspec.kwonlyargs: setattr(NewBuilder, argname, make_setter(argname)) diff --git a/tests/next_tests/integration_tests/feature_tests/ffront_tests/test_execution.py b/tests/next_tests/integration_tests/feature_tests/ffront_tests/test_execution.py index 9482860d13..6b7737df67 100644 --- a/tests/next_tests/integration_tests/feature_tests/ffront_tests/test_execution.py +++ b/tests/next_tests/integration_tests/feature_tests/ffront_tests/test_execution.py @@ -1099,7 +1099,12 @@ def test_tuple_unpacking(cartesian_case): @gtx.field_operator def unpack( inp: cases.IField, - ) -> tuple[cases.IField, cases.IField, cases.IField, cases.IField,]: + ) -> tuple[ + cases.IField, + cases.IField, + cases.IField, + cases.IField, + ]: a, b, c, d = (inp + 2, inp + 3, inp + 5, inp + 7) return a, b, c, d diff --git a/tests/next_tests/unit_tests/ffront_tests/test_foast_to_itir.py b/tests/next_tests/unit_tests/ffront_tests/test_foast_to_itir.py index 2dd4b91c48..bc92efc02c 100644 --- a/tests/next_tests/unit_tests/ffront_tests/test_foast_to_itir.py +++ b/tests/next_tests/unit_tests/ffront_tests/test_foast_to_itir.py @@ -526,15 +526,13 @@ def reduction(e1: gtx.Field[[Edge], float64], e2: gtx.Field[[Vertex, V2EDim], fl def test_builtin_int_constructors(): - def int_constrs() -> ( - tuple[ - int32, - int32, - int64, - int32, - int64, - ] - ): + def int_constrs() -> tuple[ + int32, + int32, + int64, + int32, + int64, + ]: return 1, int32(1), int64(1), int32("1"), int64("1") parsed = FieldOperatorParser.apply_to_function(int_constrs) @@ -552,17 +550,15 @@ def int_constrs() -> ( def test_builtin_float_constructors(): - def float_constrs() -> ( - tuple[ - float, - float, - float32, - float64, - float, - float32, - float64, - ] - ): + def float_constrs() -> tuple[ + float, + float, + float32, + float64, + float, + float32, + float64, + ]: return ( 0.1, float(0.1), diff --git a/tests/next_tests/unit_tests/test_type_inference.py b/tests/next_tests/unit_tests/test_type_inference.py index 74178e7548..3db67320f1 100644 --- a/tests/next_tests/unit_tests/test_type_inference.py +++ b/tests/next_tests/unit_tests/test_type_inference.py @@ -20,8 +20,7 @@ class Foo(ti.Type): bar: ti.Type baz: ti.Type - class Bar(ti.Type): - ... + class Bar(ti.Type): ... r = ti._Renamer() actual = [ From 6262708783152628fb05886a6396a1ee5e464d1e Mon Sep 17 00:00:00 2001 From: Hannes Vogt Date: Wed, 31 Jan 2024 14:17:17 +0100 Subject: [PATCH 60/85] test[next]: fix obsolete asarray (#1436) --- .../ffront_tests/test_arg_call_interface.py | 2 +- .../feature_tests/ffront_tests/test_execution.py | 14 ++++++++------ .../feature_tests/ffront_tests/test_scalar_if.py | 10 +++++----- 3 files changed, 14 insertions(+), 12 deletions(-) diff --git a/tests/next_tests/integration_tests/feature_tests/ffront_tests/test_arg_call_interface.py b/tests/next_tests/integration_tests/feature_tests/ffront_tests/test_arg_call_interface.py index 354323afeb..e5a821de52 100644 --- a/tests/next_tests/integration_tests/feature_tests/ffront_tests/test_arg_call_interface.py +++ b/tests/next_tests/integration_tests/feature_tests/ffront_tests/test_arg_call_interface.py @@ -143,7 +143,7 @@ def testee( for name in ("out1", "out2", "out3", "out4") ) - ref = np.asarray(a) + 2 * np.asarray(b) + 3 * np.asarray(c) + ref = a + 2 * b + 3 * c cases.verify( cartesian_case, diff --git a/tests/next_tests/integration_tests/feature_tests/ffront_tests/test_execution.py b/tests/next_tests/integration_tests/feature_tests/ffront_tests/test_execution.py index 6b7737df67..02d54b1cb3 100644 --- a/tests/next_tests/integration_tests/feature_tests/ffront_tests/test_execution.py +++ b/tests/next_tests/integration_tests/feature_tests/ffront_tests/test_execution.py @@ -473,7 +473,7 @@ def testee(a: cases.IKField, offset_field: cases.IKField) -> gtx.Field[[IDim, KD offset_field, out=out, offset_provider={"Ioff": IDim, "Koff": KDim}, - ref=np.full_like(offset_field, True, dtype=bool), + ref=np.full_like(offset_field.asnumpy(), True, dtype=bool), comparison=lambda out, ref: np.all(out == ref), ) @@ -665,7 +665,7 @@ def testee(a: cases.IField, b: cases.IField, left: int32, right: int32) -> cases def testee(left: int32, right: int32) -> cases.IField: return broadcast(3, (IDim,)) if left > right else broadcast(4, (IDim,)) - e = np.asarray(a) if left < right else np.asarray(b) + e = a if left < right else b cases.verify( cartesian_case, testee, @@ -764,9 +764,9 @@ def testee(out: tuple[cases.KField, tuple[cases.KField, cases.KField]]): cartesian_case, testee, ref=lambda: (expected + 1.0, (expected + 2.0, expected + 3.0)), - comparison=lambda ref, out: np.all(np.asarray(out[0]) == ref[0]) - and np.all(np.asarray(out[1][0]) == ref[1][0]) - and np.all(np.asarray(out[1][1]) == ref[1][1]), + comparison=lambda ref, out: np.all(out[0] == ref[0]) + and np.all(out[1][0] == ref[1][0]) + and np.all(out[1][1] == ref[1][1]), ) @@ -1195,5 +1195,7 @@ def consume_constants(input: cases.IFloatField) -> cases.IFloatField: return constants.PI * constants.E * input cases.verify_with_default_data( - cartesian_case, consume_constants, ref=lambda input: constants.PI * constants.E * input + cartesian_case, + consume_constants, + ref=lambda input: constants.PI * constants.E * input, ) diff --git a/tests/next_tests/integration_tests/feature_tests/ffront_tests/test_scalar_if.py b/tests/next_tests/integration_tests/feature_tests/ffront_tests/test_scalar_if.py index 834966c125..fc51af747f 100644 --- a/tests/next_tests/integration_tests/feature_tests/ffront_tests/test_scalar_if.py +++ b/tests/next_tests/integration_tests/feature_tests/ffront_tests/test_scalar_if.py @@ -241,10 +241,10 @@ def nested_if_conditional_return( out = cases.allocate(cartesian_case, nested_if_conditional_return, cases.RETURN)() ref = { - (True, True): np.asarray(inp) + 1, - (True, False): np.asarray(inp) + 2, - (False, True): np.asarray(inp) + 3, - (False, False): np.asarray(inp) + 3, + (True, True): inp.asnumpy() + 1, + (True, False): inp.asnumpy() + 2, + (False, True): inp.asnumpy() + 3, + (False, False): inp.asnumpy() + 3, } cases.verify( @@ -289,7 +289,7 @@ def nested_if(a: cases.IField, b: cases.IField, condition: bool) -> cases.IField b, condition, out=out, - ref=np.asarray(a) + 1 if condition else np.asarray(b) + 5, + ref=a.asnumpy() + 1 if condition else b.asnumpy() + 5, ) From e4dc1ee35482f0037c5db465d3b43ad4c96ceaea Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Enrique=20Gonz=C3=A1lez=20Paredes?= Date: Wed, 31 Jan 2024 15:27:46 +0100 Subject: [PATCH 61/85] test[next]: Add unit test for embedded `inverse_image` and fix bugs (#1432) Add unit tests for `ConnectivityField.inverse_image()`. --- src/gt4py/next/embedded/nd_array_field.py | 5 +- .../embedded_tests/test_nd_array_field.py | 110 ++++++++++++++++++ 2 files changed, 114 insertions(+), 1 deletion(-) diff --git a/src/gt4py/next/embedded/nd_array_field.py b/src/gt4py/next/embedded/nd_array_field.py index 3a22df1032..65a71718e4 100644 --- a/src/gt4py/next/embedded/nd_array_field.py +++ b/src/gt4py/next/embedded/nd_array_field.py @@ -403,10 +403,11 @@ def inverse_image( last_data_index = dim_nnz_indices[-1] assert isinstance(last_data_index, core_defs.INTEGRAL_TYPES) indices, counts = xp.unique(dim_nnz_indices, return_counts=True) + dim_range = self._domain[i] + if len(xp.unique(counts)) == 1 and ( len(indices) == last_data_index - first_data_index + 1 ): - dim_range = self._domain[i] idx_offset = dim_range[1].start start = idx_offset + first_data_index assert common.is_int_index(start) @@ -428,6 +429,8 @@ def inverse_image( f"Restriction generates non-contiguous dimensions '{non_contiguous_dims}'." ) + self._cache[cache_key] = new_dims + return new_dims def restrict(self, index: common.AnyIndexSpec) -> common.Field | core_defs.IntegralScalar: diff --git a/tests/next_tests/unit_tests/embedded_tests/test_nd_array_field.py b/tests/next_tests/unit_tests/embedded_tests/test_nd_array_field.py index 1972b55852..70fa274457 100644 --- a/tests/next_tests/unit_tests/embedded_tests/test_nd_array_field.py +++ b/tests/next_tests/unit_tests/embedded_tests/test_nd_array_field.py @@ -632,3 +632,113 @@ def test_setitem_wrong_domain(): with pytest.raises(ValueError, match=r"Incompatible 'Domain'.*"): field[(1, slice(None))] = value_incompatible + + +def test_connectivity_field_inverse_image(): + V = Dimension("V") + E = Dimension("E") + + V_START, V_STOP = 2, 7 + E_START, E_STOP = 0, 10 + + e2v_conn = common._connectivity( + np.roll(np.arange(E_START, E_STOP), 1), + domain=common.domain([common.named_range((E, (E_START, E_STOP)))]), + codomain=V, + ) + + # Test range + image_range = UnitRange(V_START, V_STOP) + result = e2v_conn.inverse_image(image_range) + + assert len(result) == 1 + assert result[0] == (E, UnitRange(V_START + 1, V_STOP + 1)) + + # Test cache + cached_result = e2v_conn.inverse_image(image_range) + assert result is cached_result # If the cache is not used, the result would be a new object + + # Test codomain + with pytest.raises(ValueError, match="does not match the codomain dimension"): + e2v_conn.inverse_image((E, UnitRange(1, 2))) + + +def test_connectivity_field_inverse_image_2d_domain(): + V = Dimension("V") + E = Dimension("E") + E2V = Dimension("E2V") + + V_START, V_STOP = 0, 3 + E_START, E_STOP = 0, 3 + E2V_START, E2V_STOP = 0, 3 + + e2v_conn = common._connectivity( + np.asarray([[0, 0, 2], [1, 1, 2], [2, 2, 2]]), + domain=common.domain( + [ + common.named_range((E, (E_START, E_STOP))), + common.named_range((E2V, (E2V_START, E2V_STOP))), + ] + ), + codomain=V, + ) + + # e2c_conn: + # ---E2V---- + # |[[0 0 2] + # E [1 1 2] + # | [2 2 2]] + + # Test contiguous and non-contiguous ranges. + # For the 'e2c_conn' defined above, the only valid range including 2 + # is [0, 3). Otherwise, the inverse image would be non-contiguous. + image_range = UnitRange(V_START, V_STOP) + result = e2v_conn.inverse_image(image_range) + + assert len(result) == 2 + assert result[0] == (E, UnitRange(E_START, E_STOP)) + assert result[1] == (E2V, UnitRange(E2V_START, E2V_STOP)) + + result = e2v_conn.inverse_image(UnitRange(0, 2)) + assert len(result) == 2 + assert result[0] == (E, UnitRange(0, 2)) + assert result[1] == (E2V, UnitRange(0, 2)) + + result = e2v_conn.inverse_image(UnitRange(0, 1)) + assert len(result) == 2 + assert result[0] == (E, UnitRange(0, 1)) + assert result[1] == (E2V, UnitRange(0, 2)) + + result = e2v_conn.inverse_image(UnitRange(1, 2)) + assert len(result) == 2 + assert result[0] == (E, UnitRange(1, 2)) + assert result[1] == (E2V, UnitRange(0, 2)) + + with pytest.raises(ValueError, match="generates non-contiguous dimensions"): + result = e2v_conn.inverse_image(UnitRange(1, 3)) + + with pytest.raises(ValueError, match="generates non-contiguous dimensions"): + result = e2v_conn.inverse_image(UnitRange(2, 3)) + + +def test_connectivity_field_inverse_image_non_contiguous(): + V = Dimension("V") + E = Dimension("E") + + V_START, V_STOP = 2, 7 + E_START, E_STOP = 0, 10 + + e2v_conn = common._connectivity( + np.asarray([0, 1, 2, 3, 4, 9, 7, 5, 8, 6]), + domain=common.domain([common.named_range((E, (E_START, E_STOP)))]), + codomain=V, + ) + + result = e2v_conn.inverse_image(UnitRange(V_START, 5)) + assert result[0] == (E, UnitRange(V_START, 5)) + + with pytest.raises(ValueError, match="generates non-contiguous dimensions"): + e2v_conn.inverse_image(UnitRange(V_START, 6)) + + with pytest.raises(ValueError, match="generates non-contiguous dimensions"): + e2v_conn.inverse_image(UnitRange(V_START, V_STOP)) From 28ed830b0dac64eb69418900b3f403b6189dcfd6 Mon Sep 17 00:00:00 2001 From: Hannes Vogt Date: Wed, 31 Jan 2024 17:45:47 +0100 Subject: [PATCH 62/85] build: Update gridtools-cpp version to 2.3.2 (#1437) --- .pre-commit-config.yaml | 2 +- constraints.txt | 6 +++--- min-extra-requirements-test.txt | 2 +- min-requirements-test.txt | 2 +- pyproject.toml | 2 +- requirements-dev.txt | 6 +++--- 6 files changed, 10 insertions(+), 10 deletions(-) diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index 862aa46d66..3f26dfea55 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -171,7 +171,7 @@ repos: - deepdiff==6.7.1 - devtools==0.12.2 - frozendict==2.4.0 - - gridtools-cpp==2.3.1 + - gridtools-cpp==2.3.2 - importlib-resources==6.1.1 - jinja2==3.1.3 - lark==1.1.9 diff --git a/constraints.txt b/constraints.txt index 1aa47d8340..3b32e53c0c 100644 --- a/constraints.txt +++ b/constraints.txt @@ -61,14 +61,14 @@ flake8-pyproject==1.2.3 # via -r requirements-dev.in flake8-rst-docstrings==0.3.0 # via -r requirements-dev.in flask==3.0.1 # via dace fonttools==4.47.2 # via matplotlib -fparser==0.1.3 # via dace +fparser==0.1.4 # via dace frozendict==2.4.0 # via gt4py (pyproject.toml) -gridtools-cpp==2.3.1 # via gt4py (pyproject.toml) +gridtools-cpp==2.3.2 # via gt4py (pyproject.toml) hypothesis==6.97.3 # via -r requirements-dev.in, gt4py (pyproject.toml) identify==2.5.33 # via pre-commit idna==3.6 # via requests imagesize==1.4.1 # via sphinx -importlib-metadata==7.0.1 # via build, flask, fparser, jupyter-client, sphinx +importlib-metadata==7.0.1 # via build, flask, jupyter-client, sphinx importlib-resources==6.1.1 ; python_version < "3.9" # via gt4py (pyproject.toml), jsonschema, jsonschema-specifications, matplotlib inflection==0.5.1 # via pytest-factoryboy iniconfig==2.0.0 # via pytest diff --git a/min-extra-requirements-test.txt b/min-extra-requirements-test.txt index 7200018616..1db48693be 100644 --- a/min-extra-requirements-test.txt +++ b/min-extra-requirements-test.txt @@ -40,7 +40,7 @@ flake8-pyproject==1.2.2 flake8-rst-docstrings==0.0.14 flake8==5.0.4 frozendict==2.3 -gridtools-cpp==2.3.1 +gridtools-cpp==2.3.2 hypothesis==6.0.0 importlib-resources==5.0;python_version<'3.9' isort==5.10 diff --git a/min-requirements-test.txt b/min-requirements-test.txt index 259663ffc4..badf08864e 100644 --- a/min-requirements-test.txt +++ b/min-requirements-test.txt @@ -38,7 +38,7 @@ flake8-pyproject==1.2.2 flake8-rst-docstrings==0.0.14 flake8==5.0.4 frozendict==2.3 -gridtools-cpp==2.3.1 +gridtools-cpp==2.3.2 hypothesis==6.0.0 importlib-resources==5.0;python_version<'3.9' isort==5.10 diff --git a/pyproject.toml b/pyproject.toml index 5a1618fc49..1ba20165e9 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -34,7 +34,7 @@ dependencies = [ 'deepdiff>=5.6.0', 'devtools>=0.6', 'frozendict>=2.3', - 'gridtools-cpp>=2.3.1,==2.*', + 'gridtools-cpp>=2.3.2,==2.*', "importlib-resources>=5.0;python_version<'3.9'", 'jinja2>=3.0.0', 'lark>=1.1.2', diff --git a/requirements-dev.txt b/requirements-dev.txt index e54e56ad62..94052ec478 100644 --- a/requirements-dev.txt +++ b/requirements-dev.txt @@ -61,14 +61,14 @@ flake8-pyproject==1.2.3 # via -r requirements-dev.in flake8-rst-docstrings==0.3.0 # via -r requirements-dev.in flask==3.0.1 # via dace fonttools==4.47.2 # via matplotlib -fparser==0.1.3 # via dace +fparser==0.1.4 # via dace frozendict==2.4.0 # via gt4py (pyproject.toml) -gridtools-cpp==2.3.1 # via gt4py (pyproject.toml) +gridtools-cpp==2.3.2 # via gt4py (pyproject.toml) hypothesis==6.97.3 # via -r requirements-dev.in, gt4py (pyproject.toml) identify==2.5.33 # via pre-commit idna==3.6 # via requests imagesize==1.4.1 # via sphinx -importlib-metadata==7.0.1 # via build, flask, fparser, jupyter-client, sphinx +importlib-metadata==7.0.1 # via build, flask, jupyter-client, sphinx importlib-resources==6.1.1 ; python_version < "3.9" # via gt4py (pyproject.toml), jsonschema, jsonschema-specifications, matplotlib inflection==0.5.1 # via pytest-factoryboy iniconfig==2.0.0 # via pytest From adf3a3cd96966c503e4a2b95272aa8757b9c7c92 Mon Sep 17 00:00:00 2001 From: Hannes Vogt Date: Wed, 31 Jan 2024 18:25:17 +0100 Subject: [PATCH 63/85] feat[next]: add tests with mesh with skip values (#1433) - Adds a mesh with skip values - Define `common.SKIP_VALUE = -1` instead of using `-1` explicitly - Skip tests with that mesh in embedded (will come in a next PR). --- pyproject.toml | 1 + src/gt4py/next/common.py | 8 +- src/gt4py/next/iterator/embedded.py | 12 +- tests/next_tests/definitions.py | 2 + tests/next_tests/integration_tests/cases.py | 32 ++- .../ffront_tests/ffront_test_utils.py | 190 +++++++++++++----- .../ffront_tests/test_bound_args.py | 2 +- .../ffront_tests/test_execution.py | 28 ++- .../ffront_tests/test_external_local_field.py | 12 +- .../ffront_tests/test_gt4py_builtins.py | 44 +++- .../test_temporaries_with_sizes.py | 28 +-- 11 files changed, 259 insertions(+), 100 deletions(-) diff --git a/pyproject.toml b/pyproject.toml index 1ba20165e9..5fffb9cf0c 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -356,6 +356,7 @@ markers = [ 'uses_cartesian_shift: tests that use a Cartesian connectivity', 'uses_unstructured_shift: tests that use a unstructured connectivity', 'uses_max_over: tests that use the max_over builtin', + 'uses_mesh_with_skip_values: tests that use a mesh with skip values', 'checks_specific_error: tests that rely on the backend to produce a specific error message' ] norecursedirs = ['dist', 'build', 'cpp_backend_tests/build*', '_local/*', '.*'] diff --git a/src/gt4py/next/common.py b/src/gt4py/next/common.py index d8ffc2057b..90e76d671d 100644 --- a/src/gt4py/next/common.py +++ b/src/gt4py/next/common.py @@ -32,6 +32,7 @@ Any, Callable, ClassVar, + Final, Generic, Never, Optional, @@ -1073,4 +1074,9 @@ def register_builtin_func( @classmethod def __gt_builtin_func__(cls, /, func: fbuiltins.BuiltInFunction[_R, _P]) -> Callable[_P, _R]: return cls._builtin_func_map.get(func, NotImplemented) - return cls._builtin_func_map.get(func, NotImplemented) + + +#: Numeric value used to represent missing values in connectivities. +#: Equivalent to the `_FillValue` attribute in the UGRID Conventions +#: (see: http://ugrid-conventions.github.io/ugrid-conventions/). +SKIP_VALUE: Final[int] = -1 diff --git a/src/gt4py/next/iterator/embedded.py b/src/gt4py/next/iterator/embedded.py index 6985aea853..7e0e060834 100644 --- a/src/gt4py/next/iterator/embedded.py +++ b/src/gt4py/next/iterator/embedded.py @@ -526,6 +526,16 @@ def execute_shift( for i, p in reversed(list(enumerate(new_entry))): # first shift applies to the last sparse dimensions of that axis type if p is None: + offset_implementation = offset_provider[tag] + assert isinstance(offset_implementation, common.Connectivity) + cur_index = pos[offset_implementation.origin_axis.value] + assert common.is_int_index(cur_index) + if offset_implementation.mapped_index(cur_index, index) in [ + None, + common.SKIP_VALUE, + ]: + return None + new_entry[i] = index break # the assertions above confirm pos is incomplete casting here to avoid duplicating work in a type guard @@ -549,7 +559,7 @@ def execute_shift( assert common.is_int_index(cur_index) if offset_implementation.mapped_index(cur_index, index) in [ None, - -1, + common.SKIP_VALUE, ]: return None else: diff --git a/tests/next_tests/definitions.py b/tests/next_tests/definitions.py index 70a0e7d090..c95292d702 100644 --- a/tests/next_tests/definitions.py +++ b/tests/next_tests/definitions.py @@ -135,6 +135,7 @@ class ProgramFormatterId(_PythonObjectIdMixin, str, enum.Enum): USES_CARTESIAN_SHIFT = "uses_cartesian_shift" USES_UNSTRUCTURED_SHIFT = "uses_unstructured_shift" USES_MAX_OVER = "uses_max_over" +USES_MESH_WITH_SKIP_VALUES = "uses_mesh_with_skip_values" CHECKS_SPECIFIC_ERROR = "checks_specific_error" # Skip messages (available format keys: 'marker', 'backend') @@ -170,6 +171,7 @@ class ProgramFormatterId(_PythonObjectIdMixin, str, enum.Enum): XFAIL, UNSUPPORTED_MESSAGE, ), # we can't extract the field type from scan args + (USES_MESH_WITH_SKIP_VALUES, XFAIL, UNSUPPORTED_MESSAGE), ] GTFN_SKIP_TEST_LIST = COMMON_SKIP_TEST_LIST + [ # floordiv not yet supported, see https://github.com/GridTools/gt4py/issues/1136 diff --git a/tests/next_tests/integration_tests/cases.py b/tests/next_tests/integration_tests/cases.py index 7d55e26118..8513c98d89 100644 --- a/tests/next_tests/integration_tests/cases.py +++ b/tests/next_tests/integration_tests/cases.py @@ -35,7 +35,14 @@ from next_tests import definitions as test_definitions from next_tests.integration_tests.feature_tests.ffront_tests.ffront_test_utils import ( # noqa: F401 # fixture and aliases + C2E, + C2V, + E2V, + V2E, + C2EDim, + C2VDim, Cell, + E2VDim, Edge, IDim, Ioff, @@ -43,9 +50,10 @@ Joff, KDim, Koff, + V2EDim, Vertex, exec_alloc_descriptor, - reduction_setup, + mesh_descriptor, ) @@ -65,16 +73,6 @@ CField: TypeAlias = gtx.Field[[Cell], np.int32] # type: ignore [valid-type] EmptyField: TypeAlias = gtx.Field[[], np.int32] # type: ignore [valid-type] -# TODO(ricoh): unify the following with the `ffront_test_utils.reduction_setup` -# fixture if `ffront_test_utils.reduction_setup` is not completely superseded -# by `unstructured_case`. -V2EDim = gtx.Dimension("V2E", kind=gtx.DimensionKind.LOCAL) -E2VDim = gtx.Dimension("E2V", kind=gtx.DimensionKind.LOCAL) -C2EDim = gtx.Dimension("C2E", kind=common.DimensionKind.LOCAL) -V2E = gtx.FieldOffset("V2E", source=Edge, target=(Vertex, V2EDim)) -E2V = gtx.FieldOffset("E2V", source=Vertex, target=(Edge, E2VDim)) -C2E = gtx.FieldOffset("C2E", source=Edge, target=(Cell, C2EDim)) - ScalarValue: TypeAlias = core_defs.Scalar FieldValue: TypeAlias = gtx.Field FieldViewArg: TypeAlias = FieldValue | ScalarValue | tuple["FieldViewArg", ...] @@ -489,17 +487,17 @@ def cartesian_case( @pytest.fixture def unstructured_case( - reduction_setup, # noqa: F811 # fixtures + mesh_descriptor, # noqa: F811 # fixtures exec_alloc_descriptor: test_definitions.ExecutionAndAllocatorDescriptor, # noqa: F811 # fixtures ): yield Case( exec_alloc_descriptor.executor, - offset_provider=reduction_setup.offset_provider, + offset_provider=mesh_descriptor.offset_provider, default_sizes={ - Vertex: reduction_setup.num_vertices, - Edge: reduction_setup.num_edges, - Cell: reduction_setup.num_cells, - KDim: reduction_setup.k_levels, + Vertex: mesh_descriptor.num_vertices, + Edge: mesh_descriptor.num_edges, + Cell: mesh_descriptor.num_cells, + KDim: 10, }, grid_type=common.GridType.UNSTRUCTURED, allocator=exec_alloc_descriptor.allocator, diff --git a/tests/next_tests/integration_tests/feature_tests/ffront_tests/ffront_test_utils.py b/tests/next_tests/integration_tests/feature_tests/ffront_tests/ffront_test_utils.py index e421763699..d8c4696073 100644 --- a/tests/next_tests/integration_tests/feature_tests/ffront_tests/ffront_test_utils.py +++ b/tests/next_tests/integration_tests/feature_tests/ffront_tests/ffront_test_utils.py @@ -13,13 +13,15 @@ # # SPDX-License-Identifier: GPL-3.0-or-later +import types from collections import namedtuple -from typing import Any, Optional, TypeVar +from typing import Any, Protocol, TypeVar import numpy as np import pytest import gt4py.next as gtx +from gt4py.next import common from gt4py.next.ffront import decorator from gt4py.next.iterator import ir as itir from gt4py.next.program_processors import processor_interface as ppi @@ -118,18 +120,41 @@ def debug_itir(tree): Cell = gtx.Dimension("Cell") EdgeOffset = gtx.FieldOffset("EdgeOffset", source=Edge, target=(Edge,)) +V2EDim = gtx.Dimension("V2E", kind=gtx.DimensionKind.LOCAL) +E2VDim = gtx.Dimension("E2V", kind=gtx.DimensionKind.LOCAL) +C2EDim = gtx.Dimension("C2E", kind=gtx.DimensionKind.LOCAL) +C2VDim = gtx.Dimension("C2V", kind=gtx.DimensionKind.LOCAL) +V2E = gtx.FieldOffset("V2E", source=Edge, target=(Vertex, V2EDim)) +E2V = gtx.FieldOffset("E2V", source=Vertex, target=(Edge, E2VDim)) +C2E = gtx.FieldOffset("C2E", source=Edge, target=(Cell, C2EDim)) +C2V = gtx.FieldOffset("C2V", source=Vertex, target=(Cell, C2VDim)) + size = 10 -@pytest.fixture -def reduction_setup(): +class MeshDescriptor(Protocol): + @property + def name(self) -> str: ... + + @property + def num_vertices(self) -> int: ... + + @property + def num_cells(self) -> int: ... + + @property + def num_edges(self) -> int: ... + + @property + def num_levels(self) -> int: ... + + @property + def offset_provider(self) -> dict[str, common.Connectivity]: ... + + +def simple_mesh() -> MeshDescriptor: num_vertices = 9 num_cells = 8 - k_levels = 10 - v2edim = gtx.Dimension("V2E", kind=gtx.DimensionKind.LOCAL) - e2vdim = gtx.Dimension("E2V", kind=gtx.DimensionKind.LOCAL) - c2vdim = gtx.Dimension("C2V", kind=gtx.DimensionKind.LOCAL) - c2edim = gtx.Dimension("C2E", kind=gtx.DimensionKind.LOCAL) v2e_arr = np.array( [ @@ -183,57 +208,115 @@ def reduction_setup(): assert all(len(row) == 2 for row in e2v_arr) e2v_arr = np.asarray(e2v_arr, dtype=gtx.IndexType) - yield namedtuple( - "ReductionSetup", + return types.SimpleNamespace( + name="simple_mesh", + num_vertices=num_vertices, + num_edges=np.int32(num_edges), + num_cells=num_cells, + offset_provider={ + V2E.value: gtx.NeighborTableOffsetProvider( + v2e_arr, Vertex, Edge, 4, has_skip_values=False + ), + E2V.value: gtx.NeighborTableOffsetProvider( + e2v_arr, Edge, Vertex, 2, has_skip_values=False + ), + C2V.value: gtx.NeighborTableOffsetProvider( + c2v_arr, Cell, Vertex, 4, has_skip_values=False + ), + C2E.value: gtx.NeighborTableOffsetProvider( + c2e_arr, Cell, Edge, 4, has_skip_values=False + ), + }, + ) + + +def skip_value_mesh() -> MeshDescriptor: + """Mesh with skip values from the GT4Py quickstart guide.""" + + num_vertices = 7 + num_cells = 6 + num_edges = 12 + + v2e_arr = np.array( + [ + [1, 8, 7, 0, -1], + [2, 8, 1, -1, -1], + [3, 9, 8, 2, -1], + [4, 10, 3, -1, -1], + [5, 11, 4, -1, -1], + [0, 6, 4, -1, -1], + [6, 7, 9, 10, 11], + ], + dtype=gtx.IndexType, + ) + + e2v_arr = np.array( + [ + [0, 5], + [0, 1], + [1, 2], + [2, 3], + [3, 4], + [4, 5], + [5, 6], + [6, 0], + [0, 2], + [2, 6], + [3, 6], + [4, 6], + ], + dtype=gtx.IndexType, + ) + + c2v_arr = np.array( + [ + [0, 6, 5], + [0, 2, 6], + [0, 1, 2], + [2, 3, 6], + [3, 4, 6], + [4, 5, 6], + ], + dtype=gtx.IndexType, + ) + + c2e_arr = np.array( [ - "num_vertices", - "num_edges", - "num_cells", - "k_levels", - "V2EDim", - "E2VDim", - "C2VDim", - "C2EDim", - "V2E", - "E2V", - "C2V", - "C2E", - "inp", - "out", - "offset_provider", - "v2e_table", - "e2v_table", + [0, 6, 7], # cell 0 (neighbors: edge 0, edge 6, edge 7) + [7, 8, 9], # cell 1 + [1, 2, 8], # cell 2 + [3, 9, 10], # cell 3 + [4, 10, 11], # cell 4 + [5, 6, 11], # cell 5 ], - )( + dtype=gtx.IndexType, + ) + + return types.SimpleNamespace( + name="skip_value_mesh", num_vertices=num_vertices, num_edges=num_edges, num_cells=num_cells, - k_levels=k_levels, - V2EDim=v2edim, - E2VDim=e2vdim, - C2VDim=c2vdim, - C2EDim=c2edim, - V2E=gtx.FieldOffset("V2E", source=Edge, target=(Vertex, v2edim)), - E2V=gtx.FieldOffset("E2V", source=Vertex, target=(Edge, e2vdim)), - C2V=gtx.FieldOffset("C2V", source=Vertex, target=(Cell, c2vdim)), - C2E=gtx.FieldOffset("C2E", source=Edge, target=(Cell, c2edim)), - # inp=gtx.index_field(edge, dtype=np.int64), # TODO enable once we support gtx.index_fields in bindings - inp=gtx.as_field([Edge], np.arange(num_edges, dtype=np.int32)), - out=gtx.as_field([Vertex], np.zeros([num_vertices], dtype=np.int32)), offset_provider={ - "V2E": gtx.NeighborTableOffsetProvider(v2e_arr, Vertex, Edge, 4, has_skip_values=False), - "E2V": gtx.NeighborTableOffsetProvider(e2v_arr, Edge, Vertex, 2, has_skip_values=False), - "C2V": gtx.NeighborTableOffsetProvider(c2v_arr, Cell, Vertex, 4, has_skip_values=False), - "C2E": gtx.NeighborTableOffsetProvider(c2e_arr, Cell, Edge, 4, has_skip_values=False), + V2E.value: gtx.NeighborTableOffsetProvider( + v2e_arr, Vertex, Edge, 5, has_skip_values=True + ), + E2V.value: gtx.NeighborTableOffsetProvider( + e2v_arr, Edge, Vertex, 2, has_skip_values=False + ), + C2V.value: gtx.NeighborTableOffsetProvider( + c2v_arr, Cell, Vertex, 3, has_skip_values=False + ), + C2E.value: gtx.NeighborTableOffsetProvider( + c2e_arr, Cell, Edge, 3, has_skip_values=False + ), }, - v2e_table=v2e_arr, - e2v_table=e2v_arr, - ) # type: ignore + ) __all__ = [ "exec_alloc_descriptor", - "reduction_setup", + "mesh_descriptor", "debug_itir", "DimsType", "DType", @@ -249,3 +332,14 @@ def reduction_setup(): "EdgeOffset", "size", ] + + +@pytest.fixture( + params=[ + simple_mesh(), + pytest.param(skip_value_mesh(), marks=pytest.mark.uses_mesh_with_skip_values), + ], + ids=lambda p: p.name, +) +def mesh_descriptor(request) -> MeshDescriptor: + yield request.param diff --git a/tests/next_tests/integration_tests/feature_tests/ffront_tests/test_bound_args.py b/tests/next_tests/integration_tests/feature_tests/ffront_tests/test_bound_args.py index e4baedc6ee..e4df5b9b78 100644 --- a/tests/next_tests/integration_tests/feature_tests/ffront_tests/test_bound_args.py +++ b/tests/next_tests/integration_tests/feature_tests/ffront_tests/test_bound_args.py @@ -22,7 +22,7 @@ from next_tests.integration_tests.cases import cartesian_case from next_tests.integration_tests.feature_tests.ffront_tests.ffront_test_utils import ( exec_alloc_descriptor, - reduction_setup, + mesh_descriptor, ) diff --git a/tests/next_tests/integration_tests/feature_tests/ffront_tests/test_execution.py b/tests/next_tests/integration_tests/feature_tests/ffront_tests/test_execution.py index 02d54b1cb3..7fc2d82e67 100644 --- a/tests/next_tests/integration_tests/feature_tests/ffront_tests/test_execution.py +++ b/tests/next_tests/integration_tests/feature_tests/ffront_tests/test_execution.py @@ -22,6 +22,7 @@ from gt4py.next import ( astype, broadcast, + common, errors, float32, float64, @@ -52,7 +53,7 @@ ) from next_tests.integration_tests.feature_tests.ffront_tests.ffront_test_utils import ( exec_alloc_descriptor, - reduction_setup, + mesh_descriptor, ) @@ -508,9 +509,12 @@ def testee(a: cases.EField) -> cases.EField: unstructured_case, testee, ref=lambda a: np.sum( - np.sum(a[unstructured_case.offset_provider["V2E"].table], axis=1)[ - unstructured_case.offset_provider["E2V"].table - ], + np.sum( + a[unstructured_case.offset_provider["V2E"].table], + axis=1, + initial=0, + where=unstructured_case.offset_provider["V2E"].table != common.SKIP_VALUE, + )[unstructured_case.offset_provider["E2V"].table], axis=1, ), comparison=lambda a, tmp_2: np.all(a == tmp_2), @@ -568,11 +572,15 @@ def reduce_tuple_element(e: cases.EField, v: cases.VField) -> cases.EField: tmp = red(E2V[0]) return tmp + v2e = unstructured_case.offset_provider["V2E"] cases.verify_with_default_data( unstructured_case, reduce_tuple_element, ref=lambda e, v: np.sum( - e[unstructured_case.offset_provider["V2E"].table] + np.tile(v, (4, 1)).T, axis=1 + e[v2e.table] + np.tile(v, (v2e.max_neighbors, 1)).T, + axis=1, + initial=0, + where=v2e.table != common.SKIP_VALUE, )[unstructured_case.offset_provider["E2V"].table[:, 0]], ) @@ -703,13 +711,17 @@ def testee(a: cases.EField, b: cases.EField) -> cases.VField: tmp = neighbor_sum(b(V2E) if 2 < 3 else a(V2E), axis=V2EDim) return tmp + v2e_table = unstructured_case.offset_provider["V2E"].table cases.verify_with_default_data( unstructured_case, testee, ref=lambda a, b: ( - np.sum(b[unstructured_case.offset_provider["V2E"].table], axis=1) - if 2 < 3 - else np.sum(a[unstructured_case.offset_provider["V2E"].table], axis=1) + np.sum( + b[v2e_table], + axis=1, + initial=0, + where=v2e_table != common.SKIP_VALUE, + ) ), ) diff --git a/tests/next_tests/integration_tests/feature_tests/ffront_tests/test_external_local_field.py b/tests/next_tests/integration_tests/feature_tests/ffront_tests/test_external_local_field.py index bb1d878a6a..569d7b5631 100644 --- a/tests/next_tests/integration_tests/feature_tests/ffront_tests/test_external_local_field.py +++ b/tests/next_tests/integration_tests/feature_tests/ffront_tests/test_external_local_field.py @@ -16,13 +16,13 @@ import pytest import gt4py.next as gtx -from gt4py.next import int32, neighbor_sum +from gt4py.next import common, int32, neighbor_sum from next_tests.integration_tests import cases from next_tests.integration_tests.cases import V2E, Edge, V2EDim, Vertex, unstructured_case from next_tests.integration_tests.feature_tests.ffront_tests.ffront_test_utils import ( exec_alloc_descriptor, - reduction_setup, + mesh_descriptor, ) @@ -43,13 +43,19 @@ def testee( ) ones = cases.allocate(unstructured_case, testee, "ones").strategy(cases.ConstInitializer(1))() + v2e_table = unstructured_case.offset_provider["V2E"].table cases.verify( unstructured_case, testee, inp, ones, out=cases.allocate(unstructured_case, testee, cases.RETURN)(), - ref=np.sum(unstructured_case.offset_provider["V2E"].table, axis=1), + ref=np.sum( + v2e_table, + axis=1, + initial=0, + where=v2e_table != common.SKIP_VALUE, + ), ) diff --git a/tests/next_tests/integration_tests/feature_tests/ffront_tests/test_gt4py_builtins.py b/tests/next_tests/integration_tests/feature_tests/ffront_tests/test_gt4py_builtins.py index 90d07f360d..e27e73c80d 100644 --- a/tests/next_tests/integration_tests/feature_tests/ffront_tests/test_gt4py_builtins.py +++ b/tests/next_tests/integration_tests/feature_tests/ffront_tests/test_gt4py_builtins.py @@ -17,7 +17,7 @@ import pytest import gt4py.next as gtx -from gt4py.next import broadcast, float64, int32, max_over, min_over, neighbor_sum, where +from gt4py.next import broadcast, common, float64, int32, max_over, min_over, neighbor_sum, where from gt4py.next.program_processors.runners import gtfn from next_tests.integration_tests import cases @@ -35,7 +35,7 @@ ) from next_tests.integration_tests.feature_tests.ffront_tests.ffront_test_utils import ( exec_alloc_descriptor, - reduction_setup, + mesh_descriptor, ) @@ -56,7 +56,12 @@ def testee(edge_f: cases.EField) -> cases.VField: out = cases.allocate(unstructured_case, testee, cases.RETURN)() v2e_table = unstructured_case.offset_provider["V2E"].table - ref = np.max(inp.ndarray[v2e_table], axis=1) + ref = np.max( + inp.asnumpy()[v2e_table], + axis=1, + initial=np.min(inp.asnumpy()), + where=v2e_table != common.SKIP_VALUE, + ) cases.verify(unstructured_case, testee, inp, ref=ref, out=out) @@ -69,7 +74,14 @@ def minover(edge_f: cases.EField) -> cases.VField: v2e_table = unstructured_case.offset_provider["V2E"].table cases.verify_with_default_data( - unstructured_case, minover, ref=lambda edge_f: np.min(edge_f[v2e_table], axis=1) + unstructured_case, + minover, + ref=lambda edge_f: np.min( + edge_f[v2e_table], + axis=1, + initial=np.max(edge_f), + where=v2e_table != common.SKIP_VALUE, + ), ) @@ -83,10 +95,16 @@ def reduction(edge_f: cases.EField) -> cases.VField: def fencil(edge_f: cases.EField, out: cases.VField): reduction(edge_f, out=out) + v2e_table = unstructured_case.offset_provider["V2E"].table cases.verify_with_default_data( unstructured_case, fencil, - ref=lambda edge_f: np.sum(edge_f[unstructured_case.offset_provider["V2E"].table], axis=1), + ref=lambda edge_f: np.sum( + edge_f[v2e_table], + axis=1, + initial=0, + where=v2e_table != common.SKIP_VALUE, + ), ) @@ -103,11 +121,17 @@ def reduce_expr(edge_f: cases.EField) -> cases.VField: def fencil(edge_f: cases.EField, out: cases.VField): reduce_expr(edge_f, out=out) + v2e_table = unstructured_case.offset_provider["V2E"].table cases.verify_with_default_data( unstructured_case, fencil, ref=lambda edge_f: 3 - * np.sum(-edge_f[unstructured_case.offset_provider["V2E"].table] ** 2 * 2, axis=1), + * np.sum( + -edge_f[v2e_table] ** 2 * 2, + axis=1, + initial=0, + where=v2e_table != common.SKIP_VALUE, + ), ) @@ -117,10 +141,16 @@ def test_reduction_with_common_expression(unstructured_case): def testee(flux: cases.EField) -> cases.VField: return neighbor_sum(flux(V2E) + flux(V2E), axis=V2EDim) + v2e_table = unstructured_case.offset_provider["V2E"].table cases.verify_with_default_data( unstructured_case, testee, - ref=lambda flux: np.sum(flux[unstructured_case.offset_provider["V2E"].table] * 2, axis=1), + ref=lambda flux: np.sum( + flux[v2e_table] * 2, + axis=1, + initial=0, + where=v2e_table != common.SKIP_VALUE, + ), ) diff --git a/tests/next_tests/integration_tests/feature_tests/ffront_tests/test_temporaries_with_sizes.py b/tests/next_tests/integration_tests/feature_tests/ffront_tests/test_temporaries_with_sizes.py index c4cdd8a4be..13d8f7711e 100644 --- a/tests/next_tests/integration_tests/feature_tests/ffront_tests/test_temporaries_with_sizes.py +++ b/tests/next_tests/integration_tests/feature_tests/ffront_tests/test_temporaries_with_sizes.py @@ -31,7 +31,7 @@ unstructured_case, ) from next_tests.integration_tests.feature_tests.ffront_tests.ffront_test_utils import ( - reduction_setup, + mesh_descriptor, ) from next_tests.toy_connectivity import Cell, Edge @@ -67,7 +67,7 @@ def prog( a: cases.VField, out: cases.EField, num_vertices: int32, - num_edges: int64, + num_edges: int32, num_cells: int32, ): testee_op(a, out=out) @@ -75,15 +75,15 @@ def prog( return prog -def test_verification(testee, run_gtfn_with_temporaries_and_symbolic_sizes, reduction_setup): +def test_verification(testee, run_gtfn_with_temporaries_and_symbolic_sizes, mesh_descriptor): unstructured_case = Case( run_gtfn_with_temporaries_and_symbolic_sizes.executor, - offset_provider=reduction_setup.offset_provider, + offset_provider=mesh_descriptor.offset_provider, default_sizes={ - Vertex: reduction_setup.num_vertices, - Edge: reduction_setup.num_edges, - Cell: reduction_setup.num_cells, - KDim: reduction_setup.k_levels, + Vertex: mesh_descriptor.num_vertices, + Edge: mesh_descriptor.num_edges, + Cell: mesh_descriptor.num_cells, + KDim: 10, }, grid_type=common.GridType.UNSTRUCTURED, allocator=run_gtfn_with_temporaries_and_symbolic_sizes.allocator, @@ -92,7 +92,7 @@ def test_verification(testee, run_gtfn_with_temporaries_and_symbolic_sizes, redu a = cases.allocate(unstructured_case, testee, "a")() out = cases.allocate(unstructured_case, testee, "out")() - first_nbs, second_nbs = (reduction_setup.offset_provider["E2V"].table[:, i] for i in [0, 1]) + first_nbs, second_nbs = (mesh_descriptor.offset_provider["E2V"].table[:, i] for i in [0, 1]) ref = (a.ndarray * 2)[first_nbs] + (a.ndarray * 2)[second_nbs] cases.verify( @@ -100,19 +100,19 @@ def test_verification(testee, run_gtfn_with_temporaries_and_symbolic_sizes, redu testee, a, out, - reduction_setup.num_vertices, - reduction_setup.num_edges, - reduction_setup.num_cells, + mesh_descriptor.num_vertices, + mesh_descriptor.num_edges, + mesh_descriptor.num_cells, inout=out, ref=ref, ) -def test_temporary_symbols(testee, reduction_setup): +def test_temporary_symbols(testee, mesh_descriptor): itir_with_tmp = apply_common_transforms( testee.itir, lift_mode=LiftMode.FORCE_TEMPORARIES, - offset_provider=reduction_setup.offset_provider, + offset_provider=mesh_descriptor.offset_provider, ) params = ["num_vertices", "num_edges", "num_cells"] From 75d23d0afe1d190e3d6bc1f9c886d417e1c30a4d Mon Sep 17 00:00:00 2001 From: Till Ehrengruber Date: Wed, 31 Jan 2024 23:00:11 +0100 Subject: [PATCH 64/85] Fix missing cstdint header in gtcpp codegen (#1439) --- src/gt4py/cartesian/gtc/gtcpp/gtcpp_codegen.py | 1 + 1 file changed, 1 insertion(+) diff --git a/src/gt4py/cartesian/gtc/gtcpp/gtcpp_codegen.py b/src/gt4py/cartesian/gtc/gtcpp/gtcpp_codegen.py index 00c3fb16d8..4e56b159d9 100644 --- a/src/gt4py/cartesian/gtc/gtcpp/gtcpp_codegen.py +++ b/src/gt4py/cartesian/gtc/gtcpp/gtcpp_codegen.py @@ -288,6 +288,7 @@ def visit_Program(self, node: gtcpp.Program, **kwargs: Any) -> Union[str, Collec Program = as_mako( """ + #include #include #include #include From d6dfd6ff46cc1d50b0fb6d05fb0b6271e4a1f5cc Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Philip=20M=C3=BCller?= <147368808+philip-paul-mueller@users.noreply.github.com> Date: Thu, 1 Feb 2024 09:12:44 +0100 Subject: [PATCH 65/85] feat[next][dace]: Modified the file caching. (#1434) In PR #1422 @edopao introduced a mechanism to skip the SDFG translation. This PR moves this cache from the `run_dace_iterator()` function into the `build_sdfg_from_itir()` function. --- .../runners/dace_iterator/__init__.py | 64 ++++++++++++------- 1 file changed, 41 insertions(+), 23 deletions(-) diff --git a/src/gt4py/next/program_processors/runners/dace_iterator/__init__.py b/src/gt4py/next/program_processors/runners/dace_iterator/__init__.py index 6a8b9bc9c6..ed68e66bc9 100644 --- a/src/gt4py/next/program_processors/runners/dace_iterator/__init__.py +++ b/src/gt4py/next/program_processors/runners/dace_iterator/__init__.py @@ -254,21 +254,36 @@ def build_sdfg_from_itir( on_gpu: bool = False, column_axis: Optional[common.Dimension] = None, lift_mode: itir_transforms.LiftMode = itir_transforms.LiftMode.FORCE_INLINE, + load_sdfg_from_file: bool = False, + cache_id: Optional[str] = None, + save_sdfg: bool = True, ) -> dace.SDFG: """Translate a Fencil into an SDFG. Args: - program: The Fencil that should be translated. - *args: Arguments for which the fencil should be called. - offset_provider: The set of offset providers that should be used. - auto_optimize: Apply DaCe's `auto_optimize` heuristic. - on_gpu: Performs the translation for GPU, defaults to `False`. - column_axis: The column axis to be used, defaults to `None`. - lift_mode: Which lift mode should be used, defaults `FORCE_INLINE`. + program: The Fencil that should be translated. + *args: Arguments for which the fencil should be called. + offset_provider: The set of offset providers that should be used. + auto_optimize: Apply DaCe's `auto_optimize` heuristic. + on_gpu: Performs the translation for GPU, defaults to `False`. + column_axis: The column axis to be used, defaults to `None`. + lift_mode: Which lift mode should be used, defaults `FORCE_INLINE`. + load_sdfg_from_file: Allows to read the SDFG from file, instead of generating it, for debug only. + cache_id: The id of the cache entry, used to disambiguate stored sdfgs. + save_sdfg: If `True`, the default the SDFG is stored as a file and can be loaded, this allows to skip the lowering step, requires `load_sdfg_from_file` set to `True`. Notes: Currently only the `FORCE_INLINE` liftmode is supported and the value of `lift_mode` is ignored. """ + # Test if we can go through the cache? + sdfg_filename = ( + f"_dacegraphs/gt4py/{cache_id if cache_id is not None else '.'}/{program.id}.sdfg" + ) + if load_sdfg_from_file and Path(sdfg_filename).exists(): + sdfg: dace.SDFG = dace.SDFG.from_file(sdfg_filename) + sdfg.validate() + return sdfg + # TODO(edopao): As temporary fix until temporaries are supported in the DaCe Backend force # `lift_more` to `FORCE_INLINE` mode. lift_mode = itir_transforms.LiftMode.FORCE_INLINE @@ -277,7 +292,7 @@ def build_sdfg_from_itir( # visit ITIR and generate SDFG program = preprocess_program(program, offset_provider, lift_mode) sdfg_genenerator = ItirToSDFG(arg_types, offset_provider, column_axis) - sdfg: dace.SDFG = sdfg_genenerator.visit(program) + sdfg = sdfg_genenerator.visit(program) if sdfg is None: raise RuntimeError(f"Visit failed for program {program.id}.") @@ -311,6 +326,10 @@ def build_sdfg_from_itir( if on_gpu: sdfg.apply_gpu_transformations() + # Store the sdfg such that we can later reuse it. + if save_sdfg: + sdfg.save(sdfg_filename) + return sdfg @@ -326,7 +345,8 @@ def run_dace_iterator(program: itir.FencilDefinition, *args, **kwargs): column_axis = kwargs.get("column_axis", None) offset_provider = kwargs["offset_provider"] # debug option to store SDFGs on filesystem and skip lowering ITIR to SDFG at each run - skip_itir_lowering_to_sdfg = kwargs.get("skip_itir_lowering_to_sdfg", False) + load_sdfg_from_file = kwargs.get("load_sdfg_from_file", False) + save_sdfg = kwargs.get("save_sdfg", True) arg_types = [type_translation.from_value(arg) for arg in args] @@ -336,20 +356,18 @@ def run_dace_iterator(program: itir.FencilDefinition, *args, **kwargs): sdfg_program = build_cache[cache_id] sdfg = sdfg_program.sdfg else: - sdfg_filename = f"_dacegraphs/gt4py/{cache_id}/{program.id}.sdfg" - if not (skip_itir_lowering_to_sdfg and Path(sdfg_filename).exists()): - sdfg = build_sdfg_from_itir( - program, - *args, - offset_provider=offset_provider, - auto_optimize=auto_optimize, - on_gpu=on_gpu, - column_axis=column_axis, - lift_mode=lift_mode, - ) - sdfg.save(sdfg_filename) - else: - sdfg = dace.SDFG.from_file(sdfg_filename) + sdfg = build_sdfg_from_itir( + program, + *args, + offset_provider=offset_provider, + auto_optimize=auto_optimize, + on_gpu=on_gpu, + column_axis=column_axis, + lift_mode=lift_mode, + load_sdfg_from_file=load_sdfg_from_file, + cache_id=cache_id, + save_sdfg=save_sdfg, + ) sdfg.build_folder = compilation_cache._session_cache_dir_path / ".dacecache" with dace.config.temporary_config(): From 58ec4dd2fa80d8b06d146ade1ba7ef354b1f41b9 Mon Sep 17 00:00:00 2001 From: Hannes Vogt Date: Fri, 2 Feb 2024 12:01:15 +0100 Subject: [PATCH 66/85] bug[next]: allow fields of different sizes in tuple in itir embedded (#1442) Undo an unintended change in #1202 to re-enable an icon4py pattern. Longer term, probably, only transposable tuples of fields make sense, e.g. by intersecting. --- src/gt4py/next/iterator/embedded.py | 22 +++++++------------ .../iterator_tests/test_tuple.py | 6 +++-- 2 files changed, 12 insertions(+), 16 deletions(-) diff --git a/src/gt4py/next/iterator/embedded.py b/src/gt4py/next/iterator/embedded.py index 7e0e060834..011ca4d92b 100644 --- a/src/gt4py/next/iterator/embedded.py +++ b/src/gt4py/next/iterator/embedded.py @@ -168,7 +168,7 @@ class LocatedField(Protocol): @property @abc.abstractmethod - def __gt_domain__(self) -> common.Domain: ... + def dims(self) -> tuple[common.Dimension, ...]: ... # TODO(havogt): define generic Protocol to provide a concrete return type @abc.abstractmethod @@ -176,7 +176,7 @@ def field_getitem(self, indices: NamedFieldIndices) -> Any: ... @property def __gt_origin__(self) -> tuple[int, ...]: - return tuple([0] * len(self.__gt_domain__.dims)) + return tuple([0] * len(self.dims)) @runtime_checkable @@ -678,18 +678,12 @@ def _is_concrete_position(pos: Position) -> TypeGuard[ConcretePosition]: def _get_axes( field_or_tuple: LocatedField | tuple, ) -> Sequence[common.Dimension]: # arbitrary nesting of tuples of LocatedField - return _get_domain(field_or_tuple).dims - - -def _get_domain( - field_or_tuple: LocatedField | tuple, -) -> common.Domain: # arbitrary nesting of tuples of LocatedField if isinstance(field_or_tuple, tuple): - first = _get_domain(field_or_tuple[0]) - assert all(first == _get_domain(f) for f in field_or_tuple) + first = _get_axes(field_or_tuple[0]) + assert all(first == _get_axes(f) for f in field_or_tuple) return first else: - return field_or_tuple.__gt_domain__ + return field_or_tuple.dims def _single_vertical_idx( @@ -900,8 +894,8 @@ class NDArrayLocatedFieldWrapper(MutableLocatedField): _ndarrayfield: common.Field @property - def __gt_domain__(self) -> common.Domain: - return self._ndarrayfield.__gt_domain__ + def dims(self) -> tuple[common.Dimension, ...]: + return self._ndarrayfield.__gt_domain__.dims def _translate_named_indices( self, _named_indices: NamedFieldIndices @@ -1452,7 +1446,7 @@ def _tuple_assign(field: tuple | MutableLocatedField, value: Any, named_indices: class TupleOfFields(TupleField): def __init__(self, data): self.data = data - self.__gt_domain__ = _get_domain(data) + self.dims = _get_axes(data) def field_getitem(self, named_indices: NamedFieldIndices) -> Any: return _build_tuple_result(self.data, named_indices) diff --git a/tests/next_tests/integration_tests/feature_tests/iterator_tests/test_tuple.py b/tests/next_tests/integration_tests/feature_tests/iterator_tests/test_tuple.py index add772e7ef..925ad33e86 100644 --- a/tests/next_tests/integration_tests/feature_tests/iterator_tests/test_tuple.py +++ b/tests/next_tests/integration_tests/feature_tests/iterator_tests/test_tuple.py @@ -311,7 +311,9 @@ def test_tuple_field_input(program_processor): ) inp2 = gtx.as_field( [IDim, JDim, KDim], - rng.normal(size=(shape[0], shape[1], shape[2])), + rng.normal( + size=(shape[0], shape[1], shape[2] + 1) + ), # TODO(havogt) currently we allow different sizes, needed for icon4py compatibility ) out = gtx.as_field([IDim, JDim, KDim], np.zeros(shape)) @@ -323,7 +325,7 @@ def test_tuple_field_input(program_processor): } run_processor(tuple_input[dom], program_processor, (inp1, inp2), out=out, offset_provider={}) if validate: - assert np.allclose(inp1.asnumpy() + inp2.asnumpy(), out.asnumpy()) + assert np.allclose(inp1.asnumpy() + inp2.asnumpy()[:, :, :-1], out.asnumpy()) @pytest.mark.xfail(reason="Implement wrapper for extradim as tuple") From 0d158adcbeb26172033221757435c5cd9b6d9582 Mon Sep 17 00:00:00 2001 From: edopao Date: Fri, 2 Feb 2024 15:11:34 +0100 Subject: [PATCH 67/85] fix[next][dace]: Fix for neighbor reduction with skip values (#1443) This PR provides a bugfix for the case of neighbor reductions with lambda function as reduction operation and connectivity table containing skip values. The lambda function should only accumulate the results for the valid neighbors. On the contrary, the baseline implementation was using the reduction identity value for the missing neighbors, resulting in invalid result. The fix consists of producing an array of boolean flags to determine if the neighbor value is valid or not. If not valid, the call to the lambda function is by-passed. --- .../runners/dace_iterator/__init__.py | 2 +- .../runners/dace_iterator/itir_to_tasklet.py | 264 +++++++++++++----- tests/next_tests/definitions.py | 1 - 3 files changed, 200 insertions(+), 67 deletions(-) diff --git a/src/gt4py/next/program_processors/runners/dace_iterator/__init__.py b/src/gt4py/next/program_processors/runners/dace_iterator/__init__.py index ed68e66bc9..2e9a66c435 100644 --- a/src/gt4py/next/program_processors/runners/dace_iterator/__init__.py +++ b/src/gt4py/next/program_processors/runners/dace_iterator/__init__.py @@ -299,7 +299,7 @@ def build_sdfg_from_itir( for nested_sdfg in sdfg.all_sdfgs_recursive(): if not nested_sdfg.debuginfo: _, frameinfo = warnings.warn( - f"{nested_sdfg} does not have debuginfo. Consider adding them in the corresponding nested sdfg." + f"{nested_sdfg.label} does not have debuginfo. Consider adding them in the corresponding nested sdfg." ), getframeinfo( currentframe() # type: ignore ) diff --git a/src/gt4py/next/program_processors/runners/dace_iterator/itir_to_tasklet.py b/src/gt4py/next/program_processors/runners/dace_iterator/itir_to_tasklet.py index ba969608a7..56ffe7e104 100644 --- a/src/gt4py/next/program_processors/runners/dace_iterator/itir_to_tasklet.py +++ b/src/gt4py/next/program_processors/runners/dace_iterator/itir_to_tasklet.py @@ -18,11 +18,11 @@ import dace import numpy as np -from dace.transformation.dataflow import MapFusion import gt4py.eve.codegen from gt4py import eve from gt4py.next import Dimension, StridedNeighborOffsetProvider, type_inference as next_typing +from gt4py.next.common import SKIP_VALUE as neighbor_skip_value from gt4py.next.iterator import ir as itir, type_inference as itir_typing from gt4py.next.iterator.embedded import NeighborTableOffsetProvider from gt4py.next.iterator.ir import FunCall, Lambda @@ -184,86 +184,93 @@ def __init__( def builtin_neighbors( transformer: "PythonTaskletCodegen", node: itir.Expr, node_args: list[itir.Expr] ) -> list[ValueExpr]: - di = dace_debuginfo(node, transformer.context.body.debuginfo) + sdfg: dace.SDFG = transformer.context.body + state: dace.SDFGState = transformer.context.state + + di = dace_debuginfo(node, sdfg.debuginfo) offset_literal, data = node_args assert isinstance(offset_literal, itir.OffsetLiteral) offset_dim = offset_literal.value assert isinstance(offset_dim, str) - iterator = transformer.visit(data) - assert isinstance(iterator, IteratorExpr) - field_desc = iterator.field.desc(transformer.context.body) - - field_index = "__field_idx" offset_provider = transformer.offset_provider[offset_dim] - if isinstance(offset_provider, NeighborTableOffsetProvider): - neighbor_check = f"{field_index} >= 0" - elif isinstance(offset_provider, StridedNeighborOffsetProvider): - neighbor_check = f"{field_index} < {field_desc.shape[offset_provider.neighbor_axis.value]}" - else: - assert isinstance(offset_provider, Dimension) + if not isinstance(offset_provider, NeighborTableOffsetProvider): raise NotImplementedError( - "Neighbor reductions for cartesian grids not implemented in DaCe backend." + "Neighbor reduction only implemented for connectivity based on neighbor tables." ) - assert transformer.context.reduce_identity is not None - - sdfg: dace.SDFG = transformer.context.body - state: dace.SDFGState = transformer.context.state + iterator = transformer.visit(data) + assert isinstance(iterator, IteratorExpr) + field_desc = iterator.field.desc(transformer.context.body) + origin_index_node = iterator.indices[offset_provider.origin_axis.value] - shifted_dim = offset_provider.origin_axis.value + assert transformer.context.reduce_identity is not None + assert transformer.context.reduce_identity.dtype == iterator.dtype - result_name = unique_var_name() + # gather the neighbors in a result array dimensioned for `max_neighbors` + neighbor_value_var = unique_var_name() sdfg.add_array( - result_name, dtype=iterator.dtype, shape=(offset_provider.max_neighbors,), transient=True + neighbor_value_var, + dtype=iterator.dtype, + shape=(offset_provider.max_neighbors,), + transient=True, ) - result_access = state.add_access(result_name, debuginfo=di) + neighbor_value_node = state.add_access(neighbor_value_var, debuginfo=di) + + # allocate scalar to store index for direct addressing of neighbor field + neighbor_index_var = unique_var_name() + sdfg.add_scalar(neighbor_index_var, _INDEX_DTYPE, transient=True) + neighbor_index_node = state.add_access(neighbor_index_var, debuginfo=di) # generate unique map index name to avoid conflict with other maps inside same state - neighbor_index = unique_name("neighbor_idx") + neighbor_map_index = unique_name(f"{offset_dim}_neighbor_map_idx") me, mx = state.add_map( - f"{offset_dim}_neighbors_map", - ndrange={neighbor_index: f"0:{offset_provider.max_neighbors}"}, + f"{offset_dim}_neighbor_map", + ndrange={neighbor_map_index: f"0:{offset_provider.max_neighbors}"}, debuginfo=di, ) - table_name = connectivity_identifier(offset_dim) - table_subset = (f"0:{sdfg.arrays[table_name].shape[0]}", neighbor_index) + table_name = connectivity_identifier(offset_dim) shift_tasklet = state.add_tasklet( "shift", - code="__result = __table[__idx]", + code=f"__result = __table[__idx, {neighbor_map_index}]", inputs={"__table", "__idx"}, outputs={"__result"}, debuginfo=di, ) - data_access_tasklet = state.add_tasklet( - "data_access", - code=f"__result = __field[{field_index}]" - + ( - f" if {neighbor_check} else {transformer.context.reduce_identity.value}" - if offset_provider.has_skip_values - else "" - ), - inputs={"__field", field_index}, - outputs={"__result"}, - debuginfo=di, - ) - idx_name = unique_var_name() - sdfg.add_scalar(idx_name, _INDEX_DTYPE, transient=True) state.add_memlet_path( state.add_access(table_name, debuginfo=di), me, shift_tasklet, - memlet=create_memlet_at(table_name, table_subset), + memlet=create_memlet_full(table_name, sdfg.arrays[table_name]), dst_conn="__table", ) state.add_memlet_path( - iterator.indices[shifted_dim], + origin_index_node, me, shift_tasklet, - memlet=dace.Memlet(data=iterator.indices[shifted_dim].data, subset="0", debuginfo=di), + memlet=dace.Memlet(data=origin_index_node.data, subset="0", debuginfo=di), dst_conn="__idx", ) - state.add_edge(shift_tasklet, "__result", data_access_tasklet, field_index, dace.Memlet()) + state.add_edge( + shift_tasklet, + "__result", + neighbor_index_node, + None, + dace.Memlet(data=neighbor_index_var, subset="0"), + ) + + data_access_tasklet = state.add_tasklet( + "data_access", + code="__data = __field[__idx]" + + ( + f" if __idx != {neighbor_skip_value} else {transformer.context.reduce_identity.value}" + if offset_provider.has_skip_values + else "" + ), + inputs={"__field", "__idx"}, + outputs={"__data"}, + debuginfo=di, + ) # select full shape only in the neighbor-axis dimension field_subset = tuple( f"0:{shape}" if dim == offset_provider.neighbor_axis.value else f"i_{dim}" @@ -276,15 +283,63 @@ def builtin_neighbors( memlet=create_memlet_at(iterator.field.data, field_subset), dst_conn="__field", ) + state.add_edge( + neighbor_index_node, + None, + data_access_tasklet, + "__idx", + dace.Memlet(data=neighbor_index_var, subset="0"), + ) state.add_memlet_path( data_access_tasklet, mx, - result_access, - memlet=dace.Memlet(data=result_name, subset=neighbor_index, debuginfo=di), - src_conn="__result", + neighbor_value_node, + memlet=dace.Memlet(data=neighbor_value_var, subset=neighbor_map_index, debuginfo=di), + src_conn="__data", ) - return [ValueExpr(result_access, iterator.dtype)] + if not offset_provider.has_skip_values: + return [ValueExpr(neighbor_value_node, iterator.dtype)] + else: + """ + In case of neighbor tables with skip values, in addition to the array of neighbor values this function also + returns an array of booleans to indicate if the neighbor value is present or not. This node is only used + for neighbor reductions with lambda functions, a very specific case. For single input neighbor reductions, + the regular case, this node will be removed by the simplify pass. + """ + neighbor_valid_var = unique_var_name() + sdfg.add_array( + neighbor_valid_var, + dtype=dace.dtypes.bool, + shape=(offset_provider.max_neighbors,), + transient=True, + ) + neighbor_valid_node = state.add_access(neighbor_valid_var, debuginfo=di) + + neighbor_valid_tasklet = state.add_tasklet( + "check_valid_neighbor", + {"__idx"}, + {"__valid"}, + f"__valid = True if __idx != {neighbor_skip_value} else False", + ) + state.add_edge( + neighbor_index_node, + None, + neighbor_valid_tasklet, + "__idx", + dace.Memlet(data=neighbor_index_var, subset="0"), + ) + state.add_memlet_path( + neighbor_valid_tasklet, + mx, + neighbor_valid_node, + memlet=dace.Memlet(data=neighbor_valid_var, subset=neighbor_map_index), + src_conn="__valid", + ) + return [ + ValueExpr(neighbor_value_node, iterator.dtype), + ValueExpr(neighbor_valid_node, dace.dtypes.bool), + ] def builtin_can_deref( @@ -419,6 +474,42 @@ def builtin_cast( ) +def builtin_make_const_list( + transformer: "PythonTaskletCodegen", node: itir.Expr, node_args: list[itir.Expr] +) -> list[ValueExpr]: + di = dace_debuginfo(node, transformer.context.body.debuginfo) + args = [transformer.visit(arg)[0] for arg in node_args] + assert all(isinstance(x, (SymbolExpr, ValueExpr)) for x in args) + args_dtype = [x.dtype for x in args] + assert len(set(args_dtype)) == 1 + dtype = args_dtype[0] + + var_name = unique_var_name() + transformer.context.body.add_array(var_name, (len(args),), dtype, transient=True) + var_node = transformer.context.state.add_access(var_name, debuginfo=di) + + for i, arg in enumerate(args): + if isinstance(arg, SymbolExpr): + transformer.context.state.add_edge( + transformer.context.state.add_tasklet( + f"get_arg{i}", {}, {"val"}, f"val = {arg.value}" + ), + "val", + var_node, + None, + dace.Memlet(data=var_name, subset=f"{i}"), + ) + else: + assert arg.value.desc(transformer.context.body).shape == (1,) + transformer.context.state.add_nedge( + arg.value, + var_node, + dace.Memlet(data=arg.value.data, subset="0", other_subset=f"{i}"), + ) + + return [ValueExpr(var_node, dtype)] + + def builtin_make_tuple( transformer: "PythonTaskletCodegen", node: itir.Expr, node_args: list[itir.Expr] ) -> list[ValueExpr]: @@ -443,6 +534,7 @@ def builtin_tuple_get( "cast_": builtin_cast, "if_": builtin_if, "list_get": builtin_list_get, + "make_const_list": builtin_make_const_list, "make_tuple": builtin_make_tuple, "neighbors": builtin_neighbors, "tuple_get": builtin_tuple_get, @@ -578,7 +670,7 @@ def visit_Lambda( # Create the SDFG for the lambda's body lambda_sdfg = dace.SDFG(func_name) - lambda_sdfg.debuginfo = dace_debuginfo(node) + lambda_sdfg.debuginfo = dace_debuginfo(node, self.context.body.debuginfo) lambda_state = lambda_sdfg.add_state(f"{func_name}_entry", True) lambda_symbols_pass = GatherLambdaSymbolsPass( @@ -951,7 +1043,7 @@ def _visit_reduce(self, node: itir.FunCall): args = self.visit(node.args[0]) - assert len(args) == 1 + assert 1 <= len(args) <= 2 reduce_input_node = args[0].value else: @@ -967,16 +1059,22 @@ def _visit_reduce(self, node: itir.FunCall): # set reduction state in visit context self.context.reduce_identity = SymbolExpr(reduce_identity, reduce_dtype) - args = flatten_list(self.visit(node.args)) + args = self.visit(node.args) # clear context self.context.reduce_identity = None # check that all neighbor expressions have the same shape - nreduce_shape = args[1].value.desc(self.context.body).shape - assert all( - [arg.value.desc(self.context.body).shape == nreduce_shape for arg in args[2:]] - ) + args_shape = [ + arg[0].value.desc(self.context.body).shape + for arg in args + if arg[0].value.desc(self.context.body).shape != (1,) + ] + assert len(set(args_shape)) == 1 + nreduce_shape = args_shape[0] + + input_args = [arg[0] for arg in args] + input_valid = [arg[1] for arg in args if len(arg) == 2] nreduce_index = tuple(f"_i{i}" for i in range(len(nreduce_shape))) nreduce_domain = {idx: f"0:{size}" for idx, size in zip(nreduce_index, nreduce_shape)} @@ -990,12 +1088,16 @@ def _visit_reduce(self, node: itir.FunCall): expr=fun_node.expr.args[1], params=fun_node.params[1:], location=node.location ) lambda_context, inner_inputs, inner_outputs = self.visit( - lambda_node, args=args, use_neighbor_tables=False + lambda_node, args=input_args, use_neighbor_tables=False ) input_mapping = { - param: create_memlet_at(arg.value.data, nreduce_index) - for (param, _), arg in zip(inner_inputs, args) + param: ( + dace.Memlet(data=arg.value.data, subset="0") + if arg.value.desc(self.context.body).shape == (1,) + else create_memlet_at(arg.value.data, nreduce_index) + ) + for (param, _), arg in zip(inner_inputs, input_args) } output_mapping = { inner_outputs[0].value.data: create_memlet_at(reduce_input_name, nreduce_index) @@ -1004,6 +1106,42 @@ def _visit_reduce(self, node: itir.FunCall): self.context.body, lambda_context.body, input_mapping ) + if input_valid: + """ + The neighbors builtin returns an array of booleans in case the connectivity table + contains skip values. These boolean values indicate whether the neighbor value is present or not, + and are used below to construct an if/else branch to bypass the lambda call for neighbor skip values. + If the neighbor table has full connectivity (no skip values by type definition), the input_valid node + is not built, and the construction of the if/else branch below is also skipped. + """ + input_args.append(input_valid[0]) + input_valid_node = input_valid[0].value + # add input connector to nested sdfg + input_mapping["is_valid"] = create_memlet_at(input_valid_node.data, nreduce_index) + # check neighbor validity on if/else inter-state edge + start_state = lambda_context.body.add_state("start", is_start_block=True) + skip_neighbor_state = lambda_context.body.add_state("skip_neighbor") + skip_neighbor_state.add_edge( + skip_neighbor_state.add_tasklet( + "identity", {}, {"val"}, f"val = {reduce_identity}" + ), + "val", + skip_neighbor_state.add_access(inner_outputs[0].value.data), + None, + dace.Memlet(data=inner_outputs[0].value.data, subset="0"), + ) + lambda_context.body.add_scalar("is_valid", dace.dtypes.bool) + lambda_context.body.add_edge( + start_state, + skip_neighbor_state, + dace.InterstateEdge(condition="is_valid == False"), + ) + lambda_context.body.add_edge( + start_state, + lambda_context.state, + dace.InterstateEdge(condition="is_valid == True"), + ) + reduce_input_node = self.context.state.add_access(reduce_input_name, debuginfo=di) nsdfg_node, map_entry, _ = add_mapped_nested_sdfg( @@ -1013,7 +1151,7 @@ def _visit_reduce(self, node: itir.FunCall): inputs=input_mapping, outputs=output_mapping, symbol_mapping=symbol_mapping, - input_nodes={arg.value.data: arg.value for arg in args}, + input_nodes={arg.value.data: arg.value for arg in input_args}, output_nodes={reduce_input_name: reduce_input_node}, debuginfo=di, ) @@ -1036,10 +1174,6 @@ def _visit_reduce(self, node: itir.FunCall): reduce_node, result_access, dace.Memlet(data=result_name, subset="0") ) - # we apply map fusion only to the nested-SDFG which is generated for the reduction operator - # the purpose is to keep the ITIR-visitor program simple and to clean up the generated SDFG - self.context.body.apply_transformations_repeated([MapFusion], validate=False) - return [ValueExpr(result_access, reduce_dtype)] def _visit_numeric_builtin(self, node: itir.FunCall) -> list[ValueExpr]: diff --git a/tests/next_tests/definitions.py b/tests/next_tests/definitions.py index c95292d702..56b220e0e9 100644 --- a/tests/next_tests/definitions.py +++ b/tests/next_tests/definitions.py @@ -155,7 +155,6 @@ class ProgramFormatterId(_PythonObjectIdMixin, str, enum.Enum): (USES_SPARSE_FIELDS_AS_OUTPUT, XFAIL, UNSUPPORTED_MESSAGE), ] DACE_SKIP_TEST_LIST = COMMON_SKIP_TEST_LIST + [ - (USES_CONSTANT_FIELDS, XFAIL, UNSUPPORTED_MESSAGE), (USES_INDEX_FIELDS, XFAIL, UNSUPPORTED_MESSAGE), (USES_LIFT_EXPRESSIONS, XFAIL, UNSUPPORTED_MESSAGE), (USES_ORIGIN, XFAIL, UNSUPPORTED_MESSAGE), From e462a2ec0e72e3d7079fb4fdd909160448044b4d Mon Sep 17 00:00:00 2001 From: edopao Date: Fri, 2 Feb 2024 16:35:41 +0100 Subject: [PATCH 68/85] feat[next][dace]: Add support for lift expressions in neighbor reductions (no unrolling) (#1431) Baseline dace backend forced unroll of neighbor reductions, in the ITIR pass, in order to eliminate all lift expressions. This PR adds support for lowering of lift expressions in neighbor reductions, thus avoiding the need to unroll reduce expressions. The result is a more compact SDFG, which leaves to the optimization backend the option of unrolling neighbor reductions. --- .../runners/dace_iterator/__init__.py | 20 +- .../runners/dace_iterator/itir_to_sdfg.py | 29 ++- .../runners/dace_iterator/itir_to_tasklet.py | 224 ++++++++++++++---- 3 files changed, 205 insertions(+), 68 deletions(-) diff --git a/src/gt4py/next/program_processors/runners/dace_iterator/__init__.py b/src/gt4py/next/program_processors/runners/dace_iterator/__init__.py index 2e9a66c435..fa28793187 100644 --- a/src/gt4py/next/program_processors/runners/dace_iterator/__init__.py +++ b/src/gt4py/next/program_processors/runners/dace_iterator/__init__.py @@ -69,28 +69,16 @@ def preprocess_program( program: itir.FencilDefinition, offset_provider: Mapping[str, Any], lift_mode: itir_transforms.LiftMode, + unroll_reduce: bool = False, ): - node = itir_transforms.apply_common_transforms( + return itir_transforms.apply_common_transforms( program, common_subexpression_elimination=False, + force_inline_lambda_args=True, lift_mode=lift_mode, offset_provider=offset_provider, - unroll_reduce=False, + unroll_reduce=unroll_reduce, ) - # If we don't unroll, there may be lifts left in the itir which can't be lowered to SDFG. - # In this case, just retry with unrolled reductions. - if all([ItirToSDFG._check_no_lifts(closure) for closure in node.closures]): - fencil_definition = node - else: - fencil_definition = itir_transforms.apply_common_transforms( - program, - common_subexpression_elimination=False, - force_inline_lambda_args=True, - lift_mode=lift_mode, - offset_provider=offset_provider, - unroll_reduce=True, - ) - return fencil_definition def get_args(sdfg: dace.SDFG, args: Sequence[Any]) -> dict[str, Any]: diff --git a/src/gt4py/next/program_processors/runners/dace_iterator/itir_to_sdfg.py b/src/gt4py/next/program_processors/runners/dace_iterator/itir_to_sdfg.py index 073c856d86..eaff9f467e 100644 --- a/src/gt4py/next/program_processors/runners/dace_iterator/itir_to_sdfg.py +++ b/src/gt4py/next/program_processors/runners/dace_iterator/itir_to_sdfg.py @@ -124,6 +124,24 @@ def _make_array_shape_and_strides( return shape, strides +def _check_no_lifts(node: itir.StencilClosure): + """ + Parse stencil closure ITIR to check that lift expressions only appear as child nodes in neighbor reductions. + + Returns + ------- + True if lifts do not appear in the ITIR exception lift expressions in neighbor reductions. False otherwise. + """ + neighbors_call_count = 0 + for fun in eve.walk_values(node).if_isinstance(itir.FunCall).getattr("fun"): + if getattr(fun, "id", "") == "neighbors": + neighbors_call_count = 3 + elif getattr(fun, "id", "") == "lift" and neighbors_call_count != 1: + return False + neighbors_call_count = max(0, neighbors_call_count - 1) + return True + + class ItirToSDFG(eve.NodeVisitor): param_types: list[ts.TypeSpec] storage_types: dict[str, ts.TypeSpec] @@ -262,7 +280,7 @@ def visit_FencilDefinition(self, node: itir.FencilDefinition): def visit_StencilClosure( self, node: itir.StencilClosure, array_table: dict[str, dace.data.Array] ) -> tuple[dace.SDFG, list[str], list[str]]: - assert ItirToSDFG._check_no_lifts(node) + assert _check_no_lifts(node) # Create the closure's nested SDFG and single state. closure_sdfg = dace.SDFG(name="closure") @@ -681,15 +699,6 @@ def _visit_domain( return tuple(sorted(bounds, key=lambda item: item[0])) - @staticmethod - def _check_no_lifts(node: itir.StencilClosure): - if any( - getattr(fun, "id", "") == "lift" - for fun in eve.walk_values(node).if_isinstance(itir.FunCall).getattr("fun") - ): - return False - return True - @staticmethod def _check_shift_offsets_are_literals(node: itir.StencilClosure): fun_calls = eve.walk_values(node).if_isinstance(itir.FunCall) diff --git a/src/gt4py/next/program_processors/runners/dace_iterator/itir_to_tasklet.py b/src/gt4py/next/program_processors/runners/dace_iterator/itir_to_tasklet.py index 56ffe7e104..773a3a61f7 100644 --- a/src/gt4py/next/program_processors/runners/dace_iterator/itir_to_tasklet.py +++ b/src/gt4py/next/program_processors/runners/dace_iterator/itir_to_tasklet.py @@ -181,6 +181,126 @@ def __init__( self.reduce_identity = reduce_identity +def _visit_lift_in_neighbors_reduction( + transformer: "PythonTaskletCodegen", + node: itir.FunCall, + node_args: Sequence[IteratorExpr | list[ValueExpr]], + offset_provider: NeighborTableOffsetProvider, + map_entry: dace.nodes.MapEntry, + map_exit: dace.nodes.MapExit, + neighbor_index_node: dace.nodes.AccessNode, + neighbor_value_node: dace.nodes.AccessNode, +) -> list[ValueExpr]: + neighbor_dim = offset_provider.neighbor_axis.value + origin_dim = offset_provider.origin_axis.value + + lifted_args: list[IteratorExpr | ValueExpr] = [] + for arg in node_args: + if isinstance(arg, IteratorExpr): + if origin_dim in arg.indices: + lifted_indices = arg.indices.copy() + lifted_indices.pop(origin_dim) + lifted_indices[neighbor_dim] = neighbor_index_node + lifted_args.append( + IteratorExpr( + arg.field, + lifted_indices, + arg.dtype, + arg.dimensions, + ) + ) + else: + lifted_args.append(arg) + else: + lifted_args.append(arg[0]) + + lift_context, inner_inputs, inner_outputs = transformer.visit(node.args[0], args=lifted_args) + assert len(inner_outputs) == 1 + inner_out_connector = inner_outputs[0].value.data + + input_nodes = {} + iterator_index_nodes = {} + lifted_index_connectors = set() + + for x, y in inner_inputs: + if isinstance(y, IteratorExpr): + field_connector, inner_index_table = x + input_nodes[field_connector] = y.field + for dim, connector in inner_index_table.items(): + if dim == neighbor_dim: + lifted_index_connectors.add(connector) + iterator_index_nodes[connector] = y.indices[dim] + else: + assert isinstance(y, ValueExpr) + input_nodes[x] = y.value + + neighbor_tables = filter_neighbor_tables(transformer.offset_provider) + connectivity_names = [connectivity_identifier(offset) for offset in neighbor_tables.keys()] + + parent_sdfg = transformer.context.body + parent_state = transformer.context.state + + input_mapping = { + connector: create_memlet_full(node.data, node.desc(parent_sdfg)) + for connector, node in input_nodes.items() + } + connectivity_mapping = { + name: create_memlet_full(name, parent_sdfg.arrays[name]) for name in connectivity_names + } + array_mapping = {**input_mapping, **connectivity_mapping} + symbol_mapping = map_nested_sdfg_symbols(parent_sdfg, lift_context.body, array_mapping) + + nested_sdfg_node = parent_state.add_nested_sdfg( + lift_context.body, + parent_sdfg, + inputs={*array_mapping.keys(), *iterator_index_nodes.keys()}, + outputs={inner_out_connector}, + symbol_mapping=symbol_mapping, + debuginfo=lift_context.body.debuginfo, + ) + + for connectivity_connector, memlet in connectivity_mapping.items(): + parent_state.add_memlet_path( + parent_state.add_access(memlet.data, debuginfo=lift_context.body.debuginfo), + map_entry, + nested_sdfg_node, + dst_conn=connectivity_connector, + memlet=memlet, + ) + + for inner_connector, access_node in input_nodes.items(): + parent_state.add_memlet_path( + access_node, + map_entry, + nested_sdfg_node, + dst_conn=inner_connector, + memlet=input_mapping[inner_connector], + ) + + for inner_connector, access_node in iterator_index_nodes.items(): + memlet = dace.Memlet(data=access_node.data, subset="0") + if inner_connector in lifted_index_connectors: + parent_state.add_edge(access_node, None, nested_sdfg_node, inner_connector, memlet) + else: + parent_state.add_memlet_path( + access_node, + map_entry, + nested_sdfg_node, + dst_conn=inner_connector, + memlet=memlet, + ) + + parent_state.add_memlet_path( + nested_sdfg_node, + map_exit, + neighbor_value_node, + src_conn=inner_out_connector, + memlet=dace.Memlet(data=neighbor_value_node.data, subset=",".join(map_entry.params)), + ) + + return [ValueExpr(neighbor_value_node, inner_outputs[0].dtype)] + + def builtin_neighbors( transformer: "PythonTaskletCodegen", node: itir.Expr, node_args: list[itir.Expr] ) -> list[ValueExpr]: @@ -198,7 +318,16 @@ def builtin_neighbors( "Neighbor reduction only implemented for connectivity based on neighbor tables." ) - iterator = transformer.visit(data) + lift_node = None + if isinstance(data, FunCall): + assert isinstance(data.fun, itir.FunCall) + fun_node = data.fun + if isinstance(fun_node.fun, itir.SymRef) and fun_node.fun.id == "lift": + lift_node = fun_node + lift_args = transformer.visit(data.args) + iterator = next(filter(lambda x: isinstance(x, IteratorExpr), lift_args), None) + if lift_node is None: + iterator = transformer.visit(data) assert isinstance(iterator, IteratorExpr) field_desc = iterator.field.desc(transformer.context.body) origin_index_node = iterator.indices[offset_provider.origin_axis.value] @@ -259,44 +388,56 @@ def builtin_neighbors( dace.Memlet(data=neighbor_index_var, subset="0"), ) - data_access_tasklet = state.add_tasklet( - "data_access", - code="__data = __field[__idx]" - + ( - f" if __idx != {neighbor_skip_value} else {transformer.context.reduce_identity.value}" - if offset_provider.has_skip_values - else "" - ), - inputs={"__field", "__idx"}, - outputs={"__data"}, - debuginfo=di, - ) - # select full shape only in the neighbor-axis dimension - field_subset = tuple( - f"0:{shape}" if dim == offset_provider.neighbor_axis.value else f"i_{dim}" - for dim, shape in zip(sorted(iterator.dimensions), field_desc.shape) - ) - state.add_memlet_path( - iterator.field, - me, - data_access_tasklet, - memlet=create_memlet_at(iterator.field.data, field_subset), - dst_conn="__field", - ) - state.add_edge( - neighbor_index_node, - None, - data_access_tasklet, - "__idx", - dace.Memlet(data=neighbor_index_var, subset="0"), - ) - state.add_memlet_path( - data_access_tasklet, - mx, - neighbor_value_node, - memlet=dace.Memlet(data=neighbor_value_var, subset=neighbor_map_index, debuginfo=di), - src_conn="__data", - ) + if lift_node is not None: + _visit_lift_in_neighbors_reduction( + transformer, + lift_node, + lift_args, + offset_provider, + me, + mx, + neighbor_index_node, + neighbor_value_node, + ) + else: + data_access_tasklet = state.add_tasklet( + "data_access", + code="__data = __field[__idx]" + + ( + f" if __idx != {neighbor_skip_value} else {transformer.context.reduce_identity.value}" + if offset_provider.has_skip_values + else "" + ), + inputs={"__field", "__idx"}, + outputs={"__data"}, + debuginfo=di, + ) + # select full shape only in the neighbor-axis dimension + field_subset = tuple( + f"0:{shape}" if dim == offset_provider.neighbor_axis.value else f"i_{dim}" + for dim, shape in zip(sorted(iterator.dimensions), field_desc.shape) + ) + state.add_memlet_path( + iterator.field, + me, + data_access_tasklet, + memlet=create_memlet_at(iterator.field.data, field_subset), + dst_conn="__field", + ) + state.add_edge( + neighbor_index_node, + None, + data_access_tasklet, + "__idx", + dace.Memlet(data=neighbor_index_var, subset="0"), + ) + state.add_memlet_path( + data_access_tasklet, + mx, + neighbor_value_node, + memlet=dace.Memlet(data=neighbor_value_var, subset=neighbor_map_index, debuginfo=di), + src_conn="__data", + ) if not offset_provider.has_skip_values: return [ValueExpr(neighbor_value_node, iterator.dtype)] @@ -377,9 +518,8 @@ def builtin_can_deref( # create tasklet to check that field indices are non-negative (-1 is invalid) args = [ValueExpr(access_node, _INDEX_DTYPE) for access_node in iterator.indices.values()] internals = [f"{arg.value.data}_v" for arg in args] - expr_code = " and ".join([f"{v} >= 0" for v in internals]) + expr_code = " and ".join(f"{v} != {neighbor_skip_value}" for v in internals) - # TODO(edopao): select-memlet could maybe allow to efficiently translate can_deref to predicative execution return transformer.add_expr_tasklet( list(zip(args, internals)), expr_code, @@ -946,7 +1086,7 @@ def _visit_shift(self, node: itir.FunCall) -> IteratorExpr | list[ValueExpr]: iterator = self.visit(node.args[0]) if not isinstance(iterator, IteratorExpr): # shift cannot be applied because the argument is not iterable - # TODO: remove this special case when ITIR reduce-unroll pass is able to catch it + # TODO: remove this special case when ITIR pass is able to catch it assert isinstance(iterator, list) and len(iterator) == 1 assert isinstance(iterator[0], ValueExpr) return iterator From 6509dd963878f4f1ee0c50c5a00746e33d6e4c38 Mon Sep 17 00:00:00 2001 From: edopao Date: Tue, 6 Feb 2024 07:34:40 +0100 Subject: [PATCH 69/85] feat[next][dace]: DaCe support for temporaries (#1351) Temporaries are implemented in DaCe backend as transient arrays. This PR adds extraction of temporaries and generation of corresponding transient arrays in the SDFG representation. --- .../runners/dace_iterator/__init__.py | 34 +++++-- .../runners/dace_iterator/itir_to_sdfg.py | 92 ++++++++++++++++++- .../runners/dace_iterator/utility.py | 8 ++ 3 files changed, 124 insertions(+), 10 deletions(-) diff --git a/src/gt4py/next/program_processors/runners/dace_iterator/__init__.py b/src/gt4py/next/program_processors/runners/dace_iterator/__init__.py index fa28793187..432bf3e1bf 100644 --- a/src/gt4py/next/program_processors/runners/dace_iterator/__init__.py +++ b/src/gt4py/next/program_processors/runners/dace_iterator/__init__.py @@ -22,6 +22,7 @@ from dace.codegen.compiled_sdfg import CompiledSDFG from dace.sdfg import utils as sdutils from dace.transformation.auto import auto_optimize as autoopt +from dace.transformation.interstate import RefineNestedAccess import gt4py.next.allocators as next_allocators import gt4py.next.iterator.ir as itir @@ -71,7 +72,7 @@ def preprocess_program( lift_mode: itir_transforms.LiftMode, unroll_reduce: bool = False, ): - return itir_transforms.apply_common_transforms( + node = itir_transforms.apply_common_transforms( program, common_subexpression_elimination=False, force_inline_lambda_args=True, @@ -80,6 +81,21 @@ def preprocess_program( unroll_reduce=unroll_reduce, ) + if isinstance(node, itir_transforms.global_tmps.FencilWithTemporaries): + fencil_definition = node.fencil + tmps = node.tmps + + elif isinstance(node, itir.FencilDefinition): + fencil_definition = node + tmps = [] + + else: + raise TypeError( + f"Expected 'FencilDefinition' or 'FencilWithTemporaries', got '{type(program).__name__}'." + ) + + return fencil_definition, tmps + def get_args(sdfg: dace.SDFG, args: Sequence[Any]) -> dict[str, Any]: sdfg_params: Sequence[str] = sdfg.arg_names @@ -160,6 +176,7 @@ def get_stride_args( def get_cache_id( build_type: str, build_for_gpu: bool, + lift_mode: itir_transforms.LiftMode, program: itir.FencilDefinition, arg_types: Sequence[ts.TypeSpec], column_axis: Optional[common.Dimension], @@ -185,6 +202,7 @@ def offset_invariants(offset): for arg in ( build_type, build_for_gpu, + lift_mode, program, *arg_types, column_axis, @@ -272,17 +290,17 @@ def build_sdfg_from_itir( sdfg.validate() return sdfg - # TODO(edopao): As temporary fix until temporaries are supported in the DaCe Backend force - # `lift_more` to `FORCE_INLINE` mode. - lift_mode = itir_transforms.LiftMode.FORCE_INLINE arg_types = [type_translation.from_value(arg) for arg in args] # visit ITIR and generate SDFG - program = preprocess_program(program, offset_provider, lift_mode) - sdfg_genenerator = ItirToSDFG(arg_types, offset_provider, column_axis) + program, tmps = preprocess_program(program, offset_provider, lift_mode) + sdfg_genenerator = ItirToSDFG(arg_types, offset_provider, tmps, column_axis) sdfg = sdfg_genenerator.visit(program) if sdfg is None: raise RuntimeError(f"Visit failed for program {program.id}.") + elif tmps: + # This pass is needed to avoid transformation errors in SDFG inlining, because temporaries are using offsets + sdfg.apply_transformations_repeated(RefineNestedAccess) for nested_sdfg in sdfg.all_sdfgs_recursive(): if not nested_sdfg.debuginfo: @@ -338,7 +356,9 @@ def run_dace_iterator(program: itir.FencilDefinition, *args, **kwargs): arg_types = [type_translation.from_value(arg) for arg in args] - cache_id = get_cache_id(build_type, on_gpu, program, arg_types, column_axis, offset_provider) + cache_id = get_cache_id( + build_type, on_gpu, lift_mode, program, arg_types, column_axis, offset_provider + ) if build_cache is not None and cache_id in build_cache: # retrieve SDFG program from build cache sdfg_program = build_cache[cache_id] diff --git a/src/gt4py/next/program_processors/runners/dace_iterator/itir_to_sdfg.py b/src/gt4py/next/program_processors/runners/dace_iterator/itir_to_sdfg.py index eaff9f467e..a578e9c19b 100644 --- a/src/gt4py/next/program_processors/runners/dace_iterator/itir_to_sdfg.py +++ b/src/gt4py/next/program_processors/runners/dace_iterator/itir_to_sdfg.py @@ -19,8 +19,12 @@ import gt4py.eve as eve from gt4py.next import Dimension, DimensionKind, type_inference as next_typing from gt4py.next.common import NeighborTable -from gt4py.next.iterator import ir as itir, type_inference as itir_typing -from gt4py.next.iterator.ir import Expr, FunCall, Literal, SymRef +from gt4py.next.iterator import ( + ir as itir, + transforms as itir_transforms, + type_inference as itir_typing, +) +from gt4py.next.iterator.ir import Expr, FunCall, Literal, Sym, SymRef from gt4py.next.type_system import type_specifications as ts, type_translation from .itir_to_tasklet import ( @@ -36,6 +40,7 @@ from .utility import ( add_mapped_nested_sdfg, as_dace_type, + as_scalar_type, connectivity_identifier, create_memlet_at, create_memlet_full, @@ -44,6 +49,7 @@ flatten_list, get_sorted_dims, map_nested_sdfg_symbols, + new_array_symbols, unique_name, unique_var_name, ) @@ -154,12 +160,14 @@ def __init__( self, param_types: list[ts.TypeSpec], offset_provider: dict[str, NeighborTable], + tmps: list[itir_transforms.global_tmps.Temporary], column_axis: Optional[Dimension] = None, ): self.param_types = param_types self.column_axis = column_axis self.offset_provider = offset_provider self.storage_types = {} + self.tmps = tmps def add_storage( self, @@ -189,6 +197,70 @@ def add_storage( raise NotImplementedError() self.storage_types[name] = type_ + def add_storage_for_temporaries( + self, node_params: list[Sym], defs_state: dace.SDFGState, program_sdfg: dace.SDFG + ) -> dict[str, str]: + symbol_map: dict[str, TaskletExpr] = {} + # The shape of temporary arrays might be defined based on scalar values passed as program arguments. + # Here we collect these values in a symbol map. + tmp_ids = set(tmp.id for tmp in self.tmps) + for sym in node_params: + if sym.id not in tmp_ids and sym.kind != "Iterator": + name_ = str(sym.id) + type_ = self.storage_types[name_] + assert isinstance(type_, ts.ScalarType) + symbol_map[name_] = SymbolExpr(name_, as_dace_type(type_)) + + tmp_symbols: dict[str, str] = {} + for tmp in self.tmps: + tmp_name = str(tmp.id) + + # We visit the domain of the temporary field, passing the set of available symbols. + assert isinstance(tmp.domain, itir.FunCall) + self.node_types.update(itir_typing.infer_all(tmp.domain)) + domain_ctx = Context(program_sdfg, defs_state, symbol_map) + tmp_domain = self._visit_domain(tmp.domain, domain_ctx) + + # We build the FieldType for this temporary array. + dims: list[Dimension] = [] + for dim, _ in tmp_domain: + dims.append( + Dimension( + value=dim, + kind=( + DimensionKind.VERTICAL + if self.column_axis is not None and self.column_axis.value == dim + else DimensionKind.HORIZONTAL + ), + ) + ) + assert isinstance(tmp.dtype, str) + type_ = ts.FieldType(dims=dims, dtype=as_scalar_type(tmp.dtype)) + self.storage_types[tmp_name] = type_ + + # N.B.: skip generation of symbolic strides and just let dace assign default strides, for now. + # Another option, in the future, is to use symbolic strides and apply auto-tuning or some heuristics + # to assign optimal stride values. + tmp_shape, _ = new_array_symbols(tmp_name, len(dims)) + tmp_offset = [ + dace.symbol(unique_name(f"{tmp_name}_offset{i}")) for i in range(len(dims)) + ] + _, tmp_array = program_sdfg.add_array( + tmp_name, tmp_shape, as_dace_type(type_.dtype), offset=tmp_offset, transient=True + ) + + # Loop through all dimensions to visit the symbolic expressions for array shape and offset. + # These expressions are later mapped to interstate symbols. + for (_, (begin, end)), offset_sym, shape_sym in zip( + tmp_domain, + tmp_array.offset, + tmp_array.shape, + ): + tmp_symbols[str(offset_sym)] = f"0 - {begin.value}" + tmp_symbols[str(shape_sym)] = f"{end.value} - {begin.value}" + + return tmp_symbols + def get_output_nodes( self, closure: itir.StencilClosure, sdfg: dace.SDFG, state: dace.SDFGState ) -> dict[str, dace.nodes.AccessNode]: @@ -204,7 +276,7 @@ def get_output_nodes( def visit_FencilDefinition(self, node: itir.FencilDefinition): program_sdfg = dace.SDFG(name=node.id) program_sdfg.debuginfo = dace_debuginfo(node) - last_state = program_sdfg.add_state("program_entry", True) + entry_state = program_sdfg.add_state("program_entry", is_start_block=True) self.node_types = itir_typing.infer_all(node) # Filter neighbor tables from offset providers. @@ -214,6 +286,20 @@ def visit_FencilDefinition(self, node: itir.FencilDefinition): for param, type_ in zip(node.params, self.param_types): self.add_storage(program_sdfg, str(param.id), type_, neighbor_tables) + if self.tmps: + tmp_symbols = self.add_storage_for_temporaries(node.params, entry_state, program_sdfg) + # on the first interstate edge define symbols for shape and offsets of temporary arrays + last_state = program_sdfg.add_state("init_symbols_for_temporaries") + program_sdfg.add_edge( + entry_state, + last_state, + dace.InterstateEdge( + assignments=tmp_symbols, + ), + ) + else: + last_state = entry_state + # Add connectivities as SDFG storages. for offset, offset_provider in neighbor_tables.items(): scalar_kind = type_translation.get_scalar_kind(offset_provider.table.dtype) diff --git a/src/gt4py/next/program_processors/runners/dace_iterator/utility.py b/src/gt4py/next/program_processors/runners/dace_iterator/utility.py index 49dd2472c5..0c3fd741d5 100644 --- a/src/gt4py/next/program_processors/runners/dace_iterator/utility.py +++ b/src/gt4py/next/program_processors/runners/dace_iterator/utility.py @@ -51,6 +51,14 @@ def as_dace_type(type_: ts.ScalarType): raise ValueError(f"Scalar type '{type_}' not supported.") +def as_scalar_type(typestr: str) -> ts.ScalarType: + try: + kind = getattr(ts.ScalarKind, typestr.upper()) + except AttributeError: + raise ValueError(f"Data type {typestr} not supported.") + return ts.ScalarType(kind) + + def filter_neighbor_tables(offset_provider: dict[str, Any]): return { offset: table From d95bf89b34c425dedc46327cd51b0d9649db967a Mon Sep 17 00:00:00 2001 From: Hannes Vogt Date: Tue, 6 Feb 2024 17:30:43 +0100 Subject: [PATCH 70/85] bug[next]: fix field_operator caching (#1445) The cache was copied in `with_backend`, but backend is not part of the hash. Now the cache will be empty after `with_backend`. --- src/gt4py/next/ffront/decorator.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/src/gt4py/next/ffront/decorator.py b/src/gt4py/next/ffront/decorator.py index 6510be560e..a556d0ea34 100644 --- a/src/gt4py/next/ffront/decorator.py +++ b/src/gt4py/next/ffront/decorator.py @@ -562,7 +562,9 @@ class FieldOperator(GTCallable, Generic[OperatorNodeT]): backend: Optional[ppi.ProgramExecutor] grid_type: Optional[GridType] operator_attributes: Optional[dict[str, Any]] = None - _program_cache: dict = dataclasses.field(default_factory=dict) + _program_cache: dict = dataclasses.field( + init=False, default_factory=dict + ) # init=False ensure the cache is not copied in calls to replace @classmethod def from_function( From cbc34dde69f97cdc759ddf4537a4bf9e9f09171d Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Enrique=20Gonz=C3=A1lez=20Paredes?= Date: Tue, 6 Feb 2024 20:48:44 +0100 Subject: [PATCH 71/85] build: Update requirements versions and scripts to support 3.11 (#1444) Update requirements-related scripts and minimum package versions to support for 3.11 and other stability improvements. Main changes: - Fix `tomli` compatibility with python 3.11 - Add support for python 3 .11 in package metadata - Add new `all-` extra requirement groups - Update minimum versions of: cython, cytoolz, numpy, pybind11, scipy, setuptools, wheel - Add jax-cpu and other dependencies to the frozen requirements - Replace regular expression hacks with proper requirement constraint parsing using `packaging`. - Remove `cog` processing of `requirements-dev.in` by simplifying the specification of the constraints shared with `pyproject.toml` - Rename tox task to update requirements - Update developers' documentation on how to update requirements - Update ruff configuration --- .pre-commit-config.yaml | 20 +- constraints.txt | 47 ++- docs/development/tools/requirements.md | 8 +- min-extra-requirements-test.txt | 79 +++- min-requirements-test.txt | 73 +++- pyproject.toml | 101 +++-- requirements-dev.in | 24 +- requirements-dev.txt | 535 +++++++++++++------------ tox.ini | 50 ++- 9 files changed, 531 insertions(+), 406 deletions(-) diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index 3f26dfea55..62bf3ce0ab 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -67,6 +67,18 @@ repos: hooks: - id: black +# - repo: https://github.com/charliermarsh/ruff-pre-commit +# ##[[[cog +# ## import re +# ## version = re.search('ruff==([0-9\.]*)', open("constraints.txt").read())[1] +# ## print(f"# rev: 'v{version}' # version from constraints.txt") +# ##]]] +# rev: 'v0.2.0' # version from constraints.txt +# ##[[[end]]] +# hooks: +# - id: ruff +# # args: [ --fix, --exit-non-zero-on-fix ] + - repo: https://github.com/PyCQA/isort ##[[[cog ## import re @@ -153,9 +165,13 @@ repos: - id: mypy additional_dependencies: # versions from constraints.txt ##[[[cog - ## import re, tomli + ## import re, sys + ## if sys.version_info >= (3, 11): + ## import tomllib + ## else: + ## import tomli as tomllib ## constraints = open("constraints.txt").read() - ## project = tomli.loads(open("pyproject.toml").read()) + ## project = tomllib.loads(open("pyproject.toml").read()) ## packages = [re.match('^([\w-][\w\d-]*)', r)[1] for r in project["project"]["dependencies"] if r.strip()] ## for pkg in packages: ## print(f"- {pkg}==" + str(re.search(f'\n{pkg}==([0-9\.]*)', constraints)[1])) diff --git a/constraints.txt b/constraints.txt index 3b32e53c0c..61bc04e671 100644 --- a/constraints.txt +++ b/constraints.txt @@ -2,7 +2,7 @@ # This file is autogenerated by pip-compile with Python 3.8 # by the following command: # -# "tox run -e requirements-common" +# "tox run -e requirements-base" # aenum==3.1.15 # via dace alabaster==0.7.13 # via sphinx @@ -17,7 +17,7 @@ boltons==23.1.1 # via gt4py (pyproject.toml) build==1.0.3 # via pip-tools cached-property==1.5.2 # via gt4py (pyproject.toml) cachetools==5.3.2 # via tox -certifi==2023.11.17 # via requests +certifi==2024.2.2 # via requests cffi==1.16.0 # via cryptography cfgv==3.4.0 # via pre-commit chardet==5.2.0 # via tox @@ -29,7 +29,7 @@ cogapp==3.3.0 # via -r requirements-dev.in colorama==0.4.6 # via tox comm==0.2.1 # via ipykernel contourpy==1.1.1 # via matplotlib -coverage==7.4.1 # via -r requirements-dev.in, coverage, pytest-cov +coverage==7.4.1 # via -r requirements-dev.in, pytest-cov cryptography==42.0.2 # via types-paramiko, types-pyopenssl, types-redis cycler==0.12.1 # via matplotlib cytoolz==0.12.3 # via gt4py (pyproject.toml) @@ -47,7 +47,7 @@ exceptiongroup==1.2.0 # via hypothesis, pytest execnet==2.0.2 # via pytest-cache, pytest-xdist executing==2.0.1 # via devtools, stack-data factory-boy==3.3.0 # via -r requirements-dev.in, pytest-factoryboy -faker==22.6.0 # via factory-boy +faker==22.7.0 # via factory-boy fastjsonschema==2.19.1 # via nbformat filelock==3.13.1 # via tox, virtualenv flake8==7.0.0 # via -r requirements-dev.in, flake8-bugbear, flake8-builtins, flake8-debugger, flake8-docstrings, flake8-eradicate, flake8-mutable, flake8-pyproject, flake8-rst-docstrings @@ -59,23 +59,25 @@ flake8-eradicate==1.5.0 # via -r requirements-dev.in flake8-mutable==1.2.0 # via -r requirements-dev.in flake8-pyproject==1.2.3 # via -r requirements-dev.in flake8-rst-docstrings==0.3.0 # via -r requirements-dev.in -flask==3.0.1 # via dace +flask==3.0.2 # via dace fonttools==4.47.2 # via matplotlib fparser==0.1.4 # via dace frozendict==2.4.0 # via gt4py (pyproject.toml) gridtools-cpp==2.3.2 # via gt4py (pyproject.toml) -hypothesis==6.97.3 # via -r requirements-dev.in, gt4py (pyproject.toml) +hypothesis==6.98.2 # via -r requirements-dev.in, gt4py (pyproject.toml) identify==2.5.33 # via pre-commit idna==3.6 # via requests imagesize==1.4.1 # via sphinx -importlib-metadata==7.0.1 # via build, flask, jupyter-client, sphinx +importlib-metadata==7.0.1 # via build, flask, jax, jupyter-client, sphinx importlib-resources==6.1.1 ; python_version < "3.9" # via gt4py (pyproject.toml), jsonschema, jsonschema-specifications, matplotlib inflection==0.5.1 # via pytest-factoryboy iniconfig==2.0.0 # via pytest -ipykernel==6.29.0 # via nbmake +ipykernel==6.29.1 # via nbmake ipython==8.12.3 # via ipykernel isort==5.13.2 # via -r requirements-dev.in itsdangerous==2.1.2 # via flask +jax==0.4.13 # via gt4py (pyproject.toml) +jaxlib==0.4.13 # via jax jedi==0.19.1 # via ipython jinja2==3.1.3 # via flask, gt4py (pyproject.toml), sphinx jsonschema==4.21.1 # via nbformat @@ -87,12 +89,13 @@ kiwisolver==1.4.5 # via matplotlib lark==1.1.9 # via gt4py (pyproject.toml) mako==1.3.2 # via gt4py (pyproject.toml) markdown-it-py==3.0.0 # via jupytext, mdit-py-plugins -markupsafe==2.1.4 # via jinja2, mako, werkzeug +markupsafe==2.1.5 # via jinja2, mako, werkzeug matplotlib==3.7.4 # via -r requirements-dev.in matplotlib-inline==0.1.6 # via ipykernel, ipython mccabe==0.7.0 # via flake8 mdit-py-plugins==0.4.0 # via jupytext mdurl==0.1.2 # via markdown-it-py +ml-dtypes==0.2.0 # via jax, jaxlib mpmath==1.3.0 # via sympy mypy==1.8.0 # via -r requirements-dev.in mypy-extensions==1.0.0 # via black, mypy @@ -104,7 +107,8 @@ nest-asyncio==1.6.0 # via ipykernel, nbclient networkx==3.1 # via dace ninja==1.11.1.1 # via gt4py (pyproject.toml) nodeenv==1.8.0 # via pre-commit -numpy==1.24.4 # via contourpy, dace, gt4py (pyproject.toml), matplotlib, types-jack-client +numpy==1.24.4 # via contourpy, dace, gt4py (pyproject.toml), jax, jaxlib, matplotlib, ml-dtypes, opt-einsum, scipy, types-jack-client +opt-einsum==3.3.0 # via jax ordered-set==4.1.0 # via deepdiff packaging==23.2 # via black, build, gt4py (pyproject.toml), ipykernel, jupytext, matplotlib, pyproject-api, pytest, setuptools-scm, sphinx, tox parso==0.8.3 # via jedi @@ -136,16 +140,17 @@ pytest==8.0.0 # via -r requirements-dev.in, gt4py (pyproject.toml), pytest-cache==1.0 # via -r requirements-dev.in pytest-cov==4.1.0 # via -r requirements-dev.in pytest-factoryboy==2.6.0 # via -r requirements-dev.in -pytest-xdist==3.5.0 # via -r requirements-dev.in, pytest-xdist +pytest-xdist==3.5.0 # via -r requirements-dev.in python-dateutil==2.8.2 # via faker, jupyter-client, matplotlib -pytz==2023.4 # via babel +pytz==2024.1 # via babel pyyaml==6.0.1 # via dace, jupytext, pre-commit pyzmq==25.1.2 # via ipykernel, jupyter-client referencing==0.33.0 # via jsonschema, jsonschema-specifications requests==2.31.0 # via dace, sphinx restructuredtext-lint==1.4.0 # via flake8-rst-docstrings rpds-py==0.17.1 # via jsonschema, referencing -ruff==0.1.15 # via -r requirements-dev.in +ruff==0.2.1 # via -r requirements-dev.in +scipy==1.10.1 # via gt4py (pyproject.toml), jax, jaxlib setuptools-scm==8.0.4 # via fparser six==1.16.0 # via asttokens, astunparse, python-dateutil snowballstemmer==2.2.0 # via pydocstyle, sphinx @@ -163,7 +168,7 @@ stack-data==0.6.3 # via ipython sympy==1.9 # via dace, gt4py (pyproject.toml) tabulate==0.9.0 # via gt4py (pyproject.toml) toml==0.10.2 # via jupytext -tomli==2.0.1 # via -r requirements-dev.in, black, build, coverage, flake8-pyproject, mypy, pip-tools, pyproject-api, pyproject-hooks, pytest, setuptools-scm, tox +tomli==2.0.1 ; python_version < "3.11" # via -r requirements-dev.in, black, build, coverage, flake8-pyproject, mypy, pip-tools, pyproject-api, pyproject-hooks, pytest, setuptools-scm, tox toolz==0.12.1 # via cytoolz tornado==6.4 # via ipykernel, jupyter-client tox==4.12.1 # via -r requirements-dev.in @@ -175,7 +180,7 @@ types-atomicwrites==1.4.5.1 # via types-all types-backports==0.1.3 # via types-all types-backports-abc==0.5.2 # via types-all types-bleach==6.1.0.1 # via types-all -types-boto==2.49.18.9 # via types-all +types-boto==2.49.18.20240205 # via types-all types-cachetools==5.3.0.7 # via types-all types-certifi==2021.10.8.3 # via types-all types-cffi==1.16.0.20240106 # via types-jack-client @@ -183,7 +188,7 @@ types-characteristic==14.3.7 # via types-all types-chardet==5.0.4.6 # via types-all types-click==7.1.8 # via types-all, types-flask types-click-spinner==0.1.13.20240106 # via types-all -types-colorama==0.4.15.20240106 # via types-all +types-colorama==0.4.15.20240205 # via types-all types-contextvars==2.4.7.3 # via types-all types-croniter==2.0.0.20240106 # via types-all types-cryptography==3.3.23.2 # via types-all, types-openssl-python, types-pyjwt @@ -193,7 +198,7 @@ types-datetimerange==2.0.0.6 # via types-all types-decorator==5.1.8.20240106 # via types-all types-deprecated==1.2.9.20240106 # via types-all types-docopt==0.6.11.4 # via types-all -types-docutils==0.20.0.20240126 # via types-all +types-docutils==0.20.0.20240201 # via types-all types-emoji==2.1.0.3 # via types-all types-enum34==1.1.8 # via types-all types-fb303==1.0.0 # via types-all, types-scribe @@ -217,9 +222,9 @@ types-mypy-extensions==1.0.0.5 # via types-all types-nmap==0.1.6 # via types-all types-openssl-python==0.1.3 # via types-all types-orjson==3.6.2 # via types-all -types-paramiko==3.4.0.20240120 # via types-all, types-pysftp +types-paramiko==3.4.0.20240205 # via types-all, types-pysftp types-pathlib2==2.3.0 # via types-all -types-pillow==10.2.0.20240125 # via types-all +types-pillow==10.2.0.20240206 # via types-all types-pkg-resources==0.1.3 # via types-all types-polib==1.2.0.20240115 # via types-all types-protobuf==4.24.0.20240129 # via types-all @@ -235,7 +240,7 @@ types-pysftp==0.2.17.20240106 # via types-all types-python-dateutil==2.8.19.20240106 # via types-all, types-datetimerange types-python-gflags==3.1.7.3 # via types-all types-python-slugify==8.0.2.20240127 # via types-all -types-pytz==2023.4.0.20240130 # via types-all, types-tzlocal +types-pytz==2024.1.0.20240203 # via types-all, types-tzlocal types-pyvmomi==8.0.0.6 # via types-all types-pyyaml==6.0.12.12 # via types-all types-redis==4.6.0.20240106 # via types-all @@ -268,5 +273,5 @@ xxhash==3.0.0 # via gt4py (pyproject.toml) zipp==3.17.0 # via importlib-metadata, importlib-resources # The following packages are considered to be unsafe in a requirements file: -pip==23.3.2 # via pip-tools +pip==24.0 # via pip-tools setuptools==69.0.3 # via gt4py (pyproject.toml), nodeenv, pip-tools, setuptools-scm diff --git a/docs/development/tools/requirements.md b/docs/development/tools/requirements.md index a58cf49ac3..3c76e9b80f 100644 --- a/docs/development/tools/requirements.md +++ b/docs/development/tools/requirements.md @@ -5,7 +5,7 @@ The specification of required third-party packages is scattered and partially du The following files in this repository contain information about required third-party packages: - `pyproject.toml`: GT4Py [package configuration](https://peps.python.org/pep-0621/) used by the build backend (`setuptools`). Install dependencies are specified in the _project.dependencies_ and _project.optional-dependencies_ tables. -- `requirements-dev.in`: [requirements file](https://pip.pypa.io/en/stable/reference/requirements-file-format/) used by **pip**. It contains a list of packages required for the development of GT4Py. Part of its content is generated automatically from `pyproject.toml` using **cog**. +- `requirements-dev.in`: [requirements file](https://pip.pypa.io/en/stable/reference/requirements-file-format/) used by **pip**. It contains a list of packages required for the development of GT4Py. - `requirements-dev.txt`: requirements file used by **pip**. It contains a completely frozen list of all packages required for installing and developing GT4Py. It is used by **pip** and **tox** to initialize the standard development and testing environments. It is automatically generated automatically from `requirements-dev.in` by **pip-compile**, when running the **tox** environment to update requirements. - `constraints.txt`: [constraints file](https://pip.pypa.io/en/stable/user_guide/#constraints-files) used by **pip** and **tox** to initialize a subset of the standard development environment making sure that if other packages are installed, transitive dependencies are taken from the frozen package list. It is generated automatically from `requirements-dev.in` using **pip-compile**. - `min-requirements-test.txt`: requirements file used by **pip**. It contains the minimum list of requirements to run GT4Py tests with the oldest compatible versions of all dependencies. It is generated automatically from `pyproject.toml` using **cog**. @@ -14,14 +14,14 @@ The following files in this repository contain information about required third- The expected workflow to update GT4Py requirements is as follows: -1. For changes in the GT4Py package dependencies, update the relevant table in `pyproject.toml`. When modifying the _project.optional-dependencies_ tables, make sure the `full` extra table **always** contains the combined dependencies from all the other extra tables. +1. For changes in the GT4Py package dependencies, update the relevant table in `pyproject.toml`. When adding new tables to the _project.optional-dependencies_ section, make sure to add the new table as a dependency of the `all-` extra tables when possible. 2. For changes in the development tools, update the `requirements-dev.in` file. -3. Run the **tox** _requirements-common_ environment to update all files automatically with **pip-compile** and **cog**. Note that **pip-compile** will most likely update the versions of some unrelated tools if new versions are available in PyPI. +3. Run the **tox** _requirements-base_ environment to update all files automatically with **pip-compile** and **cog**. Note that **pip-compile** will most likely update the versions of some unrelated tools if new versions are available in PyPI. ```bash - tox r -e requirements-common + tox r -e requirements-base ``` 4. Check that the **mypy** mirror used by **pre-commit** (https://github.com/pre-commit/mirrors-mypy) in `.pre-commit-config.yaml` supports the same version as in `constraints.txt`, and manually update the `rev` version number. diff --git a/min-extra-requirements-test.txt b/min-extra-requirements-test.txt index 1db48693be..d97ed299e1 100644 --- a/min-extra-requirements-test.txt +++ b/min-extra-requirements-test.txt @@ -1,20 +1,55 @@ +# +# Generated automatically by cog from pyproject.toml and requirements-dev.in +# Run: +# tox r -e requirements-common +# + ##[[[cog -## import re, tomli -## project = tomli.loads(open("pyproject.toml").read()) -## requirements = set( -## rr -## for r in ( -## project["project"]["dependencies"] -## + project["project"]["optional-dependencies"]["full"] -## + open("requirements-dev.in").readlines() -## ) -## if (rr := r.strip()) and not rr.startswith('#') -## ) -## for r in sorted(requirements): -## m = re.match("^([\w-][\w\d\[\]-]*)[=>~][=]([^,]+)", r) -## print(f"{m[1]}=={m[2].strip()}") +## import copy, sys +## from packaging import requirements as reqs, specifiers as specs +## if sys.version_info >= (3, 11): +## import tomllib +## else: +## import tomli as tomllib +## +## def make_min_req(r: reqs.Requirement) -> reqs.Requirement: +## for s in r.specifier: +## if (ss := str(s)).startswith(">"): +## assert ss.startswith(">="), f"'{r!s}' requires a '>=' constraint" +## min_spec = specs.SpecifierSet(f"=={ss[2:]}") +## break +## min_r = copy.deepcopy(r) +## min_r.specifier = min_spec +## return min_r +## +## project = tomllib.loads(open("pyproject.toml").read()) +## all_cpu_extra = project["project"]["optional-dependencies"]["all-cpu"] +## assert len(all_cpu_extra) == 1 and all_cpu_extra[0].startswith("gt4py[") +## opt_req_versions = { +## reqs.Requirement(r).name: reqs.Requirement(r) +## for e in reqs.Requirement(all_cpu_extra[0]).extras +## for r in project["project"]["optional-dependencies"][e] +## } +## requirements = [ +## reqs.Requirement(rr) +## for r in (project["project"]["dependencies"] + open("requirements-dev.in").readlines()) +## if (rr := (r[: r.find("#")] if "#" in r else r)) +## ] +## processed = set() +## result = [] +## for r in requirements: +## assert r.name not in processed +## processed.add(r.name) +## if not r.specifier: +## assert r.name in opt_req_versions, f"Missing contraints for '{r.name}'" +## r = opt_req_versions[r.name] +## result.append(str(make_min_req(r))) +## for r_name, r in opt_req_versions.items(): +## if r_name not in processed: +## result.append(str(make_min_req(r))) +## print("\n".join(sorted(result))) ##]]] -astunparse==1.6.3;python_version<'3.9' +astunparse==1.6.3; python_version < "3.9" attrs==21.3 black==22.3 boltons==20.1 @@ -24,7 +59,7 @@ click==8.0.0 cmake==3.22 cogapp==3.3 coverage[toml]==5.0 -cytoolz==0.12.0 +cytoolz==0.12.1 dace==0.15.1 darglint==1.6 deepdiff==5.6.0 @@ -42,7 +77,7 @@ flake8==5.0.4 frozendict==2.3 gridtools-cpp==2.3.2 hypothesis==6.0.0 -importlib-resources==5.0;python_version<'3.9' +importlib-resources==5.0; python_version < "3.9" isort==5.10 jax[cpu]==0.4.13 jinja2==3.0.0 @@ -54,27 +89,27 @@ mypy==1.0 nanobind==1.4.0 nbmake==1.4.6 ninja==1.10 -numpy==1.21.2 +numpy==1.23.3 packaging==20.0 pip-tools==6.10 pipdeptree==2.3 pre-commit==2.17 psutil==5.0 -pybind11==2.5 +pybind11==2.10.1 pygments==2.7.3 pytest-cache==1.0 pytest-cov==2.8 pytest-factoryboy==2.0.3 pytest-xdist[psutil]==2.4 pytest==7.0 -ruff==0.0.265 -scipy==1.7.2 +ruff==0.2.0 +scipy==1.9.2 setuptools==65.5.0 sphinx==4.4 sphinx_rtd_theme==1.0 sympy==1.9 tabulate==0.8.10 -tomli==2.0.1 +tomli==2.0.1; python_version < "3.11" tox==3.2.0 types-all==1.0.0 typing-extensions==4.2 diff --git a/min-requirements-test.txt b/min-requirements-test.txt index badf08864e..63553623a2 100644 --- a/min-requirements-test.txt +++ b/min-requirements-test.txt @@ -1,19 +1,52 @@ +# +# Generated automatically by cog from pyproject.toml and requirements-dev.in +# Run: +# tox r -e requirements-common +# + ##[[[cog -## import re, tomli -## project = tomli.loads(open("pyproject.toml").read()) -## requirements = set( -## rr -## for r in ( -## project["project"]["dependencies"] -## + open("requirements-dev.in").readlines() -## ) -## if (rr := r.strip()) and not rr.startswith('#') -## ) -## for r in sorted(requirements): -## m = re.match("^([\w-][\w\d\[\]-]*)[=>~][=]([^,]+)", r) -## print(f"{m[1]}=={m[2].strip()}") +## import copy, sys +## from packaging import requirements as reqs, specifiers as specs +## if sys.version_info >= (3, 11): +## import tomllib +## else: +## import tomli as tomllib +## +## def make_min_req(r: reqs.Requirement) -> reqs.Requirement: +## for s in r.specifier: +## if (ss := str(s)).startswith(">"): +## assert ss.startswith(">="), f"'{r!s}' requires a '>=' constraint" +## min_spec = specs.SpecifierSet(f"=={ss[2:]}") +## break +## min_r = copy.deepcopy(r) +## min_r.specifier = min_spec +## return min_r +## +## project = tomllib.loads(open("pyproject.toml").read()) +## all_cpu_extra = project["project"]["optional-dependencies"]["all-cpu"] +## assert len(all_cpu_extra) == 1 and all_cpu_extra[0].startswith("gt4py[") +## opt_req_versions = { +## reqs.Requirement(r).name: reqs.Requirement(r) +## for e in reqs.Requirement(all_cpu_extra[0]).extras +## for r in project["project"]["optional-dependencies"][e] +## } +## requirements = [ +## reqs.Requirement(rr) +## for r in (project["project"]["dependencies"] + open("requirements-dev.in").readlines()) +## if (rr := (r[: r.find("#")] if "#" in r else r)) +## ] +## processed = set() +## result = [] +## for r in requirements: +## assert r.name not in processed +## processed.add(r.name) +## if not r.specifier: +## assert r.name in opt_req_versions, f"Missing contraints for '{r.name}'" +## r = opt_req_versions[r.name] +## result.append(str(make_min_req(r))) +## print("\n".join(sorted(result))) ##]]] -astunparse==1.6.3;python_version<'3.9' +astunparse==1.6.3; python_version < "3.9" attrs==21.3 black==22.3 boltons==20.1 @@ -23,7 +56,7 @@ click==8.0.0 cmake==3.22 cogapp==3.3 coverage[toml]==5.0 -cytoolz==0.12.0 +cytoolz==0.12.1 darglint==1.6 deepdiff==5.6.0 devtools==0.6 @@ -40,7 +73,7 @@ flake8==5.0.4 frozendict==2.3 gridtools-cpp==2.3.2 hypothesis==6.0.0 -importlib-resources==5.0;python_version<'3.9' +importlib-resources==5.0; python_version < "3.9" isort==5.10 jinja2==3.0.0 jupytext==1.14 @@ -51,25 +84,25 @@ mypy==1.0 nanobind==1.4.0 nbmake==1.4.6 ninja==1.10 -numpy==1.21.2 +numpy==1.23.3 packaging==20.0 pip-tools==6.10 pipdeptree==2.3 pre-commit==2.17 psutil==5.0 -pybind11==2.5 +pybind11==2.10.1 pygments==2.7.3 pytest-cache==1.0 pytest-cov==2.8 pytest-factoryboy==2.0.3 pytest-xdist[psutil]==2.4 pytest==7.0 -ruff==0.0.265 +ruff==0.2.0 setuptools==65.5.0 sphinx==4.4 sphinx_rtd_theme==1.0 tabulate==0.8.10 -tomli==2.0.1 +tomli==2.0.1; python_version < "3.11" tox==3.2.0 types-all==1.0.0 typing-extensions==4.2 diff --git a/pyproject.toml b/pyproject.toml index 5fffb9cf0c..5297dfb45b 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -1,6 +1,6 @@ [build-system] build-backend = 'setuptools.build_meta' -requires = ['setuptools>=60.6', 'wheel', 'cython'] +requires = ['setuptools>=65.5.0', 'wheel>=0.33.6', 'cython>=0.29.13'] # ---- Project description ---- # -- Standard options (PEP 621) -- @@ -17,6 +17,7 @@ classifiers = [ 'Programming Language :: Python :: 3.8', 'Programming Language :: Python :: 3.9', 'Programming Language :: Python :: 3.10', + 'Programming Language :: Python :: 3.11', 'Programming Language :: Python :: Implementation :: CPython', 'Topic :: Scientific/Engineering :: Atmospheric Science', 'Topic :: Scientific/Engineering :: Mathematics', @@ -30,7 +31,7 @@ dependencies = [ 'cached-property>=1.5.1', 'click>=8.0.0', 'cmake>=3.22', - 'cytoolz>=0.12.0', + 'cytoolz>=0.12.1', 'deepdiff>=5.6.0', 'devtools>=0.6', 'frozendict>=2.3', @@ -41,9 +42,9 @@ dependencies = [ 'mako>=1.1', 'nanobind>=1.4.0 ', 'ninja>=1.10', - 'numpy>=1.21.2', + 'numpy>=1.23.3', 'packaging>=20.0', - 'pybind11>=2.5', + 'pybind11>=2.10.1', 'setuptools>=65.5.0', 'tabulate>=0.8.10', 'typing-extensions>=4.2,<4.6.0', @@ -66,23 +67,21 @@ readme = 'README.md' requires-python = '>=3.8' [project.optional-dependencies] -cuda = ['cupy>=12.0'] -cuda11x = ['cupy-cuda11x>=12.0'] -cuda12x = ['cupy-cuda12x>=12.0'] +# Bundles +all-cpu = ['gt4py[dace,formatting,jax-cpu,performance,testing]'] +all-cuda11 = ['gt4py[cuda11,dace,formatting,jax-cuda11,performance,testing]'] +all-cuda12 = ['gt4py[cuda12,dace,formatting,jax-cuda12,performance,testing]'] +# Other extras +cuda11 = ['cupy-cuda11x>=12.0'] +cuda12 = ['cupy-cuda12x>=12.0'] dace = ['dace>=0.15.1,<0.16', 'sympy>=1.9'] formatting = ['clang-format>=9.0'] -# Always add all extra packages to 'full' for a simple full gt4py installation -full = [ - 'clang-format>=9.0', - 'dace>=0.15.1,<0.16', - 'hypothesis>=6.0.0', - 'pytest>=7.0', - 'sympy>=1.9', - 'scipy>=1.7.2', - 'jax[cpu]>=0.4.13' -] -jax = ['jax[cpu]>=0.4.13'] -performance = ['scipy>=1.7.2'] +gpu = ['cupy>=12.0'] +jax-cpu = ['jax[cpu]>=0.4.13'] +jax-cuda11 = ['jax[cuda11_pip]>=0.4.13'] +jax-cuda12 = ['jax[cuda12_pip]>=0.4.13'] +performance = ['scipy>=1.9.2'] +rocm-43 = ['cupy-rocm-4-3'] testing = ['hypothesis>=6.0.0', 'pytest>=7.0'] [project.scripts] @@ -196,6 +195,7 @@ known_third_party = [ 'boltons', 'cached_property', 'click', + 'cupy', 'dace', 'devtools', 'factory', @@ -218,7 +218,14 @@ lexicographical = true line_length = 100 # It should be the same as in `tool.black.line-length` above lines_after_imports = 2 profile = 'black' -sections = ['FUTURE', 'STDLIB', 'THIRDPARTY', 'FIRSTPARTY', 'TESTS', 'LOCALFOLDER'] +sections = [ + 'FUTURE', + 'STDLIB', + 'THIRDPARTY', + 'FIRSTPARTY', + 'TESTS', + 'LOCALFOLDER' +] skip_gitignore = true skip_glob = ['*.venv/**', '_local/**'] @@ -364,33 +371,42 @@ testpaths = 'tests' # -- ruff -- [tool.ruff] -ignore = [ - 'E501', - 'B008', # Do not perform function calls in argument defaults - 'B028', # Consider replacing f"'{foo}'" with f"{foo!r}" # TODO: review - 'B905' # B905 `zip()` without an explicit `strict=` parameter # TODO: review -] -ignore-init-module-imports = true line-length = 100 # It should be the same as in `tool.black.line-length` above respect-gitignore = true -# Rules: +show-fixes = true +# show-source = true +target-version = 'py310' + +[tool.ruff.format] +docstring-code-format = true + +[tool.ruff.lint] +# # Rules sets: # E: pycodestyle # F: Pyflakes # I: isort # B: flake8-bugbear # A: flake8-builtins -# T100: flake8-debugger +# T10: flake8-debugger # ERA: eradicate # NPY: NumPy-specific rules # RUF: Ruff-specific rules -select = ['E', 'F', 'I', 'B', 'A', 'T100', 'ERA', 'NPY', 'RUF'] -show-fixes = true -# show-source = true -target-version = 'py310' +ignore = [ + 'B008', # Do not perform function calls in argument defaults + # 'B028', # Consider replacing f"'{foo}'" with f"{foo!r}" # TODO: review + 'B905', # B905 `zip()` without an explicit `strict=` parameter # TODO: review + # 'D1', # Public code object needs docstring + # 'E203', # Whitespace before ':' (black formatter breaks this sometimes) + 'E501', # Line too long (using Bugbear's B950 warning) + 'E701', # Multiple statements on one line, see https://github.com/psf/black/issues/3887 + 'RUF100' +] +ignore-init-module-imports = true +select = ['E', 'F', 'I', 'B', 'A', 'T10', 'ERA', 'NPY', 'RUF'] typing-modules = ['gt4py.eve.extended_typing'] unfixable = [] -[tool.ruff.isort] +[tool.ruff.lint.isort] combine-as-imports = true # force-wrap-aliases = true known-first-party = ['gt4py', '__externals__', '__gtscript__'] @@ -400,6 +416,7 @@ known-third-party = [ 'boltons', 'cached_property', 'click', + 'cupy', 'dace', 'devtools', 'factory', @@ -420,17 +437,25 @@ known-third-party = [ ] lines-after-imports = 2 order-by-type = true -section-order = ['future', 'standard-library', 'third-party', 'first-party', 'tests', 'local-folder'] +section-order = [ + 'future', + 'standard-library', + 'third-party', + 'first-party', + 'tests', + 'local-folder' +] split-on-trailing-comma = false -[tool.ruff.isort.sections] +[tool.ruff.lint.isort.sections] 'tests' = ['cartesian_tests', 'eve_tests', 'next_tests', 'storage_tests'] -[tool.ruff.mccabe] +[tool.ruff.lint.mccabe] max-complexity = 15 -[tool.ruff.per-file-ignores] +[tool.ruff.lint.per-file-ignores] 'src/gt4py/eve/extended_typing.py' = ['F401', 'F405'] +'src/gt4py/next/__init__.py' = ['F401'] # -- setuptools build backend -- [tool.setuptools] diff --git a/requirements-dev.in b/requirements-dev.in index 4bb05ecbc5..3c0a33898a 100644 --- a/requirements-dev.in +++ b/requirements-dev.in @@ -1,14 +1,10 @@ -##[[[cog -## import re, tomli -## project = tomli.loads(open("pyproject.toml").read()) -## versions = "\n".join(project["project"]["optional-dependencies"]["full"]) -## for pkg in ['hypothesis', 'pytest']: -## print(re.search(f"\n({pkg} *[=>~!].*)\n", versions)[1]) -##]]] -hypothesis>=6.0.0 -pytest>=7.0 -##[[[end]]] - +# +# Constraints should specify the minimum required version (>=). +# +# Packages also required in the extra `gt4py['all-cpu']` configuration +# should be added here without constraints, so they will use the +# constraints defined in `pyproject.toml`. +# clang-format>=9.0 cogapp>=3.3 coverage[toml]>=5.0 @@ -23,6 +19,7 @@ flake8-eradicate>=1.3.0 flake8-mutable>=1.2.0 flake8-pyproject>=1.2.2 flake8-rst-docstrings>=0.0.14 +hypothesis # constraints in gt4py['testing'] isort>=5.10 jupytext>=1.14 mypy>=1.0 @@ -33,13 +30,14 @@ pip-tools>=6.10 pre-commit>=2.17 psutil>=5.0 pygments>=2.7.3 +pytest # constraints in gt4py['testing'] pytest-cache>=1.0 pytest-cov>=2.8 pytest-factoryboy>=2.0.3 pytest-xdist[psutil]>=2.4 -ruff>=0.0.265 +ruff>=0.2.0 sphinx>=4.4 sphinx_rtd_theme>=1.0 -tomli>=2.0.1 +tomli>=2.0.1;python_version<'3.11' tox>=3.2.0 types-all>=1.0.0 diff --git a/requirements-dev.txt b/requirements-dev.txt index 94052ec478..84987138d5 100644 --- a/requirements-dev.txt +++ b/requirements-dev.txt @@ -2,271 +2,276 @@ # This file is autogenerated by pip-compile with Python 3.8 # by the following command: # -# "tox run -e requirements-common" +# "tox run -e requirements-base" # -aenum==3.1.15 # via dace -alabaster==0.7.13 # via sphinx -asttokens==2.4.1 # via devtools, stack-data -astunparse==1.6.3 ; python_version < "3.9" # via dace, gt4py (pyproject.toml) -attrs==23.2.0 # via flake8-bugbear, flake8-eradicate, gt4py (pyproject.toml), hypothesis, jsonschema, referencing -babel==2.14.0 # via sphinx -backcall==0.2.0 # via ipython -black==24.1.1 # via gt4py (pyproject.toml) -blinker==1.7.0 # via flask -boltons==23.1.1 # via gt4py (pyproject.toml) -build==1.0.3 # via pip-tools -cached-property==1.5.2 # via gt4py (pyproject.toml) -cachetools==5.3.2 # via tox -certifi==2023.11.17 # via requests -cffi==1.16.0 # via cryptography -cfgv==3.4.0 # via pre-commit -chardet==5.2.0 # via tox -charset-normalizer==3.3.2 # via requests -clang-format==17.0.6 # via -r requirements-dev.in, gt4py (pyproject.toml) -click==8.1.7 # via black, flask, gt4py (pyproject.toml), pip-tools -cmake==3.28.1 # via gt4py (pyproject.toml) -cogapp==3.3.0 # via -r requirements-dev.in -colorama==0.4.6 # via tox -comm==0.2.1 # via ipykernel -contourpy==1.1.1 # via matplotlib -coverage[toml]==7.4.1 # via -r requirements-dev.in, coverage, pytest-cov -cryptography==42.0.2 # via types-paramiko, types-pyopenssl, types-redis -cycler==0.12.1 # via matplotlib -cytoolz==0.12.3 # via gt4py (pyproject.toml) -dace==0.15.1 # via gt4py (pyproject.toml) -darglint==1.8.1 # via -r requirements-dev.in -debugpy==1.8.0 # via ipykernel -decorator==5.1.1 # via ipython -deepdiff==6.7.1 # via gt4py (pyproject.toml) -devtools==0.12.2 # via gt4py (pyproject.toml) -dill==0.3.8 # via dace -distlib==0.3.8 # via virtualenv -docutils==0.20.1 # via restructuredtext-lint, sphinx, sphinx-rtd-theme -eradicate==2.3.0 # via flake8-eradicate -exceptiongroup==1.2.0 # via hypothesis, pytest -execnet==2.0.2 # via pytest-cache, pytest-xdist -executing==2.0.1 # via devtools, stack-data -factory-boy==3.3.0 # via -r requirements-dev.in, pytest-factoryboy -faker==22.6.0 # via factory-boy -fastjsonschema==2.19.1 # via nbformat -filelock==3.13.1 # via tox, virtualenv -flake8==7.0.0 # via -r requirements-dev.in, flake8-bugbear, flake8-builtins, flake8-debugger, flake8-docstrings, flake8-eradicate, flake8-mutable, flake8-pyproject, flake8-rst-docstrings -flake8-bugbear==24.1.17 # via -r requirements-dev.in -flake8-builtins==2.2.0 # via -r requirements-dev.in -flake8-debugger==4.1.2 # via -r requirements-dev.in -flake8-docstrings==1.7.0 # via -r requirements-dev.in -flake8-eradicate==1.5.0 # via -r requirements-dev.in -flake8-mutable==1.2.0 # via -r requirements-dev.in -flake8-pyproject==1.2.3 # via -r requirements-dev.in -flake8-rst-docstrings==0.3.0 # via -r requirements-dev.in -flask==3.0.1 # via dace -fonttools==4.47.2 # via matplotlib -fparser==0.1.4 # via dace -frozendict==2.4.0 # via gt4py (pyproject.toml) -gridtools-cpp==2.3.2 # via gt4py (pyproject.toml) -hypothesis==6.97.3 # via -r requirements-dev.in, gt4py (pyproject.toml) -identify==2.5.33 # via pre-commit -idna==3.6 # via requests -imagesize==1.4.1 # via sphinx -importlib-metadata==7.0.1 # via build, flask, jupyter-client, sphinx -importlib-resources==6.1.1 ; python_version < "3.9" # via gt4py (pyproject.toml), jsonschema, jsonschema-specifications, matplotlib -inflection==0.5.1 # via pytest-factoryboy -iniconfig==2.0.0 # via pytest -ipykernel==6.29.0 # via nbmake -ipython==8.12.3 # via ipykernel -isort==5.13.2 # via -r requirements-dev.in -itsdangerous==2.1.2 # via flask -jedi==0.19.1 # via ipython -jinja2==3.1.3 # via flask, gt4py (pyproject.toml), sphinx -jsonschema==4.21.1 # via nbformat -jsonschema-specifications==2023.12.1 # via jsonschema -jupyter-client==8.6.0 # via ipykernel, nbclient -jupyter-core==5.7.1 # via ipykernel, jupyter-client, nbformat -jupytext==1.16.1 # via -r requirements-dev.in -kiwisolver==1.4.5 # via matplotlib -lark==1.1.9 # via gt4py (pyproject.toml) -mako==1.3.2 # via gt4py (pyproject.toml) -markdown-it-py==3.0.0 # via jupytext, mdit-py-plugins -markupsafe==2.1.4 # via jinja2, mako, werkzeug -matplotlib==3.7.4 # via -r requirements-dev.in -matplotlib-inline==0.1.6 # via ipykernel, ipython -mccabe==0.7.0 # via flake8 -mdit-py-plugins==0.4.0 # via jupytext -mdurl==0.1.2 # via markdown-it-py -mpmath==1.3.0 # via sympy -mypy==1.8.0 # via -r requirements-dev.in -mypy-extensions==1.0.0 # via black, mypy -nanobind==1.8.0 # via gt4py (pyproject.toml) -nbclient==0.6.8 # via nbmake -nbformat==5.9.2 # via jupytext, nbclient, nbmake -nbmake==1.5.0 # via -r requirements-dev.in -nest-asyncio==1.6.0 # via ipykernel, nbclient -networkx==3.1 # via dace -ninja==1.11.1.1 # via gt4py (pyproject.toml) -nodeenv==1.8.0 # via pre-commit -numpy==1.24.4 # via contourpy, dace, gt4py (pyproject.toml), matplotlib, types-jack-client -ordered-set==4.1.0 # via deepdiff -packaging==23.2 # via black, build, gt4py (pyproject.toml), ipykernel, jupytext, matplotlib, pyproject-api, pytest, setuptools-scm, sphinx, tox -parso==0.8.3 # via jedi -pathspec==0.12.1 # via black -pexpect==4.9.0 # via ipython -pickleshare==0.7.5 # via ipython -pillow==10.2.0 # via matplotlib -pip-tools==7.3.0 # via -r requirements-dev.in -pipdeptree==2.13.2 # via -r requirements-dev.in -pkgutil-resolve-name==1.3.10 # via jsonschema -platformdirs==4.2.0 # via black, jupyter-core, tox, virtualenv -pluggy==1.4.0 # via pytest, tox -ply==3.11 # via dace -pre-commit==3.5.0 # via -r requirements-dev.in -prompt-toolkit==3.0.43 # via ipython -psutil==5.9.8 # via -r requirements-dev.in, ipykernel, pytest-xdist -ptyprocess==0.7.0 # via pexpect -pure-eval==0.2.2 # via stack-data -pybind11==2.11.1 # via gt4py (pyproject.toml) -pycodestyle==2.11.1 # via flake8, flake8-debugger -pycparser==2.21 # via cffi -pydocstyle==6.3.0 # via flake8-docstrings -pyflakes==3.2.0 # via flake8 -pygments==2.17.2 # via -r requirements-dev.in, devtools, flake8-rst-docstrings, ipython, nbmake, sphinx -pyparsing==3.1.1 # via matplotlib -pyproject-api==1.6.1 # via tox -pyproject-hooks==1.0.0 # via build -pytest==8.0.0 # via -r requirements-dev.in, gt4py (pyproject.toml), nbmake, pytest-cache, pytest-cov, pytest-factoryboy, pytest-xdist -pytest-cache==1.0 # via -r requirements-dev.in -pytest-cov==4.1.0 # via -r requirements-dev.in -pytest-factoryboy==2.6.0 # via -r requirements-dev.in -pytest-xdist[psutil]==3.5.0 # via -r requirements-dev.in, pytest-xdist -python-dateutil==2.8.2 # via faker, jupyter-client, matplotlib -pytz==2023.4 # via babel -pyyaml==6.0.1 # via dace, jupytext, pre-commit -pyzmq==25.1.2 # via ipykernel, jupyter-client -referencing==0.33.0 # via jsonschema, jsonschema-specifications -requests==2.31.0 # via dace, sphinx -restructuredtext-lint==1.4.0 # via flake8-rst-docstrings -rpds-py==0.17.1 # via jsonschema, referencing -ruff==0.1.15 # via -r requirements-dev.in -setuptools-scm==8.0.4 # via fparser -six==1.16.0 # via asttokens, astunparse, python-dateutil -snowballstemmer==2.2.0 # via pydocstyle, sphinx -sortedcontainers==2.4.0 # via hypothesis -sphinx==7.1.2 # via -r requirements-dev.in, sphinx-rtd-theme, sphinxcontrib-jquery -sphinx-rtd-theme==2.0.0 # via -r requirements-dev.in -sphinxcontrib-applehelp==1.0.4 # via sphinx -sphinxcontrib-devhelp==1.0.2 # via sphinx -sphinxcontrib-htmlhelp==2.0.1 # via sphinx -sphinxcontrib-jquery==4.1 # via sphinx-rtd-theme -sphinxcontrib-jsmath==1.0.1 # via sphinx -sphinxcontrib-qthelp==1.0.3 # via sphinx -sphinxcontrib-serializinghtml==1.1.5 # via sphinx -stack-data==0.6.3 # via ipython -sympy==1.9 # via dace, gt4py (pyproject.toml) -tabulate==0.9.0 # via gt4py (pyproject.toml) -toml==0.10.2 # via jupytext -tomli==2.0.1 # via -r requirements-dev.in, black, build, coverage, flake8-pyproject, mypy, pip-tools, pyproject-api, pyproject-hooks, pytest, setuptools-scm, tox -toolz==0.12.1 # via cytoolz -tornado==6.4 # via ipykernel, jupyter-client -tox==4.12.1 # via -r requirements-dev.in -traitlets==5.14.1 # via comm, ipykernel, ipython, jupyter-client, jupyter-core, matplotlib-inline, nbclient, nbformat -types-aiofiles==23.2.0.20240106 # via types-all -types-all==1.0.0 # via -r requirements-dev.in -types-annoy==1.17.8.4 # via types-all -types-atomicwrites==1.4.5.1 # via types-all -types-backports==0.1.3 # via types-all -types-backports-abc==0.5.2 # via types-all -types-bleach==6.1.0.1 # via types-all -types-boto==2.49.18.9 # via types-all -types-cachetools==5.3.0.7 # via types-all -types-certifi==2021.10.8.3 # via types-all -types-cffi==1.16.0.20240106 # via types-jack-client -types-characteristic==14.3.7 # via types-all -types-chardet==5.0.4.6 # via types-all -types-click==7.1.8 # via types-all, types-flask -types-click-spinner==0.1.13.20240106 # via types-all -types-colorama==0.4.15.20240106 # via types-all -types-contextvars==2.4.7.3 # via types-all -types-croniter==2.0.0.20240106 # via types-all -types-cryptography==3.3.23.2 # via types-all, types-openssl-python, types-pyjwt -types-dataclasses==0.6.6 # via types-all -types-dateparser==1.1.4.20240106 # via types-all -types-datetimerange==2.0.0.6 # via types-all -types-decorator==5.1.8.20240106 # via types-all -types-deprecated==1.2.9.20240106 # via types-all -types-docopt==0.6.11.4 # via types-all -types-docutils==0.20.0.20240126 # via types-all -types-emoji==2.1.0.3 # via types-all -types-enum34==1.1.8 # via types-all -types-fb303==1.0.0 # via types-all, types-scribe -types-filelock==3.2.7 # via types-all -types-first==2.0.5.2 # via types-all -types-flask==1.1.6 # via types-all -types-freezegun==1.1.10 # via types-all -types-frozendict==2.0.9 # via types-all -types-futures==3.3.8 # via types-all -types-geoip2==3.0.0 # via types-all -types-ipaddress==1.0.8 # via types-all, types-maxminddb -types-itsdangerous==1.1.6 # via types-all -types-jack-client==0.5.10.20240106 # via types-all -types-jinja2==2.11.9 # via types-all, types-flask -types-kazoo==0.1.3 # via types-all -types-markdown==3.5.0.20240129 # via types-all -types-markupsafe==1.1.10 # via types-all, types-jinja2 -types-maxminddb==1.5.0 # via types-all, types-geoip2 -types-mock==5.1.0.20240106 # via types-all -types-mypy-extensions==1.0.0.5 # via types-all -types-nmap==0.1.6 # via types-all -types-openssl-python==0.1.3 # via types-all -types-orjson==3.6.2 # via types-all -types-paramiko==3.4.0.20240120 # via types-all, types-pysftp -types-pathlib2==2.3.0 # via types-all -types-pillow==10.2.0.20240125 # via types-all -types-pkg-resources==0.1.3 # via types-all -types-polib==1.2.0.20240115 # via types-all -types-protobuf==4.24.0.20240129 # via types-all -types-pyaudio==0.2.16.20240106 # via types-all -types-pycurl==7.45.2.20240106 # via types-all -types-pyfarmhash==0.3.1.2 # via types-all -types-pyjwt==1.7.1 # via types-all -types-pymssql==2.1.0 # via types-all -types-pymysql==1.1.0.1 # via types-all -types-pyopenssl==24.0.0.20240130 # via types-redis -types-pyrfc3339==1.1.1.5 # via types-all -types-pysftp==0.2.17.20240106 # via types-all -types-python-dateutil==2.8.19.20240106 # via types-all, types-datetimerange -types-python-gflags==3.1.7.3 # via types-all -types-python-slugify==8.0.2.20240127 # via types-all -types-pytz==2023.4.0.20240130 # via types-all, types-tzlocal -types-pyvmomi==8.0.0.6 # via types-all -types-pyyaml==6.0.12.12 # via types-all -types-redis==4.6.0.20240106 # via types-all -types-requests==2.31.0.20240125 # via types-all -types-retry==0.9.9.4 # via types-all -types-routes==2.5.0 # via types-all -types-scribe==2.0.0 # via types-all -types-setuptools==69.0.0.20240125 # via types-cffi -types-simplejson==3.19.0.2 # via types-all -types-singledispatch==4.1.0.0 # via types-all -types-six==1.16.21.20240106 # via types-all -types-tabulate==0.9.0.20240106 # via types-all -types-termcolor==1.1.6.2 # via types-all -types-toml==0.10.8.7 # via types-all -types-tornado==5.1.1 # via types-all -types-typed-ast==1.5.8.7 # via types-all -types-tzlocal==5.1.0.1 # via types-all -types-ujson==5.9.0.0 # via types-all -types-waitress==2.1.4.20240106 # via types-all -types-werkzeug==1.0.9 # via types-all, types-flask -types-xxhash==3.0.5.2 # via types-all -typing-extensions==4.5.0 # via black, faker, gt4py (pyproject.toml), ipython, mypy, pytest-factoryboy, setuptools-scm -urllib3==2.2.0 # via requests, types-requests -virtualenv==20.25.0 # via pre-commit, tox -wcwidth==0.2.13 # via prompt-toolkit -websockets==12.0 # via dace -werkzeug==3.0.1 # via flask -wheel==0.42.0 # via astunparse, pip-tools -xxhash==3.0.0 # via gt4py (pyproject.toml) -zipp==3.17.0 # via importlib-metadata, importlib-resources +aenum==3.1.15 # via -c constraints.txt, dace +alabaster==0.7.13 # via -c constraints.txt, sphinx +asttokens==2.4.1 # via -c constraints.txt, devtools, stack-data +astunparse==1.6.3 ; python_version < "3.9" # via -c constraints.txt, dace, gt4py (pyproject.toml) +attrs==23.2.0 # via -c constraints.txt, flake8-bugbear, flake8-eradicate, gt4py (pyproject.toml), hypothesis, jsonschema, referencing +babel==2.14.0 # via -c constraints.txt, sphinx +backcall==0.2.0 # via -c constraints.txt, ipython +black==24.1.1 # via -c constraints.txt, gt4py (pyproject.toml) +blinker==1.7.0 # via -c constraints.txt, flask +boltons==23.1.1 # via -c constraints.txt, gt4py (pyproject.toml) +build==1.0.3 # via -c constraints.txt, pip-tools +cached-property==1.5.2 # via -c constraints.txt, gt4py (pyproject.toml) +cachetools==5.3.2 # via -c constraints.txt, tox +certifi==2024.2.2 # via -c constraints.txt, requests +cffi==1.16.0 # via -c constraints.txt, cryptography +cfgv==3.4.0 # via -c constraints.txt, pre-commit +chardet==5.2.0 # via -c constraints.txt, tox +charset-normalizer==3.3.2 # via -c constraints.txt, requests +clang-format==17.0.6 # via -c constraints.txt, -r requirements-dev.in, gt4py (pyproject.toml) +click==8.1.7 # via -c constraints.txt, black, flask, gt4py (pyproject.toml), pip-tools +cmake==3.28.1 # via -c constraints.txt, gt4py (pyproject.toml) +cogapp==3.3.0 # via -c constraints.txt, -r requirements-dev.in +colorama==0.4.6 # via -c constraints.txt, tox +comm==0.2.1 # via -c constraints.txt, ipykernel +contourpy==1.1.1 # via -c constraints.txt, matplotlib +coverage[toml]==7.4.1 # via -c constraints.txt, -r requirements-dev.in, pytest-cov +cryptography==42.0.2 # via -c constraints.txt, types-paramiko, types-pyopenssl, types-redis +cycler==0.12.1 # via -c constraints.txt, matplotlib +cytoolz==0.12.3 # via -c constraints.txt, gt4py (pyproject.toml) +dace==0.15.1 # via -c constraints.txt, gt4py (pyproject.toml) +darglint==1.8.1 # via -c constraints.txt, -r requirements-dev.in +debugpy==1.8.0 # via -c constraints.txt, ipykernel +decorator==5.1.1 # via -c constraints.txt, ipython +deepdiff==6.7.1 # via -c constraints.txt, gt4py (pyproject.toml) +devtools==0.12.2 # via -c constraints.txt, gt4py (pyproject.toml) +dill==0.3.8 # via -c constraints.txt, dace +distlib==0.3.8 # via -c constraints.txt, virtualenv +docutils==0.20.1 # via -c constraints.txt, restructuredtext-lint, sphinx, sphinx-rtd-theme +eradicate==2.3.0 # via -c constraints.txt, flake8-eradicate +exceptiongroup==1.2.0 # via -c constraints.txt, hypothesis, pytest +execnet==2.0.2 # via -c constraints.txt, pytest-cache, pytest-xdist +executing==2.0.1 # via -c constraints.txt, devtools, stack-data +factory-boy==3.3.0 # via -c constraints.txt, -r requirements-dev.in, pytest-factoryboy +faker==22.7.0 # via -c constraints.txt, factory-boy +fastjsonschema==2.19.1 # via -c constraints.txt, nbformat +filelock==3.13.1 # via -c constraints.txt, tox, virtualenv +flake8==7.0.0 # via -c constraints.txt, -r requirements-dev.in, flake8-bugbear, flake8-builtins, flake8-debugger, flake8-docstrings, flake8-eradicate, flake8-mutable, flake8-pyproject, flake8-rst-docstrings +flake8-bugbear==24.1.17 # via -c constraints.txt, -r requirements-dev.in +flake8-builtins==2.2.0 # via -c constraints.txt, -r requirements-dev.in +flake8-debugger==4.1.2 # via -c constraints.txt, -r requirements-dev.in +flake8-docstrings==1.7.0 # via -c constraints.txt, -r requirements-dev.in +flake8-eradicate==1.5.0 # via -c constraints.txt, -r requirements-dev.in +flake8-mutable==1.2.0 # via -c constraints.txt, -r requirements-dev.in +flake8-pyproject==1.2.3 # via -c constraints.txt, -r requirements-dev.in +flake8-rst-docstrings==0.3.0 # via -c constraints.txt, -r requirements-dev.in +flask==3.0.2 # via -c constraints.txt, dace +fonttools==4.47.2 # via -c constraints.txt, matplotlib +fparser==0.1.4 # via -c constraints.txt, dace +frozendict==2.4.0 # via -c constraints.txt, gt4py (pyproject.toml) +gridtools-cpp==2.3.2 # via -c constraints.txt, gt4py (pyproject.toml) +hypothesis==6.98.2 # via -c constraints.txt, -r requirements-dev.in, gt4py (pyproject.toml) +identify==2.5.33 # via -c constraints.txt, pre-commit +idna==3.6 # via -c constraints.txt, requests +imagesize==1.4.1 # via -c constraints.txt, sphinx +importlib-metadata==7.0.1 # via -c constraints.txt, build, flask, jax, jupyter-client, sphinx +importlib-resources==6.1.1 ; python_version < "3.9" # via -c constraints.txt, gt4py (pyproject.toml), jsonschema, jsonschema-specifications, matplotlib +inflection==0.5.1 # via -c constraints.txt, pytest-factoryboy +iniconfig==2.0.0 # via -c constraints.txt, pytest +ipykernel==6.29.1 # via -c constraints.txt, nbmake +ipython==8.12.3 # via -c constraints.txt, ipykernel +isort==5.13.2 # via -c constraints.txt, -r requirements-dev.in +itsdangerous==2.1.2 # via -c constraints.txt, flask +jax[cpu]==0.4.13 # via -c constraints.txt, gt4py (pyproject.toml) +jaxlib==0.4.13 # via -c constraints.txt, jax +jedi==0.19.1 # via -c constraints.txt, ipython +jinja2==3.1.3 # via -c constraints.txt, flask, gt4py (pyproject.toml), sphinx +jsonschema==4.21.1 # via -c constraints.txt, nbformat +jsonschema-specifications==2023.12.1 # via -c constraints.txt, jsonschema +jupyter-client==8.6.0 # via -c constraints.txt, ipykernel, nbclient +jupyter-core==5.7.1 # via -c constraints.txt, ipykernel, jupyter-client, nbformat +jupytext==1.16.1 # via -c constraints.txt, -r requirements-dev.in +kiwisolver==1.4.5 # via -c constraints.txt, matplotlib +lark==1.1.9 # via -c constraints.txt, gt4py (pyproject.toml) +mako==1.3.2 # via -c constraints.txt, gt4py (pyproject.toml) +markdown-it-py==3.0.0 # via -c constraints.txt, jupytext, mdit-py-plugins +markupsafe==2.1.5 # via -c constraints.txt, jinja2, mako, werkzeug +matplotlib==3.7.4 # via -c constraints.txt, -r requirements-dev.in +matplotlib-inline==0.1.6 # via -c constraints.txt, ipykernel, ipython +mccabe==0.7.0 # via -c constraints.txt, flake8 +mdit-py-plugins==0.4.0 # via -c constraints.txt, jupytext +mdurl==0.1.2 # via -c constraints.txt, markdown-it-py +ml-dtypes==0.2.0 # via -c constraints.txt, jax, jaxlib +mpmath==1.3.0 # via -c constraints.txt, sympy +mypy==1.8.0 # via -c constraints.txt, -r requirements-dev.in +mypy-extensions==1.0.0 # via -c constraints.txt, black, mypy +nanobind==1.8.0 # via -c constraints.txt, gt4py (pyproject.toml) +nbclient==0.6.8 # via -c constraints.txt, nbmake +nbformat==5.9.2 # via -c constraints.txt, jupytext, nbclient, nbmake +nbmake==1.5.0 # via -c constraints.txt, -r requirements-dev.in +nest-asyncio==1.6.0 # via -c constraints.txt, ipykernel, nbclient +networkx==3.1 # via -c constraints.txt, dace +ninja==1.11.1.1 # via -c constraints.txt, gt4py (pyproject.toml) +nodeenv==1.8.0 # via -c constraints.txt, pre-commit +numpy==1.24.4 # via -c constraints.txt, contourpy, dace, gt4py (pyproject.toml), jax, jaxlib, matplotlib, ml-dtypes, opt-einsum, scipy, types-jack-client +opt-einsum==3.3.0 # via -c constraints.txt, jax +ordered-set==4.1.0 # via -c constraints.txt, deepdiff +packaging==23.2 # via -c constraints.txt, black, build, gt4py (pyproject.toml), ipykernel, jupytext, matplotlib, pyproject-api, pytest, setuptools-scm, sphinx, tox +parso==0.8.3 # via -c constraints.txt, jedi +pathspec==0.12.1 # via -c constraints.txt, black +pexpect==4.9.0 # via -c constraints.txt, ipython +pickleshare==0.7.5 # via -c constraints.txt, ipython +pillow==10.2.0 # via -c constraints.txt, matplotlib +pip-tools==7.3.0 # via -c constraints.txt, -r requirements-dev.in +pipdeptree==2.13.2 # via -c constraints.txt, -r requirements-dev.in +pkgutil-resolve-name==1.3.10 # via -c constraints.txt, jsonschema +platformdirs==4.2.0 # via -c constraints.txt, black, jupyter-core, tox, virtualenv +pluggy==1.4.0 # via -c constraints.txt, pytest, tox +ply==3.11 # via -c constraints.txt, dace +pre-commit==3.5.0 # via -c constraints.txt, -r requirements-dev.in +prompt-toolkit==3.0.43 # via -c constraints.txt, ipython +psutil==5.9.8 # via -c constraints.txt, -r requirements-dev.in, ipykernel, pytest-xdist +ptyprocess==0.7.0 # via -c constraints.txt, pexpect +pure-eval==0.2.2 # via -c constraints.txt, stack-data +pybind11==2.11.1 # via -c constraints.txt, gt4py (pyproject.toml) +pycodestyle==2.11.1 # via -c constraints.txt, flake8, flake8-debugger +pycparser==2.21 # via -c constraints.txt, cffi +pydocstyle==6.3.0 # via -c constraints.txt, flake8-docstrings +pyflakes==3.2.0 # via -c constraints.txt, flake8 +pygments==2.17.2 # via -c constraints.txt, -r requirements-dev.in, devtools, flake8-rst-docstrings, ipython, nbmake, sphinx +pyparsing==3.1.1 # via -c constraints.txt, matplotlib +pyproject-api==1.6.1 # via -c constraints.txt, tox +pyproject-hooks==1.0.0 # via -c constraints.txt, build +pytest==8.0.0 # via -c constraints.txt, -r requirements-dev.in, gt4py (pyproject.toml), nbmake, pytest-cache, pytest-cov, pytest-factoryboy, pytest-xdist +pytest-cache==1.0 # via -c constraints.txt, -r requirements-dev.in +pytest-cov==4.1.0 # via -c constraints.txt, -r requirements-dev.in +pytest-factoryboy==2.6.0 # via -c constraints.txt, -r requirements-dev.in +pytest-xdist[psutil]==3.5.0 # via -c constraints.txt, -r requirements-dev.in +python-dateutil==2.8.2 # via -c constraints.txt, faker, jupyter-client, matplotlib +pytz==2024.1 # via -c constraints.txt, babel +pyyaml==6.0.1 # via -c constraints.txt, dace, jupytext, pre-commit +pyzmq==25.1.2 # via -c constraints.txt, ipykernel, jupyter-client +referencing==0.33.0 # via -c constraints.txt, jsonschema, jsonschema-specifications +requests==2.31.0 # via -c constraints.txt, dace, sphinx +restructuredtext-lint==1.4.0 # via -c constraints.txt, flake8-rst-docstrings +rpds-py==0.17.1 # via -c constraints.txt, jsonschema, referencing +ruff==0.2.1 # via -c constraints.txt, -r requirements-dev.in +scipy==1.10.1 # via -c constraints.txt, jax, jaxlib +setuptools-scm==8.0.4 # via -c constraints.txt, fparser +six==1.16.0 # via -c constraints.txt, asttokens, astunparse, python-dateutil +snowballstemmer==2.2.0 # via -c constraints.txt, pydocstyle, sphinx +sortedcontainers==2.4.0 # via -c constraints.txt, hypothesis +sphinx==7.1.2 # via -c constraints.txt, -r requirements-dev.in, sphinx-rtd-theme, sphinxcontrib-jquery +sphinx-rtd-theme==2.0.0 # via -c constraints.txt, -r requirements-dev.in +sphinxcontrib-applehelp==1.0.4 # via -c constraints.txt, sphinx +sphinxcontrib-devhelp==1.0.2 # via -c constraints.txt, sphinx +sphinxcontrib-htmlhelp==2.0.1 # via -c constraints.txt, sphinx +sphinxcontrib-jquery==4.1 # via -c constraints.txt, sphinx-rtd-theme +sphinxcontrib-jsmath==1.0.1 # via -c constraints.txt, sphinx +sphinxcontrib-qthelp==1.0.3 # via -c constraints.txt, sphinx +sphinxcontrib-serializinghtml==1.1.5 # via -c constraints.txt, sphinx +stack-data==0.6.3 # via -c constraints.txt, ipython +sympy==1.9 # via -c constraints.txt, dace, gt4py (pyproject.toml) +tabulate==0.9.0 # via -c constraints.txt, gt4py (pyproject.toml) +toml==0.10.2 # via -c constraints.txt, jupytext +tomli==2.0.1 ; python_version < "3.11" # via -c constraints.txt, -r requirements-dev.in, black, build, coverage, flake8-pyproject, mypy, pip-tools, pyproject-api, pyproject-hooks, pytest, setuptools-scm, tox +toolz==0.12.1 # via -c constraints.txt, cytoolz +tornado==6.4 # via -c constraints.txt, ipykernel, jupyter-client +tox==4.12.1 # via -c constraints.txt, -r requirements-dev.in +traitlets==5.14.1 # via -c constraints.txt, comm, ipykernel, ipython, jupyter-client, jupyter-core, matplotlib-inline, nbclient, nbformat +types-aiofiles==23.2.0.20240106 # via -c constraints.txt, types-all +types-all==1.0.0 # via -c constraints.txt, -r requirements-dev.in +types-annoy==1.17.8.4 # via -c constraints.txt, types-all +types-atomicwrites==1.4.5.1 # via -c constraints.txt, types-all +types-backports==0.1.3 # via -c constraints.txt, types-all +types-backports-abc==0.5.2 # via -c constraints.txt, types-all +types-bleach==6.1.0.1 # via -c constraints.txt, types-all +types-boto==2.49.18.20240205 # via -c constraints.txt, types-all +types-cachetools==5.3.0.7 # via -c constraints.txt, types-all +types-certifi==2021.10.8.3 # via -c constraints.txt, types-all +types-cffi==1.16.0.20240106 # via -c constraints.txt, types-jack-client +types-characteristic==14.3.7 # via -c constraints.txt, types-all +types-chardet==5.0.4.6 # via -c constraints.txt, types-all +types-click==7.1.8 # via -c constraints.txt, types-all, types-flask +types-click-spinner==0.1.13.20240106 # via -c constraints.txt, types-all +types-colorama==0.4.15.20240205 # via -c constraints.txt, types-all +types-contextvars==2.4.7.3 # via -c constraints.txt, types-all +types-croniter==2.0.0.20240106 # via -c constraints.txt, types-all +types-cryptography==3.3.23.2 # via -c constraints.txt, types-all, types-openssl-python, types-pyjwt +types-dataclasses==0.6.6 # via -c constraints.txt, types-all +types-dateparser==1.1.4.20240106 # via -c constraints.txt, types-all +types-datetimerange==2.0.0.6 # via -c constraints.txt, types-all +types-decorator==5.1.8.20240106 # via -c constraints.txt, types-all +types-deprecated==1.2.9.20240106 # via -c constraints.txt, types-all +types-docopt==0.6.11.4 # via -c constraints.txt, types-all +types-docutils==0.20.0.20240201 # via -c constraints.txt, types-all +types-emoji==2.1.0.3 # via -c constraints.txt, types-all +types-enum34==1.1.8 # via -c constraints.txt, types-all +types-fb303==1.0.0 # via -c constraints.txt, types-all, types-scribe +types-filelock==3.2.7 # via -c constraints.txt, types-all +types-first==2.0.5.2 # via -c constraints.txt, types-all +types-flask==1.1.6 # via -c constraints.txt, types-all +types-freezegun==1.1.10 # via -c constraints.txt, types-all +types-frozendict==2.0.9 # via -c constraints.txt, types-all +types-futures==3.3.8 # via -c constraints.txt, types-all +types-geoip2==3.0.0 # via -c constraints.txt, types-all +types-ipaddress==1.0.8 # via -c constraints.txt, types-all, types-maxminddb +types-itsdangerous==1.1.6 # via -c constraints.txt, types-all +types-jack-client==0.5.10.20240106 # via -c constraints.txt, types-all +types-jinja2==2.11.9 # via -c constraints.txt, types-all, types-flask +types-kazoo==0.1.3 # via -c constraints.txt, types-all +types-markdown==3.5.0.20240129 # via -c constraints.txt, types-all +types-markupsafe==1.1.10 # via -c constraints.txt, types-all, types-jinja2 +types-maxminddb==1.5.0 # via -c constraints.txt, types-all, types-geoip2 +types-mock==5.1.0.20240106 # via -c constraints.txt, types-all +types-mypy-extensions==1.0.0.5 # via -c constraints.txt, types-all +types-nmap==0.1.6 # via -c constraints.txt, types-all +types-openssl-python==0.1.3 # via -c constraints.txt, types-all +types-orjson==3.6.2 # via -c constraints.txt, types-all +types-paramiko==3.4.0.20240205 # via -c constraints.txt, types-all, types-pysftp +types-pathlib2==2.3.0 # via -c constraints.txt, types-all +types-pillow==10.2.0.20240206 # via -c constraints.txt, types-all +types-pkg-resources==0.1.3 # via -c constraints.txt, types-all +types-polib==1.2.0.20240115 # via -c constraints.txt, types-all +types-protobuf==4.24.0.20240129 # via -c constraints.txt, types-all +types-pyaudio==0.2.16.20240106 # via -c constraints.txt, types-all +types-pycurl==7.45.2.20240106 # via -c constraints.txt, types-all +types-pyfarmhash==0.3.1.2 # via -c constraints.txt, types-all +types-pyjwt==1.7.1 # via -c constraints.txt, types-all +types-pymssql==2.1.0 # via -c constraints.txt, types-all +types-pymysql==1.1.0.1 # via -c constraints.txt, types-all +types-pyopenssl==24.0.0.20240130 # via -c constraints.txt, types-redis +types-pyrfc3339==1.1.1.5 # via -c constraints.txt, types-all +types-pysftp==0.2.17.20240106 # via -c constraints.txt, types-all +types-python-dateutil==2.8.19.20240106 # via -c constraints.txt, types-all, types-datetimerange +types-python-gflags==3.1.7.3 # via -c constraints.txt, types-all +types-python-slugify==8.0.2.20240127 # via -c constraints.txt, types-all +types-pytz==2024.1.0.20240203 # via -c constraints.txt, types-all, types-tzlocal +types-pyvmomi==8.0.0.6 # via -c constraints.txt, types-all +types-pyyaml==6.0.12.12 # via -c constraints.txt, types-all +types-redis==4.6.0.20240106 # via -c constraints.txt, types-all +types-requests==2.31.0.20240125 # via -c constraints.txt, types-all +types-retry==0.9.9.4 # via -c constraints.txt, types-all +types-routes==2.5.0 # via -c constraints.txt, types-all +types-scribe==2.0.0 # via -c constraints.txt, types-all +types-setuptools==69.0.0.20240125 # via -c constraints.txt, types-cffi +types-simplejson==3.19.0.2 # via -c constraints.txt, types-all +types-singledispatch==4.1.0.0 # via -c constraints.txt, types-all +types-six==1.16.21.20240106 # via -c constraints.txt, types-all +types-tabulate==0.9.0.20240106 # via -c constraints.txt, types-all +types-termcolor==1.1.6.2 # via -c constraints.txt, types-all +types-toml==0.10.8.7 # via -c constraints.txt, types-all +types-tornado==5.1.1 # via -c constraints.txt, types-all +types-typed-ast==1.5.8.7 # via -c constraints.txt, types-all +types-tzlocal==5.1.0.1 # via -c constraints.txt, types-all +types-ujson==5.9.0.0 # via -c constraints.txt, types-all +types-waitress==2.1.4.20240106 # via -c constraints.txt, types-all +types-werkzeug==1.0.9 # via -c constraints.txt, types-all, types-flask +types-xxhash==3.0.5.2 # via -c constraints.txt, types-all +typing-extensions==4.5.0 # via -c constraints.txt, black, faker, gt4py (pyproject.toml), ipython, mypy, pytest-factoryboy, setuptools-scm +urllib3==2.2.0 # via -c constraints.txt, requests, types-requests +virtualenv==20.25.0 # via -c constraints.txt, pre-commit, tox +wcwidth==0.2.13 # via -c constraints.txt, prompt-toolkit +websockets==12.0 # via -c constraints.txt, dace +werkzeug==3.0.1 # via -c constraints.txt, flask +wheel==0.42.0 # via -c constraints.txt, astunparse, pip-tools +xxhash==3.0.0 # via -c constraints.txt, gt4py (pyproject.toml) +zipp==3.17.0 # via -c constraints.txt, importlib-metadata, importlib-resources # The following packages are considered to be unsafe in a requirements file: -pip==23.3.2 # via pip-tools -setuptools==69.0.3 # via gt4py (pyproject.toml), nodeenv, pip-tools, setuptools-scm +pip==24.0 # via -c constraints.txt, pip-tools +setuptools==69.0.3 # via -c constraints.txt, gt4py (pyproject.toml), nodeenv, pip-tools, setuptools-scm diff --git a/tox.ini b/tox.ini index 6a21a298ba..f7e4028503 100644 --- a/tox.ini +++ b/tox.ini @@ -3,10 +3,10 @@ requires = tox>=4.2 virtualenv>20.2 envlist = + cartesian-py{310}-{internal,dace}-{cpu} eve-py{310} - storage-py{310}-{internal,dace}-{cpu} next-py{310}-{nomesh,atlas} - cartesian-py{310}-{internal,dace}-{cpu} + storage-py{310}-{internal,dace}-{cpu} linters-py{310} # docs labels = @@ -47,6 +47,7 @@ pass_env = NUM_PROCESSES set_env = PYTHONWARNINGS = {env:PYTHONWARNINGS:ignore:Support for `[tool.setuptools]` in `pyproject.toml` is still *beta*:UserWarning} +# -- Primary tests -- [testenv:cartesian-py{38,39,310,311}-{internal,dace}-{cpu,cuda,cuda11x,cuda12x}] description = Run 'gt4py.cartesian' tests pass_env = {[testenv]pass_env}, BOOST_ROOT, BOOST_HOME, CUDA_HOME, CUDA_PATH, CXX, CC, OPENMP_CPPFLAGS, OPENMP_LDFLAGS, PIP_USER, PYTHONUSERBASE @@ -97,16 +98,28 @@ commands = {cuda,cuda11x,cuda12x}: python -m pytest --cache-clear -v -n {env:NUM_PROCESSES:1} -m "requires_gpu" {posargs} tests{/}storage_tests #pytest doctest-modules {posargs} src{/}gt4py{/}storage -[testenv:notebooks-py{310,311}] -description = Run notebooks -commands = python -m pytest --nbmake examples -v -n {env:NUM_PROCESSES:1} - +# -- Secondary tests -- [testenv:linters-py{38,39,310,311}] description = Run linters commands = flake8 .{/}src mypy .{/}src +[testenv:notebooks-py{310,311}] +description = Run notebooks +commands = python -m pytest --nbmake examples -v -n {env:NUM_PROCESSES:1} + +# -- Other artefacts -- +[testenv:dev-py{38,39,310,311}{-atlas,}] +description = Initialize development environment for gt4py +deps = + -r {tox_root}{/}requirements-dev.txt + atlas: atlas4py +package = editable-legacy # => use_develop = True +set_env = + {[testenv]set_env} + PIP_EXTRA_INDEX_URL = {env:PIP_EXTRA_INDEX_URL:https://test.pypi.org/simple/} + # [testenv:docs] # usedevelop = true # commands_pre = @@ -135,32 +148,33 @@ commands = # git add _static # commands_post = -[testenv:requirements-{common,py38,py39,py310,py311}] +[testenv:requirements-{base,py38,py39,py310,py311}] description = - common: Update pinned development requirements + base: Update pinned development requirements py38: Update requirements for testing a specific python version py39: Update requirements for testing a specific python version py310: Update requirements for testing a specific python version py311: Update requirements for testing a specific python version base_python = - common: py38 + base: py38 py38: py38 py39: py39 py310: py310 py311: py311 deps = cogapp>=3.3 + packaging>=20.0 pip-tools>=6.10 package = skip set_env = - CUSTOM_COMPILE_COMMAND = "tox run -e requirements-common" + CUSTOM_COMPILE_COMMAND = "tox run -e requirements-base" allowlist_externals = mv commands = -mv constraints.txt constraints.txt.old -mv requirements-dev.txt requirements-dev.old # Run cog to update requirements files from pyproject - cog -r -P requirements-dev.in min-requirements-test.txt min-extra-requirements-test.txt + cog -r -P min-requirements-test.txt min-extra-requirements-test.txt # Generate constraints file removing extras # (extras are not supported by pip in constraints files) pip-compile -r --resolver=backtracking \ @@ -170,6 +184,8 @@ commands = --allow-unsafe \ --extra dace \ --extra formatting \ + --extra jax-cpu \ + --extra performance \ --extra testing \ -o constraints.txt \ pyproject.toml requirements-dev.in @@ -181,18 +197,10 @@ commands = --allow-unsafe \ --extra dace \ --extra formatting \ + --extra jax-cpu \ --extra testing \ + -c constraints.txt \ -o requirements-dev.txt \ pyproject.toml requirements-dev.in # Run cog to update .pre-commit-config.yaml with new versions common: cog -r -P .pre-commit-config.yaml - -[testenv:dev-py{38,39,310,311}{-atlas,}] -description = Initialize development environment for gt4py -deps = - -r {tox_root}{/}requirements-dev.txt - atlas: atlas4py -package = editable-legacy # => use_develop = True -set_env = - {[testenv]set_env} - PIP_EXTRA_INDEX_URL = {env:PIP_EXTRA_INDEX_URL:https://test.pypi.org/simple/} From b7d34b6498b99f27c72db4d22e89b887f7ccab82 Mon Sep 17 00:00:00 2001 From: Hannes Vogt Date: Wed, 7 Feb 2024 15:16:24 +0100 Subject: [PATCH 72/85] build: Update deployment action with trusted publisher (#1423) --------- Co-authored-by: Rico Haeuselmann --- .github/workflows/deploy-release.yml | 54 +++++++++++++++++++++++----- docs/development/tools/release.md | 8 +++-- 2 files changed, 51 insertions(+), 11 deletions(-) diff --git a/.github/workflows/deploy-release.yml b/.github/workflows/deploy-release.yml index 27ddfaf8ed..048a6f73e1 100644 --- a/.github/workflows/deploy-release.yml +++ b/.github/workflows/deploy-release.yml @@ -1,13 +1,17 @@ name: Deploy Python Distribution on: + push: + branches: [main] + pull_request: + branches: [main] release: types: [published] workflow_dispatch: jobs: - build-n-publish: - name: Build and publish Python distribution + build: + name: Build Python distribution runs-on: ubuntu-latest steps: - uses: actions/checkout@master @@ -21,14 +25,46 @@ jobs: - name: Build a wheel and a source tarball run: | python -m build --sdist --wheel --outdir dist/ - - name: Publish distribution to Test PyPI - if: ${{ github.event_name == 'release' }} - uses: pypa/gh-action-pypi-publish@release/v1 + - name: Upload artifact + uses: actions/upload-artifact@v3 + with: + name: gt4py-dist + path: ./dist/** + publish-pypi: + name: Publish Python distribution to pypi.org + runs-on: ubuntu-latest + needs: build + if: ${{ github.event_name == 'workflow_dispatch' }} # the action was triggered manually + environment: + name: pypi + url: https://pypi.org/project/gt4py + permissions: + id-token: write + steps: + - name: Download wheel + uses: actions/download-artifact@v3 with: - password: ${{ secrets.TEST_PYPI_API_TOKEN }} - repository_url: https://test.pypi.org/legacy/ + name: gt4py-dist + path: dist - name: Publish distribution to PyPI - if: ${{ github.event_name == 'workflow_dispatch' }} + uses: pypa/gh-action-pypi-publish@release/v1 + publish-test-pypi: + name: Publish Python distribution to test.pypi.org + runs-on: ubuntu-latest + needs: build + if: ${{ github.event_name == 'release' }} # triggered by releasing on github, test first before manually triggering the deployment to PyPI (see release documentation) + environment: + name: testpypi + url: https://test.pypi.org/project/gt4py/ + permissions: + id-token: write + steps: + - name: Download wheel + uses: actions/download-artifact@v3 + with: + name: gt4py-dist + path: dist + - name: Publish distribution to Test PyPI uses: pypa/gh-action-pypi-publish@release/v1 with: - password: ${{ secrets.PYPI_API_TOKEN }} + repository-url: https://test.pypi.org/legacy/ diff --git a/docs/development/tools/release.md b/docs/development/tools/release.md index c47d4d7c42..1430a6eff8 100644 --- a/docs/development/tools/release.md +++ b/docs/development/tools/release.md @@ -22,9 +22,13 @@ Currently, GT4Py releases are published in PyPI (and TestPyPI) and also as commi 5. On the GitHub website go to _Releases_ and _Draft a new release_. Choose `v0.{M}.{m}.{p}` as tag and select a branch (usually `main`). Follow the style of the previous releases for the title (`GT4Py v0.{M}.{m}.{p}`) and description. Then _Publish release_. -6. Upload distribution package to TestPyPI and quickly test that it works properly. +6. Publishing the release will trigger a Github action to deploy to TestPyPI. Install the package from TestPyPi and do basic tests. -7. Upload distribution package to PyPI and quickly that test it works properly. +7. If tests are ok, manually trigger the deploy Github action selecting the release tag as target. This will publish the package to PyPI. Install the package and test if it works. + +## PyPi and TestPyPi accounts + +The account is called `gridtools`. Credentials can be found in the bitwarden of CSCS. For 2FA, the recovery keys are stored in bitwarden, too. In case a new developer should get access, the recovery keys can be used to setup the authentication app (for all developers who should have access). From 75fede7b83b2b9b1565b3eb23ed78460e78ed3be Mon Sep 17 00:00:00 2001 From: Hannes Vogt Date: Wed, 7 Feb 2024 16:37:11 +0100 Subject: [PATCH 73/85] release: v1.0.3 (#1446) --- .bumpversion.cfg | 2 +- CHANGELOG.md | 14 ++++++++++++++ docs/development/tools/release.md | 4 ++-- src/gt4py/__about__.py | 2 +- 4 files changed, 18 insertions(+), 4 deletions(-) diff --git a/.bumpversion.cfg b/.bumpversion.cfg index 9e65fd9ae0..7f5a08b19e 100644 --- a/.bumpversion.cfg +++ b/.bumpversion.cfg @@ -1,5 +1,5 @@ [bumpversion] -current_version = 1.0.2 +current_version = 1.0.3 parse = (?P\d+)\.(?P\d+)(\.(?P\d+))? serialize = {major}.{minor}.{patch} diff --git a/CHANGELOG.md b/CHANGELOG.md index 87f3ee9d2b..bfaa0efd7c 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -2,6 +2,20 @@ Notable changes to this project will be documented in this file. The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.0.0/), +## [1.0.3] - 2024-02-07 + +### General + +- Support for Python 3.11 and updated dependencies + +### Testing + +- Testing of Jupyter notebooks in CI + +### Next + +See commit history. + ## [1.0.2] - 2024-01-24 ### Cartesian diff --git a/docs/development/tools/release.md b/docs/development/tools/release.md index 1430a6eff8..7feb8711b1 100644 --- a/docs/development/tools/release.md +++ b/docs/development/tools/release.md @@ -17,10 +17,10 @@ Currently, GT4Py releases are published in PyPI (and TestPyPI) and also as commi 4. Commit the changes with the following message: ```bash - $ git commit -m 'Releasing 0.{M}.{m}.{p} version.' + $ git commit -m 'Releasing {M}.{m}.{p} version.' ``` -5. On the GitHub website go to _Releases_ and _Draft a new release_. Choose `v0.{M}.{m}.{p}` as tag and select a branch (usually `main`). Follow the style of the previous releases for the title (`GT4Py v0.{M}.{m}.{p}`) and description. Then _Publish release_. +5. On the GitHub website go to _Releases_ and _Draft a new release_. Choose `v{M}.{m}.{p}` as tag and select a branch (usually `main`). Follow the style of the previous releases for the title (`GT4Py v{M}.{m}.{p}`) and description. Then _Publish release_. 6. Publishing the release will trigger a Github action to deploy to TestPyPI. Install the package from TestPyPi and do basic tests. diff --git a/src/gt4py/__about__.py b/src/gt4py/__about__.py index 10f4607724..7107c1669a 100644 --- a/src/gt4py/__about__.py +++ b/src/gt4py/__about__.py @@ -33,5 +33,5 @@ __license__: Final = "GPL-3.0-or-later" -__version__: Final = "1.0.2" +__version__: Final = "1.0.3" __version_info__: Final = pkg_version.parse(__version__) From e24f52d777837276b2657afb8410ab6941417bff Mon Sep 17 00:00:00 2001 From: Till Ehrengruber Date: Thu, 8 Feb 2024 11:53:05 +0100 Subject: [PATCH 74/85] feature[next]: Temporary extraction heuristics (#1341) Adds a heuristics that only extracts a temporary if the respective lift expr is derefed in more than one position. This should give reasonably good performance and avoids many unnecessary temporaries. --- .../next/iterator/transforms/global_tmps.py | 117 +++++++++--- .../next/iterator/transforms/pass_manager.py | 50 ++++-- src/gt4py/next/iterator/type_inference.py | 4 + .../codegens/gtfn/gtfn_module.py | 10 +- .../next/program_processors/runners/gtfn.py | 5 +- .../ffront_tests/test_execution.py | 2 +- .../transforms_tests/test_global_tmps.py | 166 +++++++++--------- 7 files changed, 224 insertions(+), 130 deletions(-) diff --git a/src/gt4py/next/iterator/transforms/global_tmps.py b/src/gt4py/next/iterator/transforms/global_tmps.py index c423a3c277..4f4fd053b2 100644 --- a/src/gt4py/next/iterator/transforms/global_tmps.py +++ b/src/gt4py/next/iterator/transforms/global_tmps.py @@ -15,7 +15,7 @@ import dataclasses import functools from collections.abc import Mapping -from typing import Any, Final, Iterable, Literal, Optional, Sequence +from typing import Any, Callable, Final, Iterable, Literal, Optional, Sequence import gt4py.eve as eve import gt4py.next as gtx @@ -150,20 +150,54 @@ def canonicalize_applied_lift(closure_params: list[str], node: ir.FunCall) -> ir return node -def temporary_extraction_predicate(expr: ir.Node, num_occurences: int) -> bool: - """Determine if `expr` is an applied lift that should be extracted as a temporary.""" - if not is_applied_lift(expr): - return False - # do not extract when the result is a list as we can not create temporaries for - # these stencils - if isinstance(expr.annex.type.dtype, type_inference.List): - return False - stencil = expr.fun.args[0] # type: ignore[attr-defined] # ensured by `is_applied_lift` - used_symbols = collect_symbol_refs(stencil) - # do not extract when the stencil is capturing - if used_symbols: +@dataclasses.dataclass(frozen=True) +class TemporaryExtractionPredicate: + """ + Construct a callable that determines if a lift expr can and should be extracted to a temporary. + + The class optionally takes a heuristic that can restrict the extraction. + """ + + heuristics: Optional[Callable[[ir.Expr], bool]] = None + + def __call__(self, expr: ir.Expr, num_occurences: int) -> bool: + """Determine if `expr` is an applied lift that should be extracted as a temporary.""" + if not is_applied_lift(expr): + return False + # do not extract when the result is a list (i.e. a lift expression used in a `reduce` call) + # as we can not create temporaries for these stencils + if isinstance(expr.annex.type.dtype, type_inference.List): + return False + if self.heuristics and not self.heuristics(expr): + return False + stencil = expr.fun.args[0] # type: ignore[attr-defined] # ensured by `is_applied_lift` + # do not extract when the stencil is capturing + used_symbols = collect_symbol_refs(stencil) + if used_symbols: + return False + return True + + +@dataclasses.dataclass(frozen=True) +class SimpleTemporaryExtractionHeuristics: + """ + Heuristic that extracts only if a lift expr is derefed in more than one position. + + Note that such expression result in redundant computations if inlined instead of being + placed into a temporary. + """ + + closure: ir.StencilClosure + + @functools.cached_property + def closure_shifts(self): + return trace_shifts.TraceShifts.apply(self.closure, inputs_only=False) + + def __call__(self, expr: ir.Expr) -> bool: + shifts = self.closure_shifts[id(expr)] + if len(shifts) > 1: + return True return False - return True def _closure_parameter_argument_mapping(closure: ir.StencilClosure): @@ -193,7 +227,14 @@ def _ensure_expr_does_not_capture(expr: ir.Expr, whitelist: list[ir.Sym]) -> Non assert not (set(used_symbol_refs) - {param.id for param in whitelist}) -def split_closures(node: ir.FencilDefinition, offset_provider) -> FencilWithTemporaries: +def split_closures( + node: ir.FencilDefinition, + offset_provider, + *, + extraction_heuristics: Optional[ + Callable[[ir.StencilClosure], Callable[[ir.Expr], bool]] + ] = None, +) -> FencilWithTemporaries: """Split closures on lifted function calls and introduce new temporary buffers for return values. Newly introduced temporaries will have the symbolic size of `AUTO_DOMAIN`. A symbol with the @@ -205,6 +246,13 @@ def split_closures(node: ir.FencilDefinition, offset_provider) -> FencilWithTemp 3. Extract lifted function class as new closures with the previously created temporary as output. The closures are processed in reverse order to properly respect the dependencies. """ + if not extraction_heuristics: + # extract all (eligible) lifts + def always_extract_heuristics(_): + return lambda _: True + + extraction_heuristics = always_extract_heuristics + uid_gen_tmps = UIDGenerator(prefix="_tmp") type_inference.infer_all(node, offset_provider=offset_provider, save_to_annex=True) @@ -228,9 +276,13 @@ def split_closures(node: ir.FencilDefinition, offset_provider) -> FencilWithTemp current_closure.stencil if not is_scan else current_closure.stencil.args[0] # type: ignore[attr-defined] # ensured by is_scan ) + extraction_predicate = TemporaryExtractionPredicate( + extraction_heuristics(current_closure) + ) + stencil_body, extracted_lifts, _ = extract_subexpression( current_closure_stencil.expr, - temporary_extraction_predicate, + extraction_predicate, uid_gen_tmps, once_only=True, deepest_expr_first=True, @@ -454,7 +506,12 @@ def update_domains( if closure.domain == AUTO_DOMAIN: # every closure with auto domain should have a single out field assert isinstance(closure.output, ir.SymRef) + + if closure.output.id not in domains: + raise NotImplementedError(f"Closure output '{closure.output.id}' is never used.") + domain = domains[closure.output.id] + closure = ir.StencilClosure( domain=copy.deepcopy(domain), stencil=closure.stencil, @@ -467,14 +524,6 @@ def update_domains( closures.append(closure) - if closure.stencil == ir.SymRef(id="deref"): - # all closure inputs inherit the domain - for input_arg in _tuple_constituents(closure.inputs[0]): - assert isinstance(input_arg, ir.SymRef) - assert domains.get(input_arg.id, domain) == domain - domains[input_arg.id] = domain - continue - local_shifts = trace_shifts.TraceShifts.apply(closure) for param, shift_chains in local_shifts.items(): assert isinstance(param, str) @@ -512,13 +561,22 @@ def update_domains( (axis, range_) if axis != old_axis else (new_axis, new_range) for axis, range_ in consumed_domain.ranges.items() ) + # TODO(tehrengruber): Revisit. Somehow the order matters so preserve it. + consumed_domain.ranges = dict( + (axis, range_) if axis != old_axis else (new_axis, new_range) + for axis, range_ in consumed_domain.ranges.items() + ) else: - raise NotImplementedError + raise NotImplementedError() consumed_domains.append(consumed_domain) # compute the bounds of all consumed domains if consumed_domains: - domains[param] = domain_union(consumed_domains).as_expr() + if all( + consumed_domain.ranges.keys() == consumed_domains[0].ranges.keys() + for consumed_domain in consumed_domains + ): # scalar otherwise + domains[param] = domain_union(consumed_domains).as_expr() return FencilWithTemporaries( fencil=ir.FencilDefinition( @@ -597,10 +655,15 @@ def visit_FencilDefinition( node: ir.FencilDefinition, *, offset_provider: Mapping[str, Any], + extraction_heuristics: Optional[ + Callable[[ir.StencilClosure], Callable[[ir.Expr], bool]] + ] = None, symbolic_sizes: Optional[dict[str, str]], ) -> FencilWithTemporaries: # Split closures on lifted function calls and introduce temporaries - res = split_closures(node, offset_provider=offset_provider) + res = split_closures( + node, offset_provider=offset_provider, extraction_heuristics=extraction_heuristics + ) # Prune unreferences closure inputs introduced in the previous step res = PruneClosureInputs().visit(res) # Prune unused temporaries possibly introduced in the previous step diff --git a/src/gt4py/next/iterator/transforms/pass_manager.py b/src/gt4py/next/iterator/transforms/pass_manager.py index 08897861c2..fe14a8f580 100644 --- a/src/gt4py/next/iterator/transforms/pass_manager.py +++ b/src/gt4py/next/iterator/transforms/pass_manager.py @@ -13,7 +13,7 @@ # SPDX-License-Identifier: GPL-3.0-or-later import enum -from typing import Optional +from typing import Callable, Optional from gt4py.next.iterator import ir from gt4py.next.iterator.transforms import simple_inline_heuristic @@ -51,8 +51,6 @@ def _inline_lifts(ir, lift_mode): return InlineLifts( flags=InlineLifts.Flag.INLINE_TRIVIAL_DEREF_LIFT | InlineLifts.Flag.INLINE_DEREF_LIFT # some tuple exprs found in FVM don't work yet. - | InlineLifts.Flag.INLINE_LIFTED_ARGS - # needed for UnrollReduce and lift args like `(↑(λ() → constant)` ).visit(ir) else: raise ValueError() @@ -73,6 +71,8 @@ def _inline_into_scan(ir, *, max_iter=10): return ir +# TODO(tehrengruber): Revisit interface to configure temporary extraction. We currently forward +# `lift_mode` and `temporary_extraction_heuristics` which is inconvenient. def apply_common_transforms( ir: ir.Node, *, @@ -82,6 +82,9 @@ def apply_common_transforms( common_subexpression_elimination=True, force_inline_lambda_args=False, unconditionally_collapse_tuples=False, + temporary_extraction_heuristics: Optional[ + Callable[[ir.StencilClosure], Callable[[ir.Expr], bool]] + ] = None, symbolic_domain_sizes: Optional[dict[str, str]] = None, ): if lift_mode is None: @@ -121,6 +124,33 @@ def apply_common_transforms( else: raise RuntimeError("Inlining 'lift' and 'lambdas' did not converge.") + if lift_mode != LiftMode.FORCE_INLINE: + assert offset_provider is not None + ir = CreateGlobalTmps().visit( + ir, + offset_provider=offset_provider, + extraction_heuristics=temporary_extraction_heuristics, + symbolic_sizes=symbolic_domain_sizes, + ) + + for _ in range(10): + inlined = InlineLifts().visit(ir) + inlined = InlineLambdas.apply( + inlined, + opcount_preserving=True, + force_inline_lift_args=True, + ) + if inlined == ir: + break + ir = inlined + else: + raise RuntimeError("Inlining 'lift' and 'lambdas' did not converge.") + + # If after creating temporaries, the scan is not at the top, we inline. + # The following example doesn't have a lift around the shift, i.e. temporary pass will not extract it. + # λ(inp) → scan(λ(state, k, kp) → state + ·k + ·kp, True, 0.0)(inp, ⟪Koffₒ, 1ₒ⟫(inp))` + ir = _inline_into_scan(ir) + # Since `CollapseTuple` relies on the type inference which does not support returning tuples # larger than the number of closure outputs as given by the unconditional collapse, we can # only run the unconditional version here instead of in the loop above. @@ -134,6 +164,7 @@ def apply_common_transforms( ir = FuseMaps().visit(ir) ir = CollapseListGet().visit(ir) + if unroll_reduce: for _ in range(10): unrolled = UnrollReduce.apply(ir, offset_provider=offset_provider) @@ -142,22 +173,11 @@ def apply_common_transforms( ir = unrolled ir = CollapseListGet().visit(ir) ir = NormalizeShifts().visit(ir) - ir = _inline_lifts(ir, lift_mode) + ir = _inline_lifts(ir, LiftMode.FORCE_INLINE) ir = NormalizeShifts().visit(ir) else: raise RuntimeError("Reduction unrolling failed.") - if lift_mode != LiftMode.FORCE_INLINE: - assert offset_provider is not None - ir = CreateGlobalTmps().visit( - ir, offset_provider=offset_provider, symbolic_sizes=symbolic_domain_sizes - ) - ir = InlineLifts().visit(ir) - # If after creating temporaries, the scan is not at the top, we inline. - # The following example doesn't have a lift around the shift, i.e. temporary pass will not extract it. - # λ(inp) → scan(λ(state, k, kp) → state + ·k + ·kp, True, 0.0)(inp, ⟪Koffₒ, 1ₒ⟫(inp))` - ir = _inline_into_scan(ir) - ir = EtaReduction().visit(ir) ir = ScanEtaReduction().visit(ir) diff --git a/src/gt4py/next/iterator/type_inference.py b/src/gt4py/next/iterator/type_inference.py index d65f67b266..683a57561c 100644 --- a/src/gt4py/next/iterator/type_inference.py +++ b/src/gt4py/next/iterator/type_inference.py @@ -21,6 +21,7 @@ import gt4py.next as gtx from gt4py.next.common import Connectivity from gt4py.next.iterator import ir +from gt4py.next.iterator.transforms.global_tmps import FencilWithTemporaries from gt4py.next.type_inference import Type, TypeVar, freshen, reindex_vars, unify @@ -936,6 +937,9 @@ def visit_StencilClosure( ) return Closure(output=output, inputs=Tuple.from_elems(*inputs)) + def visit_FencilWithTemporaries(self, node: FencilWithTemporaries, **kwargs): + return self.visit(node.fencil, **kwargs) + def visit_FencilDefinition( self, node: ir.FencilDefinition, diff --git a/src/gt4py/next/program_processors/codegens/gtfn/gtfn_module.py b/src/gt4py/next/program_processors/codegens/gtfn/gtfn_module.py index 718fef72af..c157cdcc46 100644 --- a/src/gt4py/next/program_processors/codegens/gtfn/gtfn_module.py +++ b/src/gt4py/next/program_processors/codegens/gtfn/gtfn_module.py @@ -17,7 +17,7 @@ import dataclasses import functools import warnings -from typing import Any, Final, Optional +from typing import Any, Callable, Final, Optional import numpy as np @@ -58,6 +58,9 @@ class GTFNTranslationStep( lift_mode: Optional[LiftMode] = None device_type: core_defs.DeviceType = core_defs.DeviceType.CPU symbolic_domain_sizes: Optional[dict[str, str]] = None + temporary_extraction_heuristics: Optional[ + Callable[[itir.StencilClosure], Callable[[itir.Expr], bool]] + ] = None def _default_language_settings(self) -> languages.LanguageWithHeaderFilesSettings: match self.device_type: @@ -179,14 +182,14 @@ def _preprocess_program( self, program: itir.FencilDefinition, offset_provider: dict[str, Connectivity | Dimension], - runtime_lift_mode: Optional[LiftMode] = None, + runtime_lift_mode: Optional[LiftMode], ) -> itir.FencilDefinition: # TODO(tehrengruber): Remove `lift_mode` from call interface. It has been implicitly added # to the interface of all (or at least all of concern) backends, but instead should be # configured in the backend itself (like it is here), until then we respect the argument # here and warn the user if it differs from the one configured. lift_mode = runtime_lift_mode or self.lift_mode - if lift_mode != self.lift_mode: + if runtime_lift_mode and runtime_lift_mode != self.lift_mode: warnings.warn( f"GTFN Backend was configured for LiftMode `{str(self.lift_mode)}`, but " f"overriden to be {str(runtime_lift_mode)} at runtime." @@ -202,6 +205,7 @@ def _preprocess_program( # sid::composite (via hymap) supports assigning from tuple with more elements to tuple with fewer elements unconditionally_collapse_tuples=True, symbolic_domain_sizes=self.symbolic_domain_sizes, + temporary_extraction_heuristics=self.temporary_extraction_heuristics, ) new_program = apply_common_transforms( diff --git a/src/gt4py/next/program_processors/runners/gtfn.py b/src/gt4py/next/program_processors/runners/gtfn.py index baa45ddc0e..157c00c368 100644 --- a/src/gt4py/next/program_processors/runners/gtfn.py +++ b/src/gt4py/next/program_processors/runners/gtfn.py @@ -22,7 +22,7 @@ import gt4py.next.allocators as next_allocators from gt4py.eve.utils import content_hash from gt4py.next import common -from gt4py.next.iterator.transforms import LiftMode +from gt4py.next.iterator.transforms import LiftMode, global_tmps from gt4py.next.otf import languages, recipes, stages, step_types, workflow from gt4py.next.otf.binding import nanobind from gt4py.next.otf.compilation import cache, compiler @@ -187,7 +187,8 @@ def compilation_hash(otf_closure: stages.ProgramCall) -> int: name="run_gtfn_with_temporaries", otf_workflow=gtfn_executor.otf_workflow.replace( translation=gtfn_executor.otf_workflow.translation.replace( - lift_mode=LiftMode.FORCE_TEMPORARIES + lift_mode=LiftMode.FORCE_TEMPORARIES, + temporary_extraction_heuristics=global_tmps.SimpleTemporaryExtractionHeuristics, ), ), ), diff --git a/tests/next_tests/integration_tests/feature_tests/ffront_tests/test_execution.py b/tests/next_tests/integration_tests/feature_tests/ffront_tests/test_execution.py index 7fc2d82e67..ae5e434085 100644 --- a/tests/next_tests/integration_tests/feature_tests/ffront_tests/test_execution.py +++ b/tests/next_tests/integration_tests/feature_tests/ffront_tests/test_execution.py @@ -622,7 +622,7 @@ def simple_scan_operator(carry: float) -> float: @pytest.mark.uses_lift_expressions @pytest.mark.uses_scan_nested def test_solve_triag(cartesian_case): - if cartesian_case.executor == gtfn.run_gtfn_with_temporaries: + if cartesian_case.executor == gtfn.run_gtfn_with_temporaries.executor: pytest.xfail("Temporary extraction does not work correctly in combination with scans.") @gtx.scan_operator(axis=KDim, forward=True, init=(0.0, 0.0)) diff --git a/tests/next_tests/unit_tests/iterator_tests/transforms_tests/test_global_tmps.py b/tests/next_tests/unit_tests/iterator_tests/transforms_tests/test_global_tmps.py index 5c2802f90c..46ca02217f 100644 --- a/tests/next_tests/unit_tests/iterator_tests/transforms_tests/test_global_tmps.py +++ b/tests/next_tests/unit_tests/iterator_tests/transforms_tests/test_global_tmps.py @@ -20,6 +20,7 @@ from gt4py.next.iterator.transforms.global_tmps import ( AUTO_DOMAIN, FencilWithTemporaries, + SimpleTemporaryExtractionHeuristics, Temporary, collect_tmps_info, split_closures, @@ -32,53 +33,23 @@ def test_split_closures(): testee = ir.FencilDefinition( id="f", function_definitions=[], - params=[ir.Sym(id="d"), ir.Sym(id="inp"), ir.Sym(id="out")], + params=[im.sym("d"), im.sym("inp"), im.sym("out")], closures=[ ir.StencilClosure( - domain=ir.FunCall(fun=ir.SymRef(id="cartesian_domain"), args=[]), - stencil=ir.Lambda( - params=[ir.Sym(id="baz_inp")], - expr=ir.FunCall( - fun=ir.SymRef(id="deref"), - args=[ - ir.FunCall( - fun=ir.FunCall( - fun=ir.SymRef(id="lift"), - args=[ - ir.Lambda( - params=[ir.Sym(id="bar_inp")], - expr=ir.FunCall( - fun=ir.SymRef(id="deref"), - args=[ - ir.FunCall( - fun=ir.FunCall( - fun=ir.SymRef(id="lift"), - args=[ - ir.Lambda( - params=[ir.Sym(id="foo_inp")], - expr=ir.FunCall( - fun=ir.SymRef(id="deref"), - args=[ - ir.SymRef(id="foo_inp") - ], - ), - ) - ], - ), - args=[ir.SymRef(id="bar_inp")], - ) - ], - ), - ) - ], - ), - args=[ir.SymRef(id="baz_inp")], + domain=im.call("cartesian_domain")(), + stencil=im.lambda_("baz_inp")( + im.deref( + im.lift( + im.lambda_("bar_inp")( + im.deref( + im.lift(im.lambda_("foo_inp")(im.deref("foo_inp")))("bar_inp") + ) ) - ], - ), + )("baz_inp") + ) ), - output=ir.SymRef(id="out"), - inputs=[ir.SymRef(id="inp")], + output=im.ref("out"), + inputs=[im.ref("inp")], ) ], ) @@ -87,54 +58,31 @@ def test_split_closures(): id="f", function_definitions=[], params=[ - ir.Sym(id="d"), - ir.Sym(id="inp"), - ir.Sym(id="out"), - ir.Sym(id="_tmp_1"), - ir.Sym(id="_tmp_2"), - ir.Sym(id="_gtmp_auto_domain"), + im.sym("d"), + im.sym("inp"), + im.sym("out"), + im.sym("_tmp_1"), + im.sym("_tmp_2"), + im.sym("_gtmp_auto_domain"), ], closures=[ ir.StencilClosure( domain=AUTO_DOMAIN, - stencil=ir.Lambda( - params=[ir.Sym(id="foo_inp")], - expr=ir.FunCall( - fun=ir.SymRef(id="deref"), - args=[ir.SymRef(id="foo_inp")], - ), - ), - output=ir.SymRef(id="_tmp_2"), - inputs=[ir.SymRef(id="inp")], + stencil=im.lambda_("foo_inp")(im.deref("foo_inp")), + output=im.ref("_tmp_2"), + inputs=[im.ref("inp")], ), ir.StencilClosure( domain=AUTO_DOMAIN, - stencil=ir.Lambda( - params=[ - ir.Sym(id="bar_inp"), - ir.Sym(id="_tmp_2"), - ], - expr=ir.FunCall( - fun=ir.SymRef(id="deref"), - args=[ - ir.SymRef(id="_tmp_2"), - ], - ), - ), - output=ir.SymRef(id="_tmp_1"), - inputs=[ir.SymRef(id="inp"), ir.SymRef(id="_tmp_2")], + stencil=im.lambda_("bar_inp", "_tmp_2")(im.deref("_tmp_2")), + output=im.ref("_tmp_1"), + inputs=[im.ref("inp"), im.ref("_tmp_2")], ), ir.StencilClosure( - domain=ir.FunCall(fun=ir.SymRef(id="cartesian_domain"), args=[]), - stencil=ir.Lambda( - params=[ir.Sym(id="baz_inp"), ir.Sym(id="_tmp_1")], - expr=ir.FunCall( - fun=ir.SymRef(id="deref"), - args=[ir.SymRef(id="_tmp_1")], - ), - ), - output=ir.SymRef(id="out"), - inputs=[ir.SymRef(id="inp"), ir.SymRef(id="_tmp_1")], + domain=im.call("cartesian_domain")(), + stencil=im.lambda_("baz_inp", "_tmp_1")(im.deref("_tmp_1")), + output=im.ref("out"), + inputs=[im.ref("inp"), im.ref("_tmp_1")], ), ], ) @@ -143,6 +91,60 @@ def test_split_closures(): assert actual.fencil == expected +def test_split_closures_simple_heuristics(): + UIDs.reset_sequence() + testee = ir.FencilDefinition( + id="f", + function_definitions=[], + params=[im.sym("d"), im.sym("inp"), im.sym("out")], + closures=[ + ir.StencilClosure( + domain=im.call("cartesian_domain")(), + stencil=im.lambda_("foo")( + im.let("lifted_it", im.lift(im.lambda_("bar")(im.deref("bar")))("foo"))( + im.plus(im.deref("lifted_it"), im.deref(im.shift("I", 1)("lifted_it"))) + ) + ), + output=im.ref("out"), + inputs=[im.ref("inp")], + ) + ], + ) + + expected = ir.FencilDefinition( + id="f", + function_definitions=[], + params=[ + im.sym("d"), + im.sym("inp"), + im.sym("out"), + im.sym("_tmp_1"), + im.sym("_gtmp_auto_domain"), + ], + closures=[ + ir.StencilClosure( + domain=AUTO_DOMAIN, + stencil=im.lambda_("bar")(im.deref("bar")), + output=im.ref("_tmp_1"), + inputs=[im.ref("inp")], + ), + ir.StencilClosure( + domain=im.call("cartesian_domain")(), + stencil=im.lambda_("foo", "_tmp_1")( + im.plus(im.deref("_tmp_1"), im.deref(im.shift("I", 1)("_tmp_1"))) + ), + output=im.ref("out"), + inputs=[im.ref("inp"), im.ref("_tmp_1")], + ), + ], + ) + actual = split_closures( + testee, extraction_heuristics=SimpleTemporaryExtractionHeuristics, offset_provider={} + ) + assert actual.tmps == [Temporary(id="_tmp_1")] + assert actual.fencil == expected + + def test_split_closures_lifted_scan(): UIDs.reset_sequence() From 54a28870f6e6c302dd270df9c3d5afea659d4dd9 Mon Sep 17 00:00:00 2001 From: edopao Date: Thu, 8 Feb 2024 16:51:38 +0100 Subject: [PATCH 75/85] Fix to detect CUPY gpu runtime (#1448) This PR changes the way we detect the CUPy runtime: check cp.cuda.runtime.is_hip instead of calling cp.cuda.get_hipcc_path(). The previous method did not work on Clariden gpu node. --- src/gt4py/next/allocators.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/gt4py/next/allocators.py b/src/gt4py/next/allocators.py index 559e78eb3e..30775dbab9 100644 --- a/src/gt4py/next/allocators.py +++ b/src/gt4py/next/allocators.py @@ -44,7 +44,7 @@ CUPY_DEVICE: Final[Literal[None, core_defs.DeviceType.CUDA, core_defs.DeviceType.ROCM]] = ( None if not cp - else (core_defs.DeviceType.ROCM if cp.cuda.get_hipcc_path() else core_defs.DeviceType.CUDA) + else (core_defs.DeviceType.ROCM if cp.cuda.runtime.is_hip else core_defs.DeviceType.CUDA) ) From 374f043b86d39f3f5b1971965c17913a9cc8ab62 Mon Sep 17 00:00:00 2001 From: edopao Date: Thu, 8 Feb 2024 16:56:44 +0100 Subject: [PATCH 76/85] Update CI config to use new template for container build (#1450) --- ci/cscs-ci.yml | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/ci/cscs-ci.yml b/ci/cscs-ci.yml index 971a3cfc35..a46929537b 100644 --- a/ci/cscs-ci.yml +++ b/ci/cscs-ci.yml @@ -19,7 +19,7 @@ stages: - test build py38 baseimage: - extends: .container-builder + extends: .container-builder-cscs-zen2 stage: baseimage # we create a tag that depends on the SHA value of ci/base.Dockerfile, this way # a new base image is only built when the SHA of this file changes @@ -52,7 +52,7 @@ build py310 baseimage: <<: *py310 build py38 image: - extends: .container-builder + extends: .container-builder-cscs-zen2 needs: ["build py38 baseimage"] stage: image variables: From 1d305e14db62ab50a0bf64b14770201f420b06fd Mon Sep 17 00:00:00 2001 From: Nicoletta Farabullini <41536517+nfarabullini@users.noreply.github.com> Date: Fri, 9 Feb 2024 10:06:04 +0100 Subject: [PATCH 77/85] feat[next]: Slicing field to 0d to return field not scalar (#1427) * return array for 0d field from slicing instead of scalar --- src/gt4py/_core/definitions.py | 2 ++ src/gt4py/next/common.py | 14 ++++++---- src/gt4py/next/embedded/nd_array_field.py | 19 +++++++------ src/gt4py/next/embedded/operators.py | 3 +- src/gt4py/next/ffront/func_to_foast.py | 2 +- .../next/ffront/past_passes/type_deduction.py | 6 ++-- src/gt4py/next/ffront/past_to_itir.py | 9 +++--- src/gt4py/next/iterator/embedded.py | 28 +++++++++++++++---- .../ffront_tests/test_execution.py | 15 ++++++++++ .../test_horizontal_indirection.py | 4 +-- .../iterator_tests/test_anton_toy.py | 8 ++++-- .../embedded_tests/test_nd_array_field.py | 12 +++++--- 12 files changed, 84 insertions(+), 38 deletions(-) diff --git a/src/gt4py/_core/definitions.py b/src/gt4py/_core/definitions.py index 6237704f69..a550db4f2e 100644 --- a/src/gt4py/_core/definitions.py +++ b/src/gt4py/_core/definitions.py @@ -444,6 +444,8 @@ def shape(self) -> tuple[int, ...]: ... @property def dtype(self) -> Any: ... + def item(self) -> Any: ... + def astype(self, dtype: npt.DTypeLike) -> NDArrayObject: ... def __getitem__(self, item: Any) -> NDArrayObject: ... diff --git a/src/gt4py/next/common.py b/src/gt4py/next/common.py index 90e76d671d..fdf515d2f8 100644 --- a/src/gt4py/next/common.py +++ b/src/gt4py/next/common.py @@ -623,14 +623,17 @@ def asnumpy(self) -> np.ndarray: ... def remap(self, index_field: ConnectivityField | fbuiltins.FieldOffset) -> Field: ... @abc.abstractmethod - def restrict(self, item: AnyIndexSpec) -> Field | core_defs.ScalarT: ... + def restrict(self, item: AnyIndexSpec) -> Field: ... + + @abc.abstractmethod + def as_scalar(self) -> core_defs.ScalarT: ... # Operators @abc.abstractmethod def __call__(self, index_field: ConnectivityField | fbuiltins.FieldOffset) -> Field: ... @abc.abstractmethod - def __getitem__(self, item: AnyIndexSpec) -> Field | core_defs.ScalarT: ... + def __getitem__(self, item: AnyIndexSpec) -> Field: ... @abc.abstractmethod def __abs__(self) -> Field: ... @@ -896,6 +899,9 @@ def ndarray(self) -> Never: def asnumpy(self) -> Never: raise NotImplementedError() + def as_scalar(self) -> Never: + raise NotImplementedError() + @functools.cached_property def domain(self) -> Domain: return Domain(dims=(self.dimension,), ranges=(UnitRange.infinite(),)) @@ -947,9 +953,7 @@ def remap(self, index_field: ConnectivityField | fbuiltins.FieldOffset) -> Conne __call__ = remap - def restrict(self, index: AnyIndexSpec) -> core_defs.IntegralScalar: - if is_int_index(index): - return index + self.offset + def restrict(self, index: AnyIndexSpec) -> Never: raise NotImplementedError() # we could possibly implement with a FunctionField, but we don't have a use-case __getitem__ = restrict diff --git a/src/gt4py/next/embedded/nd_array_field.py b/src/gt4py/next/embedded/nd_array_field.py index 65a71718e4..c39408ba3a 100644 --- a/src/gt4py/next/embedded/nd_array_field.py +++ b/src/gt4py/next/embedded/nd_array_field.py @@ -120,6 +120,13 @@ def asnumpy(self) -> np.ndarray: else: return np.asarray(self._ndarray) + def as_scalar(self) -> core_defs.ScalarT: + if self.domain.ndim != 0: + raise ValueError( + "'as_scalar' is only valid on 0-dimensional 'Field's, got a {self.domain.ndim}-dimensional 'Field'." + ) + return self.ndarray.item() + @property def codomain(self) -> type[core_defs.ScalarT]: return self.dtype.scalar_type @@ -204,15 +211,11 @@ def remap( __call__ = remap # type: ignore[assignment] - def restrict(self, index: common.AnyIndexSpec) -> common.Field | core_defs.ScalarT: + def restrict(self, index: common.AnyIndexSpec) -> common.Field: new_domain, buffer_slice = self._slice(index) - new_buffer = self.ndarray[buffer_slice] - if len(new_domain) == 0: - # TODO: assert core_defs.is_scalar_type(new_buffer), new_buffer - return new_buffer # type: ignore[return-value] # I don't think we can express that we return `ScalarT` here - else: - return self.__class__.from_array(new_buffer, domain=new_domain) + new_buffer = self.__class__.array_ns.asarray(new_buffer) + return self.__class__.from_array(new_buffer, domain=new_domain) __getitem__ = restrict @@ -433,7 +436,7 @@ def inverse_image( return new_dims - def restrict(self, index: common.AnyIndexSpec) -> common.Field | core_defs.IntegralScalar: + def restrict(self, index: common.AnyIndexSpec) -> common.Field: cache_key = (id(self.ndarray), self.domain, index) if (restricted_connectivity := self._cache.get(cache_key, None)) is None: diff --git a/src/gt4py/next/embedded/operators.py b/src/gt4py/next/embedded/operators.py index cb03373b41..fc3ccda335 100644 --- a/src/gt4py/next/embedded/operators.py +++ b/src/gt4py/next/embedded/operators.py @@ -187,8 +187,7 @@ def _tuple_at( ) -> core_defs.Scalar | tuple[core_defs.ScalarT | tuple, ...]: @utils.tree_map def impl(field: common.Field | core_defs.Scalar) -> core_defs.Scalar: - res = field[pos] if common.is_field(field) else field - res = res.item() if hasattr(res, "item") else res # extract scalar value from array + res = field[pos].as_scalar() if common.is_field(field) else field assert core_defs.is_scalar_type(res) return res diff --git a/src/gt4py/next/ffront/func_to_foast.py b/src/gt4py/next/ffront/func_to_foast.py index 0fd263308e..0831fc3bb2 100644 --- a/src/gt4py/next/ffront/func_to_foast.py +++ b/src/gt4py/next/ffront/func_to_foast.py @@ -289,7 +289,7 @@ def visit_Subscript(self, node: ast.Subscript, **kwargs) -> foast.Subscript: index = self._match_index(node.slice) except ValueError: raise errors.DSLError( - self.get_location(node.slice), "eXpected an integral index." + self.get_location(node.slice), "Expected an integral index." ) from None return foast.Subscript( diff --git a/src/gt4py/next/ffront/past_passes/type_deduction.py b/src/gt4py/next/ffront/past_passes/type_deduction.py index af8f5e8368..0e5be1eabd 100644 --- a/src/gt4py/next/ffront/past_passes/type_deduction.py +++ b/src/gt4py/next/ffront/past_passes/type_deduction.py @@ -217,12 +217,12 @@ def visit_Call(self, node: past.Call, **kwargs): f"'{new_kwargs['out'].type}'." ) elif new_func.id in ["minimum", "maximum"]: - if new_args[0].type != new_args[1].type: + if arg_types[0] != arg_types[1]: raise ValueError( f"First and second argument in '{new_func.id}' must be of the same type." - f"Got '{new_args[0].type}' and '{new_args[1].type}'." + f"Got '{arg_types[0]}' and '{arg_types[1]}'." ) - return_type = new_args[0].type + return_type = arg_types[0] else: raise AssertionError( "Only calls to 'FieldOperator', 'ScanOperator' or 'minimum' and 'maximum' builtins allowed." diff --git a/src/gt4py/next/ffront/past_to_itir.py b/src/gt4py/next/ffront/past_to_itir.py index ed239e0436..620e98dd4d 100644 --- a/src/gt4py/next/ffront/past_to_itir.py +++ b/src/gt4py/next/ffront/past_to_itir.py @@ -305,7 +305,7 @@ def _visit_stencil_call_out_arg( ) -> tuple[itir.Expr, itir.FunCall]: if isinstance(out_arg, past.Subscript): # as the ITIR does not support slicing a field we have to do a deeper - # inspection of the PAST to emulate the behaviour + # inspection of the PAST to emulate the behaviour out_field_name: past.Name = out_arg.value return ( self._construct_itir_out_arg(out_field_name), @@ -382,12 +382,11 @@ def visit_BinOp(self, node: past.BinOp, **kwargs) -> itir.FunCall: ) def visit_Call(self, node: past.Call, **kwargs) -> itir.FunCall: - if node.func.id in ["maximum", "minimum"] and len(node.args) == 2: + if node.func.id in ["maximum", "minimum"]: + assert len(node.args) == 2 return itir.FunCall( fun=itir.SymRef(id=node.func.id), args=[self.visit(node.args[0]), self.visit(node.args[1])], ) else: - raise AssertionError( - "Only 'minimum' and 'maximum' builtins supported supported currently." - ) + raise NotImplementedError("Only 'minimum', and 'maximum' builtins supported currently.") diff --git a/src/gt4py/next/iterator/embedded.py b/src/gt4py/next/iterator/embedded.py index 011ca4d92b..a45b81a773 100644 --- a/src/gt4py/next/iterator/embedded.py +++ b/src/gt4py/next/iterator/embedded.py @@ -919,7 +919,7 @@ def _translate_named_indices( return tuple(domain_slice) def field_getitem(self, named_indices: NamedFieldIndices) -> Any: - return self._ndarrayfield[self._translate_named_indices(named_indices)] + return self._ndarrayfield[self._translate_named_indices(named_indices)].as_scalar() def field_setitem(self, named_indices: NamedFieldIndices, value: Any): if common.is_mutable_field(self._ndarrayfield): @@ -1040,6 +1040,7 @@ class IndexField(common.Field): """ _dimension: common.Dimension + _cur_index: Optional[core_defs.IntegralScalar] = None @property def __gt_domain__(self) -> common.Domain: @@ -1055,7 +1056,10 @@ def __gt_builtin_func__(func: Callable, /) -> NoReturn: # type: ignore[override @property def domain(self) -> common.Domain: - return common.Domain((self._dimension, common.UnitRange.infinite())) + if self._cur_index is None: + return common.Domain((self._dimension, common.UnitRange.infinite())) + else: + return common.Domain() @property def codomain(self) -> type[core_defs.int32]: @@ -1072,16 +1076,24 @@ def ndarray(self) -> core_defs.NDArrayObject: def asnumpy(self) -> np.ndarray: raise NotImplementedError() + def as_scalar(self) -> core_defs.IntegralScalar: + if self.domain.ndim != 0: + raise ValueError( + "'as_scalar' is only valid on 0-dimensional 'Field's, got a {self.domain.ndim}-dimensional 'Field'." + ) + assert self._cur_index is not None + return self._cur_index + def remap(self, index_field: common.ConnectivityField | fbuiltins.FieldOffset) -> common.Field: # TODO can be implemented by constructing and ndarray (but do we know of which kind?) raise NotImplementedError() - def restrict(self, item: common.AnyIndexSpec) -> common.Field | core_defs.int32: + def restrict(self, item: common.AnyIndexSpec) -> common.Field: if common.is_absolute_index_sequence(item) and all(common.is_named_index(e) for e in item): # type: ignore[arg-type] # we don't want to pollute the typing of `is_absolute_index_sequence` for this temporary code # fmt: off d, r = item[0] assert d == self._dimension - assert isinstance(r, int) - return self.dtype.scalar_type(r) + assert isinstance(r, core_defs.INTEGRAL_TYPES) + return self.__class__(self._dimension, r) # type: ignore[arg-type] # not sure why the assert above does not work # TODO set a domain... raise NotImplementedError() @@ -1195,8 +1207,12 @@ def remap(self, index_field: common.ConnectivityField | fbuiltins.FieldOffset) - # TODO can be implemented by constructing and ndarray (but do we know of which kind?) raise NotImplementedError() - def restrict(self, item: common.AnyIndexSpec) -> common.Field | core_defs.ScalarT: + def restrict(self, item: common.AnyIndexSpec) -> common.Field: # TODO set a domain... + return self + + def as_scalar(self) -> core_defs.ScalarT: + assert self.domain.ndim == 0 return self._value __call__ = remap diff --git a/tests/next_tests/integration_tests/feature_tests/ffront_tests/test_execution.py b/tests/next_tests/integration_tests/feature_tests/ffront_tests/test_execution.py index ae5e434085..3c9c4e686c 100644 --- a/tests/next_tests/integration_tests/feature_tests/ffront_tests/test_execution.py +++ b/tests/next_tests/integration_tests/feature_tests/ffront_tests/test_execution.py @@ -321,6 +321,21 @@ def testee(inp: gtx.Field[[KDim], float]) -> gtx.Field[[KDim], float]: cases.verify(cartesian_case, testee, inp, out=out, ref=expected) +def test_single_value_field(cartesian_case): + @gtx.field_operator + def testee_fo(a: cases.IKField) -> cases.IKField: + return a + + @gtx.program + def testee_prog(a: cases.IKField): + testee_fo(a, out=a[1:2, 3:4]) + + a = cases.allocate(cartesian_case, testee_prog, "a")() + ref = a[1, 3] + + cases.verify(cartesian_case, testee_prog, a, inout=a[1, 3], ref=ref) + + def test_astype_int(cartesian_case): # noqa: F811 # fixtures @gtx.field_operator def testee(a: cases.IFloatField) -> gtx.Field[[IDim], int64]: diff --git a/tests/next_tests/integration_tests/feature_tests/iterator_tests/test_horizontal_indirection.py b/tests/next_tests/integration_tests/feature_tests/iterator_tests/test_horizontal_indirection.py index 69f594a2bc..e4540ba1b9 100644 --- a/tests/next_tests/integration_tests/feature_tests/iterator_tests/test_horizontal_indirection.py +++ b/tests/next_tests/integration_tests/feature_tests/iterator_tests/test_horizontal_indirection.py @@ -70,7 +70,7 @@ def test_simple_indirection(program_processor): ref = np.zeros(shape, dtype=inp.dtype) for i in range(shape[0]): - ref[i] = inp.ndarray[i + 1 - 1] if cond[i] < 0.0 else inp.ndarray[i + 1 + 1] + ref[i] = inp.asnumpy()[i + 1 - 1] if cond.asnumpy()[i] < 0.0 else inp.asnumpy()[i + 1 + 1] run_processor( conditional_indirection[cartesian_domain(named_range(IDim, 0, shape[0]))], @@ -101,7 +101,7 @@ def test_direct_offset_for_indirection(program_processor): ref = np.zeros(shape) for i in range(shape[0]): - ref[i] = inp[i + cond[i]] + ref[i] = inp.asnumpy()[i + cond.asnumpy()[i]] run_processor( direct_indirection[cartesian_domain(named_range(IDim, 0, shape[0]))], diff --git a/tests/next_tests/integration_tests/multi_feature_tests/iterator_tests/test_anton_toy.py b/tests/next_tests/integration_tests/multi_feature_tests/iterator_tests/test_anton_toy.py index 806ab7eb9a..9a1bc6deb6 100644 --- a/tests/next_tests/integration_tests/multi_feature_tests/iterator_tests/test_anton_toy.py +++ b/tests/next_tests/integration_tests/multi_feature_tests/iterator_tests/test_anton_toy.py @@ -65,11 +65,15 @@ def fencil(x, y, z, out, inp): def naive_lap(inp): shape = [inp.shape[0] - 2, inp.shape[1] - 2, inp.shape[2]] out = np.zeros(shape) + inp_data = inp.asnumpy() for i in range(1, shape[0] + 1): for j in range(1, shape[1] + 1): for k in range(0, shape[2]): - out[i - 1, j - 1, k] = -4 * inp[i, j, k] + ( - inp[i + 1, j, k] + inp[i - 1, j, k] + inp[i, j + 1, k] + inp[i, j - 1, k] + out[i - 1, j - 1, k] = -4 * inp_data[i, j, k] + ( + inp_data[i + 1, j, k] + + inp_data[i - 1, j, k] + + inp_data[i, j + 1, k] + + inp_data[i, j - 1, k] ) return out diff --git a/tests/next_tests/unit_tests/embedded_tests/test_nd_array_field.py b/tests/next_tests/unit_tests/embedded_tests/test_nd_array_field.py index 70fa274457..79830a75a1 100644 --- a/tests/next_tests/unit_tests/embedded_tests/test_nd_array_field.py +++ b/tests/next_tests/unit_tests/embedded_tests/test_nd_array_field.py @@ -468,10 +468,11 @@ def test_absolute_indexing_value_return(): field = common._field(np.reshape(np.arange(100, dtype=np.int32), (10, 10)), domain=domain) named_index = ((IDim, 12), (JDim, 6)) + assert common.is_field(field) value = field[named_index] - assert isinstance(value, np.int32) - assert value == 21 + assert common.is_field(value) + assert value.as_scalar() == 21 @pytest.mark.parametrize( @@ -568,14 +569,17 @@ def test_relative_indexing_slice_3D(index, expected_shape, expected_domain): @pytest.mark.parametrize( "index, expected_value", - [((1, 0), 10), ((0, 1), 1)], + [ + ((1, 0), 10), + ((0, 1), 1), + ], ) def test_relative_indexing_value_return(index, expected_value): domain = common.Domain(dims=(IDim, JDim), ranges=(UnitRange(5, 15), UnitRange(2, 12))) field = common._field(np.reshape(np.arange(100, dtype=int), (10, 10)), domain=domain) indexed_field = field[index] - assert indexed_field == expected_value + assert indexed_field.as_scalar() == expected_value @pytest.mark.parametrize("lazy_slice", [lambda f: f[13], lambda f: f[:5, :3, :2]]) From 29705759896d6ce039756ca84d0581034cff7b55 Mon Sep 17 00:00:00 2001 From: Till Ehrengruber Date: Tue, 13 Feb 2024 13:33:32 +0100 Subject: [PATCH 78/85] feature[next]: Improve CollapseTuple pass (#1350) Significantly improves the collapse tuple pass preparing wider support for `if` statements. --- .../ir_utils/common_pattern_matcher.py | 11 + src/gt4py/next/iterator/ir_utils/ir_makers.py | 30 +- src/gt4py/next/iterator/ir_utils/misc.py | 79 +++++ .../iterator/transforms/collapse_tuple.py | 282 +++++++++++++++--- .../next/iterator/transforms/pass_manager.py | 4 + .../iterator/transforms/propagate_deref.py | 9 + .../transforms_tests/test_collapse_tuple.py | 139 ++++++++- .../transforms_tests/test_cse.py | 17 +- 8 files changed, 513 insertions(+), 58 deletions(-) create mode 100644 src/gt4py/next/iterator/ir_utils/misc.py diff --git a/src/gt4py/next/iterator/ir_utils/common_pattern_matcher.py b/src/gt4py/next/iterator/ir_utils/common_pattern_matcher.py index 8df4723502..a4b074a4b6 100644 --- a/src/gt4py/next/iterator/ir_utils/common_pattern_matcher.py +++ b/src/gt4py/next/iterator/ir_utils/common_pattern_matcher.py @@ -14,6 +14,7 @@ from typing import TypeGuard from gt4py.next.iterator import ir as itir +from gt4py.next.iterator.ir_utils import ir_makers as im def is_applied_lift(arg: itir.Node) -> TypeGuard[itir.FunCall]: @@ -24,3 +25,13 @@ def is_applied_lift(arg: itir.Node) -> TypeGuard[itir.FunCall]: and isinstance(arg.fun.fun, itir.SymRef) and arg.fun.fun.id == "lift" ) + + +def is_let(node: itir.Node) -> TypeGuard[itir.FunCall]: + """Match expression of the form `(λ(...) → ...)(...)`.""" + return isinstance(node, itir.FunCall) and isinstance(node.fun, itir.Lambda) + + +def is_if_call(node: itir.Expr) -> TypeGuard[itir.FunCall]: + """Match expression of the form `if_(cond, true_branch, false_branch)`.""" + return isinstance(node, itir.FunCall) and node.fun == im.ref("if_") diff --git a/src/gt4py/next/iterator/ir_utils/ir_makers.py b/src/gt4py/next/iterator/ir_utils/ir_makers.py index 94a2646422..4337e8512a 100644 --- a/src/gt4py/next/iterator/ir_utils/ir_makers.py +++ b/src/gt4py/next/iterator/ir_utils/ir_makers.py @@ -12,7 +12,8 @@ # # SPDX-License-Identifier: GPL-3.0-or-later -from typing import Callable, Union +import typing +from typing import Callable, Iterable, Union from gt4py._core import definitions as core_defs from gt4py.next.iterator import ir as itir @@ -242,16 +243,31 @@ class let: -------- >>> str(let("a", "b")("a")) # doctest: +ELLIPSIS '(λ(a) → a)(b)' - >>> str(let("a", 1, - ... "b", 2 + >>> str(let(("a", 1), + ... ("b", 2) ... )(plus("a", "b"))) '(λ(a, b) → a + b)(1, 2)' """ - def __init__(self, *vars_and_values): - assert len(vars_and_values) % 2 == 0 - self.vars = vars_and_values[0::2] - self.init_forms = vars_and_values[1::2] + @typing.overload + def __init__(self, var: str | itir.Sym, init_form: itir.Expr): ... + + @typing.overload + def __init__(self, *args: Iterable[tuple[str | itir.Sym, itir.Expr]]): ... + + def __init__(self, *args): + if all(isinstance(arg, tuple) and len(arg) == 2 for arg in args): + assert isinstance(args, tuple) + assert all(isinstance(arg, tuple) and len(arg) == 2 for arg in args) + self.vars = [var for var, _ in args] + self.init_forms = [init_form for _, init_form in args] + elif len(args) == 2: + self.vars = [args[0]] + self.init_forms = [args[1]] + else: + raise TypeError( + "Invalid arguments: expected a variable name and an init form or a list thereof." + ) def __call__(self, form): return call(lambda_(*self.vars)(form))(*self.init_forms) diff --git a/src/gt4py/next/iterator/ir_utils/misc.py b/src/gt4py/next/iterator/ir_utils/misc.py new file mode 100644 index 0000000000..4336649d06 --- /dev/null +++ b/src/gt4py/next/iterator/ir_utils/misc.py @@ -0,0 +1,79 @@ +# GT4Py - GridTools Framework +# +# Copyright (c) 2014-2023, ETH Zurich +# All rights reserved. +# +# This file is part of the GT4Py project and the GridTools framework. +# GT4Py is free software: you can redistribute it and/or modify it under +# the terms of the GNU General Public License as published by the +# Free Software Foundation, either version 3 of the License, or any later +# version. See the LICENSE.txt file at the top-level directory of this +# distribution for a copy of the license or check . +# +# SPDX-License-Identifier: GPL-3.0-or-later + +import dataclasses +from collections import ChainMap + +from gt4py import eve +from gt4py.eve import utils as eve_utils +from gt4py.next.iterator import ir as itir +from gt4py.next.iterator.ir_utils import ir_makers as im + + +@dataclasses.dataclass(frozen=True) +class CannonicalizeBoundSymbolNames(eve.NodeTranslator): + """ + Given an iterator expression cannonicalize all bound symbol names. + + If two such expression are in the same scope and equal so are their values. + + >>> testee1 = im.lambda_("a")(im.plus("a", "b")) + >>> cannonicalized_testee1 = CannonicalizeBoundSymbolNames.apply(testee1) + >>> str(cannonicalized_testee1) + 'λ(_csym_1) → _csym_1 + b' + + >>> testee2 = im.lambda_("c")(im.plus("c", "b")) + >>> cannonicalized_testee2 = CannonicalizeBoundSymbolNames.apply(testee2) + >>> assert cannonicalized_testee1 == cannonicalized_testee2 + """ + + _uids: eve_utils.UIDGenerator = dataclasses.field( + init=False, repr=False, default_factory=lambda: eve_utils.UIDGenerator(prefix="_csym") + ) + + @classmethod + def apply(cls, node: itir.Expr): + return cls().visit(node, sym_map=ChainMap({})) + + def visit_Lambda(self, node: itir.Lambda, *, sym_map: ChainMap): + sym_map = sym_map.new_child() + for param in node.params: + sym_map[str(param.id)] = self._uids.sequential_id() + + return im.lambda_(*sym_map.values())(self.visit(node.expr, sym_map=sym_map)) + + def visit_SymRef(self, node: itir.SymRef, *, sym_map: dict[str, str]): + return im.ref(sym_map[node.id]) if node.id in sym_map else node + + +def is_equal(a: itir.Expr, b: itir.Expr): + """ + Return true if two expressions have provably equal values. + + Be aware that this function might return false even though the two expression have the same + value. + + >>> testee1 = im.lambda_("a")(im.plus("a", "b")) + >>> testee2 = im.lambda_("c")(im.plus("c", "b")) + >>> assert is_equal(testee1, testee2) + + >>> testee1 = im.lambda_("a")(im.plus("a", "b")) + >>> testee2 = im.lambda_("c")(im.plus("c", "d")) + >>> assert not is_equal(testee1, testee2) + """ + # TODO(tehrengruber): Extend this function cover more cases than just those with equal + # structure, e.g., by also canonicalization of the structure. + return a == b or ( + CannonicalizeBoundSymbolNames.apply(a) == CannonicalizeBoundSymbolNames.apply(b) + ) diff --git a/src/gt4py/next/iterator/transforms/collapse_tuple.py b/src/gt4py/next/iterator/transforms/collapse_tuple.py index 42bbf28909..51daffed05 100644 --- a/src/gt4py/next/iterator/transforms/collapse_tuple.py +++ b/src/gt4py/next/iterator/transforms/collapse_tuple.py @@ -11,22 +11,29 @@ # distribution for a copy of the license or check . # # SPDX-License-Identifier: GPL-3.0-or-later -from dataclasses import dataclass +import dataclasses +import enum +import functools +import operator from typing import Optional from gt4py import eve +from gt4py.eve import utils as eve_utils from gt4py.next import type_inference from gt4py.next.iterator import ir, type_inference as it_type_inference +from gt4py.next.iterator.ir_utils import ir_makers as im, misc as ir_misc +from gt4py.next.iterator.ir_utils.common_pattern_matcher import is_if_call, is_let +from gt4py.next.iterator.transforms.inline_lambdas import InlineLambdas, inline_lambda class UnknownLength: pass -def _get_tuple_size(elem: ir.Node, node_types: Optional[dict] = None) -> int | type[UnknownLength]: - if node_types: - type_ = node_types[id(elem)] - # global inference should always give a length, function should fail otherwise +def _get_tuple_size(elem: ir.Node, use_global_information: bool) -> int | type[UnknownLength]: + if use_global_information: + type_ = elem.annex.type + # global inference should always give a length, fail otherwise assert isinstance(type_, it_type_inference.Val) and isinstance( type_.dtype, it_type_inference.Tuple ) @@ -47,7 +54,31 @@ def _get_tuple_size(elem: ir.Node, node_types: Optional[dict] = None) -> int | t return len(type_.dtype) -@dataclass(frozen=True) +def _with_altered_arg(node: ir.FunCall, arg_idx: int, new_arg: ir.Expr): + """Given a itir.FunCall return a new call with one of its argument replaced.""" + return ir.FunCall( + fun=node.fun, args=[arg if i != arg_idx else new_arg for i, arg in enumerate(node.args)] + ) + + +def _is_trivial_make_tuple_call(node: ir.Expr): + """Return if node is a `make_tuple` call with all elements `SymRef`s, `Literal`s or tuples thereof.""" + if not (isinstance(node, ir.FunCall) and node.fun == im.ref("make_tuple")): + return False + if not all( + isinstance(arg, (ir.SymRef, ir.Literal)) or _is_trivial_make_tuple_call(arg) + for arg in node.args + ): + return False + return True + + +# TODO(tehrengruber): Conceptually the structure of this pass makes sense: Visit depth first, +# transform each node until no transformations apply anymore, whenever a node is to be transformed +# go through all available transformation and apply them. However the final result here still +# reads a little convoluted and is also different to how we write other transformations. We +# should revisit the pattern here and try to find a more general mechanism. +@dataclasses.dataclass(frozen=True) class CollapseTuple(eve.PreserveLocationVisitor, eve.NodeTranslator): """ Simplifies `make_tuple`, `tuple_get` calls. @@ -56,10 +87,44 @@ class CollapseTuple(eve.PreserveLocationVisitor, eve.NodeTranslator): - `tuple_get(i, make_tuple(e_0, e_1, ..., e_i, ..., e_N))` -> `e_i` """ + # TODO(tehrengruber): This Flag mechanism is a little low level. What we actually want + # is something like a pass manager, where for each pattern we have a corresponding + # transformation, etc. + class Flag(enum.Flag): + #: `make_tuple(tuple_get(0, t), tuple_get(1, t), ..., tuple_get(N-1,t))` -> `t` + COLLAPSE_MAKE_TUPLE_TUPLE_GET = enum.auto() + #: `tuple_get(i, make_tuple(e_0, e_1, ..., e_i, ..., e_N))` -> `e_i` + COLLAPSE_TUPLE_GET_MAKE_TUPLE = enum.auto() + #: `tuple_get(i, let(...)(make_tuple()))` -> `let(...)(tuple_get(i, make_tuple()))` + PROPAGATE_TUPLE_GET = enum.auto() + #: `{1, 2}` -> `(λ(_tuple_el_1, _tuple_el_2) → {_tuple_el_1, _tuple_el_2})(1, 2)` + LETIFY_MAKE_TUPLE_ELEMENTS = enum.auto() + #: `let(tup, {trivial_expr1, trivial_expr2})(foo(tup))` + #: -> `foo({trivial_expr1, trivial_expr2})` + INLINE_TRIVIAL_MAKE_TUPLE = enum.auto() + #: `(if cond then {1, 2} else {3, 4})[0]` -> `if cond then {1, 2}[0] else {3, 4}[0]` + PROPAGATE_TO_IF_ON_TUPLES = enum.auto() + #: `let((a, let(b, 1)(a_val)))(a)`-> `let(b, 1)(let(a, a_val)(a))` + PROPAGATE_NESTED_LET = enum.auto() + #: `let(a, 1)(a)` -> `1` + INLINE_TRIVIAL_LET = enum.auto() + + @classmethod + def all(self): # noqa: A003 # shadowing a python builtin + return functools.reduce(operator.or_, self.__members__.values()) + ignore_tuple_size: bool - collapse_make_tuple_tuple_get: bool - collapse_tuple_get_make_tuple: bool use_global_type_inference: bool + flags: Flag = Flag.all() + + PRESERVED_ANNEX_ATTRS = ("type",) + + # we use one UID generator per instance such that the generated ids are + # stable across multiple runs (required for caching to properly work) + _letify_make_tuple_uids: eve_utils.UIDGenerator = dataclasses.field( + init=False, repr=False, default_factory=lambda: eve_utils.UIDGenerator(prefix="_tuple_el") + ) + _node_types: Optional[dict[int, type_inference.Type]] = None @classmethod @@ -68,34 +133,77 @@ def apply( node: ir.Node, *, ignore_tuple_size: bool = False, - # the following options are mostly for allowing separate testing of the modes - collapse_make_tuple_tuple_get: bool = True, - collapse_tuple_get_make_tuple: bool = True, use_global_type_inference: bool = False, + remove_letified_make_tuple_elements: bool = True, + # manually passing flags is mostly for allowing separate testing of the modes + flags=None, ) -> ir.Node: """ Simplifies `make_tuple`, `tuple_get` calls. - If `ignore_tuple_size`, apply the transformation even if length of the inner tuple - is greater than the length of the outer tuple. + Arguments: + node: The node to transform. + + Keyword arguments: + ignore_tuple_size: Apply the transformation even if length of the inner tuple is greater + than the length of the outer tuple. + use_global_type_inference: Run global type inference to determine tuple sizes. + remove_letified_make_tuple_elements: Run `InlineLambdas` as a post-processing step + to remove left-overs from `LETIFY_MAKE_TUPLE_ELEMENTS` transformation. + `(λ(_tuple_el_1, _tuple_el_2) → {_tuple_el_1, _tuple_el_2})(1, 2)` -> {1, 2}` """ - node_types = it_type_inference.infer_all(node) if use_global_type_inference else None - return cls( - ignore_tuple_size, - collapse_make_tuple_tuple_get, - collapse_tuple_get_make_tuple, - use_global_type_inference, - node_types, + flags = flags or cls.flags + if use_global_type_inference: + it_type_inference.infer_all(node, save_to_annex=True) + + new_node = cls( + ignore_tuple_size=ignore_tuple_size, + use_global_type_inference=use_global_type_inference, + flags=flags, ).visit(node) - def visit_FunCall(self, node: ir.FunCall, **kwargs) -> ir.Node: - if ( - self.collapse_make_tuple_tuple_get - and node.fun == ir.SymRef(id="make_tuple") - and all( - isinstance(arg, ir.FunCall) and arg.fun == ir.SymRef(id="tuple_get") - for arg in node.args + # inline to remove left-overs from LETIFY_MAKE_TUPLE_ELEMENTS. this is important + # as otherwise two equal expressions containing a tuple will not be equal anymore + # and the CSE pass can not remove them. + # TODO(tehrengruber): test case for `scan(lambda carry: {1, 2})` + # (see solve_nonhydro_stencil_52_like_z_q_tup) + if remove_letified_make_tuple_elements: + new_node = InlineLambdas.apply( + new_node, opcount_preserving=True, force_inline_lambda_args=False ) + + return new_node + + def visit_FunCall(self, node: ir.FunCall) -> ir.Node: + node = self.generic_visit(node) + return self.fp_transform(node) + + def fp_transform(self, node: ir.Node) -> ir.Node: + while True: + new_node = self.transform(node) + if new_node is None: + break + assert new_node != node + node = new_node + return node + + def transform(self, node: ir.Node) -> Optional[ir.Node]: + if not isinstance(node, ir.FunCall): + return None + + for transformation in self.Flag: + if self.flags & transformation: + assert isinstance(transformation.name, str) + method = getattr(self, f"transform_{transformation.name.lower()}") + result = method(node) + if result is not None: + return result + return None + + def transform_collapse_make_tuple_tuple_get(self, node: ir.FunCall) -> Optional[ir.Node]: + if node.fun == ir.SymRef(id="make_tuple") and all( + isinstance(arg, ir.FunCall) and arg.fun == ir.SymRef(id="tuple_get") + for arg in node.args ): # `make_tuple(tuple_get(0, t), tuple_get(1, t), ..., tuple_get(N-1,t))` -> `t` assert isinstance(node.args[0], ir.FunCall) @@ -104,17 +212,19 @@ def visit_FunCall(self, node: ir.FunCall, **kwargs) -> ir.Node: for i, v in enumerate(node.args): assert isinstance(v, ir.FunCall) assert isinstance(v.args[0], ir.Literal) - if not (int(v.args[0].value) == i and v.args[1] == first_expr): + if not (int(v.args[0].value) == i and ir_misc.is_equal(v.args[1], first_expr)): # tuple argument differs, just continue with the rest of the tree - return self.generic_visit(node) + return None - if self.ignore_tuple_size or _get_tuple_size(first_expr, self._node_types) == len( - node.args - ): + if self.ignore_tuple_size or _get_tuple_size( + first_expr, self.use_global_type_inference + ) == len(node.args): return first_expr + return None + + def transform_collapse_tuple_get_make_tuple(self, node: ir.FunCall) -> Optional[ir.Node]: if ( - self.collapse_tuple_get_make_tuple - and node.fun == ir.SymRef(id="tuple_get") + node.fun == ir.SymRef(id="tuple_get") and isinstance(node.args[1], ir.FunCall) and node.args[1].fun == ir.SymRef(id="make_tuple") and isinstance(node.args[0], ir.Literal) @@ -127,4 +237,106 @@ def visit_FunCall(self, node: ir.FunCall, **kwargs) -> ir.Node: make_tuple_call.args ), f"Index {idx} is out of bounds for tuple of size {len(make_tuple_call.args)}" return node.args[1].args[idx] - return self.generic_visit(node) + return None + + def transform_propagate_tuple_get(self, node: ir.FunCall) -> Optional[ir.Node]: + if node.fun == ir.SymRef(id="tuple_get") and isinstance(node.args[0], ir.Literal): + # TODO(tehrengruber): extend to general symbols as long as the tail call in the let + # does not capture + # `tuple_get(i, let(...)(make_tuple()))` -> `let(...)(tuple_get(i, make_tuple()))` + if is_let(node.args[1]): + idx, let_expr = node.args + return im.call( + im.lambda_(*let_expr.fun.params)( # type: ignore[attr-defined] # ensured by is_let + self.fp_transform(im.tuple_get(idx.value, let_expr.fun.expr)) # type: ignore[attr-defined] # ensured by is_let + ) + )( + *let_expr.args # type: ignore[attr-defined] # ensured by is_let + ) + elif isinstance(node.args[1], ir.FunCall) and node.args[1].fun == im.ref("if_"): + idx = node.args[0] + cond, true_branch, false_branch = node.args[1].args + return im.if_( + cond, + self.fp_transform(im.tuple_get(idx.value, true_branch)), + self.fp_transform(im.tuple_get(idx.value, false_branch)), + ) + return None + + def transform_letify_make_tuple_elements(self, node: ir.FunCall) -> Optional[ir.Node]: + if node.fun == ir.SymRef(id="make_tuple"): + # `make_tuple(expr1, expr1)` + # -> `let((_tuple_el_1, expr1), (_tuple_el_2, expr2))(make_tuple(_tuple_el_1, _tuple_el_2))` + bound_vars: dict[str, ir.Expr] = {} + new_args: list[ir.Expr] = [] + for arg in node.args: + if ( + isinstance(node, ir.FunCall) + and node.fun == im.ref("make_tuple") + and not _is_trivial_make_tuple_call(node) + ): + el_name = self._letify_make_tuple_uids.sequential_id() + new_args.append(im.ref(el_name)) + bound_vars[el_name] = arg + else: + new_args.append(arg) + + if bound_vars: + return self.fp_transform(im.let(*bound_vars.items())(im.call(node.fun)(*new_args))) # type: ignore[arg-type] # mypy not smart enough + return None + + def transform_inline_trivial_make_tuple(self, node: ir.FunCall) -> Optional[ir.Node]: + if is_let(node): + # `let(tup, make_tuple(trivial_expr1, trivial_expr2))(foo(tup))` + # -> `foo(make_tuple(trivial_expr1, trivial_expr2))` + eligible_params = [_is_trivial_make_tuple_call(arg) for arg in node.args] + if any(eligible_params): + return self.visit(inline_lambda(node, eligible_params=eligible_params)) + return None + + def transform_propagate_to_if_on_tuples(self, node: ir.FunCall) -> Optional[ir.Node]: + if not node.fun == im.ref("if_"): + # TODO(tehrengruber): This significantly increases the size of the tree. Revisit. + # TODO(tehrengruber): Only inline if type of branch value is a tuple. + # Examples: + # `(if cond then {1, 2} else {3, 4})[0]` -> `if cond then {1, 2}[0] else {3, 4}[0]` + # `let (b, if cond then {1, 2} else {3, 4})) b[0]` + # -> `if cond then let(b, {1, 2})(b[0]) else let(b, {3, 4})(b[0])` + for i, arg in enumerate(node.args): + if is_if_call(arg): + cond, true_branch, false_branch = arg.args + new_true_branch = self.fp_transform(_with_altered_arg(node, i, true_branch)) + new_false_branch = self.fp_transform(_with_altered_arg(node, i, false_branch)) + return im.if_(cond, new_true_branch, new_false_branch) + return None + + def transform_propagate_nested_let(self, node: ir.FunCall) -> Optional[ir.Node]: + if is_let(node): + # `let((a, let(b, 1)(a_val)))(a)`-> `let(b, 1)(let(a, a_val)(a))` + outer_vars = {} + inner_vars = {} + original_inner_expr = node.fun.expr # type: ignore[attr-defined] # ensured by is_let + for arg_sym, arg in zip(node.fun.params, node.args): # type: ignore[attr-defined] # ensured by is_let + assert arg_sym not in inner_vars # TODO(tehrengruber): fix collisions + if is_let(arg): + for sym, val in zip(arg.fun.params, arg.args): # type: ignore[attr-defined] # ensured by is_let + assert sym not in outer_vars # TODO(tehrengruber): fix collisions + outer_vars[sym] = val + inner_vars[arg_sym] = arg.fun.expr # type: ignore[attr-defined] # ensured by is_let + else: + inner_vars[arg_sym] = arg + if outer_vars: + return self.fp_transform( + im.let(*outer_vars.items())( # type: ignore[arg-type] # mypy not smart enough + self.fp_transform(im.let(*inner_vars.items())(original_inner_expr)) + ) + ) + return None + + def transform_inline_trivial_let(self, node: ir.FunCall) -> Optional[ir.Node]: + if is_let(node) and isinstance(node.fun.expr, ir.SymRef): # type: ignore[attr-defined] # ensured by is_let + # `let(a, 1)(a)` -> `1` + for arg_sym, arg in zip(node.fun.params, node.args): # type: ignore[attr-defined] # ensured by is_let + if node.fun.expr == im.ref(arg_sym.id): # type: ignore[attr-defined] # ensured by is_let + return arg + return None diff --git a/src/gt4py/next/iterator/transforms/pass_manager.py b/src/gt4py/next/iterator/transforms/pass_manager.py index fe14a8f580..b9dcc094c4 100644 --- a/src/gt4py/next/iterator/transforms/pass_manager.py +++ b/src/gt4py/next/iterator/transforms/pass_manager.py @@ -117,6 +117,10 @@ def apply_common_transforms( # to limit number of times global type inference is executed, only in the last iterations. use_global_type_inference=inlined == ir, ) + # This pass is required such that a deref outside of a + # `tuple_get(make_tuple(let(...), ...))` call is propagated into the let after the + # `tuple_get` is removed by the `CollapseTuple` pass. + inlined = PropagateDeref.apply(inlined) if inlined == ir: break diff --git a/src/gt4py/next/iterator/transforms/propagate_deref.py b/src/gt4py/next/iterator/transforms/propagate_deref.py index 783e54ede0..9f8bff7a84 100644 --- a/src/gt4py/next/iterator/transforms/propagate_deref.py +++ b/src/gt4py/next/iterator/transforms/propagate_deref.py @@ -15,6 +15,7 @@ from gt4py.eve import NodeTranslator, PreserveLocationVisitor from gt4py.eve.pattern_matching import ObjectPattern as P from gt4py.next.iterator import ir +from gt4py.next.iterator.ir_utils import ir_makers as im # TODO(tehrengruber): This pass can be generalized to all builtins, e.g. @@ -56,4 +57,12 @@ def visit_FunCall(self, node: ir.FunCall): ), args=lambda_args, ) + elif ( + node.fun == im.ref("deref") + and isinstance(node.args[0], ir.FunCall) + and node.args[0].fun == im.ref("if_") + ): + cond, true_branch, false_branch = node.args[0].args + return im.if_(cond, im.deref(true_branch), im.deref(false_branch)) + return self.generic_visit(node) diff --git a/tests/next_tests/unit_tests/iterator_tests/transforms_tests/test_collapse_tuple.py b/tests/next_tests/unit_tests/iterator_tests/transforms_tests/test_collapse_tuple.py index 1444b0a64f..330f66bee5 100644 --- a/tests/next_tests/unit_tests/iterator_tests/transforms_tests/test_collapse_tuple.py +++ b/tests/next_tests/unit_tests/iterator_tests/transforms_tests/test_collapse_tuple.py @@ -20,7 +20,11 @@ def test_simple_make_tuple_tuple_get(): tuple_of_size_2 = im.make_tuple("first", "second") testee = im.make_tuple(im.tuple_get(0, tuple_of_size_2), im.tuple_get(1, tuple_of_size_2)) - actual = CollapseTuple.apply(testee, collapse_tuple_get_make_tuple=False) + actual = CollapseTuple.apply( + testee, + remove_letified_make_tuple_elements=False, + flags=CollapseTuple.Flag.COLLAPSE_MAKE_TUPLE_TUPLE_GET, + ) expected = tuple_of_size_2 assert actual == expected @@ -32,7 +36,11 @@ def test_nested_make_tuple_tuple_get(): im.tuple_get(0, tup_of_size2_from_lambda), im.tuple_get(1, tup_of_size2_from_lambda) ) - actual = CollapseTuple.apply(testee, collapse_tuple_get_make_tuple=False) + actual = CollapseTuple.apply( + testee, + remove_letified_make_tuple_elements=False, + flags=CollapseTuple.Flag.COLLAPSE_MAKE_TUPLE_TUPLE_GET, + ) assert actual == tup_of_size2_from_lambda @@ -42,7 +50,11 @@ def test_different_tuples_make_tuple_tuple_get(): t1 = im.make_tuple("foo1", "bar1") testee = im.make_tuple(im.tuple_get(0, t0), im.tuple_get(1, t1)) - actual = CollapseTuple.apply(testee, collapse_tuple_get_make_tuple=False) + actual = CollapseTuple.apply( + testee, + remove_letified_make_tuple_elements=False, + flags=CollapseTuple.Flag.COLLAPSE_MAKE_TUPLE_TUPLE_GET, + ) assert actual == testee # did nothing @@ -50,24 +62,137 @@ def test_different_tuples_make_tuple_tuple_get(): def test_incompatible_order_make_tuple_tuple_get(): tuple_of_size_2 = im.make_tuple("first", "second") testee = im.make_tuple(im.tuple_get(1, tuple_of_size_2), im.tuple_get(0, tuple_of_size_2)) - actual = CollapseTuple.apply(testee, collapse_tuple_get_make_tuple=False) + actual = CollapseTuple.apply( + testee, + remove_letified_make_tuple_elements=False, + flags=CollapseTuple.Flag.COLLAPSE_MAKE_TUPLE_TUPLE_GET, + ) assert actual == testee # did nothing def test_incompatible_size_make_tuple_tuple_get(): testee = im.make_tuple(im.tuple_get(0, im.make_tuple("first", "second"))) - actual = CollapseTuple.apply(testee, collapse_tuple_get_make_tuple=False) + actual = CollapseTuple.apply( + testee, + remove_letified_make_tuple_elements=False, + flags=CollapseTuple.Flag.COLLAPSE_MAKE_TUPLE_TUPLE_GET, + ) assert actual == testee # did nothing def test_merged_with_smaller_outer_size_make_tuple_tuple_get(): testee = im.make_tuple(im.tuple_get(0, im.make_tuple("first", "second"))) - actual = CollapseTuple.apply(testee, ignore_tuple_size=True) + actual = CollapseTuple.apply( + testee, ignore_tuple_size=True, flags=CollapseTuple.Flag.COLLAPSE_MAKE_TUPLE_TUPLE_GET + ) assert actual == im.make_tuple("first", "second") def test_simple_tuple_get_make_tuple(): expected = im.ref("bar") testee = im.tuple_get(1, im.make_tuple("foo", expected)) - actual = CollapseTuple.apply(testee, collapse_make_tuple_tuple_get=False) + actual = CollapseTuple.apply( + testee, + remove_letified_make_tuple_elements=False, + flags=CollapseTuple.Flag.COLLAPSE_TUPLE_GET_MAKE_TUPLE, + ) + assert expected == actual + + +def test_propagate_tuple_get(): + expected = im.let(("el1", 1), ("el2", 2))(im.tuple_get(0, im.make_tuple("el1", "el2"))) + testee = im.tuple_get(0, im.let(("el1", 1), ("el2", 2))(im.make_tuple("el1", "el2"))) + actual = CollapseTuple.apply( + testee, + remove_letified_make_tuple_elements=False, + flags=CollapseTuple.Flag.PROPAGATE_TUPLE_GET, + ) assert expected == actual + + +def test_letify_make_tuple_elements(): + opaque_call = im.call("opaque")() + testee = im.make_tuple(opaque_call, opaque_call) + expected = im.let(("_tuple_el_1", opaque_call), ("_tuple_el_2", opaque_call))( + im.make_tuple("_tuple_el_1", "_tuple_el_2") + ) + + actual = CollapseTuple.apply( + testee, + remove_letified_make_tuple_elements=False, + flags=CollapseTuple.Flag.LETIFY_MAKE_TUPLE_ELEMENTS, + ) + assert actual == expected + + +def test_letify_make_tuple_with_trivial_elements(): + testee = im.let(("a", 1), ("b", 2))(im.make_tuple("a", "b")) + expected = testee # did nothing + + actual = CollapseTuple.apply( + testee, + remove_letified_make_tuple_elements=False, + flags=CollapseTuple.Flag.LETIFY_MAKE_TUPLE_ELEMENTS, + ) + assert actual == expected + + +def test_inline_trivial_make_tuple(): + testee = im.let("tup", im.make_tuple("a", "b"))("tup") + expected = im.make_tuple("a", "b") + + actual = CollapseTuple.apply( + testee, + remove_letified_make_tuple_elements=False, + flags=CollapseTuple.Flag.INLINE_TRIVIAL_MAKE_TUPLE, + ) + assert actual == expected + + +def test_propagate_to_if_on_tuples(): + testee = im.tuple_get(0, im.if_("cond", im.make_tuple(1, 2), im.make_tuple(3, 4))) + expected = im.if_( + "cond", im.tuple_get(0, im.make_tuple(1, 2)), im.tuple_get(0, im.make_tuple(3, 4)) + ) + actual = CollapseTuple.apply( + testee, + remove_letified_make_tuple_elements=False, + flags=CollapseTuple.Flag.PROPAGATE_TO_IF_ON_TUPLES, + ) + assert actual == expected + + +def test_propagate_to_if_on_tuples_with_let(): + testee = im.let("val", im.if_("cond", im.make_tuple(1, 2), im.make_tuple(3, 4)))( + im.tuple_get(0, "val") + ) + expected = im.if_( + "cond", im.tuple_get(0, im.make_tuple(1, 2)), im.tuple_get(0, im.make_tuple(3, 4)) + ) + actual = CollapseTuple.apply( + testee, + remove_letified_make_tuple_elements=True, + flags=CollapseTuple.Flag.PROPAGATE_TO_IF_ON_TUPLES + | CollapseTuple.Flag.LETIFY_MAKE_TUPLE_ELEMENTS, + ) + assert actual == expected + + +def test_propagate_nested_lift(): + testee = im.let("a", im.let("b", 1)("a_val"))("a") + expected = im.let("b", 1)(im.let("a", "a_val")("a")) + actual = CollapseTuple.apply( + testee, + remove_letified_make_tuple_elements=False, + flags=CollapseTuple.Flag.PROPAGATE_NESTED_LET, + ) + assert actual == expected + + +def test_if_on_tuples_with_let(): + testee = im.let("val", im.if_("cond", im.make_tuple(1, 2), im.make_tuple(3, 4)))( + im.tuple_get(0, "val") + ) + expected = im.if_("cond", 1, 3) + actual = CollapseTuple.apply(testee, remove_letified_make_tuple_elements=False) + assert actual == expected diff --git a/tests/next_tests/unit_tests/iterator_tests/transforms_tests/test_cse.py b/tests/next_tests/unit_tests/iterator_tests/transforms_tests/test_cse.py index 065095e1c2..fb7720f4d7 100644 --- a/tests/next_tests/unit_tests/iterator_tests/transforms_tests/test_cse.py +++ b/tests/next_tests/unit_tests/iterator_tests/transforms_tests/test_cse.py @@ -213,15 +213,14 @@ def is_let(node: ir.Expr): testee = im.plus( im.let( - "c", - im.let( - "a", - 1, - "b", - 2, - )(im.plus("a", "b")), - "d", - 3, + ( + "c", + im.let( + ("a", 1), + ("b", 2), + )(im.plus("a", "b")), + ), + ("d", 3), )(im.plus("c", "d")), 4, ) From 4276d015ef615982ea53b913d2603910158d2d81 Mon Sep 17 00:00:00 2001 From: edopao Date: Fri, 16 Feb 2024 09:08:19 +0100 Subject: [PATCH 79/85] fix[next][dace]: Bugfix in neighbor reduction (#1456) When visiting a neighbor expression, the DaCe ITIR backend was ignoring the node arguments and using directly the closure symbol i_K . The problem found in one diffusion stencil is that the arguments contained a vertical offset, which returns (i_K - 1), and this offset was lost. This PR adds test coverage to GT4Py for the above case. --- .../runners/dace_iterator/itir_to_tasklet.py | 41 +++++++++++-------- .../ffront_tests/test_gt4py_builtins.py | 41 +++++++++++++++++++ 2 files changed, 66 insertions(+), 16 deletions(-) diff --git a/src/gt4py/next/program_processors/runners/dace_iterator/itir_to_tasklet.py b/src/gt4py/next/program_processors/runners/dace_iterator/itir_to_tasklet.py index 773a3a61f7..2e58eccec8 100644 --- a/src/gt4py/next/program_processors/runners/dace_iterator/itir_to_tasklet.py +++ b/src/gt4py/next/program_processors/runners/dace_iterator/itir_to_tasklet.py @@ -400,37 +400,46 @@ def builtin_neighbors( neighbor_value_node, ) else: + data_access_index = ",".join(f"{dim}_v" for dim in sorted(iterator.dimensions)) + connector_neighbor_dim = f"{offset_provider.neighbor_axis.value}_v" data_access_tasklet = state.add_tasklet( "data_access", - code="__data = __field[__idx]" + code=f"__data = __field[{data_access_index}] " + ( - f" if __idx != {neighbor_skip_value} else {transformer.context.reduce_identity.value}" + f"if {connector_neighbor_dim} != {neighbor_skip_value} else {transformer.context.reduce_identity.value}" if offset_provider.has_skip_values else "" ), - inputs={"__field", "__idx"}, + inputs={"__field"} | {f"{dim}_v" for dim in iterator.dimensions}, outputs={"__data"}, debuginfo=di, ) - # select full shape only in the neighbor-axis dimension - field_subset = tuple( - f"0:{shape}" if dim == offset_provider.neighbor_axis.value else f"i_{dim}" - for dim, shape in zip(sorted(iterator.dimensions), field_desc.shape) - ) state.add_memlet_path( iterator.field, me, data_access_tasklet, - memlet=create_memlet_at(iterator.field.data, field_subset), + memlet=create_memlet_full(iterator.field.data, field_desc), dst_conn="__field", ) - state.add_edge( - neighbor_index_node, - None, - data_access_tasklet, - "__idx", - dace.Memlet(data=neighbor_index_var, subset="0"), - ) + for dim in iterator.dimensions: + connector = f"{dim}_v" + if dim == offset_provider.neighbor_axis.value: + state.add_edge( + neighbor_index_node, + None, + data_access_tasklet, + connector, + dace.Memlet(data=neighbor_index_var, subset="0"), + ) + else: + state.add_memlet_path( + iterator.indices[dim], + me, + data_access_tasklet, + dst_conn=connector, + memlet=dace.Memlet(data=iterator.indices[dim].data, subset="0"), + ) + state.add_memlet_path( data_access_tasklet, mx, diff --git a/tests/next_tests/integration_tests/feature_tests/ffront_tests/test_gt4py_builtins.py b/tests/next_tests/integration_tests/feature_tests/ffront_tests/test_gt4py_builtins.py index e27e73c80d..05824fa779 100644 --- a/tests/next_tests/integration_tests/feature_tests/ffront_tests/test_gt4py_builtins.py +++ b/tests/next_tests/integration_tests/feature_tests/ffront_tests/test_gt4py_builtins.py @@ -12,6 +12,7 @@ # distribution for a copy of the license or check . # # SPDX-License-Identifier: GPL-3.0-or-later +from typing import TypeAlias import numpy as np import pytest @@ -29,7 +30,9 @@ JDim, Joff, KDim, + Koff, V2EDim, + Vertex, cartesian_case, unstructured_case, ) @@ -108,6 +111,44 @@ def fencil(edge_f: cases.EField, out: cases.VField): ) +@pytest.mark.uses_unstructured_shift +def test_reduction_execution_with_offset(unstructured_case): + EKField: TypeAlias = gtx.Field[[Edge, KDim], np.int32] + VKField: TypeAlias = gtx.Field[[Vertex, KDim], np.int32] + + @gtx.field_operator + def reduction(edge_f: EKField) -> VKField: + return neighbor_sum(edge_f(V2E), axis=V2EDim) + + @gtx.field_operator + def fencil_op(edge_f: EKField) -> VKField: + red = reduction(edge_f) + return red(Koff[1]) + + @gtx.program + def fencil(edge_f: EKField, out: VKField): + fencil_op(edge_f, out=out) + + v2e_table = unstructured_case.offset_provider["V2E"].table + field = cases.allocate(unstructured_case, fencil, "edge_f", sizes={KDim: 2})() + out = cases.allocate(unstructured_case, fencil_op, cases.RETURN, sizes={KDim: 1})() + + cases.verify( + unstructured_case, + fencil, + field, + out, + inout=out, + ref=np.sum( + field.asnumpy()[:, 1][v2e_table], + axis=1, + initial=0, + where=v2e_table != common.SKIP_VALUE, + ).reshape(out.shape), + offset_provider=unstructured_case.offset_provider | {"Koff": KDim}, + ) + + @pytest.mark.uses_unstructured_shift @pytest.mark.uses_constant_fields def test_reduction_expression_in_call(unstructured_case): From 077c786af8b17618fda9bd5f39d337bee7d7c2ff Mon Sep 17 00:00:00 2001 From: Till Ehrengruber Date: Fri, 16 Feb 2024 10:00:29 +0100 Subject: [PATCH 80/85] refactor[next]: Change foast lowering from iterator of tuple to tuple of iterator (#1449) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Changes lowering from FOAST to ITIR to use tuples of iterators instead of iterators of tuples. - Allows the collapse tuple pass to better optimize as what was previously lowered to `↑(λ(it) → it[0])(↑(λ() → {1, 2})` is now `({↑(λ() → 1)(), ↑(λ() → 2)()})[0]` - Fixes #1413. In particular in combination with temporaries creating an iterator of tuples is problematic in case the respective backing fields are defined on different domains. This did not surface before as inlining all lifts everything always removed the invalid derefs. --- src/gt4py/next/ffront/foast_to_itir.py | 156 ++++++++++-------- src/gt4py/next/ffront/lowering_utils.py | 144 ++++++++++++++++ src/gt4py/next/ffront/past_to_itir.py | 34 +++- src/gt4py/next/ffront/type_info.py | 15 ++ src/gt4py/next/iterator/ir_utils/ir_makers.py | 6 +- src/gt4py/next/type_system/type_info.py | 45 +++-- .../ffront_tests/test_foast_to_itir.py | 18 +- .../ffront_tests/test_past_to_itir.py | 20 ++- 8 files changed, 337 insertions(+), 101 deletions(-) create mode 100644 src/gt4py/next/ffront/lowering_utils.py diff --git a/src/gt4py/next/ffront/foast_to_itir.py b/src/gt4py/next/ffront/foast_to_itir.py index c0e618a42d..2e5c158c23 100644 --- a/src/gt4py/next/ffront/foast_to_itir.py +++ b/src/gt4py/next/ffront/foast_to_itir.py @@ -21,6 +21,7 @@ dialect_ast_enums, fbuiltins, field_operator_ast as foast, + lowering_utils, type_specifications as ts_ffront, ) from gt4py.next.ffront.fbuiltins import FUN_BUILTIN_NAMES, MATH_BUILTIN_NAMES, TYPE_BUILTIN_NAMES @@ -83,7 +84,8 @@ def visit_FunctionDefinition( def visit_FieldOperator(self, node: foast.FieldOperator, **kwargs) -> itir.FunctionDefinition: func_definition: itir.FunctionDefinition = self.visit(node.definition, **kwargs) - new_body = im.deref(func_definition.expr) + + new_body = func_definition.expr return itir.FunctionDefinition( id=func_definition.id, @@ -94,28 +96,55 @@ def visit_FieldOperator(self, node: foast.FieldOperator, **kwargs) -> itir.Funct def visit_ScanOperator(self, node: foast.ScanOperator, **kwargs) -> itir.FunctionDefinition: # note: we don't need the axis here as this is handled by the program # decorator + assert isinstance(node.type, ts_ffront.ScanOperatorType) # We are lowering node.forward and node.init to iterators, but here we expect values -> `deref`. # In iterator IR we didn't properly specify if this is legal, # however after lift-inlining the expressions are transformed back to literals. forward = im.deref(self.visit(node.forward, **kwargs)) - init = im.deref(self.visit(node.init, **kwargs)) + init = lowering_utils.process_elements( + im.deref, self.visit(node.init, **kwargs), node.init.type + ) # lower definition function func_definition: itir.FunctionDefinition = self.visit(node.definition, **kwargs) - new_body = func_definition.expr - - # promote carry to iterator - # (this is the only place in the lowering were a variable is captured in a lifted lambda) new_body = im.let( func_definition.params[0].id, - im.promote_to_const_iterator(func_definition.params[0].id), - )(im.deref(new_body)) - definition = itir.Lambda(params=func_definition.params, expr=new_body) - body = im.call(im.call("scan")(definition, forward, init))( - *(param.id for param in definition.params[1:]) + # promote carry to iterator of tuples + # (this is the only place in the lowering were a variable is captured in a lifted lambda) + lowering_utils.to_tuples_of_iterator( + im.promote_to_const_iterator(func_definition.params[0].id), + [*node.type.definition.pos_or_kw_args.values()][0], + ), + )( + # the function itself returns a tuple of iterators, deref element-wise + lowering_utils.process_elements( + im.deref, func_definition.expr, node.type.definition.returns + ) ) + stencil_args = [] + assert not node.type.definition.pos_only_args and not node.type.definition.kw_only_args + for param, arg_type in zip( + func_definition.params[1:], + [*node.type.definition.pos_or_kw_args.values()][1:], + strict=True, + ): + if isinstance(arg_type, ts.TupleType): + # convert into iterator of tuples + stencil_args.append(lowering_utils.to_iterator_of_tuples(param.id, arg_type)) + + new_body = im.let( + param.id, + lowering_utils.to_tuples_of_iterator(param.id, arg_type), + )(new_body) + else: + stencil_args.append(param.id) + + definition = itir.Lambda(params=func_definition.params, expr=new_body) + + body = im.lift(im.call("scan")(definition, forward, init))(*stencil_args) + return itir.FunctionDefinition( id=node.id, params=definition.params[1:], @@ -216,14 +245,10 @@ def visit_Name(self, node: foast.Name, **kwargs) -> itir.SymRef: return im.ref(node.id) def visit_Subscript(self, node: foast.Subscript, **kwargs) -> itir.Expr: - return im.promote_to_lifted_stencil(lambda tuple_: im.tuple_get(node.index, tuple_))( - self.visit(node.value, **kwargs) - ) + return im.tuple_get(node.index, self.visit(node.value, **kwargs)) def visit_TupleExpr(self, node: foast.TupleExpr, **kwargs) -> itir.Expr: - return im.promote_to_lifted_stencil(lambda *elts: im.make_tuple(*elts))( - *[self.visit(el, **kwargs) for el in node.elts], - ) + return im.make_tuple(*[self.visit(el, **kwargs) for el in node.elts]) def visit_UnaryOp(self, node: foast.UnaryOp, **kwargs) -> itir.Expr: # TODO(tehrengruber): extend iterator ir to support unary operators @@ -243,7 +268,19 @@ def visit_BinOp(self, node: foast.BinOp, **kwargs) -> itir.FunCall: return self._map(node.op.value, node.left, node.right) def visit_TernaryExpr(self, node: foast.TernaryExpr, **kwargs) -> itir.FunCall: - return self._map("if_", node.condition, node.true_expr, node.false_expr) + op = "if_" + args = (node.condition, node.true_expr, node.false_expr) + lowered_args = [ + lowering_utils.to_iterator_of_tuples(self.visit(arg, **kwargs), arg.type) + for arg in args + ] + if any(type_info.contains_local_field(arg.type) for arg in args): + lowered_args = [promote_to_list(arg)(larg) for arg, larg in zip(args, lowered_args)] + op = im.call("map_")(op) + + return lowering_utils.to_tuples_of_iterator( + im.promote_to_lifted_stencil(im.call(op))(*lowered_args), node.type + ) def visit_Compare(self, node: foast.Compare, **kwargs) -> itir.FunCall: return self._map(node.op.value, node.left, node.right) @@ -280,28 +317,11 @@ def visit_Call(self, node: foast.Call, **kwargs) -> itir.Expr: elif isinstance( node.func.type, ( + ts.FunctionType, ts_ffront.FieldOperatorType, ts_ffront.ScanOperatorType, ), ): - # Operators are lowered into lifted stencils. - lowered_func = self.visit(node.func, **kwargs) - # ITIR has no support for keyword arguments. Instead, we concatenate both positional - # and keyword arguments and use the unique order as given in the function signature. - lowered_args, lowered_kwargs = type_info.canonicalize_arguments( - node.func.type, - [self.visit(arg, **kwargs) for arg in node.args], - {name: self.visit(arg, **kwargs) for name, arg in node.kwargs.items()}, - use_signature_ordering=True, - ) - call_args = [f"__arg{i}" for i in range(len(lowered_args))] - call_kwargs = [f"__kwarg_{name}" for name in lowered_kwargs.keys()] - return im.lift( - im.lambda_(*call_args, *call_kwargs)( - im.call(lowered_func)(*call_args, *call_kwargs) - ) - )(*lowered_args, *lowered_kwargs.values()) - elif isinstance(node.func.type, ts.FunctionType): # ITIR has no support for keyword arguments. Instead, we concatenate both positional # and keyword arguments and use the unique order as given in the function signature. lowered_args, lowered_kwargs = type_info.canonicalize_arguments( @@ -310,7 +330,17 @@ def visit_Call(self, node: foast.Call, **kwargs) -> itir.Expr: self.visit(node.kwargs, **kwargs), use_signature_ordering=True, ) - return im.call(self.visit(node.func, **kwargs))(*lowered_args, *lowered_kwargs.values()) + result = im.call(self.visit(node.func, **kwargs))( + *lowered_args, *lowered_kwargs.values() + ) + + # scan operators return an iterator of tuples, transform into tuples of iterator again + if isinstance(node.func.type, ts_ffront.ScanOperatorType): + result = lowering_utils.to_tuples_of_iterator( + result, node.func.type.definition.returns + ) + + return result raise AssertionError( f"Call to object of type '{type(node.func.type).__name__}' not understood." @@ -319,12 +349,23 @@ def visit_Call(self, node: foast.Call, **kwargs) -> itir.Expr: def _visit_astype(self, node: foast.Call, **kwargs) -> itir.FunCall: assert len(node.args) == 2 and isinstance(node.args[1], foast.Name) obj, new_type = node.args[0], node.args[1].id - return self._process_elements( - lambda x: im.call("cast_")(x, str(new_type)), obj, obj.type, **kwargs + return lowering_utils.process_elements( + lambda x: im.promote_to_lifted_stencil( + im.lambda_("it")(im.call("cast_")("it", str(new_type))) + )(x), + self.visit(obj, **kwargs), + obj.type, ) def _visit_where(self, node: foast.Call, **kwargs) -> itir.FunCall: - return self._map("if_", *node.args) + condition, true_value, false_value = node.args + + lowered_condition = self.visit(condition, **kwargs) + return lowering_utils.process_elements( + lambda tv, fv: im.promote_to_lifted_stencil("if_")(lowered_condition, tv, fv), + [self.visit(true_value, **kwargs), self.visit(false_value, **kwargs)], + node.type, + ) def _visit_broadcast(self, node: foast.Call, **kwargs) -> itir.FunCall: return self.visit(node.args[0], **kwargs) @@ -379,13 +420,8 @@ def _make_literal(self, val: Any, type_: ts.TypeSpec) -> itir.Expr: # TODO(havogt): lifted nullary lambdas are not supported in iterator.embedded due to an implementation detail; # the following constructs work if they are removed by inlining. if isinstance(type_, ts.TupleType): - return im.promote_to_const_iterator( - im.make_tuple( - *( - im.deref(self._make_literal(val, type_)) - for val, type_ in zip(val, type_.types) - ) - ) + return im.make_tuple( + *(self._make_literal(val, type_) for val, type_ in zip(val, type_.types)) ) elif isinstance(type_, ts.ScalarType): typename = type_.kind.name.lower() @@ -403,31 +439,5 @@ def _map(self, op, *args, **kwargs): return im.promote_to_lifted_stencil(im.call(op))(*lowered_args) - def _process_elements( - self, - process_func: Callable[[itir.Expr], itir.Expr], - obj: foast.Expr, - current_el_type: ts.TypeSpec, - current_el_expr: itir.Expr = im.ref("expr"), - ): - """Recursively applies a processing function to all primitive constituents of a tuple.""" - if isinstance(current_el_type, ts.TupleType): - # TODO(ninaburg): Refactor to avoid duplicating lowered obj expression for each tuple element. - return im.promote_to_lifted_stencil(lambda *elts: im.make_tuple(*elts))( - *[ - self._process_elements( - process_func, - obj, - current_el_type.types[i], - im.tuple_get(i, current_el_expr), - ) - for i in range(len(current_el_type.types)) - ] - ) - elif type_info.contains_local_field(current_el_type): - raise NotImplementedError("Processing fields with local dimension is not implemented.") - else: - return self._map(im.lambda_("expr")(process_func(current_el_expr)), obj) - class FieldOperatorLoweringError(Exception): ... diff --git a/src/gt4py/next/ffront/lowering_utils.py b/src/gt4py/next/ffront/lowering_utils.py new file mode 100644 index 0000000000..f2b221083a --- /dev/null +++ b/src/gt4py/next/ffront/lowering_utils.py @@ -0,0 +1,144 @@ +# GT4Py - GridTools Framework +# +# Copyright (c) 2014-2023, ETH Zurich +# All rights reserved. +# +# This file is part of the GT4Py project and the GridTools framework. +# GT4Py is free software: you can redistribute it and/or modify it under +# the terms of the GNU General Public License as published by the +# Free Software Foundation, either version 3 of the License, or any later +# version. See the LICENSE.txt file at the top-level directory of this +# distribution for a copy of the license or check . +# +# SPDX-License-Identifier: GPL-3.0-or-later +from typing import Callable, TypeVar + +from gt4py.next.ffront import type_info as ti_ffront +from gt4py.next.iterator import ir as itir +from gt4py.next.iterator.ir_utils import ir_makers as im +from gt4py.next.type_system import type_info, type_specifications as ts + + +def _expr_hash(expr: itir.Expr | str) -> str: + """Small utility function that returns a string hash of an expression.""" + return str(abs(hash(expr)) % (10**12)).zfill(12) + + +def to_tuples_of_iterator(expr: itir.Expr | str, arg_type: ts.TypeSpec): + """ + Convert iterator of tuples into tuples of iterator. + + Supports arbitrary nesting. + + >>> print(to_tuples_of_iterator("arg", ts.TupleType(types=[ts.FieldType(dims=[], + ... dtype=ts.ScalarType(kind=ts.ScalarKind.FLOAT32))]))) # doctest: +ELLIPSIS + (λ(__toi_...) → {(↑(λ(it) → (·it)[0]))(__toi_...)})(arg) + """ + param = f"__toi_{_expr_hash(expr)}" + + def fun(primitive_type, path): + inner_expr = im.deref("it") + for path_part in path: + inner_expr = im.tuple_get(path_part, inner_expr) + + return im.lift(im.lambda_("it")(inner_expr))(param) + + return im.let(param, expr)( + type_info.apply_to_primitive_constituents( + arg_type, fun, with_path_arg=True, tuple_constructor=im.make_tuple + ) + ) + + +def to_iterator_of_tuples(expr: itir.Expr | str, arg_type: ts.TypeSpec): + """ + Convert tuples of iterator into iterator of tuples. + + Supports arbitrary nesting. + + >>> print(to_iterator_of_tuples("arg", ts.TupleType(types=[ts.FieldType(dims=[], + ... dtype=ts.ScalarType(kind=ts.ScalarKind.FLOAT32))]))) # doctest: +ELLIPSIS + (λ(__iot_...) → (↑(λ(__iot_el_0) → {·__iot_el_0}))(__iot_...[0]))(arg) + """ + param = f"__iot_{_expr_hash(expr)}" + + type_constituents = [ + ti_ffront.promote_scalars_to_zero_dim_field(type_) + for type_ in type_info.primitive_constituents(arg_type) + ] + assert all(isinstance(type_, ts.FieldType) and type_.dims == type_constituents[0].dims for type_ in type_constituents) # type: ignore[attr-defined] # ensure by assert above + + def fun(_, path): + param_name = "__iot_el" + for path_part in path: + param_name = f"{param_name}_{path_part}" + return im.deref(param_name) + + lift_params, lift_args = [], [] + for _, path in type_info.primitive_constituents(arg_type, with_path_arg=True): + param_name, arg_expr = "__iot_el", param + for path_part in path: + param_name = f"{param_name}_{path_part}" + arg_expr = im.tuple_get(path_part, arg_expr) + + lift_params.append(param_name) + lift_args.append(arg_expr) + + stencil_expr = type_info.apply_to_primitive_constituents( + arg_type, fun, with_path_arg=True, tuple_constructor=im.make_tuple + ) + return im.let(param, expr)(im.lift(im.lambda_(*lift_params)(stencil_expr))(*lift_args)) + + +# TODO(tehrengruber): The code quality of this function is poor. We should rewrite it. +def process_elements( + process_func: Callable[..., itir.Expr], + objs: itir.Expr | list[itir.Expr], + current_el_type: ts.TypeSpec, +): + """ + Recursively applies a processing function to all primitive constituents of a tuple. + + Arguments: + process_func: A callable that takes an itir.Expr representing a leaf-element of `objs`. + If multiple `objs` are given the callable takes equally many arguments. + objs: The object whose elements are to be transformed. + current_el_type: A type with the same structure as the elements of `objs`. The leaf-types + are not used and thus not relevant. + """ + if isinstance(objs, itir.Expr): + objs = [objs] + + _current_el_exprs = [im.ref(f"__val_{_expr_hash(obj)}") for i, obj in enumerate(objs)] + body = _process_elements_impl(process_func, _current_el_exprs, current_el_type) + + return im.let(*((f"__val_{_expr_hash(obj)}", obj) for i, obj in enumerate(objs)))( # type: ignore[arg-type] # mypy not smart enough + body + ) + + +T = TypeVar("T", bound=itir.Expr, covariant=True) + + +def _process_elements_impl( + process_func: Callable[..., itir.Expr], + _current_el_exprs: list[T], + current_el_type: ts.TypeSpec, +): + if isinstance(current_el_type, ts.TupleType): + result = im.make_tuple( + *[ + _process_elements_impl( + process_func, + [im.tuple_get(i, current_el_expr) for current_el_expr in _current_el_exprs], + current_el_type.types[i], + ) + for i in range(len(current_el_type.types)) + ] + ) + elif type_info.contains_local_field(current_el_type): + raise NotImplementedError("Processing fields with local dimension is not implemented.") + else: + result = process_func(*_current_el_exprs) + + return result diff --git a/src/gt4py/next/ffront/past_to_itir.py b/src/gt4py/next/ffront/past_to_itir.py index 620e98dd4d..8be9309630 100644 --- a/src/gt4py/next/ffront/past_to_itir.py +++ b/src/gt4py/next/ffront/past_to_itir.py @@ -18,8 +18,9 @@ from gt4py.eve import NodeTranslator, concepts, traits from gt4py.next.common import Dimension, DimensionKind, GridType -from gt4py.next.ffront import program_ast as past, type_specifications as ts_ffront +from gt4py.next.ffront import lowering_utils, program_ast as past, type_specifications as ts_ffront from gt4py.next.iterator import ir as itir +from gt4py.next.iterator.ir_utils import ir_makers as im from gt4py.next.type_system import type_info, type_specifications as ts @@ -141,16 +142,39 @@ def _visit_stencil_call(self, node: past.Call, **kwargs) -> itir.StencilClosure: assert isinstance(node.func.type, (ts_ffront.FieldOperatorType, ts_ffront.ScanOperatorType)) - lowered_args, lowered_kwargs = type_info.canonicalize_arguments( + args, node_kwargs = type_info.canonicalize_arguments( node.func.type, - self.visit(node.args, **kwargs), - self.visit(node_kwargs, **kwargs), + node.args, + node_kwargs, use_signature_ordering=True, ) + lowered_args, lowered_kwargs = self.visit(args, **kwargs), self.visit(node_kwargs, **kwargs) + + stencil_params = [] + stencil_args = [] + for i, arg in enumerate([*args, *node_kwargs]): + stencil_params.append(f"__stencil_arg{i}") + if isinstance(arg.type, ts.TupleType): + # convert into tuple of iterators + stencil_args.append( + lowering_utils.to_tuples_of_iterator(f"__stencil_arg{i}", arg.type) + ) + else: + stencil_args.append(f"__stencil_arg{i}") + + if isinstance(node.func.type, ts_ffront.ScanOperatorType): + # scan operators return an iterator of tuples, just deref directly + stencil_body = im.deref(im.call(node.func.id)(*stencil_args)) + else: + # field operators return a tuple of iterators, deref element-wise + stencil_body = lowering_utils.process_elements( + im.deref, im.call(node.func.id)(*stencil_args), node.func.type.definition.returns + ) + return itir.StencilClosure( domain=lowered_domain, - stencil=itir.SymRef(id=node.func.id), + stencil=im.lambda_(*stencil_params)(stencil_body), inputs=[*lowered_args, *lowered_kwargs.values()], output=output, location=node.location, diff --git a/src/gt4py/next/ffront/type_info.py b/src/gt4py/next/ffront/type_info.py index affae8fbca..c25b7dd829 100644 --- a/src/gt4py/next/ffront/type_info.py +++ b/src/gt4py/next/ffront/type_info.py @@ -25,6 +25,21 @@ def _is_zero_dim_field(field: ts.TypeSpec) -> bool: return isinstance(field, ts.FieldType) and len(field.dims) == 0 +def promote_scalars_to_zero_dim_field(type_: ts.TypeSpec) -> ts.TypeSpec: + """ + Promote scalar primitive constituents to zero dimensional fields. + + E.g. all elements of a tuple which are scalars are promoted to a zero dimensional field. + """ + + def promote_el(type_el: ts.TypeSpec) -> ts.TypeSpec: + if isinstance(type_el, ts.ScalarType): + return ts.FieldType(dims=[], dtype=type_el) + return type_el + + return type_info.apply_to_primitive_constituents(type_, promote_el) + + def promote_zero_dims( function_type: ts.FunctionType, args: list[ts.TypeSpec], kwargs: dict[str, ts.TypeSpec] ) -> tuple[list, dict]: diff --git a/src/gt4py/next/iterator/ir_utils/ir_makers.py b/src/gt4py/next/iterator/ir_utils/ir_makers.py index 4337e8512a..f6655e9d41 100644 --- a/src/gt4py/next/iterator/ir_utils/ir_makers.py +++ b/src/gt4py/next/iterator/ir_utils/ir_makers.py @@ -250,10 +250,10 @@ class let: """ @typing.overload - def __init__(self, var: str | itir.Sym, init_form: itir.Expr): ... + def __init__(self, var: str | itir.Sym, init_form: itir.Expr | str): ... @typing.overload - def __init__(self, *args: Iterable[tuple[str | itir.Sym, itir.Expr]]): ... + def __init__(self, *args: Iterable[tuple[str | itir.Sym, itir.Expr | str]]): ... def __init__(self, *args): if all(isinstance(arg, tuple) and len(arg) == 2 for arg in args): @@ -369,7 +369,7 @@ def promote_to_lifted_stencil(op: str | itir.SymRef | Callable) -> Callable[..., >>> str(promote_to_lifted_stencil("op")("a", "b")) '(↑(λ(__arg0, __arg1) → op(·__arg0, ·__arg1)))(a, b)' """ - if isinstance(op, (str, itir.SymRef)): + if isinstance(op, (str, itir.SymRef, itir.Lambda)): op = call(op) def _impl(*its: itir.Expr) -> itir.Expr: diff --git a/src/gt4py/next/type_system/type_info.py b/src/gt4py/next/type_system/type_info.py index 7c4c8e6e23..5cfb901ff1 100644 --- a/src/gt4py/next/type_system/type_info.py +++ b/src/gt4py/next/type_system/type_info.py @@ -14,6 +14,7 @@ import functools import types +import typing from typing import Any, Callable, Iterator, Type, TypeGuard, cast import numpy as np @@ -86,9 +87,24 @@ def type_class(symbol_type: ts.TypeSpec) -> Type[ts.TypeSpec]: ) +@typing.overload def primitive_constituents( symbol_type: ts.TypeSpec, -) -> XIterable[ts.TypeSpec]: + with_path_arg: typing.Literal[False] = False, +) -> XIterable[ts.TypeSpec]: ... + + +@typing.overload +def primitive_constituents( + symbol_type: ts.TypeSpec, + with_path_arg: typing.Literal[True], +) -> XIterable[tuple[ts.TypeSpec, tuple[str, ...]]]: ... + + +def primitive_constituents( + symbol_type: ts.TypeSpec, + with_path_arg: bool = False, +) -> XIterable[ts.TypeSpec] | XIterable[tuple[ts.TypeSpec, tuple[str, ...]]]: """ Return the primitive types contained in a composite type. @@ -106,14 +122,17 @@ def primitive_constituents( [FieldType(...), ScalarType(...), FieldType(...)] """ - def constituents_yielder(symbol_type: ts.TypeSpec): + def constituents_yielder(symbol_type: ts.TypeSpec, path: tuple[int, ...]): if isinstance(symbol_type, ts.TupleType): - for el_type in symbol_type.types: - yield from constituents_yielder(el_type) + for i, el_type in enumerate(symbol_type.types): + yield from constituents_yielder(el_type, (*path, i)) else: - yield symbol_type + if with_path_arg: + yield (symbol_type, path) + else: + yield symbol_type - return xiter(constituents_yielder(symbol_type)) + return xiter(constituents_yielder(symbol_type, ())) def apply_to_primitive_constituents( @@ -121,8 +140,10 @@ def apply_to_primitive_constituents( fun: ( Callable[[ts.TypeSpec], ts.TypeSpec] | Callable[[ts.TypeSpec, tuple[int, ...]], ts.TypeSpec] ), - with_path_arg=False, _path=(), + *, + with_path_arg=False, + tuple_constructor=lambda *elements: ts.TupleType(types=[*elements]), ): """ Apply function to all primitive constituents of a type. @@ -133,10 +154,14 @@ def apply_to_primitive_constituents( tuple[Field[[], int64], Field[[], int64]] """ if isinstance(symbol_type, ts.TupleType): - return ts.TupleType( - types=[ + return tuple_constructor( + *[ apply_to_primitive_constituents( - el, fun, _path=(*_path, i), with_path_arg=with_path_arg + el, + fun, + _path=(*_path, i), + with_path_arg=with_path_arg, + tuple_constructor=tuple_constructor, ) for i, el in enumerate(symbol_type.types) ] diff --git a/tests/next_tests/unit_tests/ffront_tests/test_foast_to_itir.py b/tests/next_tests/unit_tests/ffront_tests/test_foast_to_itir.py index bc92efc02c..a0035348ad 100644 --- a/tests/next_tests/unit_tests/ffront_tests/test_foast_to_itir.py +++ b/tests/next_tests/unit_tests/ffront_tests/test_foast_to_itir.py @@ -85,7 +85,7 @@ def multicopy(inp1: gtx.Field[[IDim], float64], inp2: gtx.Field[[IDim], float64] parsed = FieldOperatorParser.apply_to_function(multicopy) lowered = FieldOperatorLowering.apply(parsed) - reference = im.promote_to_lifted_stencil("make_tuple")("inp1", "inp2") + reference = im.make_tuple("inp1", "inp2") assert lowered.expr == reference @@ -195,9 +195,9 @@ def unpacking( parsed = FieldOperatorParser.apply_to_function(unpacking) lowered = FieldOperatorLowering.apply(parsed) - tuple_expr = im.promote_to_lifted_stencil("make_tuple")("inp1", "inp2") - tuple_access_0 = im.promote_to_lifted_stencil(lambda x: im.tuple_get(0, x))("__tuple_tmp_0") - tuple_access_1 = im.promote_to_lifted_stencil(lambda x: im.tuple_get(1, x))("__tuple_tmp_0") + tuple_expr = im.make_tuple("inp1", "inp2") + tuple_access_0 = im.tuple_get(0, "__tuple_tmp_0") + tuple_access_1 = im.tuple_get(1, "__tuple_tmp_0") reference = im.let("__tuple_tmp_0", tuple_expr)( im.let( @@ -248,7 +248,7 @@ def call(inp: gtx.Field[[TDim], float64]) -> gtx.Field[[TDim], float64]: parsed = FieldOperatorParser.apply_to_function(call) lowered = FieldOperatorLowering.apply(parsed) - reference = im.lift(im.lambda_("__arg0")(im.call("identity")("__arg0")))("inp") + reference = im.call("identity")("inp") assert lowered.expr == reference @@ -263,7 +263,7 @@ def temp_tuple(a: gtx.Field[[TDim], float64], b: gtx.Field[[TDim], int64]): parsed = FieldOperatorParser.apply_to_function(temp_tuple) lowered = FieldOperatorLowering.apply(parsed) - tuple_expr = im.promote_to_lifted_stencil("make_tuple")("a", "b") + tuple_expr = im.make_tuple("a", "b") reference = im.let(ssa.unique_name("tmp", 0), tuple_expr)(ssa.unique_name("tmp", 0)) assert lowered.expr == reference @@ -538,7 +538,7 @@ def int_constrs() -> tuple[ parsed = FieldOperatorParser.apply_to_function(int_constrs) lowered = FieldOperatorLowering.apply(parsed) - reference = im.promote_to_lifted_stencil("make_tuple")( + reference = im.make_tuple( im.promote_to_const_iterator(im.literal("1", "int32")), im.promote_to_const_iterator(im.literal("1", "int32")), im.promote_to_const_iterator(im.literal("1", "int64")), @@ -572,7 +572,7 @@ def float_constrs() -> tuple[ parsed = FieldOperatorParser.apply_to_function(float_constrs) lowered = FieldOperatorLowering.apply(parsed) - reference = im.promote_to_lifted_stencil("make_tuple")( + reference = im.make_tuple( im.promote_to_const_iterator(im.literal("0.1", "float64")), im.promote_to_const_iterator(im.literal("0.1", "float64")), im.promote_to_const_iterator(im.literal("0.1", "float32")), @@ -592,7 +592,7 @@ def bool_constrs() -> tuple[bool, bool, bool, bool, bool, bool, bool, bool]: parsed = FieldOperatorParser.apply_to_function(bool_constrs) lowered = FieldOperatorLowering.apply(parsed) - reference = im.promote_to_lifted_stencil("make_tuple")( + reference = im.make_tuple( im.promote_to_const_iterator(im.literal(str(True), "bool")), im.promote_to_const_iterator(im.literal(str(False), "bool")), im.promote_to_const_iterator(im.literal(str(True), "bool")), diff --git a/tests/next_tests/unit_tests/ffront_tests/test_past_to_itir.py b/tests/next_tests/unit_tests/ffront_tests/test_past_to_itir.py index a1a7b79cec..05947996c1 100644 --- a/tests/next_tests/unit_tests/ffront_tests/test_past_to_itir.py +++ b/tests/next_tests/unit_tests/ffront_tests/test_past_to_itir.py @@ -65,7 +65,25 @@ def test_copy_lowering(copy_program_def, itir_identity_fundef): ) ], ), - stencil=P(itir.SymRef, id=eve.SymbolRef("identity")), + stencil=P( + itir.Lambda, + params=[P(itir.Sym, id=eve.SymbolName("__stencil_arg0"))], + expr=P( + itir.FunCall, + fun=P( + itir.Lambda, + params=[P(itir.Sym)], + expr=P(itir.FunCall, fun=P(itir.SymRef, id=eve.SymbolRef("deref"))), + ), + args=[ + P( + itir.FunCall, + fun=P(itir.SymRef, id=eve.SymbolRef("identity")), + args=[P(itir.SymRef, id=eve.SymbolRef("__stencil_arg0"))], + ) + ], + ), + ), inputs=[P(itir.SymRef, id=eve.SymbolRef("in_field"))], output=P(itir.SymRef, id=eve.SymbolRef("out")), ) From d31d2cf046cfada8eeff1d1bc4e02bea9352ac81 Mon Sep 17 00:00:00 2001 From: Till Ehrengruber Date: Fri, 16 Feb 2024 10:36:41 +0100 Subject: [PATCH 81/85] fix[next]: Disable collapse tuple `if` propagation (#1454) We observed some increase in compile time due to the `PROPAGATE_TO_IF_ON_TUPLES` option of the `CollapseTuple` pass. This option is needed for `if` statements to work properly, which is not fully functional yet anyway and will be taken care of separately. We disable the option for now until #1414 is merged. --- src/gt4py/next/iterator/transforms/pass_manager.py | 9 ++++++++- 1 file changed, 8 insertions(+), 1 deletion(-) diff --git a/src/gt4py/next/iterator/transforms/pass_manager.py b/src/gt4py/next/iterator/transforms/pass_manager.py index b9dcc094c4..cd8ebb5516 100644 --- a/src/gt4py/next/iterator/transforms/pass_manager.py +++ b/src/gt4py/next/iterator/transforms/pass_manager.py @@ -116,6 +116,8 @@ def apply_common_transforms( inlined, # to limit number of times global type inference is executed, only in the last iterations. use_global_type_inference=inlined == ir, + # TODO(tehrengruber): disabled since it increases compile-time too much right now + flags=~CollapseTuple.Flag.PROPAGATE_TO_IF_ON_TUPLES, ) # This pass is required such that a deref outside of a # `tuple_get(make_tuple(let(...), ...))` call is propagated into the let after the @@ -159,7 +161,12 @@ def apply_common_transforms( # larger than the number of closure outputs as given by the unconditional collapse, we can # only run the unconditional version here instead of in the loop above. if unconditionally_collapse_tuples: - ir = CollapseTuple.apply(ir, ignore_tuple_size=unconditionally_collapse_tuples) + ir = CollapseTuple.apply( + ir, + ignore_tuple_size=unconditionally_collapse_tuples, + # TODO(tehrengruber): disabled since it increases compile-time too much right now + flags=~CollapseTuple.Flag.PROPAGATE_TO_IF_ON_TUPLES, + ) if lift_mode == LiftMode.FORCE_INLINE: ir = _inline_into_scan(ir) From 0a292611aa655c34abe497bcfeb72e33361939c5 Mon Sep 17 00:00:00 2001 From: edopao Date: Fri, 16 Feb 2024 11:01:40 +0100 Subject: [PATCH 82/85] fix[next][dace]: Bugfix for nested neighbor reduction (#1457) In case of nested neighbor reduction with lift expression on inner node, the DaCe backend should generate a conditional state transition to field access, based on the value of neighbor index provided by the outer connectivity table. Additional change. The previous selection of valid neighbors implemented as conditional inter-state edge is replaced by a select tasklet, which makes the SDFG easier to read. --- .../runners/dace_iterator/itir_to_tasklet.py | 102 ++++++++++++------ .../ffront_tests/test_execution.py | 12 +-- 2 files changed, 77 insertions(+), 37 deletions(-) diff --git a/src/gt4py/next/program_processors/runners/dace_iterator/itir_to_tasklet.py b/src/gt4py/next/program_processors/runners/dace_iterator/itir_to_tasklet.py index 2e58eccec8..3a33ee1e35 100644 --- a/src/gt4py/next/program_processors/runners/dace_iterator/itir_to_tasklet.py +++ b/src/gt4py/next/program_processors/runners/dace_iterator/itir_to_tasklet.py @@ -191,6 +191,7 @@ def _visit_lift_in_neighbors_reduction( neighbor_index_node: dace.nodes.AccessNode, neighbor_value_node: dace.nodes.AccessNode, ) -> list[ValueExpr]: + assert transformer.context.reduce_identity is not None neighbor_dim = offset_provider.neighbor_axis.value origin_dim = offset_provider.origin_axis.value @@ -220,7 +221,7 @@ def _visit_lift_in_neighbors_reduction( input_nodes = {} iterator_index_nodes = {} - lifted_index_connectors = set() + lifted_index_connectors = [] for x, y in inner_inputs: if isinstance(y, IteratorExpr): @@ -228,7 +229,7 @@ def _visit_lift_in_neighbors_reduction( input_nodes[field_connector] = y.field for dim, connector in inner_index_table.items(): if dim == neighbor_dim: - lifted_index_connectors.add(connector) + lifted_index_connectors.append(connector) iterator_index_nodes[connector] = y.indices[dim] else: assert isinstance(y, ValueExpr) @@ -298,6 +299,30 @@ def _visit_lift_in_neighbors_reduction( memlet=dace.Memlet(data=neighbor_value_node.data, subset=",".join(map_entry.params)), ) + if offset_provider.has_skip_values: + # check neighbor validity on if/else inter-state edge + start_state = lift_context.body.add_state("start", is_start_block=True) + skip_neighbor_state = lift_context.body.add_state("skip_neighbor") + skip_neighbor_state.add_edge( + skip_neighbor_state.add_tasklet( + "identity", {}, {"val"}, f"val = {transformer.context.reduce_identity.value}" + ), + "val", + skip_neighbor_state.add_access(inner_outputs[0].value.data), + None, + dace.Memlet(data=inner_outputs[0].value.data, subset="0"), + ) + lift_context.body.add_edge( + start_state, + skip_neighbor_state, + dace.InterstateEdge(condition=f"{lifted_index_connectors[0]} == {neighbor_skip_value}"), + ) + lift_context.body.add_edge( + start_state, + lift_context.state, + dace.InterstateEdge(condition=f"{lifted_index_connectors[0]} != {neighbor_skip_value}"), + ) + return [ValueExpr(neighbor_value_node, inner_outputs[0].dtype)] @@ -467,7 +492,7 @@ def builtin_neighbors( neighbor_valid_node = state.add_access(neighbor_valid_var, debuginfo=di) neighbor_valid_tasklet = state.add_tasklet( - "check_valid_neighbor", + f"check_valid_neighbor_{offset_dim}", {"__idx"}, {"__valid"}, f"__valid = True if __idx != {neighbor_skip_value} else False", @@ -1223,7 +1248,7 @@ def _visit_reduce(self, node: itir.FunCall): nreduce_shape = args_shape[0] input_args = [arg[0] for arg in args] - input_valid = [arg[1] for arg in args if len(arg) == 2] + input_valid_args = [arg[1] for arg in args if len(arg) == 2] nreduce_index = tuple(f"_i{i}" for i in range(len(nreduce_shape))) nreduce_domain = {idx: f"0:{size}" for idx, size in zip(nreduce_index, nreduce_shape)} @@ -1255,41 +1280,56 @@ def _visit_reduce(self, node: itir.FunCall): self.context.body, lambda_context.body, input_mapping ) - if input_valid: + if input_valid_args: """ - The neighbors builtin returns an array of booleans in case the connectivity table - contains skip values. These boolean values indicate whether the neighbor value is present or not, - and are used below to construct an if/else branch to bypass the lambda call for neighbor skip values. + The neighbors builtin returns an array of booleans in case the connectivity table contains skip values. + These booleans indicate whether the neighbor is present or not, and are used in a tasklet to select + the result of field access or the identity value, respectively. If the neighbor table has full connectivity (no skip values by type definition), the input_valid node - is not built, and the construction of the if/else branch below is also skipped. + is not built, and the construction of the select tasklet below is also skipped. """ - input_args.append(input_valid[0]) - input_valid_node = input_valid[0].value + input_args.append(input_valid_args[0]) + input_valid_node = input_valid_args[0].value + lambda_output_node = inner_outputs[0].value # add input connector to nested sdfg - input_mapping["is_valid"] = create_memlet_at(input_valid_node.data, nreduce_index) - # check neighbor validity on if/else inter-state edge - start_state = lambda_context.body.add_state("start", is_start_block=True) - skip_neighbor_state = lambda_context.body.add_state("skip_neighbor") - skip_neighbor_state.add_edge( - skip_neighbor_state.add_tasklet( - "identity", {}, {"val"}, f"val = {reduce_identity}" - ), - "val", - skip_neighbor_state.add_access(inner_outputs[0].value.data), + lambda_context.body.add_scalar("_valid_neighbor", dace.dtypes.bool) + input_mapping["_valid_neighbor"] = create_memlet_at( + input_valid_node.data, nreduce_index + ) + # add select tasklet before writing to output node + # TODO: consider replacing it with a select-memlet once it is supported by DaCe SDFG API + output_edge = lambda_context.state.in_edges(lambda_output_node)[0] + assert isinstance( + lambda_context.body.arrays[output_edge.src.data], dace.data.Scalar + ) + select_tasklet = lambda_context.state.add_tasklet( + "neighbor_select", + {"_inp", "_valid"}, + {"_out"}, + f"_out = _inp if _valid else {reduce_identity}", + ) + lambda_context.state.add_edge( + output_edge.src, None, - dace.Memlet(data=inner_outputs[0].value.data, subset="0"), + select_tasklet, + "_inp", + dace.Memlet(data=output_edge.src.data, subset="0"), ) - lambda_context.body.add_scalar("is_valid", dace.dtypes.bool) - lambda_context.body.add_edge( - start_state, - skip_neighbor_state, - dace.InterstateEdge(condition="is_valid == False"), + lambda_context.state.add_edge( + lambda_context.state.add_access("_valid_neighbor"), + None, + select_tasklet, + "_valid", + dace.Memlet(data="_valid_neighbor", subset="0"), ) - lambda_context.body.add_edge( - start_state, - lambda_context.state, - dace.InterstateEdge(condition="is_valid == True"), + lambda_context.state.add_edge( + select_tasklet, + "_out", + lambda_output_node, + None, + dace.Memlet(data=lambda_output_node.data, subset="0"), ) + lambda_context.state.remove_edge(output_edge) reduce_input_node = self.context.state.add_access(reduce_input_name, debuginfo=di) diff --git a/tests/next_tests/integration_tests/feature_tests/ffront_tests/test_execution.py b/tests/next_tests/integration_tests/feature_tests/ffront_tests/test_execution.py index 3c9c4e686c..e499f83f86 100644 --- a/tests/next_tests/integration_tests/feature_tests/ffront_tests/test_execution.py +++ b/tests/next_tests/integration_tests/feature_tests/ffront_tests/test_execution.py @@ -515,9 +515,9 @@ def combine(a: cases.IField, b: cases.IField) -> cases.IField: @pytest.mark.uses_reduction_over_lift_expressions def test_nested_reduction(unstructured_case): @gtx.field_operator - def testee(a: cases.EField) -> cases.EField: - tmp = neighbor_sum(a(V2E), axis=V2EDim) - tmp_2 = neighbor_sum(tmp(E2V), axis=E2VDim) + def testee(a: cases.VField) -> cases.VField: + tmp = neighbor_sum(a(E2V), axis=E2VDim) + tmp_2 = neighbor_sum(tmp(V2E), axis=V2EDim) return tmp_2 cases.verify_with_default_data( @@ -525,12 +525,12 @@ def testee(a: cases.EField) -> cases.EField: testee, ref=lambda a: np.sum( np.sum( - a[unstructured_case.offset_provider["V2E"].table], + a[unstructured_case.offset_provider["E2V"].table], axis=1, initial=0, - where=unstructured_case.offset_provider["V2E"].table != common.SKIP_VALUE, - )[unstructured_case.offset_provider["E2V"].table], + )[unstructured_case.offset_provider["V2E"].table], axis=1, + where=unstructured_case.offset_provider["V2E"].table != common.SKIP_VALUE, ), comparison=lambda a, tmp_2: np.all(a == tmp_2), ) From e0a8734467f685a973dab485c4425dda8639b68d Mon Sep 17 00:00:00 2001 From: edopao Date: Mon, 19 Feb 2024 09:13:33 +0100 Subject: [PATCH 83/85] feat[next][dace]: Use gt4py symbols for field size as dace array shape (#1458) This PR changes the ITIR DaCe backend to reuse the gt4py symbols for field size as dace symbols in array shape. The gt4py symbols are passed as scalar arguments to the program and are used in the definition of the closure domain. --- .../runners/dace_iterator/__init__.py | 15 +++++--- .../runners/dace_iterator/itir_to_sdfg.py | 38 +++++++++++-------- 2 files changed, 31 insertions(+), 22 deletions(-) diff --git a/src/gt4py/next/program_processors/runners/dace_iterator/__init__.py b/src/gt4py/next/program_processors/runners/dace_iterator/__init__.py index 432bf3e1bf..5a5df5ce14 100644 --- a/src/gt4py/next/program_processors/runners/dace_iterator/__init__.py +++ b/src/gt4py/next/program_processors/runners/dace_iterator/__init__.py @@ -215,7 +215,7 @@ def offset_invariants(offset): return m.hexdigest() -def get_sdfg_args(sdfg: dace.SDFG, *args, **kwargs) -> dict[str, Any]: +def get_sdfg_args(sdfg: dace.SDFG, *args, check_args: bool = False, **kwargs) -> dict[str, Any]: """Extracts the arguments needed to call the SDFG. This function can handle the same arguments that are passed to `run_dace_iterator()`. @@ -229,7 +229,6 @@ def get_sdfg_args(sdfg: dace.SDFG, *args, **kwargs) -> dict[str, Any]: neighbor_tables = filter_neighbor_tables(offset_provider) device = dace.DeviceType.GPU if on_gpu else dace.DeviceType.CPU - sdfg_sig = sdfg.signature_arglist(with_types=False) dace_args = get_args(sdfg, args) dace_field_args = {n: v for n, v in dace_args.items() if not np.isscalar(v)} dace_conn_args = get_connectivity_args(neighbor_tables, device) @@ -247,9 +246,13 @@ def get_sdfg_args(sdfg: dace.SDFG, *args, **kwargs) -> dict[str, Any]: **dace_conn_strides, **dace_offsets, } - expected_args = {key: all_args[key] for key in sdfg_sig} - return expected_args + if check_args: + # return only arguments expected in SDFG signature (note hat `signature_arglist` takes time) + sdfg_sig = sdfg.signature_arglist(with_types=False) + return {key: all_args[key] for key in sdfg_sig} + + return all_args def build_sdfg_from_itir( @@ -390,12 +393,12 @@ def run_dace_iterator(program: itir.FencilDefinition, *args, **kwargs): if build_cache is not None: build_cache[cache_id] = sdfg_program - expected_args = get_sdfg_args(sdfg, *args, **kwargs) + sdfg_args = get_sdfg_args(sdfg, *args, **kwargs) with dace.config.temporary_config(): dace.config.Config.set("compiler", "allow_view_arguments", value=True) dace.config.Config.set("frontend", "check_args", value=True) - sdfg_program(**expected_args) + sdfg_program(**sdfg_args) def _run_dace_cpu(program: itir.FencilDefinition, *args, **kwargs) -> None: diff --git a/src/gt4py/next/program_processors/runners/dace_iterator/itir_to_sdfg.py b/src/gt4py/next/program_processors/runners/dace_iterator/itir_to_sdfg.py index a578e9c19b..05a104449c 100644 --- a/src/gt4py/next/program_processors/runners/dace_iterator/itir_to_sdfg.py +++ b/src/gt4py/next/program_processors/runners/dace_iterator/itir_to_sdfg.py @@ -104,7 +104,7 @@ def _make_array_shape_and_strides( dims: Sequence[Dimension], neighbor_tables: Mapping[str, NeighborTable], sort_dims: bool, -) -> tuple[list[dace.symbol], list[dace.symbol]]: +) -> tuple[list[dace.symbol], list[dace.symbol], list[dace.symbol]]: """ Parse field dimensions and allocate symbols for array shape and strides. @@ -116,18 +116,20 @@ def _make_array_shape_and_strides( tuple(shape, strides) The output tuple fields are arrays of dace symbolic expressions. """ - dtype = dace.int64 - sorted_dims = [dim for _, dim in get_sorted_dims(dims)] if sort_dims else dims + dtype = dace.int32 + sorted_dims = get_sorted_dims(dims) if sort_dims else enumerate(dims) shape = [ ( neighbor_tables[dim.value].max_neighbors if dim.kind == DimensionKind.LOCAL - else dace.symbol(unique_name(f"{name}_shape{i}"), dtype) + # we reuse the same gt4py symbol for field size passed as scalar argument which is used in closure domain + else dace.symbol(f"__{name}_size_{i}", dtype) ) - for i, dim in enumerate(sorted_dims) + for i, dim in sorted_dims ] - strides = [dace.symbol(unique_name(f"{name}_stride{i}"), dtype) for i, _ in enumerate(shape)] - return shape, strides + offset = [dace.symbol(unique_name(f"{name}_offset{i}"), dtype) for i, _ in sorted_dims] + strides = [dace.symbol(unique_name(f"{name}_stride{i}"), dtype) for i, _ in sorted_dims] + return shape, offset, strides def _check_no_lifts(node: itir.StencilClosure): @@ -179,19 +181,24 @@ def add_storage( sort_dimensions: bool = True, ): if isinstance(type_, ts.FieldType): - shape, strides = _make_array_shape_and_strides( + shape, offset, strides = _make_array_shape_and_strides( name, type_.dims, neighbor_tables, sort_dimensions ) - offset = ( - [dace.symbol(unique_name(f"{name}_offset{i}_")) for i in range(len(type_.dims))] - if has_offset - else None - ) dtype = as_dace_type(type_.dtype) - sdfg.add_array(name, shape=shape, strides=strides, offset=offset, dtype=dtype) + sdfg.add_array( + name, + shape=shape, + strides=strides, + offset=(offset if has_offset else None), + dtype=dtype, + ) elif isinstance(type_, ts.ScalarType): - sdfg.add_symbol(name, as_dace_type(type_)) + dtype = as_dace_type(type_) + if name in sdfg.symbols: + assert sdfg.symbols[name].dtype == dtype + else: + sdfg.add_symbol(name, dtype) else: raise NotImplementedError() @@ -429,7 +436,6 @@ def visit_StencilClosure( for name, type_ in self.storage_types.items(): if isinstance(type_, ts.ScalarType): dtype = as_dace_type(type_) - closure_sdfg.add_symbol(name, dtype) if name in input_names: out_name = unique_var_name() closure_sdfg.add_scalar(out_name, dtype, transient=True) From e631c7f02ce00f5573c75c45eb64875ee1357e70 Mon Sep 17 00:00:00 2001 From: Rico Haeuselmann Date: Mon, 19 Feb 2024 10:49:01 +0100 Subject: [PATCH 84/85] feature[next]: toolchain configuration interfaces (#1438) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit * GTFN factories: high level toolchain creation interface * gt4py.next.config module: initialize global configuration options from env at import time * ADR for this iteration of configuration management --------- Co-authored-by: Rico Häuselmann Co-authored-by: Hannes Vogt --- .pre-commit-config.yaml | 1 + constraints.txt | 1 + .../ADRs/0017-Toolchain-Configuration.md | 264 ++++++++++++++++++ docs/development/ADRs/Index.md | 1 + min-extra-requirements-test.txt | 1 + min-requirements-test.txt | 1 + pyproject.toml | 1 + src/gt4py/next/config.py | 75 +++++ .../otf/compilation/build_systems/cmake.py | 20 +- .../compilation/build_systems/compiledb.py | 23 +- src/gt4py/next/otf/compilation/cache.py | 21 +- src/gt4py/next/otf/compilation/compiler.py | 16 +- .../codegens/gtfn/gtfn_module.py | 6 + .../program_processors/formatters/gtfn.py | 4 +- .../next/program_processors/runners/gtfn.py | 162 +++++------ tests/next_tests/__init__.py | 10 + .../otf_tests/test_nanobind_build.py | 7 +- .../test_with_toy_connectivity.py | 2 +- .../build_systems_tests/conftest.py | 4 +- .../build_systems_tests/test_cmake.py | 3 +- .../build_systems_tests/test_compiledb.py | 3 +- .../runners_tests/__init__.py | 13 + .../runners_tests/test_gtfn.py | 89 ++++++ 23 files changed, 586 insertions(+), 142 deletions(-) create mode 100644 docs/development/ADRs/0017-Toolchain-Configuration.md create mode 100644 src/gt4py/next/config.py create mode 100644 tests/next_tests/unit_tests/program_processor_tests/runners_tests/__init__.py create mode 100644 tests/next_tests/unit_tests/program_processor_tests/runners_tests/test_gtfn.py diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index 62bf3ce0ab..0b8eca844d 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -186,6 +186,7 @@ repos: - cytoolz==0.12.3 - deepdiff==6.7.1 - devtools==0.12.2 + - factory-boy==3.3.0 - frozendict==2.4.0 - gridtools-cpp==2.3.2 - importlib-resources==6.1.1 diff --git a/constraints.txt b/constraints.txt index 61bc04e671..0ee69751ea 100644 --- a/constraints.txt +++ b/constraints.txt @@ -6,6 +6,7 @@ # aenum==3.1.15 # via dace alabaster==0.7.13 # via sphinx +appnope==0.1.3 # via ipykernel, ipython asttokens==2.4.1 # via devtools, stack-data astunparse==1.6.3 ; python_version < "3.9" # via dace, gt4py (pyproject.toml) attrs==23.2.0 # via flake8-bugbear, flake8-eradicate, gt4py (pyproject.toml), hypothesis, jsonschema, referencing diff --git a/docs/development/ADRs/0017-Toolchain-Configuration.md b/docs/development/ADRs/0017-Toolchain-Configuration.md new file mode 100644 index 0000000000..3448314316 --- /dev/null +++ b/docs/development/ADRs/0017-Toolchain-Configuration.md @@ -0,0 +1,264 @@ +--- +tags: [backend, otf, workflows, toolchain] +--- + +# Toolchain Configuration + +- **Status**: valid +- **Authors**: Rico Häuselmann (@DropD), Till Ehrengruber (@tehrengruber) +- **Created**: 2024-02-13 +- **Updated**: 2024-02-13 + +In order to provide a streamlined user experience, we attempt to standardize how users of GT4Py stencils can configure how those stencils are optimized without editing GT4Py code. This describes the design of the first minimal implementation. + +## Context + +In this document the word toolchain is used to mean all the code components that work together to go from DSL code to an optimized, runnable python callable. +It includes JIT / OTF pipelines or workflows but also transformation passes, lowerings, parsers etc. + +In this document the term "end user" refers to someone who runs an application which uses GT4Py internally. The end user may or may not be aware of GT4Py, only of the documentation that the application provides. + +The most pressing issue is the developer experience. One debugging technique is to have the generated C++ code written to a permanent file location for inspection. This requires changing the code to reconfigure the build cache location. However this is forseeably only one of multiple values that stencil developers and their end users will wish to configure without touching GT4Py code. + +At the time of creation of this document, at least one additional toolchain is actively being worked on and there are plans to make additional parts of the existing toolchains configurable, both of which will compound the issue. + +**Concerns** + +- Some ways of configuring toolchains involve configuring multiple components in synch +- Some ways of configuring toolchains involve switching out or nesting toolchain steps +- What configuration options are avaiable will depend on what toolchain will be used and how that is configured. +- Hierarchical configuration defaults and overrides can be confusing from a user perspective. +- Leaving the configuration interface completely up to toolchain developers could lead to a confusing ad fragmented user experience. + +## Decision + +**All decisions below are in the spirit of keeping the scope of the initial implementation small and can be changed whenever changing them is suitably justified.** + +### Opt-in pattern for building toolchains from user configuration and client code + +Any toolchain that has user configurable options should provide a high level interface for building a toolchain that is consistent with the options set by the end user. If a default toolchain instance is provided in GT4Py code, it should use that interface. This ensures that the simplest way of obtaining an instance of a toolchain respects user configuration. + +The pattern established for the 'GTFN' toolchain uses [`factory-boy`](https://factoryboy.readthedocs.io/en/stable/index.html), a package designed for constructing [ORM](https://en.wikipedia.org/wiki/Object%E2%80%93relational_mapping) models. + +```python +class ToolchainFactory(factory.Factory): + class Meta: + model = ToolchainClass + + class Params: + high_level_parameters = ... # check factoryboy docs for possibilities + some_option = config.SOME_OPTION # default read from config module + + attribute_defaults = ... # may use parameter values etc +``` + +### Limit configuration options exposed to the end user + +Any option that the end user can change in order to influence the toolchain behavior must be defined in `gt4py.next.config` with + +- an internal name (a module level variable) +- an external name used to load from environment variables (possibly with a common prefix) +- a fallback default value in case no environment variable is defined + +Any other toolchain option is considered an implementation detail from the point of view of the end user. + +```python +# gt4py.next.config + +#: information about the configuration option +INTERNAL_NAME_1 = os.environ.get(f"{_PREFIX}_EXTERNAL_NAME_1", ) + +#: information about the configuration option +INTERNAL_NAME_2 = os.environ.get(f"{_PREFIX}_EXTERNAL_NAME_2", ) +``` + +Note that this module thus contains a handy list of all environment variables one can set to influence GT4Py behavior from the outside. It might be used to create the end user configuration documentation with sphinx, if variable docstrings are consistently used. + +### Read end user configuration only once at import time + +We design `gt4py.next.config` as a module with module level variables, which are initialized at import time from environment variables (if they exist). +We are aware that this decision has significant drawbacks. The main justification for it is to keep scope minimal by reusing the pattern from `gt4py.cartesian`. + +```python +# gt4py.next.config +MASTER_SWITCH = os.environ.get(f"{_PREFIX}_MASTER_SWITCH", "false") +DEPENDENT_OPT = os.environ.get(f"{_PREFIX}_DEPENDENT_OPT", "one_default" if MASTER_SWITCH else "another_default") + +if MASTER_SWITCH: + ... # more complex config related side effects +``` + +### Environment variables are the primary end user interface + +Each user configurable option must be loaded from an environment variable if it is set. If there is an in-code default value it must be overridden by the environment variable. This, particularly, was decided only to keep the implementation minimal. + +Changing this without changing the import time initialization for adding a configuration file might look something like the following: + +```python +# gt4py.next.config + +_FILE_CONFIG = read_config_file() + +OPTION1 = os.environ.get(f"{_PREFIX}_OPTION1", _FILE_CONFIG.option1 or "") +``` + +## Consequences + +### Changing configuration variables at runtime can lead to inconsistencies + +Config variables are module-level and initialized at import time. Therefore any + +- logic that switches one of them based on another or any other module-level +- initialization of dependent module-level defaults +- side effects + +Will also have to happen at import time. At least in the first case it can **only** happen at import time. +This means changing the variables after import time will lead to inconsistencies if any of those patterns are present. + +Implementations of the two latter patterns can be designed to mitigate this but at the cost of increased complexity elsewhere in the code. The first pattern can not. + +```python +# gt4py.next.config +MASTER_SWITCH = os.environ.get(f"{_PREFIX}_MASTER_SWITCH", "false") +DEPENDENT_OPT = os.environ.get(f"{_PREFIX}_DEPENDENT_OPT", "one_default" if MASTER_SWITCH else "another_default") + +if MASTER_SWITCH: + ... # more complex config related side effects + +# in client code +from gt4py.next import config + +config.MASTER_SWITCH = "true" +# config.DEPENDENT_OPT has not been changed and the logic of how to change it is not accessible to be called at this point +``` + +#### Testability is limited as a result + +- the patterns outlined above are not repeatable for testing purposes +- the potential resulting inconsistencies in configuration limit usefulness of changing the config variables for testing + +The example above illustrates this, the test being the client code in this case. + +#### Implementation is kept minimal + +Since we accept that changing the configuration at runtime may cause inconsistencies, + +- we do not have to implement any way of delaying the point when we read the configuration to the last possible moment +- no new pattern needs to be established for how to expose end user config vars + +### It is not possible to track where configuration values come from. + +- the user-set environment variables must override the config module fallback value (can not be tested) +- the config module value must override the default values of the high-level toolchain building interface (testing up to toolchain implementer, requires monkey patching) +- the high-level toolchain building interface's defaults should override defaults in the toolchain modules (possibly testable, up to toolchain implementer) +- arguments passed to the high-level toolchain building interface should override all others, including user config (testing up to toolchain implementer) +- toolchain instances created without the high-level interface can do whatever they want (not testable in general) + +The situation may arise that toolchain instances are used (possibly for good reason) in a program / library using GT4Py, which do not respect user configuration. It remains up to the implementer of such a program or library to communicate this to the user. + +```python +toolchain = ToolchainFactory(some_option="foo") # this will override any `config.SOME_OPTION` default + +@gtx.program(backend=toolchain) +def foo(...): + ... +``` + +In this case the client program author has chosen to hardcode 'some_option', disregarding any 'config.SOME_OPTION' configuration variable read from the user environment. This must be documented clearly for the end user of the client code. + +## Alternatives Considered + +### Warn the user of workflow instances that disregard user configuration + +The user would receive a warning, either + +- when a toolchain instance is created, which disregards user config, or +- when such a toolchain is used + +```python +toolchain = ToolchainFactory(some_option="foo") # either at this point a warning would be issued: +# warning: 'client_code.toolchain' is overriding your configuration option 'GT4PY_SOME_OPTION'. + +toolchain2 = ToolchainClass(...) # it would be more effort to emit warnings when toolchains are constructed directly + +@gtx.program(backend=toolchain) # or at this point: +def foo(...): + ... +# warning: Program 'client_code.foo' is using a toolchain that overrides your configuration option ... +``` + +The latter was never tried due to the obvious run time overhead in checking before every use. The former was experimented with and two types of implementation were considered: black-box analysis of the effects of configuration sources and configuration tracking. + +Black-box analysis of the effects of different configuration sources would be far less intrusive and lightweight. In contrast to tracking, it can only follow configuration sources down to the toolchain construction entry point, not to lower level defaults. It was shown to be feasible, however we could not justify the maintenance burden. The following is a sketch of an algorithm for such analysis: + +```python +env_defined_vars: dict[str, Any] # This would contain the environment variables and all the dependent configuration variables +config_file_defined_vars: dict[str, Any] # this would contain the configuration variables in the case of a clean environment +code_defined_vars: dict[str, Any] # This would contain the parameters passed to the toolchain building entry point + +var_sources = { + "env": env_defined_vars, + "config": config_file_defined_vars, + "in-code": code_defined_vars +} + +effects_per_var_source: dict[str, dict[str, list[str]]] = [] + +for var_source, vars in var_sources.items(): + res = factory(**vars) + # compute what attributes are different in res from default, i.e. factory() + effect = ... + effects_per_var_source[var_source][var] = effect + +for (source_a, effect_a), (source_b, effect_b) in itertools.product(effects_per_var_source.items(), effects_per_var_source.items()) + for name_outer, effect_outer in user_defined_vars_effects.items(): + for name_inner, effect_inner in code_defined_vars_effects.items(): + if intersection(effect_outer, effect_inner): + print("{source_a}:{name_outer} conflicts with {source_b}:{name_inner}") +``` + +Note that: + +- The current decision to initialize configuration variables at import time makes the distinction between `env_defined_vars` and `config_file_defined_vars` impractical. +- The algorithm would have to be inserted into every toolchain building entry point, which we want checked. This could be more or less high level and could be more or less involved. +- Logic for excluding conflicts between two levels of configuration which will be overridden anyway is not in the sketch but could be added. + +As opposed to black-box analysis, tracking would mean to annotate each config variable with it's source. This would involve refactoring at every level down to each configurable toolchain component. +Implementing tracking was briefly considered but looked like it would be too heavy weight to justify the maintenance burden and too much work for the appetite of the implementation project. The current choice of high-level toolchain building pattern (`factory-boy`) does not particularly lend itself to implementing tracking. + +### Dynamical loading of user configuration + +The first PoC used a function call to load user configuration just before using it. This would have increased testability at the cost of a less minimal implementation. + +```python + +class Configuration: + ... + @property + def option_1(self): + ... + + @option_1.setter + def option_1(self, value): + self._option_1 = value + self._dependent_option = ... + ... # more dependent behavior + +OPTION_1_DEFAULT + +def get_configuration(): + conf = Configuration() + conf.option_1 = read_from_env() or read_from_other_source() or default + +current_configuration: Configuration = get_configuration() + +# later on + +config.current_configuration.option_1 = "foo" # now all the dependent logic is handled correctly +``` + +A side effect of this would be that tests could work with independent configuration objects when necessary. + +### Dynamical exposing of configuration options + +It is in principle possible for every toolchain building interface to pick what it considers to look like configuration options from the environment variables. In practice this would make it very difficult to keep the experience consistent between toolchains. diff --git a/docs/development/ADRs/Index.md b/docs/development/ADRs/Index.md index 24272d9cee..072e6dc2ea 100644 --- a/docs/development/ADRs/Index.md +++ b/docs/development/ADRs/Index.md @@ -46,6 +46,7 @@ _None_ - [0007 - Fencil Processors](0007-Fencil-Processors.md) - [0008 - Mapping Domain to Cpp Backend](0008-Mapping_Domain_to_Cpp-Backend.md) - [0016 - Multiple Backends and Build Systems](0016-Multiple-Backends-and-Build-Systems.md) +- [0017 - Toolchain Configuration](0017-Toolchain-Configuration.md) ### Python Integration diff --git a/min-extra-requirements-test.txt b/min-extra-requirements-test.txt index d97ed299e1..6579db0f8f 100644 --- a/min-extra-requirements-test.txt +++ b/min-extra-requirements-test.txt @@ -65,6 +65,7 @@ darglint==1.6 deepdiff==5.6.0 devtools==0.6 factory-boy==3.1 +factory-boy==3.3.0 flake8-bugbear==20.11.1 flake8-builtins==1.5.3 flake8-debugger==4.0.0 diff --git a/min-requirements-test.txt b/min-requirements-test.txt index 63553623a2..6a05dbb9ac 100644 --- a/min-requirements-test.txt +++ b/min-requirements-test.txt @@ -61,6 +61,7 @@ darglint==1.6 deepdiff==5.6.0 devtools==0.6 factory-boy==3.1 +factory-boy==3.3.0 flake8-bugbear==20.11.1 flake8-builtins==1.5.3 flake8-debugger==4.0.0 diff --git a/pyproject.toml b/pyproject.toml index 5297dfb45b..3abbe8cd73 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -34,6 +34,7 @@ dependencies = [ 'cytoolz>=0.12.1', 'deepdiff>=5.6.0', 'devtools>=0.6', + 'factory-boy>=3.3.0', 'frozendict>=2.3', 'gridtools-cpp>=2.3.2,==2.*', "importlib-resources>=5.0;python_version<'3.9'", diff --git a/src/gt4py/next/config.py b/src/gt4py/next/config.py new file mode 100644 index 0000000000..74bf56d6e8 --- /dev/null +++ b/src/gt4py/next/config.py @@ -0,0 +1,75 @@ +# GT4Py - GridTools Framework +# +# Copyright (c) 2014-2023, ETH Zurich +# All rights reserved. +# +# This file is part of the GT4Py project and the GridTools framework. +# GT4Py is free software: you can redistribute it and/or modify it under +# the terms of the GNU General Public License as published by the +# Free Software Foundation, either version 3 of the License, or any later +# version. See the LICENSE.txt file at the top-level directory of this +# distribution for a copy of the license or check . +# +# SPDX-License-Identifier: GPL-3.0-or-later +from __future__ import annotations + +import enum +import os +import pathlib +import tempfile +from typing import Final + + +class BuildCacheLifetime(enum.Enum): + SESSION = 1 + PERSISTENT = 2 + + +class CMakeBuildType(enum.Enum): + def _generate_next_value_(name, start, count, last_values): + return "".join(part.capitalize() for part in name.split("_")) + + DEBUG = enum.auto() + RELEASE = enum.auto() + REL_WITH_DEB_INFO = enum.auto() + MIN_SIZE_REL = enum.auto() + + +def env_flag_to_bool(flag_value: str) -> bool: + """Like in gt4py.cartesian, env vars for flags should be set to '0' or '1'.""" + match flag_value: + case "0" | "1": + return bool(int(flag_value)) + case _: + raise ValueError("GT4Py flag environment variables must have value '0' or '1'.") + + +_PREFIX: Final[str] = "GT4PY" + +#: Master debug flag +#: Changes defaults for all the other options to be as helpful for debugging as possible. +#: Does not override values set in environment variables. +DEBUG: Final[bool] = env_flag_to_bool(os.environ.get(f"{_PREFIX}_DEBUG", "0")) + +#: Where generated code projects should be persisted. +#: Only active if BUILD_CACHE_LIFETIME is set to PERSISTENT +BUILD_CACHE_DIR: Final[pathlib.Path] = ( + pathlib.Path(os.environ.get(f"{_PREFIX}_BUILD_CACHE_DIR", tempfile.gettempdir())) + / "gt4py_cache" +) + + +#: Whether generated code projects should be kept around between runs. +#: - SESSION: generated code projects get destroyed when the interpreter shuts down +#: - PERSISTENT: generated code projects are written to BUILD_CACHE_DIR and persist between runs +BUILD_CACHE_LIFETIME: Final[BuildCacheLifetime] = getattr( + BuildCacheLifetime, + os.environ.get(f"{_PREFIX}_BUILD_CACHE_LIFETIME", "persistent" if DEBUG else "session").upper(), +) + +#: Build type to be used when CMake is used to compile generated code. +#: Might have no effect when CMake is not used as part of the toolchain. +CMAKE_BUILD_TYPE: Final[CMakeBuildType] = getattr( + CMakeBuildType, + os.environ.get(f"{_PREFIX}_CMAKE_BUILD_TYPE", "debug" if DEBUG else "release").upper(), +) diff --git a/src/gt4py/next/otf/compilation/build_systems/cmake.py b/src/gt4py/next/otf/compilation/build_systems/cmake.py index 3d36f5d985..2aadd4e21f 100644 --- a/src/gt4py/next/otf/compilation/build_systems/cmake.py +++ b/src/gt4py/next/otf/compilation/build_systems/cmake.py @@ -15,26 +15,16 @@ from __future__ import annotations import dataclasses -import enum import pathlib import subprocess from typing import Optional +from gt4py.next import config from gt4py.next.otf import languages, stages from gt4py.next.otf.compilation import build_data, cache, common, compiler from gt4py.next.otf.compilation.build_systems import cmake_lists -class BuildType(enum.Enum): - def _generate_next_value_(name, start, count, last_values): - return "".join(part.capitalize() for part in name.split("_")) - - DEBUG = enum.auto() - RELEASE = enum.auto() - REL_WITH_DEB_INFO = enum.auto() - MIN_SIZE_REL = enum.auto() - - @dataclasses.dataclass class CMakeFactory( compiler.BuildSystemProjectGenerator[ @@ -44,7 +34,7 @@ class CMakeFactory( """Create a CMakeProject from a ``CompilableSource`` stage object with given CMake settings.""" cmake_generator_name: str = "Ninja" - cmake_build_type: BuildType = BuildType.DEBUG + cmake_build_type: config.CMakeBuildType = config.CMakeBuildType.DEBUG cmake_extra_flags: Optional[list[str]] = None def __call__( @@ -54,7 +44,7 @@ def __call__( languages.LanguageWithHeaderFilesSettings, languages.Python, ], - cache_strategy: cache.Strategy, + cache_lifetime: config.BuildCacheLifetime, ) -> CMakeProject: if not source.binding_source: raise NotImplementedError( @@ -73,7 +63,7 @@ def __call__( languages=cmake_languages, ) return CMakeProject( - root_path=cache.get_cache_folder(source, cache_strategy), + root_path=cache.get_cache_folder(source, cache_lifetime), source_files={ header_name: source.program_source.source_code, bindings_name: source.binding_source.source_code, @@ -105,7 +95,7 @@ class CMakeProject( source_files: dict[str, str] program_name: str generator_name: str = "Ninja" - build_type: BuildType = BuildType.DEBUG + build_type: config.CMakeBuildType = config.CMakeBuildType.DEBUG extra_cmake_flags: list[str] = dataclasses.field(default_factory=list) def build(self): diff --git a/src/gt4py/next/otf/compilation/build_systems/compiledb.py b/src/gt4py/next/otf/compilation/build_systems/compiledb.py index 84a69859c0..140ea6a5fc 100644 --- a/src/gt4py/next/otf/compilation/build_systems/compiledb.py +++ b/src/gt4py/next/otf/compilation/build_systems/compiledb.py @@ -22,6 +22,7 @@ import subprocess from typing import Optional, TypeVar +from gt4py.next import config from gt4py.next.otf import languages, stages from gt4py.next.otf.binding import interface from gt4py.next.otf.compilation import build_data, cache, compiler @@ -44,7 +45,7 @@ class CompiledbFactory( Generate a compiledb only if there isn't one for the given combination of cmake configuration and library dependencies. """ - cmake_build_type: cmake.BuildType = cmake.BuildType.DEBUG + cmake_build_type: config.CMakeBuildType = config.CMakeBuildType.DEBUG cmake_extra_flags: list[str] = dataclasses.field(default_factory=list) renew_compiledb: bool = False @@ -55,7 +56,7 @@ def __call__( languages.LanguageWithHeaderFilesSettings, languages.Python, ], - cache_strategy: cache.Strategy, + cache_lifetime: config.BuildCacheLifetime, ) -> CompiledbProject: if not source.binding_source: raise NotImplementedError( @@ -74,17 +75,17 @@ def __call__( ) if self.renew_compiledb or not ( - compiledb_template := _cc_find_compiledb(cc_prototype_program_source, cache_strategy) + compiledb_template := _cc_find_compiledb(cc_prototype_program_source, cache_lifetime) ): compiledb_template = _cc_create_compiledb( cc_prototype_program_source, build_type=self.cmake_build_type, cmake_flags=self.cmake_extra_flags or [], - cache_strategy=cache_strategy, + cache_lifetime=cache_lifetime, ) return CompiledbProject( - root_path=cache.get_cache_folder(source, cache_strategy), + root_path=cache.get_cache_folder(source, cache_lifetime), program_name=name, source_files={ header_name: source.program_source.source_code, @@ -216,7 +217,7 @@ def _cc_prototype_program_name( def _cc_prototype_program_source( deps: tuple[interface.LibraryDependency, ...], - build_type: cmake.BuildType, + build_type: config.CMakeBuildType, cmake_flags: list[str], language: type[SrcL], language_settings: languages.LanguageWithHeaderFilesSettings, @@ -232,10 +233,10 @@ def _cc_prototype_program_source( def _cc_find_compiledb( - prototype_program_source: stages.ProgramSource, cache_strategy: cache.Strategy + prototype_program_source: stages.ProgramSource, cache_lifetime: config.BuildCacheLifetime ) -> Optional[pathlib.Path]: cache_path = cache.get_cache_folder( - stages.CompilableSource(prototype_program_source, None), cache_strategy + stages.CompilableSource(prototype_program_source, None), cache_lifetime ) compile_db_path = cache_path / "compile_commands.json" if compile_db_path.exists(): @@ -245,13 +246,13 @@ def _cc_find_compiledb( def _cc_create_compiledb( prototype_program_source: stages.ProgramSource, - build_type: cmake.BuildType, + build_type: config.CMakeBuildType, cmake_flags: list[str], - cache_strategy: cache.Strategy, + cache_lifetime: config.BuildCacheLifetime, ) -> pathlib.Path: name = prototype_program_source.entry_point.name cache_path = cache.get_cache_folder( - stages.CompilableSource(prototype_program_source, None), cache_strategy + stages.CompilableSource(prototype_program_source, None), cache_lifetime ) header_ext = prototype_program_source.language_settings.header_extension diff --git a/src/gt4py/next/otf/compilation/cache.py b/src/gt4py/next/otf/compilation/cache.py index ed59522410..ee5ec650e0 100644 --- a/src/gt4py/next/otf/compilation/cache.py +++ b/src/gt4py/next/otf/compilation/cache.py @@ -15,28 +15,23 @@ """Caching for compiled backend artifacts.""" -import enum import hashlib import pathlib import tempfile +from gt4py.next import config from gt4py.next.otf import stages from gt4py.next.otf.binding import interface -class Strategy(enum.Enum): - SESSION = 1 - PERSISTENT = 2 - - _session_cache_dir = tempfile.TemporaryDirectory(prefix="gt4py_session_") _session_cache_dir_path = pathlib.Path(_session_cache_dir.name) -_persistent_cache_dir_path = pathlib.Path(tempfile.gettempdir()) / "gt4py_cache" +_persistent_cache_dir_path = config.BUILD_CACHE_DIR def _serialize_param(parameter: interface.Parameter) -> str: - return f"{parameter.name}: {str(parameter.type_)}" + return f"{parameter.name}: {parameter.type_!s}" def _serialize_library_dependency(dependency: interface.LibraryDependency) -> str: @@ -63,7 +58,7 @@ def _cache_folder_name(source: stages.ProgramSource) -> str: def get_cache_folder( - compilable_source: stages.CompilableSource, strategy: Strategy + compilable_source: stages.CompilableSource, lifetime: config.BuildCacheLifetime ) -> pathlib.Path: """ Construct the path to where the build system project artifact of a compilable source should be cached. @@ -73,13 +68,13 @@ def get_cache_folder( # TODO(ricoh): make dependent on binding source too or add alternative that depends on bindings folder_name = _cache_folder_name(compilable_source.program_source) - match strategy: - case Strategy.SESSION: + match lifetime: + case config.BuildCacheLifetime.SESSION: base_path = _session_cache_dir_path - case Strategy.PERSISTENT: + case config.BuildCacheLifetime.PERSISTENT: base_path = _persistent_cache_dir_path case _: - raise ValueError("Unsupported caching strategy.") + raise ValueError("Unsupported caching lifetime.") base_path.mkdir(exist_ok=True) diff --git a/src/gt4py/next/otf/compilation/compiler.py b/src/gt4py/next/otf/compilation/compiler.py index 29541a3ae5..fb85d074df 100644 --- a/src/gt4py/next/otf/compilation/compiler.py +++ b/src/gt4py/next/otf/compilation/compiler.py @@ -18,6 +18,9 @@ import pathlib from typing import Protocol, TypeVar +import factory + +from gt4py.next import config from gt4py.next.otf import languages, stages, step_types, workflow from gt4py.next.otf.compilation import build_data, cache, importer from gt4py.next.otf.step_types import LS, SrcL, TgtL @@ -40,7 +43,7 @@ class BuildSystemProjectGenerator(Protocol[SrcL, LS, TgtL]): def __call__( self, source: stages.CompilableSource[SrcL, LS, TgtL], - cache_strategy: cache.Strategy, + cache_lifetime: config.BuildCacheLifetime, ) -> stages.BuildSystemProject[SrcL, LS, TgtL]: ... @@ -58,7 +61,7 @@ class Compiler( ): """Use any build system (via configured factory) to compile a GT4Py program to a ``gt4py.next.otf.stages.CompiledProgram``.""" - cache_strategy: cache.Strategy + cache_lifetime: config.BuildCacheLifetime builder_factory: BuildSystemProjectGenerator[ SourceLanguageType, LanguageSettingsType, languages.Python ] @@ -68,12 +71,12 @@ def __call__( self, inp: stages.CompilableSource[SourceLanguageType, LanguageSettingsType, languages.Python], ) -> stages.CompiledProgram: - src_dir = cache.get_cache_folder(inp, self.cache_strategy) + src_dir = cache.get_cache_folder(inp, self.cache_lifetime) data = build_data.read_data(src_dir) if not data or not is_compiled(data) or self.force_recompile: - self.builder_factory(inp, self.cache_strategy).build() + self.builder_factory(inp, self.cache_lifetime).build() new_data = build_data.read_data(src_dir) @@ -87,4 +90,9 @@ def __call__( ) +class CompilerFactory(factory.Factory): + class Meta: + model = Compiler + + class CompilationError(RuntimeError): ... diff --git a/src/gt4py/next/program_processors/codegens/gtfn/gtfn_module.py b/src/gt4py/next/program_processors/codegens/gtfn/gtfn_module.py index c157cdcc46..46861197fe 100644 --- a/src/gt4py/next/program_processors/codegens/gtfn/gtfn_module.py +++ b/src/gt4py/next/program_processors/codegens/gtfn/gtfn_module.py @@ -19,6 +19,7 @@ import warnings from typing import Any, Callable, Final, Optional +import factory import numpy as np from gt4py._core import definitions as core_defs @@ -351,6 +352,11 @@ def _not_implemented_for_device_type(self) -> NotImplementedError: ) +class GTFNTranslationStepFactory(factory.Factory): + class Meta: + model = GTFNTranslationStep + + translate_program_cpu: Final[step_types.TranslationStep] = GTFNTranslationStep() translate_program_gpu: Final[step_types.TranslationStep] = GTFNTranslationStep( diff --git a/src/gt4py/next/program_processors/formatters/gtfn.py b/src/gt4py/next/program_processors/formatters/gtfn.py index 27dec77ed1..6c8d4478c2 100644 --- a/src/gt4py/next/program_processors/formatters/gtfn.py +++ b/src/gt4py/next/program_processors/formatters/gtfn.py @@ -17,13 +17,13 @@ from gt4py.next.iterator import ir as itir from gt4py.next.program_processors.codegens.gtfn.gtfn_module import GTFNTranslationStep from gt4py.next.program_processors.processor_interface import program_formatter -from gt4py.next.program_processors.runners.gtfn import gtfn_executor +from gt4py.next.program_processors.runners import gtfn @program_formatter def format_cpp(program: itir.FencilDefinition, *args: Any, **kwargs: Any) -> str: # TODO(tehrengruber): This is a little ugly. Revisit. - gtfn_translation = gtfn_executor.otf_workflow.translation + gtfn_translation = gtfn.GTFNBackendFactory().executor.otf_workflow.translation assert isinstance(gtfn_translation, GTFNTranslationStep) return gtfn_translation.generate_stencil_source( program, diff --git a/src/gt4py/next/program_processors/runners/gtfn.py b/src/gt4py/next/program_processors/runners/gtfn.py index 157c00c368..04af4a5283 100644 --- a/src/gt4py/next/program_processors/runners/gtfn.py +++ b/src/gt4py/next/program_processors/runners/gtfn.py @@ -16,16 +16,17 @@ import warnings from typing import Any +import factory import numpy.typing as npt import gt4py._core.definitions as core_defs import gt4py.next.allocators as next_allocators from gt4py.eve.utils import content_hash -from gt4py.next import common +from gt4py.next import common, config from gt4py.next.iterator.transforms import LiftMode, global_tmps -from gt4py.next.otf import languages, recipes, stages, step_types, workflow +from gt4py.next.otf import recipes, stages, workflow from gt4py.next.otf.binding import nanobind -from gt4py.next.otf.compilation import cache, compiler +from gt4py.next.otf.compilation import compiler from gt4py.next.otf.compilation.build_systems import compiledb from gt4py.next.program_processors import otf_compile_executor from gt4py.next.program_processors.codegens.gtfn import gtfn_module @@ -113,104 +114,87 @@ def compilation_hash(otf_closure: stages.ProgramCall) -> int: ) -GTFN_DEFAULT_TRANSLATION_STEP: step_types.TranslationStep[ - languages.NanobindSrcL, languages.LanguageWithHeaderFilesSettings -] = gtfn_module.GTFNTranslationStep( - enable_itir_transforms=True, - use_imperative_backend=False, - device_type=core_defs.DeviceType.CPU, -) - -GTFN_GPU_TRANSLATION_STEP: step_types.TranslationStep[ - languages.NanobindSrcL, languages.LanguageWithHeaderFilesSettings -] = gtfn_module.GTFNTranslationStep( - enable_itir_transforms=True, - use_imperative_backend=False, - device_type=core_defs.DeviceType.CUDA, -) +class GTFNCompileWorkflowFactory(factory.Factory): + class Meta: + model = recipes.OTFCompileWorkflow -GTFN_DEFAULT_COMPILE_STEP: step_types.CompilationStep = compiler.Compiler( - cache_strategy=cache.Strategy.SESSION, builder_factory=compiledb.CompiledbFactory() -) + class Params: + device_type: core_defs.DeviceType = core_defs.DeviceType.CPU + cmake_build_type: config.CMakeBuildType = factory.LazyFunction( + lambda: config.CMAKE_BUILD_TYPE + ) + builder_factory: compiler.BuildSystemProjectGenerator = factory.LazyAttribute( + lambda o: compiledb.CompiledbFactory(cmake_build_type=o.cmake_build_type) + ) + translation = factory.SubFactory( + gtfn_module.GTFNTranslationStepFactory, device_type=factory.SelfAttribute("..device_type") + ) + bindings: workflow.Workflow[stages.ProgramSource, stages.CompilableSource] = ( + nanobind.bind_source + ) + compilation = factory.SubFactory( + compiler.CompilerFactory, + cache_lifetime=factory.LazyFunction(lambda: config.BUILD_CACHE_LIFETIME), + builder_factory=factory.SelfAttribute("..builder_factory"), + ) + decoration = factory.LazyAttribute( + lambda o: functools.partial(convert_args, device=o.device_type) + ) -GTFN_DEFAULT_WORKFLOW = recipes.OTFCompileWorkflow( - translation=GTFN_DEFAULT_TRANSLATION_STEP, - bindings=nanobind.bind_source, - compilation=GTFN_DEFAULT_COMPILE_STEP, - decoration=convert_args, -) +class GTFNBackendFactory(factory.Factory): + class Meta: + model = otf_compile_executor.OTFBackend -GTFN_GPU_WORKFLOW = recipes.OTFCompileWorkflow( - translation=GTFN_GPU_TRANSLATION_STEP, - bindings=nanobind.bind_source, - compilation=GTFN_DEFAULT_COMPILE_STEP, - decoration=functools.partial(convert_args, device=core_defs.DeviceType.CUDA), -) + class Params: + name_device = "cpu" + name_cached = "" + name_postfix = "" + gpu = factory.Trait( + allocator=next_allocators.StandardGPUFieldBufferAllocator(), + device_type=core_defs.DeviceType.CUDA, + name_device="gpu", + ) + cached = factory.Trait( + executor=factory.LazyAttribute( + lambda o: otf_compile_executor.CachedOTFCompileExecutor( + otf_workflow=workflow.CachedStep(o.otf_workflow, hash_function=o.hash_function), + name=o.name, + ) + ), + name_cached="_cached", + ) + device_type = core_defs.DeviceType.CPU + hash_function = compilation_hash + otf_workflow = factory.SubFactory( + GTFNCompileWorkflowFactory, device_type=factory.SelfAttribute("..device_type") + ) + name = factory.LazyAttribute( + lambda o: f"run_gtfn_{o.name_device}{o.name_cached}{o.name_postfix}" + ) + executor = factory.LazyAttribute( + lambda o: otf_compile_executor.OTFCompileExecutor(otf_workflow=o.otf_workflow, name=o.name) + ) + allocator = next_allocators.StandardCPUFieldBufferAllocator() -gtfn_executor = otf_compile_executor.OTFCompileExecutor( - name="run_gtfn", otf_workflow=GTFN_DEFAULT_WORKFLOW -) -run_gtfn = otf_compile_executor.OTFBackend( - executor=gtfn_executor, - allocator=next_allocators.StandardCPUFieldBufferAllocator(), -) -gtfn_imperative_executor = otf_compile_executor.OTFCompileExecutor( - name="run_gtfn_imperative", - otf_workflow=gtfn_executor.otf_workflow.replace( - translation=gtfn_executor.otf_workflow.translation.replace(use_imperative_backend=True), - ), -) -run_gtfn_imperative = otf_compile_executor.OTFBackend( - executor=gtfn_imperative_executor, - allocator=next_allocators.StandardCPUFieldBufferAllocator(), -) +run_gtfn = GTFNBackendFactory() -# TODO(ricoh): add API for converting an executor to a cached version of itself and vice versa -gtfn_cached_executor = otf_compile_executor.CachedOTFCompileExecutor( - name="run_gtfn_cached", - otf_workflow=workflow.CachedStep( - step=gtfn_executor.otf_workflow, hash_function=compilation_hash - ), -) -run_gtfn_cached = otf_compile_executor.OTFBackend( - executor=gtfn_cached_executor, - allocator=next_allocators.StandardCPUFieldBufferAllocator(), +run_gtfn_imperative = GTFNBackendFactory( + name_postfix="_imperative", + otf_workflow__translation__use_imperative_backend=True, ) +run_gtfn_cached = GTFNBackendFactory(cached=True) -run_gtfn_with_temporaries = otf_compile_executor.OTFBackend( - executor=otf_compile_executor.OTFCompileExecutor( - name="run_gtfn_with_temporaries", - otf_workflow=gtfn_executor.otf_workflow.replace( - translation=gtfn_executor.otf_workflow.translation.replace( - lift_mode=LiftMode.FORCE_TEMPORARIES, - temporary_extraction_heuristics=global_tmps.SimpleTemporaryExtractionHeuristics, - ), - ), - ), - allocator=next_allocators.StandardCPUFieldBufferAllocator(), +run_gtfn_with_temporaries = GTFNBackendFactory( + name_postfix="_with_temporaries", + otf_workflow__translation__lift_mode=LiftMode.FORCE_TEMPORARIES, + otf_workflow__translation__temporary_extraction_heuristics=global_tmps.SimpleTemporaryExtractionHeuristics, ) -gtfn_gpu_executor = otf_compile_executor.OTFCompileExecutor( - name="run_gtfn_gpu", otf_workflow=GTFN_GPU_WORKFLOW -) -run_gtfn_gpu = otf_compile_executor.OTFBackend( - executor=gtfn_gpu_executor, - allocator=next_allocators.StandardGPUFieldBufferAllocator(), -) +run_gtfn_gpu = GTFNBackendFactory(gpu=True) - -gtfn_gpu_cached_executor = otf_compile_executor.CachedOTFCompileExecutor( - name="run_gtfn_gpu_cached", - otf_workflow=workflow.CachedStep( - step=gtfn_gpu_executor.otf_workflow, hash_function=compilation_hash - ), -) -run_gtfn_gpu_cached = otf_compile_executor.OTFBackend( - executor=gtfn_gpu_cached_executor, - allocator=next_allocators.StandardGPUFieldBufferAllocator(), -) +run_gtfn_gpu_cached = GTFNBackendFactory(gpu=True, cached=True) diff --git a/tests/next_tests/__init__.py b/tests/next_tests/__init__.py index 1745dac6ef..96a106a1e6 100644 --- a/tests/next_tests/__init__.py +++ b/tests/next_tests/__init__.py @@ -12,12 +12,22 @@ # # SPDX-License-Identifier: GPL-3.0-or-later +import warnings + +from gt4py.next import config + from . import definitions __all__ = ["definitions", "get_processor_id"] +if config.BUILD_CACHE_LIFETIME is config.BuildCacheLifetime.PERSISTENT: + warnings.warn( + "You are running GT4Py tests with BUILD_CACHE_LIFETIME set to PERSISTENT!", UserWarning + ) + + def get_processor_id(processor): if hasattr(processor, "__module__") and hasattr(processor, "__name__"): module_path = processor.__module__.split(".")[-1] diff --git a/tests/next_tests/integration_tests/feature_tests/otf_tests/test_nanobind_build.py b/tests/next_tests/integration_tests/feature_tests/otf_tests/test_nanobind_build.py index f24dc4bc59..936bb3ee58 100644 --- a/tests/next_tests/integration_tests/feature_tests/otf_tests/test_nanobind_build.py +++ b/tests/next_tests/integration_tests/feature_tests/otf_tests/test_nanobind_build.py @@ -16,9 +16,10 @@ import numpy as np +from gt4py.next import config from gt4py.next.otf import workflow from gt4py.next.otf.binding import nanobind -from gt4py.next.otf.compilation import cache, compiler +from gt4py.next.otf.compilation import compiler from gt4py.next.otf.compilation.build_systems import cmake, compiledb from next_tests.unit_tests.otf_tests.compilation_tests.build_systems_tests.conftest import ( @@ -30,7 +31,7 @@ def test_gtfn_cpp_with_cmake(program_source_with_name): example_program_source = program_source_with_name("gtfn_cpp_with_cmake") build_the_program = workflow.make_step(nanobind.bind_source).chain( compiler.Compiler( - cache_strategy=cache.Strategy.SESSION, builder_factory=cmake.CMakeFactory() + cache_lifetime=config.BuildCacheLifetime.SESSION, builder_factory=cmake.CMakeFactory() ), ) compiled_program = build_the_program(example_program_source) @@ -48,7 +49,7 @@ def test_gtfn_cpp_with_compiledb(program_source_with_name): example_program_source = program_source_with_name("gtfn_cpp_with_compiledb") build_the_program = workflow.make_step(nanobind.bind_source).chain( compiler.Compiler( - cache_strategy=cache.Strategy.SESSION, + cache_lifetime=config.BuildCacheLifetime.SESSION, builder_factory=compiledb.CompiledbFactory(), ), ) diff --git a/tests/next_tests/integration_tests/multi_feature_tests/iterator_tests/test_with_toy_connectivity.py b/tests/next_tests/integration_tests/multi_feature_tests/iterator_tests/test_with_toy_connectivity.py index 000d3c4822..0b9b639b08 100644 --- a/tests/next_tests/integration_tests/multi_feature_tests/iterator_tests/test_with_toy_connectivity.py +++ b/tests/next_tests/integration_tests/multi_feature_tests/iterator_tests/test_with_toy_connectivity.py @@ -413,7 +413,7 @@ def test_shift_sparse_input_field2(program_processor, lift_mode): ]: pytest.xfail( "Bug in bindings/compilation/caching: only the first program seems to be compiled." - ) # observed in `cache.Strategy.PERSISTENT` mode + ) # observed in `config.BuildCacheLifetime.PERSISTENT` mode inp = vertex_index_field() inp_sparse = gtx.as_field([Edge, E2VDim], e2v_arr) out1 = gtx.as_field([Vertex], np.zeros([9], dtype=inp.dtype)) diff --git a/tests/next_tests/unit_tests/otf_tests/compilation_tests/build_systems_tests/conftest.py b/tests/next_tests/unit_tests/otf_tests/compilation_tests/build_systems_tests/conftest.py index 45ef85e37c..0911576fd6 100644 --- a/tests/next_tests/unit_tests/otf_tests/compilation_tests/build_systems_tests/conftest.py +++ b/tests/next_tests/unit_tests/otf_tests/compilation_tests/build_systems_tests/conftest.py @@ -15,11 +15,11 @@ import shutil import jinja2 -import numpy as np import pytest import gt4py.next as gtx import gt4py.next.type_system.type_specifications as ts +from gt4py.next import config from gt4py.next.otf import languages, stages from gt4py.next.otf.binding import cpp_interface, interface, nanobind from gt4py.next.otf.compilation import cache @@ -105,7 +105,7 @@ def compilable_source_example(program_source_example): @pytest.fixture def clean_example_session_cache(compilable_source_example): - cache_dir = cache.get_cache_folder(compilable_source_example, cache.Strategy.SESSION) + cache_dir = cache.get_cache_folder(compilable_source_example, config.BuildCacheLifetime.SESSION) if cache_dir.exists(): shutil.rmtree(cache_dir) yield diff --git a/tests/next_tests/unit_tests/otf_tests/compilation_tests/build_systems_tests/test_cmake.py b/tests/next_tests/unit_tests/otf_tests/compilation_tests/build_systems_tests/test_cmake.py index efe3b145dd..036060e5f1 100644 --- a/tests/next_tests/unit_tests/otf_tests/compilation_tests/build_systems_tests/test_cmake.py +++ b/tests/next_tests/unit_tests/otf_tests/compilation_tests/build_systems_tests/test_cmake.py @@ -14,13 +14,14 @@ import pathlib +from gt4py.next import config from gt4py.next.otf.compilation import build_data, cache, importer from gt4py.next.otf.compilation.build_systems import cmake def test_default_cmake_factory(compilable_source_example, clean_example_session_cache): otf_builder = cmake.CMakeFactory()( - source=compilable_source_example, cache_strategy=cache.Strategy.SESSION + source=compilable_source_example, cache_lifetime=config.BuildCacheLifetime.SESSION ) assert not build_data.contains_data(otf_builder.root_path) diff --git a/tests/next_tests/unit_tests/otf_tests/compilation_tests/build_systems_tests/test_compiledb.py b/tests/next_tests/unit_tests/otf_tests/compilation_tests/build_systems_tests/test_compiledb.py index 5fab08f52c..ef6980eb88 100644 --- a/tests/next_tests/unit_tests/otf_tests/compilation_tests/build_systems_tests/test_compiledb.py +++ b/tests/next_tests/unit_tests/otf_tests/compilation_tests/build_systems_tests/test_compiledb.py @@ -14,13 +14,14 @@ import pathlib +from gt4py.next import config from gt4py.next.otf.compilation import build_data, cache, importer from gt4py.next.otf.compilation.build_systems import compiledb def test_default_compiledb_factory(compilable_source_example, clean_example_session_cache): otf_builder = compiledb.CompiledbFactory()( - compilable_source_example, cache_strategy=cache.Strategy.SESSION + compilable_source_example, cache_lifetime=config.BuildCacheLifetime.SESSION ) # make sure the example project has not been written yet diff --git a/tests/next_tests/unit_tests/program_processor_tests/runners_tests/__init__.py b/tests/next_tests/unit_tests/program_processor_tests/runners_tests/__init__.py new file mode 100644 index 0000000000..6c43e2f12a --- /dev/null +++ b/tests/next_tests/unit_tests/program_processor_tests/runners_tests/__init__.py @@ -0,0 +1,13 @@ +# GT4Py - GridTools Framework +# +# Copyright (c) 2014-2023, ETH Zurich +# All rights reserved. +# +# This file is part of the GT4Py project and the GridTools framework. +# GT4Py is free software: you can redistribute it and/or modify it under +# the terms of the GNU General Public License as published by the +# Free Software Foundation, either version 3 of the License, or any later +# version. See the LICENSE.txt file at the top-level directory of this +# distribution for a copy of the license or check . +# +# SPDX-License-Identifier: GPL-3.0-or-later diff --git a/tests/next_tests/unit_tests/program_processor_tests/runners_tests/test_gtfn.py b/tests/next_tests/unit_tests/program_processor_tests/runners_tests/test_gtfn.py new file mode 100644 index 0000000000..897ddcfc08 --- /dev/null +++ b/tests/next_tests/unit_tests/program_processor_tests/runners_tests/test_gtfn.py @@ -0,0 +1,89 @@ +# GT4Py - GridTools Framework +# +# Copyright (c) 2014-2023, ETH Zurich +# All rights reserved. +# +# This file is part of the GT4Py project and the GridTools framework. +# GT4Py is free software: you can redistribute it and/or modify it under +# the terms of the GNU General Public License as published by the +# Free Software Foundation, either version 3 of the License, or any later +# version. See the LICENSE.txt file at the top-level directory of this +# distribution for a copy of the license or check . +# +# SPDX-License-Identifier: GPL-3.0-or-later +""" +Test that the high level gtfn interface respects user config. + +Note: certain features of the config system can not be tested. + +These features include: +- build cache location +- debug mode + +Because monkey patching the config variables is not enough, as +other variables are computed at import time based on them. +""" + +import gt4py._core.definitions as core_defs +from gt4py.next import allocators, config +from gt4py.next.otf import workflow +from gt4py.next.program_processors.runners import gtfn + + +def test_backend_factory_set_device(): + cpu_version = gtfn.GTFNBackendFactory(gpu=False, cached=False) + gpu_version = gtfn.GTFNBackendFactory(gpu=True, cached=False) + + assert cpu_version.executor.__name__ == "run_gtfn_cpu" + assert gpu_version.executor.__name__ == "run_gtfn_gpu" + + assert cpu_version.executor.otf_workflow.translation.device_type is core_defs.DeviceType.CPU + assert gpu_version.executor.otf_workflow.translation.device_type is core_defs.DeviceType.CUDA + + assert ( + cpu_version.executor.otf_workflow.decoration.keywords["device"] is core_defs.DeviceType.CPU + ) + assert ( + gpu_version.executor.otf_workflow.decoration.keywords["device"] is core_defs.DeviceType.CUDA + ) + + assert allocators.is_field_allocator_for(cpu_version.allocator, core_defs.DeviceType.CPU) + assert allocators.is_field_allocator_for(gpu_version.allocator, core_defs.DeviceType.CUDA) + + +def test_backend_factory_set_cached(): + cached_version = gtfn.GTFNBackendFactory(gpu=False, cached=True) + assert isinstance(cached_version.executor.otf_workflow, workflow.CachedStep) + assert cached_version.executor.__name__ == "run_gtfn_cpu_cached" + + +def test_backend_factory_build_cache_config(monkeypatch): + monkeypatch.setattr(config, "BUILD_CACHE_LIFETIME", config.BuildCacheLifetime.SESSION) + session_version = gtfn.GTFNBackendFactory() + monkeypatch.setattr(config, "BUILD_CACHE_LIFETIME", config.BuildCacheLifetime.PERSISTENT) + persistent_version = gtfn.GTFNBackendFactory() + + assert ( + session_version.executor.otf_workflow.compilation.cache_lifetime + is config.BuildCacheLifetime.SESSION + ) + assert ( + persistent_version.executor.otf_workflow.compilation.cache_lifetime + is config.BuildCacheLifetime.PERSISTENT + ) + + +def test_backend_factory_build_type_config(monkeypatch): + monkeypatch.setattr(config, "CMAKE_BUILD_TYPE", config.CMakeBuildType.RELEASE) + release_version = gtfn.GTFNBackendFactory() + monkeypatch.setattr(config, "CMAKE_BUILD_TYPE", config.CMakeBuildType.MIN_SIZE_REL) + min_size_version = gtfn.GTFNBackendFactory() + + assert ( + release_version.executor.otf_workflow.compilation.builder_factory.cmake_build_type + is config.CMakeBuildType.RELEASE + ) + assert ( + min_size_version.executor.otf_workflow.compilation.builder_factory.cmake_build_type + is config.CMakeBuildType.MIN_SIZE_REL + ) From ba353d3bf7c5352906d34bf4ab770341401d8938 Mon Sep 17 00:00:00 2001 From: edopao Date: Mon, 19 Feb 2024 15:31:12 +0100 Subject: [PATCH 85/85] feat[next][dace]: Remove offsets in connectivity arrays (#1460) Remove generation of offset symbols for connectivity arrays. --- .../runners/dace_iterator/__init__.py | 21 +++++++++++++------ 1 file changed, 15 insertions(+), 6 deletions(-) diff --git a/src/gt4py/next/program_processors/runners/dace_iterator/__init__.py b/src/gt4py/next/program_processors/runners/dace_iterator/__init__.py index 5a5df5ce14..aba5656192 100644 --- a/src/gt4py/next/program_processors/runners/dace_iterator/__init__.py +++ b/src/gt4py/next/program_processors/runners/dace_iterator/__init__.py @@ -140,16 +140,23 @@ def get_shape_args( return shape_args -def get_offset_args( - sdfg: dace.SDFG, - args: Sequence[Any], -) -> Mapping[str, int]: +def get_offset_args(sdfg: dace.SDFG, args: Sequence[Any]) -> Mapping[str, int]: sdfg_arrays: Mapping[str, dace.data.Array] = sdfg.arrays sdfg_params: Sequence[str] = sdfg.arg_names + field_args = {param: arg for param, arg in zip(sdfg_params, args) if common.is_field(arg)} + + # assume that arrays for connectivity tables do not use offset + assert all( + drange.start == 0 + for sdfg_param, arg in field_args.items() + if sdfg_param.startswith("__connectivity") + for drange in arg.domain.ranges + ) + return { str(sym): -drange.start - for sdfg_param, arg in zip(sdfg_params, args) - if common.is_field(arg) + for sdfg_param, arg in field_args.items() + if not sdfg_param.startswith("__connectivity") for sym, drange in zip(sdfg_arrays[sdfg_param].offset, get_sorted_dim_ranges(arg.domain)) } @@ -331,6 +338,8 @@ def build_sdfg_from_itir( symbols: dict[str, int] = {} device = dace.DeviceType.GPU if on_gpu else dace.DeviceType.CPU sdfg = autoopt.auto_optimize(sdfg, device, symbols=symbols, use_gpu_storage=on_gpu) + elif on_gpu: + autoopt.apply_gpu_storage(sdfg) if on_gpu: sdfg.apply_gpu_transformations()