Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add ability to poll engines for available devices #21

Closed
wants to merge 2 commits into from
Closed
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
89 changes: 61 additions & 28 deletions src/main/java/qupath/ext/djl/DjlTools.java
Original file line number Diff line number Diff line change
Expand Up @@ -16,35 +16,7 @@

package qupath.ext.djl;

import java.io.IOException;
import java.net.URI;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.Collection;
import java.util.Collections;
import java.util.HashMap;
import java.util.HashSet;
import java.util.List;
import java.util.Map;
import java.util.Set;

import ai.djl.Device;
import org.bytedeco.javacpp.Loader;
import org.bytedeco.javacpp.PointerScope;
import org.bytedeco.javacpp.indexer.BooleanIndexer;
import org.bytedeco.javacpp.indexer.ByteIndexer;
import org.bytedeco.javacpp.indexer.DoubleIndexer;
import org.bytedeco.javacpp.indexer.FloatIndexer;
import org.bytedeco.javacpp.indexer.HalfIndexer;
import org.bytedeco.javacpp.indexer.IntIndexer;
import org.bytedeco.javacpp.indexer.LongIndexer;
import org.bytedeco.javacpp.indexer.UByteIndexer;
import org.bytedeco.javacpp.indexer.UShortIndexer;
import org.bytedeco.opencv.global.opencv_core;
import org.bytedeco.opencv.opencv_core.Mat;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

import ai.djl.MalformedModelException;
import ai.djl.Model;
import ai.djl.engine.Engine;
Expand All @@ -64,10 +36,38 @@
import ai.djl.translate.TranslatorContext;
import ai.djl.util.Pair;
import ai.djl.util.PairList;
import org.bytedeco.javacpp.PointerScope;
import org.bytedeco.javacpp.indexer.BooleanIndexer;
import org.bytedeco.javacpp.indexer.ByteIndexer;
import org.bytedeco.javacpp.indexer.DoubleIndexer;
import org.bytedeco.javacpp.indexer.FloatIndexer;
import org.bytedeco.javacpp.indexer.HalfIndexer;
import org.bytedeco.javacpp.indexer.IntIndexer;
import org.bytedeco.javacpp.indexer.LongIndexer;
import org.bytedeco.javacpp.indexer.UByteIndexer;
import org.bytedeco.javacpp.indexer.UShortIndexer;
import org.bytedeco.opencv.global.opencv_core;
import org.bytedeco.opencv.opencv_core.Mat;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import qupath.lib.common.GeneralTools;
import qupath.opencv.dnn.DnnModel;
import qupath.opencv.dnn.DnnShape;
import qupath.opencv.tools.OpenCVTools;

import java.io.IOException;
import java.net.URI;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.Collection;
import java.util.Collections;
import java.util.HashMap;
import java.util.HashSet;
import java.util.LinkedHashSet;
import java.util.List;
import java.util.Map;
import java.util.Set;

/**
* Tools to help work with Deep Java Library within QuPath.
*
Expand Down Expand Up @@ -341,6 +341,39 @@ else if ((urlString.endsWith(".pb") || urlString.endsWith("tf_savedmodel.zip") |
return ModelZoo.loadModel(criteria);
}

/**
* Get the available devices for an Engine, including MPS if Apple Silicon.
* Does <b>not</b> attempt to download the engine if not available.
* @return An empty list if the engine isn't available, otherwise the possible devices you can currently use.
*/
public static Collection<String> getAvailableDevices(String engineName) {
Set<String> availableDevices = new LinkedHashSet<>();
var engine = DjlTools.getEngine(engineName, false);
if (engine == null) {
return List.of();
}
boolean includesMPS = false; // Don't add MPS twice
// This is expected to return GPUs if available, or CPU otherwise
for (var device : engine.getDevices()) {
String name = device.getDeviceType();
availableDevices.add(name);
if (name.toLowerCase().startsWith("mps"))
includesMPS = true;
}
// CPU should always be available
availableDevices.add("cpu");

// If we could use MPS, but don't have it already, add it
if (engineName.equalsIgnoreCase("pytorch")
&& !includesMPS
&& GeneralTools.isMac()
&& "aarch64".equals(System.getProperty("os.arch"))) {
availableDevices.add("mps");
}
return availableDevices;
}



/**
* Set the default device for the specified engine.
Expand Down
Loading