From 5a9744199a0a7ee3f0aef7de1d32a6a8b3b451a6 Mon Sep 17 00:00:00 2001 From: Eric Niebler Date: Wed, 1 Nov 2023 16:00:52 -0700 Subject: [PATCH] factor out the allocator garbage --- include/stdexec/__detail/__execution_fwd.hpp | 3 + include/stdexec/__detail/__memory.hpp | 96 +++++++++++++++++++ include/stdexec/execution.hpp | 66 ++++--------- .../algos/consumers/test_start_detached.cpp | 41 ++++++++ 4 files changed, 161 insertions(+), 45 deletions(-) create mode 100644 include/stdexec/__detail/__memory.hpp diff --git a/include/stdexec/__detail/__execution_fwd.hpp b/include/stdexec/__detail/__execution_fwd.hpp index 27dd5c1aa..7f7721d21 100644 --- a/include/stdexec/__detail/__execution_fwd.hpp +++ b/include/stdexec/__detail/__execution_fwd.hpp @@ -67,6 +67,7 @@ namespace stdexec { namespace __queries { struct forwarding_query_t; + struct query_or_t; struct execute_may_block_caller_t; struct get_forward_progress_guarantee_t; struct __has_algorithm_customizations_t; @@ -79,6 +80,7 @@ namespace stdexec { } // namespace __queries using __queries::forwarding_query_t; + using __queries::query_or_t; using __queries::execute_may_block_caller_t; using __queries::__has_algorithm_customizations_t; using __queries::get_forward_progress_guarantee_t; @@ -89,6 +91,7 @@ namespace stdexec { using __queries::get_completion_scheduler_t; extern const forwarding_query_t forwarding_query; + extern const query_or_t query_or; extern const execute_may_block_caller_t execute_may_block_caller; extern const __has_algorithm_customizations_t __has_algorithm_customizations; extern const get_forward_progress_guarantee_t get_forward_progress_guarantee; diff --git a/include/stdexec/__detail/__memory.hpp b/include/stdexec/__detail/__memory.hpp new file mode 100644 index 000000000..b3d5cf844 --- /dev/null +++ b/include/stdexec/__detail/__memory.hpp @@ -0,0 +1,96 @@ +/* + * Copyright (c) 2021-2022 NVIDIA Corporation + * + * Licensed under the Apache License Version 2.0 with LLVM Exceptions + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * https://llvm.org/LICENSE.txt + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#pragma once + +#include "__execution_fwd.hpp" +#include "__scope.hpp" + +#include +#include + +namespace stdexec { + namespace __mem { + + template + struct __allocator { + using __value_type = typename std::allocator_traits<_Alloc>::value_type; + + explicit __allocator(_Alloc __alloc) + : __alloc_(std::move(__alloc)) {} + + auto __allocate() { + return std::allocator_traits<_Alloc>::allocate(__alloc_, 1); + } + + void __destroy(__value_type* __ptr) noexcept { + return std::allocator_traits<_Alloc>::destroy(__alloc_, __ptr); + } + + void __deallocate(__value_type* __ptr) noexcept { + return std::allocator_traits<_Alloc>::deallocate(__alloc_, __ptr, 1); + } + + private: + STDEXEC_ATTRIBUTE((no_unique_address)) _Alloc __alloc_; + }; + + template + struct __with_allocator_provider { + // _GetAllocator returns an allocator for the derived type. + using _Derived = typename __result_of<_GetAllocator, _Data&>::__value_type; + + union { + _Data __data_; + }; + + // The __data_ field is already initialized; don't touch it. + constexpr __with_allocator_provider() noexcept {} + + static void* operator new(std::size_t size, _Data&& __data) { + auto __alloc = _GetAllocator(__data); + _Derived* __ptr = __alloc.__allocate(); + [[maybe_unused]] __scope_guard __g{[&]() noexcept { __alloc.__deallocate(__ptr); }}; + ::new(&__ptr->__data_) _Data((_Data&&) __data); + __g.__dismiss(); + return __ptr; + } + + static void operator delete(void* __self, _Data&&) { + _Derived* __ptr = static_cast<_Derived*>(__self); + auto __alloc = _GetAllocator(__ptr->__data_); + __ptr->__data_.~_Data(); + __alloc.__deallocate(__ptr); + } + + static void operator delete(__with_allocator_provider* __self, std::destroying_delete_t) noexcept { + _Derived* __ptr = static_cast<_Derived*>(__self); + auto __alloc = _GetAllocator(__ptr->__data_); + __ptr->__data_.~_Data(); + __alloc.__destroy(__ptr); + __alloc.__deallocate(__ptr); + } + }; + + template + using __alloc_rebind_t = typename std::allocator_traits<_Alloc>::template rebind_alloc<_Ty>; + + template + inline constexpr auto __get_allocator_for = [](auto& __data) noexcept { + auto __alloc = query_or(get_allocator, get_env(__data), std::allocator()); + return __allocator{__alloc_rebind_t{__alloc}}; + }; + } +} diff --git a/include/stdexec/execution.hpp b/include/stdexec/execution.hpp index 9eb7d8768..442a51885 100644 --- a/include/stdexec/execution.hpp +++ b/include/stdexec/execution.hpp @@ -34,6 +34,7 @@ #include "__detail/__meta.hpp" #include "__detail/__scope.hpp" #include "__detail/__basic_sender.hpp" +#include "__detail/__memory.hpp" #include "functional.hpp" #include "concepts.hpp" #include "coroutine.hpp" @@ -2455,70 +2456,45 @@ namespace stdexec { template <__completion_tag _Tag, class... _As> requires __callable<_Tag, _Receiver, _As...> friend void tag_invoke(_Tag __tag, __receiver&& __self, _As&&... __as) noexcept { - __tag((_Receiver&&) __self.__opref_().__rcvr_, (_As&&) __as...); - __self.__delete_op(); - } - - void __delete_op() noexcept { - _Operation* __op = &__opref_(); - if constexpr (__callable>) { - auto&& __env = get_env(__op->__rcvr_); - auto __alloc = get_allocator(__env); - using _Alloc = decltype(__alloc); - using _OpAlloc = typename std::allocator_traits<_Alloc>::template rebind_alloc<_Operation>; - _OpAlloc __op_alloc{__alloc}; - std::allocator_traits<_OpAlloc>::destroy(__op_alloc, __op); - std::allocator_traits<_OpAlloc>::deallocate(__op_alloc, __op, 1); - } else { - delete __op; - } + __tag((_Receiver&&) __self.__opref_().__rcvr(), (_As&&) __as...); + delete &__self.__opref_(); } // Forward all receiever queries. friend auto tag_invoke(get_env_t, const __receiver& __self) noexcept -> env_of_t<_Receiver&> { - return get_env(__self.__opref_().__rcvr_); + return get_env(__self.__opref_().__rcvr()); } }; +#if !defined(__cpp_impl_destroying_delete) || !defined(__cpp_lib_destroying_delete) +#error This library needs support for C++20's destroying delete +#endif + template - struct __operation { + struct __operation + : __mem::__with_allocator_provider< + stdexec::__t<_ReceiverId>, + __mem::__get_allocator_for<__operation<_SenderId, _ReceiverId>>> { using _Sender = stdexec::__t<_SenderId>; using _Receiver = stdexec::__t<_ReceiverId>; using __receiver_t = __receiver<__ref_t<__operation>>; - STDEXEC_ATTRIBUTE((no_unique_address)) _Receiver __rcvr_; - connect_result_t<_Sender, __receiver_t> __op_state_; + explicit __operation(_Sender&& __sndr) + : __op_state_(connect((_Sender&&) __sndr, __receiver_t{__ref(*this)})) { + } - __operation(_Sender&& __sndr, _Receiver __rcvr) - : __rcvr_((_Receiver&&) __rcvr) - , __op_state_(connect((_Sender&&) __sndr, __receiver_t{__ref(*this)})) { + _Receiver& __rcvr() & noexcept { + return this->__with_allocator_provider::__data_; } + + connect_result_t<_Sender, __receiver_t> __op_state_; }; struct __submit_t { template _Sender> void operator()(_Sender&& __sndr, _Receiver __rcvr) const noexcept(false) { - if constexpr (__callable>) { - auto&& __env = get_env(__rcvr); - auto __alloc = get_allocator(__env); - using _Alloc = decltype(__alloc); - using _Op = __operation<__id<_Sender>, __id<_Receiver>>; - using _OpAlloc = typename std::allocator_traits<_Alloc>::template rebind_alloc<_Op>; - _OpAlloc __op_alloc{__alloc}; - auto __op = std::allocator_traits<_OpAlloc>::allocate(__op_alloc, 1); - try { - std::allocator_traits<_OpAlloc>::construct( - __op_alloc, __op, (_Sender&&) __sndr, (_Receiver&&) __rcvr); - start(__op->__op_state_); - } catch (...) { - std::allocator_traits<_OpAlloc>::deallocate(__op_alloc, __op, 1); - throw; - } - } else { - start((new __operation<__id<_Sender>, __id<_Receiver>>{ - (_Sender&&) __sndr, (_Receiver&&) __rcvr}) - ->__op_state_); - } + using __operation_t = __operation<__id<_Sender>, __id<_Receiver>>; + start((new((_Receiver&&) __rcvr) __operation_t{(_Sender&&) __sndr})->__op_state_); } }; } // namespace __submit_ diff --git a/test/stdexec/algos/consumers/test_start_detached.cpp b/test/stdexec/algos/consumers/test_start_detached.cpp index 699b5b648..f314dd27c 100644 --- a/test/stdexec/algos/consumers/test_start_detached.cpp +++ b/test/stdexec/algos/consumers/test_start_detached.cpp @@ -22,6 +22,7 @@ #include #include +#include namespace ex = stdexec; @@ -160,5 +161,45 @@ namespace { exec::make_env(exec::with(ex::get_scheduler, custom_scheduler{}))); CHECK_FALSE(called); } + + struct counting_resource : std::pmr::memory_resource { + counting_resource() = default; + + std::size_t get_count() const noexcept { + return count; + } + + std::size_t get_alive() const noexcept { + return alive; + } + private: + void* do_allocate(std::size_t bytes, std::size_t alignment) override { + ++count; + ++alive; + return std::pmr::new_delete_resource()->allocate(bytes, alignment); + } + void do_deallocate(void* p, std::size_t bytes, std::size_t alignment) override { + --alive; + return std::pmr::new_delete_resource()->deallocate(p, bytes, alignment); + } + bool do_is_equal(const memory_resource& other) const noexcept override { + return this == &other; + } + + std::size_t count = 0, alive = 0; + }; + + // NOT TO SPEC + TEST_CASE("start_detached works with a custom allocator", "[consumers][start_detached]") { + bool called = false; + counting_resource res; + std::pmr::polymorphic_allocator alloc(&res); + ex::start_detached( + ex::just() | ex::then([&] { called = true; }), + exec::make_env(exec::with(ex::get_allocator, alloc))); + CHECK(called); + CHECK(res.get_count() == 1); + CHECK(res.get_alive() == 0); + } } STDEXEC_PRAGMA_POP()