diff --git a/src/main/java/qupath/ext/djl/DjlTools.java b/src/main/java/qupath/ext/djl/DjlTools.java index 21251fc..fdf9217 100644 --- a/src/main/java/qupath/ext/djl/DjlTools.java +++ b/src/main/java/qupath/ext/djl/DjlTools.java @@ -16,35 +16,7 @@ package qupath.ext.djl; -import java.io.IOException; -import java.net.URI; -import java.util.ArrayList; -import java.util.Arrays; -import java.util.Collection; -import java.util.Collections; -import java.util.HashMap; -import java.util.HashSet; -import java.util.List; -import java.util.Map; -import java.util.Set; - import ai.djl.Device; -import org.bytedeco.javacpp.Loader; -import org.bytedeco.javacpp.PointerScope; -import org.bytedeco.javacpp.indexer.BooleanIndexer; -import org.bytedeco.javacpp.indexer.ByteIndexer; -import org.bytedeco.javacpp.indexer.DoubleIndexer; -import org.bytedeco.javacpp.indexer.FloatIndexer; -import org.bytedeco.javacpp.indexer.HalfIndexer; -import org.bytedeco.javacpp.indexer.IntIndexer; -import org.bytedeco.javacpp.indexer.LongIndexer; -import org.bytedeco.javacpp.indexer.UByteIndexer; -import org.bytedeco.javacpp.indexer.UShortIndexer; -import org.bytedeco.opencv.global.opencv_core; -import org.bytedeco.opencv.opencv_core.Mat; -import org.slf4j.Logger; -import org.slf4j.LoggerFactory; - import ai.djl.MalformedModelException; import ai.djl.Model; import ai.djl.engine.Engine; @@ -64,10 +36,38 @@ import ai.djl.translate.TranslatorContext; import ai.djl.util.Pair; import ai.djl.util.PairList; +import org.bytedeco.javacpp.PointerScope; +import org.bytedeco.javacpp.indexer.BooleanIndexer; +import org.bytedeco.javacpp.indexer.ByteIndexer; +import org.bytedeco.javacpp.indexer.DoubleIndexer; +import org.bytedeco.javacpp.indexer.FloatIndexer; +import org.bytedeco.javacpp.indexer.HalfIndexer; +import org.bytedeco.javacpp.indexer.IntIndexer; +import org.bytedeco.javacpp.indexer.LongIndexer; +import org.bytedeco.javacpp.indexer.UByteIndexer; +import org.bytedeco.javacpp.indexer.UShortIndexer; +import org.bytedeco.opencv.global.opencv_core; +import org.bytedeco.opencv.opencv_core.Mat; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; +import qupath.lib.common.GeneralTools; import qupath.opencv.dnn.DnnModel; import qupath.opencv.dnn.DnnShape; import qupath.opencv.tools.OpenCVTools; +import java.io.IOException; +import java.net.URI; +import java.util.ArrayList; +import java.util.Arrays; +import java.util.Collection; +import java.util.Collections; +import java.util.HashMap; +import java.util.HashSet; +import java.util.LinkedHashSet; +import java.util.List; +import java.util.Map; +import java.util.Set; + /** * Tools to help work with Deep Java Library within QuPath. * @@ -341,6 +341,39 @@ else if ((urlString.endsWith(".pb") || urlString.endsWith("tf_savedmodel.zip") | return ModelZoo.loadModel(criteria); } + /** + * Get the available devices for an Engine, including MPS if Apple Silicon. + * Does not attempt to download the engine if not available. + * @return An empty list if the engine isn't available, otherwise the possible devices you can currently use. + */ + public static Collection getAvailableDevices(String engineName) { + Set availableDevices = new LinkedHashSet<>(); + var engine = DjlTools.getEngine(engineName, false); + if (engine == null) { + return List.of(); + } + boolean includesMPS = false; // Don't add MPS twice + // This is expected to return GPUs if available, or CPU otherwise + for (var device : engine.getDevices()) { + String name = device.getDeviceType(); + availableDevices.add(name); + if (name.toLowerCase().startsWith("mps")) + includesMPS = true; + } + // CPU should always be available + availableDevices.add("cpu"); + + // If we could use MPS, but don't have it already, add it + if (engineName.equalsIgnoreCase("pytorch") + && !includesMPS + && GeneralTools.isMac() + && "aarch64".equals(System.getProperty("os.arch"))) { + availableDevices.add("mps"); + } + return availableDevices; + } + + /** * Set the default device for the specified engine.