Skip to content

Commit

Permalink
renamed all take_conjugate_* to conjugate_*
Browse files Browse the repository at this point in the history
  • Loading branch information
AD2605 committed Jan 23, 2024
1 parent e4167ae commit 5fe5eaa
Show file tree
Hide file tree
Showing 8 changed files with 64 additions and 70 deletions.
2 changes: 1 addition & 1 deletion src/portfft/common/helpers.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -190,7 +190,7 @@ PORTFFT_INLINE constexpr Idx int_log2(Idx x) {
* @param num_complex number of complex numbers in the private memory
*/
template <typename T>
PORTFFT_INLINE void take_conjugate_inplace(T* priv, Idx num_complex) {
PORTFFT_INLINE void conjugate_inplace(T* priv, Idx num_complex) {
PORTFFT_UNROLL
for (Idx i = 0; i < num_complex; i++) {
priv[2 * i + 1] *= -1;
Expand Down
39 changes: 20 additions & 19 deletions src/portfft/common/workgroup.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -77,8 +77,8 @@ namespace detail {
* @param multiply_on_load Whether the input data is multiplied with some data array before fft computation.
* @param MultiplyOnStore Whether the input data is multiplied with some data array after fft computation.
* @param ApplyScaleFactor Whether or not the scale factor is applied
* @param take_conjugate_on_load whether or not to take conjugate of the input
* @param take_conjugate_on_store whether or not to take conjugate of the output
* @param conjugate_on_load whether or not to take conjugate of the input
* @param conjugate_on_store whether or not to take conjugate of the output
* @param global_data global data for the kernel
*/
template <Idx SubgroupSize, typename LocalT, typename T>
Expand All @@ -88,7 +88,7 @@ __attribute__((always_inline)) inline void dimension_dft(
Idx dft_size, Idx stride_within_dft, Idx ndfts_in_outer_dimension, complex_storage storage,
detail::layout layout_in, detail::elementwise_multiply multiply_on_load,
detail::elementwise_multiply multiply_on_store, detail::apply_scale_factor apply_scale_factor,
detail::complex_conjugate take_conjugate_on_load, detail::complex_conjugate take_conjugate_on_store,
detail::complex_conjugate conjugate_on_load, detail::complex_conjugate conjugate_on_store,
global_data_struct<1> global_data) {
static_assert(std::is_same_v<detail::get_element_t<LocalT>, T>, "Real type mismatch");
global_data.log_message_global(__func__, "entered", "DFTSize", dft_size, "stride_within_dft", stride_within_dft,
Expand Down Expand Up @@ -226,12 +226,12 @@ __attribute__((always_inline)) inline void dimension_dft(
}
}
}
if (take_conjugate_on_load == detail::complex_conjugate::TAKEN) {
take_conjugate_inplace(priv, fact_wi);
if (conjugate_on_load == detail::complex_conjugate::APPLIED) {
conjugate_inplace(priv, fact_wi);
}
sg_dft<SubgroupSize>(priv, global_data.sg, fact_wi, fact_sg, loc_twiddles, wi_private_scratch);
if (take_conjugate_on_store == detail::complex_conjugate::TAKEN) {
take_conjugate_inplace(priv, fact_wi);
if (conjugate_on_store == detail::complex_conjugate::APPLIED) {
conjugate_inplace(priv, fact_wi);
}
if (working) {
if (multiply_on_store == detail::elementwise_multiply::APPLIED) {
Expand Down Expand Up @@ -317,34 +317,35 @@ __attribute__((always_inline)) inline void dimension_dft(
* @param multiply_on_load Whether the input data is multiplied with some data array before fft computation.
* @param multiply_on_store Whether the input data is multiplied with some data array after fft computation.
* @param apply_scale_factor Whether or not the scale factor is applied
* @param take_conjugate_on_load whether or not to take conjugate of the input
* @param take_conjugate_on_store whether or not to take conjugate of the output
* @param conjugate_on_load whether or not to take conjugate of the input
* @param conjugate_on_store whether or not to take conjugate of the output
* @param global_data global data for the kernel
*/
template <Idx SubgroupSize, typename LocalT, typename T>
PORTFFT_INLINE void wg_dft(
LocalT loc, T* loc_twiddles, const T* wg_twiddles, T scaling_factor, Idx max_num_batches_in_local_mem,
Idx batch_num_in_local, IdxGlobal batch_num_in_kernel, const T* load_modifier_data, const T* store_modifier_data,
Idx fft_size, Idx N, Idx M, complex_storage storage, detail::layout layout_in,
detail::elementwise_multiply multiply_on_load, detail::elementwise_multiply multiply_on_store,
detail::apply_scale_factor apply_scale_factor, detail::complex_conjugate take_conjugate_on_load,
detail::complex_conjugate take_conjugate_on_store, detail::global_data_struct<1> global_data) {
PORTFFT_INLINE void wg_dft(LocalT loc, T* loc_twiddles, const T* wg_twiddles, T scaling_factor,
Idx max_num_batches_in_local_mem, Idx batch_num_in_local, IdxGlobal batch_num_in_kernel,
const T* load_modifier_data, const T* store_modifier_data, Idx fft_size, Idx N, Idx M,
complex_storage storage, detail::layout layout_in,
detail::elementwise_multiply multiply_on_load,
detail::elementwise_multiply multiply_on_store,
detail::apply_scale_factor apply_scale_factor, detail::complex_conjugate conjugate_on_load,
detail::complex_conjugate conjugate_on_store, detail::global_data_struct<1> global_data) {
global_data.log_message_global(__func__, "entered", "FFTSize", fft_size, "N", N, "M", M,
"max_num_batches_in_local_mem", max_num_batches_in_local_mem, "batch_num_in_local",
batch_num_in_local);
// column-wise DFTs
detail::dimension_dft<SubgroupSize, LocalT, T>(
loc, loc_twiddles + (2 * M), nullptr, 1, max_num_batches_in_local_mem, batch_num_in_local, load_modifier_data,
store_modifier_data, batch_num_in_kernel, N, M, 1, storage, layout_in, multiply_on_load,
detail::elementwise_multiply::NOT_APPLIED, detail::apply_scale_factor::NOT_APPLIED, take_conjugate_on_load,
detail::complex_conjugate::NOT_TAKEN, global_data);
detail::elementwise_multiply::NOT_APPLIED, detail::apply_scale_factor::NOT_APPLIED, conjugate_on_load,
detail::complex_conjugate::NOT_APPLIED, global_data);
sycl::group_barrier(global_data.it.get_group());
// row-wise DFTs, including twiddle multiplications and scaling
detail::dimension_dft<SubgroupSize, LocalT, T>(
loc, loc_twiddles, wg_twiddles, scaling_factor, max_num_batches_in_local_mem, batch_num_in_local,
load_modifier_data, store_modifier_data, batch_num_in_kernel, M, 1, N, storage, layout_in,
detail::elementwise_multiply::NOT_APPLIED, multiply_on_store, apply_scale_factor,
detail::complex_conjugate::NOT_TAKEN, take_conjugate_on_store, global_data);
detail::complex_conjugate::NOT_APPLIED, conjugate_on_store, global_data);
global_data.log_message_global(__func__, "exited");
}

Expand Down
41 changes: 20 additions & 21 deletions src/portfft/descriptor.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -441,8 +441,8 @@ class committed_descriptor {
* @param multiply_on_store Whether the input data is multiplied with some data array after fft computation
* @param scale_factor_applied whether or not to multiply scale factor
* @param level sub implementation to run which will be set as a spec constant
* @param take_conjugate_on_load whether or not to take conjugate of the input
* @param take_conjugate_on_store whether or not to take conjugate of the output
* @param conjugate_on_load whether or not to take conjugate of the input
* @param conjugate_on_store whether or not to take conjugate of the output
* @param scale_factor Scale to be applied to the result
* @param factor_num factor number which is set as a spec constant
* @param num_factors total number of factors of the committed size, set as a spec constant
Expand All @@ -451,9 +451,8 @@ class committed_descriptor {
std::size_t length, const std::vector<Idx>& factors,
detail::elementwise_multiply multiply_on_load, detail::elementwise_multiply multiply_on_store,
detail::apply_scale_factor scale_factor_applied, detail::level level,
detail::complex_conjugate take_conjugate_on_load,
detail::complex_conjugate take_conjugate_on_store, Scalar scale_factor, Idx factor_num = 0,
Idx num_factors = 0) {
detail::complex_conjugate conjugate_on_load, detail::complex_conjugate conjugate_on_store,
Scalar scale_factor, Idx factor_num = 0, Idx num_factors = 0) {
const Idx length_idx = static_cast<Idx>(length);
// These spec constants are used in all implementations, so we set them here
in_bundle.template set_specialization_constant<detail::SpecConstComplexStorage>(params.complex_storage);
Expand All @@ -462,8 +461,8 @@ class committed_descriptor {
in_bundle.template set_specialization_constant<detail::SpecConstMultiplyOnLoad>(multiply_on_load);
in_bundle.template set_specialization_constant<detail::SpecConstMultiplyOnStore>(multiply_on_store);
in_bundle.template set_specialization_constant<detail::SpecConstApplyScaleFactor>(scale_factor_applied);
in_bundle.template set_specialization_constant<detail::SpecConstTakeConjugateOnLoad>(take_conjugate_on_load);
in_bundle.template set_specialization_constant<detail::SpecConstTakeConjugateOnStore>(take_conjugate_on_store);
in_bundle.template set_specialization_constant<detail::SpecTakeConjugateOnLoad>(conjugate_on_load);
in_bundle.template set_specialization_constant<detail::SpecConstConjugateOnStore>(conjugate_on_store);
if constexpr (std::is_same_v<Scalar, float>) {
in_bundle.template set_specialization_constant<detail::SpecConstScaleFactorFloat>(scale_factor);
} else {
Expand Down Expand Up @@ -545,45 +544,45 @@ class committed_descriptor {
scale_factor = static_cast<Scalar>(1.0);
}
std::size_t counter = 0;
auto take_conjugate_on_load = detail::complex_conjugate::NOT_TAKEN;
auto take_conjugate_on_store = detail::complex_conjugate::NOT_TAKEN;
auto conjugate_on_load = detail::complex_conjugate::NOT_APPLIED;
auto conjugate_on_store = detail::complex_conjugate::NOT_APPLIED;
std::vector<kernel_data_struct> result;
for (auto& [level, ids, factors] : prepared_vec) {
auto in_bundle = sycl::get_kernel_bundle<sycl::bundle_state::input>(queue.get_context(), ids);
if (top_level == detail::level::GLOBAL) {
if (counter == prepared_vec.size() - 1) {
if (compute_direction == direction::BACKWARD) {
take_conjugate_on_store = detail::complex_conjugate::TAKEN;
conjugate_on_store = detail::complex_conjugate::APPLIED;
}
set_spec_constants(
detail::level::GLOBAL, in_bundle,
static_cast<std::size_t>(std::accumulate(factors.begin(), factors.end(), Idx(1), std::multiplies<Idx>())),
factors, detail::elementwise_multiply::NOT_APPLIED, detail::elementwise_multiply::NOT_APPLIED,
detail::apply_scale_factor::APPLIED, level, take_conjugate_on_load, take_conjugate_on_store, scale_factor,
detail::apply_scale_factor::APPLIED, level, conjugate_on_load, conjugate_on_store, scale_factor,
static_cast<Idx>(counter), static_cast<Idx>(prepared_vec.size()));
// reset take_conjugate_on_store
take_conjugate_on_store = detail::complex_conjugate::NOT_TAKEN;
// reset conjugate_on_store
conjugate_on_store = detail::complex_conjugate::NOT_APPLIED;
} else {
if (counter == 0 && compute_direction == direction::BACKWARD) {
take_conjugate_on_load = detail::complex_conjugate::TAKEN;
conjugate_on_load = detail::complex_conjugate::APPLIED;
}
set_spec_constants(
detail::level::GLOBAL, in_bundle,
static_cast<std::size_t>(std::accumulate(factors.begin(), factors.end(), Idx(1), std::multiplies<Idx>())),
factors, detail::elementwise_multiply::NOT_APPLIED, detail::elementwise_multiply::APPLIED,
detail::apply_scale_factor::NOT_APPLIED, level, take_conjugate_on_load, take_conjugate_on_store,
scale_factor, static_cast<Idx>(counter), static_cast<Idx>(prepared_vec.size()));
// reset take_conjugate_on_load
take_conjugate_on_load = detail::complex_conjugate::NOT_TAKEN;
detail::apply_scale_factor::NOT_APPLIED, level, conjugate_on_load, conjugate_on_store, scale_factor,
static_cast<Idx>(counter), static_cast<Idx>(prepared_vec.size()));
// reset conjugate_on_load
conjugate_on_load = detail::complex_conjugate::NOT_APPLIED;
}
} else {
if (compute_direction == direction::BACKWARD) {
take_conjugate_on_load = detail::complex_conjugate::TAKEN;
take_conjugate_on_store = detail::complex_conjugate::TAKEN;
conjugate_on_load = detail::complex_conjugate::APPLIED;
conjugate_on_store = detail::complex_conjugate::APPLIED;
}
set_spec_constants(level, in_bundle, params.lengths[dimension_num], factors,
detail::elementwise_multiply::NOT_APPLIED, detail::elementwise_multiply::NOT_APPLIED,
detail::apply_scale_factor::APPLIED, level, take_conjugate_on_load, take_conjugate_on_store,
detail::apply_scale_factor::APPLIED, level, conjugate_on_load, conjugate_on_store,
scale_factor);
}
try {
Expand Down
22 changes: 10 additions & 12 deletions src/portfft/dispatcher/subgroup_dispatcher.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -94,10 +94,8 @@ PORTFFT_INLINE void subgroup_impl(const T* input, T* output, const T* input_imag
detail::elementwise_multiply multiply_on_load = kh.get_specialization_constant<detail::SpecConstMultiplyOnLoad>();
detail::elementwise_multiply multiply_on_store = kh.get_specialization_constant<detail::SpecConstMultiplyOnStore>();
detail::apply_scale_factor apply_scale_factor = kh.get_specialization_constant<detail::SpecConstApplyScaleFactor>();
detail::complex_conjugate take_conjugate_on_load =
kh.get_specialization_constant<detail::SpecConstTakeConjugateOnLoad>();
detail::complex_conjugate take_conjugate_on_store =
kh.get_specialization_constant<detail::SpecConstTakeConjugateOnStore>();
detail::complex_conjugate conjugate_on_load = kh.get_specialization_constant<detail::SpecTakeConjugateOnLoad>();
detail::complex_conjugate conjugate_on_store = kh.get_specialization_constant<detail::SpecConstConjugateOnStore>();
T scaling_factor = [&]() {
if constexpr (std::is_same_v<T, float>) {
return kh.get_specialization_constant<detail::SpecConstScaleFactorFloat>();
Expand Down Expand Up @@ -265,12 +263,12 @@ PORTFFT_INLINE void subgroup_impl(const T* input, T* output, const T* input_imag
}
}
}
if (take_conjugate_on_load == detail::complex_conjugate::TAKEN) {
take_conjugate_inplace(priv, factor_wi);
if (conjugate_on_load == detail::complex_conjugate::APPLIED) {
conjugate_inplace(priv, factor_wi);
}
sg_dft<SubgroupSize>(priv, global_data.sg, factor_wi, factor_sg, loc_twiddles, wi_private_scratch);
if (take_conjugate_on_store == detail::complex_conjugate::TAKEN) {
take_conjugate_inplace(priv, factor_wi);
if (conjugate_on_store == detail::complex_conjugate::APPLIED) {
conjugate_inplace(priv, factor_wi);
}
if (working_inner) {
global_data.log_dump_private("data in registers after computation:", priv, n_reals_per_wi);
Expand Down Expand Up @@ -462,12 +460,12 @@ PORTFFT_INLINE void subgroup_impl(const T* input, T* output, const T* input_imag
}
}
}
if (take_conjugate_on_load == detail::complex_conjugate::TAKEN) {
take_conjugate_inplace(priv, factor_wi);
if (conjugate_on_load == detail::complex_conjugate::APPLIED) {
conjugate_inplace(priv, factor_wi);
}
sg_dft<SubgroupSize>(priv, global_data.sg, factor_wi, factor_sg, loc_twiddles, wi_private_scratch);
if (take_conjugate_on_store == detail::complex_conjugate::TAKEN) {
take_conjugate_inplace(priv, factor_wi);
if (conjugate_on_store == detail::complex_conjugate::APPLIED) {
conjugate_inplace(priv, factor_wi);
}
if (working) {
global_data.log_dump_private("data in registers after computation:", priv, n_reals_per_wi);
Expand Down
10 changes: 4 additions & 6 deletions src/portfft/dispatcher/workgroup_dispatcher.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -106,10 +106,8 @@ PORTFFT_INLINE void workgroup_impl(const T* input, T* output, const T* input_ima
detail::elementwise_multiply multiply_on_load = kh.get_specialization_constant<detail::SpecConstMultiplyOnLoad>();
detail::elementwise_multiply multiply_on_store = kh.get_specialization_constant<detail::SpecConstMultiplyOnStore>();
detail::apply_scale_factor apply_scale_factor = kh.get_specialization_constant<detail::SpecConstApplyScaleFactor>();
detail::complex_conjugate take_conjugate_on_load =
kh.get_specialization_constant<detail::SpecConstTakeConjugateOnLoad>();
detail::complex_conjugate take_conjugate_on_store =
kh.get_specialization_constant<detail::SpecConstTakeConjugateOnStore>();
detail::complex_conjugate conjugate_on_load = kh.get_specialization_constant<detail::SpecTakeConjugateOnLoad>();
detail::complex_conjugate conjugate_on_store = kh.get_specialization_constant<detail::SpecConstConjugateOnStore>();
T scaling_factor = [&]() {
if constexpr (std::is_same_v<T, float>) {
return kh.get_specialization_constant<detail::SpecConstScaleFactorFloat>();
Expand Down Expand Up @@ -176,7 +174,7 @@ PORTFFT_INLINE void workgroup_impl(const T* input, T* output, const T* input_ima
wg_dft<SubgroupSize>(loc_view, loc_twiddles, wg_twiddles, scaling_factor, max_num_batches_in_local_mem,
sub_batch, batch_start_idx, load_modifier_data, store_modifier_data, fft_size, factor_n,
factor_m, storage, LayoutIn, multiply_on_load, multiply_on_store, apply_scale_factor,
take_conjugate_on_load, take_conjugate_on_store, global_data);
conjugate_on_load, conjugate_on_store, global_data);
sycl::group_barrier(global_data.it.get_group());
}
if constexpr (LayoutOut == detail::layout::PACKED) {
Expand Down Expand Up @@ -240,7 +238,7 @@ PORTFFT_INLINE void workgroup_impl(const T* input, T* output, const T* input_ima
wg_dft<SubgroupSize>(loc_view, loc_twiddles, wg_twiddles, scaling_factor, max_num_batches_in_local_mem, 0,
batch_start_idx, load_modifier_data, store_modifier_data, fft_size, factor_n, factor_m,
storage, LayoutIn, multiply_on_load, multiply_on_store, apply_scale_factor,
take_conjugate_on_load, take_conjugate_on_store, global_data);
conjugate_on_load, conjugate_on_store, global_data);
sycl::group_barrier(global_data.it.get_group());
global_data.log_message_global(__func__, "storing non-transposed data from local to global memory");
// transposition for WG CT
Expand Down
Loading

0 comments on commit 5fe5eaa

Please sign in to comment.