diff --git a/vkFFT/vkFFT.h b/vkFFT/vkFFT.h index 3fd01814..1baae3d0 100644 --- a/vkFFT/vkFFT.h +++ b/vkFFT/vkFFT.h @@ -1686,10 +1686,10 @@ layout(std430, binding = %d) readonly buffer DataLUT {\n\ } sprintf(output + strlen(output), " }\n"); sprintf(output + strlen(output), " else\n"); - if (sc.readToRegisters) + if (sc.readToRegisters) sprintf(output + strlen(output), " temp_%d = %s(0,0);\n", i, vecType); else - sprintf(output + strlen(output), " sdata[gl_WorkGroupSize.x*(gl_LocalInvocationID.y+%d)+gl_LocalInvocationID.x]=%s(0,0);\n", i* sc.localSize[1], vecType); + sprintf(output + strlen(output), " sdata[gl_WorkGroupSize.x*(gl_LocalInvocationID.y+%d)+gl_LocalInvocationID.x]=%s(0,0);\n", i * sc.localSize[1], vecType); } } @@ -2069,7 +2069,9 @@ layout(std430, binding = %d) readonly buffer DataLUT {\n\ for (uint32_t i = 0; i < ceil(sc.min_registers_per_thread / 2.0); i++) { if ((uint32_t)ceil(sc.size[1] / 2.0) % sc.localSize[1] != 0) sprintf(output + strlen(output), " if(gl_GlobalInvocationID.y%s < %d){", shiftY, (uint32_t)ceil(sc.size[1] / 2.0)); - + if ((ceil(sc.min_registers_per_thread / 2.0) != sc.min_registers_per_thread / 2) && (i == (ceil(sc.min_registers_per_thread / 2.0) - 1))) + sprintf(output + strlen(output), "if (gl_LocalInvocationID.x < %d){\n", sc.fftDim / 2 - i * sc.localSize[0]); + sprintf(output + strlen(output), " inoutID = gl_LocalInvocationID.x+%d;\n", i * sc.localSize[0]); sprintf(output + strlen(output), " if((inoutID < %d)||(inoutID >= %d)){\n", sc.fft_zeropad_left_read[sc.axis_id], sc.fft_zeropad_right_read[sc.axis_id]); @@ -2102,6 +2104,8 @@ layout(std430, binding = %d) readonly buffer DataLUT {\n\ sdata[sharedStride * gl_LocalInvocationID.y + gl_LocalInvocationID.x + %d].y = (temp_0.y + temp_1.x);\n\ sdata[sharedStride * gl_LocalInvocationID.y + %d - gl_LocalInvocationID.x].x = (temp_0.x + temp_1.y);\n\ sdata[sharedStride * gl_LocalInvocationID.y + %d - gl_LocalInvocationID.x].y = (-temp_0.y + temp_1.x);\n", i * sc.localSize[0] + 1, i * sc.localSize[0] + 1, sc.fftDim - i * sc.localSize[0] - 1, sc.fftDim - i * sc.localSize[0] - 1); + if ((ceil(sc.min_registers_per_thread / 2.0) != sc.min_registers_per_thread / 2) && (i == (ceil(sc.min_registers_per_thread / 2.0) - 1))) + sprintf(output + strlen(output), "}\n"); } sprintf(output + strlen(output), "\ if (gl_LocalInvocationID.x==0) \n\ @@ -2134,7 +2138,8 @@ layout(std430, binding = %d) readonly buffer DataLUT {\n\ for (uint32_t i = 0; i < ceil(sc.min_registers_per_thread / 2.0); i++) { if ((uint32_t)ceil(sc.size[1] / 2.0) % sc.localSize[1] != 0) sprintf(output + strlen(output), " if(gl_GlobalInvocationID.y%s < %d){", shiftY, (uint32_t)ceil(sc.size[1] / 2.0)); - + if ((ceil(sc.min_registers_per_thread / 2.0) != sc.min_registers_per_thread / 2) && (i == (ceil(sc.min_registers_per_thread / 2.0) - 1))) + sprintf(output + strlen(output), "if (gl_LocalInvocationID.x < %d){\n", sc.fftDim / 2 - i * sc.localSize[0]); sprintf(output + strlen(output), " inoutID = indexInput(gl_LocalInvocationID.x + %d, (gl_GlobalInvocationID.y%s));\n", i * sc.localSize[0], shiftY); if (sc.inputBufferBlockNum == 1) @@ -2160,6 +2165,8 @@ layout(std430, binding = %d) readonly buffer DataLUT {\n\ sdata[sharedStride * gl_LocalInvocationID.y + gl_LocalInvocationID.x + %d].y = (temp_0.y + temp_1.x);\n\ sdata[sharedStride * gl_LocalInvocationID.y + %d - gl_LocalInvocationID.x].x = (temp_0.x + temp_1.y);\n\ sdata[sharedStride * gl_LocalInvocationID.y + %d - gl_LocalInvocationID.x].y = (-temp_0.y + temp_1.x);\n", i * sc.localSize[0] + 1, i * sc.localSize[0] + 1, sc.fftDim - i * sc.localSize[0] - 1, sc.fftDim - i * sc.localSize[0] - 1); + if ((ceil(sc.min_registers_per_thread / 2.0) != sc.min_registers_per_thread / 2) && (i == (ceil(sc.min_registers_per_thread / 2.0) - 1))) + sprintf(output + strlen(output), "}\n"); } sprintf(output + strlen(output), "\ if (gl_LocalInvocationID.x==0) \n\ @@ -4305,6 +4312,9 @@ layout(std430, binding = %d) readonly buffer DataLUT {\n\ sprintf(output + strlen(output), " }\n"); for (uint32_t i = 0; i < ceil(sc.min_registers_per_thread / 2.0); i++) { + if ((ceil(sc.min_registers_per_thread / 2.0) != sc.min_registers_per_thread / 2) && (i == (ceil(sc.min_registers_per_thread / 2.0) - 1))) + sprintf(output + strlen(output), "if (gl_LocalInvocationID.x < %d){\n", sc.fftDim / 2 - i * sc.localSize[0]); + if (sc.localSize[1] == 1) sprintf(output + strlen(output), "\ temp_0.x = 0.5 * (sdata[%d + gl_LocalInvocationID.x].x + sdata[%d - gl_LocalInvocationID.x].x);\n\ @@ -4352,6 +4362,9 @@ layout(std430, binding = %d) readonly buffer DataLUT {\n\ sprintf(output + strlen(output), " outputBlocks[inoutID / %d].outputs[inoutID %% %d] = %stemp_1%s;\n", sc.outputBufferBlockSize, sc.outputBufferBlockSize, convTypeLeft, convTypeRight); } + if ((ceil(sc.min_registers_per_thread / 2.0) != sc.min_registers_per_thread / 2) && (i == (ceil(sc.min_registers_per_thread / 2.0) - 1))) { + sprintf(output + strlen(output), "}\n"); + } } if ((uint32_t)ceil(sc.size[1] / 2.0) % sc.localSize[1] != 0) sprintf(output + strlen(output), " }");