diff --git a/pytorch/test.py b/pytorch/test.py index 596cbee..d37cc25 100644 --- a/pytorch/test.py +++ b/pytorch/test.py @@ -22,6 +22,14 @@ def test_specific_output_size(self): output_batch = self.augment(input_batch, output_size=[width, height]) self.assertEqual(output_batch.shape, (7, height, width, 3)) + def test_input_dims(self): + x = self.augment(torch.zeros(10, 20, 3).cuda()) + self.assertEqual(x.shape, (10, 20, 3)) + x = self.augment(torch.zeros(5, 10, 20, 3).cuda()) + self.assertEqual(x.shape, (5, 10, 20, 3)) + x = self.augment(torch.zeros(5, 2, 10, 20, 3).cuda()) + self.assertEqual(x.shape, (5, 2, 10, 20, 3)) + class ColorTests(unittest.TestCase): """Color-related tests""" @@ -108,9 +116,9 @@ class MixupLabelsTests(unittest.TestCase): def test_no_mixup(self): # make random input - input_batch = torch.randint(size=(8, 8, 8, 3), high=255) + input_batch = torch.randint(size=(8, 5, 8, 8, 3), high=255) input_batch = input_batch.to(torch.uint8).cuda() - input_labels = torch.rand(size=(8, 1000)) + input_labels = torch.rand(size=(8, 5, 1000)) # apply random transformation _, output_labels = FastAugment()(input_batch, input_labels) @@ -127,9 +135,8 @@ def test_yes_mixup(self): input_batch = input_batch.reshape(-1, 1, 1, 1).repeat(1, 5, 5, 3) # transform labels to one-hot - input_proba = torch.nn.functional.one_hot(input_labels.to(torch.long), 2).to( - torch.float32 - ) + input_labels = torch.nn.functional.one_hot(input_labels.to(torch.long), 2) \ + .to(torch.float32) # apply mixup augment = FastAugment( @@ -143,13 +150,13 @@ def test_yes_mixup(self): cutout=0, mixup=0.9, ) - output_batch, output_proba = augment(input_batch, input_proba) + output_batch, output_labels = augment(input_batch, input_labels) # check that probabilities sum up to 1 - assert torch.allclose(output_proba[:, 0] + output_proba[:, 1], torch.ones((50))) + assert torch.allclose(output_labels[:, 0] + output_labels[:, 1], torch.ones((50))) # compare probabilities to center pixel values - assert torch.allclose(output_proba[:, 1], output_batch[:, 3, 3, 0].cpu()) + assert torch.allclose(output_labels[:, 1], output_batch[:, 3, 3, 0].cpu()) class SeedTests(unittest.TestCase): @@ -231,9 +238,9 @@ def test_uint8_vs_float32(self): class CoordinatesMappingTest(unittest.TestCase): def test_coordinates_mapping(self): # generate random batch of zeros with a bright spot at a known position - input_batch = torch.zeros((30, 120, 250, 3), dtype=torch.uint8).cuda() + input_batch = torch.zeros((30, 2, 120, 250, 3), dtype=torch.uint8).cuda() y, x = 28, 222 - input_batch[:, y-2:y+2, x-2:x+2, :] = 255 + input_batch[..., y-2:y+2, x-2:x+2, :] = 255 # perform augmentation augment = FastAugment(gamma_corr=0, @@ -259,9 +266,10 @@ def test_coordinates_mapping(self): coords = (coords[:, :2] / coords[:, 2:3]).round().to(torch.int32).numpy() # make sure it is in the output images - for image, (x, y) in zip(output_batch, coords): - if x >= 0 and x < image.shape[-2] and y >= 0 and y < image.shape[-3]: - self.assertEqual(image[y, x, 0], 255) + for group, (x, y) in zip(output_batch, coords): + if x >= 0 and x < group.shape[-2] and y >= 0 and y < group.shape[-3]: + for image in group: + self.assertEqual(image[y, x, 0], 255) if __name__ == "__main__": diff --git a/src/kernel_base.hpp b/src/kernel_base.hpp index e6b2fa7..35e8fc0 100644 --- a/src/kernel_base.hpp +++ b/src/kernel_base.hpp @@ -117,33 +117,25 @@ template class KernelBas * @param settings Randomization settings * @param inputPtr Pointer to the input batch tensor in GPU memory * @param outputPtr Pointer to the output batch tensor in GPU memory - * @param inputLabelsPtr Pointer to the input class probabilities tensor - * in host memory - * @param outputLabelsPtr Pointer to the output class probabilities tensor - * in host memory + * @param inputLabelsPtr Pointer to the input class probabilities tensor in host memory + * @param outputLabelsPtr Pointer to the output class probabilities tensor in host memory * @param outputMappingPtr Pointer to the output homography tensor in host memory - * @param batchSize Batch size; 0 if 3-dimensional input tensor is - * given + * @param batchSize Batch size + * @param groups Group size * @param inputHeight Input batch height in pixels * @param inputWidth Input batch width in pixels * @param outputHeight Output batch height in pixels * @param outputWidth Output batch width in pixels * @param numClasses Number of classes * @param stream CUDA stream - * @param allocationArgs Frontend-specific TempGPUBuffer allocation - * arguments + * @param allocationArgs Frontend-specific TempGPUBuffer allocation arguments */ template void run(const Settings &settings, const in_t *inputPtr, out_t *outputPtr, const float *inputLabelsPtr, - float *outputLabelsPtr, float *outputMappingPtr, int64_t batchSize, int64_t inputHeight, + float *outputLabelsPtr, float *outputMappingPtr, int64_t batchSize, int64_t groups, int64_t inputHeight, int64_t inputWidth, int64_t outputHeight, int64_t outputWidth, int64_t numClasses, cudaStream_t stream, BufferAllocationArgs... allocationArgs) { - // correct batchSize value (can be zero if input is a 3-dim tensor) - const bool isBatch = batchSize > 0; - if (!isBatch) - batchSize = 1; - // compute scale factors to keep the aspect ratio float arScaleX = 1, arScaleY = 1; if (inputWidth * outputHeight >= inputHeight * outputWidth) @@ -154,7 +146,7 @@ template class KernelBas // allocate a temporary buffer ensuring starting address and pitch // alignment const int64_t pitchBytes = roundUp(inputWidth * 4 * sizeof(in_t), texturePitchAlignment); - const size_t bufferSizeBytes = textureAlignment + batchSize * inputHeight * pitchBytes; + const size_t bufferSizeBytes = textureAlignment + batchSize * groups * inputHeight * pitchBytes; TempGPUBuffer buffer(bufferSizeBytes, allocationArgs...); auto unalignedBufferAddress = buffer(); @@ -163,7 +155,8 @@ template class KernelBas reinterpret_cast(roundUp(reinterpret_cast(unalignedBufferAddress), textureAlignment)); // pad the input to have 4 channels and an aligned pitch - padChannels(stream, inputPtr, bufferPtr, inputWidth, inputHeight, batchSize, pitchBytes / (4 * sizeof(in_t))); + padChannels(stream, inputPtr, bufferPtr, inputWidth, inputHeight, batchSize * groups, + pitchBytes / (4 * sizeof(in_t))); reportCudaError(cudaGetLastError(), "Cannot pad the input image"); // check if no labels but mixup @@ -172,7 +165,7 @@ template class KernelBas // prepare parameters samplers std::vector paramsCpu; - paramsCpu.resize(batchSize); + paramsCpu.resize(batchSize * groups); std::uniform_real_distribution xShiftFactor(-settings.translation[0], settings.translation[0]), yShiftFactor(-settings.translation[1], settings.translation[1]), xScaleFactor(settings.prescale - settings.scale[0], settings.prescale + settings.scale[0]), @@ -189,7 +182,7 @@ template class KernelBas std::gamma_distribution<> mixupGamma(settings.mixupAlpha, 1); // sample transformation parameters - for (size_t i = 0; i < paramsCpu.size(); ++i) + for (size_t i = 0; i < paramsCpu.size(); i += groups) { auto &img = paramsCpu[i]; img.flags = 0; @@ -226,11 +219,11 @@ template class KernelBas // Mixup params if (mixupApplication(rnd) < settings.mixupProb) { - img.mixImgIdx = (i + mixIdx(rnd)) % batchSize; + img.mixImgIdx = ((i / groups + mixIdx(rnd)) % batchSize) * groups; float x = mixupGamma(rnd); img.mixFactor = x / (x + mixupGamma(rnd)); // beta distribution generation // trick using gamma distribution - if (img.mixFactor > 0.5) + if (img.mixFactor > 0.5f) img.mixFactor = 1 - img.mixFactor; // making sure the current image has higher contribution to avoid // duplicates @@ -239,10 +232,17 @@ template class KernelBas { img.mixImgIdx = i; } + + // propagate the same parameters across the group + for (size_t j = 1; j < static_cast(groups); ++j) + { + paramsCpu[i + j] = img; + paramsCpu[i + j].mixImgIdx += j; + } } // create temporary tensor for parameters - const size_t paramsSizeBytes = sizeof(Params) * batchSize; + const size_t paramsSizeBytes = sizeof(Params) * batchSize * groups; TempGPUBuffer paramsGpu(paramsSizeBytes, allocationArgs...); // copy parameters to GPU @@ -259,8 +259,8 @@ template class KernelBas inputWidth, inputHeight, pitchBytes, // input sizes outputWidth, - outputHeight, // output sizes - batchSize, // batch size + outputHeight, // output sizes + batchSize * groups, // batch size maxTextureHeight, paramsGpuPtr); // transformation description } @@ -274,7 +274,7 @@ template class KernelBas { const float *inLabel = inputLabelsPtr; float *outLabel = outputLabelsPtr; - for (int64_t n = 0; n < batchSize; ++n, inLabel += numClasses, outLabel += numClasses) + for (int64_t n = 0; n < batchSize * groups; ++n, inLabel += numClasses, outLabel += numClasses) { float f = paramsCpu[n].mixFactor; const float *mixLabel = inputLabelsPtr + paramsCpu[n].mixImgIdx * numClasses; @@ -287,7 +287,7 @@ template class KernelBas if (outputMappingPtr) { float *ptr = outputMappingPtr; - for (size_t i = 0; i < paramsCpu.size(); ++i, ptr += 9) + for (size_t i = 0; i < paramsCpu.size(); i += groups, ptr += 9) { // compute homography in normalized coordinates following the kernel implementation const auto &a = paramsCpu[i].geom; diff --git a/src/pytorch.cpp b/src/pytorch.cpp index f1859fd..fc0c0c1 100644 --- a/src/pytorch.cpp +++ b/src/pytorch.cpp @@ -150,8 +150,8 @@ class TorchKernel : public fastaugment::KernelBase 5) + throw std::invalid_argument("Expected a 3-, 4- or 5-dimensional input tensor, got " + std::to_string(input.dim()) + " dimensions"); if (!input.is_cuda()) throw std::invalid_argument("Expected an input tensor in GPU memory (likely missing a .cuda() " @@ -162,15 +162,15 @@ class TorchKernel : public fastaugment::KernelBase= 4 ? input.size(0) : 0, groups = input.dim() == 5 ? input.size(1) : 0, + inputHeight = input.size(input.dim() - 3), inputWidth = input.size(input.dim() - 2), + inputChannels = input.size(input.dim() - 1), outputWidth = outputSize.empty() ? inputWidth : outputSize[0], outputHeight = outputSize.empty() ? inputHeight : outputSize[1]; // check number of input channels if (inputChannels != 3) - throw std::invalid_argument("Expected a 3-channel channels-last (NHWC) input tensor, got " + + throw std::invalid_argument("Expected a 3-channel channels-last (*HW3) input tensor, got " + std::to_string(inputChannels) + " channels"); // get CUDA stream @@ -178,15 +178,24 @@ class TorchKernel : public fastaugment::KernelBase(groups, 1) * numClasses; + if (labels.numel() != expectedNumElems) + throw std::invalid_argument("Expected " + std::to_string(expectedNumElems) + + " elements in the labels tensor but got " + std::to_string(labels.numel())); + if (!labels.is_cpu()) throw std::invalid_argument("Expected an input_labels tensor stored in RAM (likely missing a " ".cpu() call)"); @@ -198,8 +207,11 @@ class TorchKernel : public fastaugment::KernelBase{batchSize, outputHeight, outputWidth, 3} - : std::vector{outputHeight, outputWidth, 3}); + std::vector outputShape{outputHeight, outputWidth, 3}; + if (groups > 0) + outputShape.emplace(outputShape.begin(), groups); + if (batchSize > 0) + outputShape.emplace(outputShape.begin(), batchSize); torch::Tensor output = torch::empty(outputShape, outputOptions); torch::Tensor outputLabels = torch::empty_like(labels); @@ -208,14 +220,14 @@ class TorchKernel : public fastaugment::KernelBase 0 ? torch::empty({batchSize, 3, 3}, opts) : torch::empty({3, 3}, opts); } auto outputMappingPtr = outputMapping ? mapping.expect_contiguous()->data_ptr() : nullptr; // launch the kernel - launchKernel(input, output, inputLabelsPtr, outputLabels.data_ptr(), outputMappingPtr, batchSize, - inputHeight, inputWidth, outputHeight, outputWidth, noLabels ? 0 : labels.size(1), stream.stream(), - stream); + launchKernel(input, output, inputLabelsPtr, outputLabels.data_ptr(), outputMappingPtr, + std::max(batchSize, 1), std::max(groups, 1), inputHeight, inputWidth, + outputHeight, outputWidth, numClasses, stream.stream(), stream); return {output, outputLabels, mapping}; } diff --git a/src/tensorflow.cpp b/src/tensorflow.cpp index d0b51a9..00e1a76 100644 --- a/src/tensorflow.cpp +++ b/src/tensorflow.cpp @@ -214,8 +214,8 @@ class FastAugmentTFOpKernel : public OpKernel, { fastaugment::KernelBase::run( *this, inputTensor.flat().data(), outputTensor->flat().data(), inputLabelsPtr, - outputLabelsTensor->flat().data(), nullptr, batchSize, inputHeight, inputWidth, outputHeight, - outputWidth, noLabels ? 0 : labelsShape.dim_size(1), stream, context); + outputLabelsTensor->flat().data(), nullptr, isBatch ? batchSize : 1, 1, inputHeight, inputWidth, + outputHeight, outputWidth, noLabels ? 0 : labelsShape.dim_size(1), stream, context); } catch (std::exception &ex) {