Skip to content

Commit

Permalink
feat[cartesian]: gtc cuda backend deprecation (#1498)
Browse files Browse the repository at this point in the history
## Description

GTC `cuda` backend was made available a few years ago for AI2 team
research. It has been kept updated but a recent poll shows that it is
not in use. Recent new features break the backend and we propose here to
hard deprecate it rather than keep spending time maintaining it.

`GT4PY_GTC_ENABLE_CUDA=1` can be used to force the use of the backend, but
will warn that any feature from February 2024 are not available/not
tested.

Additionally a mechanism to deprecate all GTC backends are now in use. Using
```python
@disabled(
    message="Disable message.",
    enabled_env_var="EnvVarToEnable",
)
```

## Requirements

- [x] All fixes and/or new features come with corresponding tests.

---------

Co-authored-by: Hannes Vogt <hannes@havogt.de>
  • Loading branch information
FlorianDeconinck and havogt authored Apr 26, 2024
1 parent 6a5ae7e commit 2cd0c91
Show file tree
Hide file tree
Showing 5 changed files with 78 additions and 7 deletions.
43 changes: 41 additions & 2 deletions src/gt4py/cartesian/backend/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,21 @@
import pathlib
import time
import warnings
from typing import TYPE_CHECKING, Any, ClassVar, Dict, List, Optional, Protocol, Tuple, Type, Union
from typing import (
TYPE_CHECKING,
Any,
Callable,
ClassVar,
Dict,
List,
Optional,
Protocol,
Tuple,
Type,
Union,
)

from typing_extensions import deprecated

from gt4py import storage as gt_storage
from gt4py.cartesian import definitions as gt_definitions, utils as gt_utils
Expand All @@ -39,7 +53,7 @@ def from_name(name: str) -> Optional[Type["Backend"]]:
return REGISTRY.get(name, None)


def register(backend_cls: Type["Backend"]) -> None:
def register(backend_cls: Type["Backend"]) -> Type["Backend"]:
assert issubclass(backend_cls, Backend) and backend_cls.name is not None

if isinstance(backend_cls.name, str):
Expand Down Expand Up @@ -413,3 +427,28 @@ def build_extension_module(
)

return module_name, file_path


def disabled(message: str, *, enabled_env_var: str) -> Callable[[Type[Backend]], Type[Backend]]:
# We push for hard deprecation here by raising by default and warning if enabling has been forced.
enabled = bool(int(os.environ.get(enabled_env_var, "0")))
if enabled:
return deprecated(message)
else:

def _decorator(cls: Type[Backend]) -> Type[Backend]:
def _no_generate(obj) -> Type["StencilObject"]:
raise NotImplementedError(
f"Disabled '{cls.name}' backend: 'f{message}'\n",
f"You can still enable the backend by hand using the environment variable '{enabled_env_var}=1'",
)

# Replace generate method with raise
if not hasattr(cls, "generate"):
raise ValueError(f"Coding error. Expected a generate method on {cls}")
# Flag that it got disabled for register lookup
cls.disabled = True # type: ignore
cls.generate = _no_generate # type: ignore
return cls

return _decorator
11 changes: 9 additions & 2 deletions src/gt4py/cartesian/backend/cuda_backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@
from typing import TYPE_CHECKING, Any, Dict, Optional, Tuple, Type

from gt4py import storage as gt_storage
from gt4py.cartesian.backend.base import CLIBackendMixin, register
from gt4py.cartesian.backend.base import CLIBackendMixin, disabled, register
from gt4py.cartesian.backend.gtc_common import (
BackendCodegen,
bindings_main_template,
Expand Down Expand Up @@ -125,12 +125,19 @@ def apply_codegen(cls, root, *, module_name="stencil", backend, **kwargs) -> str
return generated_code


@disabled(
message="CUDA backend is deprecated. New features developed after February 2024 are not available.",
enabled_env_var="GT4PY_GTC_ENABLE_CUDA",
)
@register
class CudaBackend(BaseGTBackend, CLIBackendMixin):
"""CUDA backend using gtc."""

name = "cuda"
options = {**BaseGTBackend.GT_BACKEND_OPTS, "device_sync": {"versioning": True, "type": bool}}
options = {
**BaseGTBackend.GT_BACKEND_OPTS,
"device_sync": {"versioning": True, "type": bool},
}
languages = {"computation": "cuda", "bindings": ["python"]}
storage_info = gt_storage.layout.CUDALayout
PYEXT_GENERATOR_CLASS = CudaExtGenerator # type: ignore
Expand Down
5 changes: 4 additions & 1 deletion src/gt4py/cartesian/backend/gtc_common.py
Original file line number Diff line number Diff line change
Expand Up @@ -278,7 +278,10 @@ def make_extension(
gt_pyext_sources: Dict[str, Any]
if not self.builder.options._impl_opts.get("disable-code-generation", False):
gt_pyext_files = self.make_extension_sources(stencil_ir=stencil_ir)
gt_pyext_sources = {**gt_pyext_files["computation"], **gt_pyext_files["bindings"]}
gt_pyext_sources = {
**gt_pyext_files["computation"],
**gt_pyext_files["bindings"],
}
else:
# Pass NOTHING to the self.builder means try to reuse the source code files
gt_pyext_files = {}
Expand Down
2 changes: 1 addition & 1 deletion tests/cartesian_tests/definitions.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,7 @@ def _get_backends_with_storage_info(storage_info_kind: str):
res = []
for name in _ALL_BACKEND_NAMES:
backend = gt4pyc.backend.from_name(name)
if backend is not None:
if not getattr(backend, "disabled", False):
if backend.storage_info["device"] == storage_info_kind:
res.append(_backend_name_as_param(name))
return res
Expand Down
24 changes: 23 additions & 1 deletion tests/cartesian_tests/unit_tests/backend_tests/test_backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,7 +48,11 @@ def stencil_def(
out = pa * fa + pb * fb - pc * fc # type: ignore # noqa


field_info_val = {0: ("out", "fa"), 1: ("out", "fa", "fb"), 2: ("out", "fa", "fb", "fc")}
field_info_val = {
0: ("out", "fa"),
1: ("out", "fa", "fb"),
2: ("out", "fa", "fb", "fc"),
}
parameter_info_val = {0: ("pa",), 1: ("pa", "pb"), 2: ("pa", "pb", "pc")}
unreferenced_val = {0: ("pb", "fb", "pc", "fc"), 1: ("pc", "fc"), 2: ()}

Expand Down Expand Up @@ -168,5 +172,23 @@ def test_toolchain_profiling(backend_name: str, mode: int, rebuild: bool):
assert build_info["load_time"] > 0.0


@pytest.mark.parametrize("backend_name", ["cuda"])
def test_deprecation_gtc_cuda(backend_name: str):
# Default deprecation, raise an error
build_info: Dict[str, Any] = {}
builder = (
StencilBuilder(cast(StencilFunc, stencil_def))
.with_backend(backend_name)
.with_externals({"MODE": 2})
.with_options(
name=stencil_def.__name__,
module=stencil_def.__module__,
build_info=build_info,
)
)
with pytest.raises(NotImplementedError):
builder.build()


if __name__ == "__main__":
pytest.main([__file__])

0 comments on commit 2cd0c91

Please sign in to comment.