diff --git a/third_party/tsl/workspace2.bzl b/third_party/tsl/workspace2.bzl index 993450dc31f61..0b1bd5ee69785 100644 --- a/third_party/tsl/workspace2.bzl +++ b/third_party/tsl/workspace2.bzl @@ -111,9 +111,9 @@ def _tf_repositories(): # LINT.IfChange tf_http_archive( name = "XNNPACK", - sha256 = "ca3a5316b8161214f8f22a578fb638f1fccd0585eee40301363ffd026310379a", - strip_prefix = "XNNPACK-a50369c0fdd15f0f35b1a91c964644327a88d480", - urls = tf_mirror_urls("https://github.com/google/XNNPACK/archive/a50369c0fdd15f0f35b1a91c964644327a88d480.zip"), + sha256 = "3306f4178c8594b689165d385e644f03a3154c3be044f6ae36dd170fbf182cf5", + strip_prefix = "XNNPACK-983d013300f19fd3f4e33220b6401408e97a8d12", + urls = tf_mirror_urls("https://github.com/google/XNNPACK/archive/983d013300f19fd3f4e33220b6401408e97a8d12.zip"), ) # LINT.ThenChange(//tensorflow/lite/tools/cmake/modules/xnnpack.cmake) @@ -126,16 +126,16 @@ def _tf_repositories(): tf_http_archive( name = "pthreadpool", - sha256 = "b96413b10dd8edaa4f6c0a60c6cf5ef55eebeef78164d5d69294c8173457f0ec", - strip_prefix = "pthreadpool-b8374f80e42010941bda6c85b0e3f1a1bd77a1e0", - urls = tf_mirror_urls("https://github.com/Maratyszcza/pthreadpool/archive/b8374f80e42010941bda6c85b0e3f1a1bd77a1e0.zip"), + sha256 = "a4cf06de57bfdf8d7b537c61f1c3071bce74e57524fe053e0bbd2332feca7f95", + strip_prefix = "pthreadpool-4fe0e1e183925bf8cfa6aae24237e724a96479b8", + urls = tf_mirror_urls("https://github.com/Maratyszcza/pthreadpool/archive/4fe0e1e183925bf8cfa6aae24237e724a96479b8.zip"), ) tf_http_archive( name = "cpuinfo", - strip_prefix = "cpuinfo-5e63739504f0f8e18e941bd63b2d6d42536c7d90", - sha256 = "18eca9bc8d9c4ce5496d0d2be9f456d55cbbb5f0639a551ce9c8bac2e84d85fe", - urls = tf_mirror_urls("https://github.com/pytorch/cpuinfo/archive/5e63739504f0f8e18e941bd63b2d6d42536c7d90.tar.gz"), + sha256 = "52e0ffd7998d8cb3a927d8a6e1145763744d866d2be09c4eccea27fc157b6bb0", + strip_prefix = "cpuinfo-cebb0933058d7f181c979afd50601dc311e1bf8c", + urls = tf_mirror_urls("https://github.com/pytorch/cpuinfo/archive/cebb0933058d7f181c979afd50601dc311e1bf8c.zip"), ) tf_http_archive( diff --git a/workspace2.bzl b/workspace2.bzl index bf6f4088e9743..20ffe0cf82606 100644 --- a/workspace2.bzl +++ b/workspace2.bzl @@ -42,12 +42,42 @@ def _tf_repositories(): # curl -L | sha256sum # and update the sha256 with the result. + # LINT.IfChange tf_http_archive( name = "XNNPACK", sha256 = "3306f4178c8594b689165d385e644f03a3154c3be044f6ae36dd170fbf182cf5", strip_prefix = "XNNPACK-983d013300f19fd3f4e33220b6401408e97a8d12", urls = tf_mirror_urls("https://github.com/google/XNNPACK/archive/983d013300f19fd3f4e33220b6401408e97a8d12.zip"), ) + # LINT.ThenChange(//tensorflow/lite/tools/cmake/modules/xnnpack.cmake) + + tf_http_archive( + name = "KleidiAI", + sha256 = "ad37707084a6d4ff41be10cbe8540c75bea057ba79d0de6c367c1bfac6ba0852", + strip_prefix = "kleidiai-40a926833857fb64786e02f97703e42b1537cb57", + urls = tf_mirror_urls("https://gitlab.arm.com/kleidi/kleidiai/-/archive/40a926833857fb64786e02f97703e42b1537cb57/kleidiai-40a926833857fb64786e02f97703e42b1537cb57.zip"), + ) + + tf_http_archive( + name = "FXdiv", + sha256 = "3d7b0e9c4c658a84376a1086126be02f9b7f753caa95e009d9ac38d11da444db", + strip_prefix = "FXdiv-63058eff77e11aa15bf531df5dd34395ec3017c8", + urls = tf_mirror_urls("https://github.com/Maratyszcza/FXdiv/archive/63058eff77e11aa15bf531df5dd34395ec3017c8.zip"), + ) + + tf_http_archive( + name = "cpuinfo", + sha256 = "52e0ffd7998d8cb3a927d8a6e1145763744d866d2be09c4eccea27fc157b6bb0", + strip_prefix = "cpuinfo-cebb0933058d7f181c979afd50601dc311e1bf8c", + urls = tf_mirror_urls("https://github.com/pytorch/cpuinfo/archive/cebb0933058d7f181c979afd50601dc311e1bf8c.zip"), + ) + + tf_http_archive( + name = "pthreadpool", + sha256 = "a4cf06de57bfdf8d7b537c61f1c3071bce74e57524fe053e0bbd2332feca7f95", + strip_prefix = "pthreadpool-4fe0e1e183925bf8cfa6aae24237e724a96479b8", + urls = tf_mirror_urls("https://github.com/Maratyszcza/pthreadpool/archive/4fe0e1e183925bf8cfa6aae24237e724a96479b8.zip"), + ) tf_http_archive( name = "jsoncpp_git", diff --git a/xla/backends/cpu/runtime/xnnpack/BUILD b/xla/backends/cpu/runtime/xnnpack/BUILD index 96447df288c58..5bdbacf6cca49 100644 --- a/xla/backends/cpu/runtime/xnnpack/BUILD +++ b/xla/backends/cpu/runtime/xnnpack/BUILD @@ -71,3 +71,44 @@ cc_library( "@XNNPACK", ], ) + +cc_library( + name = "xnn_threadpool", + srcs = ["xnn_threadpool.cc"], + hdrs = ["xnn_threadpool.h"], + # copybara:uncomment_begin(google-only) + # local_defines = select({ + # "@pthreadpool:pthreadpool_header_only_explicit_true": [ + # "XLA_CPU_USE_CUSTOM_PTHREADPOOL", + # ], + # "//conditions:default": [], + # }), + # copybara:uncomment_end + deps = [ + ":parallel_loop_runner", + "@eigen_archive//:eigen3", + "@pthreadpool", + "@tsl//tsl/platform:env", + "@tsl//tsl/platform:logging", + ], +) + +xla_cc_test( + name = "xnn_threadpool_test", + srcs = ["xnn_threadpool_test.cc"], + tags = ["no_oss"], + deps = [ + ":parallel_loop_runner", + ":xnn_threadpool", + "//xla/tsl/concurrency:async_value", + "@XNNPACK", + "@com_google_absl//absl/algorithm:container", + "@com_google_absl//absl/synchronization", + "@eigen_archive//:eigen3", + "@pthreadpool", + "@tsl//tsl/platform:env", + "@tsl//tsl/platform:test", + "@tsl//tsl/platform:test_benchmark", + "@tsl//tsl/platform:test_main", + ], +) diff --git a/xla/backends/cpu/runtime/xnnpack/parallel_loop_runner.cc b/xla/backends/cpu/runtime/xnnpack/parallel_loop_runner.cc index c8dbda535a637..780b0ce8bde56 100644 --- a/xla/backends/cpu/runtime/xnnpack/parallel_loop_runner.cc +++ b/xla/backends/cpu/runtime/xnnpack/parallel_loop_runner.cc @@ -69,6 +69,10 @@ static void ScheduleRange(tsl::CountDownAsyncValueRef count_down, ParallelLoopRunner::ParallelLoopRunner(Eigen::ThreadPoolDevice* device) : done_event_(OkDoneEventSingleton()), device_(device) {} +size_t ParallelLoopRunner::num_threads() const { + return device_->numThreadsInPool(); +} + tsl::AsyncValueRef ParallelLoopRunner::TakeDoneEvent( ParallelLoopRunner&& runner) { return std::move(runner.done_event_); diff --git a/xla/backends/cpu/runtime/xnnpack/parallel_loop_runner.h b/xla/backends/cpu/runtime/xnnpack/parallel_loop_runner.h index 76e28f3b48743..661337c9acd07 100644 --- a/xla/backends/cpu/runtime/xnnpack/parallel_loop_runner.h +++ b/xla/backends/cpu/runtime/xnnpack/parallel_loop_runner.h @@ -62,6 +62,8 @@ class ParallelLoopRunner { tsl::AsyncValueRef done_event() const { return done_event_; } Eigen::ThreadPoolDevice* device() const { return device_; } + size_t num_threads() const; + private: // Async value that signals completion of the last scheduled parallel loop. tsl::AsyncValueRef done_event_; diff --git a/xla/backends/cpu/runtime/xnnpack/xnn_threadpool.cc b/xla/backends/cpu/runtime/xnnpack/xnn_threadpool.cc new file mode 100644 index 0000000000000..cc3f9d286398e --- /dev/null +++ b/xla/backends/cpu/runtime/xnnpack/xnn_threadpool.cc @@ -0,0 +1,367 @@ +/* Copyright 2024 The OpenXLA Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "xla/backends/cpu/runtime/xnnpack/xnn_threadpool.h" + +#include +#include +#include + +#include "pthreadpool.h" +#include "xla/backends/cpu/runtime/xnnpack/parallel_loop_runner.h" +#include "tsl/platform/env.h" +#include "tsl/platform/logging.h" +#include "tsl/platform/threadpool.h" + +#define EIGEN_USE_THREADS +#include "unsupported/Eigen/CXX11/Tensor" + +// `pthreadpool` API implementation on top of ParallelLoopRunner. +// +// When building with `pthreadpool_header_only` config, `pthreadpool` becomes a +// header-only library, and we implement the API on top of ParallelLoopRunner. +// +// At link time `pthreadpool` symbols resolved to our own implementation. This +// is a temporary hack around the fact that it's impossible to customize +// `pthreadpool` implementation at run time. The downsize is that it's +// impossible to have two `pthreadpool` implementations linked into the same +// binary. +// +// WARNING: This is under construction and implements only the subset of the API +// surface which is needed by XNNPACK uses inside XLA. + +namespace xla::cpu { + +bool IsCustomPthreadpoolEnabled() { +#if defined(XLA_CPU_USE_CUSTOM_PTHREADPOOL) + return true; +#else + return false; +#endif // XLA_CPU_USE_CUSTOM_PTHREADPOOL +} + +namespace { + +class Pthreadpool { + public: + virtual ~Pthreadpool() = default; + virtual ParallelLoopRunner* runner() = 0; +}; + +// Wraps user-provided parallel loop runner into the custom pthreadpool. +class WrappedParallelLoopRunner : public Pthreadpool { + public: + explicit WrappedParallelLoopRunner(ParallelLoopRunner* runner) + : runner_(runner) {} + ParallelLoopRunner* runner() final { return runner_; } + + private: + ParallelLoopRunner* runner_; +}; + +// Wraps newly created thread pool into the custom pthreadpool. +class OwnedParallelLoopRunner : public Pthreadpool { + public: + explicit OwnedParallelLoopRunner(size_t threads_count) + : thread_pool_(tsl::Env::Default(), "xnn_threadpool", threads_count), + device_(thread_pool_.AsEigenThreadPool(), threads_count), + runner_(&device_) {} + + ParallelLoopRunner* runner() final { return &runner_; } + + private: + tsl::thread::ThreadPool thread_pool_; + Eigen::ThreadPoolDevice device_; + ParallelLoopRunner runner_; +}; + +} // namespace + +pthreadpool_t CreatePthreadpool(ParallelLoopRunner* runner) { + if (IsCustomPthreadpoolEnabled()) { + return reinterpret_cast( + std::make_unique(runner).release()); + } + LOG(FATAL) << "To use custom pthreadpool, build with " + "`--define pthreadpool_header_only=true`"; +} + +static pthreadpool_t CreatePthreadpool(size_t threads_count) { // NOLINT + if (IsCustomPthreadpoolEnabled()) { + return reinterpret_cast( + std::make_unique(threads_count).release()); + } + LOG(FATAL) << "To use custom pthreadpool, build with " + "`--define pthreadpool_header_only=true`"; +} + +static Pthreadpool* Cast(pthreadpool_t threadpool) { + return reinterpret_cast(threadpool); +} + +xla::cpu::ParallelLoopRunner* GetParallelLoopRunner(pthreadpool_t threadpool) { + return IsCustomPthreadpoolEnabled() ? Cast(threadpool)->runner() : nullptr; +} + +//===----------------------------------------------------------------------===// +// C++ implementation of the subset of `pthreadpool` C API. +//===----------------------------------------------------------------------===// + +static void DestroyPthreadpool(pthreadpool_t threadpool) { // NOLINT + delete Cast(threadpool); +} + +static size_t GetThreadsCount(pthreadpool_t threadpool) { // NOLINT + return Cast(threadpool)->runner()->num_threads(); +} + +static void Parallelize1dTile1d( // NOLINT + pthreadpool_t threadpool, pthreadpool_task_1d_tile_1d_t function, + void* context, size_t range, size_t tile, uint32_t flags) { + ParallelLoopRunner::Task1D task = [function, context](size_t offset, + size_t extent) { + (*function)(context, offset, extent); + }; + + Cast(threadpool)->runner()->Parallelize(range, tile, task); +} + +} // namespace xla::cpu + +#if defined(XLA_CPU_USE_CUSTOM_PTHREADPOOL) + +extern "C" pthreadpool_t pthreadpool_create(size_t threads_count) { + return xla::cpu::CreatePthreadpool(threads_count); +} + +extern "C" void pthreadpool_destroy(pthreadpool_t threadpool) { + xla::cpu::DestroyPthreadpool(threadpool); +} + +extern "C" size_t pthreadpool_get_threads_count(pthreadpool_t threadpool) { + return xla::cpu::GetThreadsCount(threadpool); +} + +extern "C" void pthreadpool_parallelize_1d(pthreadpool_t threadpool, + pthreadpool_task_1d_t function, + void* context, size_t range, + uint32_t flags) { + LOG(FATAL) << "Not implemented"; +} + +extern "C" void pthreadpool_parallelize_1d_with_thread( + pthreadpool_t threadpool, pthreadpool_task_1d_with_thread_t function, + void* context, size_t range, uint32_t flags) { + LOG(FATAL) << "Not implemented"; +} + +extern "C" void pthreadpool_parallelize_1d_with_uarch( + pthreadpool_t threadpool, pthreadpool_task_1d_with_id_t function, + void* context, uint32_t default_uarch_index, uint32_t max_uarch_index, + size_t range, uint32_t flags) { + LOG(FATAL) << "Not implemented"; +} + +extern "C" void pthreadpool_parallelize_1d_tile_1d( + pthreadpool_t threadpool, pthreadpool_task_1d_tile_1d_t function, + void* context, size_t range, size_t tile, uint32_t flags) { + xla::cpu::Parallelize1dTile1d(threadpool, function, context, range, tile, + flags); +} + +extern "C" void pthreadpool_parallelize_2d(pthreadpool_t threadpool, + pthreadpool_task_2d_t function, + void* context, size_t range_i, + size_t range_j, uint32_t flags) { + LOG(FATAL) << "Not implemented"; +} + +extern "C" void pthreadpool_parallelize_2d_with_thread( + pthreadpool_t threadpool, pthreadpool_task_2d_with_thread_t function, + void* context, size_t range_i, size_t range_j, uint32_t flags) { + LOG(FATAL) << "Not implemented"; +} + +extern "C" void pthreadpool_parallelize_2d_tile_1d( + pthreadpool_t threadpool, pthreadpool_task_2d_tile_1d_t function, + void* context, size_t range_i, size_t range_j, size_t tile_j, + uint32_t flags) { + LOG(FATAL) << "Not implemented"; +} + +extern "C" void pthreadpool_parallelize_2d_tile_1d_with_uarch( + pthreadpool_t threadpool, pthreadpool_task_2d_tile_1d_with_id_t function, + void* context, uint32_t default_uarch_index, uint32_t max_uarch_index, + size_t range_i, size_t range_j, size_t tile_j, uint32_t flags) { + LOG(FATAL) << "Not implemented"; +} + +extern "C" void pthreadpool_parallelize_2d_tile_1d_with_uarch_with_thread( + pthreadpool_t threadpool, + pthreadpool_task_2d_tile_1d_with_id_with_thread_t function, void* context, + uint32_t default_uarch_index, uint32_t max_uarch_index, size_t range_i, + size_t range_j, size_t tile_j, uint32_t flags) { + LOG(FATAL) << "Not implemented"; +} + +extern "C" void pthreadpool_parallelize_2d_tile_2d( + pthreadpool_t threadpool, pthreadpool_task_2d_tile_2d_t function, + void* context, size_t range_i, size_t range_j, size_t tile_i, size_t tile_j, + uint32_t flags) { + LOG(FATAL) << "Not implemented"; +} + +extern "C" void pthreadpool_parallelize_2d_tile_2d_with_uarch( + pthreadpool_t threadpool, pthreadpool_task_2d_tile_2d_with_id_t function, + void* context, uint32_t default_uarch_index, uint32_t max_uarch_index, + size_t range_i, size_t range_j, size_t tile_i, size_t tile_j, + uint32_t flags) { + LOG(FATAL) << "Not implemented"; +} + +extern "C" void pthreadpool_parallelize_3d(pthreadpool_t threadpool, + pthreadpool_task_3d_t function, + void* context, size_t range_i, + size_t range_j, size_t range_k, + uint32_t flags) { + LOG(FATAL) << "Not implemented"; +} + +extern "C" void pthreadpool_parallelize_3d_tile_1d( + pthreadpool_t threadpool, pthreadpool_task_3d_tile_1d_t function, + void* context, size_t range_i, size_t range_j, size_t range_k, + size_t tile_k, uint32_t flags) { + LOG(FATAL) << "Not implemented"; +} + +extern "C" void pthreadpool_parallelize_3d_tile_1d_with_thread( + pthreadpool_t threadpool, + pthreadpool_task_3d_tile_1d_with_thread_t function, void* context, + size_t range_i, size_t range_j, size_t range_k, size_t tile_k, + uint32_t flags) { + LOG(FATAL) << "Not implemented"; +} + +extern "C" void pthreadpool_parallelize_3d_tile_1d_with_uarch( + pthreadpool_t threadpool, pthreadpool_task_3d_tile_1d_with_id_t function, + void* context, uint32_t default_uarch_index, uint32_t max_uarch_index, + size_t range_i, size_t range_j, size_t range_k, size_t tile_k, + uint32_t flags) { + LOG(FATAL) << "Not implemented"; +} + +extern "C" void pthreadpool_parallelize_3d_tile_1d_with_uarch_with_thread( + pthreadpool_t threadpool, + pthreadpool_task_3d_tile_1d_with_id_with_thread_t function, void* context, + uint32_t default_uarch_index, uint32_t max_uarch_index, size_t range_i, + size_t range_j, size_t range_k, size_t tile_k, uint32_t flags) { + LOG(FATAL) << "Not implemented"; +} + +extern "C" void pthreadpool_parallelize_3d_tile_2d( + pthreadpool_t threadpool, pthreadpool_task_3d_tile_2d_t function, + void* context, size_t range_i, size_t range_j, size_t range_k, + size_t tile_j, size_t tile_k, uint32_t flags) { + LOG(FATAL) << "Not implemented"; +} + +extern "C" void pthreadpool_parallelize_3d_tile_2d_with_uarch( + pthreadpool_t threadpool, pthreadpool_task_3d_tile_2d_with_id_t function, + void* context, uint32_t default_uarch_index, uint32_t max_uarch_index, + size_t range_i, size_t range_j, size_t range_k, size_t tile_j, + size_t tile_k, uint32_t flags) { + LOG(FATAL) << "Not implemented"; +} + +extern "C" void pthreadpool_parallelize_4d(pthreadpool_t threadpool, + pthreadpool_task_4d_t function, + void* context, size_t range_i, + size_t range_j, size_t range_k, + size_t range_l, uint32_t flags) { + LOG(FATAL) << "Not implemented"; +} + +extern "C" void pthreadpool_parallelize_4d_tile_1d( + pthreadpool_t threadpool, pthreadpool_task_4d_tile_1d_t function, + void* context, size_t range_i, size_t range_j, size_t range_k, + size_t range_l, size_t tile_l, uint32_t flags) { + LOG(FATAL) << "Not implemented"; +} + +extern "C" void pthreadpool_parallelize_4d_tile_2d( + pthreadpool_t threadpool, pthreadpool_task_4d_tile_2d_t function, + void* context, size_t range_i, size_t range_j, size_t range_k, + size_t range_l, size_t tile_k, size_t tile_l, uint32_t flags) { + LOG(FATAL) << "Not implemented"; +} + +extern "C" void pthreadpool_parallelize_4d_tile_2d_with_uarch( + pthreadpool_t threadpool, pthreadpool_task_4d_tile_2d_with_id_t function, + void* context, uint32_t default_uarch_index, uint32_t max_uarch_index, + size_t range_i, size_t range_j, size_t range_k, size_t range_l, + size_t tile_k, size_t tile_l, uint32_t flags) { + LOG(FATAL) << "Not implemented"; +} + +extern "C" void pthreadpool_parallelize_5d(pthreadpool_t threadpool, + pthreadpool_task_5d_t function, + void* context, size_t range_i, + size_t range_j, size_t range_k, + size_t range_l, size_t range_m, + uint32_t flags) { + LOG(FATAL) << "Not implemented"; +} + +extern "C" void pthreadpool_parallelize_5d_tile_1d( + pthreadpool_t threadpool, pthreadpool_task_5d_tile_1d_t function, + void* context, size_t range_i, size_t range_j, size_t range_k, + size_t range_l, size_t range_m, size_t tile_m, uint32_t flags) { + LOG(FATAL) << "Not implemented"; +} + +extern "C" void pthreadpool_parallelize_5d_tile_2d( + pthreadpool_t threadpool, pthreadpool_task_5d_tile_2d_t function, + void* context, size_t range_i, size_t range_j, size_t range_k, + size_t range_l, size_t range_m, size_t tile_l, size_t tile_m, + uint32_t flags) { + LOG(FATAL) << "Not implemented"; +} + +extern "C" void pthreadpool_parallelize_6d(pthreadpool_t threadpool, + pthreadpool_task_6d_t function, + void* context, size_t range_i, + size_t range_j, size_t range_k, + size_t range_l, size_t range_m, + size_t range_n, uint32_t flags) { + LOG(FATAL) << "Not implemented"; +} + +extern "C" void pthreadpool_parallelize_6d_tile_1d( + pthreadpool_t threadpool, pthreadpool_task_6d_tile_1d_t function, + void* context, size_t range_i, size_t range_j, size_t range_k, + size_t range_l, size_t range_m, size_t range_n, size_t tile_n, + uint32_t flags) { + LOG(FATAL) << "Not implemented"; +} + +extern "C" void pthreadpool_parallelize_6d_tile_2d( + pthreadpool_t threadpool, pthreadpool_task_6d_tile_2d_t function, + void* context, size_t range_i, size_t range_j, size_t range_k, + size_t range_l, size_t range_m, size_t range_n, size_t tile_m, + size_t tile_n, uint32_t flags) { + LOG(FATAL) << "Not implemented"; +} + +#endif // XLA_CPU_USE_CUSTOM_PTHREADPOOL diff --git a/xla/backends/cpu/runtime/xnnpack/xnn_threadpool.h b/xla/backends/cpu/runtime/xnnpack/xnn_threadpool.h new file mode 100644 index 0000000000000..94afb6b6499e7 --- /dev/null +++ b/xla/backends/cpu/runtime/xnnpack/xnn_threadpool.h @@ -0,0 +1,37 @@ +/* Copyright 2024 The OpenXLA Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef XLA_BACKENDS_CPU_RUNTIME_XNNPACK_XNN_THREADPOOL_H_ +#define XLA_BACKENDS_CPU_RUNTIME_XNNPACK_XNN_THREADPOOL_H_ + +#include "pthreadpool.h" +#include "xla/backends/cpu/runtime/xnnpack/parallel_loop_runner.h" + +namespace xla::cpu { + +// Returns true if the custom pthreadpool is enabled. +bool IsCustomPthreadpoolEnabled(); + +// Creates a `pthreadpool` that uses the given `runner` to execute work. +pthreadpool_t CreatePthreadpool(xla::cpu::ParallelLoopRunner* runner); + +// Returns the parallel loop runner associated with the given `pthreadpool`. If +// the `pthreadpool` is not associated with a parallel loop runner, returns +// nullptr. +xla::cpu::ParallelLoopRunner* GetParallelLoopRunner(pthreadpool_t threadpool); + +} // namespace xla::cpu + +#endif // XLA_BACKENDS_CPU_RUNTIME_XNNPACK_XNN_THREADPOOL_H_ diff --git a/xla/backends/cpu/runtime/xnnpack/xnn_threadpool_test.cc b/xla/backends/cpu/runtime/xnnpack/xnn_threadpool_test.cc new file mode 100644 index 0000000000000..41b9127231ebe --- /dev/null +++ b/xla/backends/cpu/runtime/xnnpack/xnn_threadpool_test.cc @@ -0,0 +1,143 @@ +/* Copyright 2024 The OpenXLA Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "xla/backends/cpu/runtime/xnnpack/xnn_threadpool.h" + +#include +#include +#include +#include + +#include "xnnpack.h" +#include "absl/algorithm/container.h" +#include "pthreadpool.h" +#include "xla/backends/cpu/runtime/xnnpack/parallel_loop_runner.h" +#include "xla/tsl/concurrency/async_value_ref.h" +#include "tsl/platform/test.h" + +namespace xla::cpu { +namespace { + +static xnn_status CreateBinaryOpsSubgraph(xnn_subgraph_t subgraph, + std::vector dims) { + uint32_t lhs_id = XNN_INVALID_VALUE_ID; + uint32_t rhs_id = XNN_INVALID_VALUE_ID; + uint32_t out0_id = XNN_INVALID_VALUE_ID; + uint32_t out1_id = XNN_INVALID_VALUE_ID; + + if (auto s = xnn_define_tensor_value(subgraph, xnn_datatype_fp32, dims.size(), + dims.data(), nullptr, /*external_id=*/0, + XNN_VALUE_FLAG_EXTERNAL_INPUT, &lhs_id); + s != xnn_status_success) { + return s; + } + + if (auto s = xnn_define_tensor_value(subgraph, xnn_datatype_fp32, dims.size(), + dims.data(), nullptr, /*external_id=*/1, + XNN_VALUE_FLAG_EXTERNAL_INPUT, &rhs_id); + s != xnn_status_success) { + return s; + } + + if (auto s = xnn_define_tensor_value( + subgraph, xnn_datatype_fp32, dims.size(), dims.data(), nullptr, + /*external_id=*/2, XNN_VALUE_FLAG_EXTERNAL_OUTPUT, &out0_id); + s != xnn_status_success) { + return s; + } + + if (auto s = xnn_define_tensor_value( + subgraph, xnn_datatype_fp32, dims.size(), dims.data(), nullptr, + /*external_id=*/3, XNN_VALUE_FLAG_EXTERNAL_OUTPUT, &out1_id); + s != xnn_status_success) { + return s; + } + + xnn_binary_params params = {-std::numeric_limits::infinity(), + std::numeric_limits::infinity()}; + + if (auto s = xnn_define_binary(subgraph, xnn_binary_add, ¶ms, lhs_id, + rhs_id, out0_id, /*flags=*/0); + s != xnn_status_success) { + return s; + } + + if (auto s = xnn_define_binary(subgraph, xnn_binary_multiply, ¶ms, lhs_id, + rhs_id, out1_id, /*flags=*/0); + s != xnn_status_success) { + return s; + } + + return xnn_status_success; +} + +TEST(XnnThreadPoolTest, BinarySubgraph) { + pthreadpool_t threadpool = pthreadpool_create(8); + ASSERT_NE(threadpool, nullptr); + + ASSERT_EQ(xnn_initialize(/*allocator=*/nullptr), xnn_status_success); + + xnn_workspace_t workspace = nullptr; + ASSERT_EQ(xnn_create_workspace(&workspace), xnn_status_success); + + xnn_subgraph_t subgraph = nullptr; + + ASSERT_EQ( + xnn_create_subgraph(/*external_value_ids=*/4, /*flags=*/0, &subgraph), + xnn_status_success); + + size_t d0 = 1024; + CreateBinaryOpsSubgraph(subgraph, {d0, d0}); + + std::vector lhs(d0 * d0, 2.0f); + std::vector rhs(d0 * d0, 3.0f); + std::vector out0(d0 * d0, 0.0f); + std::vector out1(d0 * d0, 0.0f); + + xnn_runtime_t runtime = nullptr; + ASSERT_EQ(xnn_create_runtime_v4(subgraph, nullptr, workspace, threadpool, 0, + &runtime), + xnn_status_success); + + std::vector external_values = { + xnn_external_value{0, lhs.data()}, + xnn_external_value{1, rhs.data()}, + xnn_external_value{2, out0.data()}, + xnn_external_value{3, out1.data()}, + }; + + ASSERT_EQ(xnn_reshape_runtime(runtime), xnn_status_success); + ASSERT_EQ(xnn_setup_runtime_v2(runtime, 4, external_values.data()), + xnn_status_success); + + ASSERT_EQ(xnn_invoke_runtime(runtime), xnn_status_success); + + if (ParallelLoopRunner* runner = GetParallelLoopRunner(threadpool)) { + tsl::BlockUntilReady(runner->done_event()); + ASSERT_TRUE(runner->done_event().IsConcrete()); + } + + ASSERT_TRUE(absl::c_all_of(out0, [](float v) { return v == 5.0f; })); + ASSERT_TRUE(absl::c_all_of(out1, [](float v) { return v == 6.0f; })); + + ASSERT_EQ(xnn_delete_runtime(runtime), xnn_status_success); + ASSERT_EQ(xnn_delete_subgraph(subgraph), xnn_status_success); + ASSERT_EQ(xnn_release_workspace(workspace), xnn_status_success); + + pthreadpool_destroy(threadpool); +} + +} // namespace +} // namespace xla::cpu