Skip to content

Commit

Permalink
Relax overload resolution of any_receiver's completion functions
Browse files Browse the repository at this point in the history
  • Loading branch information
maikel committed Dec 19, 2024
1 parent ac27beb commit a762e9b
Show file tree
Hide file tree
Showing 3 changed files with 39 additions and 24 deletions.
20 changes: 8 additions & 12 deletions include/exec/any_sender_of.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -439,9 +439,7 @@ namespace exec {
(*__other.__vtable_)(__copy_construct, this, __other);
}

auto operator=(const __t& __other) -> __t&
requires(_Copyable)
{
auto operator=(const __t& __other) -> __t& requires(_Copyable) {
if (&__other != this) {
__t tmp(__other);
*this = std::move(tmp);
Expand Down Expand Up @@ -615,6 +613,7 @@ namespace exec {
, public __query_vfun<_Queries>... {
public:
using __query_vfun<_Queries>::operator()...;
using __any_::__rcvr_vfun<_Sigs>::operator()...;

private:
template <class _Rcvr>
Expand Down Expand Up @@ -674,24 +673,21 @@ namespace exec {
}

template <class... _As>
requires __one_of<set_value_t(_As...), _Sigs...>
requires __callable<__vtable_t, void*, set_value_t, _As...>
void set_value(_As&&... __as) noexcept {
const __any_::__rcvr_vfun<set_value_t(_As...)>* __vfun = __env_.__vtable_;
(*__vfun->__complete_)(__env_.__rcvr_, static_cast<_As&&>(__as)...);
(*__env_.__vtable_)(__env_.__rcvr_, set_value_t(), static_cast<_As&&>(__as)...);
}

template <class _Error>
requires __one_of<set_error_t(_Error), _Sigs...>
requires __callable<__vtable_t, void*, set_error_t, _Error>
void set_error(_Error&& __err) noexcept {
const __any_::__rcvr_vfun<set_error_t(_Error)>* __vfun = __env_.__vtable_;
(*__vfun->__complete_)(__env_.__rcvr_, static_cast<_Error&&>(__err));
(*__env_.__vtable_)(__env_.__rcvr_, set_error_t(), static_cast<_Error&&>(__err));
}

void set_stopped() noexcept
requires __one_of<set_stopped_t(), _Sigs...>
requires __callable<__vtable_t, void*, set_stopped_t>
{
const __any_::__rcvr_vfun<set_stopped_t()>* __vfun = __env_.__vtable_;
(*__vfun->__complete_)(__env_.__rcvr_);
(*__env_.__vtable_)(__env_.__rcvr_, set_stopped_t());
}

auto get_env() const noexcept -> const __env_t& {
Expand Down
12 changes: 8 additions & 4 deletions include/stdexec/__detail/__receiver_ref.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -30,16 +30,16 @@ namespace stdexec { namespace __any_ {

template <class _Tag, class... _Args>
struct __rcvr_vfun<_Tag(_Args...)> {
void (*__complete_)(void*, _Args&&...) noexcept;
void (*__complete_)(void*, _Args...) noexcept;

void operator()(void* __obj, _Tag, _Args&&... __args) const noexcept {
void operator()(void* __obj, _Tag, _Args... __args) const noexcept {
__complete_(__obj, static_cast<_Args&&>(__args)...);
}
};

template <class _GetReceiver = std::identity, class _Obj, class _Tag, class... _Args>
constexpr auto __rcvr_vfun_fn(_Obj*, _Tag (*)(_Args...)) noexcept {
return +[](void* __ptr, _Args&&... __args) noexcept {
return +[](void* __ptr, _Args... __args) noexcept {
_Obj* __obj = static_cast<_Obj*>(__ptr);
_Tag()(std::move(_GetReceiver()(*__obj)), static_cast<_Args&&>(__args)...);
};
Expand Down Expand Up @@ -95,16 +95,20 @@ namespace stdexec { namespace __any_ {
}

template <class... _As>
requires __callable<__receiver_vtable_for<_Sigs, _Env>, void*, set_value_t, _As...>
void set_value(_As&&... __as) noexcept {
(*__vtable_)(__op_state_, set_value_t(), static_cast<_As&&>(__as)...);
}

template <class _Error>
requires __callable<__receiver_vtable_for<_Sigs, _Env>, void*, set_error_t, _Error>
void set_error(_Error&& __err) noexcept {
(*__vtable_)(__op_state_, set_error_t(), static_cast<_Error&&>(__err));
}

void set_stopped() noexcept {
void set_stopped() noexcept
requires __callable<__receiver_vtable_for<_Sigs, _Env>, void*, set_stopped_t>
{
(*__vtable_)(__op_state_, set_stopped_t());
}

Expand Down
31 changes: 23 additions & 8 deletions test/exec/test_any_sender.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -259,10 +259,9 @@ namespace {
TEST_CASE("sync_wait works on any_sender_of", "[types][any_sender]") {
int value = 0;
any_sender_of<set_value_t()> sender = just(42) | then([&](int v) noexcept { value = v; });
CHECK(
std::same_as<
completion_signatures_of_t<any_sender_of<set_value_t()>>,
completion_signatures<set_value_t()>>);
CHECK(std::same_as<
completion_signatures_of_t<any_sender_of<set_value_t()>>,
completion_signatures<set_value_t()>>);
sync_wait(std::move(sender));
CHECK(value == 42);
}
Expand All @@ -276,10 +275,9 @@ namespace {

TEST_CASE("sync_wait returns value", "[types][any_sender]") {
any_sender_of<set_value_t(int)> sender = just(21) | then([&](int v) noexcept { return 2 * v; });
CHECK(
std::same_as<
completion_signatures_of_t<any_sender_of<set_value_t(int)>>,
completion_signatures<set_value_t(int)>>);
CHECK(std::same_as<
completion_signatures_of_t<any_sender_of<set_value_t(int)>>,
completion_signatures<set_value_t(int)>>);
auto [value1] = *sync_wait(std::move(sender));
CHECK(value1 == 42);
}
Expand Down Expand Up @@ -330,6 +328,23 @@ namespace {
}
}

template <class... Vals>
using my_stoppable_sender_of =
any_sender_of<set_value_t(Vals)..., set_error_t(std::exception_ptr), set_stopped_t()>;

TEST_CASE("any_sender uses overload rules for completion signatures", "[types][any_sender]") {
auto split_sender = split(just(42));
static_assert(sender_of<decltype(split_sender), set_error_t(const std::exception_ptr&)>);
static_assert(sender_of<decltype(split_sender), set_value_t(const int&)>);
my_stoppable_sender_of<int> sender = split_sender;

auto [value] = *sync_wait(std::move(sender));
CHECK(value == 42);

sender = just(21) | then([&](int) -> int { throw 420; });
CHECK_THROWS_AS(sync_wait(std::move(sender)), int);
}

class stopped_token {
private:
bool stopped_{true};
Expand Down

0 comments on commit a762e9b

Please sign in to comment.