Skip to content

Commit

Permalink
fix long-standing race condition in split
Browse files Browse the repository at this point in the history
  • Loading branch information
ericniebler committed Dec 31, 2024
1 parent c211de1 commit 434e917
Show file tree
Hide file tree
Showing 3 changed files with 106 additions and 85 deletions.
9 changes: 4 additions & 5 deletions include/stdexec/__detail/__ensure_started.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -79,14 +79,13 @@ namespace stdexec {
static_cast<_Sender&&>(__sndr),
[&]<class _Env, class _Child>(__ignore, _Env&& __env, _Child&& __child) {
// The shared state starts life with a ref-count of one.
auto __sh_state = __make_intrusive<__shared_state<_Child, __decay_t<_Env>>, 2>(
static_cast<_Child&&>(__child), static_cast<_Env&&>(__env));
auto* __sh_state =
new __shared_state{static_cast<_Child&&>(__child), static_cast<_Env&&>(__env)};

// Eagerly start the work:
__sh_state->__try_start();
__sh_state->__try_start(); // cannot throw

return __make_sexpr<__ensure_started_t>(
__box{__ensure_started_t(), std::move(__sh_state)});
return __make_sexpr<__ensure_started_t>(__box{__ensure_started_t(), __sh_state});
});
}
};
Expand Down
176 changes: 99 additions & 77 deletions include/stdexec/__detail/__shared.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,6 @@
#include "__basic_sender.hpp"
#include "__cpo.hpp"
#include "__env.hpp"
#include "__intrusive_ptr.hpp"
#include "__intrusive_slist.hpp"
#include "__optional.hpp"
#include "__meta.hpp"
Expand All @@ -32,8 +31,11 @@
#include "../stop_token.hpp"
#include "../functional.hpp"

#include <atomic>
#include <exception>
#include <mutex>
#include <type_traits>
#include <utility>

namespace stdexec {
////////////////////////////////////////////////////////////////////////////
Expand Down Expand Up @@ -102,7 +104,9 @@ namespace stdexec {
}

~__local_state() {
__sh_state_t::__detach(__sh_state_);
if (__sh_state_) {
__sh_state_->__detach();
}
}

// Stop request callback:
Expand All @@ -126,13 +130,14 @@ namespace stdexec {
// __notify function is called from the shared state's __notify_waiters function, which
// first sets __waiters_ to the completed state. As a result, the attempt to remove `this`
// from the waiters list above will fail and this stop request is ignored.
__sh_state_t::__detach(__sh_state_);
std::exchange(__sh_state_, nullptr)->__detach();
stdexec::set_stopped(static_cast<_Receiver&&>(this->__receiver()));
}

// This is called from __shared_state::__notify_waiters when the input async operation
// completes; or, if it has already completed when start is called, it is called from start:
// __notify cannot race with __on_stop_request. See comment in __on_stop_request.
// __notify cannot race with __local_state::operator(). See comment in
// __local_state::operator().
template <class _Tag>
static void __notify(__local_state_base* __base) noexcept {
auto* const __self = static_cast<__local_state*>(__base);
Expand All @@ -150,11 +155,12 @@ namespace stdexec {
}

static auto __get_sh_state(_CvrefSender& __sndr) noexcept {
return __sndr.apply(static_cast<_CvrefSender&&>(__sndr), __detail::__get_data()).__sh_state_;
auto __box = __sndr.apply(static_cast<_CvrefSender&&>(__sndr), __detail::__get_data());
return std::exchange(__box.__sh_state_, nullptr);
}

using __sh_state_ptr_t = __result_of<__get_sh_state, _CvrefSender&>;
using __sh_state_t = typename __sh_state_ptr_t::element_type;
using __sh_state_t = std::remove_pointer_t<__sh_state_ptr_t>;

__optional<stop_callback_for_t<__stok_t, __local_state&>> __on_stop_{};
__sh_state_ptr_t __sh_state_;
Expand Down Expand Up @@ -193,14 +199,13 @@ namespace stdexec {
};

inline __local_state_base* __get_tombstone() noexcept {
static __local_state_base __tombstone_{{}, nullptr, nullptr};
static constinit __local_state_base __tombstone_{{}, nullptr, nullptr};
return &__tombstone_;
}

//! Heap-allocatable shared state for things like `stdexec::split`.
template <class _CvrefSender, class _Env>
struct __shared_state
: private __enable_intrusive_from_this<__shared_state<_CvrefSender, _Env>, 2> {
struct __shared_state {
using __receiver_t = __t<__receiver<__cvref_id<_CvrefSender>, __id<_Env>>>;
using __waiters_list_t = __intrusive_slist<&__local_state_base::__next_>;

Expand All @@ -213,70 +218,82 @@ namespace stdexec {
__munique<__mbind_front_q<__variant_for, __tuple_for<set_stopped_t>>>::__f,
__tuple_for<set_error_t, std::exception_ptr>>;

static constexpr std::size_t __started_bit = 0;
static constexpr std::size_t __completed_bit = 1;

inplace_stop_source __stop_source_{};
__env_t<_Env> __env_;
__variant_t __results_{}; // Defaults to the "set_stopped" state
std::mutex __mutex_; // This mutex guards access to __waiters_.
__waiters_list_t __waiters_{};
connect_result_t<_CvrefSender, __receiver_t> __shared_op_;
std::atomic_flag __started_{};
std::atomic<std::size_t> __ref_count_{2};

// Let a "consumer" be either a split/ensure_started sender, or an operation
// state created by connecting a split/ensure_started sender to a receiver.
// Let is_running be 1 if the shared operation is currently executing (after
// start has been called but before the receiver's completion functions have
// executed), and 0 otherwise. Then __ref_count_ is equal to:
//
// (2 * (nbr of consumers)) + is_running

explicit __shared_state(_CvrefSender&& __sndr, _Env __env)
: __env_(
__env::__join(
prop{get_stop_token, __stop_source_.get_token()},
static_cast<_Env&&>(__env)))
, __shared_op_(connect(static_cast<_CvrefSender&&>(__sndr), __receiver_t{this})) {
// add one ref count to account for the case where there are no watchers left but the
// shared op is still running.
this->__inc_ref();
}

// The caller of this wants to release their reference to the shared state. The ref
// count must be at least 2 at this point: one owned by the caller, and one added in the
// __shared_state ctor.
static void __detach(__intrusive_ptr<__shared_state, 2>& __ptr) noexcept {
// Ask the intrusive ptr to stop managing the reference count so we can manage it manually.
if (auto* __self = __ptr.__release_()) {
auto __old = __self->__dec_ref();
STDEXEC_ASSERT(__count(__old) >= 2);

if (__count(__old) == 2) {
// The last watcher has released its reference. Asked the shared op to stop.
static_cast<__shared_state*>(__self)->__stop_source_.request_stop();

// Additionally, if the shared op was never started, or if it has already completed,
// then the shared state is no longer needed. Decrement the ref count to 0 here, which
// will delete __self.
if (!__bit<__started_bit>(__old) || __bit<__completed_bit>(__old)) {
__self->__dec_ref();
}
}
void __inc_ref() noexcept {
__ref_count_.fetch_add(2ul, std::memory_order_relaxed);
}

void __dec_ref() noexcept {
if (2ul == __ref_count_.fetch_sub(2ul, std::memory_order_acq_rel)) {
delete this;
}
}

/// @post The started bit is set in the shared state's ref count, OR the __waiters_ list
bool __set_started() noexcept {
if (__started_.test_and_set(std::memory_order_acq_rel)) {
return false; // already started
}
__ref_count_.fetch_add(1ul, std::memory_order_relaxed);
return true;
}

void __set_completed() noexcept {
if (1ul == __ref_count_.fetch_sub(1ul, std::memory_order_acq_rel)) {
delete this;
}
}

void __detach() noexcept {
if (__ref_count_.load(std::memory_order_acq_rel) < 4ul) {
// We are the final "consumer", and we are about to release our reference
// to the shared state. Ask the operation to stop early.
__stop_source_.request_stop();
}
__dec_ref();
}

/// @post The "is running" bit is set in the shared state's ref count, OR the __waiters_ list
/// is set to the known "tombstone" value indicating completion.
void __try_start() noexcept {
// With the split algorithm, multiple split senders can be started simultaneously, but
// only one should start the shared async operation. If the "started" bit is set, then
// only one should start the shared async operation. If the low bit is set, then
// someone else has already started the shared operation. Do nothing.
if (this->template __is_set<__started_bit>()) {
return;
} else if (__bit<__started_bit>(this->template __set_bit<__started_bit>())) {
return;
} else if (__stop_source_.stop_requested()) {
// Stop has already been requested. Rather than starting the operation, complete with
// set_stopped immediately.
// 1. Sets __waiters_ to a known "tombstone" value
// 2. Notifies all the waiters that the operation has stopped
// 3. Sets the "completed" bit in the ref count.
__notify_waiters();
return;
} else {
stdexec::start(__shared_op_);
if (__set_started()) {
// we are the first to start the underlying operation
if (__stop_source_.stop_requested()) {
// Stop has already been requested. Rather than starting the operation, complete with
// set_stopped immediately.
// 1. Sets __waiters_ to a known "tombstone" value.
// 2. Notifies all the waiters that the operation has stopped.
// 3. Sets the "is running" bit in the ref count to 0.
__notify_waiters();
} else {
stdexec::start(__shared_op_);
}
}
}

Expand Down Expand Up @@ -328,22 +345,22 @@ namespace stdexec {
for (auto __itr = __waiters_copy.begin(); __itr != __waiters_copy.end();) {
__local_state_base* __item = *__itr;

// We must increment the iterator before calling notify, since notify
// may end up triggering *__item to be destructed on another thread,
// and the intrusive slist's iterator increment relies on __item.
// We must increment the iterator before calling notify, since notify may end up
// triggering *__item to be destructed on another thread, and the intrusive slist's
// iterator increment relies on __item.
++__itr;

__item->__notify_(__item);
}

// Set the "completed" bit in the ref count. If the ref count is 1, then there are no more
// waiters. Release the final reference.
if (__count(this->template __set_bit<__completed_bit>()) == 1) {
this->__dec_ref(); // release the extra ref count, deletes this
}
// Set the "is running" bit in the ref count to zero. Delete the shared state if the
// ref-count is now zero.
__set_completed();
}
};

template <class _CvrefSender, class _Env>
__shared_state(_CvrefSender&&, _Env) -> __shared_state<_CvrefSender, _Env>;

template <class _Cvref, class _CvrefSender, class _Env>
using __make_completions = //
__try_make_completion_signatures<
Expand Down Expand Up @@ -374,30 +391,36 @@ namespace stdexec {
using __tag_t = __if_c<_Copyable, __split::__split_t, __ensure_started::__ensure_started_t>;
using __sh_state_t = __shared_state<_CvrefSender, _Env>;

__box(__tag_t, __intrusive_ptr<__sh_state_t, 2> __sh_state) noexcept
: __sh_state_(std::move(__sh_state)) {
__box(__tag_t, __sh_state_t* __sh_state) noexcept
: __sh_state_(__sh_state) {
}

__box(__box&& __other) noexcept
: __sh_state_(std::exchange(__other.__sh_state_, nullptr)) {
}

__box(__box&&) noexcept = default;
__box(const __box&) noexcept
__box(const __box& __other) noexcept
requires _Copyable
= default;
: __sh_state_(__other.__sh_state_) {
__sh_state_->__inc_ref();
}

~__box() {
__sh_state_t::__detach(__sh_state_);
if (__sh_state_) {
__sh_state_->__detach();
}
}

__intrusive_ptr<__sh_state_t, 2> __sh_state_;
__sh_state_t* __sh_state_;
};

template <class _CvrefSender, class _Env>
__box(__split::__split_t, __intrusive_ptr<__shared_state<_CvrefSender, _Env>, 2>) //
__box(__split::__split_t, __shared_state<_CvrefSender, _Env>*) //
->__box<_CvrefSender, _Env, true>;

template <class _CvrefSender, class _Env>
__box(
__ensure_started::__ensure_started_t,
__intrusive_ptr<__shared_state<_CvrefSender, _Env>, 2>) -> __box<_CvrefSender, _Env, false>;
__box(__ensure_started::__ensure_started_t, __shared_state<_CvrefSender, _Env>*)
-> __box<_CvrefSender, _Env, false>;

template <class _Tag>
struct __shared_impl : __sexpr_defaults {
Expand All @@ -419,14 +442,13 @@ namespace stdexec {
[]<class _Sender, class _Receiver>(
__local_state<_Sender, _Receiver>& __self,
_Receiver& __rcvr) noexcept -> void {
using __sh_state_t = typename __local_state<_Sender, _Receiver>::__sh_state_t;
// Scenario: there are no more split senders, this is the only operation state, the
// underlying operation has not yet been started, and the receiver's stop token is already
// in the "stop requested" state. Then registering the stop callback will call
// __on_stop_request on __self synchronously. It may also be called asynchronously at
// any point after the callback is registered. Beware. We are guaranteed, however, that
// __on_stop_request will not complete the operation or decrement the shared state's ref
// count until after __self has been added to the waiters list.
// __local_state::operator() on __self synchronously. It may also be called asynchronously
// at any point after the callback is registered. Beware. We are guaranteed, however, that
// __local_state::operator() will not complete the operation or decrement the shared state's
// ref count until after __self has been added to the waiters list.
const auto __stok = stdexec::get_stop_token(stdexec::get_env(__rcvr));
__self.__on_stop_.emplace(__stok, __self);

Expand All @@ -446,7 +468,7 @@ namespace stdexec {
// Otherwise, failed to add the waiter because of a stop-request.
// Complete synchronously with set_stopped().
__self.__on_stop_.reset();
__sh_state_t::__detach(__self.__sh_state_);
std::exchange(__self.__sh_state_, nullptr)->__detach();
stdexec::set_stopped(static_cast<_Receiver&&>(__rcvr));
};
};
Expand Down
6 changes: 3 additions & 3 deletions include/stdexec/__detail/__split.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -73,10 +73,10 @@ namespace stdexec {
static_cast<_Sender&&>(__sndr),
[&]<class _Env, class _Child>(__ignore, _Env&& __env, _Child&& __child) {
// The shared state starts life with a ref-count of one.
auto __sh_state = __make_intrusive<__shared_state<_Child, __decay_t<_Env>>, 2>(
static_cast<_Child&&>(__child), static_cast<_Env&&>(__env));
auto* __sh_state =
new __shared_state{static_cast<_Child&&>(__child), static_cast<_Env&&>(__env)};

return __make_sexpr<__split_t>(__box{__split_t(), std::move(__sh_state)});
return __make_sexpr<__split_t>(__box{__split_t(), __sh_state});
});
}
};
Expand Down

0 comments on commit 434e917

Please sign in to comment.