diff --git a/build.gradle b/build.gradle index f2e99d8..26ffa7d 100644 --- a/build.gradle +++ b/build.gradle @@ -12,7 +12,7 @@ ext.qupathVersion = gradle.ext.qupathVersion base { description = 'QuPath extension to use Deep Java Library' - version = "0.3.0" + version = "0.3.1-SNAPSHOT" group = 'io.github.qupath' } diff --git a/src/main/java/qupath/ext/djl/DjlDnnModel.java b/src/main/java/qupath/ext/djl/DjlDnnModel.java index 2beb026..63235ec 100644 --- a/src/main/java/qupath/ext/djl/DjlDnnModel.java +++ b/src/main/java/qupath/ext/djl/DjlDnnModel.java @@ -35,6 +35,7 @@ import java.net.URI; import java.util.ArrayList; import java.util.Collection; +import java.util.LinkedHashMap; import java.util.List; import java.util.Map; import java.util.Objects; @@ -127,7 +128,18 @@ public Map predict(Map blobs) { synchronized (predictor) { try { var result = predictor.predict(blobs.values().stream().toArray(Mat[]::new)); - return Map.of(DnnModel.DEFAULT_OUTPUT_NAME, result[0]); + if (result.length == 1) + return Map.of(DnnModel.DEFAULT_OUTPUT_NAME, result[0]); + else if (result.length == 0) + return Map.of(); + else { + // Try to handle multiple outputs, naming them sequentially + Map output = new LinkedHashMap<>(); + for (int i = 0; i < result.length; i++) { + output.put(DEFAULT_OUTPUT_NAME + i, result[i]); + } + return output; + } } catch (TranslateException e) { throw new RuntimeException(e); } @@ -195,7 +207,7 @@ private class ModelMatTranslator implements NoBatchifyTranslator { @Override public Mat[] processOutput(TranslatorContext ctx, NDList list) throws Exception { String layout; - if ((ndLayout == null || ndLayout.length() != list.singletonOrThrow().getShape().dimension()) && !list.isEmpty()) + if ((ndLayout == null || ndLayout.length() != list.get(0).getShape().dimension()) && !list.isEmpty()) layout = estimateOutputLayout(list.get(0)); else layout = ndLayout;