Skip to content

Commit

Permalink
[xla:cpu] Enable xnn_threadpool test in OSS
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 706830430
  • Loading branch information
ezhulenev authored and Google-ML-Automation committed Dec 16, 2024
1 parent a5b94d6 commit a0c4407
Show file tree
Hide file tree
Showing 6 changed files with 512 additions and 64 deletions.
3 changes: 2 additions & 1 deletion xla/backends/cpu/runtime/xnnpack/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -55,7 +55,9 @@ xla_cc_test(
":parallel_loop_runner",
"//xla/tsl/concurrency:async_value",
"@com_google_absl//absl/algorithm:container",
"@com_google_absl//absl/cleanup",
"@com_google_absl//absl/synchronization",
"@com_google_absl//absl/types:span",
"@eigen_archive//:eigen3",
"@tsl//tsl/platform:env",
"@tsl//tsl/platform:test",
Expand Down Expand Up @@ -96,7 +98,6 @@ cc_library(
xla_cc_test(
name = "xnn_threadpool_test",
srcs = ["xnn_threadpool_test.cc"],
tags = ["no_oss"],
deps = [
":parallel_loop_runner",
":xnn_threadpool",
Expand Down
265 changes: 221 additions & 44 deletions xla/backends/cpu/runtime/xnnpack/parallel_loop_runner.cc
Original file line number Diff line number Diff line change
Expand Up @@ -47,38 +47,165 @@ static tsl::AsyncValueRef<tsl::Chain> OkDoneEventSingleton() {
return singleton->AsRef();
}

// Schedules tasks in the [start_index, end_index) range into the Eigen thread
// pool using recursive work splitting. Executes the `start_index` task in the
// caller thread.
static void ScheduleRange(tsl::CountDownAsyncValueRef<tsl::Chain> count_down,
Eigen::ThreadPoolDevice* device, size_t start_index,
size_t end_index, Task task) {
ParallelLoopRunner::ParallelLoopRunner(Eigen::ThreadPoolDevice* device)
: done_event_(OkDoneEventSingleton()), device_(device) {}

size_t ParallelLoopRunner::num_threads() const {
return device_->numThreadsInPool();
}

tsl::AsyncValueRef<tsl::Chain> ParallelLoopRunner::TakeDoneEvent(
ParallelLoopRunner&& runner) {
return std::move(runner.done_event_);
}

void ParallelLoopRunner::Parallelize(
tsl::CountDownAsyncValueRef<tsl::Chain> count_down, size_t start_index,
size_t end_index, ParallelTask parallel_task) {
CHECK_LT(start_index, end_index) << "Invalid task index range"; // Crash OK
while (end_index - start_index > 1) {
uint64_t mid_index = (start_index + end_index) / 2;
device->enqueueNoNotification([device, mid_index, end_index, task,
count_down] {
ScheduleRange(std::move(count_down), device, mid_index, end_index, task);
device_->enqueueNoNotification([this, mid_index, end_index, parallel_task,
count_down] {
Parallelize(std::move(count_down), mid_index, end_index, parallel_task);
});
end_index = mid_index;
}
task(start_index);
parallel_task(start_index);
count_down.CountDown();
}

ParallelLoopRunner::ParallelLoopRunner(Eigen::ThreadPoolDevice* device)
: done_event_(OkDoneEventSingleton()), device_(device) {}
template <typename Task>
void ParallelLoopRunner::ScheduleOne(Task&& task) {
auto event = tsl::MakeConstructedAsyncValueRef<tsl::Chain>();
done_event_.AndThen([event, task = std::forward<Task>(task)] {
task();
event.SetStateConcrete();
});
done_event_ = std::move(event);
}

size_t ParallelLoopRunner::num_threads() const {
return device_->numThreadsInPool();
template <typename ParallelTask>
void ParallelLoopRunner::ScheduleAll(size_t num_tasks,
ParallelTask&& parallel_task) {
tsl::CountDownAsyncValueRef<tsl::Chain> count_down(num_tasks);
auto count_down_done = count_down.AsRef();

done_event_.AndThen([this, num_tasks, count_down = std::move(count_down),
parallel_task =
std::forward<ParallelTask>(parallel_task)] {
Parallelize(std::move(count_down), 0, num_tasks, std::move(parallel_task));
});
done_event_ = std::move(count_down_done);
}

tsl::AsyncValueRef<tsl::Chain> ParallelLoopRunner::TakeDoneEvent(
ParallelLoopRunner&& runner) {
return std::move(runner.done_event_);
namespace {

// Multidimensional index types for the parallel loop runner tasks. We launch
// tasks using one-dimensional `task_index` and convert it into a
// multidimensional index type depending on the loop type.

struct Task1DTile1DIndex {
size_t offset;
size_t extent;
};

struct Task2DTile1DIndex {
size_t i;
size_t offset_j;
size_t extent_j;
};

struct Task3DTile2DIndex {
size_t i;
size_t offset_j;
size_t offset_k;
size_t extent_j;
size_t extent_k;
};

} // namespace

static Task1DTile1DIndex Delinearize(size_t task_index, size_t range,
size_t tile) {
size_t offset = task_index * tile;
size_t extent = std::min(range - offset, tile);
return {offset, extent};
}

void ParallelLoopRunner::Parallelize(size_t range, size_t tile, Task1D task) {
static size_t NumTasks(size_t range_i, size_t range_j, size_t tile_j) {
size_t num_tile_j_tasks = tsl::MathUtil::CeilOfRatio(range_j, tile_j);
size_t num_tasks = range_i * num_tile_j_tasks;
DCHECK_GT(num_tasks, 0) << "Expected at least one tile task";
return num_tasks;
}

static Task2DTile1DIndex Delinearize(size_t task_index, size_t range_i,
size_t range_j, size_t tile_j) {
size_t num_tile_j_tasks = tsl::MathUtil::CeilOfRatio(range_j, tile_j);
DCHECK_GT(num_tile_j_tasks, 0) << "Expected at least one tile j task";

// Compute task indices along the `i` and `j` dimensions.
size_t task_i = task_index / num_tile_j_tasks;
size_t task_j = task_index % num_tile_j_tasks;

// Convert task index into the offset and extent along the `j` dimension.
size_t offset_j = task_j * tile_j;
size_t extent_j = std::min(range_j - offset_j, tile_j);

return {task_i, offset_j, extent_j};
}

static size_t NumTasks(size_t range_i, size_t range_j, size_t range_k,
size_t tile_j, size_t tile_k) {
size_t num_tile_j_tasks = tsl::MathUtil::CeilOfRatio(range_j, tile_j);
size_t num_tile_k_tasks = tsl::MathUtil::CeilOfRatio(range_k, tile_k);
size_t num_tasks = range_i * num_tile_j_tasks * num_tile_k_tasks;
DCHECK_GT(num_tasks, 0) << "Expected at least one tile task";
return num_tasks;
}

static Task3DTile2DIndex Delinearize(size_t task_index, size_t range_i,
size_t range_j, size_t range_k,
size_t tile_j, size_t tile_k) {
size_t num_tile_j_tasks = tsl::MathUtil::CeilOfRatio(range_j, tile_j);
size_t num_tile_k_tasks = tsl::MathUtil::CeilOfRatio(range_k, tile_k);
size_t num_tile_tasks = num_tile_j_tasks * num_tile_k_tasks;

DCHECK_GT(num_tile_j_tasks, 0) << "Expected at least one tile j task";
DCHECK_GT(num_tile_k_tasks, 0) << "Expected at least one tile k task";

// Compute task indices along the `i`, `j` and `k` dimensions.
size_t task_i = task_index / num_tile_tasks;
task_index %= num_tile_tasks;

size_t task_j = task_index / num_tile_k_tasks;
task_index %= num_tile_k_tasks;

size_t task_k = task_index;

// Convert task indices into the offset and extent along the `j` and `k`
// dimensions.
size_t offset_j = task_j * tile_j;
size_t offset_k = task_k * tile_k;
size_t extent_j = std::min(range_j - offset_j, tile_j);
size_t extent_k = std::min(range_k - offset_k, tile_k);

return {task_i, offset_j, offset_k, extent_j, extent_k};
}

// In the `Parallelize` implementations below:
//
// (1) If done event is already available, execute the task immediately in the
// caller thread. In this case we don't need to overwrite the done event,
// because the existing one will correctly represent the state of the
// parallel loop runner (all scheduled loops are ready).
//
// (2) If done event is not available, we have to overwrite it with a new one
// that will be set to concrete state after the task is executed.

void ParallelLoopRunner::Parallelize(size_t range, size_t tile,
Task1DTile1D task) {
DCHECK(done_event_) << "Parallel loop runner is in moved-from state";

size_t num_tasks = tsl::MathUtil::CeilOfRatio(range, tile);
Expand All @@ -88,42 +215,92 @@ void ParallelLoopRunner::Parallelize(size_t range, size_t tile, Task1D task) {
if (ABSL_PREDICT_TRUE(num_tasks == 1)) {
DCHECK_EQ(range, tile) << "Expected range to be equal to tile";

// Execute task in the caller thread if done event is already available.
if (ABSL_PREDICT_TRUE(done_event_.IsConcrete())) {
// If done event is already available, execute the task immediately in the
// caller thread. In this case we don't need to overwrite the done event,
// because the existing one will correctly represent the state of the
// parallel loop runner (all scheduled loops are ready).
task(0, range);
return;
}

// Schedule task when done event becomes available.
ScheduleOne([range, task = std::move(task)] { task(0, range); });
return;
}

// Schedule `num_tasks` into the underlying thread pool when done event
// becomes available.
auto parallel_task = [range, tile,
task = std::move(task)](size_t task_index) {
auto x = Delinearize(task_index, range, tile);
task(x.offset, x.extent);
};

ScheduleAll(num_tasks, std::move(parallel_task));
}

void ParallelLoopRunner::Parallelize(size_t range_i, size_t range_j,
size_t tile_j, Task2DTile1D task) {
DCHECK(done_event_) << "Parallel loop runner is in moved-from state";
size_t num_tasks = NumTasks(range_i, range_j, tile_j);

// Fast path for the degenerate parallel loop with single task.
if (ABSL_PREDICT_TRUE(num_tasks == 1)) {
DCHECK_EQ(range_j, tile_j) << "Expected range to be equal to tile";

} else {
// If done event is not available, we have to overwrite it with a new one
// that will be set to concrete state after the task is executed.
auto done_event = tsl::MakeConstructedAsyncValueRef<tsl::Chain>();
done_event_.AndThen([range, done_event, task = std::move(task)] {
task(0, range);
done_event.SetStateConcrete();
});
done_event_ = std::move(done_event);
// Execute task in the caller thread if done event is already available.
if (ABSL_PREDICT_TRUE(done_event_.IsConcrete())) {
task(0, 0, range_j);
return;
}

// Schedule task when done event becomes available.
ScheduleOne([range_j, task = std::move(task)] { task(0, 0, range_j); });
return;
}

// Schedule `num_tasks` into the underlying thread pool when done event
// becomes available.
tsl::CountDownAsyncValueRef<tsl::Chain> count_down(num_tasks);
auto done_event = count_down.AsRef();

done_event_.AndThen([this, num_tasks, range, tile, task = std::move(task),
count_down = std::move(count_down)] {
ScheduleRange(std::move(count_down), device_, 0, num_tasks,
[range, tile, task = std::move(task)](size_t task_index) {
size_t offset = task_index * tile;
size_t extent = std::min(range - offset, tile);
task(offset, extent);
});
});
done_event_ = std::move(done_event);
auto parallel_task = [range_i, range_j, tile_j,
task = std::move(task)](size_t task_index) {
auto x = Delinearize(task_index, range_i, range_j, tile_j);
task(x.i, x.offset_j, x.extent_j);
};

ScheduleAll(num_tasks, std::move(parallel_task));
}

void ParallelLoopRunner::Parallelize(size_t range_i, size_t range_j,
size_t range_k, size_t tile_j,
size_t tile_k, Task3DTile2D task) {
DCHECK(done_event_) << "Parallel loop runner is in moved-from state";
size_t num_tasks = NumTasks(range_i, range_j, range_k, tile_j, tile_k);

// Fast path for the degenerate parallel loop with single task.
if (ABSL_PREDICT_TRUE(num_tasks == 1)) {
DCHECK_EQ(range_j, tile_j) << "Expected range to be equal to tile";
DCHECK_EQ(range_k, tile_k) << "Expected range to be equal to tile";

// Execute task in the caller thread if done event is already available.
if (ABSL_PREDICT_TRUE(done_event_.IsConcrete())) {
task(0, 0, 0, range_j, range_k);
return;
}

// Schedule task when done event becomes available.
ScheduleOne([range_j, range_k, task = std::move(task)] {
task(0, 0, 0, range_j, range_k);
});
return;
}

// Schedule `num_tasks` into the underlying thread pool when done event
// becomes available.
auto parallel_task = [range_i, range_j, range_k, tile_j, tile_k,
task = std::move(task)](size_t task_index) {
auto x = Delinearize(task_index, range_i, range_j, range_k, tile_j, tile_k);
task(x.i, x.offset_j, x.offset_k, x.extent_j, x.extent_k);
};

ScheduleAll(num_tasks, std::move(parallel_task));
}

} // namespace xla::cpu
48 changes: 46 additions & 2 deletions xla/backends/cpu/runtime/xnnpack/parallel_loop_runner.h
Original file line number Diff line number Diff line change
Expand Up @@ -51,20 +51,64 @@ class ParallelLoopRunner {
static tsl::AsyncValueRef<tsl::Chain> TakeDoneEvent(
ParallelLoopRunner&& runner);

using Task1D = std::function<void(size_t offset, size_t extent)>;
using Task1DTile1D = std::function<void(size_t offset, size_t extent)>;

using Task2DTile1D =
std::function<void(size_t offset_i, size_t offset_j, size_t extent_j)>;

using Task3DTile2D =
std::function<void(size_t offset_i, size_t offset_j, size_t offset_k,
size_t extent_j, size_t extent_k)>;

// This function implements a parallel version of a following loop:
//
// for (size_t i = 0; i < range; i += tile)
// task(i, std::min(range - i, tile));
void Parallelize(size_t range, size_t tile, Task1D task);
void Parallelize(size_t range, size_t tile, Task1DTile1D task);

// This function implements a parallel version of a following loop:
//
// for (size_t i = 0; i < range_i; i++)
// for (size_t j = 0; j < range_j; j += tile_j)
// task(i, j, min(range_j - j, tile_j));
void Parallelize(size_t range_i, size_t range_j, size_t tile_j,
Task2DTile1D task);

// This function implements a parallel version of a following loop:
//
// for (size_t i = 0; i < range_i; i++)
// for (size_t j = 0; j < range_j; j += tile_j)
// for (size_t k = 0; k < range_k; k += tile_k)
// task(i, j, k, min(range_j - j, tile_j), min(range_k - k, tile_k));
void Parallelize(size_t range_i, size_t range_j, size_t range_k,
size_t tile_j, size_t tile_k, Task3DTile2D task);

tsl::AsyncValueRef<tsl::Chain> done_event() const { return done_event_; }
Eigen::ThreadPoolDevice* device() const { return device_; }

size_t num_threads() const;

private:
using ParallelTask = std::function<void(size_t task_index)>;

// Schedules tasks in the [start_index, end_index) range into the Eigen thread
// pool using recursive work splitting. Executes the `start_index` task in the
// caller thread.
void Parallelize(tsl::CountDownAsyncValueRef<tsl::Chain> count_down,
size_t start_index, size_t end_index,
ParallelTask parallel_task);

// Schedules `task` as the AndThen callback of the `done_event_`. Updates
// `done_event_` to the new completion event.
template <typename Task>
void ScheduleOne(Task&& task);

// Schedules `num_tasks` invocation of the `parallel_task` into the Eigen
// thread pool when the `done_event_` becomes available. Updates `done_event_`
// to the new completion event.
template <typename ParallelTask>
void ScheduleAll(size_t num_tasks, ParallelTask&& parallel_task);

// Async value that signals completion of the last scheduled parallel loop.
tsl::AsyncValueRef<tsl::Chain> done_event_;

Expand Down
Loading

0 comments on commit a0c4407

Please sign in to comment.