Skip to content

Commit

Permalink
fix assorted bugs in the exec::sequence algorithm (#1424)
Browse files Browse the repository at this point in the history
* fix assorted bugs in the `exec::sequence` algorithm

* add qualification to help nvhpc with name resolution

* try to work around nvhpc bug
  • Loading branch information
ericniebler authored Oct 22, 2024
1 parent 111d41b commit f11f711
Show file tree
Hide file tree
Showing 5 changed files with 163 additions and 114 deletions.
213 changes: 110 additions & 103 deletions include/exec/sequence.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,8 @@
#pragma once

#include <stdexec/execution.hpp>
#include <stdexec/__detail/__manual_lifetime.hpp>
#include <stdexec/__detail/__tuple.hpp>
#include <stdexec/__detail/__variant.hpp>

namespace exec {
namespace _seq {
Expand All @@ -25,121 +26,121 @@ namespace exec {

struct sequence_t {
template <class Sndr>
Sndr operator()(Sndr sndr) const;
STDEXEC_ATTRIBUTE((nodiscard, host, device))
Sndr
operator()(Sndr sndr) const;

template <class... Sndrs>
requires(sizeof...(Sndrs) > 1) && stdexec::__domain::__has_common_domain<Sndrs...>
_sndr<Sndrs...> operator()(Sndrs... sndrs) const;
STDEXEC_ATTRIBUTE((nodiscard, host, device))
_sndr<Sndrs...>
operator()(Sndrs... sndrs) const;
};

template <class... Args>
struct _ops_tuple;

template <class Sndr, class... Rest>
struct _ops_tuple<Sndr, Rest...> : _ops_tuple<Rest...> {
explicit _ops_tuple(Sndr&& sndr, Rest&&... rest)
: _ops_tuple<Rest...>{static_cast<Rest&&>(rest)...}
, _head{static_cast<Sndr&&>(sndr)} {
}

Sndr _head;

_ops_tuple<Rest...>& _tail() noexcept {
return *this;
}
};

template <class Rcvr>
struct _ops_tuple<Rcvr> {
using _rcvr_t = Rcvr;
Rcvr _rcvr;
};

template <class... Args>
union _ops_variant { };

template <class Sndr, class... Rest>
template <class Rcvr, class OpStateId, class Index>
struct _rcvr {
using receiver_concept = stdexec::receiver_t;
using _rcvr_t = typename _ops_tuple<Rest...>::_rcvr_t;
_ops_variant<Sndr, Rest...>* _self;
using _opstate_t = stdexec::__t<OpStateId>;
_opstate_t* _opstate;

template <class... Args>
void set_value(Args&&... args) && noexcept {
auto& sndrs = *_self->_head.__get()._sndrs;
try {
if constexpr (sizeof...(Rest) == 1) {
// destroy _head after completing the operation in case the arguments are references
// to objects owned by _head.
stdexec::set_value(static_cast<_rcvr_t&&>(sndrs._rcvr), static_cast<Args&&>(args)...);
_self->_head.__destroy();
} else {
_self->_head.__destroy();
_self->_tail.__construct(sndrs._head, sndrs._tail()); // potentially throwing
stdexec::start(_self->_tail.__get()._head.__get()._op);
}
} catch (...) {
stdexec::set_error(static_cast<_rcvr_t&&>(sndrs._rcvr), std::current_exception());
}
STDEXEC_ATTRIBUTE((always_inline, host, device))
void
set_value(Args&&... args) && noexcept {
_opstate->_set_value(Index(), static_cast<Args&&>(args)...);
}

template <class Error>
void set_error(Error&& err) && noexcept {
stdexec::set_error(
static_cast<_rcvr_t&&>(_self->_head.__get()._sndrs->_rcvr), static_cast<Error&&>(err));
_self->_head.__destroy();
STDEXEC_ATTRIBUTE((host, device))
void
set_error(Error&& err) && noexcept {
stdexec::set_error(static_cast<Rcvr&&>(_opstate->_rcvr), static_cast<Error&&>(err));
}

void set_stopped() && noexcept {
stdexec::set_stopped(static_cast<_rcvr_t&&>(_self->_head.__get()._sndrs->_rcvr));
_self->_head.__destroy();
STDEXEC_ATTRIBUTE((host, device))
void
set_stopped() && noexcept {
stdexec::set_stopped(static_cast<Rcvr&&>(_opstate->_rcvr));
}

stdexec::env_of_t<_rcvr_t> get_env() const noexcept {
return stdexec::get_env(_self->_head.__get()._sndrs->_rcvr);
// TODO: use the predecessor's completion scheduler as the current scheduler here.
STDEXEC_ATTRIBUTE((host, device))
stdexec::env_of_t<Rcvr>
get_env() const noexcept {
return stdexec::get_env(_opstate->_rcvr);
}
};

template <class Sndr, class... Rest>
requires(sizeof...(Rest) > 0)
union _ops_variant<Sndr, Rest...> {
explicit _ops_variant(Sndr& sndr, _ops_tuple<Rest...>& sndrs) {
auto connect_fn = [&] {
return stdexec::connect(static_cast<Sndr&&>(sndr), _rcvr<Sndr, Rest...>{this});
};
_head.__construct(&sndrs, stdexec::__emplace_from{connect_fn});
}

~_ops_variant() {
}

struct _head_t {
_ops_tuple<Rest...>* _sndrs;
stdexec::connect_result_t<Sndr, _rcvr<Sndr, Rest...>> _op;
};

stdexec::__manual_lifetime<_head_t> _head;
stdexec::__manual_lifetime<_ops_variant<Rest...>> _tail;
};

template <class Rcvr, class... Sndrs>
struct _opstate;

template <class Rcvr, class Sndr, class... Rest>
struct _opstate<Rcvr, Sndr, Rest...> {
template <class Rcvr, class Sndr0, class... Sndrs>
struct _opstate<Rcvr, Sndr0, Sndrs...> {
using operation_state_concept = stdexec::operation_state_t;

_ops_tuple<Rest..., Rcvr> _tupl;
_ops_variant<Sndr, Rest..., Rcvr> _var;
// We will be connecting the first sender in the opstate constructor, so we don't need to
// store it in the opstate. The use of `stdexec::__ignore` causes the first sender to not
// be stored.
using _senders_tuple_t = stdexec::__tuple_for<stdexec::__ignore, Sndrs...>;

template <size_t Idx>
using _rcvr_t = _seq::_rcvr<Rcvr, stdexec::__id<_opstate>, stdexec::__msize_t<Idx>>;

template <class Sndr, class Idx>
using _child_opstate_t = stdexec::connect_result_t<Sndr, _rcvr_t<stdexec::__v<Idx>>>;

using _mk_child_ops_variant_fn =
stdexec::__mzip_with2<stdexec::__q2<_child_opstate_t>, stdexec::__qq<stdexec::__variant_for>>;

using _ops_variant_t = stdexec::__minvoke<
_mk_child_ops_variant_fn,
stdexec::__tuple_for<Sndr0, Sndrs...>,
stdexec::__make_indices<sizeof...(Sndrs) + 1>>;

template <class CvrefSndrs>
STDEXEC_ATTRIBUTE((host, device))
explicit _opstate(Rcvr&& rcvr, CvrefSndrs&& sndrs)
: _rcvr{static_cast<Rcvr&&>(rcvr)}
, _sndrs{_senders_tuple_t::__convert_from(static_cast<CvrefSndrs&&>(sndrs))}
// move all but the first sender into the opstate.
, _ops{} {
// Below, it looks like we are using `sndrs` after it has been moved from. This is not the
// case. `sndrs` is moved into a tuple type that has `__ignore` for the first element. The
// result is that the first sender in `sndrs` is not moved from, but the rest are.
_ops.template emplace_from_at<0>(
stdexec::connect,
stdexec::__tup::get<0>(static_cast<CvrefSndrs&&>(sndrs)),
_rcvr_t<0>{this});
}

explicit _opstate(Rcvr&& rcvr, Sndr sndr, Rest&&... rest)
: _tupl{static_cast<Rest&&>(rest)..., static_cast<Rcvr&&>(rcvr)}
, _var{sndr, _tupl} {
template <class Index, class... Args>
STDEXEC_ATTRIBUTE((host, device))
void
_set_value(Index, [[maybe_unused]] Args&&... args) noexcept {
try {
constexpr size_t Idx = stdexec::__v<Index> + 1;
if constexpr (Idx == sizeof...(Sndrs) + 1) {
stdexec::set_value(static_cast<Rcvr&&>(_rcvr), static_cast<Args&&>(args)...);
} else {
auto& sndr = stdexec::__tup::get<Idx>(_sndrs);
auto& op = _ops.template emplace_from_at<Idx>(
stdexec::connect, std::move(sndr), _rcvr_t<Idx>{this});
stdexec::start(op);
}
} catch (...) {
stdexec::set_error(static_cast<Rcvr&&>(_rcvr), std::current_exception());
}
}

void start() & noexcept {
stdexec::start(_var._head.__get()._op);
STDEXEC_ATTRIBUTE((host, device))
void
start() & noexcept {
stdexec::start(_ops.template get<0>());
}

Rcvr _rcvr;
_senders_tuple_t _sndrs;
_ops_variant_t _ops;
};

// The completions of the sequence sender are the error and stopped completions of all the
Expand Down Expand Up @@ -183,41 +184,47 @@ namespace exec {
};

template <class... Sndrs>
struct _sndr : stdexec::__tuple_for<sequence_t, stdexec::__, Sndrs...> {
struct _sndr {
using sender_concept = stdexec::sender_t;

template <class... Env>
using _completions_t = stdexec::__minvoke<_completions<Env...>, Sndrs...>;

template <class Self, class... Env>
requires(stdexec::__decay_copyable<stdexec::__copy_cvref_t<Self, Sndrs>> && ...)
static auto get_completion_signatures(Self&&, Env&&...) -> _completions_t<Env...> {
STDEXEC_ATTRIBUTE((host, device))
static auto
get_completion_signatures(Self&&, Env&&...) -> _completions_t<Env...> {
return {};
}

template <class Self, class Rcvr>
static auto connect(Self&& self, Rcvr rcvr) {
return self.apply(
[](Rcvr&& rcvr, auto, auto, Sndrs... sndrs) {
return _opstate<Rcvr, Sndrs...>{
static_cast<Rcvr&&>(rcvr), static_cast<Sndrs&&>(sndrs)...};
},
static_cast<typename _sndr::__tuple&&>(self),
static_cast<Rcvr&&>(rcvr));
STDEXEC_ATTRIBUTE((host, device))
static auto
connect(Self&& self, Rcvr rcvr) {
return _opstate<Rcvr, Sndrs...>{static_cast<Rcvr&&>(rcvr), static_cast<Self&&>(self)._sndrs};
}

STDEXEC_ATTRIBUTE((no_unique_address, maybe_unused))
sequence_t _tag; //
STDEXEC_ATTRIBUTE((no_unique_address, maybe_unused))
stdexec::__ignore _ignore; //
stdexec::__tuple_for<Sndrs...> _sndrs;
};

template <class Sndr>
Sndr sequence_t::operator()(Sndr sndr) const {
STDEXEC_ATTRIBUTE((host, device))
Sndr
sequence_t::operator()(Sndr sndr) const {
return sndr;
}

template <class... Sndrs>
requires(sizeof...(Sndrs) > 1) && stdexec::__domain::__has_common_domain<Sndrs...>
_sndr<Sndrs...> sequence_t::operator()(Sndrs... sndrs) const {
return _sndr<Sndrs...>{
{{}, {}, {static_cast<Sndrs&&>(sndrs)}...}
};
STDEXEC_ATTRIBUTE((host, device))
_sndr<Sndrs...>
sequence_t::operator()(Sndrs... sndrs) const {
return _sndr<Sndrs...>{{}, {}, {{static_cast<Sndrs&&>(sndrs)}...}};
}
} // namespace _seq

Expand Down
27 changes: 24 additions & 3 deletions include/stdexec/__detail/__meta.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -80,7 +80,10 @@ namespace stdexec {
enum class __muchar : unsigned char {
};

#if STDEXEC_MSVC()
#if STDEXEC_NVCC() || STDEXEC_NVHPC()
template <std::size_t _Np>
using __msize_t = std::integral_constant<std::size_t, _Np>;
#elif STDEXEC_MSVC()
template <std::size_t _Np>
using __msize_t = __mconstant<_Np>;
#else
Expand Down Expand Up @@ -655,6 +658,18 @@ namespace stdexec {
using __f = __minvoke<_Fn, _As...>;
};

template <std::size_t... _Ns>
struct __muncurry_<__pack::__t<_Ns...> *> {
template <class _Fn>
using __f = __minvoke<_Fn, __msize_t<_Ns>...>;
};

template <template <class _Np, _Np...> class _Cp, class _Np, _Np... _Ns>
struct __muncurry_<_Cp<_Np, _Ns...>> {
template <class _Fn>
using __f = __minvoke<_Fn, std::integral_constant<_Np, _Ns>...>;
};

template <class _What, class... _With>
struct __muncurry_<_ERROR_<_What, _With...>> {
template <class _Fn>
Expand Down Expand Up @@ -829,6 +844,7 @@ namespace stdexec {
template <class _Ty>
using __f = _Id<_Ty>;
};

template <class _Ty>
using __id = __minvoke<__id_<__has_id<_Ty>>, _Ty>;

Expand Down Expand Up @@ -882,8 +898,13 @@ namespace stdexec {
template <class _Fn>
__emplace_from(_Fn) -> __emplace_from<_Fn>;

template <class, class, class, class>
struct __mzip_with2_;
template <class _Fn, class _Continuation, class _List1, class _List2>
struct __mzip_with2_
: __mzip_with2_<
_Fn,
_Continuation,
__mapply<__qq<__types>, _List1>,
__mapply<__qq<__types>, _List2>> { };

template < //
class _Fn, //
Expand Down
9 changes: 9 additions & 0 deletions include/stdexec/__detail/__tuple.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -54,6 +54,15 @@ namespace stdexec {

template <std::size_t... _Is, __indices<_Is...> _Idx, class... _Ts>
struct __tuple<_Idx, _Ts...> : __box<_Ts, _Is>... {
template <class... _Us>
static __tuple __convert_from(__tuple<_Idx, _Us...> &&__tup) {
return __tuple{{static_cast<_Us &&>(__tup.__box<_Us, _Is>::__value)}...};
}

template <class... _Us>
static __tuple __convert_from(__tuple<_Idx, _Us...> const &__tup) {
return __tuple{{__tup.__box<_Us, _Is>::__value}...};
}

template <class _Fn, class _Self, class... _Us>
STDEXEC_ATTRIBUTE((host, device, always_inline))
Expand Down
Loading

0 comments on commit f11f711

Please sign in to comment.