From 3023c99fa3fbc89954a460f6fed5eec58c2eabe3 Mon Sep 17 00:00:00 2001 From: lnstadrum <21985366+lnstadrum@users.noreply.github.com> Date: Thu, 28 Mar 2024 22:09:46 +0100 Subject: [PATCH] Changing output mapping shape when groups are used --- pytorch/test.py | 8 ++++---- src/kernel_base.hpp | 2 +- src/pytorch.cpp | 7 ++++++- 3 files changed, 11 insertions(+), 6 deletions(-) diff --git a/pytorch/test.py b/pytorch/test.py index d37cc25..78a0762 100644 --- a/pytorch/test.py +++ b/pytorch/test.py @@ -263,12 +263,12 @@ def test_coordinates_mapping(self): # get coordinates of the spot in the augmented images coords = torch.matmul(mappings, torch.tensor([x, y, 1], dtype=torch.float32).t()) - coords = (coords[:, :2] / coords[:, 2:3]).round().to(torch.int32).numpy() + coords = (coords[..., :2] / coords[..., 2:3]).round().to(torch.int32).numpy() # make sure it is in the output images - for group, (x, y) in zip(output_batch, coords): - if x >= 0 and x < group.shape[-2] and y >= 0 and y < group.shape[-3]: - for image in group: + for group in zip(output_batch, coords): + for image, (x, y) in zip(*group): + if x >= 0 and x < image.shape[-2] and y >= 0 and y < image.shape[-3]: self.assertEqual(image[y, x, 0], 255) diff --git a/src/kernel_base.hpp b/src/kernel_base.hpp index 35e8fc0..e4421f1 100644 --- a/src/kernel_base.hpp +++ b/src/kernel_base.hpp @@ -287,7 +287,7 @@ template class KernelBas if (outputMappingPtr) { float *ptr = outputMappingPtr; - for (size_t i = 0; i < paramsCpu.size(); i += groups, ptr += 9) + for (size_t i = 0; i < paramsCpu.size(); ++i, ptr += 9) { // compute homography in normalized coordinates following the kernel implementation const auto &a = paramsCpu[i].geom; diff --git a/src/pytorch.cpp b/src/pytorch.cpp index fc0c0c1..49ccd60 100644 --- a/src/pytorch.cpp +++ b/src/pytorch.cpp @@ -219,8 +219,13 @@ class TorchKernel : public fastaugment::KernelBase shape{3, 3}; + if (groups > 0) + shape.emplace(shape.begin(), groups); + if (batchSize > 0) + shape.emplace(shape.begin(), batchSize); auto opts = torch::TensorOptions().dtype(torch::kFloat32); - mapping = batchSize > 0 ? torch::empty({batchSize, 3, 3}, opts) : torch::empty({3, 3}, opts); + mapping = torch::empty(shape, opts); } auto outputMappingPtr = outputMapping ? mapping.expect_contiguous()->data_ptr() : nullptr;