diff --git a/xla/backends/cpu/runtime/xnnpack/parallel_loop_runner.cc b/xla/backends/cpu/runtime/xnnpack/parallel_loop_runner.cc index d3405aace8d44..b2597fad8f818 100644 --- a/xla/backends/cpu/runtime/xnnpack/parallel_loop_runner.cc +++ b/xla/backends/cpu/runtime/xnnpack/parallel_loop_runner.cc @@ -19,6 +19,7 @@ limitations under the License. #include #include #include +#include #include #include "absl/base/optimization.h" @@ -79,20 +80,60 @@ ParallelLoopRunner::ComputeParallelTaskConfig(size_t num_tasks) const { return {num_tasks, parallel_task_size, num_parallel_tasks}; } -void ParallelLoopRunner::Parallelize( - tsl::CountDownAsyncValueRef count_down, size_t start_index, - size_t end_index, ParallelTask parallel_task) { +template +static void Parallelize(ParallelizeContext* ctx, Index start_index, + Index end_index) { CHECK_LT(start_index, end_index) << "Invalid task index range"; // Crash OK + + // Recursively split the task into two halves and schedule the right half into + // the thread pool. while (end_index - start_index > 1) { - uint64_t mid_index = (start_index + end_index) / 2; - device_.load()->enqueueNoNotification([this, mid_index, end_index, - parallel_task, count_down] { - Parallelize(std::move(count_down), mid_index, end_index, parallel_task); + Index mid_index = (start_index + end_index) / 2; + ctx->device->enqueueNoNotification([ctx, mid_index, end_index] { + Parallelize(ctx, mid_index, end_index); }); end_index = mid_index; } - parallel_task(start_index); - count_down.CountDown(); + + // Execute the `start_index` task in the caller thread. + ctx->parallel_task(start_index); + + // If count down is completed, delete the context. + if (ctx->count_down.CountDown()) { + delete ctx; + } +} + +template +void ParallelLoopRunner::Parallelize( + tsl::CountDownAsyncValueRef count_down, size_t start_index, + size_t end_index, ParallelTask&& parallel_task) { + CHECK_LT(start_index, end_index) << "Invalid task index range"; // Crash OK + + struct ParallelizeContext { + ParallelizeContext(tsl::CountDownAsyncValueRef count_down, + const Eigen::ThreadPoolDevice* device, + ParallelTask&& parallel_task) + : count_down(std::move(count_down)), + device(device), + parallel_task(std::forward(parallel_task)) {} + + tsl::CountDownAsyncValueRef count_down; + const Eigen::ThreadPoolDevice* device; + ParallelTask parallel_task; + }; + + auto ctx = std::make_unique( + std::move(count_down), device_, + std::forward(parallel_task)); + + // We try to use uint16_t for index type because it enables small buffer + // optimization in the constructed `std::function` tasks. + if (ABSL_PREDICT_TRUE(end_index <= std::numeric_limits::max())) { + xla::cpu::Parallelize(ctx.release(), start_index, end_index); + } else { + xla::cpu::Parallelize(ctx.release(), start_index, end_index); + } } template diff --git a/xla/backends/cpu/runtime/xnnpack/parallel_loop_runner.h b/xla/backends/cpu/runtime/xnnpack/parallel_loop_runner.h index b8c70b6310443..58adc1b5f39b9 100644 --- a/xla/backends/cpu/runtime/xnnpack/parallel_loop_runner.h +++ b/xla/backends/cpu/runtime/xnnpack/parallel_loop_runner.h @@ -96,8 +96,6 @@ class ParallelLoopRunner { size_t num_threads() const; private: - using ParallelTask = std::function; - // When parallelizing loops, we split the loop iteration space of `num_tasks` // size into `num_parallel_tasks` parallel tasks, each of which processes // `parallel_task_size` original tasks sequentially on a single thread. We do @@ -120,9 +118,10 @@ class ParallelLoopRunner { // Schedules tasks in the [start_index, end_index) range into the Eigen thread // pool using recursive work splitting. Executes the `start_index` task in the // caller thread. + template void Parallelize(tsl::CountDownAsyncValueRef count_down, size_t start_index, size_t end_index, - ParallelTask parallel_task); + ParallelTask&& parallel_task); // Schedules `task` as the AndThen callback of the `done_event_`. Updates // `done_event_` to the new completion event.