Skip to content

Commit

Permalink
Introducing groups dimension
Browse files Browse the repository at this point in the history
  • Loading branch information
lnstadrum committed Mar 27, 2024
1 parent 2f30617 commit e3e0e50
Show file tree
Hide file tree
Showing 4 changed files with 75 additions and 55 deletions.
34 changes: 21 additions & 13 deletions pytorch/test.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,14 @@ def test_specific_output_size(self):
output_batch = self.augment(input_batch, output_size=[width, height])
self.assertEqual(output_batch.shape, (7, height, width, 3))

def test_input_dims(self):
x = self.augment(torch.zeros(10, 20, 3).cuda())
self.assertEqual(x.shape, (10, 20, 3))
x = self.augment(torch.zeros(5, 10, 20, 3).cuda())
self.assertEqual(x.shape, (5, 10, 20, 3))
x = self.augment(torch.zeros(5, 2, 10, 20, 3).cuda())
self.assertEqual(x.shape, (5, 2, 10, 20, 3))


class ColorTests(unittest.TestCase):
"""Color-related tests"""
Expand Down Expand Up @@ -108,9 +116,9 @@ class MixupLabelsTests(unittest.TestCase):

def test_no_mixup(self):
# make random input
input_batch = torch.randint(size=(8, 8, 8, 3), high=255)
input_batch = torch.randint(size=(8, 5, 8, 8, 3), high=255)
input_batch = input_batch.to(torch.uint8).cuda()
input_labels = torch.rand(size=(8, 1000))
input_labels = torch.rand(size=(8, 5, 1000))

# apply random transformation
_, output_labels = FastAugment()(input_batch, input_labels)
Expand All @@ -127,9 +135,8 @@ def test_yes_mixup(self):
input_batch = input_batch.reshape(-1, 1, 1, 1).repeat(1, 5, 5, 3)

# transform labels to one-hot
input_proba = torch.nn.functional.one_hot(input_labels.to(torch.long), 2).to(
torch.float32
)
input_labels = torch.nn.functional.one_hot(input_labels.to(torch.long), 2) \
.to(torch.float32)

# apply mixup
augment = FastAugment(
Expand All @@ -143,13 +150,13 @@ def test_yes_mixup(self):
cutout=0,
mixup=0.9,
)
output_batch, output_proba = augment(input_batch, input_proba)
output_batch, output_labels = augment(input_batch, input_labels)

# check that probabilities sum up to 1
assert torch.allclose(output_proba[:, 0] + output_proba[:, 1], torch.ones((50)))
assert torch.allclose(output_labels[:, 0] + output_labels[:, 1], torch.ones((50)))

# compare probabilities to center pixel values
assert torch.allclose(output_proba[:, 1], output_batch[:, 3, 3, 0].cpu())
assert torch.allclose(output_labels[:, 1], output_batch[:, 3, 3, 0].cpu())


class SeedTests(unittest.TestCase):
Expand Down Expand Up @@ -231,9 +238,9 @@ def test_uint8_vs_float32(self):
class CoordinatesMappingTest(unittest.TestCase):
def test_coordinates_mapping(self):
# generate random batch of zeros with a bright spot at a known position
input_batch = torch.zeros((30, 120, 250, 3), dtype=torch.uint8).cuda()
input_batch = torch.zeros((30, 2, 120, 250, 3), dtype=torch.uint8).cuda()
y, x = 28, 222
input_batch[:, y-2:y+2, x-2:x+2, :] = 255
input_batch[..., y-2:y+2, x-2:x+2, :] = 255

# perform augmentation
augment = FastAugment(gamma_corr=0,
Expand All @@ -259,9 +266,10 @@ def test_coordinates_mapping(self):
coords = (coords[:, :2] / coords[:, 2:3]).round().to(torch.int32).numpy()

# make sure it is in the output images
for image, (x, y) in zip(output_batch, coords):
if x >= 0 and x < image.shape[-2] and y >= 0 and y < image.shape[-3]:
self.assertEqual(image[y, x, 0], 255)
for group, (x, y) in zip(output_batch, coords):
if x >= 0 and x < group.shape[-2] and y >= 0 and y < group.shape[-3]:
for image in group:
self.assertEqual(image[y, x, 0], 255)


if __name__ == "__main__":
Expand Down
50 changes: 25 additions & 25 deletions src/kernel_base.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -117,33 +117,25 @@ template <class TempGPUBuffer, typename... BufferAllocationArgs> class KernelBas
* @param settings Randomization settings
* @param inputPtr Pointer to the input batch tensor in GPU memory
* @param outputPtr Pointer to the output batch tensor in GPU memory
* @param inputLabelsPtr Pointer to the input class probabilities tensor
* in host memory
* @param outputLabelsPtr Pointer to the output class probabilities tensor
* in host memory
* @param inputLabelsPtr Pointer to the input class probabilities tensor in host memory
* @param outputLabelsPtr Pointer to the output class probabilities tensor in host memory
* @param outputMappingPtr Pointer to the output homography tensor in host memory
* @param batchSize Batch size; 0 if 3-dimensional input tensor is
* given
* @param batchSize Batch size
* @param groups Group size
* @param inputHeight Input batch height in pixels
* @param inputWidth Input batch width in pixels
* @param outputHeight Output batch height in pixels
* @param outputWidth Output batch width in pixels
* @param numClasses Number of classes
* @param stream CUDA stream
* @param allocationArgs Frontend-specific TempGPUBuffer allocation
* arguments
* @param allocationArgs Frontend-specific TempGPUBuffer allocation arguments
*/
template <typename in_t, typename out_t>
void run(const Settings &settings, const in_t *inputPtr, out_t *outputPtr, const float *inputLabelsPtr,
float *outputLabelsPtr, float *outputMappingPtr, int64_t batchSize, int64_t inputHeight,
float *outputLabelsPtr, float *outputMappingPtr, int64_t batchSize, int64_t groups, int64_t inputHeight,
int64_t inputWidth, int64_t outputHeight, int64_t outputWidth, int64_t numClasses, cudaStream_t stream,
BufferAllocationArgs... allocationArgs)
{
// correct batchSize value (can be zero if input is a 3-dim tensor)
const bool isBatch = batchSize > 0;
if (!isBatch)
batchSize = 1;

// compute scale factors to keep the aspect ratio
float arScaleX = 1, arScaleY = 1;
if (inputWidth * outputHeight >= inputHeight * outputWidth)
Expand All @@ -154,7 +146,7 @@ template <class TempGPUBuffer, typename... BufferAllocationArgs> class KernelBas
// allocate a temporary buffer ensuring starting address and pitch
// alignment
const int64_t pitchBytes = roundUp(inputWidth * 4 * sizeof(in_t), texturePitchAlignment);
const size_t bufferSizeBytes = textureAlignment + batchSize * inputHeight * pitchBytes;
const size_t bufferSizeBytes = textureAlignment + batchSize * groups * inputHeight * pitchBytes;
TempGPUBuffer buffer(bufferSizeBytes, allocationArgs...);
auto unalignedBufferAddress = buffer();

Expand All @@ -163,7 +155,8 @@ template <class TempGPUBuffer, typename... BufferAllocationArgs> class KernelBas
reinterpret_cast<in_t *>(roundUp(reinterpret_cast<size_t>(unalignedBufferAddress), textureAlignment));

// pad the input to have 4 channels and an aligned pitch
padChannels(stream, inputPtr, bufferPtr, inputWidth, inputHeight, batchSize, pitchBytes / (4 * sizeof(in_t)));
padChannels(stream, inputPtr, bufferPtr, inputWidth, inputHeight, batchSize * groups,
pitchBytes / (4 * sizeof(in_t)));
reportCudaError(cudaGetLastError(), "Cannot pad the input image");

// check if no labels but mixup
Expand All @@ -172,7 +165,7 @@ template <class TempGPUBuffer, typename... BufferAllocationArgs> class KernelBas

// prepare parameters samplers
std::vector<Params> paramsCpu;
paramsCpu.resize(batchSize);
paramsCpu.resize(batchSize * groups);
std::uniform_real_distribution<float> xShiftFactor(-settings.translation[0], settings.translation[0]),
yShiftFactor(-settings.translation[1], settings.translation[1]),
xScaleFactor(settings.prescale - settings.scale[0], settings.prescale + settings.scale[0]),
Expand All @@ -189,7 +182,7 @@ template <class TempGPUBuffer, typename... BufferAllocationArgs> class KernelBas
std::gamma_distribution<> mixupGamma(settings.mixupAlpha, 1);

// sample transformation parameters
for (size_t i = 0; i < paramsCpu.size(); ++i)
for (size_t i = 0; i < paramsCpu.size(); i += groups)
{
auto &img = paramsCpu[i];
img.flags = 0;
Expand Down Expand Up @@ -226,11 +219,11 @@ template <class TempGPUBuffer, typename... BufferAllocationArgs> class KernelBas
// Mixup params
if (mixupApplication(rnd) < settings.mixupProb)
{
img.mixImgIdx = (i + mixIdx(rnd)) % batchSize;
img.mixImgIdx = ((i / groups + mixIdx(rnd)) % batchSize) * groups;
float x = mixupGamma(rnd);
img.mixFactor = x / (x + mixupGamma(rnd)); // beta distribution generation
// trick using gamma distribution
if (img.mixFactor > 0.5)
if (img.mixFactor > 0.5f)
img.mixFactor = 1 - img.mixFactor;
// making sure the current image has higher contribution to avoid
// duplicates
Expand All @@ -239,10 +232,17 @@ template <class TempGPUBuffer, typename... BufferAllocationArgs> class KernelBas
{
img.mixImgIdx = i;
}

// propagate the same parameters across the group
for (size_t j = 1; j < static_cast<size_t>(groups); ++j)
{
paramsCpu[i + j] = img;
paramsCpu[i + j].mixImgIdx += j;
}
}

// create temporary tensor for parameters
const size_t paramsSizeBytes = sizeof(Params) * batchSize;
const size_t paramsSizeBytes = sizeof(Params) * batchSize * groups;
TempGPUBuffer paramsGpu(paramsSizeBytes, allocationArgs...);

// copy parameters to GPU
Expand All @@ -259,8 +259,8 @@ template <class TempGPUBuffer, typename... BufferAllocationArgs> class KernelBas
inputWidth, inputHeight,
pitchBytes, // input sizes
outputWidth,
outputHeight, // output sizes
batchSize, // batch size
outputHeight, // output sizes
batchSize * groups, // batch size
maxTextureHeight,
paramsGpuPtr); // transformation description
}
Expand All @@ -274,7 +274,7 @@ template <class TempGPUBuffer, typename... BufferAllocationArgs> class KernelBas
{
const float *inLabel = inputLabelsPtr;
float *outLabel = outputLabelsPtr;
for (int64_t n = 0; n < batchSize; ++n, inLabel += numClasses, outLabel += numClasses)
for (int64_t n = 0; n < batchSize * groups; ++n, inLabel += numClasses, outLabel += numClasses)
{
float f = paramsCpu[n].mixFactor;
const float *mixLabel = inputLabelsPtr + paramsCpu[n].mixImgIdx * numClasses;
Expand All @@ -287,7 +287,7 @@ template <class TempGPUBuffer, typename... BufferAllocationArgs> class KernelBas
if (outputMappingPtr)
{
float *ptr = outputMappingPtr;
for (size_t i = 0; i < paramsCpu.size(); ++i, ptr += 9)
for (size_t i = 0; i < paramsCpu.size(); i += groups, ptr += 9)
{
// compute homography in normalized coordinates following the kernel implementation
const auto &a = paramsCpu[i].geom;
Expand Down
42 changes: 27 additions & 15 deletions src/pytorch.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -150,8 +150,8 @@ class TorchKernel : public fastaugment::KernelBase<TorchTempGPUBuffer, c10::cuda
std::to_string(outputSize.size()));

// check the input tensor
if (input.dim() != 3 && input.dim() != 4)
throw std::invalid_argument("Expected a 3- or 4-dimensional input tensor, got " +
if (input.dim() < 3 || input.dim() > 5)
throw std::invalid_argument("Expected a 3-, 4- or 5-dimensional input tensor, got " +
std::to_string(input.dim()) + " dimensions");
if (!input.is_cuda())
throw std::invalid_argument("Expected an input tensor in GPU memory (likely missing a .cuda() "
Expand All @@ -162,31 +162,40 @@ class TorchKernel : public fastaugment::KernelBase<TorchTempGPUBuffer, c10::cuda
throw std::invalid_argument("Expected uint8 or float input tensor");

// get input sizes
const bool isBatch = input.dim() == 4;
const int64_t batchSize = isBatch ? input.size(0) : 0, inputHeight = input.size(isBatch ? 1 : 0),
inputWidth = input.size(isBatch ? 2 : 1), inputChannels = input.size(isBatch ? 3 : 2),
const int64_t batchSize = input.dim() >= 4 ? input.size(0) : 0, groups = input.dim() == 5 ? input.size(1) : 0,
inputHeight = input.size(input.dim() - 3), inputWidth = input.size(input.dim() - 2),
inputChannels = input.size(input.dim() - 1),
outputWidth = outputSize.empty() ? inputWidth : outputSize[0],
outputHeight = outputSize.empty() ? inputHeight : outputSize[1];

// check number of input channels
if (inputChannels != 3)
throw std::invalid_argument("Expected a 3-channel channels-last (NHWC) input tensor, got " +
throw std::invalid_argument("Expected a 3-channel channels-last (*HW3) input tensor, got " +
std::to_string(inputChannels) + " channels");

// get CUDA stream
auto stream = c10::cuda::getCurrentCUDAStream(input.device().index());

// get input labels tensor
const bool noLabels = labels.dim() == 0 || (labels.dim() == 1 && labels.size(0) == 0);
const auto numClasses = noLabels ? 0 : labels.size(labels.dim() - 1);
if (!noLabels)
{
if (labels.dim() != 2)
throw std::invalid_argument("Expected a 2-dimensional input_labels tensor, got " +
std::to_string(labels.dim()) + " dimensions");
const auto expectedDims = input.dim() - 2;
if (labels.dim() != expectedDims)
throw std::invalid_argument("Expected a " + std::to_string(expectedDims) +
"-dimensional input_labels tensor, got " + std::to_string(labels.dim()) +
" dimensions");
if (labels.size(0) != batchSize)
throw std::invalid_argument("First dimension of the input labels tensor is expected to match "
"the batch size, but got " +
std::to_string(labels.size(0)));

const auto expectedNumElems = batchSize * std::max<int64_t>(groups, 1) * numClasses;
if (labels.numel() != expectedNumElems)
throw std::invalid_argument("Expected " + std::to_string(expectedNumElems) +
" elements in the labels tensor but got " + std::to_string(labels.numel()));

if (!labels.is_cpu())
throw std::invalid_argument("Expected an input_labels tensor stored in RAM (likely missing a "
".cpu() call)");
Expand All @@ -198,8 +207,11 @@ class TorchKernel : public fastaugment::KernelBase<TorchTempGPUBuffer, c10::cuda
// allocate output tensors
auto outputOptions =
torch::TensorOptions().device(input.device()).dtype(isFloat32Output ? torch::kFloat32 : torch::kUInt8);
auto outputShape(isBatch ? std::vector<int64_t>{batchSize, outputHeight, outputWidth, 3}
: std::vector<int64_t>{outputHeight, outputWidth, 3});
std::vector<int64_t> outputShape{outputHeight, outputWidth, 3};
if (groups > 0)
outputShape.emplace(outputShape.begin(), groups);
if (batchSize > 0)
outputShape.emplace(outputShape.begin(), batchSize);

torch::Tensor output = torch::empty(outputShape, outputOptions);
torch::Tensor outputLabels = torch::empty_like(labels);
Expand All @@ -208,14 +220,14 @@ class TorchKernel : public fastaugment::KernelBase<TorchTempGPUBuffer, c10::cuda
if (outputMapping)
{
auto opts = torch::TensorOptions().dtype(torch::kFloat32);
mapping = torch::empty({batchSize, 3, 3}, opts);
mapping = batchSize > 0 ? torch::empty({batchSize, 3, 3}, opts) : torch::empty({3, 3}, opts);
}
auto outputMappingPtr = outputMapping ? mapping.expect_contiguous()->data_ptr<float>() : nullptr;

// launch the kernel
launchKernel(input, output, inputLabelsPtr, outputLabels.data_ptr<float>(), outputMappingPtr, batchSize,
inputHeight, inputWidth, outputHeight, outputWidth, noLabels ? 0 : labels.size(1), stream.stream(),
stream);
launchKernel(input, output, inputLabelsPtr, outputLabels.data_ptr<float>(), outputMappingPtr,
std::max<int64_t>(batchSize, 1), std::max<int64_t>(groups, 1), inputHeight, inputWidth,
outputHeight, outputWidth, numClasses, stream.stream(), stream);

return {output, outputLabels, mapping};
}
Expand Down
4 changes: 2 additions & 2 deletions src/tensorflow.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -214,8 +214,8 @@ class FastAugmentTFOpKernel : public OpKernel,
{
fastaugment::KernelBase<TFTempGPUBuffer, OpKernelContext *>::run<in_t, out_t>(
*this, inputTensor.flat<in_t>().data(), outputTensor->flat<out_t>().data(), inputLabelsPtr,
outputLabelsTensor->flat<float>().data(), nullptr, batchSize, inputHeight, inputWidth, outputHeight,
outputWidth, noLabels ? 0 : labelsShape.dim_size(1), stream, context);
outputLabelsTensor->flat<float>().data(), nullptr, isBatch ? batchSize : 1, 1, inputHeight, inputWidth,
outputHeight, outputWidth, noLabels ? 0 : labelsShape.dim_size(1), stream, context);
}
catch (std::exception &ex)
{
Expand Down

0 comments on commit e3e0e50

Please sign in to comment.