diff --git a/tests/test_models.py b/tests/test_models.py index b5490ff..1f9e854 100644 --- a/tests/test_models.py +++ b/tests/test_models.py @@ -5,7 +5,7 @@ def has_gpu(): - return tf.test.is_gpu_available() + return len(tf.config.list_physical_devices('GPU')) > 0 MODELS = [