From 64655d61402f4179682110c2b5fb47c426abb3e0 Mon Sep 17 00:00:00 2001 From: DivyaSesh <64513125+gitttt-1234@users.noreply.github.com> Date: Wed, 6 Sep 2023 09:50:35 -0700 Subject: [PATCH] Fix Auto-select GPU (#1474) * Fix Auto-select GPU * Format file * Add variable in init * Format files * Add small test to ensure environment variable is set * Make linter happy --------- Co-authored-by: roomrys --- sleap/nn/__init__.py | 3 +++ sleap/nn/system.py | 1 + tests/nn/test_system.py | 6 ++++++ 3 files changed, 10 insertions(+) diff --git a/sleap/nn/__init__.py b/sleap/nn/__init__.py index b3c4eacd3..648fd49ff 100644 --- a/sleap/nn/__init__.py +++ b/sleap/nn/__init__.py @@ -14,3 +14,6 @@ import sleap.nn.tracking import sleap.nn.viz import sleap.nn.identity +import os + +os.environ["CUDA_DEVICE_ORDER"] = "PCI_BUS_ID" diff --git a/sleap/nn/system.py b/sleap/nn/system.py index 24b4c14b3..eeb3f3ca4 100644 --- a/sleap/nn/system.py +++ b/sleap/nn/system.py @@ -195,6 +195,7 @@ def get_gpu_memory() -> List[int]: A list of the available memory on each GPU in MiB. """ + if shutil.which("nvidia-smi") is None: return [] diff --git a/tests/nn/test_system.py b/tests/nn/test_system.py index ea835e3c3..fc95bb0ea 100644 --- a/tests/nn/test_system.py +++ b/tests/nn/test_system.py @@ -87,3 +87,9 @@ def test_gpu_order_and_length(): # Assert that the order and length of GPU indices match assert sleap_indices == nvidia_indices + + +def test_gpu_device_order(): + """Indirectly tests GPU device order by ensuring environment variable is set.""" + + assert os.environ["CUDA_DEVICE_ORDER"] == "PCI_BUS_ID"