From b174ba0d8dd194da3fb821e2cea213127f04ddd8 Mon Sep 17 00:00:00 2001 From: andreii Date: Mon, 12 Feb 2024 16:29:15 -0800 Subject: [PATCH 1/5] Skip test when atomic operations are not supported on GPU. --- tests/python/common/ops/test_ops.py | 10 ++++++---- 1 file changed, 6 insertions(+), 4 deletions(-) diff --git a/tests/python/common/ops/test_ops.py b/tests/python/common/ops/test_ops.py index 2f8634c99126..119a2b1baba3 100644 --- a/tests/python/common/ops/test_ops.py +++ b/tests/python/common/ops/test_ops.py @@ -413,14 +413,16 @@ def test_gather_mm_idx_b(feat_size, dtype, tol): and not torch.cuda.is_bf16_supported() ): pytest.skip("BF16 is not supported.") - if ( F._default_context_str == "gpu" - and dtype == torch.float16 - and torch.cuda.get_device_capability() < (7, 0) + and (dtype == torch.float16 + and torch.cuda.get_device_capability() < (7, 0) + or dtype == torch.bfloat16 + and torch.cuda.get_device_capability() < (8, 0) + ) ): pytest.skip( - f"FP16 is not supported for atomic operations on GPU with " + f"{dtype} is not supported for atomic operations on GPU with " f"cuda capability ({torch.cuda.get_device_capability()})." ) From 9f29957415fd0dfffd17eb7fd13ddd00bd142e6f Mon Sep 17 00:00:00 2001 From: andreii Date: Mon, 12 Feb 2024 16:54:44 -0800 Subject: [PATCH 2/5] Fixing lint problems --- tests/python/common/ops/test_ops.py | 12 +++++------- 1 file changed, 5 insertions(+), 7 deletions(-) diff --git a/tests/python/common/ops/test_ops.py b/tests/python/common/ops/test_ops.py index 119a2b1baba3..dda1cba262cf 100644 --- a/tests/python/common/ops/test_ops.py +++ b/tests/python/common/ops/test_ops.py @@ -413,13 +413,11 @@ def test_gather_mm_idx_b(feat_size, dtype, tol): and not torch.cuda.is_bf16_supported() ): pytest.skip("BF16 is not supported.") - if ( - F._default_context_str == "gpu" - and (dtype == torch.float16 - and torch.cuda.get_device_capability() < (7, 0) - or dtype == torch.bfloat16 - and torch.cuda.get_device_capability() < (8, 0) - ) + if F._default_context_str == "gpu" and ( + dtype == torch.float16 + and torch.cuda.get_device_capability() < (7, 0) + or dtype == torch.bfloat16 + and torch.cuda.get_device_capability() < (8, 0) ): pytest.skip( f"{dtype} is not supported for atomic operations on GPU with " From 18b755398bcb1128fe4e0ba2a7f4d5a755b73648 Mon Sep 17 00:00:00 2001 From: andreii Date: Tue, 20 Feb 2024 09:50:58 -0800 Subject: [PATCH 3/5] Changes suggested by @frozenbugs --- tests/python/common/ops/test_ops.py | 29 +++++++++++++---------------- 1 file changed, 13 insertions(+), 16 deletions(-) diff --git a/tests/python/common/ops/test_ops.py b/tests/python/common/ops/test_ops.py index dda1cba262cf..b664742dd3cf 100644 --- a/tests/python/common/ops/test_ops.py +++ b/tests/python/common/ops/test_ops.py @@ -407,22 +407,19 @@ def test_segment_mm(idtype, feat_size, dtype, tol): def test_gather_mm_idx_b(feat_size, dtype, tol): if F._default_context_str == "cpu" and dtype == torch.float16: pytest.skip("float16 is not supported on CPU.") - if ( - F._default_context_str == "gpu" - and dtype == torch.bfloat16 - and not torch.cuda.is_bf16_supported() - ): - pytest.skip("BF16 is not supported.") - if F._default_context_str == "gpu" and ( - dtype == torch.float16 - and torch.cuda.get_device_capability() < (7, 0) - or dtype == torch.bfloat16 - and torch.cuda.get_device_capability() < (8, 0) - ): - pytest.skip( - f"{dtype} is not supported for atomic operations on GPU with " - f"cuda capability ({torch.cuda.get_device_capability()})." - ) + + if F._default_context_str == "gpu": + if dtype == torch.bfloat16 and not torch.cuda.is_bf16_supported(): + pytest.skip("BF16 is not supported.") + + device_capability = torch.cuda.get_device_capability() + if (dtype == torch.float16 and device_capability < (7, 0) or + dtype == torch.bfloat16 and device_capability < (8, 0) + ): + pytest.skip( + f"{dtype} is not supported for atomic operations on GPU with " + f"cuda capability ({torch.cuda.get_device_capability()})." + ) dev = F.ctx() # input From 219d53a1f9c1ebb72e10d39244c087dc3ecc6bcf Mon Sep 17 00:00:00 2001 From: andreii Date: Tue, 20 Feb 2024 10:18:46 -0800 Subject: [PATCH 4/5] Fixing lint problems --- tests/python/common/ops/test_ops.py | 14 ++++++++++---- 1 file changed, 10 insertions(+), 4 deletions(-) diff --git a/tests/python/common/ops/test_ops.py b/tests/python/common/ops/test_ops.py index b664742dd3cf..100db29c0533 100644 --- a/tests/python/common/ops/test_ops.py +++ b/tests/python/common/ops/test_ops.py @@ -408,13 +408,19 @@ def test_gather_mm_idx_b(feat_size, dtype, tol): if F._default_context_str == "cpu" and dtype == torch.float16: pytest.skip("float16 is not supported on CPU.") - if F._default_context_str == "gpu": + if F._default_context_str == "gpu": if dtype == torch.bfloat16 and not torch.cuda.is_bf16_supported(): pytest.skip("BF16 is not supported.") - device_capability = torch.cuda.get_device_capability() - if (dtype == torch.float16 and device_capability < (7, 0) or - dtype == torch.bfloat16 and device_capability < (8, 0) + if ( + ( + dtype == torch.float16 + and torch.cuda.get_device_capability() < (7, 0) + ) + or ( + dtype == torch.bfloat16 + and device_capability < (8, 0) + ) ): pytest.skip( f"{dtype} is not supported for atomic operations on GPU with " From 491cb459aa4cc15ea7c3a8028a15b2c69cb939d2 Mon Sep 17 00:00:00 2001 From: andreii Date: Tue, 20 Feb 2024 10:28:11 -0800 Subject: [PATCH 5/5] Fixing lint problems --- tests/python/common/ops/test_ops.py | 13 +++++-------- 1 file changed, 5 insertions(+), 8 deletions(-) diff --git a/tests/python/common/ops/test_ops.py b/tests/python/common/ops/test_ops.py index 100db29c0533..f2e50c2e4c50 100644 --- a/tests/python/common/ops/test_ops.py +++ b/tests/python/common/ops/test_ops.py @@ -413,14 +413,11 @@ def test_gather_mm_idx_b(feat_size, dtype, tol): pytest.skip("BF16 is not supported.") if ( - ( - dtype == torch.float16 - and torch.cuda.get_device_capability() < (7, 0) - ) - or ( - dtype == torch.bfloat16 - and device_capability < (8, 0) - ) + dtype == torch.float16 + and torch.cuda.get_device_capability() < (7, 0) + ) or ( + dtype == torch.bfloat16 + and torch.cuda.get_device_capability() < (8, 0) ): pytest.skip( f"{dtype} is not supported for atomic operations on GPU with "