diff --git a/src/main/java/qupath/ext/djl/DjlDnnModel.java b/src/main/java/qupath/ext/djl/DjlDnnModel.java index 635e3eb..2beb026 100644 --- a/src/main/java/qupath/ext/djl/DjlDnnModel.java +++ b/src/main/java/qupath/ext/djl/DjlDnnModel.java @@ -16,35 +16,31 @@ package qupath.ext.djl; -import java.io.IOException; -import java.net.URI; -import java.util.ArrayList; -import java.util.Collection; -import java.util.Collections; -import java.util.List; -import java.util.Map; -import java.util.Objects; -import java.util.stream.Collectors; - -import org.bytedeco.opencv.opencv_core.Mat; -import org.slf4j.Logger; -import org.slf4j.LoggerFactory; - import ai.djl.inference.Predictor; import ai.djl.ndarray.NDArray; import ai.djl.ndarray.NDList; import ai.djl.ndarray.types.LayoutType; import ai.djl.repository.zoo.ZooModel; -import ai.djl.translate.Batchifier; -import ai.djl.translate.NoopTranslator; +import ai.djl.translate.NoBatchifyTranslator; import ai.djl.translate.TranslateException; +import ai.djl.translate.TranslatorContext; +import org.bytedeco.opencv.opencv_core.Mat; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; import qupath.lib.io.UriResource; -import qupath.opencv.dnn.BlobFunction; import qupath.opencv.dnn.DnnModel; import qupath.opencv.dnn.DnnShape; -import qupath.opencv.dnn.PredictionFunction; -class DjlDnnModel implements DnnModel, AutoCloseable, UriResource { +import java.io.IOException; +import java.net.URI; +import java.util.ArrayList; +import java.util.Collection; +import java.util.List; +import java.util.Map; +import java.util.Objects; +import java.util.stream.Collectors; + +class DjlDnnModel implements DnnModel, AutoCloseable, UriResource { private static final Logger logger = LoggerFactory.getLogger(DjlDnnModel.class); @@ -58,10 +54,13 @@ class DjlDnnModel implements DnnModel, AutoCloseable, UriResource { private boolean lazyInitialize; private transient boolean failed; - private transient ZooModel model; - private transient Predictor predictor; - private transient BlobFunction blobFun; - private transient PredictionFunction predictFun; + private transient ZooModel model; + private transient Predictor predictor; + + /** + * Default layout for an OpenCV Mat + */ + private static final String DEFAULT_MAT_LAYOUT = getLayout(LayoutType.HEIGHT, LayoutType.WIDTH, LayoutType.CHANNEL); DjlDnnModel(String engine, Collection uris, String ndLayout, Map inputs, Map outputs, boolean lazyInitialize) { this.engine = engine; @@ -85,11 +84,13 @@ private void ensureInitialized() { if (!failed && model == null) { try { logger.debug("Initializing DjlDnnModel"); - model = DjlTools.loadModel(engine, uris.toArray(URI[]::new)); - if (ndLayout != null && ndLayout.contains("N")) - predictor = model.newPredictor(); - else - predictor = model.newPredictor(new NoopTranslator(Batchifier.STACK)); + model = DjlTools.loadModel(engine, + Mat[].class, Mat[].class, + new ModelMatTranslator(), + uris.toArray(URI[]::new)); +// if (ndLayout != null && ndLayout.contains("N")) + predictor = model.newPredictor(); + // TODO: Better handling of missing inputs/outputs - we may need to run a prediction for this to work if (this.inputs == null || this.inputs.isEmpty()) { @@ -111,9 +112,6 @@ private void ensureInitialized() { if (this.outputs == null || this.outputs.isEmpty()) outputs = Map.of(DnnModel.DEFAULT_OUTPUT_NAME, DnnShape.UNKNOWN_SHAPE); } - - blobFun = new BlobFun(); - predictFun = new PredictFun(); } catch (Exception e) { failed = true; logger.debug("Failed to create DjlDnnModel"); @@ -125,21 +123,25 @@ private void ensureInitialized() { } @Override - public BlobFunction getBlobFunction() { - ensureInitialized(); - return blobFun; + public Map predict(Map blobs) { + synchronized (predictor) { + try { + var result = predictor.predict(blobs.values().stream().toArray(Mat[]::new)); + return Map.of(DnnModel.DEFAULT_OUTPUT_NAME, result[0]); + } catch (TranslateException e) { + throw new RuntimeException(e); + } + } } @Override - public BlobFunction getBlobFunction(String name) { - ensureInitialized(); - return blobFun; + public Mat predict(Mat mat) { + return DnnModel.super.predict(mat); } @Override - public PredictionFunction getPredictionFunction() { - ensureInitialized(); - return predictFun; + public List batchPredict(List mats) { + return DnnModel.super.batchPredict(mats); } @Override @@ -147,14 +149,10 @@ public synchronized void close() throws Exception { if (model != null) { model.close(); model = null; - blobFun = null; - predictFun = null; logger.debug("Closed DjlDnnModel"); } } - private static final String DEFAULT_MAT_LAYOUT = getLayout(LayoutType.HEIGHT, LayoutType.WIDTH, LayoutType.CHANNEL); - private static String getLayout(LayoutType... layouts) { return LayoutType.toString(layouts); } @@ -192,63 +190,32 @@ private static String estimateOutputLayout(NDArray array) { - private class BlobFun implements BlobFunction { + private class ModelMatTranslator implements NoBatchifyTranslator { @Override - public NDList toBlob(Mat... mats) { - NDList list = new NDList(); - String layout = ndLayout; - for (var mat : mats) { - // Try to figure out the layout - if (layout == null) { - layout = estimateInputLayout(mat); - } - list.add(DjlTools.matToNDArray(model.getNDManager(), mat, layout)); - } - return list; - } - - @Override - public List fromBlob(NDList blob) { + public Mat[] processOutput(TranslatorContext ctx, NDList list) throws Exception { String layout; - if ((ndLayout == null || ndLayout.length() != blob.singletonOrThrow().getShape().dimension()) && !blob.isEmpty()) - layout = estimateOutputLayout(blob.get(0)); + if ((ndLayout == null || ndLayout.length() != list.singletonOrThrow().getShape().dimension()) && !list.isEmpty()) + layout = estimateOutputLayout(list.get(0)); else layout = ndLayout; - var output = blob.stream().map(b -> DjlTools.ndArrayToMat(b, layout)).collect(Collectors.toList()); - blob.close(); + var output = list.stream().map(b -> DjlTools.ndArrayToMat(b, layout)).toArray(Mat[]::new); return output; } - } - - private class PredictFun implements PredictionFunction { - @Override - public NDList predict(NDList input) { - try { - NDList output; - // TODO: Check whether to support per-thread predictors - synchronized (predictor) { - output = predictor.batchPredict(Collections.singletonList(input)).get(0); + public NDList processInput(TranslatorContext ctx, Mat... input) throws Exception { + NDList list = new NDList(); + String layout = ndLayout; + for (var mat : input) { + // Try to figure out the layout + if (layout == null) { + layout = estimateInputLayout(mat); } - input.close(); - return output; - } catch (TranslateException e) { - throw new RuntimeException(e); + list.add(DjlTools.matToNDArray(ctx.getNDManager(), mat, layout)); } + return list; } - - @Override - public Map getInputs() { - return inputs; - } - - @Override - public Map getOutputs(DnnShape... inputShapes) { - return outputs; - } - } @Override diff --git a/src/main/java/qupath/ext/djl/DjlDnnModelBuilder.java b/src/main/java/qupath/ext/djl/DjlDnnModelBuilder.java index a8f4920..f4d6356 100644 --- a/src/main/java/qupath/ext/djl/DjlDnnModelBuilder.java +++ b/src/main/java/qupath/ext/djl/DjlDnnModelBuilder.java @@ -32,7 +32,7 @@ * * @author Pete Bankhead */ -public class DjlDnnModelBuilder implements DnnModelBuilder { +public class DjlDnnModelBuilder implements DnnModelBuilder { private static String getEngineName(String framework) { if (DjlTools.ALL_ENGINES.contains(framework)) @@ -41,6 +41,8 @@ private static String getEngineName(String framework) { switch(framework) { case DnnModelParams.FRAMEWORK_TENSORFLOW: return DjlTools.ENGINE_TENSORFLOW; + case DnnModelParams.FRAMEWORK_TF_LITE: + return DjlTools.ENGINE_TFLITE; case DnnModelParams.FRAMEWORK_ONNX_RUNTIME: return DjlTools.ENGINE_ONNX_RUNTIME; case DnnModelParams.FRAMEWORK_PYTORCH: @@ -112,7 +114,7 @@ private static String axesToLayout(String axes) { } @Override - public DnnModel buildModel(DnnModelParams params) { + public DnnModel buildModel(DnnModelParams params) { var framework = params.getFramework(); String engineName = null; if (framework == null) { diff --git a/src/main/java/qupath/ext/djl/DjlExtension.java b/src/main/java/qupath/ext/djl/DjlExtension.java index 17b729c..92ab62b 100644 --- a/src/main/java/qupath/ext/djl/DjlExtension.java +++ b/src/main/java/qupath/ext/djl/DjlExtension.java @@ -37,7 +37,7 @@ public class DjlExtension implements QuPathExtension, GitHubProject { private final static Logger logger = LoggerFactory.getLogger(DjlExtension.class); - private final static DnnModelBuilder builder = new DjlDnnModelBuilder(); + private final static DnnModelBuilder builder = new DjlDnnModelBuilder(); static { // Prevent downloading engines automatically diff --git a/src/main/java/qupath/ext/djl/DjlTools.java b/src/main/java/qupath/ext/djl/DjlTools.java index e260e8e..6b438e9 100644 --- a/src/main/java/qupath/ext/djl/DjlTools.java +++ b/src/main/java/qupath/ext/djl/DjlTools.java @@ -153,7 +153,7 @@ public class DjlTools { * @param inputShape expected input shape, according to ndLayout * @return */ - public static DnnModel createDnnModel(URI uri, String ndLayout, int[] inputShape) { + public static DnnModel createDnnModel(URI uri, String ndLayout, int[] inputShape) { DnnShape shape = null; if (inputShape != null) shape = DnnShape.of(Arrays.stream(inputShape).mapToLong(i -> i).toArray()); @@ -169,7 +169,7 @@ public static DnnModel createDnnModel(URI uri, String ndLayout, int[] in * @param outputs outputs shapes, if known; if these are null, an attempt will be made to get them from DJL (but this does not always work) * @return */ - public static DnnModel createDnnModel(String engine, URI uri, String ndLayout, Map inputs, Map outputs) { + public static DnnModel createDnnModel(String engine, URI uri, String ndLayout, Map inputs, Map outputs) { return createDnnModel(engine, Collections.singletonList(uri), ndLayout, inputs, outputs); } @@ -182,7 +182,7 @@ public static DnnModel createDnnModel(String engine, URI uri, String ndL * @param outputs outputs shapes, if known; if these are null, an attempt will be made to get them from DJL (but this does not always work) * @return */ - private static DnnModel createDnnModel(String engine, Collection uris, String ndLayout, Map inputs, Map outputs) { + private static DnnModel createDnnModel(String engine, Collection uris, String ndLayout, Map inputs, Map outputs) { return new DjlDnnModel(engine, uris, ndLayout, inputs, outputs, false); // Eagerly initialize (so we know if it doesn't work sooner) } @@ -291,8 +291,12 @@ public static Engine getEngine(String name, boolean downloadIfNeeded) throws Ill static DnnShape convertShape(Shape shape) { return DnnShape.of(shape.getShape()); } - + static ZooModel loadModel(String engineName, URI... uris) throws ModelNotFoundException, MalformedModelException, IOException { + return loadModel(engineName, NDList.class, NDList.class, null, uris); + } + + static ZooModel loadModel(String engineName, Class

inputClass, Class outputClass, Translator translator, URI... uris) throws ModelNotFoundException, MalformedModelException, IOException { var sb = new StringBuilder(); boolean isFirst = true; for (var uri : uris) { @@ -307,13 +311,14 @@ static ZooModel loadModel(String engineName, URI... uris) throws } sb.append(uri.toString()); } - return loadModel(engineName, sb.toString()); + return loadModel(engineName, inputClass, outputClass, translator, sb.toString()); } - private static ZooModel loadModel(String engineName, String urls) throws ModelNotFoundException, MalformedModelException, IOException { + private static ZooModel loadModel(String engineName, Class

inputClass, Class outputClass, Translator translator, String urls) throws ModelNotFoundException, MalformedModelException, IOException { var builder = Criteria.builder() - .setTypes(NDList.class, NDList.class) + .setTypes(inputClass, outputClass) .optModelUrls(urls) + .optTranslator(translator) .optProgress(new ProgressBar()); String selectedEngine = null; @@ -330,6 +335,8 @@ private static ZooModel loadModel(String engineName, String urls selectedEngine = "OnnxRuntime"; else if ((urlString.endsWith("pytorch") || urlString.endsWith(".pt")) && Engine.hasEngine("PyTorch")) selectedEngine = "PyTorch"; + else if (urlString.endsWith(".tflite") && Engine.hasEngine("TFLite")) + selectedEngine = "TFLite"; else if ((urlString.endsWith(".pb") || urlString.endsWith("tf_savedmodel.zip") || urlString.endsWith("tf_savedmodel")) && Engine.hasEngine("TensorFlow")) selectedEngine = "TensorFlow"; }