Skip to content

Commit

Permalink
[xla:cpu] Optimize parallel loop runner
Browse files Browse the repository at this point in the history
Make sure that all scheduled tasks capture <24 bytes to be able to create std::function without extra heap allocation.

BEFORE

------------------------------------------------------------------------
Benchmark                              Time             CPU   Iterations
------------------------------------------------------------------------
BM_HloModule/dot/process_time     560396 ns      6503078 ns          415

AFTER

------------------------------------------------------------------------
Benchmark                              Time             CPU   Iterations
------------------------------------------------------------------------
BM_HloModule/dot/process_time     320843 ns      3224568 ns          858

PiperOrigin-RevId: 708340973
  • Loading branch information
ezhulenev authored and Google-ML-Automation committed Dec 20, 2024
1 parent 211ca72 commit afe7a58
Show file tree
Hide file tree
Showing 2 changed files with 52 additions and 12 deletions.
59 changes: 50 additions & 9 deletions xla/backends/cpu/runtime/xnnpack/parallel_loop_runner.cc
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@ limitations under the License.
#include <cstddef>
#include <cstdint>
#include <functional>
#include <limits>
#include <utility>

#include "absl/base/optimization.h"
Expand Down Expand Up @@ -79,20 +80,60 @@ ParallelLoopRunner::ComputeParallelTaskConfig(size_t num_tasks) const {
return {num_tasks, parallel_task_size, num_parallel_tasks};
}

void ParallelLoopRunner::Parallelize(
tsl::CountDownAsyncValueRef<tsl::Chain> count_down, size_t start_index,
size_t end_index, ParallelTask parallel_task) {
template <typename Index, typename ParallelizeContext>
static void Parallelize(ParallelizeContext* ctx, Index start_index,
Index end_index) {
CHECK_LT(start_index, end_index) << "Invalid task index range"; // Crash OK

// Recursively split the task into two halves and schedule the right half into
// the thread pool.
while (end_index - start_index > 1) {
uint64_t mid_index = (start_index + end_index) / 2;
device_.load()->enqueueNoNotification([this, mid_index, end_index,
parallel_task, count_down] {
Parallelize(std::move(count_down), mid_index, end_index, parallel_task);
Index mid_index = (start_index + end_index) / 2;
ctx->device->enqueueNoNotification([ctx, mid_index, end_index] {
Parallelize(ctx, mid_index, end_index);
});
end_index = mid_index;
}
parallel_task(start_index);
count_down.CountDown();

// Execute the `start_index` task in the caller thread.
ctx->parallel_task(start_index);

// If count down is completed, delete the context.
if (ctx->count_down.CountDown()) {
delete ctx;
}
}

template <typename ParallelTask>
void ParallelLoopRunner::Parallelize(
tsl::CountDownAsyncValueRef<tsl::Chain> count_down, size_t start_index,
size_t end_index, ParallelTask&& parallel_task) {
CHECK_LT(start_index, end_index) << "Invalid task index range"; // Crash OK

struct ParallelizeContext {
ParallelizeContext(tsl::CountDownAsyncValueRef<tsl::Chain> count_down,
const Eigen::ThreadPoolDevice* device,
ParallelTask&& parallel_task)
: count_down(std::move(count_down)),
device(device),
parallel_task(std::forward<ParallelTask>(parallel_task)) {}

tsl::CountDownAsyncValueRef<tsl::Chain> count_down;
const Eigen::ThreadPoolDevice* device;
ParallelTask parallel_task;
};

auto ctx = std::make_unique<ParallelizeContext>(
std::move(count_down), device_,
std::forward<ParallelTask>(parallel_task));

// We try to use uint16_t for index type because it enables small buffer
// optimization in the constructed `std::function` tasks.
if (ABSL_PREDICT_TRUE(end_index <= std::numeric_limits<uint16_t>::max())) {
xla::cpu::Parallelize<uint16_t>(ctx.release(), start_index, end_index);
} else {
xla::cpu::Parallelize<size_t>(ctx.release(), start_index, end_index);
}
}

template <typename Task>
Expand Down
5 changes: 2 additions & 3 deletions xla/backends/cpu/runtime/xnnpack/parallel_loop_runner.h
Original file line number Diff line number Diff line change
Expand Up @@ -96,8 +96,6 @@ class ParallelLoopRunner {
size_t num_threads() const;

private:
using ParallelTask = std::function<void(size_t task_index)>;

// When parallelizing loops, we split the loop iteration space of `num_tasks`
// size into `num_parallel_tasks` parallel tasks, each of which processes
// `parallel_task_size` original tasks sequentially on a single thread. We do
Expand All @@ -120,9 +118,10 @@ class ParallelLoopRunner {
// Schedules tasks in the [start_index, end_index) range into the Eigen thread
// pool using recursive work splitting. Executes the `start_index` task in the
// caller thread.
template <typename ParallelTask>
void Parallelize(tsl::CountDownAsyncValueRef<tsl::Chain> count_down,
size_t start_index, size_t end_index,
ParallelTask parallel_task);
ParallelTask&& parallel_task);

// Schedules `task` as the AndThen callback of the `done_event_`. Updates
// `done_event_` to the new completion event.
Expand Down

0 comments on commit afe7a58

Please sign in to comment.