diff --git a/src/portfft/committed_descriptor_impl.hpp b/src/portfft/committed_descriptor_impl.hpp index d985e816..c842bcc5 100644 --- a/src/portfft/committed_descriptor_impl.hpp +++ b/src/portfft/committed_descriptor_impl.hpp @@ -31,6 +31,7 @@ #include "common/exceptions.hpp" #include "common/subgroup.hpp" +#include "common/workgroup.hpp" #include "defines.hpp" #include "enums.hpp" #include "specialization_constant.hpp" @@ -215,57 +216,44 @@ class committed_descriptor_impl { throw unsupported_configuration("portFFT only supports complex to complex transforms"); } - std::vector ids; - std::vector factors; IdxGlobal fft_size = static_cast(params.lengths[kernel_num]); - if (detail::fits_in_wi(fft_size)) { - ids = detail::get_ids(); - PORTFFT_LOG_TRACE("Prepared workitem impl for size: ", fft_size); - return {detail::level::WORKITEM, {{detail::level::WORKITEM, ids, factors}}}; - } - if (detail::fits_in_sg(fft_size, SubgroupSize)) { - Idx factor_sg = detail::factorize_sg(static_cast(fft_size), SubgroupSize); - Idx factor_wi = static_cast(fft_size) / factor_sg; - // This factorization is duplicated in the dispatch logic on the device. - // The CT and spec constant factors should match. - factors.push_back(factor_wi); - factors.push_back(factor_sg); - ids = detail::get_ids(); - PORTFFT_LOG_TRACE("Prepared subgroup impl with factor_wi:", factor_wi, "and factor_sg:", factor_sg); - return {detail::level::SUBGROUP, {{detail::level::SUBGROUP, ids, factors}}}; - } - IdxGlobal n_idx_global = detail::factorize(fft_size); - if (detail::can_cast_safely(n_idx_global) && - detail::can_cast_safely(fft_size / n_idx_global)) { - if (n_idx_global == 1) { - throw unsupported_configuration("FFT size ", fft_size, " : Large Prime sized FFT currently is unsupported"); + if (static_cast(fft_size) * 2 * sizeof(Scalar) <= static_cast(local_memory_size)) { + // These implementations only work if the size fits in local memory. + // They still may not be suitable if the extra local memory needed for the algorithm exceeds the available memory. + + if (detail::fits_in_wi(fft_size)) { + auto ids = detail::get_ids(); + PORTFFT_LOG_TRACE("Prepared workitem impl for size: ", fft_size); + return {detail::level::WORKITEM, {{detail::level::WORKITEM, ids, {}}}}; } - Idx n = static_cast(n_idx_global); - Idx m = static_cast(fft_size / n_idx_global); - Idx factor_sg_n = detail::factorize_sg(n, SubgroupSize); - Idx factor_wi_n = n / factor_sg_n; - Idx factor_sg_m = detail::factorize_sg(m, SubgroupSize); - Idx factor_wi_m = m / factor_sg_m; - Idx temp_num_sgs_in_wg; - std::size_t local_memory_usage = - num_scalars_in_local_mem(detail::level::WORKGROUP, static_cast(fft_size), SubgroupSize, - {factor_sg_n, factor_wi_n, factor_sg_m, factor_wi_m}, temp_num_sgs_in_wg, - layout::PACKED) * - sizeof(Scalar); - // Checks for PACKED layout only at the moment, as the other layout will not be supported - // by the global implementation. For such sizes, only PACKED layout will be supported - if (detail::fits_in_wi(factor_wi_n) && detail::fits_in_wi(factor_wi_m) && - (local_memory_usage <= static_cast(local_memory_size))) { - factors.push_back(factor_wi_n); - factors.push_back(factor_sg_n); - factors.push_back(factor_wi_m); - factors.push_back(factor_sg_m); - // This factorization of N and M is duplicated in the dispatch logic on the device. + if (detail::fits_in_sg(fft_size, SubgroupSize)) { + Idx factor_sg = detail::factorize_sg(static_cast(fft_size), SubgroupSize); + Idx factor_wi = static_cast(fft_size) / factor_sg; + // This factorization is duplicated in the dispatch logic on the device. // The CT and spec constant factors should match. - ids = detail::get_ids(); - PORTFFT_LOG_TRACE("Prepared workgroup impl with factor_wi_n:", factor_wi_n, " factor_sg_n:", factor_sg_n, - " factor_wi_m:", factor_wi_m, " factor_sg_m:", factor_sg_m); - return {detail::level::WORKGROUP, {{detail::level::WORKGROUP, ids, factors}}}; + auto ids = detail::get_ids(); + PORTFFT_LOG_TRACE("Prepared subgroup impl with factor_wi:", factor_wi, "and factor_sg:", factor_sg); + return {detail::level::SUBGROUP, {{detail::level::SUBGROUP, ids, {factor_wi, factor_sg}}}}; + } + if (auto wg_factorization = detail::factorize_for_wg(fft_size, SubgroupSize); wg_factorization) { + auto [factor_wi_n, factor_sg_n, factor_wi_m, factor_sg_m] = wg_factorization.value(); + Idx temp_num_sgs_in_wg; + std::size_t local_memory_usage = + num_scalars_in_local_mem(detail::level::WORKGROUP, static_cast(fft_size), SubgroupSize, + {factor_sg_n, factor_wi_n, factor_sg_m, factor_wi_m}, temp_num_sgs_in_wg, + layout::PACKED) * + sizeof(Scalar); + // Checks for PACKED layout only at the moment, as the other layout will not be supported + // by the global implementation. For such sizes, only PACKED layout will be supported + if (local_memory_usage <= static_cast(local_memory_size)) { + // This factorization of N and M is duplicated in the dispatch logic on the device. + // The CT and spec constant factors should match. + auto ids = detail::get_ids(); + PORTFFT_LOG_TRACE("Prepared workgroup impl with factor_wi_n:", factor_wi_n, " factor_sg_n:", factor_sg_n, + " factor_wi_m:", factor_wi_m, " factor_sg_m:", factor_sg_m); + return {detail::level::WORKGROUP, + {{detail::level::WORKGROUP, ids, {factor_wi_n, factor_sg_n, factor_wi_m, factor_sg_m}}}}; + } } } PORTFFT_LOG_TRACE("Preparing global impl"); diff --git a/src/portfft/common/workgroup.hpp b/src/portfft/common/workgroup.hpp index 25038527..4d7af97d 100644 --- a/src/portfft/common/workgroup.hpp +++ b/src/portfft/common/workgroup.hpp @@ -21,12 +21,17 @@ #ifndef PORTFFT_COMMON_WORKGROUP_HPP #define PORTFFT_COMMON_WORKGROUP_HPP +#include + #include "helpers.hpp" #include "logging.hpp" +#include "memory_views.hpp" #include "portfft/defines.hpp" #include "portfft/enums.hpp" #include "portfft/traits.hpp" +#include "portfft/utils.hpp" #include "subgroup.hpp" +#include "transfers.hpp" namespace portfft { @@ -53,6 +58,45 @@ constexpr T bank_lines_per_pad_wg(T row_size) { } namespace detail { + +// struct for the result of factorize_for_wg +struct wg_factorization { + Idx factor_wi_n; + Idx factor_sg_n; + Idx factor_wi_m; + Idx factor_sg_m; +}; + +/** Calculate a valid factorization for workgroup dfts, assuming there is sufficient local memory. + * @tparam Scalar scalar type of the transform data + * @param fft_size the number of elements in the transforms + * @param subgroup_size the size of subgroup used for the transform + * @return a factorization for workgroup dft or null if the size won't work with the implemenation of workgroup dfts. + */ +template +std::optional factorize_for_wg(IdxGlobal fft_size, Idx subgroup_size) { + IdxGlobal n_idx_global = detail::factorize(fft_size); + if (n_idx_global == 1) { + return std::nullopt; + } + + IdxGlobal m_idx_global = fft_size / n_idx_global; + if (detail::can_cast_safely(n_idx_global) && detail::can_cast_safely(m_idx_global)) { + Idx n = static_cast(n_idx_global); + Idx m = static_cast(m_idx_global); + Idx factor_sg_n = detail::factorize_sg(n, subgroup_size); + Idx factor_wi_n = n / factor_sg_n; + Idx factor_sg_m = detail::factorize_sg(m, subgroup_size); + Idx factor_wi_m = m / factor_sg_m; + + if (fits_in_wi(factor_wi_n) && fits_in_wi(factor_wi_m)) { + return wg_factorization{factor_wi_n, factor_sg_n, factor_wi_m, factor_sg_m}; + } + } + + return std::nullopt; +} + /** * Calculate all dfts in one dimension of the data stored in local memory. * diff --git a/src/portfft/descriptor_validation.hpp b/src/portfft/descriptor_validation.hpp index 9c4e421a..4e819323 100644 --- a/src/portfft/descriptor_validation.hpp +++ b/src/portfft/descriptor_validation.hpp @@ -24,7 +24,8 @@ #include #include "common/exceptions.hpp" -#include "common/subgroup.hpp" +#include "common/workgroup.hpp" +#include "common/workitem.hpp" #include "enums.hpp" #include "utils.hpp" @@ -67,8 +68,8 @@ inline void validate_layout(const std::vector& lengths, portfft::de if (forward_layout == portfft::detail::layout::UNPACKED || backward_layout == portfft::detail::layout::UNPACKED) { bool fits_subgroup = false; for (auto sg_size : {PORTFFT_SUBGROUP_SIZES}) { - fits_subgroup = - fits_subgroup || portfft::detail::fits_in_sg(static_cast(lengths.back()), sg_size); + fits_subgroup = fits_subgroup || portfft::detail::fits_in_wi(lengths.back()) || + portfft::detail::factorize_for_wg(static_cast(lengths.back()), sg_size); if (fits_subgroup) { break; } diff --git a/src/portfft/dispatcher/workgroup_dispatcher.hpp b/src/portfft/dispatcher/workgroup_dispatcher.hpp index dbbca454..b84269ca 100644 --- a/src/portfft/dispatcher/workgroup_dispatcher.hpp +++ b/src/portfft/dispatcher/workgroup_dispatcher.hpp @@ -40,9 +40,11 @@ namespace detail { * @param is_batch_interleaved is the input data layout batch interleaved * @param workgroup_size The size of the work-group. Must be divisible by 2. */ -PORTFFT_INLINE constexpr Idx get_num_batches_in_local_mem_workgroup(bool is_batch_interleaved, - Idx workgroup_size) noexcept { - return is_batch_interleaved ? workgroup_size / 2 : 1; +PORTFFT_INLINE constexpr Idx get_num_batches_in_local_mem_workgroup(bool /*is_batch_interleaved*/, + Idx /*workgroup_size*/) noexcept { + // TODO re-enable when tests can run in the batch interleaved path + // return is_batch_interleaved ? workgroup_size / 2 : 1; + return 1; } /** @@ -105,11 +107,14 @@ PORTFFT_INLINE void workgroup_impl(const T* input, T* output, const T* input_ima T scaling_factor = kh.get_specialization_constant()>(); const Idx fft_size = kh.get_specialization_constant(); + const IdxGlobal input_stride = kh.get_specialization_constant(); + const IdxGlobal output_stride = kh.get_specialization_constant(); const IdxGlobal input_distance = kh.get_specialization_constant(); const IdxGlobal output_distance = kh.get_specialization_constant(); - const bool input_batch_interleaved = input_distance == 1; - const bool output_batch_interleaved = output_distance == 1; + // TODO re-enable when tests can run in the batch interleaved path + const bool is_input_batch_interleaved = false; // input_stride == n_transforms && input_distance == 1; + const bool is_input_packed = input_stride == 1 && input_distance == fft_size; global_data.log_message_global(__func__, "entered", "fft_size", fft_size, "n_transforms", n_transforms); Idx num_workgroups = static_cast(global_data.it.get_group_range(0)); @@ -127,7 +132,7 @@ PORTFFT_INLINE void workgroup_impl(const T* input, T* output, const T* input_ima global_data.log_dump_local("twiddles loaded to local memory:", loc_twiddles, 2 * (factor_m + factor_n)); Idx max_num_batches_in_local_mem = get_num_batches_in_local_mem_workgroup( - input_batch_interleaved, static_cast(global_data.it.get_local_range(0))); + is_input_batch_interleaved, static_cast(global_data.it.get_local_range(0))); IdxGlobal first_batch_start = static_cast(wg_id) * static_cast(max_num_batches_in_local_mem); IdxGlobal num_batches_in_kernel = @@ -136,8 +141,9 @@ PORTFFT_INLINE void workgroup_impl(const T* input, T* output, const T* input_ima for (IdxGlobal batch_start_idx = first_batch_start; batch_start_idx < n_transforms; batch_start_idx += num_batches_in_kernel) { - IdxGlobal offset = static_cast(vec_size * fft_size) * batch_start_idx; - if (input_batch_interleaved) { + IdxGlobal input_global_offset = static_cast(vec_size * input_distance) * batch_start_idx; + IdxGlobal output_global_offset = static_cast(vec_size * output_distance) * batch_start_idx; + if (is_input_batch_interleaved) { /** * In the transposed case, the data is laid out in the local memory column-wise, viewing it as a FFT_Size x * WG_SIZE / 2 matrix, Each column contains either the real or the complex component of the batch. Loads WG_SIZE @@ -163,6 +169,7 @@ PORTFFT_INLINE void workgroup_impl(const T* input, T* output, const T* input_ima std::array{fft_size, num_batches_in_local_mem}); } sycl::group_barrier(global_data.it.get_group()); + for (Idx sub_batch = 0; sub_batch < num_batches_in_local_mem; sub_batch++) { wg_dft(loc_view, loc_twiddles, wg_twiddles, scaling_factor, max_num_batches_in_local_mem, sub_batch, batch_start_idx, load_modifier_data, store_modifier_data, fft_size, factor_n, @@ -170,62 +177,63 @@ PORTFFT_INLINE void workgroup_impl(const T* input, T* output, const T* input_ima apply_scale_factor, conjugate_on_load, conjugate_on_store, global_data); sycl::group_barrier(global_data.it.get_group()); } - if (!output_batch_interleaved) { - global_data.log_message_global(__func__, "storing data from local to global memory (with 2 transposes)"); - if (storage == complex_storage::INTERLEAVED_COMPLEX) { - detail::md_view loc_md_view2{loc_view, std::array{2, 1, 2 * max_num_batches_in_local_mem, - 2 * max_num_batches_in_local_mem * factor_m}}; - detail::md_view output_view{output, std::array{2 * fft_size, 1, 2 * factor_n, 2}, offset}; - copy_group(global_data, loc_md_view2, output_view, - std::array{num_batches_in_local_mem, 2, factor_m, factor_n}); - } else { // storage == complex_storage::SPLIT_COMPLEX - detail::md_view loc_real_view{ - loc_view, std::array{1, max_num_batches_in_local_mem, max_num_batches_in_local_mem * factor_m}}; - detail::md_view loc_imag_view{ - loc_view, std::array{1, max_num_batches_in_local_mem, max_num_batches_in_local_mem * factor_m}, - local_imag_offset}; - detail::md_view output_real_view{output, std::array{fft_size, factor_n, 1}, offset}; - detail::md_view output_imag_view{output_imag, std::array{fft_size, factor_n, 1}, offset}; - copy_group(global_data, loc_real_view, output_real_view, - std::array{num_batches_in_local_mem, factor_m, factor_n}); - copy_group(global_data, loc_imag_view, output_imag_view, - std::array{num_batches_in_local_mem, factor_m, factor_n}); - } - } else { // batch interleaved layout out - if (storage == complex_storage::INTERLEAVED_COMPLEX) { - detail::md_view loc_md_view2{ - loc_view, std::array{2 * max_num_batches_in_local_mem, 2 * max_num_batches_in_local_mem * factor_m, 1}}; - detail::md_view output_view{ - output, std::array{2 * n_transforms * factor_n, 2 * n_transforms, static_cast(1)}, - 2 * batch_start_idx}; - copy_group(global_data, loc_md_view2, output_view, - std::array{factor_m, factor_n, 2 * num_batches_in_local_mem}); - } else { // storage == complex_storage::SPLIT_COMPLEX - detail::md_view loc_real_view{ - loc_view, std::array{max_num_batches_in_local_mem, max_num_batches_in_local_mem * factor_m, 1}}; - detail::md_view loc_imag_view{ - loc_view, std::array{max_num_batches_in_local_mem, max_num_batches_in_local_mem * factor_m, 1}, - local_imag_offset}; - detail::md_view output_real_view{ - output, std::array{n_transforms * factor_n, n_transforms, static_cast(1)}, batch_start_idx}; - detail::md_view output_imag_view{output_imag, - std::array{n_transforms * factor_n, n_transforms, static_cast(1)}, - batch_start_idx}; - copy_group(global_data, loc_real_view, output_real_view, - std::array{factor_m, factor_n, num_batches_in_local_mem}); - copy_group(global_data, loc_imag_view, output_imag_view, - std::array{factor_m, factor_n, num_batches_in_local_mem}); - } + + global_data.log_message_global(__func__, "storing data from local to global memory (with 2 transposes)"); + if (storage == complex_storage::INTERLEAVED_COMPLEX) { + std::array global_strides{2 * output_distance, 2 * factor_n * output_stride, 2 * output_stride, + 1}; + std::array local_strides{2, 2 * max_num_batches_in_local_mem, + 2 * factor_m * max_num_batches_in_local_mem, 1}; + std::array copy_lengths{num_batches_in_local_mem, factor_m, factor_n, 2}; + + detail::md_view global_output_view{output, global_strides, output_global_offset}; + detail::md_view local_output_view{loc_view, local_strides}; + + copy_group(global_data, local_output_view, global_output_view, copy_lengths); + } else { // storage == complex_storage::SPLIT_COMPLEX + std::array global_strides{output_distance, factor_n * output_stride, output_stride}; + std::array local_strides{1, max_num_batches_in_local_mem, factor_m * max_num_batches_in_local_mem}; + std::array copy_lengths{num_batches_in_local_mem, factor_m, factor_n}; + + detail::md_view global_output_real_view{output, global_strides, output_global_offset}; + detail::md_view global_output_imag_view{output_imag, global_strides, output_global_offset}; + detail::md_view local_output_real_view{loc_view, local_strides}; + detail::md_view local_output_imag_view{loc_view, local_strides, local_imag_offset}; + + copy_group(global_data, local_output_real_view, global_output_real_view, copy_lengths); + copy_group(global_data, local_output_imag_view, global_output_imag_view, copy_lengths); } sycl::group_barrier(global_data.it.get_group()); - } else { // packed input layout + } else { // not batch interleaved input layout global_data.log_message_global(__func__, "loading non-transposed data from global to local memory"); - if (storage == complex_storage::INTERLEAVED_COMPLEX) { - global2local(global_data, input, loc_view, 2 * fft_size, offset); + if (is_input_packed) { + if (storage == complex_storage::INTERLEAVED_COMPLEX) { + global2local(global_data, input, loc_view, 2 * fft_size, input_global_offset); + } else { + global2local(global_data, input, loc_view, fft_size, input_global_offset); + global2local(global_data, input_imag, loc_view, fft_size, input_global_offset, + local_imag_offset); + } } else { - global2local(global_data, input, loc_view, fft_size, offset); - global2local(global_data, input_imag, loc_view, fft_size, offset, - local_imag_offset); + if (storage == complex_storage::INTERLEAVED_COMPLEX) { + std::array global_strides{input_stride * 2, 1}; + std::array local_strides{2, 1}; + std::array copy_lengths{fft_size, 2}; + detail::md_view global_input_view{input, global_strides, input_global_offset}; + detail::md_view local_input_view{loc_view, local_strides}; + + global_data.log_message_global(__func__, "storing data from unpacked global memory to local"); + copy_group(global_data, global_input_view, local_input_view, copy_lengths); + } else { + detail::strided_view global_input_real_view{input, input_stride, input_global_offset}; + detail::strided_view global_input_imag_view{input_imag, input_stride, input_global_offset}; + detail::offset_view local_input_imag_view{loc_view, local_imag_offset}; + + global_data.log_message_global(__func__, "storing real data from unpacked global memory to local"); + copy_group(global_data, global_input_real_view, loc_view, fft_size); + global_data.log_message_global(__func__, "storing imaginary data from unpacked global memory to local"); + copy_group(global_data, global_input_imag_view, local_input_imag_view, fft_size); + } } sycl::group_barrier(global_data.it.get_group()); wg_dft(loc_view, loc_twiddles, wg_twiddles, scaling_factor, max_num_batches_in_local_mem, 0, @@ -235,35 +243,27 @@ PORTFFT_INLINE void workgroup_impl(const T* input, T* output, const T* input_ima sycl::group_barrier(global_data.it.get_group()); global_data.log_message_global(__func__, "storing non-transposed data from local to global memory"); // transposition for WG CT - if (!output_batch_interleaved) { - if (storage == complex_storage::INTERLEAVED_COMPLEX) { - detail::md_view local_md_view2{loc_view, std::array{1, 2, 2 * factor_m}}; - detail::md_view output_view{output, std::array{1, 2 * factor_n, 2}, offset}; - copy_group(global_data, local_md_view2, output_view, std::array{2, factor_m, factor_n}); - } else { - detail::md_view loc_real_view{loc_view, std::array{1, factor_m}}; - detail::md_view loc_imag_view{loc_view, std::array{1, factor_m}, local_imag_offset}; - detail::md_view output_real_view{output, std::array{factor_n, 1}, offset}; - detail::md_view output_imag_view{output_imag, std::array{factor_n, 1}, offset}; - copy_group(global_data, loc_real_view, output_real_view, std::array{factor_m, factor_n}); - copy_group(global_data, loc_imag_view, output_imag_view, std::array{factor_m, factor_n}); - } + if (storage == complex_storage::INTERLEAVED_COMPLEX) { + std::array global_strides{2 * factor_n * output_stride, 2 * output_stride, 1}; + std::array local_strides{2, 2 * factor_m, 1}; + std::array copy_lengths{factor_m, factor_n, 2}; + + detail::md_view global_output_view{output, global_strides, output_global_offset}; + detail::md_view local_output_view{loc_view, local_strides}; + + copy_group(global_data, local_output_view, global_output_view, copy_lengths); } else { - if (storage == complex_storage::INTERLEAVED_COMPLEX) { - detail::md_view local_md_view2{loc_view, std::array{2, 1, 2 * factor_m}}; - detail::md_view output_view{ - output, std::array{2 * factor_n * n_transforms, static_cast(1), 2 * n_transforms}, - 2 * batch_start_idx}; - copy_group(global_data, local_md_view2, output_view, std::array{factor_m, 2, factor_n}); - } else { - detail::md_view loc_real_view{loc_view, std::array{1, factor_m}}; - detail::md_view loc_imag_view{loc_view, std::array{1, factor_m}, local_imag_offset}; - detail::md_view output_real_view{output, std::array{factor_n * n_transforms, n_transforms}, batch_start_idx}; - detail::md_view output_imag_view{output_imag, std::array{factor_n * n_transforms, n_transforms}, - batch_start_idx}; - copy_group(global_data, loc_real_view, output_real_view, std::array{factor_m, factor_n}); - copy_group(global_data, loc_imag_view, output_imag_view, std::array{factor_m, factor_n}); - } + std::array global_strides{factor_n * output_stride, output_stride}; + std::array local_strides{1, factor_m}; + std::array copy_lengths{factor_m, factor_n}; + + detail::md_view global_output_real_view{output, global_strides, output_global_offset}; + detail::md_view global_output_imag_view{output_imag, global_strides, output_global_offset}; + detail::md_view loc_output_real_view{loc_view, local_strides}; + detail::md_view loc_output_imag_view{loc_view, local_strides, local_imag_offset}; + + copy_group(global_data, loc_output_real_view, global_output_real_view, copy_lengths); + copy_group(global_data, loc_output_imag_view, global_output_imag_view, copy_lengths); } sycl::group_barrier(global_data.it.get_group()); } @@ -283,8 +283,8 @@ struct committed_descriptor_impl::run_kernel_struct ? detail::memory::USM : detail::memory::BUFFER; Scalar* twiddles = kernel_data.twiddles_forward.get(); std::size_t local_elements = @@ -358,8 +358,8 @@ struct committed_descriptor_impl::num_scalars_in_local_mem_struc // working memory + twiddles for subgroup impl for the two sizes Idx num_batches_in_local_mem = detail::get_num_batches_in_local_mem_workgroup( input_layout == layout::BATCH_INTERLEAVED, used_sg_size * PORTFFT_SGS_IN_WG); - return detail::pad_local(static_cast(2 * num_batches_in_local_mem) * length, - bank_lines_per_pad_wg(2 * static_cast(sizeof(Scalar)) * m)) + + const auto bank_lines_per_pad = bank_lines_per_pad_wg(2 * static_cast(sizeof(Scalar)) * m); + return detail::pad_local(static_cast(2 * num_batches_in_local_mem) * length, bank_lines_per_pad) + 2 * (m + n); } }; diff --git a/test/common/reference_data_wrangler.hpp b/test/common/reference_data_wrangler.hpp index 9609d676..08e6329d 100644 --- a/test/common/reference_data_wrangler.hpp +++ b/test/common/reference_data_wrangler.hpp @@ -109,7 +109,7 @@ auto gen_fourier_data(portfft::descriptor& desc, portfft::detail constexpr bool IsRealDomain = Domain == portfft::domain::REAL; constexpr bool IsForward = Dir == portfft::direction::FORWARD; constexpr bool IsInterleaved = Storage == portfft::complex_storage::INTERLEAVED_COMPLEX; - constexpr bool debug_input = false; + constexpr bool DebugInput = false; const auto batches = desc.number_of_transforms; const auto& dims = desc.lengths; @@ -160,7 +160,7 @@ auto gen_fourier_data(portfft::descriptor& desc, portfft::detail command << "," << (std::is_same_v ? "False" : "True"); - command << "," << (debug_input ? "True" : "False"); + command << "," << (DebugInput ? "True" : "False"); command << ")\""; @@ -243,15 +243,19 @@ auto gen_fourier_data(portfft::descriptor& desc, portfft::detail // Return a tuple in the expected order if constexpr (IsForward) { if constexpr (IsInterleaved) { - return std::make_tuple(forward, backward, forward_imag, backward_imag); + return std::make_tuple(std::move(forward), std::move(backward), std::move(forward_imag), + std::move(backward_imag)); } else { - return std::make_tuple(forward_real, backward_real, forward_imag, backward_imag); + return std::make_tuple(std::move(forward_real), std::move(backward_real), std::move(forward_imag), + std::move(backward_imag)); } } else { if constexpr (IsInterleaved) { - return std::make_tuple(backward, forward, backward_imag, forward_imag); + return std::make_tuple(std::move(backward), std::move(forward), std::move(backward_imag), + std::move(forward_imag)); } else { - return std::make_tuple(backward_real, forward_real, backward_imag, forward_imag); + return std::make_tuple(std::move(backward_real), std::move(forward_real), std::move(backward_imag), + std::move(forward_imag)); } } } diff --git a/test/unit_test/instantiate_fft_tests.hpp b/test/unit_test/instantiate_fft_tests.hpp index 026f27c1..ad41a160 100644 --- a/test/unit_test/instantiate_fft_tests.hpp +++ b/test/unit_test/instantiate_fft_tests.hpp @@ -245,9 +245,16 @@ INSTANTIATE_TEST_SUITE_P( SubgroupStridedOOPInOrder, FFTTest, ::testing::ConvertGenerator(::testing::Combine( oop_unpacked_unpacked_layout, both_directions, complex_storages, ::testing::Values(1, 3, 33000ul), - ::testing::Values(layout_params{{64}, {1}, {7}}, layout_params{{64}, {4}, {7}}, + ::testing::Values(layout_params{{64}, {1}, {7}}, layout_params{{64}, {4}, {1}}, layout_params{{75}, {3}, {2}, 300, 200}, layout_params{{104}, {3}, {4}}))), test_params_print()); +INSTANTIATE_TEST_SUITE_P( + WorkgroupStridedOOPInOrder, FFTTest, + ::testing::ConvertGenerator(::testing::Combine( + oop_unpacked_unpacked_layout, both_directions, complex_storages, ::testing::Values(1, 3, 12000ul), + ::testing::Values(layout_params{{2048}, {1}, {7}}, layout_params{{5625}, {4}, {7}}, + layout_params{{4096}, {3}, {2}, 12500, 9000}, layout_params{{2704}, {3}, {1}}))), + test_params_print()); // The LikeBatchInterleaved tests must have stride >= number of transforms INSTANTIATE_TEST_SUITE_P( workItemStridedOOPLikeBatchInterleaved, FFTTest, @@ -263,7 +270,15 @@ INSTANTIATE_TEST_SUITE_P( ::testing::Values(layout_params{{64}, {33}, {99}, 1, 3}, layout_params{{96}, {33}, {2}, 1, 192}, layout_params{{70}, {2}, {66}, 140, 2}))), test_params_print()); -INSTANTIATE_TEST_SUITE_P(workItemStridedIP, FFTTest, +INSTANTIATE_TEST_SUITE_P( + WorkgroupStridedOOPLikeBatchInterleaved, FFTTest, + ::testing::ConvertGenerator(::testing::Combine( + oop_unpacked_unpacked_layout, both_directions, complex_storages, ::testing::Values(1, 10, 33), + ::testing::Values(layout_params{{4096}, {33}, {99}, 1, 3}, layout_params{{4096}, {99}, {33}, 3, 1}, + layout_params{{5625}, {66}, {1}, 2, 5625}, layout_params{{2704}, {66}, {4}, 2, 10816}, + layout_params{{780}, {1}, {66}, 780, 2}, layout_params{{3072}, {2}, {33}, 6144, 1}))), + test_params_print()); +INSTANTIATE_TEST_SUITE_P(workItemStridedIPInOrder, FFTTest, ::testing::ConvertGenerator(::testing::Combine( ip_unpacked_unpacked_layout, both_directions, complex_storages, ::testing::Values(1, 3, 33000ul), @@ -271,7 +286,7 @@ INSTANTIATE_TEST_SUITE_P(workItemStridedIP, FFTTest, // no space between last element of one batch and first of the next layout_params{{9}, {3}, {3}, 25, 25}))), test_params_print()); -INSTANTIATE_TEST_SUITE_P(SubgroupStridedIP, FFTTest, +INSTANTIATE_TEST_SUITE_P(SubgroupStridedIPInOrder, FFTTest, ::testing::ConvertGenerator(::testing::Combine( ip_unpacked_unpacked_layout, both_directions, complex_storages, ::testing::Values(1, 3, 33000ul), @@ -279,6 +294,14 @@ INSTANTIATE_TEST_SUITE_P(SubgroupStridedIP, FFTTest, // no space between last element of one batch and first of the next layout_params{{96}, {3}, {3}, 286, 286}))), test_params_print()); +INSTANTIATE_TEST_SUITE_P(WorkgroupStridedIPInOrder, FFTTest, + ::testing::ConvertGenerator(::testing::Combine( + ip_unpacked_unpacked_layout, both_directions, complex_storages, + ::testing::Values(1, 3, 33000ul), + ::testing::Values(layout_params{{3072}, {4}, {4}}, + // no space between last element of one batch and first of the next + layout_params{{780}, {3}, {3}, 2338, 2338}))), + test_params_print()); INSTANTIATE_TEST_SUITE_P( workItemStridedIPLikeBatchInterleaved, FFTTest, ::testing::ConvertGenerator(::testing::Combine( @@ -291,6 +314,12 @@ INSTANTIATE_TEST_SUITE_P( ip_unpacked_unpacked_layout, both_directions, complex_storages, ::testing::Values(1, 3, 33), ::testing::Values(layout_params{{75}, {66}, {66}, 2, 2}, layout_params{{96}, {40}, {40}, 1, 1}))), test_params_print()); +INSTANTIATE_TEST_SUITE_P( + WorkgroupStridedIPLikeBatchInterleaved, FFTTest, + ::testing::ConvertGenerator(::testing::Combine( + ip_unpacked_unpacked_layout, both_directions, complex_storages, ::testing::Values(1, 3, 33), + ::testing::Values(layout_params{{2048}, {66}, {66}, 2, 2}, layout_params{{5625}, {40}, {40}, 1, 1}))), + test_params_print()); // these layouts are only valid because there is only a single batch INSTANTIATE_TEST_SUITE_P(StridedStrideEqualsDistance, FFTTest, @@ -317,6 +346,11 @@ INSTANTIATE_TEST_SUITE_P(SubgroupStridedArbitraryInterleaved, FFTTest, all_unpacked_unpacked_layout, both_directions, complex_storages, ::testing::Values(13), ::testing::Values(layout_params{{85}, {13}, {13}, 12, 12}))), test_params_print()); +INSTANTIATE_TEST_SUITE_P(WorkgroupStridedArbitraryInterleaved, FFTTest, + ::testing::ConvertGenerator(::testing::Combine( + all_unpacked_unpacked_layout, both_directions, complex_storages, ::testing::Values(13), + ::testing::Values(layout_params{{780}, {13}, {13}, 12, 12}))), + test_params_print()); // Invalid configurations test suite INSTANTIATE_TEST_SUITE_P(InvalidLength, InvalidFFTTest,