Skip to content

Commit

Permalink
Faster Mat to NDArray channels-first conversion
Browse files Browse the repository at this point in the history
  • Loading branch information
petebankhead committed Dec 23, 2024
1 parent d2dad6c commit f2e5ab0
Showing 1 changed file with 8 additions and 4 deletions.
12 changes: 8 additions & 4 deletions src/main/java/qupath/ext/djl/DjlTools.java
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
/*-
* Copyright 2022 QuPath developers, University of Edinburgh
* Copyright 2022-2024 QuPath developers, University of Edinburgh
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
Expand Down Expand Up @@ -29,7 +29,6 @@
import java.util.Set;

import ai.djl.Device;
import org.bytedeco.javacpp.Loader;
import org.bytedeco.javacpp.PointerScope;
import org.bytedeco.javacpp.indexer.BooleanIndexer;
import org.bytedeco.javacpp.indexer.ByteIndexer;
Expand All @@ -41,6 +40,7 @@
import org.bytedeco.javacpp.indexer.UByteIndexer;
import org.bytedeco.javacpp.indexer.UShortIndexer;
import org.bytedeco.opencv.global.opencv_core;
import org.bytedeco.opencv.global.opencv_dnn;
import org.bytedeco.opencv.opencv_core.Mat;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
Expand Down Expand Up @@ -389,7 +389,6 @@ static Mat predict(Model model, Mat mat) throws TranslateException {

/**
* Convert an Opencv {@link Mat} to a Deep Java Library {@link NDArray}.
* Note that this ass
* @param manager an {@link NDManager}, required to create the NDArray
* @param mat the mat to convert
* @param ndLayout a layout string for the NDArray, e.g. "CHW"; currently, HW must appear together (in that order)
Expand All @@ -408,10 +407,15 @@ public static NDArray matToNDArray(NDManager manager, Mat mat, String ndLayout)
// TODO: Check what this order is!!!
NDArray array = null;
if (indC > indHW || shape.get(indC) == 1) {
// Channels-last, or single-channel
var buffer = mat.createBuffer();
array = manager.create(buffer, shape, dataType);
} else if ("NCHW".equals(ndLayout) || "CHW".equals(ndLayout)) {
// Channels-first - an OpenCV blob is defined to have the order NCHW
array = manager.create(opencv_dnn.blobFromImage(mat).createBuffer(), shape, dataType);
} else {
var shapeDims = shape.getShape();
// Really awkward strategy to handle channels in an uncommon place (shouldn't actually occur?)
var shapeDims = shape.getShape().clone();
shapeDims[indC] = 1;
var shapeChannel = new Shape(shapeDims, shape.getLayout());

Expand Down

0 comments on commit f2e5ab0

Please sign in to comment.