Skip to content

Commit

Permalink
[xla:cpu] Add an xnn_threadpool for wrapping ParallelLoopRunner as pt…
Browse files Browse the repository at this point in the history
…hreadpool API

Update xnnpack version to the latest one required by XLA:CPU

PiperOrigin-RevId: 706816929
  • Loading branch information
ezhulenev authored and Google-ML-Automation committed Dec 16, 2024
1 parent 23907df commit a5b94d6
Show file tree
Hide file tree
Showing 8 changed files with 633 additions and 9 deletions.
18 changes: 9 additions & 9 deletions third_party/tsl/workspace2.bzl
Original file line number Diff line number Diff line change
Expand Up @@ -111,9 +111,9 @@ def _tf_repositories():
# LINT.IfChange
tf_http_archive(
name = "XNNPACK",
sha256 = "ca3a5316b8161214f8f22a578fb638f1fccd0585eee40301363ffd026310379a",
strip_prefix = "XNNPACK-a50369c0fdd15f0f35b1a91c964644327a88d480",
urls = tf_mirror_urls("https://github.com/google/XNNPACK/archive/a50369c0fdd15f0f35b1a91c964644327a88d480.zip"),
sha256 = "3306f4178c8594b689165d385e644f03a3154c3be044f6ae36dd170fbf182cf5",
strip_prefix = "XNNPACK-983d013300f19fd3f4e33220b6401408e97a8d12",
urls = tf_mirror_urls("https://github.com/google/XNNPACK/archive/983d013300f19fd3f4e33220b6401408e97a8d12.zip"),
)
# LINT.ThenChange(//tensorflow/lite/tools/cmake/modules/xnnpack.cmake)

Expand All @@ -126,16 +126,16 @@ def _tf_repositories():

tf_http_archive(
name = "pthreadpool",
sha256 = "b96413b10dd8edaa4f6c0a60c6cf5ef55eebeef78164d5d69294c8173457f0ec",
strip_prefix = "pthreadpool-b8374f80e42010941bda6c85b0e3f1a1bd77a1e0",
urls = tf_mirror_urls("https://github.com/Maratyszcza/pthreadpool/archive/b8374f80e42010941bda6c85b0e3f1a1bd77a1e0.zip"),
sha256 = "a4cf06de57bfdf8d7b537c61f1c3071bce74e57524fe053e0bbd2332feca7f95",
strip_prefix = "pthreadpool-4fe0e1e183925bf8cfa6aae24237e724a96479b8",
urls = tf_mirror_urls("https://github.com/Maratyszcza/pthreadpool/archive/4fe0e1e183925bf8cfa6aae24237e724a96479b8.zip"),
)

tf_http_archive(
name = "cpuinfo",
strip_prefix = "cpuinfo-5e63739504f0f8e18e941bd63b2d6d42536c7d90",
sha256 = "18eca9bc8d9c4ce5496d0d2be9f456d55cbbb5f0639a551ce9c8bac2e84d85fe",
urls = tf_mirror_urls("https://github.com/pytorch/cpuinfo/archive/5e63739504f0f8e18e941bd63b2d6d42536c7d90.tar.gz"),
sha256 = "52e0ffd7998d8cb3a927d8a6e1145763744d866d2be09c4eccea27fc157b6bb0",
strip_prefix = "cpuinfo-cebb0933058d7f181c979afd50601dc311e1bf8c",
urls = tf_mirror_urls("https://github.com/pytorch/cpuinfo/archive/cebb0933058d7f181c979afd50601dc311e1bf8c.zip"),
)

tf_http_archive(
Expand Down
30 changes: 30 additions & 0 deletions workspace2.bzl
Original file line number Diff line number Diff line change
Expand Up @@ -42,12 +42,42 @@ def _tf_repositories():
# curl -L <url> | sha256sum
# and update the sha256 with the result.

# LINT.IfChange
tf_http_archive(
name = "XNNPACK",
sha256 = "3306f4178c8594b689165d385e644f03a3154c3be044f6ae36dd170fbf182cf5",
strip_prefix = "XNNPACK-983d013300f19fd3f4e33220b6401408e97a8d12",
urls = tf_mirror_urls("https://github.com/google/XNNPACK/archive/983d013300f19fd3f4e33220b6401408e97a8d12.zip"),
)
# LINT.ThenChange(//tensorflow/lite/tools/cmake/modules/xnnpack.cmake)

tf_http_archive(
name = "KleidiAI",
sha256 = "ad37707084a6d4ff41be10cbe8540c75bea057ba79d0de6c367c1bfac6ba0852",
strip_prefix = "kleidiai-40a926833857fb64786e02f97703e42b1537cb57",
urls = tf_mirror_urls("https://gitlab.arm.com/kleidi/kleidiai/-/archive/40a926833857fb64786e02f97703e42b1537cb57/kleidiai-40a926833857fb64786e02f97703e42b1537cb57.zip"),
)

tf_http_archive(
name = "FXdiv",
sha256 = "3d7b0e9c4c658a84376a1086126be02f9b7f753caa95e009d9ac38d11da444db",
strip_prefix = "FXdiv-63058eff77e11aa15bf531df5dd34395ec3017c8",
urls = tf_mirror_urls("https://github.com/Maratyszcza/FXdiv/archive/63058eff77e11aa15bf531df5dd34395ec3017c8.zip"),
)

tf_http_archive(
name = "cpuinfo",
sha256 = "52e0ffd7998d8cb3a927d8a6e1145763744d866d2be09c4eccea27fc157b6bb0",
strip_prefix = "cpuinfo-cebb0933058d7f181c979afd50601dc311e1bf8c",
urls = tf_mirror_urls("https://github.com/pytorch/cpuinfo/archive/cebb0933058d7f181c979afd50601dc311e1bf8c.zip"),
)

tf_http_archive(
name = "pthreadpool",
sha256 = "a4cf06de57bfdf8d7b537c61f1c3071bce74e57524fe053e0bbd2332feca7f95",
strip_prefix = "pthreadpool-4fe0e1e183925bf8cfa6aae24237e724a96479b8",
urls = tf_mirror_urls("https://github.com/Maratyszcza/pthreadpool/archive/4fe0e1e183925bf8cfa6aae24237e724a96479b8.zip"),
)

tf_http_archive(
name = "jsoncpp_git",
Expand Down
41 changes: 41 additions & 0 deletions xla/backends/cpu/runtime/xnnpack/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -71,3 +71,44 @@ cc_library(
"@XNNPACK",
],
)

cc_library(
name = "xnn_threadpool",
srcs = ["xnn_threadpool.cc"],
hdrs = ["xnn_threadpool.h"],
# copybara:uncomment_begin(google-only)
# local_defines = select({
# "@pthreadpool:pthreadpool_header_only_explicit_true": [
# "XLA_CPU_USE_CUSTOM_PTHREADPOOL",
# ],
# "//conditions:default": [],
# }),
# copybara:uncomment_end
deps = [
":parallel_loop_runner",
"@eigen_archive//:eigen3",
"@pthreadpool",
"@tsl//tsl/platform:env",
"@tsl//tsl/platform:logging",
],
)

xla_cc_test(
name = "xnn_threadpool_test",
srcs = ["xnn_threadpool_test.cc"],
tags = ["no_oss"],
deps = [
":parallel_loop_runner",
":xnn_threadpool",
"//xla/tsl/concurrency:async_value",
"@XNNPACK",
"@com_google_absl//absl/algorithm:container",
"@com_google_absl//absl/synchronization",
"@eigen_archive//:eigen3",
"@pthreadpool",
"@tsl//tsl/platform:env",
"@tsl//tsl/platform:test",
"@tsl//tsl/platform:test_benchmark",
"@tsl//tsl/platform:test_main",
],
)
4 changes: 4 additions & 0 deletions xla/backends/cpu/runtime/xnnpack/parallel_loop_runner.cc
Original file line number Diff line number Diff line change
Expand Up @@ -69,6 +69,10 @@ static void ScheduleRange(tsl::CountDownAsyncValueRef<tsl::Chain> count_down,
ParallelLoopRunner::ParallelLoopRunner(Eigen::ThreadPoolDevice* device)
: done_event_(OkDoneEventSingleton()), device_(device) {}

size_t ParallelLoopRunner::num_threads() const {
return device_->numThreadsInPool();
}

tsl::AsyncValueRef<tsl::Chain> ParallelLoopRunner::TakeDoneEvent(
ParallelLoopRunner&& runner) {
return std::move(runner.done_event_);
Expand Down
2 changes: 2 additions & 0 deletions xla/backends/cpu/runtime/xnnpack/parallel_loop_runner.h
Original file line number Diff line number Diff line change
Expand Up @@ -62,6 +62,8 @@ class ParallelLoopRunner {
tsl::AsyncValueRef<tsl::Chain> done_event() const { return done_event_; }
Eigen::ThreadPoolDevice* device() const { return device_; }

size_t num_threads() const;

private:
// Async value that signals completion of the last scheduled parallel loop.
tsl::AsyncValueRef<tsl::Chain> done_event_;
Expand Down
Loading

0 comments on commit a5b94d6

Please sign in to comment.