From 215ffbea9886a1d8a49d0cc8608fb89da742897d Mon Sep 17 00:00:00 2001 From: Georgy Evtushenko Date: Sun, 5 Nov 2023 12:13:52 -0800 Subject: [PATCH] Fix maxwell / repeat_n for static thread pool --- examples/nvexec/maxwell/snr.cuh | 116 +++++++++++++++++++++++++------- 1 file changed, 91 insertions(+), 25 deletions(-) diff --git a/examples/nvexec/maxwell/snr.cuh b/examples/nvexec/maxwell/snr.cuh index b5018cd38..56536c3f4 100644 --- a/examples/nvexec/maxwell/snr.cuh +++ b/examples/nvexec/maxwell/snr.cuh @@ -185,64 +185,130 @@ namespace nvexec::STDEXEC_STREAM_DETAIL_NS { // #endif namespace repeat_n_detail { + template - class receiver_t { + class receiver_2_t : public stdexec::__receiver_base { + using Sender = typename OpT::PredSender; using Receiver = typename OpT::Receiver; OpT& op_state_; - public: - using __t = receiver_t; - using __id = receiver_t; - using is_receiver = void; - + public: template _Tag, class... _Args> - STDEXEC_ATTRIBUTE((host, device)) - friend void tag_invoke(_Tag __tag, receiver_t&& __self, _Args&&... __args) noexcept { - __tag(std::move(__self.op_state_.rcvr_), (_Args&&) __args...); + friend void tag_invoke(_Tag __tag, receiver_2_t&& __self, _Args&&... __args) noexcept { + OpT& op_state = __self.op_state_; + __tag(std::move(op_state.rcvr_), (_Args&&) __args...); } - friend void tag_invoke(ex::set_value_t, receiver_t&& __self) noexcept { + friend void tag_invoke(ex::set_value_t, receiver_2_t&& __self) noexcept { + using inner_op_state_t = typename OpT::inner_op_state_t; + OpT& op_state = __self.op_state_; + op_state.i_++; - for (std::size_t i = 0; i < op_state.n_; i++) { - stdexec::sync_wait(ex::schedule(exec::inline_scheduler{}) | op_state.closure_); + if (op_state.i_ == op_state.n_) { + stdexec::set_value(std::move(op_state.rcvr_)); + return; } - stdexec::set_value(std::move(op_state.rcvr_)); + auto sch = stdexec::get_scheduler(stdexec::get_env(op_state.rcvr_)); + inner_op_state_t& inner_op_state = op_state.inner_op_state_.emplace( + stdexec::__conv{[&]() noexcept { + return ex::connect(ex::schedule(sch) | op_state.closure_, receiver_2_t{op_state}); + }}); + + ex::start(inner_op_state); } - friend auto tag_invoke(ex::get_env_t, const receiver_t& self) noexcept + friend auto tag_invoke(ex::get_env_t, const receiver_2_t& self) noexcept -> stdexec::env_of_t { return stdexec::get_env(self.op_state_.rcvr_); } - explicit receiver_t(OpT& op_state) + explicit receiver_2_t(OpT& op_state) : op_state_(op_state) { } }; - template + template + class receiver_1_t : public stdexec::__receiver_base { + using Receiver = typename OpT::Receiver; + + OpT& op_state_; + + public: + template _Tag, class... _Args> + friend void tag_invoke(_Tag __tag, receiver_1_t&& __self, _Args&&... __args) noexcept { + OpT& op_state = __self.op_state_; + __tag(std::move(op_state.rcvr_), (_Args&&) __args...); + } + + friend void tag_invoke(ex::set_value_t, receiver_1_t&& __self) noexcept { + using inner_op_state_t = typename OpT::inner_op_state_t; + + OpT& op_state = __self.op_state_; + + if (op_state.n_) { + auto sch = stdexec::get_scheduler(stdexec::get_env(op_state.rcvr_)); + inner_op_state_t& inner_op_state = op_state.inner_op_state_.emplace( + stdexec::__conv{[&]() noexcept { + return ex::connect( + ex::schedule(sch) | op_state.closure_, receiver_2_t{op_state}); + }}); + + ex::start(inner_op_state); + } else { + stdexec::set_value(std::move(op_state.rcvr_)); + } + } + + friend auto tag_invoke(ex::get_env_t, const receiver_1_t& self) noexcept + -> stdexec::env_of_t { + return stdexec::get_env(self.op_state_.rcvr_); + } + + explicit receiver_1_t(OpT& op_state) + : op_state_(op_state) { + } + }; + + template struct operation_state_t { - using Sender = stdexec::__t; + using PredSender = stdexec::__t; using Receiver = stdexec::__t; + using Scheduler = + stdexec::tag_invoke_result_t>; + using InnerSender = + std::invoke_result_t>; - using inner_op_state_t = stdexec::connect_result_t>; + using predecessor_op_state_t = + ex::connect_result_t>; + using inner_op_state_t = ex::connect_result_t>; - inner_op_state_t op_state_; + PredSender pred_sender_; Closure closure_; Receiver rcvr_; + std::optional pred_op_state_; + std::optional inner_op_state_; std::size_t n_{}; + std::size_t i_{}; - friend void tag_invoke(stdexec::start_t, operation_state_t& self) noexcept { - stdexec::start(self.op_state_); + friend void tag_invoke(stdexec::start_t, operation_state_t& op) noexcept { + if (op.n_) { + stdexec::start(*op.pred_op_state_); + } else { + stdexec::set_value(std::move(op.rcvr_)); + } } - operation_state_t(Sender&& sender, Closure closure, Receiver&& rcvr, std::size_t n) - : op_state_{stdexec::connect((Sender&&) sender, receiver_t{*this})} - , closure_{closure} - , rcvr_{(Receiver&&) rcvr} + operation_state_t(PredSender&& pred_sender, Closure closure, Receiver&& rcvr, std::size_t n) + : pred_sender_{(PredSender&&) pred_sender} + , closure_(closure) + , rcvr_(rcvr) , n_(n) { + pred_op_state_.emplace(stdexec::__conv{[&]() noexcept { + return ex::connect((PredSender&&) pred_sender_, receiver_1_t{*this}); + }}); } };