From f2e5ab06895baaa7fe3f2ba926010f633fe52511 Mon Sep 17 00:00:00 2001 From: petebankhead Date: Mon, 23 Dec 2024 13:01:38 +0000 Subject: [PATCH] Faster Mat to NDArray channels-first conversion --- src/main/java/qupath/ext/djl/DjlTools.java | 12 ++++++++---- 1 file changed, 8 insertions(+), 4 deletions(-) diff --git a/src/main/java/qupath/ext/djl/DjlTools.java b/src/main/java/qupath/ext/djl/DjlTools.java index ff1cba3..1cdb6bf 100644 --- a/src/main/java/qupath/ext/djl/DjlTools.java +++ b/src/main/java/qupath/ext/djl/DjlTools.java @@ -1,5 +1,5 @@ /*- - * Copyright 2022 QuPath developers, University of Edinburgh + * Copyright 2022-2024 QuPath developers, University of Edinburgh * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -29,7 +29,6 @@ import java.util.Set; import ai.djl.Device; -import org.bytedeco.javacpp.Loader; import org.bytedeco.javacpp.PointerScope; import org.bytedeco.javacpp.indexer.BooleanIndexer; import org.bytedeco.javacpp.indexer.ByteIndexer; @@ -41,6 +40,7 @@ import org.bytedeco.javacpp.indexer.UByteIndexer; import org.bytedeco.javacpp.indexer.UShortIndexer; import org.bytedeco.opencv.global.opencv_core; +import org.bytedeco.opencv.global.opencv_dnn; import org.bytedeco.opencv.opencv_core.Mat; import org.slf4j.Logger; import org.slf4j.LoggerFactory; @@ -389,7 +389,6 @@ static Mat predict(Model model, Mat mat) throws TranslateException { /** * Convert an Opencv {@link Mat} to a Deep Java Library {@link NDArray}. - * Note that this ass * @param manager an {@link NDManager}, required to create the NDArray * @param mat the mat to convert * @param ndLayout a layout string for the NDArray, e.g. "CHW"; currently, HW must appear together (in that order) @@ -408,10 +407,15 @@ public static NDArray matToNDArray(NDManager manager, Mat mat, String ndLayout) // TODO: Check what this order is!!! NDArray array = null; if (indC > indHW || shape.get(indC) == 1) { + // Channels-last, or single-channel var buffer = mat.createBuffer(); array = manager.create(buffer, shape, dataType); + } else if ("NCHW".equals(ndLayout) || "CHW".equals(ndLayout)) { + // Channels-first - an OpenCV blob is defined to have the order NCHW + array = manager.create(opencv_dnn.blobFromImage(mat).createBuffer(), shape, dataType); } else { - var shapeDims = shape.getShape(); + // Really awkward strategy to handle channels in an uncommon place (shouldn't actually occur?) + var shapeDims = shape.getShape().clone(); shapeDims[indC] = 1; var shapeChannel = new Shape(shapeDims, shape.getLayout());