Skip to content

Commit

Permalink
Changes suggested by PR reviewer.
Browse files Browse the repository at this point in the history
  • Loading branch information
drivanov committed Sep 25, 2024
1 parent ecf3aa0 commit ae6b970
Show file tree
Hide file tree
Showing 3 changed files with 18 additions and 16 deletions.
15 changes: 8 additions & 7 deletions tests/python/pytorch/graphbolt/impl/test_gpu_cached_feature.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,20 +17,21 @@ def to_on_disk_numpy(test_dir, name, t):
return path

def _skip_condition_CachedFeature():
return (F._default_context_str != "gpu") or \
(torch.cuda.get_device_capability()[0] < 7)
return (F._default_context_str != "gpu") or (
torch.cuda.get_device_capability()[0] < 7
)


def _reson_to_skip_CachedFeature():
def _reason_to_skip_CachedFeature():
if F._default_context_str != "gpu":
return "GPUCachedFeature tests are available only on GPU."
return "GPUCachedFeature tests are available only when testing the GPU backend."

return "GPUCachedFeature requires a Volta or later generation NVIDIA GPU."


@unittest.skipIf(
_skip_condition_CachedFeature(),
reason=_reson_to_skip_CachedFeature(),
reason=_reason_to_skip_CachedFeature(),
)
@pytest.mark.parametrize(
"dtype",
Expand Down Expand Up @@ -127,7 +128,7 @@ def test_gpu_cached_feature(dtype, cache_size_a, cache_size_b):

@unittest.skipIf(
_skip_condition_CachedFeature(),
reason=_reson_to_skip_CachedFeature(),
reason=_reason_to_skip_CachedFeature(),
)
@pytest.mark.parametrize(
"dtype",
Expand Down Expand Up @@ -165,7 +166,7 @@ def test_gpu_cached_feature_read_async(dtype, pin_memory):

@unittest.skipIf(
_skip_condition_CachedFeature(),
reason=_reson_to_skip_CachedFeature(),
reason=_reason_to_skip_CachedFeature(),
)
@unittest.skipIf(
not torch.ops.graphbolt.detect_io_uring(),
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@
@unittest.skipIf(
F._default_context_str != "gpu"
or torch.cuda.get_device_capability()[0] < 7,
reason="GPUCachedFeature tests are available only on GPU."
reason="GPUCachedFeature tests are available only when testing the GPU backend."
if F._default_context_str != "gpu"
else "GPUCachedFeature requires a Volta or later generation NVIDIA GPU.",
)
Expand Down
Original file line number Diff line number Diff line change
@@ -1,4 +1,3 @@
import unittest
import backend as F

import pytest
Expand All @@ -7,17 +6,19 @@
from dgl import graphbolt as gb


@unittest.skipIf(
F._default_context_str != "gpu"
or torch.cuda.get_device_capability()[0] < 7,
reason="GPUCachedFeature tests are available only on GPU."
if F._default_context_str != "gpu"
else "GPUCachedFeature requires a Volta or later generation NVIDIA GPU.",
)
@pytest.mark.parametrize(
"cached_feature_type", [gb.cpu_cached_feature, gb.gpu_cached_feature]
)
def test_hetero_cached_feature(cached_feature_type):
if cached_feature_type == gb.gpu_cached_feature and (
F._default_context_str != "gpu"
or torch.cuda.get_device_capability()[0] < 7
):
pytest.skip(
"GPUCachedFeature tests are available only when testing the GPU backend."
if F._default_context_str != "gpu" else
"GPUCachedFeature requires a Volta or later generation NVIDIA GPU."
)
device = F.ctx() if cached_feature_type == gb.gpu_cached_feature else None
pin_memory = cached_feature_type == gb.gpu_cached_feature

Expand Down

0 comments on commit ae6b970

Please sign in to comment.