From a762e9bc1f2df84f1f8cb2459c09c5224b06c1a5 Mon Sep 17 00:00:00 2001 From: Maikel Nadolski Date: Wed, 18 Dec 2024 20:46:27 +0100 Subject: [PATCH] Relax overload resolution of any_receiver's completion functions --- include/exec/any_sender_of.hpp | 20 ++++++------- include/stdexec/__detail/__receiver_ref.hpp | 12 +++++--- test/exec/test_any_sender.cpp | 31 +++++++++++++++------ 3 files changed, 39 insertions(+), 24 deletions(-) diff --git a/include/exec/any_sender_of.hpp b/include/exec/any_sender_of.hpp index eff7cd631..22793582c 100644 --- a/include/exec/any_sender_of.hpp +++ b/include/exec/any_sender_of.hpp @@ -439,9 +439,7 @@ namespace exec { (*__other.__vtable_)(__copy_construct, this, __other); } - auto operator=(const __t& __other) -> __t& - requires(_Copyable) - { + auto operator=(const __t& __other) -> __t& requires(_Copyable) { if (&__other != this) { __t tmp(__other); *this = std::move(tmp); @@ -615,6 +613,7 @@ namespace exec { , public __query_vfun<_Queries>... { public: using __query_vfun<_Queries>::operator()...; + using __any_::__rcvr_vfun<_Sigs>::operator()...; private: template @@ -674,24 +673,21 @@ namespace exec { } template - requires __one_of + requires __callable<__vtable_t, void*, set_value_t, _As...> void set_value(_As&&... __as) noexcept { - const __any_::__rcvr_vfun* __vfun = __env_.__vtable_; - (*__vfun->__complete_)(__env_.__rcvr_, static_cast<_As&&>(__as)...); + (*__env_.__vtable_)(__env_.__rcvr_, set_value_t(), static_cast<_As&&>(__as)...); } template - requires __one_of + requires __callable<__vtable_t, void*, set_error_t, _Error> void set_error(_Error&& __err) noexcept { - const __any_::__rcvr_vfun* __vfun = __env_.__vtable_; - (*__vfun->__complete_)(__env_.__rcvr_, static_cast<_Error&&>(__err)); + (*__env_.__vtable_)(__env_.__rcvr_, set_error_t(), static_cast<_Error&&>(__err)); } void set_stopped() noexcept - requires __one_of + requires __callable<__vtable_t, void*, set_stopped_t> { - const __any_::__rcvr_vfun* __vfun = __env_.__vtable_; - (*__vfun->__complete_)(__env_.__rcvr_); + (*__env_.__vtable_)(__env_.__rcvr_, set_stopped_t()); } auto get_env() const noexcept -> const __env_t& { diff --git a/include/stdexec/__detail/__receiver_ref.hpp b/include/stdexec/__detail/__receiver_ref.hpp index c0638c009..35101d058 100644 --- a/include/stdexec/__detail/__receiver_ref.hpp +++ b/include/stdexec/__detail/__receiver_ref.hpp @@ -30,16 +30,16 @@ namespace stdexec { namespace __any_ { template struct __rcvr_vfun<_Tag(_Args...)> { - void (*__complete_)(void*, _Args&&...) noexcept; + void (*__complete_)(void*, _Args...) noexcept; - void operator()(void* __obj, _Tag, _Args&&... __args) const noexcept { + void operator()(void* __obj, _Tag, _Args... __args) const noexcept { __complete_(__obj, static_cast<_Args&&>(__args)...); } }; template constexpr auto __rcvr_vfun_fn(_Obj*, _Tag (*)(_Args...)) noexcept { - return +[](void* __ptr, _Args&&... __args) noexcept { + return +[](void* __ptr, _Args... __args) noexcept { _Obj* __obj = static_cast<_Obj*>(__ptr); _Tag()(std::move(_GetReceiver()(*__obj)), static_cast<_Args&&>(__args)...); }; @@ -95,16 +95,20 @@ namespace stdexec { namespace __any_ { } template + requires __callable<__receiver_vtable_for<_Sigs, _Env>, void*, set_value_t, _As...> void set_value(_As&&... __as) noexcept { (*__vtable_)(__op_state_, set_value_t(), static_cast<_As&&>(__as)...); } template + requires __callable<__receiver_vtable_for<_Sigs, _Env>, void*, set_error_t, _Error> void set_error(_Error&& __err) noexcept { (*__vtable_)(__op_state_, set_error_t(), static_cast<_Error&&>(__err)); } - void set_stopped() noexcept { + void set_stopped() noexcept + requires __callable<__receiver_vtable_for<_Sigs, _Env>, void*, set_stopped_t> + { (*__vtable_)(__op_state_, set_stopped_t()); } diff --git a/test/exec/test_any_sender.cpp b/test/exec/test_any_sender.cpp index 48a56f4e2..1fc59cea2 100644 --- a/test/exec/test_any_sender.cpp +++ b/test/exec/test_any_sender.cpp @@ -259,10 +259,9 @@ namespace { TEST_CASE("sync_wait works on any_sender_of", "[types][any_sender]") { int value = 0; any_sender_of sender = just(42) | then([&](int v) noexcept { value = v; }); - CHECK( - std::same_as< - completion_signatures_of_t>, - completion_signatures>); + CHECK(std::same_as< + completion_signatures_of_t>, + completion_signatures>); sync_wait(std::move(sender)); CHECK(value == 42); } @@ -276,10 +275,9 @@ namespace { TEST_CASE("sync_wait returns value", "[types][any_sender]") { any_sender_of sender = just(21) | then([&](int v) noexcept { return 2 * v; }); - CHECK( - std::same_as< - completion_signatures_of_t>, - completion_signatures>); + CHECK(std::same_as< + completion_signatures_of_t>, + completion_signatures>); auto [value1] = *sync_wait(std::move(sender)); CHECK(value1 == 42); } @@ -330,6 +328,23 @@ namespace { } } + template + using my_stoppable_sender_of = + any_sender_of; + + TEST_CASE("any_sender uses overload rules for completion signatures", "[types][any_sender]") { + auto split_sender = split(just(42)); + static_assert(sender_of); + static_assert(sender_of); + my_stoppable_sender_of sender = split_sender; + + auto [value] = *sync_wait(std::move(sender)); + CHECK(value == 42); + + sender = just(21) | then([&](int) -> int { throw 420; }); + CHECK_THROWS_AS(sync_wait(std::move(sender)), int); + } + class stopped_token { private: bool stopped_{true};