diff --git a/src/main/java/qupath/ext/djl/DjlTools.java b/src/main/java/qupath/ext/djl/DjlTools.java
index 21251fc..fdf9217 100644
--- a/src/main/java/qupath/ext/djl/DjlTools.java
+++ b/src/main/java/qupath/ext/djl/DjlTools.java
@@ -16,35 +16,7 @@
package qupath.ext.djl;
-import java.io.IOException;
-import java.net.URI;
-import java.util.ArrayList;
-import java.util.Arrays;
-import java.util.Collection;
-import java.util.Collections;
-import java.util.HashMap;
-import java.util.HashSet;
-import java.util.List;
-import java.util.Map;
-import java.util.Set;
-
import ai.djl.Device;
-import org.bytedeco.javacpp.Loader;
-import org.bytedeco.javacpp.PointerScope;
-import org.bytedeco.javacpp.indexer.BooleanIndexer;
-import org.bytedeco.javacpp.indexer.ByteIndexer;
-import org.bytedeco.javacpp.indexer.DoubleIndexer;
-import org.bytedeco.javacpp.indexer.FloatIndexer;
-import org.bytedeco.javacpp.indexer.HalfIndexer;
-import org.bytedeco.javacpp.indexer.IntIndexer;
-import org.bytedeco.javacpp.indexer.LongIndexer;
-import org.bytedeco.javacpp.indexer.UByteIndexer;
-import org.bytedeco.javacpp.indexer.UShortIndexer;
-import org.bytedeco.opencv.global.opencv_core;
-import org.bytedeco.opencv.opencv_core.Mat;
-import org.slf4j.Logger;
-import org.slf4j.LoggerFactory;
-
import ai.djl.MalformedModelException;
import ai.djl.Model;
import ai.djl.engine.Engine;
@@ -64,10 +36,38 @@
import ai.djl.translate.TranslatorContext;
import ai.djl.util.Pair;
import ai.djl.util.PairList;
+import org.bytedeco.javacpp.PointerScope;
+import org.bytedeco.javacpp.indexer.BooleanIndexer;
+import org.bytedeco.javacpp.indexer.ByteIndexer;
+import org.bytedeco.javacpp.indexer.DoubleIndexer;
+import org.bytedeco.javacpp.indexer.FloatIndexer;
+import org.bytedeco.javacpp.indexer.HalfIndexer;
+import org.bytedeco.javacpp.indexer.IntIndexer;
+import org.bytedeco.javacpp.indexer.LongIndexer;
+import org.bytedeco.javacpp.indexer.UByteIndexer;
+import org.bytedeco.javacpp.indexer.UShortIndexer;
+import org.bytedeco.opencv.global.opencv_core;
+import org.bytedeco.opencv.opencv_core.Mat;
+import org.slf4j.Logger;
+import org.slf4j.LoggerFactory;
+import qupath.lib.common.GeneralTools;
import qupath.opencv.dnn.DnnModel;
import qupath.opencv.dnn.DnnShape;
import qupath.opencv.tools.OpenCVTools;
+import java.io.IOException;
+import java.net.URI;
+import java.util.ArrayList;
+import java.util.Arrays;
+import java.util.Collection;
+import java.util.Collections;
+import java.util.HashMap;
+import java.util.HashSet;
+import java.util.LinkedHashSet;
+import java.util.List;
+import java.util.Map;
+import java.util.Set;
+
/**
* Tools to help work with Deep Java Library within QuPath.
*
@@ -341,6 +341,39 @@ else if ((urlString.endsWith(".pb") || urlString.endsWith("tf_savedmodel.zip") |
return ModelZoo.loadModel(criteria);
}
+ /**
+ * Get the available devices for an Engine, including MPS if Apple Silicon.
+ * Does not attempt to download the engine if not available.
+ * @return An empty list if the engine isn't available, otherwise the possible devices you can currently use.
+ */
+ public static Collection getAvailableDevices(String engineName) {
+ Set availableDevices = new LinkedHashSet<>();
+ var engine = DjlTools.getEngine(engineName, false);
+ if (engine == null) {
+ return List.of();
+ }
+ boolean includesMPS = false; // Don't add MPS twice
+ // This is expected to return GPUs if available, or CPU otherwise
+ for (var device : engine.getDevices()) {
+ String name = device.getDeviceType();
+ availableDevices.add(name);
+ if (name.toLowerCase().startsWith("mps"))
+ includesMPS = true;
+ }
+ // CPU should always be available
+ availableDevices.add("cpu");
+
+ // If we could use MPS, but don't have it already, add it
+ if (engineName.equalsIgnoreCase("pytorch")
+ && !includesMPS
+ && GeneralTools.isMac()
+ && "aarch64".equals(System.getProperty("os.arch"))) {
+ availableDevices.add("mps");
+ }
+ return availableDevices;
+ }
+
+
/**
* Set the default device for the specified engine.