Skip to content

Commit

Permalink
Document bulk including the static_thread_pool customization.
Browse files Browse the repository at this point in the history
  • Loading branch information
BenFrantzDale committed Oct 9, 2024
1 parent 7371dea commit 736b23b
Show file tree
Hide file tree
Showing 2 changed files with 28 additions and 2 deletions.
24 changes: 22 additions & 2 deletions include/exec/static_thread_pool.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -475,6 +475,11 @@ namespace exec {
const nodemask& contraints = nodemask::any()) noexcept;
void enqueue(remote_queue& queue, task_base* task, std::size_t threadIndex) noexcept;

//! Enqueue a contiguous span of tasks across task queues.
//! Note: We use the concrete `TaskT` because we enqueue
//! tasks `task + 0`, `task + 1`, etc. so std::span<task_base>
//! wouldn't be correct.
//! This is O(n_threads) on the calling thread.
template <std::derived_from<task_base> TaskT>
void bulk_enqueue(TaskT* task, std::uint32_t n_threads) noexcept;
void bulk_enqueue(
Expand Down Expand Up @@ -810,12 +815,15 @@ namespace exec {

template <std::derived_from<task_base> TaskT>
void static_thread_pool_::bulk_enqueue(TaskT* task, std::uint32_t n_threads) noexcept {
auto& queue = *get_remote_queue();
auto& queue = *this->get_remote_queue();
for (std::uint32_t i = 0; i < n_threads; ++i) {
std::uint32_t index = i % available_parallelism();
std::uint32_t index = i % this->available_parallelism();
queue.queues_[index].push_front(task + i);
threadStates_[index]->notify();
}
// At this point the calling thread can exit and the pool will take over.
// Ultimately, the last completing thread passes the result forward.
// See `if (is_last_thread)` above.
}

inline void static_thread_pool_::bulk_enqueue(
Expand Down Expand Up @@ -1115,8 +1123,11 @@ namespace exec {
}
};

//! The customized operation state for `stdexec::bulk` operations
template <class CvrefSender, class Receiver, class Shape, class Fun, bool MayThrow>
struct static_thread_pool_::bulk_shared_state {
//! The actual `bulk_task` holds a pointer to the shared state
//! and its `__execute` function reads from that shared state.
struct bulk_task : task_base {
bulk_shared_state* sh_state_;

Expand All @@ -1127,6 +1138,9 @@ namespace exec {
auto total_threads = sh_state.num_agents_required();

auto computation = [&](auto&... args) {
// Each computation does one or more call to the the bulk function.
// In the case that the shape is much larger than the total number of threads,
// then each call to computation will call the function many times.
auto [begin, end] = even_share(sh_state.shape_, tid, total_threads);
for (Shape i = begin; i < end; ++i) {
sh_state.fun_(i, args...);
Expand Down Expand Up @@ -1192,6 +1206,8 @@ namespace exec {
std::exception_ptr exception_;
std::vector<bulk_task> tasks_;

//! The number of agents required is the minimum of `shape_` and the available parallelism.
//! That is, we don't need an agent for each of the shape values.
[[nodiscard]]
auto num_agents_required() const -> std::uint32_t {
return static_cast<std::uint32_t>(
Expand All @@ -1205,6 +1221,8 @@ namespace exec {
data_);
}

//! Construct from a pool, receiver, shape, and function.
//! Allocates O(min(shape, available_parallelism())) memory.
bulk_shared_state(static_thread_pool_& pool, Receiver rcvr, Shape shape, Fun fun)
: pool_{pool}
, rcvr_{static_cast<Receiver&&>(rcvr)}
Expand All @@ -1215,6 +1233,8 @@ namespace exec {
}
};


//! A customized receiver to allow parallel execution of `stdexec::bulk` operations:
template <class CvrefSenderId, class ReceiverId, class Shape, class Fun, bool MayThrow>
struct static_thread_pool_::bulk_receiver<CvrefSenderId, ReceiverId, Shape, Fun, MayThrow>::__t {
using __id = bulk_receiver;
Expand Down
6 changes: 6 additions & 0 deletions include/stdexec/__detail/__bulk.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -122,6 +122,10 @@ namespace stdexec {
return {};
};

//! This implements the core default behavior for `bulk`:
//! When setting value, it loops over the shape and invokes the function.
//! Note: This is not done in parallel. That is customized by the scheduler.
//! See, e.g., static_thread_pool::bulk_receiver::__t.
static constexpr auto complete = //
[]<class _Tag, class _State, class _Receiver, class... _Args>(
__ignore,
Expand All @@ -130,8 +134,10 @@ namespace stdexec {
_Tag,
_Args&&... __args) noexcept -> void {
if constexpr (std::same_as<_Tag, set_value_t>) {
// Intercept set_value and dispatch to the bulk operation.
using __shape_t = decltype(__state.__shape_);
if constexpr (noexcept(__state.__fun_(__shape_t{}, __args...))) {
// The noexcept version that doesn't need try/catch:
for (__shape_t __i{}; __i != __state.__shape_; ++__i) {
__state.__fun_(__i, __args...);
}
Expand Down

0 comments on commit 736b23b

Please sign in to comment.