Skip to content

Commit

Permalink
bugfix
Browse files Browse the repository at this point in the history
-fixed R2C/C2R mode for some non-power of 2 sequences
  • Loading branch information
DTolm committed Jan 22, 2021
1 parent 9e1eca1 commit 1c6f9cb
Showing 1 changed file with 17 additions and 4 deletions.
21 changes: 17 additions & 4 deletions vkFFT/vkFFT.h
Original file line number Diff line number Diff line change
Expand Up @@ -1686,10 +1686,10 @@ layout(std430, binding = %d) readonly buffer DataLUT {\n\
}
sprintf(output + strlen(output), " }\n");
sprintf(output + strlen(output), " else\n");
if (sc.readToRegisters)
if (sc.readToRegisters)
sprintf(output + strlen(output), " temp_%d = %s(0,0);\n", i, vecType);
else
sprintf(output + strlen(output), " sdata[gl_WorkGroupSize.x*(gl_LocalInvocationID.y+%d)+gl_LocalInvocationID.x]=%s(0,0);\n", i* sc.localSize[1], vecType);
sprintf(output + strlen(output), " sdata[gl_WorkGroupSize.x*(gl_LocalInvocationID.y+%d)+gl_LocalInvocationID.x]=%s(0,0);\n", i * sc.localSize[1], vecType);
}

}
Expand Down Expand Up @@ -2069,7 +2069,9 @@ layout(std430, binding = %d) readonly buffer DataLUT {\n\
for (uint32_t i = 0; i < ceil(sc.min_registers_per_thread / 2.0); i++) {
if ((uint32_t)ceil(sc.size[1] / 2.0) % sc.localSize[1] != 0)
sprintf(output + strlen(output), " if(gl_GlobalInvocationID.y%s < %d){", shiftY, (uint32_t)ceil(sc.size[1] / 2.0));

if ((ceil(sc.min_registers_per_thread / 2.0) != sc.min_registers_per_thread / 2) && (i == (ceil(sc.min_registers_per_thread / 2.0) - 1)))
sprintf(output + strlen(output), "if (gl_LocalInvocationID.x < %d){\n", sc.fftDim / 2 - i * sc.localSize[0]);

sprintf(output + strlen(output), " inoutID = gl_LocalInvocationID.x+%d;\n", i * sc.localSize[0]);

sprintf(output + strlen(output), " if((inoutID < %d)||(inoutID >= %d)){\n", sc.fft_zeropad_left_read[sc.axis_id], sc.fft_zeropad_right_read[sc.axis_id]);
Expand Down Expand Up @@ -2102,6 +2104,8 @@ layout(std430, binding = %d) readonly buffer DataLUT {\n\
sdata[sharedStride * gl_LocalInvocationID.y + gl_LocalInvocationID.x + %d].y = (temp_0.y + temp_1.x);\n\
sdata[sharedStride * gl_LocalInvocationID.y + %d - gl_LocalInvocationID.x].x = (temp_0.x + temp_1.y);\n\
sdata[sharedStride * gl_LocalInvocationID.y + %d - gl_LocalInvocationID.x].y = (-temp_0.y + temp_1.x);\n", i * sc.localSize[0] + 1, i * sc.localSize[0] + 1, sc.fftDim - i * sc.localSize[0] - 1, sc.fftDim - i * sc.localSize[0] - 1);
if ((ceil(sc.min_registers_per_thread / 2.0) != sc.min_registers_per_thread / 2) && (i == (ceil(sc.min_registers_per_thread / 2.0) - 1)))
sprintf(output + strlen(output), "}\n");
}
sprintf(output + strlen(output), "\
if (gl_LocalInvocationID.x==0) \n\
Expand Down Expand Up @@ -2134,7 +2138,8 @@ layout(std430, binding = %d) readonly buffer DataLUT {\n\
for (uint32_t i = 0; i < ceil(sc.min_registers_per_thread / 2.0); i++) {
if ((uint32_t)ceil(sc.size[1] / 2.0) % sc.localSize[1] != 0)
sprintf(output + strlen(output), " if(gl_GlobalInvocationID.y%s < %d){", shiftY, (uint32_t)ceil(sc.size[1] / 2.0));

if ((ceil(sc.min_registers_per_thread / 2.0) != sc.min_registers_per_thread / 2) && (i == (ceil(sc.min_registers_per_thread / 2.0) - 1)))
sprintf(output + strlen(output), "if (gl_LocalInvocationID.x < %d){\n", sc.fftDim / 2 - i * sc.localSize[0]);
sprintf(output + strlen(output), " inoutID = indexInput(gl_LocalInvocationID.x + %d, (gl_GlobalInvocationID.y%s));\n", i * sc.localSize[0], shiftY);

if (sc.inputBufferBlockNum == 1)
Expand All @@ -2160,6 +2165,8 @@ layout(std430, binding = %d) readonly buffer DataLUT {\n\
sdata[sharedStride * gl_LocalInvocationID.y + gl_LocalInvocationID.x + %d].y = (temp_0.y + temp_1.x);\n\
sdata[sharedStride * gl_LocalInvocationID.y + %d - gl_LocalInvocationID.x].x = (temp_0.x + temp_1.y);\n\
sdata[sharedStride * gl_LocalInvocationID.y + %d - gl_LocalInvocationID.x].y = (-temp_0.y + temp_1.x);\n", i * sc.localSize[0] + 1, i * sc.localSize[0] + 1, sc.fftDim - i * sc.localSize[0] - 1, sc.fftDim - i * sc.localSize[0] - 1);
if ((ceil(sc.min_registers_per_thread / 2.0) != sc.min_registers_per_thread / 2) && (i == (ceil(sc.min_registers_per_thread / 2.0) - 1)))
sprintf(output + strlen(output), "}\n");
}
sprintf(output + strlen(output), "\
if (gl_LocalInvocationID.x==0) \n\
Expand Down Expand Up @@ -4305,6 +4312,9 @@ layout(std430, binding = %d) readonly buffer DataLUT {\n\
sprintf(output + strlen(output), " }\n");

for (uint32_t i = 0; i < ceil(sc.min_registers_per_thread / 2.0); i++) {
if ((ceil(sc.min_registers_per_thread / 2.0) != sc.min_registers_per_thread / 2) && (i == (ceil(sc.min_registers_per_thread / 2.0) - 1)))
sprintf(output + strlen(output), "if (gl_LocalInvocationID.x < %d){\n", sc.fftDim / 2 - i * sc.localSize[0]);

if (sc.localSize[1] == 1)
sprintf(output + strlen(output), "\
temp_0.x = 0.5 * (sdata[%d + gl_LocalInvocationID.x].x + sdata[%d - gl_LocalInvocationID.x].x);\n\
Expand Down Expand Up @@ -4352,6 +4362,9 @@ layout(std430, binding = %d) readonly buffer DataLUT {\n\
sprintf(output + strlen(output), " outputBlocks[inoutID / %d].outputs[inoutID %% %d] = %stemp_1%s;\n", sc.outputBufferBlockSize, sc.outputBufferBlockSize, convTypeLeft, convTypeRight);

}
if ((ceil(sc.min_registers_per_thread / 2.0) != sc.min_registers_per_thread / 2) && (i == (ceil(sc.min_registers_per_thread / 2.0) - 1))) {
sprintf(output + strlen(output), "}\n");
}
}
if ((uint32_t)ceil(sc.size[1] / 2.0) % sc.localSize[1] != 0)
sprintf(output + strlen(output), " }");
Expand Down

0 comments on commit 1c6f9cb

Please sign in to comment.