diff --git a/sleap/nn/__init__.py b/sleap/nn/__init__.py index b3c4eacd3..648fd49ff 100644 --- a/sleap/nn/__init__.py +++ b/sleap/nn/__init__.py @@ -14,3 +14,6 @@ import sleap.nn.tracking import sleap.nn.viz import sleap.nn.identity +import os + +os.environ["CUDA_DEVICE_ORDER"] = "PCI_BUS_ID" diff --git a/sleap/nn/system.py b/sleap/nn/system.py index 24b4c14b3..eeb3f3ca4 100644 --- a/sleap/nn/system.py +++ b/sleap/nn/system.py @@ -195,6 +195,7 @@ def get_gpu_memory() -> List[int]: A list of the available memory on each GPU in MiB. """ + if shutil.which("nvidia-smi") is None: return [] diff --git a/tests/nn/test_system.py b/tests/nn/test_system.py index ea835e3c3..fc95bb0ea 100644 --- a/tests/nn/test_system.py +++ b/tests/nn/test_system.py @@ -87,3 +87,9 @@ def test_gpu_order_and_length(): # Assert that the order and length of GPU indices match assert sleap_indices == nvidia_indices + + +def test_gpu_device_order(): + """Indirectly tests GPU device order by ensuring environment variable is set.""" + + assert os.environ["CUDA_DEVICE_ORDER"] == "PCI_BUS_ID"