From 7363c31785064e8fb197bfbde3ff1ccab3382a4d Mon Sep 17 00:00:00 2001 From: Alan O'Callaghan Date: Wed, 1 May 2024 00:52:28 +0100 Subject: [PATCH 1/2] Add ability to poll engines for available devices --- src/main/java/qupath/ext/djl/DjlTools.java | 86 +++++++++++++++------- 1 file changed, 58 insertions(+), 28 deletions(-) diff --git a/src/main/java/qupath/ext/djl/DjlTools.java b/src/main/java/qupath/ext/djl/DjlTools.java index 21251fc..4abf567 100644 --- a/src/main/java/qupath/ext/djl/DjlTools.java +++ b/src/main/java/qupath/ext/djl/DjlTools.java @@ -16,35 +16,7 @@ package qupath.ext.djl; -import java.io.IOException; -import java.net.URI; -import java.util.ArrayList; -import java.util.Arrays; -import java.util.Collection; -import java.util.Collections; -import java.util.HashMap; -import java.util.HashSet; -import java.util.List; -import java.util.Map; -import java.util.Set; - import ai.djl.Device; -import org.bytedeco.javacpp.Loader; -import org.bytedeco.javacpp.PointerScope; -import org.bytedeco.javacpp.indexer.BooleanIndexer; -import org.bytedeco.javacpp.indexer.ByteIndexer; -import org.bytedeco.javacpp.indexer.DoubleIndexer; -import org.bytedeco.javacpp.indexer.FloatIndexer; -import org.bytedeco.javacpp.indexer.HalfIndexer; -import org.bytedeco.javacpp.indexer.IntIndexer; -import org.bytedeco.javacpp.indexer.LongIndexer; -import org.bytedeco.javacpp.indexer.UByteIndexer; -import org.bytedeco.javacpp.indexer.UShortIndexer; -import org.bytedeco.opencv.global.opencv_core; -import org.bytedeco.opencv.opencv_core.Mat; -import org.slf4j.Logger; -import org.slf4j.LoggerFactory; - import ai.djl.MalformedModelException; import ai.djl.Model; import ai.djl.engine.Engine; @@ -64,10 +36,38 @@ import ai.djl.translate.TranslatorContext; import ai.djl.util.Pair; import ai.djl.util.PairList; +import org.bytedeco.javacpp.PointerScope; +import org.bytedeco.javacpp.indexer.BooleanIndexer; +import org.bytedeco.javacpp.indexer.ByteIndexer; +import org.bytedeco.javacpp.indexer.DoubleIndexer; +import org.bytedeco.javacpp.indexer.FloatIndexer; +import org.bytedeco.javacpp.indexer.HalfIndexer; +import org.bytedeco.javacpp.indexer.IntIndexer; +import org.bytedeco.javacpp.indexer.LongIndexer; +import org.bytedeco.javacpp.indexer.UByteIndexer; +import org.bytedeco.javacpp.indexer.UShortIndexer; +import org.bytedeco.opencv.global.opencv_core; +import org.bytedeco.opencv.opencv_core.Mat; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; +import qupath.lib.common.GeneralTools; import qupath.opencv.dnn.DnnModel; import qupath.opencv.dnn.DnnShape; import qupath.opencv.tools.OpenCVTools; +import java.io.IOException; +import java.net.URI; +import java.util.ArrayList; +import java.util.Arrays; +import java.util.Collection; +import java.util.Collections; +import java.util.HashMap; +import java.util.HashSet; +import java.util.LinkedHashSet; +import java.util.List; +import java.util.Map; +import java.util.Set; + /** * Tools to help work with Deep Java Library within QuPath. * @@ -341,6 +341,36 @@ else if ((urlString.endsWith(".pb") || urlString.endsWith("tf_savedmodel.zip") | return ModelZoo.loadModel(criteria); } + /** + * Get the available devices for an Engine, including MPS if Apple Silicon. + * Does not attempt to download the engine if not available. + * @return An empty list if the engine isn't available, otherwise the possible devices you can currently use. + */ + public static Collection getAvailableDevices(String engineName) { + Set availableDevices = new LinkedHashSet<>(); + boolean includesMPS = false; // Don't add MPS twice + var engine = DjlTools.getEngine(engineName, false); + if (engine == null) { + return List.of(); + } + // This is expected to return GPUs if available, or CPU otherwise + for (var device : engine.getDevices()) { + String name = device.getDeviceType(); + availableDevices.add(name); + if (name.toLowerCase().startsWith("mps")) + includesMPS = true; + } + // CPU should always be available + availableDevices.add("cpu"); + + // If we could use MPS, but don't have it already, add it + if (!includesMPS && GeneralTools.isMac() && "aarch64".equals(System.getProperty("os.arch"))) { + availableDevices.add("mps"); + } + return availableDevices; + } + + /** * Set the default device for the specified engine. From 1bfaa1a771ed99939f7b2db03a6df08fa0e3d859 Mon Sep 17 00:00:00 2001 From: Alan O'Callaghan Date: Wed, 1 May 2024 10:53:44 +0100 Subject: [PATCH 2/2] MPS only on pytorch --- src/main/java/qupath/ext/djl/DjlTools.java | 7 +++++-- 1 file changed, 5 insertions(+), 2 deletions(-) diff --git a/src/main/java/qupath/ext/djl/DjlTools.java b/src/main/java/qupath/ext/djl/DjlTools.java index 4abf567..fdf9217 100644 --- a/src/main/java/qupath/ext/djl/DjlTools.java +++ b/src/main/java/qupath/ext/djl/DjlTools.java @@ -348,11 +348,11 @@ else if ((urlString.endsWith(".pb") || urlString.endsWith("tf_savedmodel.zip") | */ public static Collection getAvailableDevices(String engineName) { Set availableDevices = new LinkedHashSet<>(); - boolean includesMPS = false; // Don't add MPS twice var engine = DjlTools.getEngine(engineName, false); if (engine == null) { return List.of(); } + boolean includesMPS = false; // Don't add MPS twice // This is expected to return GPUs if available, or CPU otherwise for (var device : engine.getDevices()) { String name = device.getDeviceType(); @@ -364,7 +364,10 @@ public static Collection getAvailableDevices(String engineName) { availableDevices.add("cpu"); // If we could use MPS, but don't have it already, add it - if (!includesMPS && GeneralTools.isMac() && "aarch64".equals(System.getProperty("os.arch"))) { + if (engineName.equalsIgnoreCase("pytorch") + && !includesMPS + && GeneralTools.isMac() + && "aarch64".equals(System.getProperty("os.arch"))) { availableDevices.add("mps"); } return availableDevices;