Skip to content

Commit

Permalink
Correcting Misleading Test Skip Reasons. (#7808)
Browse files Browse the repository at this point in the history
  • Loading branch information
drivanov authored Oct 11, 2024
1 parent 433ffb3 commit 4a6bfa4
Show file tree
Hide file tree
Showing 3 changed files with 23 additions and 11 deletions.
28 changes: 19 additions & 9 deletions tests/python/pytorch/graphbolt/impl/test_gpu_cached_feature.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,10 +17,22 @@ def to_on_disk_numpy(test_dir, name, t):
return path


def _skip_condition_cached_feature():
return (F._default_context_str != "gpu") or (
torch.cuda.get_device_capability()[0] < 7
)


def _reason_to_skip_cached_feature():
if F._default_context_str != "gpu":
return "GPUCachedFeature tests are available only when testing the GPU backend."

return "GPUCachedFeature requires a Volta or later generation NVIDIA GPU."


@unittest.skipIf(
F._default_context_str != "gpu"
or torch.cuda.get_device_capability()[0] < 7,
reason="GPUCachedFeature requires a Volta or later generation NVIDIA GPU.",
_skip_condition_cached_feature(),
reason=_reason_to_skip_cached_feature(),
)
@pytest.mark.parametrize(
"dtype",
Expand Down Expand Up @@ -116,9 +128,8 @@ def test_gpu_cached_feature(dtype, cache_size_a, cache_size_b):


@unittest.skipIf(
F._default_context_str != "gpu"
or torch.cuda.get_device_capability()[0] < 7,
reason="GPUCachedFeature requires a Volta or later generation NVIDIA GPU.",
_skip_condition_cached_feature(),
reason=_reason_to_skip_cached_feature(),
)
@pytest.mark.parametrize(
"dtype",
Expand Down Expand Up @@ -155,9 +166,8 @@ def test_gpu_cached_feature_read_async(dtype, pin_memory):


@unittest.skipIf(
F._default_context_str != "gpu"
or torch.cuda.get_device_capability()[0] < 7,
reason="GPUCachedFeature requires a Volta or later generation NVIDIA GPU.",
_skip_condition_cached_feature(),
reason=_reason_to_skip_cached_feature(),
)
@unittest.skipIf(
not torch.ops.graphbolt.detect_io_uring(),
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@
@unittest.skipIf(
F._default_context_str != "gpu"
or torch.cuda.get_device_capability()[0] < 7,
reason="GPUCachedFeature tests are available only on GPU."
reason="GPUCachedFeature tests are available only when testing the GPU backend."
if F._default_context_str != "gpu"
else "GPUCachedFeature requires a Volta or later generation NVIDIA GPU.",
)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,9 @@ def test_hetero_cached_feature(cached_feature_type):
or torch.cuda.get_device_capability()[0] < 7
):
pytest.skip(
"GPUCachedFeature requires a Volta or later generation NVIDIA GPU."
"GPUCachedFeature tests are available only when testing the GPU backend."
if F._default_context_str != "gpu"
else "GPUCachedFeature requires a Volta or later generation NVIDIA GPU."
)
device = F.ctx() if cached_feature_type == gb.gpu_cached_feature else None
pin_memory = cached_feature_type == gb.gpu_cached_feature
Expand Down

0 comments on commit 4a6bfa4

Please sign in to comment.