Skip to content

Commit

Permalink
WIP: Re-use DeviceType from core_definitions
Browse files Browse the repository at this point in the history
To be split from the other changes, probably in a follow-up PR.
  • Loading branch information
romanc committed Feb 12, 2025
1 parent ed59a6a commit 4ad9458
Show file tree
Hide file tree
Showing 13 changed files with 93 additions and 83 deletions.
24 changes: 5 additions & 19 deletions src/gt4py/cartesian/backend/dace_backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,7 @@
from gt4py.cartesian.utils import shash
from gt4py.eve import codegen
from gt4py.eve.codegen import MakoTemplate as as_mako
from gt4py.storage.cartesian.layout import StorageDevice
from gt4py.storage.cartesian.layout import is_cuda_device


if TYPE_CHECKING:
Expand Down Expand Up @@ -409,7 +409,7 @@ def __call__(self, stencil_ir: gtir.Stencil) -> Dict[str, Dict[str, str]]:
stencil_ir, sdfg, module_name=self.module_name, backend=self.backend
)

bindings_ext = "cu" if self.backend.storage_info["device"] == StorageDevice.GPU else "cpp"
bindings_ext = "cu" if is_cuda_device(self.backend.storage_info) else "cpp"
sources = {
"computation": {"computation.hpp": implementation},
"bindings": {f"bindings.{bindings_ext}": bindings},
Expand Down Expand Up @@ -664,7 +664,7 @@ def generate_entry_params(self, stencil_ir: gtir.Stencil, sdfg: dace.SDFG) -> Li
"py::{pybind_type} {name}, std::array<gt::int_t,{ndim}> {name}_origin".format(
pybind_type=(
"object"
if self.backend.storage_info["device"] == StorageDevice.GPU
if is_cuda_device(self.backend.storage_info["device"])
else "buffer"
),
name=name,
Expand Down Expand Up @@ -767,14 +767,7 @@ def generate(self) -> Type[StencilObject]:
class DaceCPUBackend(BaseDaceBackend):
name = "dace:cpu"
languages = {"computation": "c++", "bindings": ["python"]}
storage_info = {
"alignment": 1,
"device": StorageDevice.CPU,
"layout_map": gt_storage.layout.layout_maker_factory((0, 1, 2)),
"is_optimal_layout": gt_storage.layout.layout_checker_factory(
gt_storage.layout.layout_maker_factory((0, 1, 2))
),
}
storage_info = gt_storage.layout.DaceCPULayout
MODULE_GENERATOR_CLASS = DaCePyExtModuleGenerator

options = BaseGTBackend.GT_BACKEND_OPTS
Expand All @@ -789,14 +782,7 @@ class DaceGPUBackend(BaseDaceBackend):

name = "dace:gpu"
languages = {"computation": "cuda", "bindings": ["python"]}
storage_info = {
"alignment": 32,
"device": StorageDevice.GPU,
"layout_map": gt_storage.layout.layout_maker_factory((2, 1, 0)),
"is_optimal_layout": gt_storage.layout.layout_checker_factory(
gt_storage.layout.layout_maker_factory((2, 1, 0))
),
}
storage_info = gt_storage.layout.DaceGPULayout
MODULE_GENERATOR_CLASS = DaCeCUDAPyExtModuleGenerator
options = {**BaseGTBackend.GT_BACKEND_OPTS, "device_sync": {"versioning": True, "type": bool}}

Expand Down
4 changes: 2 additions & 2 deletions src/gt4py/cartesian/backend/dace_stencil_object.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,14 +20,14 @@
from dace.frontend.python.common import SDFGClosure, SDFGConvertible

from gt4py import cartesian as gt4pyc
from gt4py._core.definitions import DeviceType
from gt4py.cartesian.backend.dace_backend import freeze_origin_domain_sdfg
from gt4py.cartesian.definitions import AccessKind, DomainInfo, FieldInfo
from gt4py.cartesian.stencil_object import ArgsInfo, FrozenStencil, StencilObject
from gt4py.cartesian.utils import shash
from gt4py.storage.cartesian.layout import StorageDevice


def _extract_array_infos(field_args, device: StorageDevice) -> Dict[str, Optional[ArgsInfo]]:
def _extract_array_infos(field_args, device: DeviceType) -> Dict[str, Optional[ArgsInfo]]:
return {
name: ArgsInfo(
array=arg,
Expand Down
4 changes: 2 additions & 2 deletions src/gt4py/cartesian/backend/gtc_common.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@
from gt4py.cartesian.gtc.passes.gtir_pipeline import GtirPipeline
from gt4py.cartesian.gtc.passes.oir_pipeline import OirPipeline
from gt4py.eve.codegen import MakoTemplate as as_mako
from gt4py.storage.cartesian.layout import StorageDevice
from gt4py.storage.cartesian.layout import is_cuda_device


if TYPE_CHECKING:
Expand Down Expand Up @@ -52,7 +52,7 @@ def pybuffer_to_sid(
domain_ndim = domain_dim_flags.count(True)
sid_ndim = domain_ndim + data_ndim

as_sid = "as_cuda_sid" if backend.storage_info["device"] == StorageDevice.GPU else "as_sid"
as_sid = "as_cuda_sid" if is_cuda_device(backend.storage_info) else "as_sid"

sid_def = """gt::{as_sid}<{ctype}, {sid_ndim},
gt::integral_constant<int, {unique_index}>>({name})""".format(
Expand Down
6 changes: 2 additions & 4 deletions src/gt4py/cartesian/backend/gtcpp_backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@
from gt4py.cartesian.gtc.passes.gtir_pipeline import GtirPipeline
from gt4py.cartesian.gtc.passes.oir_pipeline import DefaultPipeline
from gt4py.eve import codegen
from gt4py.storage.cartesian.layout import StorageDevice
from gt4py.storage.cartesian.layout import is_gpu_device

from .gtc_common import BaseGTBackend, CUDAPyExtModuleGenerator

Expand Down Expand Up @@ -86,9 +86,7 @@ def visit_FieldDecl(self, node: gtcpp.FieldDecl, **kwargs):
if kwargs["external_arg"]:
return "py::{pybind_type} {name}, std::array<gt::int_t,{sid_ndim}> {name}_origin".format(
pybind_type=(
"object"
if self.backend.storage_info["device"] == StorageDevice.GPU
else "buffer"
"object" if is_gpu_device(self.backend.storage_info) else "buffer"
),
name=node.name,
sid_ndim=sid_ndim,
Expand Down
18 changes: 9 additions & 9 deletions src/gt4py/cartesian/gtc/dace/oir_to_dace.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@

import gt4py.cartesian.gtc.oir as oir
from gt4py import eve
from gt4py._core.definitions import DeviceType
from gt4py.cartesian.gtc.dace import daceir as dcir
from gt4py.cartesian.gtc.dace.constants import CONNECTOR_PREFIX_IN, CONNECTOR_PREFIX_OUT
from gt4py.cartesian.gtc.dace.nodes import StencilComputation
Expand All @@ -31,17 +32,16 @@
AccessCollector,
compute_horizontal_block_extents,
)
from gt4py.storage.cartesian.layout import StorageDevice


transient_storage_per_device: Dict[StorageDevice, dace.StorageType] = {
StorageDevice.CPU: dace.StorageType.Default,
StorageDevice.GPU: dace.StorageType.GPU_Global,
transient_storage_per_device: Dict[DeviceType, dace.StorageType] = {
DeviceType.CPU: dace.StorageType.Default,
DeviceType.CUDA: dace.StorageType.GPU_Global,
}

device_type_per_device: Dict[StorageDevice, dace.DeviceType] = {
StorageDevice.CPU: dace.DeviceType.CPU,
StorageDevice.GPU: dace.DeviceType.GPU,
device_type_per_device: Dict[DeviceType, dace.DeviceType] = {
DeviceType.CPU: dace.DeviceType.CPU,
DeviceType.CUDA: dace.DeviceType.GPU,
}


Expand Down Expand Up @@ -115,7 +115,7 @@ def visit_VerticalLoop(
node: oir.VerticalLoop,
*,
ctx: OirSDFGBuilder.SDFGContext,
device: StorageDevice,
device: DeviceType,
) -> None:
declarations = {
acc.name: ctx.decls[acc.name]
Expand Down Expand Up @@ -156,7 +156,7 @@ def visit_VerticalLoop(
library_node, connector_name, access_node, None, dace.Memlet(field, subset=subset)
)

def visit_Stencil(self, node: oir.Stencil, *, device: StorageDevice) -> dace.SDFG:
def visit_Stencil(self, node: oir.Stencil, *, device: DeviceType) -> dace.SDFG:
ctx = OirSDFGBuilder.SDFGContext(node)
for param in node.params:
if isinstance(param, oir.FieldDecl):
Expand Down
6 changes: 3 additions & 3 deletions src/gt4py/cartesian/stencil_object.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,9 +22,9 @@
import gt4py.cartesian.gtc.utils as gtc_utils
import gt4py.storage.cartesian.utils as storage_utils
from gt4py import cartesian as gt4pyc
from gt4py._core.definitions import DeviceType
from gt4py.cartesian.definitions import AccessKind, DomainInfo, FieldInfo, ParameterInfo
from gt4py.cartesian.gtc.definitions import Index, Shape
from gt4py.storage.cartesian.layout import StorageDevice


try:
Expand Down Expand Up @@ -52,15 +52,15 @@ def _compute_domain_origin_cache_key(

@dataclass
class ArgsInfo:
device: StorageDevice
device: DeviceType
array: FieldType
original_object: Optional[Any] = None
origin: Optional[Tuple[int, ...]] = None
dimensions: Optional[Tuple[str, ...]] = None


def _extract_array_infos(
field_args: Dict[str, Optional[FieldType]], device: StorageDevice
field_args: Dict[str, Optional[FieldType]], device: DeviceType
) -> Dict[str, Optional[ArgsInfo]]:
array_infos: Dict[str, Optional[ArgsInfo]] = {}
for name, arg in field_args.items():
Expand Down
12 changes: 3 additions & 9 deletions src/gt4py/cartesian/testing/suites.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@
from gt4py.cartesian.gtc.definitions import Boundary, CartesianSpace, Index, Shape
from gt4py.cartesian.stencil_object import StencilObject
from gt4py.storage.cartesian import utils as storage_utils
from gt4py.storage.cartesian.layout import StorageDevice
from gt4py.storage.cartesian.layout import is_gpu_device

from .input_strategies import (
SymbolKind,
Expand Down Expand Up @@ -207,10 +207,7 @@ def hyp_wrapper(test_hyp, hypothesis_data):
)

marks = test["marks"].copy()
if (
gt4pyc.backend.from_name(test["backend"]).storage_info["device"]
== StorageDevice.GPU
):
if is_gpu_device(gt4pyc.backend.from_name(test["backend"]).storage_info):
marks.append(pytest.mark.requires_gpu)
# Run generation and implementation tests in the same group to ensure
# (thread-) safe parallelization of stencil tests.
Expand Down Expand Up @@ -244,10 +241,7 @@ def hyp_wrapper(test_hyp, hypothesis_data):
)

marks = test["marks"].copy()
if (
gt4pyc.backend.from_name(test["backend"]).storage_info["device"]
== StorageDevice.GPU
):
if is_gpu_device(gt4pyc.backend.from_name(test["backend"]).storage_info):
marks.append(pytest.mark.requires_gpu)
# Run generation and implementation tests in the same group to ensure
# (thread-) safe parallelization of stencil tests.
Expand Down
9 changes: 7 additions & 2 deletions src/gt4py/storage/cartesian/interface.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@

import numpy as np

from gt4py._core.definitions import DeviceType
from gt4py.storage import allocators
from gt4py.storage.cartesian import layout, utils as storage_utils

Expand Down Expand Up @@ -81,10 +82,14 @@ def empty(
_error_on_invalid_preset(backend)
storage_info = layout.from_name(backend)
assert storage_info is not None
if storage_info["device"] == layout.StorageDevice.GPU:
if storage_info["device"] == DeviceType.CUDA:
allocate_f = storage_utils.allocate_gpu
else:
elif storage_info["device"] == DeviceType.CPU:
allocate_f = storage_utils.allocate_cpu
else:
raise ValueError(
f"Allocation is only defined for DeviceTypes {DeviceType.CPU} (CPU) and {DeviceType.CUDA} (CUDA). Got {storage_info['device']} instead."
)

aligned_index, shape, dtype, dimensions = storage_utils.normalize_storage_spec(
aligned_index, shape, dtype, dimensions
Expand Down
63 changes: 45 additions & 18 deletions src/gt4py/storage/cartesian/layout.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,6 @@
# Please, refer to the LICENSE file in the root directory.
# SPDX-License-Identifier: BSD-3-Clause

from enum import Enum
from typing import (
TYPE_CHECKING,
Any,
Expand All @@ -22,6 +21,8 @@

import numpy as np

from gt4py._core.definitions import DeviceType


if TYPE_CHECKING:
try:
Expand All @@ -30,34 +31,32 @@
cp = None


class StorageDevice(Enum):
CPU = 1
GPU = 2


class LayoutInfo(TypedDict):
alignment: int # measured in bytes
device: StorageDevice
device: DeviceType
layout_map: Callable[[Tuple[str, ...]], Tuple[Optional[int], ...]]
is_optimal_layout: Callable[[Any, Tuple[str, ...]], bool]


# Registry of LayoutInfo per backend name.
REGISTRY: Dict[str, LayoutInfo] = {}


def from_name(name: str) -> Optional[LayoutInfo]:
return REGISTRY.get(name, None)
def from_name(backend_name: str) -> Optional[LayoutInfo]:
"""Get LayoutInfo from backend name."""
return REGISTRY.get(backend_name, None)


def register(name: str, info: Optional[LayoutInfo]) -> None:
def register(backend_name: str, info: Optional[LayoutInfo]) -> None:
"""Register LayoutInfo with backend name."""
if info is None:
if name in REGISTRY:
del REGISTRY[name]
if backend_name in REGISTRY:
del REGISTRY[backend_name]
else:
assert isinstance(name, str)
assert isinstance(backend_name, str)
assert isinstance(info, dict)

REGISTRY[name] = info
REGISTRY[backend_name] = info


def check_layout(layout_map, strides):
Expand All @@ -71,6 +70,18 @@ def check_layout(layout_map, strides):
return True


def is_cpu_device(layout: LayoutInfo) -> bool:
return layout["device"] == DeviceType.CPU


def is_gpu_device(layout: LayoutInfo) -> bool:
return layout["device"] in [DeviceType.CUDA, DeviceType.ROCM]


def is_cuda_device(layout: LayoutInfo) -> bool:
return layout["device"] == DeviceType.CUDA


def layout_maker_factory(
base_layout: Tuple[int, ...],
) -> Callable[[Tuple[str, ...]], Tuple[int, ...]]:
Expand Down Expand Up @@ -141,15 +152,15 @@ def make_cuda_layout_map(dimensions: Tuple[str, ...]) -> Tuple[Optional[int], ..

NaiveCPULayout: Final[LayoutInfo] = {
"alignment": 1,
"device": StorageDevice.CPU,
"device": DeviceType.CPU,
"layout_map": lambda axes: tuple(i for i in range(len(axes))),
"is_optimal_layout": lambda *_: True,
}
register("naive_cpu", NaiveCPULayout)

CPUIFirstLayout: Final[LayoutInfo] = {
"alignment": 8,
"device": StorageDevice.CPU,
"device": DeviceType.CPU,
"layout_map": make_gtcpu_ifirst_layout_map,
"is_optimal_layout": layout_checker_factory(make_gtcpu_ifirst_layout_map),
}
Expand All @@ -158,7 +169,7 @@ def make_cuda_layout_map(dimensions: Tuple[str, ...]) -> Tuple[Optional[int], ..

CPUKFirstLayout: Final[LayoutInfo] = {
"alignment": 1,
"device": StorageDevice.CPU,
"device": DeviceType.CPU,
"layout_map": make_gtcpu_kfirst_layout_map,
"is_optimal_layout": layout_checker_factory(make_gtcpu_kfirst_layout_map),
}
Expand All @@ -167,11 +178,27 @@ def make_cuda_layout_map(dimensions: Tuple[str, ...]) -> Tuple[Optional[int], ..

CUDALayout: Final[LayoutInfo] = {
"alignment": 32,
"device": StorageDevice.GPU,
"device": DeviceType.CUDA,
"layout_map": make_cuda_layout_map,
"is_optimal_layout": layout_checker_factory(make_cuda_layout_map),
}
register("cuda", CUDALayout)

GPULayout: Final[LayoutInfo] = CUDALayout
register("gpu", GPULayout)

DaceCPULayout: Final[LayoutInfo] = {
"alignment": 1,
"device": DeviceType.CPU,
"layout_map": layout_maker_factory((0, 1, 2)),
"is_optimal_layout": layout_checker_factory(layout_maker_factory((0, 1, 2))),
}
register("dace:cpu", DaceCPULayout)

DaceGPULayout: Final[LayoutInfo] = {
"alignment": 32,
"device": DeviceType.CUDA,
"layout_map": layout_maker_factory((2, 1, 0)),
"is_optimal_layout": layout_checker_factory(layout_maker_factory((2, 1, 0))),
}
register("dace:gpu", DaceGPULayout)
Loading

0 comments on commit 4ad9458

Please sign in to comment.