Skip to content

Commit

Permalink
review comments 2
Browse files Browse the repository at this point in the history
  • Loading branch information
AD2605 committed Jan 23, 2024
1 parent 5fe5eaa commit 8ea2076
Show file tree
Hide file tree
Showing 3 changed files with 33 additions and 34 deletions.
6 changes: 3 additions & 3 deletions src/portfft/common/helpers.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -184,10 +184,10 @@ PORTFFT_INLINE constexpr Idx int_log2(Idx x) {
}

/**
* Takes the conjugate of the complex data in private array
* Conjugates complex data in an array in place (expected to be used on private memory)
* @tparam T Scalar type
* @param priv pointer to the data in registers
* @param num_complex number of complex numbers in the private memory
* @param priv pointer to the data
* @param num_complex number of complex numbers to conjugate
*/
template <typename T>
PORTFFT_INLINE void conjugate_inplace(T* priv, Idx num_complex) {
Expand Down
8 changes: 4 additions & 4 deletions src/portfft/common/workgroup.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -77,8 +77,8 @@ namespace detail {
* @param multiply_on_load Whether the input data is multiplied with some data array before fft computation.
* @param MultiplyOnStore Whether the input data is multiplied with some data array after fft computation.
* @param ApplyScaleFactor Whether or not the scale factor is applied
* @param conjugate_on_load whether or not to take conjugate of the input
* @param conjugate_on_store whether or not to take conjugate of the output
* @param conjugate_on_load whether or not to conjugate the input
* @param conjugate_on_store whether or not to conjugate the output
* @param global_data global data for the kernel
*/
template <Idx SubgroupSize, typename LocalT, typename T>
Expand Down Expand Up @@ -317,8 +317,8 @@ __attribute__((always_inline)) inline void dimension_dft(
* @param multiply_on_load Whether the input data is multiplied with some data array before fft computation.
* @param multiply_on_store Whether the input data is multiplied with some data array after fft computation.
* @param apply_scale_factor Whether or not the scale factor is applied
* @param conjugate_on_load whether or not to take conjugate of the input
* @param conjugate_on_store whether or not to take conjugate of the output
* @param conjugate_on_load whether or not to conjugate the input
* @param conjugate_on_store whether or not to conjugate the output
* @param global_data global data for the kernel
*/
template <Idx SubgroupSize, typename LocalT, typename T>
Expand Down
53 changes: 26 additions & 27 deletions src/portfft/descriptor.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -527,21 +527,23 @@ class committed_descriptor {
* returned from prepare_implementation
* @tparam SubgroupSize Subgroup size
* @param top_level selected level of implementation
* @param prepared_vec vector as returned from prepare_implementation
* @param prepared_vec vector of tuples of: implementation to use for a kernel,
* vector of kernel ids, factors
* @param compute_direction direction of compute, as in forward or backward
* @param dimension_num which dimension are the kernels being built for
* @param is_compatible flag to be set if the kernels are compatible
* @param set_scale_as_unity whether or not scale factor needs to be set as unity
* @param skip_scaling whether or not to skip scaling
* @return
*/
template <Idx SubgroupSize>
std::vector<kernel_data_struct> set_spec_constants_driver(
detail::level top_level,
std::vector<std::tuple<detail::level, std::vector<sycl::kernel_id>, std::vector<Idx>>>& prepared_vec,
direction compute_direction, std::size_t dimension_num, bool& is_compatible, bool set_scale_as_unity) {
direction compute_direction, std::size_t dimension_num, bool& is_compatible, bool skip_scaling) {
Scalar scale_factor = compute_direction == direction::FORWARD ? params.forward_scale : params.backward_scale;
if (set_scale_as_unity) {
scale_factor = static_cast<Scalar>(1.0);
detail::apply_scale_factor scale_factor_applied = detail::apply_scale_factor::APPLIED;
if (skip_scaling) {
scale_factor_applied = detail::apply_scale_factor::NOT_APPLIED;
}
std::size_t counter = 0;
auto conjugate_on_load = detail::complex_conjugate::NOT_APPLIED;
Expand All @@ -550,28 +552,26 @@ class committed_descriptor {
for (auto& [level, ids, factors] : prepared_vec) {
auto in_bundle = sycl::get_kernel_bundle<sycl::bundle_state::input>(queue.get_context(), ids);
if (top_level == detail::level::GLOBAL) {
std::size_t factor_size =
static_cast<std::size_t>(std::accumulate(factors.begin(), factors.end(), Idx(1), std::multiplies<Idx>()));
if (counter == prepared_vec.size() - 1) {
if (compute_direction == direction::BACKWARD) {
conjugate_on_store = detail::complex_conjugate::APPLIED;
}
set_spec_constants(
detail::level::GLOBAL, in_bundle,
static_cast<std::size_t>(std::accumulate(factors.begin(), factors.end(), Idx(1), std::multiplies<Idx>())),
factors, detail::elementwise_multiply::NOT_APPLIED, detail::elementwise_multiply::NOT_APPLIED,
detail::apply_scale_factor::APPLIED, level, conjugate_on_load, conjugate_on_store, scale_factor,
static_cast<Idx>(counter), static_cast<Idx>(prepared_vec.size()));
set_spec_constants(detail::level::GLOBAL, in_bundle, factor_size, factors,
detail::elementwise_multiply::NOT_APPLIED, detail::elementwise_multiply::NOT_APPLIED,
detail::apply_scale_factor::APPLIED, level, conjugate_on_load, conjugate_on_store,
scale_factor, static_cast<Idx>(counter), static_cast<Idx>(prepared_vec.size()));
// reset conjugate_on_store
conjugate_on_store = detail::complex_conjugate::NOT_APPLIED;
} else {
if (counter == 0 && compute_direction == direction::BACKWARD) {
conjugate_on_load = detail::complex_conjugate::APPLIED;
}
set_spec_constants(
detail::level::GLOBAL, in_bundle,
static_cast<std::size_t>(std::accumulate(factors.begin(), factors.end(), Idx(1), std::multiplies<Idx>())),
factors, detail::elementwise_multiply::NOT_APPLIED, detail::elementwise_multiply::APPLIED,
detail::apply_scale_factor::NOT_APPLIED, level, conjugate_on_load, conjugate_on_store, scale_factor,
static_cast<Idx>(counter), static_cast<Idx>(prepared_vec.size()));
set_spec_constants(detail::level::GLOBAL, in_bundle, factor_size, factors,
detail::elementwise_multiply::NOT_APPLIED, detail::elementwise_multiply::APPLIED,
detail::apply_scale_factor::NOT_APPLIED, level, conjugate_on_load, conjugate_on_store,
scale_factor, static_cast<Idx>(counter), static_cast<Idx>(prepared_vec.size()));
// reset conjugate_on_load
conjugate_on_load = detail::complex_conjugate::NOT_APPLIED;
}
Expand All @@ -582,8 +582,7 @@ class committed_descriptor {
}
set_spec_constants(level, in_bundle, params.lengths[dimension_num], factors,
detail::elementwise_multiply::NOT_APPLIED, detail::elementwise_multiply::NOT_APPLIED,
detail::apply_scale_factor::APPLIED, level, conjugate_on_load, conjugate_on_store,
scale_factor);
scale_factor_applied, level, conjugate_on_load, conjugate_on_store, scale_factor);
}
try {
result.emplace_back(sycl::build(in_bundle), factors, params.lengths[dimension_num], SubgroupSize,
Expand All @@ -605,11 +604,11 @@ class committed_descriptor {
* @tparam SubgroupSize first subgroup size
* @tparam OtherSGSizes other subgroup sizes
* @param dimension_num The dimension for which the kernels are being built
* @param set_scale_as_unity whether or not scale factor needs to be set as unity
* @param skip_scaling whether or not to skip scaling
* @return `dimension_struct` for the newly built kernels
*/
template <Idx SubgroupSize, Idx... OtherSGSizes>
dimension_struct build_w_spec_const(std::size_t dimension_num, bool set_scale_as_unity) {
dimension_struct build_w_spec_const(std::size_t dimension_num, bool skip_scaling) {
if (std::count(supported_sg_sizes.begin(), supported_sg_sizes.end(), SubgroupSize)) {
auto [top_level, prepared_vec] = prepare_implementation<SubgroupSize>(dimension_num);
bool is_compatible = true;
Expand All @@ -622,9 +621,9 @@ class committed_descriptor {

if (is_compatible) {
std::vector<kernel_data_struct> forward_kernels = set_spec_constants_driver<SubgroupSize>(
top_level, prepared_vec, direction::FORWARD, dimension_num, is_compatible, set_scale_as_unity);
top_level, prepared_vec, direction::FORWARD, dimension_num, is_compatible, skip_scaling);
std::vector<kernel_data_struct> backward_kernels = set_spec_constants_driver<SubgroupSize>(
top_level, prepared_vec, direction::BACKWARD, dimension_num, is_compatible, set_scale_as_unity);
top_level, prepared_vec, direction::BACKWARD, dimension_num, is_compatible, skip_scaling);
if (is_compatible) {
return {forward_kernels, backward_kernels, top_level, params.lengths[dimension_num], SubgroupSize};
}
Expand All @@ -633,7 +632,7 @@ class committed_descriptor {
if constexpr (sizeof...(OtherSGSizes) == 0) {
throw invalid_configuration("None of the compiled subgroup sizes are supported by the device");
} else {
return build_w_spec_const<OtherSGSizes...>(dimension_num, set_scale_as_unity);
return build_w_spec_const<OtherSGSizes...>(dimension_num, skip_scaling);
}
}

Expand Down Expand Up @@ -804,11 +803,11 @@ class committed_descriptor {
// compile the kernels and precalculate twiddles
std::size_t n_kernels = params.lengths.size();
for (std::size_t i = 0; i < n_kernels; i++) {
bool set_scale_as_unity = true;
bool skip_scaling = true;
if (i == n_kernels - 1) {
set_scale_as_unity = false;
skip_scaling = false;
}
dimensions.emplace_back(build_w_spec_const<PORTFFT_SUBGROUP_SIZES>(i, set_scale_as_unity));
dimensions.emplace_back(build_w_spec_const<PORTFFT_SUBGROUP_SIZES>(i, skip_scaling));
dimensions.back().forward_kernels.at(0).twiddles_forward = std::shared_ptr<Scalar>(
calculate_twiddles(dimensions.back().level, dimensions.back().forward_kernels), [queue](Scalar* ptr) {
if (ptr != nullptr) {
Expand Down

0 comments on commit 8ea2076

Please sign in to comment.