Skip to content

Commit

Permalink
Fix maxwell / repeat_n for static thread pool
Browse files Browse the repository at this point in the history
  • Loading branch information
gevtushenko committed Nov 5, 2023
1 parent 323c761 commit 215ffbe
Showing 1 changed file with 91 additions and 25 deletions.
116 changes: 91 additions & 25 deletions examples/nvexec/maxwell/snr.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -185,64 +185,130 @@ namespace nvexec::STDEXEC_STREAM_DETAIL_NS { //
#endif

namespace repeat_n_detail {

template <class OpT>
class receiver_t {
class receiver_2_t : public stdexec::__receiver_base {
using Sender = typename OpT::PredSender;
using Receiver = typename OpT::Receiver;

OpT& op_state_;

public:
using __t = receiver_t;
using __id = receiver_t;
using is_receiver = void;

public:
template <stdexec::__one_of<ex::set_error_t, ex::set_stopped_t> _Tag, class... _Args>
STDEXEC_ATTRIBUTE((host, device))
friend void tag_invoke(_Tag __tag, receiver_t&& __self, _Args&&... __args) noexcept {
__tag(std::move(__self.op_state_.rcvr_), (_Args&&) __args...);
friend void tag_invoke(_Tag __tag, receiver_2_t&& __self, _Args&&... __args) noexcept {
OpT& op_state = __self.op_state_;
__tag(std::move(op_state.rcvr_), (_Args&&) __args...);
}

friend void tag_invoke(ex::set_value_t, receiver_t&& __self) noexcept {
friend void tag_invoke(ex::set_value_t, receiver_2_t&& __self) noexcept {
using inner_op_state_t = typename OpT::inner_op_state_t;

OpT& op_state = __self.op_state_;
op_state.i_++;

for (std::size_t i = 0; i < op_state.n_; i++) {
stdexec::sync_wait(ex::schedule(exec::inline_scheduler{}) | op_state.closure_);
if (op_state.i_ == op_state.n_) {
stdexec::set_value(std::move(op_state.rcvr_));
return;
}

stdexec::set_value(std::move(op_state.rcvr_));
auto sch = stdexec::get_scheduler(stdexec::get_env(op_state.rcvr_));
inner_op_state_t& inner_op_state = op_state.inner_op_state_.emplace(
stdexec::__conv{[&]() noexcept {
return ex::connect(ex::schedule(sch) | op_state.closure_, receiver_2_t<OpT>{op_state});
}});

ex::start(inner_op_state);
}

friend auto tag_invoke(ex::get_env_t, const receiver_t& self) noexcept
friend auto tag_invoke(ex::get_env_t, const receiver_2_t& self) noexcept
-> stdexec::env_of_t<Receiver> {
return stdexec::get_env(self.op_state_.rcvr_);
}

explicit receiver_t(OpT& op_state)
explicit receiver_2_t(OpT& op_state)
: op_state_(op_state) {
}
};

template <class SenderId, class Closure, class ReceiverId>
template <class OpT>
class receiver_1_t : public stdexec::__receiver_base {
using Receiver = typename OpT::Receiver;

OpT& op_state_;

public:
template <stdexec::__one_of<ex::set_error_t, ex::set_stopped_t> _Tag, class... _Args>
friend void tag_invoke(_Tag __tag, receiver_1_t&& __self, _Args&&... __args) noexcept {
OpT& op_state = __self.op_state_;
__tag(std::move(op_state.rcvr_), (_Args&&) __args...);
}

friend void tag_invoke(ex::set_value_t, receiver_1_t&& __self) noexcept {
using inner_op_state_t = typename OpT::inner_op_state_t;

OpT& op_state = __self.op_state_;

if (op_state.n_) {
auto sch = stdexec::get_scheduler(stdexec::get_env(op_state.rcvr_));
inner_op_state_t& inner_op_state = op_state.inner_op_state_.emplace(
stdexec::__conv{[&]() noexcept {
return ex::connect(
ex::schedule(sch) | op_state.closure_, receiver_2_t<OpT>{op_state});
}});

ex::start(inner_op_state);
} else {
stdexec::set_value(std::move(op_state.rcvr_));
}
}

friend auto tag_invoke(ex::get_env_t, const receiver_1_t& self) noexcept
-> stdexec::env_of_t<Receiver> {
return stdexec::get_env(self.op_state_.rcvr_);
}

explicit receiver_1_t(OpT& op_state)
: op_state_(op_state) {
}
};

template <class PredecessorSenderId, class Closure, class ReceiverId>
struct operation_state_t {
using Sender = stdexec::__t<SenderId>;
using PredSender = stdexec::__t<PredecessorSenderId>;
using Receiver = stdexec::__t<ReceiverId>;
using Scheduler =
stdexec::tag_invoke_result_t<stdexec::get_scheduler_t, stdexec::env_of_t<Receiver>>;
using InnerSender =
std::invoke_result_t<Closure, stdexec::tag_invoke_result_t<stdexec::schedule_t, Scheduler>>;

using inner_op_state_t = stdexec::connect_result_t<Sender, receiver_t<operation_state_t>>;
using predecessor_op_state_t =
ex::connect_result_t<PredSender, receiver_1_t<operation_state_t>>;
using inner_op_state_t = ex::connect_result_t<InnerSender, receiver_2_t<operation_state_t>>;

inner_op_state_t op_state_;
PredSender pred_sender_;
Closure closure_;
Receiver rcvr_;
std::optional<predecessor_op_state_t> pred_op_state_;
std::optional<inner_op_state_t> inner_op_state_;
std::size_t n_{};
std::size_t i_{};

friend void tag_invoke(stdexec::start_t, operation_state_t& self) noexcept {
stdexec::start(self.op_state_);
friend void tag_invoke(stdexec::start_t, operation_state_t& op) noexcept {
if (op.n_) {
stdexec::start(*op.pred_op_state_);
} else {
stdexec::set_value(std::move(op.rcvr_));
}
}

operation_state_t(Sender&& sender, Closure closure, Receiver&& rcvr, std::size_t n)
: op_state_{stdexec::connect((Sender&&) sender, receiver_t<operation_state_t>{*this})}
, closure_{closure}
, rcvr_{(Receiver&&) rcvr}
operation_state_t(PredSender&& pred_sender, Closure closure, Receiver&& rcvr, std::size_t n)
: pred_sender_{(PredSender&&) pred_sender}
, closure_(closure)
, rcvr_(rcvr)
, n_(n) {
pred_op_state_.emplace(stdexec::__conv{[&]() noexcept {
return ex::connect((PredSender&&) pred_sender_, receiver_1_t{*this});
}});
}
};

Expand Down

0 comments on commit 215ffbe

Please sign in to comment.