Skip to content

Commit

Permalink
Changing output mapping shape when groups are used
Browse files Browse the repository at this point in the history
  • Loading branch information
lnstadrum committed Mar 28, 2024
1 parent e3e0e50 commit 3023c99
Show file tree
Hide file tree
Showing 3 changed files with 11 additions and 6 deletions.
8 changes: 4 additions & 4 deletions pytorch/test.py
Original file line number Diff line number Diff line change
Expand Up @@ -263,12 +263,12 @@ def test_coordinates_mapping(self):

# get coordinates of the spot in the augmented images
coords = torch.matmul(mappings, torch.tensor([x, y, 1], dtype=torch.float32).t())
coords = (coords[:, :2] / coords[:, 2:3]).round().to(torch.int32).numpy()
coords = (coords[..., :2] / coords[..., 2:3]).round().to(torch.int32).numpy()

# make sure it is in the output images
for group, (x, y) in zip(output_batch, coords):
if x >= 0 and x < group.shape[-2] and y >= 0 and y < group.shape[-3]:
for image in group:
for group in zip(output_batch, coords):
for image, (x, y) in zip(*group):
if x >= 0 and x < image.shape[-2] and y >= 0 and y < image.shape[-3]:
self.assertEqual(image[y, x, 0], 255)


Expand Down
2 changes: 1 addition & 1 deletion src/kernel_base.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -287,7 +287,7 @@ template <class TempGPUBuffer, typename... BufferAllocationArgs> class KernelBas
if (outputMappingPtr)
{
float *ptr = outputMappingPtr;
for (size_t i = 0; i < paramsCpu.size(); i += groups, ptr += 9)
for (size_t i = 0; i < paramsCpu.size(); ++i, ptr += 9)
{
// compute homography in normalized coordinates following the kernel implementation
const auto &a = paramsCpu[i].geom;
Expand Down
7 changes: 6 additions & 1 deletion src/pytorch.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -219,8 +219,13 @@ class TorchKernel : public fastaugment::KernelBase<TorchTempGPUBuffer, c10::cuda
torch::Tensor mapping;
if (outputMapping)
{
std::vector<int64_t> shape{3, 3};
if (groups > 0)
shape.emplace(shape.begin(), groups);
if (batchSize > 0)
shape.emplace(shape.begin(), batchSize);
auto opts = torch::TensorOptions().dtype(torch::kFloat32);
mapping = batchSize > 0 ? torch::empty({batchSize, 3, 3}, opts) : torch::empty({3, 3}, opts);
mapping = torch::empty(shape, opts);
}
auto outputMappingPtr = outputMapping ? mapping.expect_contiguous()->data_ptr<float>() : nullptr;

Expand Down

0 comments on commit 3023c99

Please sign in to comment.