From 362fb913c781f4367a3ff932c35741704765b17e Mon Sep 17 00:00:00 2001 From: Maikel Nadolski Date: Wed, 25 Oct 2023 22:45:47 +0200 Subject: [PATCH] Improve global schedule algorithm --- .../__detail/__atomic_intrusive_queue.hpp | 2 +- include/exec/__detail/__bwos_lifo_queue.hpp | 9 +- include/exec/__detail/__xorshift.hpp | 70 +++++ include/exec/static_thread_pool.hpp | 277 ++++++++++++++---- 4 files changed, 289 insertions(+), 69 deletions(-) create mode 100644 include/exec/__detail/__xorshift.hpp diff --git a/include/exec/__detail/__atomic_intrusive_queue.hpp b/include/exec/__detail/__atomic_intrusive_queue.hpp index 71ac4e2c7..7a5d7fd97 100644 --- a/include/exec/__detail/__atomic_intrusive_queue.hpp +++ b/include/exec/__detail/__atomic_intrusive_queue.hpp @@ -23,7 +23,7 @@ namespace exec { class __atomic_intrusive_queue; template - class __atomic_intrusive_queue<_NextPtr> { + class alignas(64) __atomic_intrusive_queue<_NextPtr> { public: using __node_pointer = _Tp*; using __atomic_node_pointer = std::atomic<_Tp*>; diff --git a/include/exec/__detail/__bwos_lifo_queue.hpp b/include/exec/__detail/__bwos_lifo_queue.hpp index f534409a1..9a40bc50b 100644 --- a/include/exec/__detail/__bwos_lifo_queue.hpp +++ b/include/exec/__detail/__bwos_lifo_queue.hpp @@ -399,7 +399,7 @@ namespace exec::bwos { ++back; ++first; } - tail_.store(back, std::memory_order_relaxed); + tail_.store(back, std::memory_order_release); return first; } @@ -413,8 +413,9 @@ namespace exec::bwos { if (front == back) [[unlikely]] { return {lifo_queue_error_code::empty, nullptr}; } - tail_.store(back - 1, std::memory_order_relaxed); - return {lifo_queue_error_code::success, static_cast(ring_buffer_[back - 1])}; + Tp value = static_cast(ring_buffer_[back - 1]); + tail_.store(back - 1, std::memory_order_release); + return {lifo_queue_error_code::success, value}; } template @@ -425,7 +426,7 @@ namespace exec::bwos { result.status = lifo_queue_error_code::done; return result; } - std::uint64_t back = tail_.load(std::memory_order_relaxed); + std::uint64_t back = tail_.load(std::memory_order_acquire); if (spos == back) [[unlikely]] { result.status = lifo_queue_error_code::empty; return result; diff --git a/include/exec/__detail/__xorshift.hpp b/include/exec/__detail/__xorshift.hpp new file mode 100644 index 000000000..6b95b851b --- /dev/null +++ b/include/exec/__detail/__xorshift.hpp @@ -0,0 +1,70 @@ +/* + * Copyright (c) 2023 Maikel Nadolski + * Copyright (c) 2023 NVIDIA Corporation + * + * Licensed under the Apache License Version 2.0 with LLVM Exceptions + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * https://llvm.org/LICENSE.txt + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +/* I have taken and modified this code from https://gist.github.com/Leandros/6dc334c22db135b033b57e9ee0311553 */ +/* Copyright (c) 2018 Arvid Gerstmann. */ +/* This code is licensed under MIT license. */ + +#include +#include + +namespace exec { + + class xorshift { + public: + using result_type = std::uint32_t; + + static constexpr result_type(min)() { + return 0; + } + + static constexpr result_type(max)() { + return UINT32_MAX; + } + + friend bool operator==(xorshift const &, xorshift const &) = default; + + xorshift() + : m_seed(0xc1f651c67c62c6e0ull) { + } + + explicit xorshift(std::random_device &rd) { + seed(rd); + } + + void seed(std::random_device &rd) { + m_seed = std::uint64_t(rd()) << 31 | std::uint64_t(rd()); + } + + result_type operator()() { + std::uint64_t result = m_seed * 0xd989bcacc137dcd5ull; + m_seed ^= m_seed >> 11; + m_seed ^= m_seed << 31; + m_seed ^= m_seed >> 18; + return std::uint32_t(result >> 32ull); + } + + void discard(unsigned long long n) { + for (unsigned long long i = 0; i < n; ++i) + operator()(); + } + + private: + std::uint64_t m_seed; + }; + +} // namespace exec \ No newline at end of file diff --git a/include/exec/static_thread_pool.hpp b/include/exec/static_thread_pool.hpp index 4d1bac481..08e6b6904 100644 --- a/include/exec/static_thread_pool.hpp +++ b/include/exec/static_thread_pool.hpp @@ -1,6 +1,7 @@ /* * Copyright (c) 2021-2022 Facebook, Inc. and its affiliates. * Copyright (c) 2021-2022 NVIDIA Corporation + * Copyright (c) 2023 Maikel Nadolski * * Licensed under the Apache License Version 2.0 with LLVM Exceptions * (the "License"); you may not use this file except in compliance with @@ -22,6 +23,7 @@ #include "../stdexec/__detail/__meta.hpp" #include "./__detail/__bwos_lifo_queue.hpp" #include "./__detail/__atomic_intrusive_queue.hpp" +#include "./__detail/__xorshift.hpp" #include #include @@ -39,6 +41,70 @@ namespace exec { void (*__execute)(task_base*, std::uint32_t tid) noexcept; }; + struct bwos_params { + std::size_t numBlocks{8}; + std::size_t blockSize{1024}; + }; + + struct remote_queue { + remote_queue* next_{}; + std::unique_ptr<__atomic_intrusive_queue<&task_base::next>[]> queues_{}; + std::thread::id id_{std::this_thread::get_id()}; + }; + + struct remote_queue_list { + private: + std::atomic head_; + remote_queue* tail_; + std::size_t nthreads_; + remote_queue this_remotes_; + + public: + explicit remote_queue_list(std::size_t nthreads) noexcept + : head_{&this_remotes_} + , tail_{&this_remotes_} + , nthreads_(nthreads) { + this_remotes_.queues_ = std::make_unique<__atomic_intrusive_queue<&task_base::next>[]>( + nthreads_); + } + + ~remote_queue_list() noexcept { + remote_queue* head = head_.load(std::memory_order_acquire); + while (head != tail_) { + remote_queue* tmp = std::exchange(head, head->next_); + delete tmp; + } + } + + __intrusive_queue<&task_base::next> pop_all_reversed(std::size_t tid) noexcept { + remote_queue* head = head_.load(std::memory_order_acquire); + __intrusive_queue<&task_base::next> tasks{}; + while (head != nullptr) { + tasks.append(head->queues_[tid].pop_all_reversed()); + head = head->next_; + } + return tasks; + } + + remote_queue* get() { + thread_local std::thread::id this_id = std::this_thread::get_id(); + remote_queue* head = head_.load(std::memory_order_acquire); + remote_queue* queue = head; + while (queue != tail_) { + if (queue->id_ == this_id) { + return queue; + } + queue = queue->next_; + } + remote_queue* new_head = new remote_queue{head}; + new_head->queues_ = std::make_unique<__atomic_intrusive_queue<&task_base::next>[]>(nthreads_); + while (!head_.compare_exchange_weak(head, new_head, std::memory_order_acq_rel)) { + new_head->next_ = head; + } + return new_head; + } + }; + class static_thread_pool { template class operation; @@ -122,7 +188,7 @@ namespace exec { public: static_thread_pool(); - static_thread_pool(std::uint32_t threadCount); + static_thread_pool(std::uint32_t threadCount, bwos_params params = {}); ~static_thread_pool(); struct scheduler { @@ -144,7 +210,7 @@ namespace exec { private: template auto make_operation_(Receiver r) const -> operation> { - return operation>{pool_, (Receiver&&) r}; + return operation>{pool_, queue_, (Receiver&&) r}; } template @@ -173,15 +239,17 @@ namespace exec { friend struct static_thread_pool::scheduler; - explicit sender(static_thread_pool& pool) noexcept - : pool_(pool) { + explicit sender(static_thread_pool& pool, remote_queue* queue) noexcept + : pool_(pool) + , queue_(queue) { } static_thread_pool& pool_; + remote_queue* queue_; }; sender make_sender_() const { - return sender{*pool_}; + return sender{*pool_, queue_}; } friend sender tag_invoke(stdexec::schedule_t, const scheduler& s) noexcept { @@ -200,16 +268,22 @@ namespace exec { friend class static_thread_pool; explicit scheduler(static_thread_pool& pool) noexcept - : pool_(&pool) { + : pool_(&pool) + , queue_{pool.get_remote_queue()} { } static_thread_pool* pool_; + remote_queue* queue_; }; scheduler get_scheduler() noexcept { return scheduler{*this}; } + remote_queue* get_remote_queue() noexcept { + return remotes_.get(); + } + void request_stop() noexcept; std::uint32_t available_parallelism() const { @@ -246,10 +320,21 @@ namespace exec { std::uint32_t queueIndex; }; - pop_result try_pop(); + explicit thread_state( + static_thread_pool* pool, + std::uint32_t index, + bwos_params params) noexcept + : local_queue_(params.numBlocks, params.blockSize) + , state_(state::running) + , index_(index) + , pool_(pool) { + std::random_device rd; + rng_.seed(rd); + } + pop_result pop(); - bool try_push(task_base* task); void push(task_base* task); + bool notify(); void request_stop(); void victims(std::vector& victims) { @@ -264,10 +349,6 @@ namespace exec { victims_.erase(victims_.begin()); } - void index(std::uint32_t value) { - index_ = value; - } - std::uint32_t index() const noexcept { return index_; } @@ -277,53 +358,71 @@ namespace exec { } private: - bwos::lifo_queue local_queue_{8, 1024}; - __atomic_intrusive_queue<&task_base::next> remote_queue_{}; + enum state { + running, + stealing, + sleeping, + notified + }; + + pop_result try_pop(); + pop_result try_remote(); + pop_result try_steal(); + + void notify_one_sleeping(); + void set_stealing(); + void clear_stealing(); + + bwos::lifo_queue local_queue_; __intrusive_queue<&task_base::next> pending_queue_{}; std::mutex mut_{}; std::condition_variable cv_{}; bool stopRequested_{false}; - enum state { - running, sleeping, notified - }; - std::atomic state_; std::vector victims_{}; + std::atomic state_; std::uint32_t index_{}; + static_thread_pool* pool_; + xorshift rng_{}; }; void run(std::uint32_t index) noexcept; void join() noexcept; void enqueue(task_base* task) noexcept; + void enqueue(remote_queue& queue, task_base* task) noexcept; template TaskT> void bulk_enqueue(TaskT* task, std::uint32_t n_threads) noexcept; + alignas(64) std::atomic nextThread_; + alignas(64) std::atomic numThiefs_{}; + alignas(64) remote_queue_list remotes_; std::uint32_t threadCount_; + std::uint32_t maxSteals_{(threadCount_ + 1) << 1}; std::vector threads_; - std::vector threadStates_; - std::atomic nextThread_; + std::vector> threadStates_; }; inline static_thread_pool::static_thread_pool() : static_thread_pool(std::thread::hardware_concurrency()) { } - inline static_thread_pool::static_thread_pool(std::uint32_t threadCount) - : threadCount_(threadCount) - , threadStates_(threadCount) - , nextThread_(0) { + inline static_thread_pool::static_thread_pool(std::uint32_t threadCount, bwos_params params) + : nextThread_(0) + , remotes_(threadCount) + , threadCount_(threadCount) + , threadStates_(threadCount) { STDEXEC_ASSERT(threadCount > 0); for (std::uint32_t index = 0; index < threadCount; ++index) { - threadStates_[index].index(index); + threadStates_[index].emplace(this, index, params); } std::vector victims{}; - for (thread_state& state: threadStates_) { - victims.emplace_back(state.as_victim()); + for (auto& state: threadStates_) { + victims.emplace_back(state->as_victim()); } - for (thread_state& state: threadStates_) { - state.victims(victims); + for (auto& state: threadStates_) { + state->victims(victims); } threads_.reserve(threadCount); @@ -345,7 +444,7 @@ namespace exec { inline void static_thread_pool::request_stop() noexcept { for (auto& state: threadStates_) { - state.request_stop(); + state->request_stop(); } } @@ -353,7 +452,7 @@ namespace exec { STDEXEC_ASSERT(threadIndex < threadCount_); while (true) { // Make a blocking call to de-queue a task if we don't already have one. - auto [task, queueIndex] = threadStates_[threadIndex].pop(); + auto [task, queueIndex] = threadStates_[threadIndex]->pop(); if (!task) { return; // pop() only returns null when request_stop() was called. } @@ -369,16 +468,24 @@ namespace exec { } inline void static_thread_pool::enqueue(task_base* task) noexcept { + this->enqueue(*get_remote_queue(), task); + } + + inline void static_thread_pool::enqueue(remote_queue& queue, task_base* task) noexcept { const std::uint32_t threadCount = static_cast(threads_.size()); const std::uint32_t startIndex = nextThread_.fetch_add(1, std::memory_order_relaxed) % threadCount; - threadStates_[startIndex].push(task); + queue.queues_[startIndex].push_front(task); + threadStates_[startIndex]->notify(); } template TaskT> inline void static_thread_pool::bulk_enqueue(TaskT* task, std::uint32_t n_threads) noexcept { + auto& queue = *get_remote_queue(); for (std::size_t i = 0; i < n_threads; ++i) { - threadStates_[i % available_parallelism()].push(task + i); + std::uint32_t index = i % available_parallelism(); + queue.queues_[index].push_front(task + i); + threadStates_[index]->notify(); } } @@ -391,65 +498,105 @@ namespace exec { tmp.clear(); } - inline static_thread_pool::thread_state::pop_result static_thread_pool::thread_state::try_pop() { - std::size_t free_capacity = local_queue_.get_free_capacity(); - std::size_t capacity = local_queue_.get_available_capacity(); - std::size_t threshold = capacity / 2; - if (free_capacity > threshold) { - __intrusive_queue<&task_base::next> remotes = remote_queue_.pop_all_reversed(); - pending_queue_.append(std::move(remotes)); - if (!pending_queue_.empty()) { - move_pending_to_local(pending_queue_, local_queue_); - } + inline static_thread_pool::thread_state::pop_result + static_thread_pool::thread_state::try_remote() { + pop_result result{nullptr, index_}; + __intrusive_queue<&task_base::next> remotes = pool_->remotes_.pop_all_reversed(index_); + pending_queue_.append(std::move(remotes)); + if (!pending_queue_.empty()) { + move_pending_to_local(pending_queue_, local_queue_); + result.task = local_queue_.pop_back(); } + return result; + } + + inline static_thread_pool::thread_state::pop_result static_thread_pool::thread_state::try_pop() { pop_result result{nullptr, index_}; result.task = local_queue_.pop_back(); if (result.task) [[likely]] { return result; } - pending_queue_ = remote_queue_.pop_all_reversed(); - if (!pending_queue_.empty()) { - move_pending_to_local(pending_queue_, local_queue_); - result.task = local_queue_.pop_back(); - return result; + return try_remote(); + } + + inline static_thread_pool::thread_state::pop_result + static_thread_pool::thread_state::try_steal() { + if (victims_.empty()) { + return {nullptr, index_}; } - for (auto& victim: victims_) { - result.task = victim.try_steal(); - if (result.task) { - result.queueIndex = victim.index(); - return result; + std::uniform_int_distribution dist(0, victims_.size() - 1); + std::uint32_t victimIndex = dist(rng_); + auto& v = victims_[victimIndex]; + return {v.try_steal(), v.index()}; + } + + inline void static_thread_pool::thread_state::set_stealing() { + pool_->numThiefs_.fetch_add(1, std::memory_order_relaxed); + } + + inline void static_thread_pool::thread_state::clear_stealing() { + if (pool_->numThiefs_.fetch_sub(1, std::memory_order_relaxed) == 1) { + notify_one_sleeping(); + } + } + + inline void static_thread_pool::thread_state::notify_one_sleeping() { + std::uniform_int_distribution dist(0, pool_->threadCount_ - 1); + std::uint32_t startIndex = dist(rng_); + for (std::uint32_t i = 0; i < pool_->threadCount_; ++i) { + std::uint32_t index = (startIndex + i) % pool_->threadCount_; + if (index == index_) { + continue; + } + if (pool_->threadStates_[index]->notify()) { + return; } } - return result; } inline static_thread_pool::thread_state::pop_result static_thread_pool::thread_state::pop() { pop_result result = try_pop(); while (!result.task) { - std::unique_lock lock{mut_}; + set_stealing(); + for (std::size_t i = 0; i < pool_->maxSteals_; ++i) { + result = try_steal(); + if (result.task) { + clear_stealing(); + return result; + } + } + clear_stealing(); + + std::unique_lock lock{mut_}; if (stopRequested_) { return result; } using namespace std::chrono_literals; // spurious wakeups are fine to look for stealing opportunities state expected = state::running; - if (!state_.compare_exchange_weak(expected, state::sleeping, std::memory_order_relaxed)) { - cv_.wait_for(lock, 100ms); - } else { - lock.unlock(); + if (state_.compare_exchange_weak(expected, state::sleeping, std::memory_order_relaxed)) { + result = try_remote(); + if (result.task) { + return result; + } + cv_.wait(lock); } + lock.unlock(); state_.store(state::running, std::memory_order_relaxed); result = try_pop(); } return result; } - inline void static_thread_pool::thread_state::push(task_base* task) { - remote_queue_.push_front(task); + inline bool static_thread_pool::thread_state::notify() { if (state_.exchange(state::notified, std::memory_order_relaxed) == state::sleeping) { - { std::lock_guard lock{mut_}; } + { + std::lock_guard lock{mut_}; + } cv_.notify_one(); + return true; } + return false; } inline void static_thread_pool::thread_state::request_stop() { @@ -466,10 +613,12 @@ namespace exec { friend static_thread_pool::scheduler::sender; static_thread_pool& pool_; + remote_queue* queue_; Receiver receiver_; - explicit operation(static_thread_pool& pool, Receiver&& r) + explicit operation(static_thread_pool& pool, remote_queue* queue, Receiver&& r) : pool_(pool) + , queue_(queue) , receiver_((Receiver&&) r) { this->__execute = [](task_base* t, const std::uint32_t /* tid */) noexcept { auto& op = *static_cast(t); @@ -485,7 +634,7 @@ namespace exec { } void enqueue_(task_base* op) const { - pool_.enqueue(op); + pool_.enqueue(*queue_, op); } friend void tag_invoke(stdexec::start_t, operation& op) noexcept {