Skip to content

Commit

Permalink
Merge pull request #16 from petebankhead/multi-output
Browse files Browse the repository at this point in the history
Update DjlDnnModel.java
  • Loading branch information
petebankhead authored Dec 29, 2023
2 parents 0b3a26c + 83ae811 commit e73df2f
Show file tree
Hide file tree
Showing 2 changed files with 15 additions and 3 deletions.
2 changes: 1 addition & 1 deletion build.gradle
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@ ext.qupathVersion = gradle.ext.qupathVersion

base {
description = 'QuPath extension to use Deep Java Library'
version = "0.3.0"
version = "0.3.1-SNAPSHOT"
group = 'io.github.qupath'
}

Expand Down
16 changes: 14 additions & 2 deletions src/main/java/qupath/ext/djl/DjlDnnModel.java
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,7 @@
import java.net.URI;
import java.util.ArrayList;
import java.util.Collection;
import java.util.LinkedHashMap;
import java.util.List;
import java.util.Map;
import java.util.Objects;
Expand Down Expand Up @@ -127,7 +128,18 @@ public Map<String, Mat> predict(Map<String, Mat> blobs) {
synchronized (predictor) {
try {
var result = predictor.predict(blobs.values().stream().toArray(Mat[]::new));
return Map.of(DnnModel.DEFAULT_OUTPUT_NAME, result[0]);
if (result.length == 1)
return Map.of(DnnModel.DEFAULT_OUTPUT_NAME, result[0]);
else if (result.length == 0)
return Map.of();
else {
// Try to handle multiple outputs, naming them sequentially
Map<String, Mat> output = new LinkedHashMap<>();
for (int i = 0; i < result.length; i++) {
output.put(DEFAULT_OUTPUT_NAME + i, result[i]);
}
return output;
}
} catch (TranslateException e) {
throw new RuntimeException(e);
}
Expand Down Expand Up @@ -195,7 +207,7 @@ private class ModelMatTranslator implements NoBatchifyTranslator<Mat[], Mat[]> {
@Override
public Mat[] processOutput(TranslatorContext ctx, NDList list) throws Exception {
String layout;
if ((ndLayout == null || ndLayout.length() != list.singletonOrThrow().getShape().dimension()) && !list.isEmpty())
if ((ndLayout == null || ndLayout.length() != list.get(0).getShape().dimension()) && !list.isEmpty())
layout = estimateOutputLayout(list.get(0));
else
layout = ndLayout;
Expand Down

0 comments on commit e73df2f

Please sign in to comment.