From cf9c442346f3e4c2543c0e1c289bbed75c09fdd8 Mon Sep 17 00:00:00 2001 From: Roman Cattaneo <1116746+romanc@users.noreply.github.com> Date: Fri, 14 Feb 2025 17:54:31 +0100 Subject: [PATCH 1/2] refactor[cartesian]: use DeviceType from core_defn --- src/gt4py/cartesian/backend/base.py | 39 +++++----- src/gt4py/cartesian/backend/dace_backend.py | 24 ++----- .../cartesian/backend/dace_stencil_object.py | 7 +- src/gt4py/cartesian/backend/gtc_common.py | 3 +- src/gt4py/cartesian/backend/gtcpp_backend.py | 3 +- src/gt4py/cartesian/stencil_object.py | 12 ++-- src/gt4py/cartesian/testing/definitions.py | 24 +++++++ src/gt4py/cartesian/testing/suites.py | 5 +- src/gt4py/storage/cartesian/cupy_device.py | 24 +++++++ src/gt4py/storage/cartesian/interface.py | 11 ++- src/gt4py/storage/cartesian/layout.py | 71 ++++++++++++++----- src/gt4py/storage/cartesian/utils.py | 25 +++---- tests/cartesian_tests/definitions.py | 45 ++++++------ .../backend_tests/test_backend_api.py | 17 +++-- .../unit_tests/test_interface.py | 29 ++++---- 15 files changed, 203 insertions(+), 136 deletions(-) create mode 100644 src/gt4py/cartesian/testing/definitions.py create mode 100644 src/gt4py/storage/cartesian/cupy_device.py diff --git a/src/gt4py/cartesian/backend/base.py b/src/gt4py/cartesian/backend/base.py index 5bab0453a9..4dd3757187 100644 --- a/src/gt4py/cartesian/backend/base.py +++ b/src/gt4py/cartesian/backend/base.py @@ -45,9 +45,9 @@ REGISTRY = gt_utils.Registry() -def from_name(name: str) -> Optional[Type[Backend]]: +def from_name(name: str) -> Type[Backend]: backend = REGISTRY.get(name, None) - if not backend: + if backend is None: raise NotImplementedError( f"Backend {name} is not implemented, options are: {REGISTRY.names}" ) @@ -84,7 +84,7 @@ class Backend(abc.ABC): #: Backend-specific storage parametrization: #: #: - "alignment": in bytes - #: - "device": "cpu" | "gpu" + #: - "device": core_defs.DeviceType | None #: - "layout_map": callback converting a mask to a layout #: - "is_optimal_layout": callback checking if a storage has compatible layout storage_info: ClassVar[gt_storage.layout.LayoutInfo] @@ -435,21 +435,20 @@ def disabled(message: str, *, enabled_env_var: str) -> Callable[[Type[Backend]], enabled = bool(int(os.environ.get(enabled_env_var, "0"))) if enabled: return deprecated(message) - else: - def _decorator(cls: Type[Backend]) -> Type[Backend]: - def _no_generate(obj) -> Type[StencilObject]: - raise NotImplementedError( - f"Disabled '{cls.name}' backend: 'f{message}'\n", - f"You can still enable the backend by hand using the environment variable '{enabled_env_var}=1'", - ) - - # Replace generate method with raise - if not hasattr(cls, "generate"): - raise ValueError(f"Coding error. Expected a generate method on {cls}") - # Flag that it got disabled for register lookup - cls.disabled = True # type: ignore - cls.generate = _no_generate # type: ignore - return cls - - return _decorator + def _decorator(cls: Type[Backend]) -> Type[Backend]: + def _no_generate(obj) -> Type[StencilObject]: + raise NotImplementedError( + f"Disabled '{cls.name}' backend: 'f{message}'\n", + f"You can still enable the backend by hand using the environment variable '{enabled_env_var}=1'", + ) + + # Replace generate method with raise + if not hasattr(cls, "generate"): + raise ValueError(f"Coding error. Expected a generate method on {cls}") + # Flag that it got disabled for register lookup + cls.disabled = True # type: ignore + cls.generate = _no_generate # type: ignore + return cls + + return _decorator diff --git a/src/gt4py/cartesian/backend/dace_backend.py b/src/gt4py/cartesian/backend/dace_backend.py index 5b822a1ab5..e56e60f7e7 100644 --- a/src/gt4py/cartesian/backend/dace_backend.py +++ b/src/gt4py/cartesian/backend/dace_backend.py @@ -421,7 +421,7 @@ def __call__(self, stencil_ir: gtir.Stencil) -> Dict[str, Dict[str, str]]: stencil_ir, sdfg, module_name=self.module_name, backend=self.backend ) - bindings_ext = "cu" if self.backend.storage_info["device"] == "gpu" else "cpp" + bindings_ext = "cu" if self.backend.name.endswith(":gpu") else "cpp" sources = { "computation": {"computation.hpp": implementation}, "bindings": {f"bindings.{bindings_ext}": bindings}, @@ -674,9 +674,7 @@ def generate_entry_params(self, stencil_ir: gtir.Stencil, sdfg: dace.SDFG) -> Li assert isinstance(data, dace.data.Array) res[name] = ( "py::{pybind_type} {name}, std::array {name}_origin".format( - pybind_type=( - "object" if self.backend.storage_info["device"] == "gpu" else "buffer" - ), + pybind_type=("object" if self.backend.name.endswith(":gpu") else "buffer"), name=name, ndim=len(data.shape), ) @@ -777,14 +775,7 @@ def generate(self) -> Type[StencilObject]: class DaceCPUBackend(BaseDaceBackend): name = "dace:cpu" languages = {"computation": "c++", "bindings": ["python"]} - storage_info = { - "alignment": 1, - "device": "cpu", - "layout_map": gt_storage.layout.layout_maker_factory((0, 1, 2)), - "is_optimal_layout": gt_storage.layout.layout_checker_factory( - gt_storage.layout.layout_maker_factory((0, 1, 2)) - ), - } + storage_info = gt_storage.layout.DaceCPULayout MODULE_GENERATOR_CLASS = DaCePyExtModuleGenerator options = BaseGTBackend.GT_BACKEND_OPTS @@ -799,14 +790,7 @@ class DaceGPUBackend(BaseDaceBackend): name = "dace:gpu" languages = {"computation": "cuda", "bindings": ["python"]} - storage_info = { - "alignment": 32, - "device": "gpu", - "layout_map": gt_storage.layout.layout_maker_factory((2, 1, 0)), - "is_optimal_layout": gt_storage.layout.layout_checker_factory( - gt_storage.layout.layout_maker_factory((2, 1, 0)) - ), - } + storage_info = gt_storage.layout.DaceGPULayout MODULE_GENERATOR_CLASS = DaCeCUDAPyExtModuleGenerator options = {**BaseGTBackend.GT_BACKEND_OPTS, "device_sync": {"versioning": True, "type": bool}} diff --git a/src/gt4py/cartesian/backend/dace_stencil_object.py b/src/gt4py/cartesian/backend/dace_stencil_object.py index 21006475a0..4a2d52289b 100644 --- a/src/gt4py/cartesian/backend/dace_stencil_object.py +++ b/src/gt4py/cartesian/backend/dace_stencil_object.py @@ -26,12 +26,11 @@ from gt4py.cartesian.utils import shash -def _extract_array_infos(field_args, device) -> Dict[str, Optional[ArgsInfo]]: +def _extract_array_infos(field_args) -> Dict[str, Optional[ArgsInfo]]: return { name: ArgsInfo( array=arg, dimensions=getattr(arg, "__gt_dims__", None), - device=device, origin=getattr(arg, "__gt_origin__", None), ) for name, arg in field_args.items() @@ -186,9 +185,7 @@ def normalize_args( args_as_kwargs = { name: (kwargs[name] if name in kwargs else next(args_iter)) for name in arg_names } - arg_infos = _extract_array_infos( - field_args=args_as_kwargs, device=backend_cls.storage_info["device"] - ) + arg_infos = _extract_array_infos(field_args=args_as_kwargs) origin = DaCeStencilObject._normalize_origins(arg_infos, field_info, origin) diff --git a/src/gt4py/cartesian/backend/gtc_common.py b/src/gt4py/cartesian/backend/gtc_common.py index 348e85de92..8213d567f9 100644 --- a/src/gt4py/cartesian/backend/gtc_common.py +++ b/src/gt4py/cartesian/backend/gtc_common.py @@ -23,6 +23,7 @@ from gt4py.cartesian.gtc.passes.gtir_pipeline import GtirPipeline from gt4py.cartesian.gtc.passes.oir_pipeline import OirPipeline from gt4py.eve.codegen import MakoTemplate as as_mako +from gt4py.storage.cartesian.layout import is_gpu_device if TYPE_CHECKING: @@ -51,7 +52,7 @@ def pybuffer_to_sid( domain_ndim = domain_dim_flags.count(True) sid_ndim = domain_ndim + data_ndim - as_sid = "as_cuda_sid" if backend.storage_info["device"] == "gpu" else "as_sid" + as_sid = "as_cuda_sid" if is_gpu_device(backend.storage_info) else "as_sid" sid_def = """gt::{as_sid}<{ctype}, {sid_ndim}, gt::integral_constant>({name})""".format( diff --git a/src/gt4py/cartesian/backend/gtcpp_backend.py b/src/gt4py/cartesian/backend/gtcpp_backend.py index 96f5672ae4..39281eaf45 100644 --- a/src/gt4py/cartesian/backend/gtcpp_backend.py +++ b/src/gt4py/cartesian/backend/gtcpp_backend.py @@ -25,6 +25,7 @@ from gt4py.cartesian.gtc.passes.gtir_pipeline import GtirPipeline from gt4py.cartesian.gtc.passes.oir_pipeline import DefaultPipeline from gt4py.eve import codegen +from gt4py.storage.cartesian.layout import is_gpu_device from .gtc_common import BaseGTBackend, CUDAPyExtModuleGenerator @@ -85,7 +86,7 @@ def visit_FieldDecl(self, node: gtcpp.FieldDecl, **kwargs): if kwargs["external_arg"]: return "py::{pybind_type} {name}, std::array {name}_origin".format( pybind_type=( - "object" if self.backend.storage_info["device"] == "gpu" else "buffer" + "object" if is_gpu_device(self.backend.storage_info) else "buffer" ), name=node.name, sid_ndim=sid_ndim, diff --git a/src/gt4py/cartesian/stencil_object.py b/src/gt4py/cartesian/stencil_object.py index 5e5976e3e5..6f8eac9f73 100644 --- a/src/gt4py/cartesian/stencil_object.py +++ b/src/gt4py/cartesian/stencil_object.py @@ -15,7 +15,7 @@ from dataclasses import dataclass from numbers import Number from pickle import dumps -from typing import Any, Callable, ClassVar, Dict, Literal, Optional, Tuple, Union, cast +from typing import Any, Callable, ClassVar, Dict, Optional, Tuple, Union, cast import numpy as np @@ -24,6 +24,7 @@ from gt4py import cartesian as gt4pyc from gt4py.cartesian.definitions import AccessKind, DomainInfo, FieldInfo, ParameterInfo from gt4py.cartesian.gtc.definitions import Index, Shape +from gt4py.storage.cartesian.layout import LayoutInfo try: @@ -51,7 +52,6 @@ def _compute_domain_origin_cache_key( @dataclass class ArgsInfo: - device: str array: FieldType original_object: Optional[Any] = None origin: Optional[Tuple[int, ...]] = None @@ -59,14 +59,14 @@ class ArgsInfo: def _extract_array_infos( - field_args: Dict[str, Optional[FieldType]], device: Literal["cpu", "gpu"] + field_args: Dict[str, Optional[FieldType]], layout_info: LayoutInfo ) -> Dict[str, Optional[ArgsInfo]]: array_infos: Dict[str, Optional[ArgsInfo]] = {} for name, arg in field_args.items(): if arg is None: array_infos[name] = None else: - array = storage_utils.asarray(arg, device=device) + array = storage_utils.asarray(arg, layout_info=layout_info) dimensions = storage_utils.get_dims(arg) if dimensions is not None: sorted_dimensions = [d for d in "IJK" if d in dimensions] @@ -79,7 +79,6 @@ def _extract_array_infos( array=array, original_object=arg, dimensions=dimensions, - device=device, origin=storage_utils.get_origin(arg), ) return array_infos @@ -562,8 +561,7 @@ def _call_run( exec_info["call_run_start_time"] = time.perf_counter() backend_cls = gt4pyc.backend.from_name(self.backend) assert backend_cls is not None - device = backend_cls.storage_info["device"] - array_infos = _extract_array_infos(field_args, device) + array_infos = _extract_array_infos(field_args, backend_cls.storage_info) cache_key = _compute_domain_origin_cache_key(array_infos, parameter_args, domain, origin) if cache_key not in self._domain_origin_cache: diff --git a/src/gt4py/cartesian/testing/definitions.py b/src/gt4py/cartesian/testing/definitions.py new file mode 100644 index 0000000000..fcc4e98a42 --- /dev/null +++ b/src/gt4py/cartesian/testing/definitions.py @@ -0,0 +1,24 @@ +# GT4Py - GridTools Framework +# +# Copyright (c) 2014-2024, ETH Zurich +# All rights reserved. +# +# Please, refer to the LICENSE file in the root directory. +# SPDX-License-Identifier: BSD-3-Clause + +from gt4py.cartesian.backend.base import REGISTRY as BACKEND_REGISTRY + + +# TODO (romanc) +# This file can move to the tests/ folder once we refactor the `StencilTestSuite` +# class (i.e. https://github.com/GEOS-ESM/NDSL/issues/72). The stencil test suites +# use the `GPU_BACKEND_NAMES` and I didn't wanna have `gt4py/cartesian/testing` depend +# on the `tests/` directory. That sounded the wrong way around, so I moved them here +# for now. + +ALL_BACKEND_NAMES = list(BACKEND_REGISTRY.keys()) + +GPU_BACKEND_NAMES = ["cuda", "gt:gpu", "dace:gpu"] +CPU_BACKEND_NAMES = [name for name in ALL_BACKEND_NAMES if name not in GPU_BACKEND_NAMES] + +PERFORMANCE_BACKEND_NAMES = [name for name in ALL_BACKEND_NAMES if name not in ("numpy", "cuda")] diff --git a/src/gt4py/cartesian/testing/suites.py b/src/gt4py/cartesian/testing/suites.py index 423f834f51..39dd1326c6 100644 --- a/src/gt4py/cartesian/testing/suites.py +++ b/src/gt4py/cartesian/testing/suites.py @@ -23,6 +23,7 @@ from gt4py.cartesian.definitions import AccessKind, FieldInfo from gt4py.cartesian.gtc.definitions import Boundary, CartesianSpace, Index, Shape from gt4py.cartesian.stencil_object import StencilObject +from gt4py.cartesian.testing.definitions import GPU_BACKEND_NAMES from gt4py.storage.cartesian import utils as storage_utils from .input_strategies import ( @@ -206,7 +207,7 @@ def hyp_wrapper(test_hyp, hypothesis_data): ) marks = test["marks"].copy() - if gt4pyc.backend.from_name(test["backend"]).storage_info["device"] == "gpu": + if test["backend"] in GPU_BACKEND_NAMES: marks.append(pytest.mark.requires_gpu) # Run generation and implementation tests in the same group to ensure # (thread-) safe parallelization of stencil tests. @@ -240,7 +241,7 @@ def hyp_wrapper(test_hyp, hypothesis_data): ) marks = test["marks"].copy() - if gt4pyc.backend.from_name(test["backend"]).storage_info["device"] == "gpu": + if test["backend"] in GPU_BACKEND_NAMES: marks.append(pytest.mark.requires_gpu) # Run generation and implementation tests in the same group to ensure # (thread-) safe parallelization of stencil tests. diff --git a/src/gt4py/storage/cartesian/cupy_device.py b/src/gt4py/storage/cartesian/cupy_device.py new file mode 100644 index 0000000000..9d3ba0439c --- /dev/null +++ b/src/gt4py/storage/cartesian/cupy_device.py @@ -0,0 +1,24 @@ +# GT4Py - GridTools Framework +# +# Copyright (c) 2014-2024, ETH Zurich +# All rights reserved. +# +# Please, refer to the LICENSE file in the root directory. +# SPDX-License-Identifier: BSD-3-Clause + +from __future__ import annotations + +from typing import Final, Literal + +from gt4py._core.definitions import DeviceType + + +try: + import cupy as cp +except ImportError: + cp = None + + +CUPY_DEVICE: Final[Literal[DeviceType.CUDA, DeviceType.ROCM] | None] = ( + None if not cp else (DeviceType.ROCM if cp.cuda.get_hipcc_path() else DeviceType.CUDA) +) diff --git a/src/gt4py/storage/cartesian/interface.py b/src/gt4py/storage/cartesian/interface.py index 8b38bcdd42..64149100d0 100644 --- a/src/gt4py/storage/cartesian/interface.py +++ b/src/gt4py/storage/cartesian/interface.py @@ -81,10 +81,15 @@ def empty( _error_on_invalid_preset(backend) storage_info = layout.from_name(backend) assert storage_info is not None - if storage_info["device"] == "gpu": + + if storage_info["device"] is None: + raise ValueError("device is None") + elif layout.is_gpu_device(storage_info): allocate_f = storage_utils.allocate_gpu - else: + elif layout.is_cpu_device(storage_info): allocate_f = storage_utils.allocate_cpu + else: + raise ValueError("Unknown device") aligned_index, shape, dtype, dimensions = storage_utils.normalize_storage_spec( aligned_index, shape, dtype, dimensions @@ -322,6 +327,6 @@ def from_array( layout_info = layout.from_name(backend) assert layout_info is not None - storage[...] = storage_utils.asarray(data, device=layout_info["device"]) + storage[...] = storage_utils.asarray(data, layout_info=layout_info) return storage diff --git a/src/gt4py/storage/cartesian/layout.py b/src/gt4py/storage/cartesian/layout.py index 868f7af7c4..760fabe9a6 100644 --- a/src/gt4py/storage/cartesian/layout.py +++ b/src/gt4py/storage/cartesian/layout.py @@ -12,16 +12,19 @@ Callable, Dict, Final, - Literal, Optional, Sequence, Tuple, + TypeAlias, TypedDict, Union, ) import numpy as np +import gt4py._core.definitions as core_defs +from gt4py.storage.cartesian.cupy_device import CUPY_DEVICE + if TYPE_CHECKING: try: @@ -29,30 +32,50 @@ except ImportError: cp = None +LayoutMap: TypeAlias = Callable[[Tuple[str, ...]], Tuple[Optional[int], ...]] + class LayoutInfo(TypedDict): alignment: int # measured in bytes - device: Literal["cpu", "gpu"] - layout_map: Callable[[Tuple[str, ...]], Tuple[Optional[int], ...]] + device: core_defs.DeviceType | None + layout_map: LayoutMap is_optimal_layout: Callable[[Any, Tuple[str, ...]], bool] +# Registry of LayoutInfos per backend REGISTRY: Dict[str, LayoutInfo] = {} -def from_name(name: str) -> Optional[LayoutInfo]: - return REGISTRY.get(name, None) +def from_name(backend_name: str) -> Optional[LayoutInfo]: + """Fetch LayoutInfo from the registry for a given backend name.""" + return REGISTRY.get(backend_name, None) -def register(name: str, info: Optional[LayoutInfo]) -> None: +def register(backend_name: str, info: Optional[LayoutInfo]) -> None: + """ "Register LayoutInfo under the given backend name. Clears an existing registry entry if None is given as info.""" if info is None: - if name in REGISTRY: - del REGISTRY[name] - else: - assert isinstance(name, str) - assert isinstance(info, dict) + if backend_name in REGISTRY: + del REGISTRY[backend_name] + return + + assert isinstance(backend_name, str) + assert isinstance(info, dict) + + REGISTRY[backend_name] = info + + +def is_cpu_device(layout_info: LayoutInfo) -> bool: + device = layout_info["device"] + if device is None: + raise ValueError("Can't determine if device is CPU because layout_info['device'] is None.") + return device == core_defs.DeviceType.CPU + - REGISTRY[name] = info +def is_gpu_device(layout_info: LayoutInfo) -> bool: + device = layout_info["device"] + if device is None: + raise ValueError("Can't determine if device is GPU because layout_info['device'] is None.") + return device in [core_defs.DeviceType.CUDA, core_defs.DeviceType.ROCM] def check_layout(layout_map, strides): @@ -136,7 +159,7 @@ def make_cuda_layout_map(dimensions: Tuple[str, ...]) -> Tuple[Optional[int], .. NaiveCPULayout: Final[LayoutInfo] = { "alignment": 1, - "device": "cpu", + "device": core_defs.DeviceType.CPU, "layout_map": lambda axes: tuple(i for i in range(len(axes))), "is_optimal_layout": lambda *_: True, } @@ -144,7 +167,7 @@ def make_cuda_layout_map(dimensions: Tuple[str, ...]) -> Tuple[Optional[int], .. CPUIFirstLayout: Final[LayoutInfo] = { "alignment": 8, - "device": "cpu", + "device": core_defs.DeviceType.CPU, "layout_map": make_gtcpu_ifirst_layout_map, "is_optimal_layout": layout_checker_factory(make_gtcpu_ifirst_layout_map), } @@ -153,7 +176,7 @@ def make_cuda_layout_map(dimensions: Tuple[str, ...]) -> Tuple[Optional[int], .. CPUKFirstLayout: Final[LayoutInfo] = { "alignment": 1, - "device": "cpu", + "device": core_defs.DeviceType.CPU, "layout_map": make_gtcpu_kfirst_layout_map, "is_optimal_layout": layout_checker_factory(make_gtcpu_kfirst_layout_map), } @@ -162,7 +185,7 @@ def make_cuda_layout_map(dimensions: Tuple[str, ...]) -> Tuple[Optional[int], .. CUDALayout: Final[LayoutInfo] = { "alignment": 32, - "device": "gpu", + "device": CUPY_DEVICE, "layout_map": make_cuda_layout_map, "is_optimal_layout": layout_checker_factory(make_cuda_layout_map), } @@ -170,3 +193,19 @@ def make_cuda_layout_map(dimensions: Tuple[str, ...]) -> Tuple[Optional[int], .. GPULayout: Final[LayoutInfo] = CUDALayout register("gpu", GPULayout) + +DaceCPULayout: Final[LayoutInfo] = { + "alignment": 1, + "device": core_defs.DeviceType.CPU, + "layout_map": layout_maker_factory((0, 1, 2)), + "is_optimal_layout": layout_checker_factory(layout_maker_factory((0, 1, 2))), +} +register("dace:cpu", DaceCPULayout) + +DaceGPULayout: Final[LayoutInfo] = { + "alignment": 32, + "device": CUPY_DEVICE, + "layout_map": layout_maker_factory((2, 1, 0)), + "is_optimal_layout": layout_checker_factory(layout_maker_factory((2, 1, 0))), +} +register("dace:gpu", DaceGPULayout) diff --git a/src/gt4py/storage/cartesian/utils.py b/src/gt4py/storage/cartesian/utils.py index bd89c85052..0f22fa5cfb 100644 --- a/src/gt4py/storage/cartesian/utils.py +++ b/src/gt4py/storage/cartesian/utils.py @@ -12,7 +12,7 @@ import functools import math import numbers -from typing import Final, Literal, Optional, Sequence, Tuple, Union, cast +from typing import Optional, Sequence, Tuple, Union, cast import numpy as np import numpy.typing as npt @@ -22,6 +22,8 @@ from gt4py.cartesian import config as gt_config from gt4py.eve.extended_typing import ArrayInterface, CUDAArrayInterface from gt4py.storage import allocators +from gt4py.storage.cartesian.cupy_device import CUPY_DEVICE +from gt4py.storage.cartesian.layout import LayoutInfo, is_cpu_device, is_gpu_device try: @@ -30,13 +32,6 @@ cp = None -CUPY_DEVICE: Final[Literal[None, core_defs.DeviceType.CUDA, core_defs.DeviceType.ROCM]] = ( - None - if not cp - else (core_defs.DeviceType.ROCM if cp.cuda.get_hipcc_path() else core_defs.DeviceType.CUDA) -) - - FieldLike = Union["cp.ndarray", np.ndarray, ArrayInterface, CUDAArrayInterface] _CPUBufferAllocator = allocators.NDArrayBufferAllocator( @@ -182,21 +177,19 @@ def cpu_copy(array: Union[np.ndarray, "cp.ndarray"]) -> np.ndarray: return np.array(array) -def asarray( - array: FieldLike, *, device: Literal["cpu", "gpu", None] = None -) -> np.ndarray | cp.ndarray: +def asarray(array: FieldLike, *, layout_info: LayoutInfo | None = None) -> np.ndarray | cp.ndarray: if hasattr(array, "ndarray"): # extract the buffer from a gt4py.next.Field # TODO(havogt): probably `Field` should provide the array interface methods when applicable array = array.ndarray xp = None - if device == "cpu": + if layout_info is not None and is_cpu_device(layout_info): xp = np - elif device == "gpu": + elif layout_info is not None and is_gpu_device(layout_info): assert cp is not None xp = cp - elif not device: + elif layout_info is None: if hasattr(array, "__dlpack_device__"): kind, _ = array.__dlpack_device__() if kind in [core_defs.DeviceType.CPU, core_defs.DeviceType.CPU_PINNED]: @@ -218,8 +211,8 @@ def asarray( if xp: return xp.asarray(array) - if device is not None: - raise ValueError(f"Invalid device: {device!s}") + if layout_info is not None: + raise ValueError(f"Invalid device: {layout_info['device']!s}") raise TypeError(f"Cannot convert {type(array)} to ndarray") diff --git a/tests/cartesian_tests/definitions.py b/tests/cartesian_tests/definitions.py index 7499ad4a95..7812f59fb8 100644 --- a/tests/cartesian_tests/definitions.py +++ b/tests/cartesian_tests/definitions.py @@ -14,42 +14,47 @@ cp = None import datetime +from typing import Callable, List import numpy as np import pytest from gt4py import cartesian as gt4pyc from gt4py.cartesian import utils as gt_utils - - -def _backend_name_as_param(name): +from gt4py.cartesian.backend.base import Backend, from_name +from gt4py.cartesian.testing.definitions import ( + ALL_BACKEND_NAMES, + CPU_BACKEND_NAMES, + GPU_BACKEND_NAMES, + PERFORMANCE_BACKEND_NAMES, +) +from gt4py.cartesian.testing.suites import ParameterSet +from gt4py.storage.cartesian.layout import is_gpu_device + + +def _backend_name_as_param(backend_name: str) -> ParameterSet: marks = [] - if gt4pyc.backend.from_name(name).storage_info["device"] == "gpu": + if backend_name in GPU_BACKEND_NAMES: marks.append(pytest.mark.requires_gpu) - if "dace" in name: + if "dace" in backend_name: marks.append(pytest.mark.requires_dace) - return pytest.param(name, marks=marks) - - -_ALL_BACKEND_NAMES = list(gt4pyc.backend.REGISTRY.keys()) + return pytest.param(backend_name, marks=marks) -def _get_backends_with_storage_info(storage_info_kind: str): +def _filter_backends(filter: Callable[[Backend], bool]) -> List[str]: res = [] - for name in _ALL_BACKEND_NAMES: - backend = gt4pyc.backend.from_name(name) - if not getattr(backend, "disabled", False): - if backend.storage_info["device"] == storage_info_kind: - res.append(_backend_name_as_param(name)) + for name in ALL_BACKEND_NAMES: + backend = from_name(name) + if not getattr(backend, "disabled", False) and filter(backend): + res.append(_backend_name_as_param(name)) return res -CPU_BACKENDS = _get_backends_with_storage_info("cpu") -GPU_BACKENDS = _get_backends_with_storage_info("gpu") +CPU_BACKENDS = _filter_backends(lambda backend: backend.name in CPU_BACKEND_NAMES) +GPU_BACKENDS = _filter_backends(lambda backend: backend.name in GPU_BACKEND_NAMES) ALL_BACKENDS = CPU_BACKENDS + GPU_BACKENDS -_PERFORMANCE_BACKEND_NAMES = [name for name in _ALL_BACKEND_NAMES if name not in ("numpy", "cuda")] -PERFORMANCE_BACKENDS = [_backend_name_as_param(name) for name in _PERFORMANCE_BACKEND_NAMES] +PERFORMANCE_BACKENDS = [_backend_name_as_param(name) for name in PERFORMANCE_BACKEND_NAMES] @pytest.fixture() @@ -61,7 +66,7 @@ def get_array_library(backend: str): """Return device ready array maker library""" backend_cls = gt4pyc.backend.from_name(backend) assert backend_cls is not None - if backend_cls.storage_info["device"] == "gpu": + if is_gpu_device(backend_cls.storage_info): assert cp is not None return cp else: diff --git a/tests/cartesian_tests/unit_tests/backend_tests/test_backend_api.py b/tests/cartesian_tests/unit_tests/backend_tests/test_backend_api.py index 3fbf586b35..055e6e8e24 100644 --- a/tests/cartesian_tests/unit_tests/backend_tests/test_backend_api.py +++ b/tests/cartesian_tests/unit_tests/backend_tests/test_backend_api.py @@ -10,15 +10,10 @@ import pytest -from gt4py import cartesian as gt4pyc +from gt4py.cartesian.backend.base import from_name from gt4py.cartesian.gtscript import PARALLEL, Field, computation, interval from gt4py.cartesian.stencil_builder import StencilBuilder - - -@pytest.fixture(params=[name for name in gt4pyc.backend.REGISTRY.keys()]) -def backend(request): - """Parametrize by backend name.""" - yield gt4pyc.backend.from_name(request.param) +from cartesian_tests.definitions import ALL_BACKENDS # mypy gets confused by gtscript @@ -28,10 +23,12 @@ def init_1(input_field: Field[float]): # type: ignore input_field = 1 # noqa # unused var is in/out field -def test_generate_computation(backend, tmp_path): +@pytest.mark.parametrize("backend_name", ALL_BACKENDS) +def test_generate_computation(backend_name, tmp_path): """Test if the :py:meth:`gt4pyc.backend.CLIBackendMixin.generate_computation` generates code.""" # note: if a backend is added that doesn't use CliBackendMixin it will # have to be special cased in the backend fixture + backend = from_name(backend_name) builder = StencilBuilder(init_1, backend=backend).with_caching( "nocaching", output_path=tmp_path / __name__ / "generate_computation" ) @@ -61,8 +58,10 @@ def test_generate_computation(backend, tmp_path): assert py_result or py_standalone_result or gt_result or gtc_result -def test_generate_bindings(backend, tmp_path): +@pytest.mark.parametrize("backend_name", ALL_BACKENDS) +def test_generate_bindings(backend_name, tmp_path): """Test :py:meth:`gt4pyc.backend.CLIBackendMixin.generate_bindings`.""" + backend = from_name(backend_name) builder = StencilBuilder(init_1, backend=backend).with_caching( "nocaching", output_path=tmp_path / __name__ / "generate_bindings" ) diff --git a/tests/storage_tests/unit_tests/test_interface.py b/tests/storage_tests/unit_tests/test_interface.py index ba7bc2aaef..69a629c65b 100644 --- a/tests/storage_tests/unit_tests/test_interface.py +++ b/tests/storage_tests/unit_tests/test_interface.py @@ -10,26 +10,21 @@ import hypothesis.strategies as hyp_st import numpy as np import pytest - +from gt4py.storage.cartesian import layout +import gt4py +from gt4py.cartesian import gtscript +from gt4py.storage.cartesian.utils import allocate_cpu, allocate_gpu, normalize_storage_spec try: import cupy as cp except ImportError: cp = None -import gt4py -from gt4py.cartesian import gtscript -from gt4py.storage.cartesian.utils import allocate_cpu, allocate_gpu, normalize_storage_spec +_ALL_LAYOUTS = [name for name, _ in layout.REGISTRY.items()] +_GPU_LAYOUTS = ["cuda", "dace:gpu", "gpu", "gt:gpu"] +_CPU_LAYOUTS = [name for name in _ALL_LAYOUTS if name not in _GPU_LAYOUTS] -CPU_LAYOUTS = [ - name for name, info in gt4py.storage.layout.REGISTRY.items() if info["device"] == "cpu" -] -GPU_LAYOUTS = [ - pytest.param(name, marks=pytest.mark.requires_gpu) - for name, info in gt4py.storage.layout.REGISTRY.items() - if info["device"] == "gpu" -] try: import dace @@ -367,28 +362,30 @@ def alloc_fun(request): return request.param -@pytest.mark.parametrize("backend", CPU_LAYOUTS) +@pytest.mark.parametrize("backend", _CPU_LAYOUTS) def test_cpu_constructor(alloc_fun, backend): stor = alloc_fun(dtype=np.float64, aligned_index=(1, 2, 3), shape=(2, 4, 6), backend=backend) assert stor.shape == (2, 4, 6) assert isinstance(stor, np.ndarray) -@pytest.mark.parametrize("backend", CPU_LAYOUTS) +@pytest.mark.parametrize("backend", _CPU_LAYOUTS) def test_cpu_constructor_0d(alloc_fun, backend): stor = alloc_fun(shape=(), dtype=np.float64, backend=backend, aligned_index=()) assert stor.shape == () assert isinstance(stor, np.ndarray) -@pytest.mark.parametrize("backend", GPU_LAYOUTS) +@pytest.mark.requires_gpu +@pytest.mark.parametrize("backend", _GPU_LAYOUTS) def test_gpu_constructor(alloc_fun, backend): stor = alloc_fun(dtype=np.float64, aligned_index=(1, 2, 3), shape=(2, 4, 6), backend=backend) assert stor.shape == (2, 4, 6) assert isinstance(stor, cp.ndarray) -@pytest.mark.parametrize("backend", GPU_LAYOUTS) +@pytest.mark.requires_gpu +@pytest.mark.parametrize("backend", _GPU_LAYOUTS) def test_gpu_constructor_0d(alloc_fun, backend): stor = alloc_fun(shape=(), dtype=np.float64, backend=backend, aligned_index=()) assert stor.shape == () From 4ecdfa1473f82508be5650cde9fe228d8effa8cc Mon Sep 17 00:00:00 2001 From: Roman Cattaneo <1116746+romanc@users.noreply.github.com> Date: Fri, 14 Feb 2025 20:16:05 +0100 Subject: [PATCH 2/2] Fixup: cleanups from self-review of code --- src/gt4py/cartesian/backend/base.py | 10 ++++------ src/gt4py/cartesian/backend/dace_backend.py | 7 +++++-- src/gt4py/storage/__init__.py | 3 +-- src/gt4py/storage/cartesian/layout.py | 16 ++++++++-------- src/gt4py/storage/cartesian/utils.py | 12 ++++++------ tests/storage_tests/unit_tests/test_interface.py | 2 +- 6 files changed, 25 insertions(+), 25 deletions(-) diff --git a/src/gt4py/cartesian/backend/base.py b/src/gt4py/cartesian/backend/base.py index 4dd3757187..a7a0385217 100644 --- a/src/gt4py/cartesian/backend/base.py +++ b/src/gt4py/cartesian/backend/base.py @@ -58,15 +58,13 @@ def register(backend_cls: Type[Backend]) -> Type[Backend]: assert issubclass(backend_cls, Backend) and backend_cls.name is not None if isinstance(backend_cls.name, str): - gt_storage.register(backend_cls.name, backend_cls.storage_info) return REGISTRY.register(backend_cls.name, backend_cls) - else: - raise ValueError( - "Invalid 'name' attribute ('{name}') in backend class '{cls}'".format( - name=backend_cls.name, cls=backend_cls - ) + raise ValueError( + "Invalid 'name' attribute ('{name}') in backend class '{cls}'".format( + name=backend_cls.name, cls=backend_cls ) + ) class Backend(abc.ABC): diff --git a/src/gt4py/cartesian/backend/dace_backend.py b/src/gt4py/cartesian/backend/dace_backend.py index e56e60f7e7..89e21e0c91 100644 --- a/src/gt4py/cartesian/backend/dace_backend.py +++ b/src/gt4py/cartesian/backend/dace_backend.py @@ -47,6 +47,7 @@ from gt4py.cartesian.utils import shash from gt4py.eve import codegen from gt4py.eve.codegen import MakoTemplate as as_mako +from gt4py.storage.cartesian.layout import is_gpu_device if TYPE_CHECKING: @@ -421,7 +422,7 @@ def __call__(self, stencil_ir: gtir.Stencil) -> Dict[str, Dict[str, str]]: stencil_ir, sdfg, module_name=self.module_name, backend=self.backend ) - bindings_ext = "cu" if self.backend.name.endswith(":gpu") else "cpp" + bindings_ext = "cu" if is_gpu_device(self.backend.storage_info) else "cpp" sources = { "computation": {"computation.hpp": implementation}, "bindings": {f"bindings.{bindings_ext}": bindings}, @@ -674,7 +675,9 @@ def generate_entry_params(self, stencil_ir: gtir.Stencil, sdfg: dace.SDFG) -> Li assert isinstance(data, dace.data.Array) res[name] = ( "py::{pybind_type} {name}, std::array {name}_origin".format( - pybind_type=("object" if self.backend.name.endswith(":gpu") else "buffer"), + pybind_type=( + "object" if is_gpu_device(self.backend.storage_info) else "buffer" + ), name=name, ndim=len(data.shape), ) diff --git a/src/gt4py/storage/__init__.py b/src/gt4py/storage/__init__.py index 5986baa65e..12716962fb 100644 --- a/src/gt4py/storage/__init__.py +++ b/src/gt4py/storage/__init__.py @@ -11,7 +11,7 @@ from . import cartesian from .cartesian import layout from .cartesian.interface import empty, from_array, full, ones, zeros -from .cartesian.layout import from_name, register +from .cartesian.layout import from_name __all__ = [ @@ -22,6 +22,5 @@ "full", "layout", "ones", - "register", "zeros", ] diff --git a/src/gt4py/storage/cartesian/layout.py b/src/gt4py/storage/cartesian/layout.py index 760fabe9a6..90d99505d6 100644 --- a/src/gt4py/storage/cartesian/layout.py +++ b/src/gt4py/storage/cartesian/layout.py @@ -51,7 +51,7 @@ def from_name(backend_name: str) -> Optional[LayoutInfo]: return REGISTRY.get(backend_name, None) -def register(backend_name: str, info: Optional[LayoutInfo]) -> None: +def _register(backend_name: str, info: Optional[LayoutInfo]) -> None: """ "Register LayoutInfo under the given backend name. Clears an existing registry entry if None is given as info.""" if info is None: if backend_name in REGISTRY: @@ -163,7 +163,7 @@ def make_cuda_layout_map(dimensions: Tuple[str, ...]) -> Tuple[Optional[int], .. "layout_map": lambda axes: tuple(i for i in range(len(axes))), "is_optimal_layout": lambda *_: True, } -register("naive_cpu", NaiveCPULayout) +_register("numpy", NaiveCPULayout) CPUIFirstLayout: Final[LayoutInfo] = { "alignment": 8, @@ -171,7 +171,7 @@ def make_cuda_layout_map(dimensions: Tuple[str, ...]) -> Tuple[Optional[int], .. "layout_map": make_gtcpu_ifirst_layout_map, "is_optimal_layout": layout_checker_factory(make_gtcpu_ifirst_layout_map), } -register("cpu_ifirst", CPUIFirstLayout) +_register("gt:cpu_ifirst", CPUIFirstLayout) CPUKFirstLayout: Final[LayoutInfo] = { @@ -180,7 +180,7 @@ def make_cuda_layout_map(dimensions: Tuple[str, ...]) -> Tuple[Optional[int], .. "layout_map": make_gtcpu_kfirst_layout_map, "is_optimal_layout": layout_checker_factory(make_gtcpu_kfirst_layout_map), } -register("cpu_kfirst", CPUKFirstLayout) +_register("gt:cpu_kfirst", CPUKFirstLayout) CUDALayout: Final[LayoutInfo] = { @@ -189,10 +189,10 @@ def make_cuda_layout_map(dimensions: Tuple[str, ...]) -> Tuple[Optional[int], .. "layout_map": make_cuda_layout_map, "is_optimal_layout": layout_checker_factory(make_cuda_layout_map), } -register("cuda", CUDALayout) +_register("cuda", CUDALayout) GPULayout: Final[LayoutInfo] = CUDALayout -register("gpu", GPULayout) +_register("gt:gpu", GPULayout) DaceCPULayout: Final[LayoutInfo] = { "alignment": 1, @@ -200,7 +200,7 @@ def make_cuda_layout_map(dimensions: Tuple[str, ...]) -> Tuple[Optional[int], .. "layout_map": layout_maker_factory((0, 1, 2)), "is_optimal_layout": layout_checker_factory(layout_maker_factory((0, 1, 2))), } -register("dace:cpu", DaceCPULayout) +_register("dace:cpu", DaceCPULayout) DaceGPULayout: Final[LayoutInfo] = { "alignment": 32, @@ -208,4 +208,4 @@ def make_cuda_layout_map(dimensions: Tuple[str, ...]) -> Tuple[Optional[int], .. "layout_map": layout_maker_factory((2, 1, 0)), "is_optimal_layout": layout_checker_factory(layout_maker_factory((2, 1, 0))), } -register("dace:gpu", DaceGPULayout) +_register("dace:gpu", DaceGPULayout) diff --git a/src/gt4py/storage/cartesian/utils.py b/src/gt4py/storage/cartesian/utils.py index 0f22fa5cfb..300bb99688 100644 --- a/src/gt4py/storage/cartesian/utils.py +++ b/src/gt4py/storage/cartesian/utils.py @@ -184,12 +184,7 @@ def asarray(array: FieldLike, *, layout_info: LayoutInfo | None = None) -> np.nd array = array.ndarray xp = None - if layout_info is not None and is_cpu_device(layout_info): - xp = np - elif layout_info is not None and is_gpu_device(layout_info): - assert cp is not None - xp = cp - elif layout_info is None: + if layout_info is None: if hasattr(array, "__dlpack_device__"): kind, _ = array.__dlpack_device__() if kind in [core_defs.DeviceType.CPU, core_defs.DeviceType.CPU_PINNED]: @@ -207,6 +202,11 @@ def asarray(array: FieldLike, *, layout_info: LayoutInfo | None = None) -> np.nd xp = cp elif hasattr(array, "__array_interface__") or hasattr(array, "__array__"): xp = np + elif is_cpu_device(layout_info): + xp = np + elif is_gpu_device(layout_info): + assert cp is not None + xp = cp if xp: return xp.asarray(array) diff --git a/tests/storage_tests/unit_tests/test_interface.py b/tests/storage_tests/unit_tests/test_interface.py index 69a629c65b..8eae8bb36b 100644 --- a/tests/storage_tests/unit_tests/test_interface.py +++ b/tests/storage_tests/unit_tests/test_interface.py @@ -22,7 +22,7 @@ _ALL_LAYOUTS = [name for name, _ in layout.REGISTRY.items()] -_GPU_LAYOUTS = ["cuda", "dace:gpu", "gpu", "gt:gpu"] +_GPU_LAYOUTS = ["cuda", "dace:gpu", "gt:gpu"] _CPU_LAYOUTS = [name for name in _ALL_LAYOUTS if name not in _GPU_LAYOUTS]