Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Workgroup strided transforms #143

Open
wants to merge 7 commits into
base: main
Choose a base branch
from
Open

Workgroup strided transforms #143

wants to merge 7 commits into from

Conversation

FMarno
Copy link
Contributor

@FMarno FMarno commented Mar 1, 2024

Add the option for strided transforms at the workgroup level + testing

Checklist

Tick if relevant:

  • [N/A] New files have a copyright
  • [N/A] New headers have an include guards
  • API is documented with Doxygen
  • New functionalities are tested
  • Tests pass locally
  • Files are clang-formatted

src/portfft/committed_descriptor_impl.hpp Outdated Show resolved Hide resolved
src/portfft/common/workgroup.hpp Outdated Show resolved Hide resolved
src/portfft/common/workgroup.hpp Outdated Show resolved Hide resolved
src/portfft/common/workgroup.hpp Outdated Show resolved Hide resolved
src/portfft/common/workgroup.hpp Show resolved Hide resolved
test/unit_test/instantiate_fft_tests.hpp Show resolved Hide resolved
FMarno and others added 3 commits March 1, 2024 16:36
hjabird
hjabird previously approved these changes Mar 4, 2024
t4c1
t4c1 previously approved these changes Mar 5, 2024
Comment on lines +178 to +195
global_data.log_message_global(__func__, "storing data from local to global memory (with 2 transposes)");
if (storage == complex_storage::INTERLEAVED_COMPLEX) {
std::array<IdxGlobal, 4> global_strides{2 * output_distance, 2 * factor_n * output_stride, 2 * output_stride,
1};
std::array<Idx, 4> local_strides{2, 2 * max_num_batches_in_local_mem,
2 * factor_m * max_num_batches_in_local_mem, 1};
std::array<Idx, 4> copy_lengths{num_batches_in_local_mem, factor_m, factor_n, 2};

detail::md_view global_output_view{output, global_strides, output_global_offset};
detail::md_view local_output_view{loc_view, local_strides};

copy_group<level::WORKGROUP>(global_data, local_output_view, global_output_view, copy_lengths);
} else { // storage == complex_storage::SPLIT_COMPLEX
std::array<IdxGlobal, 3> global_strides{output_distance, factor_n * output_stride, output_stride};
std::array<Idx, 3> local_strides{1, max_num_batches_in_local_mem, factor_m * max_num_batches_in_local_mem};
std::array<Idx, 3> copy_lengths{num_batches_in_local_mem, factor_m, factor_n};

detail::md_view global_output_real_view{output, global_strides, output_global_offset};
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Lets move this logic out into a new file and lower the number of lines in workgroup_dispatcher, as the created views are not used anywhere else. similar to what I have done here.
Similarly for line 206 onwards

Copy link
Contributor Author

@FMarno FMarno Mar 5, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'll have a look at options for splitting the code into smaller chunks, but your example only really hides the easy bit. I would rather just inline the definition of the view objects.
I find statements like this

  subgroup_impl_local2global_strided_copy<level::SUBGROUP, 3, 3, 3>(
              const_cast<T*>(input), loc_view, {input_distance * 2, input_stride * 2, 1}, {fft_size * 2, 2, 1},
              input_distance * 2 * (i - static_cast<IdxGlobal>(id_of_fft_in_sg)), local_offset,
              {n_ffts_worked_on_by_sg, fft_size, 2}, global_data, detail::transfer_direction::GLOBAL_TO_LOCAL);

quite awkward to understand since there a lot of things going on in one big statement and if I wanted to understand it I would need to somehow maps things out from parameter to argument. The copies in workgroup dispatcher are currently verbose, but it's because they split the definition into individual chunks of information which I find very helpful. We need to prioritise readability over writability.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We need to prioritise readability over writability.

Yet we cannot prioritize readability to a point where our kernel functions start to span over 600+ lines (subgroup_dispatcher, where the majority of the lines come from view creations and then copying it).

All the views are always only temporarily required and add a LOT of lines. hence I would say its best to move it to a different function

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

subgroup_dispatcher definitely needs some work but I would rather sort that in a different PR.
I would also like to see the *_impl functions become smaller. I'll have a look at the workgroup_impl and see what I can do.

* @return a factorization for workgroup dft or null if the size won't work with the implemenation of workgroup dfts.
*/
template <typename Scalar>
std::optional<wg_factorization> factorize_for_wg(IdxGlobal fft_size, Idx subgroup_size) {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Lets move this to utils.hpp, we have only device callable functions in the common folder.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Currently we have the factorization functions for workitem and subgroup in common/workitem.hpp and common/subgroup.hpp respectively, so I was following the example there.
factorize_sg is not called from device anywhere, along with fits_in_sg and fits_in_wi, so I wouldn't say we only have device callable functions in the common folder.
If we do want to refactor to puts the factorization functions in a utility file, then we should group them and put them in a "factorization.hpp" or something like that. Generic util files are a bit of a code smell imo (though I am guilty of committing that sin).

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'd agree that a factorisation.hpp would be better than having everything in utils.hpp.

@FMarno FMarno dismissed stale reviews from t4c1 and hjabird via 2e7f777 March 7, 2024 15:03
* @return a factorization for workgroup dft or null if the size won't work with the implemenation of workgroup dfts.
*/
template <typename Scalar>
std::optional<wg_factorization> factorize_for_wg(IdxGlobal fft_size, Idx subgroup_size) {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'd agree that a factorisation.hpp would be better than having everything in utils.hpp.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

4 participants