Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Update to use simplified DnnModel interface #9

Merged
merged 2 commits into from
Nov 15, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
145 changes: 56 additions & 89 deletions src/main/java/qupath/ext/djl/DjlDnnModel.java
Original file line number Diff line number Diff line change
Expand Up @@ -16,35 +16,31 @@

package qupath.ext.djl;

import java.io.IOException;
import java.net.URI;
import java.util.ArrayList;
import java.util.Collection;
import java.util.Collections;
import java.util.List;
import java.util.Map;
import java.util.Objects;
import java.util.stream.Collectors;

import org.bytedeco.opencv.opencv_core.Mat;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

import ai.djl.inference.Predictor;
import ai.djl.ndarray.NDArray;
import ai.djl.ndarray.NDList;
import ai.djl.ndarray.types.LayoutType;
import ai.djl.repository.zoo.ZooModel;
import ai.djl.translate.Batchifier;
import ai.djl.translate.NoopTranslator;
import ai.djl.translate.NoBatchifyTranslator;
import ai.djl.translate.TranslateException;
import ai.djl.translate.TranslatorContext;
import org.bytedeco.opencv.opencv_core.Mat;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import qupath.lib.io.UriResource;
import qupath.opencv.dnn.BlobFunction;
import qupath.opencv.dnn.DnnModel;
import qupath.opencv.dnn.DnnShape;
import qupath.opencv.dnn.PredictionFunction;

class DjlDnnModel implements DnnModel<NDList>, AutoCloseable, UriResource {
import java.io.IOException;
import java.net.URI;
import java.util.ArrayList;
import java.util.Collection;
import java.util.List;
import java.util.Map;
import java.util.Objects;
import java.util.stream.Collectors;

class DjlDnnModel implements DnnModel, AutoCloseable, UriResource {

private static final Logger logger = LoggerFactory.getLogger(DjlDnnModel.class);

Expand All @@ -58,10 +54,13 @@ class DjlDnnModel implements DnnModel<NDList>, AutoCloseable, UriResource {
private boolean lazyInitialize;

private transient boolean failed;
private transient ZooModel<NDList, NDList> model;
private transient Predictor<NDList, NDList> predictor;
private transient BlobFunction<NDList> blobFun;
private transient PredictionFunction<NDList> predictFun;
private transient ZooModel<Mat[], Mat[]> model;
private transient Predictor<Mat[], Mat[]> predictor;

/**
* Default layout for an OpenCV Mat
*/
private static final String DEFAULT_MAT_LAYOUT = getLayout(LayoutType.HEIGHT, LayoutType.WIDTH, LayoutType.CHANNEL);

DjlDnnModel(String engine, Collection<URI> uris, String ndLayout, Map<String, DnnShape> inputs, Map<String, DnnShape> outputs, boolean lazyInitialize) {
this.engine = engine;
Expand All @@ -85,11 +84,13 @@ private void ensureInitialized() {
if (!failed && model == null) {
try {
logger.debug("Initializing DjlDnnModel");
model = DjlTools.loadModel(engine, uris.toArray(URI[]::new));
if (ndLayout != null && ndLayout.contains("N"))
predictor = model.newPredictor();
else
predictor = model.newPredictor(new NoopTranslator(Batchifier.STACK));
model = DjlTools.loadModel(engine,
Mat[].class, Mat[].class,
new ModelMatTranslator(),
uris.toArray(URI[]::new));
// if (ndLayout != null && ndLayout.contains("N"))
predictor = model.newPredictor();


// TODO: Better handling of missing inputs/outputs - we may need to run a prediction for this to work
if (this.inputs == null || this.inputs.isEmpty()) {
Expand All @@ -111,9 +112,6 @@ private void ensureInitialized() {
if (this.outputs == null || this.outputs.isEmpty())
outputs = Map.of(DnnModel.DEFAULT_OUTPUT_NAME, DnnShape.UNKNOWN_SHAPE);
}

blobFun = new BlobFun();
predictFun = new PredictFun();
} catch (Exception e) {
failed = true;
logger.debug("Failed to create DjlDnnModel");
Expand All @@ -125,36 +123,36 @@ private void ensureInitialized() {
}

@Override
public BlobFunction<NDList> getBlobFunction() {
ensureInitialized();
return blobFun;
public Map<String, Mat> predict(Map<String, Mat> blobs) {
synchronized (predictor) {
try {
var result = predictor.predict(blobs.values().stream().toArray(Mat[]::new));
return Map.of(DnnModel.DEFAULT_OUTPUT_NAME, result[0]);
} catch (TranslateException e) {
throw new RuntimeException(e);
}
}
}

@Override
public BlobFunction<NDList> getBlobFunction(String name) {
ensureInitialized();
return blobFun;
public Mat predict(Mat mat) {
return DnnModel.super.predict(mat);
}

@Override
public PredictionFunction<NDList> getPredictionFunction() {
ensureInitialized();
return predictFun;
public List<Mat> batchPredict(List<? extends Mat> mats) {
return DnnModel.super.batchPredict(mats);
}

@Override
public synchronized void close() throws Exception {
if (model != null) {
model.close();
model = null;
blobFun = null;
predictFun = null;
logger.debug("Closed DjlDnnModel");
}
}

private static final String DEFAULT_MAT_LAYOUT = getLayout(LayoutType.HEIGHT, LayoutType.WIDTH, LayoutType.CHANNEL);

private static String getLayout(LayoutType... layouts) {
return LayoutType.toString(layouts);
}
Expand Down Expand Up @@ -192,63 +190,32 @@ private static String estimateOutputLayout(NDArray array) {



private class BlobFun implements BlobFunction<NDList> {
private class ModelMatTranslator implements NoBatchifyTranslator<Mat[], Mat[]> {

@Override
public NDList toBlob(Mat... mats) {
NDList list = new NDList();
String layout = ndLayout;
for (var mat : mats) {
// Try to figure out the layout
if (layout == null) {
layout = estimateInputLayout(mat);
}
list.add(DjlTools.matToNDArray(model.getNDManager(), mat, layout));
}
return list;
}

@Override
public List<Mat> fromBlob(NDList blob) {
public Mat[] processOutput(TranslatorContext ctx, NDList list) throws Exception {
String layout;
if ((ndLayout == null || ndLayout.length() != blob.singletonOrThrow().getShape().dimension()) && !blob.isEmpty())
layout = estimateOutputLayout(blob.get(0));
if ((ndLayout == null || ndLayout.length() != list.singletonOrThrow().getShape().dimension()) && !list.isEmpty())
layout = estimateOutputLayout(list.get(0));
else
layout = ndLayout;
var output = blob.stream().map(b -> DjlTools.ndArrayToMat(b, layout)).collect(Collectors.toList());
blob.close();
var output = list.stream().map(b -> DjlTools.ndArrayToMat(b, layout)).toArray(Mat[]::new);
return output;
}

}

private class PredictFun implements PredictionFunction<NDList> {

@Override
public NDList predict(NDList input) {
try {
NDList output;
// TODO: Check whether to support per-thread predictors
synchronized (predictor) {
output = predictor.batchPredict(Collections.singletonList(input)).get(0);
public NDList processInput(TranslatorContext ctx, Mat... input) throws Exception {
NDList list = new NDList();
String layout = ndLayout;
for (var mat : input) {
// Try to figure out the layout
if (layout == null) {
layout = estimateInputLayout(mat);
}
input.close();
return output;
} catch (TranslateException e) {
throw new RuntimeException(e);
list.add(DjlTools.matToNDArray(ctx.getNDManager(), mat, layout));
}
return list;
}

@Override
public Map<String, DnnShape> getInputs() {
return inputs;
}

@Override
public Map<String, DnnShape> getOutputs(DnnShape... inputShapes) {
return outputs;
}

}

@Override
Expand Down
6 changes: 4 additions & 2 deletions src/main/java/qupath/ext/djl/DjlDnnModelBuilder.java
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@
*
* @author Pete Bankhead
*/
public class DjlDnnModelBuilder implements DnnModelBuilder<NDList> {
public class DjlDnnModelBuilder implements DnnModelBuilder {

private static String getEngineName(String framework) {
if (DjlTools.ALL_ENGINES.contains(framework))
Expand All @@ -41,6 +41,8 @@ private static String getEngineName(String framework) {
switch(framework) {
case DnnModelParams.FRAMEWORK_TENSORFLOW:
return DjlTools.ENGINE_TENSORFLOW;
case DnnModelParams.FRAMEWORK_TF_LITE:
return DjlTools.ENGINE_TFLITE;
case DnnModelParams.FRAMEWORK_ONNX_RUNTIME:
return DjlTools.ENGINE_ONNX_RUNTIME;
case DnnModelParams.FRAMEWORK_PYTORCH:
Expand Down Expand Up @@ -112,7 +114,7 @@ private static String axesToLayout(String axes) {
}

@Override
public DnnModel<NDList> buildModel(DnnModelParams params) {
public DnnModel buildModel(DnnModelParams params) {
var framework = params.getFramework();
String engineName = null;
if (framework == null) {
Expand Down
2 changes: 1 addition & 1 deletion src/main/java/qupath/ext/djl/DjlExtension.java
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,7 @@ public class DjlExtension implements QuPathExtension, GitHubProject {

private final static Logger logger = LoggerFactory.getLogger(DjlExtension.class);

private final static DnnModelBuilder<?> builder = new DjlDnnModelBuilder();
private final static DnnModelBuilder builder = new DjlDnnModelBuilder();

static {
// Prevent downloading engines automatically
Expand Down
21 changes: 14 additions & 7 deletions src/main/java/qupath/ext/djl/DjlTools.java
Original file line number Diff line number Diff line change
Expand Up @@ -153,7 +153,7 @@ public class DjlTools {
* @param inputShape expected input shape, according to ndLayout
* @return
*/
public static DnnModel<NDList> createDnnModel(URI uri, String ndLayout, int[] inputShape) {
public static DnnModel createDnnModel(URI uri, String ndLayout, int[] inputShape) {
DnnShape shape = null;
if (inputShape != null)
shape = DnnShape.of(Arrays.stream(inputShape).mapToLong(i -> i).toArray());
Expand All @@ -169,7 +169,7 @@ public static DnnModel<NDList> createDnnModel(URI uri, String ndLayout, int[] in
* @param outputs outputs shapes, if known; if these are null, an attempt will be made to get them from DJL (but this does not always work)
* @return
*/
public static DnnModel<NDList> createDnnModel(String engine, URI uri, String ndLayout, Map<String, DnnShape> inputs, Map<String, DnnShape> outputs) {
public static DnnModel createDnnModel(String engine, URI uri, String ndLayout, Map<String, DnnShape> inputs, Map<String, DnnShape> outputs) {
return createDnnModel(engine, Collections.singletonList(uri), ndLayout, inputs, outputs);
}

Expand All @@ -182,7 +182,7 @@ public static DnnModel<NDList> createDnnModel(String engine, URI uri, String ndL
* @param outputs outputs shapes, if known; if these are null, an attempt will be made to get them from DJL (but this does not always work)
* @return
*/
private static DnnModel<NDList> createDnnModel(String engine, Collection<URI> uris, String ndLayout, Map<String, DnnShape> inputs, Map<String, DnnShape> outputs) {
private static DnnModel createDnnModel(String engine, Collection<URI> uris, String ndLayout, Map<String, DnnShape> inputs, Map<String, DnnShape> outputs) {
return new DjlDnnModel(engine, uris, ndLayout, inputs, outputs, false); // Eagerly initialize (so we know if it doesn't work sooner)
}

Expand Down Expand Up @@ -291,8 +291,12 @@ public static Engine getEngine(String name, boolean downloadIfNeeded) throws Ill
static DnnShape convertShape(Shape shape) {
return DnnShape.of(shape.getShape());
}

static ZooModel<NDList, NDList> loadModel(String engineName, URI... uris) throws ModelNotFoundException, MalformedModelException, IOException {
return loadModel(engineName, NDList.class, NDList.class, null, uris);
}

static <P, Q> ZooModel<P, Q> loadModel(String engineName, Class<P> inputClass, Class<Q> outputClass, Translator<P, Q> translator, URI... uris) throws ModelNotFoundException, MalformedModelException, IOException {
var sb = new StringBuilder();
boolean isFirst = true;
for (var uri : uris) {
Expand All @@ -307,13 +311,14 @@ static ZooModel<NDList, NDList> loadModel(String engineName, URI... uris) throws
}
sb.append(uri.toString());
}
return loadModel(engineName, sb.toString());
return loadModel(engineName, inputClass, outputClass, translator, sb.toString());
}

private static ZooModel<NDList, NDList> loadModel(String engineName, String urls) throws ModelNotFoundException, MalformedModelException, IOException {
private static <P, Q> ZooModel<P, Q> loadModel(String engineName, Class<P> inputClass, Class<Q> outputClass, Translator<P, Q> translator, String urls) throws ModelNotFoundException, MalformedModelException, IOException {
var builder = Criteria.builder()
.setTypes(NDList.class, NDList.class)
.setTypes(inputClass, outputClass)
.optModelUrls(urls)
.optTranslator(translator)
.optProgress(new ProgressBar());

String selectedEngine = null;
Expand All @@ -330,6 +335,8 @@ private static ZooModel<NDList, NDList> loadModel(String engineName, String urls
selectedEngine = "OnnxRuntime";
else if ((urlString.endsWith("pytorch") || urlString.endsWith(".pt")) && Engine.hasEngine("PyTorch"))
selectedEngine = "PyTorch";
else if (urlString.endsWith(".tflite") && Engine.hasEngine("TFLite"))
selectedEngine = "TFLite";
else if ((urlString.endsWith(".pb") || urlString.endsWith("tf_savedmodel.zip") || urlString.endsWith("tf_savedmodel")) && Engine.hasEngine("TensorFlow"))
selectedEngine = "TensorFlow";
}
Expand Down