From a0c440768b0ad565b9425f8a618be9135ba075b5 Mon Sep 17 00:00:00 2001 From: Eugene Zhulenev Date: Mon, 16 Dec 2024 14:07:43 -0800 Subject: [PATCH] [xla:cpu] Enable xnn_threadpool test in OSS PiperOrigin-RevId: 706830430 --- xla/backends/cpu/runtime/xnnpack/BUILD | 3 +- .../runtime/xnnpack/parallel_loop_runner.cc | 265 +++++++++++++++--- .../runtime/xnnpack/parallel_loop_runner.h | 48 +++- .../xnnpack/parallel_loop_runner_test.cc | 123 +++++++- .../cpu/runtime/xnnpack/xnn_threadpool.cc | 40 ++- .../runtime/xnnpack/xnn_threadpool_test.cc | 97 ++++++- 6 files changed, 512 insertions(+), 64 deletions(-) diff --git a/xla/backends/cpu/runtime/xnnpack/BUILD b/xla/backends/cpu/runtime/xnnpack/BUILD index 5bdbacf6cca493..c9fbf122188c9d 100644 --- a/xla/backends/cpu/runtime/xnnpack/BUILD +++ b/xla/backends/cpu/runtime/xnnpack/BUILD @@ -55,7 +55,9 @@ xla_cc_test( ":parallel_loop_runner", "//xla/tsl/concurrency:async_value", "@com_google_absl//absl/algorithm:container", + "@com_google_absl//absl/cleanup", "@com_google_absl//absl/synchronization", + "@com_google_absl//absl/types:span", "@eigen_archive//:eigen3", "@tsl//tsl/platform:env", "@tsl//tsl/platform:test", @@ -96,7 +98,6 @@ cc_library( xla_cc_test( name = "xnn_threadpool_test", srcs = ["xnn_threadpool_test.cc"], - tags = ["no_oss"], deps = [ ":parallel_loop_runner", ":xnn_threadpool", diff --git a/xla/backends/cpu/runtime/xnnpack/parallel_loop_runner.cc b/xla/backends/cpu/runtime/xnnpack/parallel_loop_runner.cc index 780b0ce8bde56c..1ad7f32ff8eb48 100644 --- a/xla/backends/cpu/runtime/xnnpack/parallel_loop_runner.cc +++ b/xla/backends/cpu/runtime/xnnpack/parallel_loop_runner.cc @@ -47,38 +47,165 @@ static tsl::AsyncValueRef OkDoneEventSingleton() { return singleton->AsRef(); } -// Schedules tasks in the [start_index, end_index) range into the Eigen thread -// pool using recursive work splitting. Executes the `start_index` task in the -// caller thread. -static void ScheduleRange(tsl::CountDownAsyncValueRef count_down, - Eigen::ThreadPoolDevice* device, size_t start_index, - size_t end_index, Task task) { +ParallelLoopRunner::ParallelLoopRunner(Eigen::ThreadPoolDevice* device) + : done_event_(OkDoneEventSingleton()), device_(device) {} + +size_t ParallelLoopRunner::num_threads() const { + return device_->numThreadsInPool(); +} + +tsl::AsyncValueRef ParallelLoopRunner::TakeDoneEvent( + ParallelLoopRunner&& runner) { + return std::move(runner.done_event_); +} + +void ParallelLoopRunner::Parallelize( + tsl::CountDownAsyncValueRef count_down, size_t start_index, + size_t end_index, ParallelTask parallel_task) { CHECK_LT(start_index, end_index) << "Invalid task index range"; // Crash OK while (end_index - start_index > 1) { uint64_t mid_index = (start_index + end_index) / 2; - device->enqueueNoNotification([device, mid_index, end_index, task, - count_down] { - ScheduleRange(std::move(count_down), device, mid_index, end_index, task); + device_->enqueueNoNotification([this, mid_index, end_index, parallel_task, + count_down] { + Parallelize(std::move(count_down), mid_index, end_index, parallel_task); }); end_index = mid_index; } - task(start_index); + parallel_task(start_index); count_down.CountDown(); } -ParallelLoopRunner::ParallelLoopRunner(Eigen::ThreadPoolDevice* device) - : done_event_(OkDoneEventSingleton()), device_(device) {} +template +void ParallelLoopRunner::ScheduleOne(Task&& task) { + auto event = tsl::MakeConstructedAsyncValueRef(); + done_event_.AndThen([event, task = std::forward(task)] { + task(); + event.SetStateConcrete(); + }); + done_event_ = std::move(event); +} -size_t ParallelLoopRunner::num_threads() const { - return device_->numThreadsInPool(); +template +void ParallelLoopRunner::ScheduleAll(size_t num_tasks, + ParallelTask&& parallel_task) { + tsl::CountDownAsyncValueRef count_down(num_tasks); + auto count_down_done = count_down.AsRef(); + + done_event_.AndThen([this, num_tasks, count_down = std::move(count_down), + parallel_task = + std::forward(parallel_task)] { + Parallelize(std::move(count_down), 0, num_tasks, std::move(parallel_task)); + }); + done_event_ = std::move(count_down_done); } -tsl::AsyncValueRef ParallelLoopRunner::TakeDoneEvent( - ParallelLoopRunner&& runner) { - return std::move(runner.done_event_); +namespace { + +// Multidimensional index types for the parallel loop runner tasks. We launch +// tasks using one-dimensional `task_index` and convert it into a +// multidimensional index type depending on the loop type. + +struct Task1DTile1DIndex { + size_t offset; + size_t extent; +}; + +struct Task2DTile1DIndex { + size_t i; + size_t offset_j; + size_t extent_j; +}; + +struct Task3DTile2DIndex { + size_t i; + size_t offset_j; + size_t offset_k; + size_t extent_j; + size_t extent_k; +}; + +} // namespace + +static Task1DTile1DIndex Delinearize(size_t task_index, size_t range, + size_t tile) { + size_t offset = task_index * tile; + size_t extent = std::min(range - offset, tile); + return {offset, extent}; } -void ParallelLoopRunner::Parallelize(size_t range, size_t tile, Task1D task) { +static size_t NumTasks(size_t range_i, size_t range_j, size_t tile_j) { + size_t num_tile_j_tasks = tsl::MathUtil::CeilOfRatio(range_j, tile_j); + size_t num_tasks = range_i * num_tile_j_tasks; + DCHECK_GT(num_tasks, 0) << "Expected at least one tile task"; + return num_tasks; +} + +static Task2DTile1DIndex Delinearize(size_t task_index, size_t range_i, + size_t range_j, size_t tile_j) { + size_t num_tile_j_tasks = tsl::MathUtil::CeilOfRatio(range_j, tile_j); + DCHECK_GT(num_tile_j_tasks, 0) << "Expected at least one tile j task"; + + // Compute task indices along the `i` and `j` dimensions. + size_t task_i = task_index / num_tile_j_tasks; + size_t task_j = task_index % num_tile_j_tasks; + + // Convert task index into the offset and extent along the `j` dimension. + size_t offset_j = task_j * tile_j; + size_t extent_j = std::min(range_j - offset_j, tile_j); + + return {task_i, offset_j, extent_j}; +} + +static size_t NumTasks(size_t range_i, size_t range_j, size_t range_k, + size_t tile_j, size_t tile_k) { + size_t num_tile_j_tasks = tsl::MathUtil::CeilOfRatio(range_j, tile_j); + size_t num_tile_k_tasks = tsl::MathUtil::CeilOfRatio(range_k, tile_k); + size_t num_tasks = range_i * num_tile_j_tasks * num_tile_k_tasks; + DCHECK_GT(num_tasks, 0) << "Expected at least one tile task"; + return num_tasks; +} + +static Task3DTile2DIndex Delinearize(size_t task_index, size_t range_i, + size_t range_j, size_t range_k, + size_t tile_j, size_t tile_k) { + size_t num_tile_j_tasks = tsl::MathUtil::CeilOfRatio(range_j, tile_j); + size_t num_tile_k_tasks = tsl::MathUtil::CeilOfRatio(range_k, tile_k); + size_t num_tile_tasks = num_tile_j_tasks * num_tile_k_tasks; + + DCHECK_GT(num_tile_j_tasks, 0) << "Expected at least one tile j task"; + DCHECK_GT(num_tile_k_tasks, 0) << "Expected at least one tile k task"; + + // Compute task indices along the `i`, `j` and `k` dimensions. + size_t task_i = task_index / num_tile_tasks; + task_index %= num_tile_tasks; + + size_t task_j = task_index / num_tile_k_tasks; + task_index %= num_tile_k_tasks; + + size_t task_k = task_index; + + // Convert task indices into the offset and extent along the `j` and `k` + // dimensions. + size_t offset_j = task_j * tile_j; + size_t offset_k = task_k * tile_k; + size_t extent_j = std::min(range_j - offset_j, tile_j); + size_t extent_k = std::min(range_k - offset_k, tile_k); + + return {task_i, offset_j, offset_k, extent_j, extent_k}; +} + +// In the `Parallelize` implementations below: +// +// (1) If done event is already available, execute the task immediately in the +// caller thread. In this case we don't need to overwrite the done event, +// because the existing one will correctly represent the state of the +// parallel loop runner (all scheduled loops are ready). +// +// (2) If done event is not available, we have to overwrite it with a new one +// that will be set to concrete state after the task is executed. + +void ParallelLoopRunner::Parallelize(size_t range, size_t tile, + Task1DTile1D task) { DCHECK(done_event_) << "Parallel loop runner is in moved-from state"; size_t num_tasks = tsl::MathUtil::CeilOfRatio(range, tile); @@ -88,42 +215,92 @@ void ParallelLoopRunner::Parallelize(size_t range, size_t tile, Task1D task) { if (ABSL_PREDICT_TRUE(num_tasks == 1)) { DCHECK_EQ(range, tile) << "Expected range to be equal to tile"; + // Execute task in the caller thread if done event is already available. if (ABSL_PREDICT_TRUE(done_event_.IsConcrete())) { - // If done event is already available, execute the task immediately in the - // caller thread. In this case we don't need to overwrite the done event, - // because the existing one will correctly represent the state of the - // parallel loop runner (all scheduled loops are ready). task(0, range); + return; + } + + // Schedule task when done event becomes available. + ScheduleOne([range, task = std::move(task)] { task(0, range); }); + return; + } + + // Schedule `num_tasks` into the underlying thread pool when done event + // becomes available. + auto parallel_task = [range, tile, + task = std::move(task)](size_t task_index) { + auto x = Delinearize(task_index, range, tile); + task(x.offset, x.extent); + }; + + ScheduleAll(num_tasks, std::move(parallel_task)); +} + +void ParallelLoopRunner::Parallelize(size_t range_i, size_t range_j, + size_t tile_j, Task2DTile1D task) { + DCHECK(done_event_) << "Parallel loop runner is in moved-from state"; + size_t num_tasks = NumTasks(range_i, range_j, tile_j); + + // Fast path for the degenerate parallel loop with single task. + if (ABSL_PREDICT_TRUE(num_tasks == 1)) { + DCHECK_EQ(range_j, tile_j) << "Expected range to be equal to tile"; - } else { - // If done event is not available, we have to overwrite it with a new one - // that will be set to concrete state after the task is executed. - auto done_event = tsl::MakeConstructedAsyncValueRef(); - done_event_.AndThen([range, done_event, task = std::move(task)] { - task(0, range); - done_event.SetStateConcrete(); - }); - done_event_ = std::move(done_event); + // Execute task in the caller thread if done event is already available. + if (ABSL_PREDICT_TRUE(done_event_.IsConcrete())) { + task(0, 0, range_j); + return; } + // Schedule task when done event becomes available. + ScheduleOne([range_j, task = std::move(task)] { task(0, 0, range_j); }); return; } // Schedule `num_tasks` into the underlying thread pool when done event // becomes available. - tsl::CountDownAsyncValueRef count_down(num_tasks); - auto done_event = count_down.AsRef(); - - done_event_.AndThen([this, num_tasks, range, tile, task = std::move(task), - count_down = std::move(count_down)] { - ScheduleRange(std::move(count_down), device_, 0, num_tasks, - [range, tile, task = std::move(task)](size_t task_index) { - size_t offset = task_index * tile; - size_t extent = std::min(range - offset, tile); - task(offset, extent); - }); - }); - done_event_ = std::move(done_event); + auto parallel_task = [range_i, range_j, tile_j, + task = std::move(task)](size_t task_index) { + auto x = Delinearize(task_index, range_i, range_j, tile_j); + task(x.i, x.offset_j, x.extent_j); + }; + + ScheduleAll(num_tasks, std::move(parallel_task)); +} + +void ParallelLoopRunner::Parallelize(size_t range_i, size_t range_j, + size_t range_k, size_t tile_j, + size_t tile_k, Task3DTile2D task) { + DCHECK(done_event_) << "Parallel loop runner is in moved-from state"; + size_t num_tasks = NumTasks(range_i, range_j, range_k, tile_j, tile_k); + + // Fast path for the degenerate parallel loop with single task. + if (ABSL_PREDICT_TRUE(num_tasks == 1)) { + DCHECK_EQ(range_j, tile_j) << "Expected range to be equal to tile"; + DCHECK_EQ(range_k, tile_k) << "Expected range to be equal to tile"; + + // Execute task in the caller thread if done event is already available. + if (ABSL_PREDICT_TRUE(done_event_.IsConcrete())) { + task(0, 0, 0, range_j, range_k); + return; + } + + // Schedule task when done event becomes available. + ScheduleOne([range_j, range_k, task = std::move(task)] { + task(0, 0, 0, range_j, range_k); + }); + return; + } + + // Schedule `num_tasks` into the underlying thread pool when done event + // becomes available. + auto parallel_task = [range_i, range_j, range_k, tile_j, tile_k, + task = std::move(task)](size_t task_index) { + auto x = Delinearize(task_index, range_i, range_j, range_k, tile_j, tile_k); + task(x.i, x.offset_j, x.offset_k, x.extent_j, x.extent_k); + }; + + ScheduleAll(num_tasks, std::move(parallel_task)); } } // namespace xla::cpu diff --git a/xla/backends/cpu/runtime/xnnpack/parallel_loop_runner.h b/xla/backends/cpu/runtime/xnnpack/parallel_loop_runner.h index 661337c9acd072..ccaaf14157f4d5 100644 --- a/xla/backends/cpu/runtime/xnnpack/parallel_loop_runner.h +++ b/xla/backends/cpu/runtime/xnnpack/parallel_loop_runner.h @@ -51,13 +51,37 @@ class ParallelLoopRunner { static tsl::AsyncValueRef TakeDoneEvent( ParallelLoopRunner&& runner); - using Task1D = std::function; + using Task1DTile1D = std::function; + + using Task2DTile1D = + std::function; + + using Task3DTile2D = + std::function; // This function implements a parallel version of a following loop: // // for (size_t i = 0; i < range; i += tile) // task(i, std::min(range - i, tile)); - void Parallelize(size_t range, size_t tile, Task1D task); + void Parallelize(size_t range, size_t tile, Task1DTile1D task); + + // This function implements a parallel version of a following loop: + // + // for (size_t i = 0; i < range_i; i++) + // for (size_t j = 0; j < range_j; j += tile_j) + // task(i, j, min(range_j - j, tile_j)); + void Parallelize(size_t range_i, size_t range_j, size_t tile_j, + Task2DTile1D task); + + // This function implements a parallel version of a following loop: + // + // for (size_t i = 0; i < range_i; i++) + // for (size_t j = 0; j < range_j; j += tile_j) + // for (size_t k = 0; k < range_k; k += tile_k) + // task(i, j, k, min(range_j - j, tile_j), min(range_k - k, tile_k)); + void Parallelize(size_t range_i, size_t range_j, size_t range_k, + size_t tile_j, size_t tile_k, Task3DTile2D task); tsl::AsyncValueRef done_event() const { return done_event_; } Eigen::ThreadPoolDevice* device() const { return device_; } @@ -65,6 +89,26 @@ class ParallelLoopRunner { size_t num_threads() const; private: + using ParallelTask = std::function; + + // Schedules tasks in the [start_index, end_index) range into the Eigen thread + // pool using recursive work splitting. Executes the `start_index` task in the + // caller thread. + void Parallelize(tsl::CountDownAsyncValueRef count_down, + size_t start_index, size_t end_index, + ParallelTask parallel_task); + + // Schedules `task` as the AndThen callback of the `done_event_`. Updates + // `done_event_` to the new completion event. + template + void ScheduleOne(Task&& task); + + // Schedules `num_tasks` invocation of the `parallel_task` into the Eigen + // thread pool when the `done_event_` becomes available. Updates `done_event_` + // to the new completion event. + template + void ScheduleAll(size_t num_tasks, ParallelTask&& parallel_task); + // Async value that signals completion of the last scheduled parallel loop. tsl::AsyncValueRef done_event_; diff --git a/xla/backends/cpu/runtime/xnnpack/parallel_loop_runner_test.cc b/xla/backends/cpu/runtime/xnnpack/parallel_loop_runner_test.cc index 5069ae1664dc50..7ef43eba130ad0 100644 --- a/xla/backends/cpu/runtime/xnnpack/parallel_loop_runner_test.cc +++ b/xla/backends/cpu/runtime/xnnpack/parallel_loop_runner_test.cc @@ -18,9 +18,10 @@ limitations under the License. #include #include #include -#include #include "absl/algorithm/container.h" +#include "absl/cleanup/cleanup.h" +#include "absl/types/span.h" #include "xla/tsl/concurrency/async_value_ref.h" #include "tsl/platform/env.h" #include "tsl/platform/test.h" @@ -33,27 +34,94 @@ limitations under the License. namespace xla::cpu { namespace { -TEST(ParallelLoopRunnerTest, BackToBack1DLoops) { +TEST(ParallelLoopRunnerTest, Parallelize1DTile1D) { tsl::thread::ThreadPool threads(tsl::Env::Default(), "test", 8); Eigen::ThreadPoolDevice device(threads.AsEigenThreadPool(), threads.NumThreads()); ParallelLoopRunner runner(&device); - std::vector data(1024); - auto inc_range = [&](size_t offset, size_t extent) { + constexpr int32_t d0 = 128; + + auto* data = new int32_t[d0](); + auto cleanup = absl::Cleanup([&]() { delete[] data; }); + + auto increment = [&](size_t offset, size_t extent) { for (size_t i = offset; i < offset + extent; ++i) { data[i] += 1; } }; - runner.Parallelize(1024, 1, inc_range); - runner.Parallelize(1024, 2, inc_range); - runner.Parallelize(1024, 3, inc_range); - runner.Parallelize(1024, 4, inc_range); - runner.Parallelize(1024, 5, inc_range); + runner.Parallelize(d0, 1, increment); + runner.Parallelize(d0, 2, increment); + runner.Parallelize(d0, 3, increment); + runner.Parallelize(d0, 4, increment); + runner.Parallelize(d0, 5, increment); + + tsl::BlockUntilReady(ParallelLoopRunner::TakeDoneEvent(std::move(runner))); + ASSERT_TRUE(absl::c_all_of(absl::MakeSpan(&data[0], d0), + [](int32_t value) { return value == 5; })); +} + +TEST(ParallelLoopRunnerTest, Parallelize2DTile1D) { + tsl::thread::ThreadPool threads(tsl::Env::Default(), "test", 8); + Eigen::ThreadPoolDevice device(threads.AsEigenThreadPool(), + threads.NumThreads()); + ParallelLoopRunner runner(&device); + + constexpr int32_t d0 = 4; + constexpr int32_t d1 = 39; + + auto* data = new int32_t[d0][d1](); + auto cleanup = absl::Cleanup([&]() { delete[] data; }); + + auto increment = [&](size_t i, size_t offset_j, size_t extent_j) { + for (size_t j = offset_j; j < offset_j + extent_j; ++j) { + data[i][j] += 1; + } + }; + + runner.Parallelize(d0, d1, 1, increment); + runner.Parallelize(d0, d1, 2, increment); + runner.Parallelize(d0, d1, 3, increment); + runner.Parallelize(d0, d1, 4, increment); + runner.Parallelize(d0, d1, 5, increment); + + tsl::BlockUntilReady(ParallelLoopRunner::TakeDoneEvent(std::move(runner))); + ASSERT_TRUE(absl::c_all_of(absl::MakeSpan(&data[0][0], d0 * d1), + [](int32_t value) { return value == 5; })); +} + +TEST(ParallelLoopRunnerTest, Parallelize3DTile2D) { + tsl::thread::ThreadPool threads(tsl::Env::Default(), "test", 8); + Eigen::ThreadPoolDevice device(threads.AsEigenThreadPool(), + threads.NumThreads()); + ParallelLoopRunner runner(&device); + + constexpr int32_t d0 = 4; + constexpr int32_t d1 = 39; + constexpr int32_t d2 = 63; + + auto* data = new int32_t[d0][d1][d2](); + auto cleanup = absl::Cleanup([&]() { delete[] data; }); + + auto increment = [&](size_t i, size_t offset_j, size_t offset_k, + size_t extent_j, size_t extent_k) { + for (size_t j = offset_j; j < offset_j + extent_j; ++j) { + for (size_t k = offset_k; k < offset_k + extent_k; ++k) { + data[i][j][k] += 1; + } + } + }; + + runner.Parallelize(d0, d1, d2, 1, 5, increment); + runner.Parallelize(d0, d1, d2, 2, 4, increment); + runner.Parallelize(d0, d1, d2, 3, 4, increment); + runner.Parallelize(d0, d1, d2, 4, 3, increment); + runner.Parallelize(d0, d1, d2, 5, 1, increment); tsl::BlockUntilReady(ParallelLoopRunner::TakeDoneEvent(std::move(runner))); - ASSERT_TRUE(absl::c_all_of(data, [](int32_t value) { return value == 5; })); + ASSERT_TRUE(absl::c_all_of(absl::MakeSpan(&data[0][0][0], d0 * d1 * d2), + [](int32_t value) { return value == 5; })); } //===----------------------------------------------------------------------===// @@ -74,5 +142,40 @@ static void BM_SingleTask1DLoop(benchmark::State& state) { BENCHMARK(BM_SingleTask1DLoop); +static void BM_Parallelize2DTile1D(benchmark::State& state) { + tsl::thread::ThreadPool threads(tsl::Env::Default(), "test", 8); + Eigen::ThreadPoolDevice device(threads.AsEigenThreadPool(), + threads.NumThreads()); + ParallelLoopRunner runner(&device); + + size_t range = 4; + size_t tile = 1; + + for (auto _ : state) { + runner.Parallelize(range, range, tile, [](size_t, size_t, size_t) {}); + tsl::BlockUntilReady(runner.done_event()); + } +} + +BENCHMARK(BM_Parallelize2DTile1D); + +static void BM_Parallelize3DTile2D(benchmark::State& state) { + tsl::thread::ThreadPool threads(tsl::Env::Default(), "test", 8); + Eigen::ThreadPoolDevice device(threads.AsEigenThreadPool(), + threads.NumThreads()); + ParallelLoopRunner runner(&device); + + size_t range = 4; + size_t tile = 1; + + for (auto _ : state) { + runner.Parallelize(range, range, range, tile, tile, + [](size_t, size_t, size_t, size_t, size_t) {}); + tsl::BlockUntilReady(runner.done_event()); + } +} + +BENCHMARK(BM_Parallelize3DTile2D); + } // namespace } // namespace xla::cpu diff --git a/xla/backends/cpu/runtime/xnnpack/xnn_threadpool.cc b/xla/backends/cpu/runtime/xnnpack/xnn_threadpool.cc index cc3f9d286398e3..485334286e3386 100644 --- a/xla/backends/cpu/runtime/xnnpack/xnn_threadpool.cc +++ b/xla/backends/cpu/runtime/xnnpack/xnn_threadpool.cc @@ -127,17 +127,43 @@ static size_t GetThreadsCount(pthreadpool_t threadpool) { // NOLINT return Cast(threadpool)->runner()->num_threads(); } -static void Parallelize1dTile1d( // NOLINT +static void Parallelize1DTile1D( // NOLINT pthreadpool_t threadpool, pthreadpool_task_1d_tile_1d_t function, void* context, size_t range, size_t tile, uint32_t flags) { - ParallelLoopRunner::Task1D task = [function, context](size_t offset, - size_t extent) { + ParallelLoopRunner::Task1DTile1D task = [function, context](size_t offset, + size_t extent) { (*function)(context, offset, extent); }; Cast(threadpool)->runner()->Parallelize(range, tile, task); } +static void Parallelize2DTile1D(pthreadpool_t threadpool, // NOLINT + pthreadpool_task_2d_tile_1d_t function, + void* context, size_t range_i, size_t range_j, + size_t tile_j, uint32_t flags) { + ParallelLoopRunner::Task2DTile1D task = + [function, context](size_t offset_i, size_t offset_j, size_t extent_j) { + (*function)(context, offset_i, offset_j, extent_j); + }; + Cast(threadpool)->runner()->Parallelize(range_i, range_j, tile_j, task); +} + +static void Parallelize3DTile2D(pthreadpool_t threadpool, // NOLINT + pthreadpool_task_3d_tile_2d_t function, + void* context, size_t range_i, size_t range_j, + size_t range_k, size_t tile_j, size_t tile_k, + uint32_t flags) { + ParallelLoopRunner::Task3DTile2D task = + [function, context](size_t offset_i, size_t offset_j, size_t offset_k, + size_t extent_j, size_t extent_k) { + (*function)(context, offset_i, offset_j, offset_k, extent_j, extent_k); + }; + Cast(threadpool) + ->runner() + ->Parallelize(range_i, range_j, range_k, tile_j, tile_k, task); +} + } // namespace xla::cpu #if defined(XLA_CPU_USE_CUSTOM_PTHREADPOOL) @@ -177,7 +203,7 @@ extern "C" void pthreadpool_parallelize_1d_with_uarch( extern "C" void pthreadpool_parallelize_1d_tile_1d( pthreadpool_t threadpool, pthreadpool_task_1d_tile_1d_t function, void* context, size_t range, size_t tile, uint32_t flags) { - xla::cpu::Parallelize1dTile1d(threadpool, function, context, range, tile, + xla::cpu::Parallelize1DTile1D(threadpool, function, context, range, tile, flags); } @@ -198,7 +224,8 @@ extern "C" void pthreadpool_parallelize_2d_tile_1d( pthreadpool_t threadpool, pthreadpool_task_2d_tile_1d_t function, void* context, size_t range_i, size_t range_j, size_t tile_j, uint32_t flags) { - LOG(FATAL) << "Not implemented"; + xla::cpu::Parallelize2DTile1D(threadpool, function, context, range_i, range_j, + tile_j, flags); } extern "C" void pthreadpool_parallelize_2d_tile_1d_with_uarch( @@ -274,7 +301,8 @@ extern "C" void pthreadpool_parallelize_3d_tile_2d( pthreadpool_t threadpool, pthreadpool_task_3d_tile_2d_t function, void* context, size_t range_i, size_t range_j, size_t range_k, size_t tile_j, size_t tile_k, uint32_t flags) { - LOG(FATAL) << "Not implemented"; + xla::cpu::Parallelize3DTile2D(threadpool, function, context, range_i, range_j, + range_k, tile_j, tile_k, flags); } extern "C" void pthreadpool_parallelize_3d_tile_2d_with_uarch( diff --git a/xla/backends/cpu/runtime/xnnpack/xnn_threadpool_test.cc b/xla/backends/cpu/runtime/xnnpack/xnn_threadpool_test.cc index 41b9127231ebe8..7cdf1dd1cb91a0 100644 --- a/xla/backends/cpu/runtime/xnnpack/xnn_threadpool_test.cc +++ b/xla/backends/cpu/runtime/xnnpack/xnn_threadpool_test.cc @@ -83,7 +83,49 @@ static xnn_status CreateBinaryOpsSubgraph(xnn_subgraph_t subgraph, return xnn_status_success; } -TEST(XnnThreadPoolTest, BinarySubgraph) { +static xnn_status CreateDotSubgraph(xnn_subgraph_t subgraph, size_t m, size_t n, + size_t k) { + uint32_t lhs_id = XNN_INVALID_VALUE_ID; + uint32_t rhs_id = XNN_INVALID_VALUE_ID; + uint32_t out_id = XNN_INVALID_VALUE_ID; + + std::vector lhs_dims = {m, k}; + std::vector rhs_dims = {k, n}; + std::vector out_dims = {m, n}; + + if (auto s = xnn_define_tensor_value( + subgraph, xnn_datatype_fp32, lhs_dims.size(), lhs_dims.data(), + nullptr, /*external_id=*/0, XNN_VALUE_FLAG_EXTERNAL_INPUT, &lhs_id); + s != xnn_status_success) { + return s; + } + + if (auto s = xnn_define_tensor_value( + subgraph, xnn_datatype_fp32, rhs_dims.size(), rhs_dims.data(), + nullptr, /*external_id=*/1, XNN_VALUE_FLAG_EXTERNAL_INPUT, &rhs_id); + s != xnn_status_success) { + return s; + } + + if (auto s = xnn_define_tensor_value( + subgraph, xnn_datatype_fp32, out_dims.size(), out_dims.data(), + nullptr, + /*external_id=*/2, XNN_VALUE_FLAG_EXTERNAL_OUTPUT, &out_id); + s != xnn_status_success) { + return s; + } + + if (auto s = + xnn_define_batch_matrix_multiply(subgraph, lhs_id, rhs_id, out_id, + /*flags=*/0); + s != xnn_status_success) { + return s; + } + + return xnn_status_success; +} + +TEST(XnnThreadPoolTest, Binary) { pthreadpool_t threadpool = pthreadpool_create(8); ASSERT_NE(threadpool, nullptr); @@ -139,5 +181,58 @@ TEST(XnnThreadPoolTest, BinarySubgraph) { pthreadpool_destroy(threadpool); } +TEST(XnnThreadPoolTest, Dot) { + pthreadpool_t threadpool = pthreadpool_create(8); + ASSERT_NE(threadpool, nullptr); + + ASSERT_EQ(xnn_initialize(/*allocator=*/nullptr), xnn_status_success); + + xnn_workspace_t workspace = nullptr; + ASSERT_EQ(xnn_create_workspace(&workspace), xnn_status_success); + + xnn_subgraph_t subgraph = nullptr; + + ASSERT_EQ( + xnn_create_subgraph(/*external_value_ids=*/3, /*flags=*/0, &subgraph), + xnn_status_success); + + size_t m = 256, k = 256, n = 256; + CreateDotSubgraph(subgraph, m, k, n); + + std::vector lhs(m * k, 1.0f); + std::vector rhs(k * n, 1.0f); + std::vector out(m * n, 0.0f); + + xnn_runtime_t runtime = nullptr; + ASSERT_EQ(xnn_create_runtime_v4(subgraph, nullptr, workspace, threadpool, 0, + &runtime), + xnn_status_success); + + std::vector external_values = { + xnn_external_value{0, lhs.data()}, + xnn_external_value{1, rhs.data()}, + xnn_external_value{2, out.data()}, + }; + + ASSERT_EQ(xnn_reshape_runtime(runtime), xnn_status_success); + ASSERT_EQ(xnn_setup_runtime_v2(runtime, 3, external_values.data()), + xnn_status_success); + + ASSERT_EQ(xnn_invoke_runtime(runtime), xnn_status_success); + + if (ParallelLoopRunner* runner = GetParallelLoopRunner(threadpool)) { + tsl::BlockUntilReady(runner->done_event()); + ASSERT_TRUE(runner->done_event().IsConcrete()); + } + + ASSERT_TRUE(absl::c_all_of(out, [&](float v) { return v == k; })); + + ASSERT_EQ(xnn_delete_runtime(runtime), xnn_status_success); + ASSERT_EQ(xnn_delete_subgraph(subgraph), xnn_status_success); + ASSERT_EQ(xnn_release_workspace(workspace), xnn_status_success); + + pthreadpool_destroy(threadpool); +} + } // namespace } // namespace xla::cpu