From 5fe5eaa4632c283c497df520314ccde80710e88d Mon Sep 17 00:00:00 2001 From: "atharva.dubey" Date: Tue, 23 Jan 2024 12:23:48 +0000 Subject: [PATCH] renamed all take_conjugate_* to conjugate_* --- src/portfft/common/helpers.hpp | 2 +- src/portfft/common/workgroup.hpp | 39 +++++++++--------- src/portfft/descriptor.hpp | 41 +++++++++---------- .../dispatcher/subgroup_dispatcher.hpp | 22 +++++----- .../dispatcher/workgroup_dispatcher.hpp | 10 ++--- .../dispatcher/workitem_dispatcher.hpp | 14 +++---- src/portfft/enums.hpp | 2 +- src/portfft/specialization_constant.hpp | 4 +- 8 files changed, 64 insertions(+), 70 deletions(-) diff --git a/src/portfft/common/helpers.hpp b/src/portfft/common/helpers.hpp index 8064727d..ab4728cd 100644 --- a/src/portfft/common/helpers.hpp +++ b/src/portfft/common/helpers.hpp @@ -190,7 +190,7 @@ PORTFFT_INLINE constexpr Idx int_log2(Idx x) { * @param num_complex number of complex numbers in the private memory */ template -PORTFFT_INLINE void take_conjugate_inplace(T* priv, Idx num_complex) { +PORTFFT_INLINE void conjugate_inplace(T* priv, Idx num_complex) { PORTFFT_UNROLL for (Idx i = 0; i < num_complex; i++) { priv[2 * i + 1] *= -1; diff --git a/src/portfft/common/workgroup.hpp b/src/portfft/common/workgroup.hpp index c3abde27..9996c9e5 100644 --- a/src/portfft/common/workgroup.hpp +++ b/src/portfft/common/workgroup.hpp @@ -77,8 +77,8 @@ namespace detail { * @param multiply_on_load Whether the input data is multiplied with some data array before fft computation. * @param MultiplyOnStore Whether the input data is multiplied with some data array after fft computation. * @param ApplyScaleFactor Whether or not the scale factor is applied - * @param take_conjugate_on_load whether or not to take conjugate of the input - * @param take_conjugate_on_store whether or not to take conjugate of the output + * @param conjugate_on_load whether or not to take conjugate of the input + * @param conjugate_on_store whether or not to take conjugate of the output * @param global_data global data for the kernel */ template @@ -88,7 +88,7 @@ __attribute__((always_inline)) inline void dimension_dft( Idx dft_size, Idx stride_within_dft, Idx ndfts_in_outer_dimension, complex_storage storage, detail::layout layout_in, detail::elementwise_multiply multiply_on_load, detail::elementwise_multiply multiply_on_store, detail::apply_scale_factor apply_scale_factor, - detail::complex_conjugate take_conjugate_on_load, detail::complex_conjugate take_conjugate_on_store, + detail::complex_conjugate conjugate_on_load, detail::complex_conjugate conjugate_on_store, global_data_struct<1> global_data) { static_assert(std::is_same_v, T>, "Real type mismatch"); global_data.log_message_global(__func__, "entered", "DFTSize", dft_size, "stride_within_dft", stride_within_dft, @@ -226,12 +226,12 @@ __attribute__((always_inline)) inline void dimension_dft( } } } - if (take_conjugate_on_load == detail::complex_conjugate::TAKEN) { - take_conjugate_inplace(priv, fact_wi); + if (conjugate_on_load == detail::complex_conjugate::APPLIED) { + conjugate_inplace(priv, fact_wi); } sg_dft(priv, global_data.sg, fact_wi, fact_sg, loc_twiddles, wi_private_scratch); - if (take_conjugate_on_store == detail::complex_conjugate::TAKEN) { - take_conjugate_inplace(priv, fact_wi); + if (conjugate_on_store == detail::complex_conjugate::APPLIED) { + conjugate_inplace(priv, fact_wi); } if (working) { if (multiply_on_store == detail::elementwise_multiply::APPLIED) { @@ -317,18 +317,19 @@ __attribute__((always_inline)) inline void dimension_dft( * @param multiply_on_load Whether the input data is multiplied with some data array before fft computation. * @param multiply_on_store Whether the input data is multiplied with some data array after fft computation. * @param apply_scale_factor Whether or not the scale factor is applied - * @param take_conjugate_on_load whether or not to take conjugate of the input - * @param take_conjugate_on_store whether or not to take conjugate of the output + * @param conjugate_on_load whether or not to take conjugate of the input + * @param conjugate_on_store whether or not to take conjugate of the output * @param global_data global data for the kernel */ template -PORTFFT_INLINE void wg_dft( - LocalT loc, T* loc_twiddles, const T* wg_twiddles, T scaling_factor, Idx max_num_batches_in_local_mem, - Idx batch_num_in_local, IdxGlobal batch_num_in_kernel, const T* load_modifier_data, const T* store_modifier_data, - Idx fft_size, Idx N, Idx M, complex_storage storage, detail::layout layout_in, - detail::elementwise_multiply multiply_on_load, detail::elementwise_multiply multiply_on_store, - detail::apply_scale_factor apply_scale_factor, detail::complex_conjugate take_conjugate_on_load, - detail::complex_conjugate take_conjugate_on_store, detail::global_data_struct<1> global_data) { +PORTFFT_INLINE void wg_dft(LocalT loc, T* loc_twiddles, const T* wg_twiddles, T scaling_factor, + Idx max_num_batches_in_local_mem, Idx batch_num_in_local, IdxGlobal batch_num_in_kernel, + const T* load_modifier_data, const T* store_modifier_data, Idx fft_size, Idx N, Idx M, + complex_storage storage, detail::layout layout_in, + detail::elementwise_multiply multiply_on_load, + detail::elementwise_multiply multiply_on_store, + detail::apply_scale_factor apply_scale_factor, detail::complex_conjugate conjugate_on_load, + detail::complex_conjugate conjugate_on_store, detail::global_data_struct<1> global_data) { global_data.log_message_global(__func__, "entered", "FFTSize", fft_size, "N", N, "M", M, "max_num_batches_in_local_mem", max_num_batches_in_local_mem, "batch_num_in_local", batch_num_in_local); @@ -336,15 +337,15 @@ PORTFFT_INLINE void wg_dft( detail::dimension_dft( loc, loc_twiddles + (2 * M), nullptr, 1, max_num_batches_in_local_mem, batch_num_in_local, load_modifier_data, store_modifier_data, batch_num_in_kernel, N, M, 1, storage, layout_in, multiply_on_load, - detail::elementwise_multiply::NOT_APPLIED, detail::apply_scale_factor::NOT_APPLIED, take_conjugate_on_load, - detail::complex_conjugate::NOT_TAKEN, global_data); + detail::elementwise_multiply::NOT_APPLIED, detail::apply_scale_factor::NOT_APPLIED, conjugate_on_load, + detail::complex_conjugate::NOT_APPLIED, global_data); sycl::group_barrier(global_data.it.get_group()); // row-wise DFTs, including twiddle multiplications and scaling detail::dimension_dft( loc, loc_twiddles, wg_twiddles, scaling_factor, max_num_batches_in_local_mem, batch_num_in_local, load_modifier_data, store_modifier_data, batch_num_in_kernel, M, 1, N, storage, layout_in, detail::elementwise_multiply::NOT_APPLIED, multiply_on_store, apply_scale_factor, - detail::complex_conjugate::NOT_TAKEN, take_conjugate_on_store, global_data); + detail::complex_conjugate::NOT_APPLIED, conjugate_on_store, global_data); global_data.log_message_global(__func__, "exited"); } diff --git a/src/portfft/descriptor.hpp b/src/portfft/descriptor.hpp index 4d77dd67..ea3a44aa 100644 --- a/src/portfft/descriptor.hpp +++ b/src/portfft/descriptor.hpp @@ -441,8 +441,8 @@ class committed_descriptor { * @param multiply_on_store Whether the input data is multiplied with some data array after fft computation * @param scale_factor_applied whether or not to multiply scale factor * @param level sub implementation to run which will be set as a spec constant - * @param take_conjugate_on_load whether or not to take conjugate of the input - * @param take_conjugate_on_store whether or not to take conjugate of the output + * @param conjugate_on_load whether or not to take conjugate of the input + * @param conjugate_on_store whether or not to take conjugate of the output * @param scale_factor Scale to be applied to the result * @param factor_num factor number which is set as a spec constant * @param num_factors total number of factors of the committed size, set as a spec constant @@ -451,9 +451,8 @@ class committed_descriptor { std::size_t length, const std::vector& factors, detail::elementwise_multiply multiply_on_load, detail::elementwise_multiply multiply_on_store, detail::apply_scale_factor scale_factor_applied, detail::level level, - detail::complex_conjugate take_conjugate_on_load, - detail::complex_conjugate take_conjugate_on_store, Scalar scale_factor, Idx factor_num = 0, - Idx num_factors = 0) { + detail::complex_conjugate conjugate_on_load, detail::complex_conjugate conjugate_on_store, + Scalar scale_factor, Idx factor_num = 0, Idx num_factors = 0) { const Idx length_idx = static_cast(length); // These spec constants are used in all implementations, so we set them here in_bundle.template set_specialization_constant(params.complex_storage); @@ -462,8 +461,8 @@ class committed_descriptor { in_bundle.template set_specialization_constant(multiply_on_load); in_bundle.template set_specialization_constant(multiply_on_store); in_bundle.template set_specialization_constant(scale_factor_applied); - in_bundle.template set_specialization_constant(take_conjugate_on_load); - in_bundle.template set_specialization_constant(take_conjugate_on_store); + in_bundle.template set_specialization_constant(conjugate_on_load); + in_bundle.template set_specialization_constant(conjugate_on_store); if constexpr (std::is_same_v) { in_bundle.template set_specialization_constant(scale_factor); } else { @@ -545,45 +544,45 @@ class committed_descriptor { scale_factor = static_cast(1.0); } std::size_t counter = 0; - auto take_conjugate_on_load = detail::complex_conjugate::NOT_TAKEN; - auto take_conjugate_on_store = detail::complex_conjugate::NOT_TAKEN; + auto conjugate_on_load = detail::complex_conjugate::NOT_APPLIED; + auto conjugate_on_store = detail::complex_conjugate::NOT_APPLIED; std::vector result; for (auto& [level, ids, factors] : prepared_vec) { auto in_bundle = sycl::get_kernel_bundle(queue.get_context(), ids); if (top_level == detail::level::GLOBAL) { if (counter == prepared_vec.size() - 1) { if (compute_direction == direction::BACKWARD) { - take_conjugate_on_store = detail::complex_conjugate::TAKEN; + conjugate_on_store = detail::complex_conjugate::APPLIED; } set_spec_constants( detail::level::GLOBAL, in_bundle, static_cast(std::accumulate(factors.begin(), factors.end(), Idx(1), std::multiplies())), factors, detail::elementwise_multiply::NOT_APPLIED, detail::elementwise_multiply::NOT_APPLIED, - detail::apply_scale_factor::APPLIED, level, take_conjugate_on_load, take_conjugate_on_store, scale_factor, + detail::apply_scale_factor::APPLIED, level, conjugate_on_load, conjugate_on_store, scale_factor, static_cast(counter), static_cast(prepared_vec.size())); - // reset take_conjugate_on_store - take_conjugate_on_store = detail::complex_conjugate::NOT_TAKEN; + // reset conjugate_on_store + conjugate_on_store = detail::complex_conjugate::NOT_APPLIED; } else { if (counter == 0 && compute_direction == direction::BACKWARD) { - take_conjugate_on_load = detail::complex_conjugate::TAKEN; + conjugate_on_load = detail::complex_conjugate::APPLIED; } set_spec_constants( detail::level::GLOBAL, in_bundle, static_cast(std::accumulate(factors.begin(), factors.end(), Idx(1), std::multiplies())), factors, detail::elementwise_multiply::NOT_APPLIED, detail::elementwise_multiply::APPLIED, - detail::apply_scale_factor::NOT_APPLIED, level, take_conjugate_on_load, take_conjugate_on_store, - scale_factor, static_cast(counter), static_cast(prepared_vec.size())); - // reset take_conjugate_on_load - take_conjugate_on_load = detail::complex_conjugate::NOT_TAKEN; + detail::apply_scale_factor::NOT_APPLIED, level, conjugate_on_load, conjugate_on_store, scale_factor, + static_cast(counter), static_cast(prepared_vec.size())); + // reset conjugate_on_load + conjugate_on_load = detail::complex_conjugate::NOT_APPLIED; } } else { if (compute_direction == direction::BACKWARD) { - take_conjugate_on_load = detail::complex_conjugate::TAKEN; - take_conjugate_on_store = detail::complex_conjugate::TAKEN; + conjugate_on_load = detail::complex_conjugate::APPLIED; + conjugate_on_store = detail::complex_conjugate::APPLIED; } set_spec_constants(level, in_bundle, params.lengths[dimension_num], factors, detail::elementwise_multiply::NOT_APPLIED, detail::elementwise_multiply::NOT_APPLIED, - detail::apply_scale_factor::APPLIED, level, take_conjugate_on_load, take_conjugate_on_store, + detail::apply_scale_factor::APPLIED, level, conjugate_on_load, conjugate_on_store, scale_factor); } try { diff --git a/src/portfft/dispatcher/subgroup_dispatcher.hpp b/src/portfft/dispatcher/subgroup_dispatcher.hpp index 97034d96..90e77686 100644 --- a/src/portfft/dispatcher/subgroup_dispatcher.hpp +++ b/src/portfft/dispatcher/subgroup_dispatcher.hpp @@ -94,10 +94,8 @@ PORTFFT_INLINE void subgroup_impl(const T* input, T* output, const T* input_imag detail::elementwise_multiply multiply_on_load = kh.get_specialization_constant(); detail::elementwise_multiply multiply_on_store = kh.get_specialization_constant(); detail::apply_scale_factor apply_scale_factor = kh.get_specialization_constant(); - detail::complex_conjugate take_conjugate_on_load = - kh.get_specialization_constant(); - detail::complex_conjugate take_conjugate_on_store = - kh.get_specialization_constant(); + detail::complex_conjugate conjugate_on_load = kh.get_specialization_constant(); + detail::complex_conjugate conjugate_on_store = kh.get_specialization_constant(); T scaling_factor = [&]() { if constexpr (std::is_same_v) { return kh.get_specialization_constant(); @@ -265,12 +263,12 @@ PORTFFT_INLINE void subgroup_impl(const T* input, T* output, const T* input_imag } } } - if (take_conjugate_on_load == detail::complex_conjugate::TAKEN) { - take_conjugate_inplace(priv, factor_wi); + if (conjugate_on_load == detail::complex_conjugate::APPLIED) { + conjugate_inplace(priv, factor_wi); } sg_dft(priv, global_data.sg, factor_wi, factor_sg, loc_twiddles, wi_private_scratch); - if (take_conjugate_on_store == detail::complex_conjugate::TAKEN) { - take_conjugate_inplace(priv, factor_wi); + if (conjugate_on_store == detail::complex_conjugate::APPLIED) { + conjugate_inplace(priv, factor_wi); } if (working_inner) { global_data.log_dump_private("data in registers after computation:", priv, n_reals_per_wi); @@ -462,12 +460,12 @@ PORTFFT_INLINE void subgroup_impl(const T* input, T* output, const T* input_imag } } } - if (take_conjugate_on_load == detail::complex_conjugate::TAKEN) { - take_conjugate_inplace(priv, factor_wi); + if (conjugate_on_load == detail::complex_conjugate::APPLIED) { + conjugate_inplace(priv, factor_wi); } sg_dft(priv, global_data.sg, factor_wi, factor_sg, loc_twiddles, wi_private_scratch); - if (take_conjugate_on_store == detail::complex_conjugate::TAKEN) { - take_conjugate_inplace(priv, factor_wi); + if (conjugate_on_store == detail::complex_conjugate::APPLIED) { + conjugate_inplace(priv, factor_wi); } if (working) { global_data.log_dump_private("data in registers after computation:", priv, n_reals_per_wi); diff --git a/src/portfft/dispatcher/workgroup_dispatcher.hpp b/src/portfft/dispatcher/workgroup_dispatcher.hpp index e4c26310..bef1b37a 100644 --- a/src/portfft/dispatcher/workgroup_dispatcher.hpp +++ b/src/portfft/dispatcher/workgroup_dispatcher.hpp @@ -106,10 +106,8 @@ PORTFFT_INLINE void workgroup_impl(const T* input, T* output, const T* input_ima detail::elementwise_multiply multiply_on_load = kh.get_specialization_constant(); detail::elementwise_multiply multiply_on_store = kh.get_specialization_constant(); detail::apply_scale_factor apply_scale_factor = kh.get_specialization_constant(); - detail::complex_conjugate take_conjugate_on_load = - kh.get_specialization_constant(); - detail::complex_conjugate take_conjugate_on_store = - kh.get_specialization_constant(); + detail::complex_conjugate conjugate_on_load = kh.get_specialization_constant(); + detail::complex_conjugate conjugate_on_store = kh.get_specialization_constant(); T scaling_factor = [&]() { if constexpr (std::is_same_v) { return kh.get_specialization_constant(); @@ -176,7 +174,7 @@ PORTFFT_INLINE void workgroup_impl(const T* input, T* output, const T* input_ima wg_dft(loc_view, loc_twiddles, wg_twiddles, scaling_factor, max_num_batches_in_local_mem, sub_batch, batch_start_idx, load_modifier_data, store_modifier_data, fft_size, factor_n, factor_m, storage, LayoutIn, multiply_on_load, multiply_on_store, apply_scale_factor, - take_conjugate_on_load, take_conjugate_on_store, global_data); + conjugate_on_load, conjugate_on_store, global_data); sycl::group_barrier(global_data.it.get_group()); } if constexpr (LayoutOut == detail::layout::PACKED) { @@ -240,7 +238,7 @@ PORTFFT_INLINE void workgroup_impl(const T* input, T* output, const T* input_ima wg_dft(loc_view, loc_twiddles, wg_twiddles, scaling_factor, max_num_batches_in_local_mem, 0, batch_start_idx, load_modifier_data, store_modifier_data, fft_size, factor_n, factor_m, storage, LayoutIn, multiply_on_load, multiply_on_store, apply_scale_factor, - take_conjugate_on_load, take_conjugate_on_store, global_data); + conjugate_on_load, conjugate_on_store, global_data); sycl::group_barrier(global_data.it.get_group()); global_data.log_message_global(__func__, "storing non-transposed data from local to global memory"); // transposition for WG CT diff --git a/src/portfft/dispatcher/workitem_dispatcher.hpp b/src/portfft/dispatcher/workitem_dispatcher.hpp index 327557a9..2f5cafb1 100644 --- a/src/portfft/dispatcher/workitem_dispatcher.hpp +++ b/src/portfft/dispatcher/workitem_dispatcher.hpp @@ -107,10 +107,8 @@ PORTFFT_INLINE void workitem_impl(const T* input, T* output, const T* input_imag detail::elementwise_multiply multiply_on_load = kh.get_specialization_constant(); detail::elementwise_multiply multiply_on_store = kh.get_specialization_constant(); detail::apply_scale_factor apply_scale_factor = kh.get_specialization_constant(); - detail::complex_conjugate take_conjugate_on_load = - kh.get_specialization_constant(); - detail::complex_conjugate take_conjugate_on_store = - kh.get_specialization_constant(); + detail::complex_conjugate conjugate_on_load = kh.get_specialization_constant(); + detail::complex_conjugate conjugate_on_store = kh.get_specialization_constant(); T scaling_factor = [&]() { if constexpr (std::is_same_v) { @@ -213,12 +211,12 @@ PORTFFT_INLINE void workitem_impl(const T* input, T* output, const T* input_imag global_data.log_message_global(__func__, "applying load modifier"); detail::apply_modifier(fft_size, priv, load_modifier_data, i * n_reals); } - if (take_conjugate_on_load == detail::complex_conjugate::TAKEN) { - take_conjugate_inplace(priv, fft_size); + if (conjugate_on_load == detail::complex_conjugate::APPLIED) { + conjugate_inplace(priv, fft_size); } wi_dft<0>(priv, priv, fft_size, 1, 1, wi_private_scratch); - if (take_conjugate_on_store == detail::complex_conjugate::TAKEN) { - take_conjugate_inplace(priv, fft_size); + if (conjugate_on_store == detail::complex_conjugate::APPLIED) { + conjugate_inplace(priv, fft_size); } global_data.log_dump_private("data in registers after computation:", priv, n_reals); if (multiply_on_store == detail::elementwise_multiply::APPLIED) { diff --git a/src/portfft/enums.hpp b/src/portfft/enums.hpp index 13e3e70f..cbb1c607 100644 --- a/src/portfft/enums.hpp +++ b/src/portfft/enums.hpp @@ -70,7 +70,7 @@ enum class elementwise_multiply { APPLIED, NOT_APPLIED }; enum class apply_scale_factor { APPLIED, NOT_APPLIED }; -enum class complex_conjugate { TAKEN, NOT_TAKEN }; +enum class complex_conjugate { APPLIED, NOT_APPLIED }; } // namespace detail } // namespace portfft diff --git a/src/portfft/specialization_constant.hpp b/src/portfft/specialization_constant.hpp index 06156ad1..90f22f76 100644 --- a/src/portfft/specialization_constant.hpp +++ b/src/portfft/specialization_constant.hpp @@ -45,8 +45,8 @@ constexpr static sycl::specialization_id GlobalSpecConstLevelNum{}; constexpr static sycl::specialization_id GlobalSpecConstNumFactors{}; // Specialization constants used for IFFT, when expressed as a IFFT=(conjugate(FFT(conjugate(input)))) -constexpr static sycl::specialization_id SpecConstTakeConjugateOnLoad{}; -constexpr static sycl::specialization_id SpecConstTakeConjugateOnStore{}; +constexpr static sycl::specialization_id SpecTakeConjugateOnLoad{}; +constexpr static sycl::specialization_id SpecConstConjugateOnStore{}; constexpr static sycl::specialization_id SpecConstScaleFactorFloat{}; constexpr static sycl::specialization_id SpecConstScaleFactorDouble{};