From 61e10e3a1434957b1757a7ad7eae5b4d269939de Mon Sep 17 00:00:00 2001 From: anakinxc <103552181+anakinxc@users.noreply.github.com> Date: Mon, 13 Nov 2023 15:02:30 +0800 Subject: [PATCH] Repo sync (#399) --- bazel/curve25519-donna.BUILD | 21 +++++++++++ bazel/emp-tool.BUILD | 4 +-- bazel/patches/xla.patch | 22 ------------ bazel/repositories.bzl | 30 ++++++++++------ bazel/spu.bzl | 12 ++----- libspu/compiler/front_end/hlo_importer.cc | 2 +- libspu/kernel/hal/ring.cc | 33 ++++++++++------- libspu/kernel/hal/sort.cc | 29 ++++----------- spu/libspu.cc | 26 +++++++++++--- spu/tests/link_test.py | 43 +++++++++-------------- spu/tests/psi_test.py | 4 --- 11 files changed, 112 insertions(+), 114 deletions(-) create mode 100644 bazel/curve25519-donna.BUILD delete mode 100644 bazel/patches/xla.patch diff --git a/bazel/curve25519-donna.BUILD b/bazel/curve25519-donna.BUILD new file mode 100644 index 00000000..9dd666d4 --- /dev/null +++ b/bazel/curve25519-donna.BUILD @@ -0,0 +1,21 @@ +# Copyright 2022 Ant Group Co., Ltd. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +load("@rules_cc//cc:defs.bzl", "cc_library") + +cc_library( + name = "curve25519_donna", + srcs = ["curve25519.c"], + hdrs = glob(["*.h"]), + visibility = ["//visibility:public"], +) diff --git a/bazel/emp-tool.BUILD b/bazel/emp-tool.BUILD index 57a2c776..034df99d 100644 --- a/bazel/emp-tool.BUILD +++ b/bazel/emp-tool.BUILD @@ -12,7 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. -load("@yacl//bazel:yacl.bzl", "yacl_cmake_external") +load("@spulib//bazel:spu.bzl", "spu_cmake_external") package(default_visibility = ["//visibility:public"]) @@ -21,7 +21,7 @@ filegroup( srcs = glob(["**"]), ) -yacl_cmake_external( +spu_cmake_external( name = "emp-tool", cache_entries = { "OPENSSL_ROOT_DIR": "$EXT_BUILD_DEPS/openssl", diff --git a/bazel/patches/xla.patch b/bazel/patches/xla.patch deleted file mode 100644 index dbb5a1bc..00000000 --- a/bazel/patches/xla.patch +++ /dev/null @@ -1,22 +0,0 @@ -diff --git a/third_party/tsl/workspace1.bzl b/third_party/tsl/workspace1.bzl -index 4cfb6da82..0e3774834 100644 ---- a/third_party/tsl/workspace1.bzl -+++ b/third_party/tsl/workspace1.bzl -@@ -3,7 +3,7 @@ - load("@bazel_tools//tools/build_defs/repo:http.bzl", "http_archive") - load("@com_github_grpc_grpc//bazel:grpc_deps.bzl", "grpc_deps") - load("@io_bazel_rules_closure//closure:defs.bzl", "closure_repositories") --load("@rules_cuda//cuda:dependencies.bzl", "rules_cuda_dependencies") -+# load("@rules_cuda//cuda:dependencies.bzl", "rules_cuda_dependencies") - load("@rules_pkg//:deps.bzl", "rules_pkg_dependencies") - - # buildifier: disable=unnamed-macro -@@ -14,7 +14,7 @@ def workspace(with_rules_cc = True): - with_rules_cc: whether to load and patch rules_cc repository. - """ - native.register_toolchains("@local_config_python//:py_toolchain") -- rules_cuda_dependencies(with_rules_cc) -+ # rules_cuda_dependencies(with_rules_cc) - rules_pkg_dependencies() - - closure_repositories() \ No newline at end of file diff --git a/bazel/repositories.bzl b/bazel/repositories.bzl index b45e7734..f1b479d7 100644 --- a/bazel/repositories.bzl +++ b/bazel/repositories.bzl @@ -12,13 +12,13 @@ # See the License for the specific language governing permissions and # limitations under the License. +load("@bazel_tools//tools/build_defs/repo:git.bzl", "git_repository") load("@bazel_tools//tools/build_defs/repo:http.bzl", "http_archive") load("@bazel_tools//tools/build_defs/repo:utils.bzl", "maybe") -load("@bazel_tools//tools/build_defs/repo:git.bzl", "git_repository") SECRETFLOW_GIT = "https://github.com/secretflow" -YACL_COMMIT_ID = "6be4330542e92b6503317c45a999c99e654ced58" +YACL_COMMIT_ID = "0953593df3ca6544442236f2b6d78a5b89035e24" def spu_deps(): _rules_cuda() @@ -43,6 +43,7 @@ def spu_deps(): _com_github_microsoft_kuku() _com_google_flatbuffers() _com_github_nvidia_cutlass() + _com_github_floodyberry_curve25519_donna() maybe( git_repository, @@ -159,17 +160,17 @@ def _com_github_xtensor_xtl(): ) def _com_github_openxla_xla(): - OPENXLA_COMMIT = "75a7973c2850fcc33278c84e1b62eff8f0ad35f8" - OPENXLA_SHA256 = "4534c3230853e990ac613898c2ff39626d1beacb0c3675fbea502dce3e32f620" + OPENXLA_COMMIT = "d5791b01aa7541e3400224ac0a2985cc0f6940cb" + OPENXLA_SHA256 = "82dd50e6f51d79e8da69f109a234e33b8036f7b8798e41a03831b19c0c64d6e5" SKYLIB_VERSION = "1.3.0" + SKYLIB_SHA256 = "74d544d96f4a5bb630d465ca8bbcfe231e3594e5aae57e1edbf17a6eb3ca2506" maybe( http_archive, name = "bazel_skylib", - sha256 = "74d544d96f4a5bb630d465ca8bbcfe231e3594e5aae57e1edbf17a6eb3ca2506", + sha256 = SKYLIB_SHA256, urls = [ - "https://mirror.bazel.build/github.com/bazelbuild/bazel-skylib/releases/download/{version}/bazel-skylib-{version}.tar.gz".format(version = SKYLIB_VERSION), "https://github.com/bazelbuild/bazel-skylib/releases/download/{version}/bazel-skylib-{version}.tar.gz".format(version = SKYLIB_VERSION), ], ) @@ -181,10 +182,6 @@ def _com_github_openxla_xla(): sha256 = OPENXLA_SHA256, strip_prefix = "xla-" + OPENXLA_COMMIT, type = ".tar.gz", - patch_args = ["-p1"], - patches = [ - "@spulib//bazel:patches/xla.patch", - ], urls = [ "https://github.com/openxla/xla/archive/{commit}.tar.gz".format(commit = OPENXLA_COMMIT), ], @@ -370,3 +367,16 @@ def _com_github_nvidia_cutlass(): sha256 = "9637961560a9d63a6bb3f407faf457c7dbc4246d3afb54ac7dc1e014dd7f172f", build_file = "@spulib//bazel:nvidia_cutlass.BUILD", ) + +def _com_github_floodyberry_curve25519_donna(): + maybe( + http_archive, + name = "com_github_floodyberry_curve25519_donna", + strip_prefix = "curve25519-donna-2fe66b65ea1acb788024f40a3373b8b3e6f4bbb2", + sha256 = "ba57d538c241ad30ff85f49102ab2c8dd996148456ed238a8c319f263b7b149a", + type = "tar.gz", + build_file = "@spulib//bazel:curve25519-donna.BUILD", + urls = [ + "https://github.com/floodyberry/curve25519-donna/archive/2fe66b65ea1acb788024f40a3373b8b3e6f4bbb2.tar.gz", + ], + ) diff --git a/bazel/spu.bzl b/bazel/spu.bzl index 64fc79eb..0ac3fb45 100644 --- a/bazel/spu.bzl +++ b/bazel/spu.bzl @@ -17,7 +17,7 @@ warpper bazel cc_xx to modify flags. """ load("@rules_cc//cc:defs.bzl", "cc_binary", "cc_library", "cc_test") -load("@rules_foreign_cc//foreign_cc:defs.bzl", "cmake", "configure_make") +load("@yacl//bazel:yacl.bzl", "yacl_cmake_external") WARNING_FLAGS = [ "-Wall", @@ -68,15 +68,7 @@ def spu_cc_library( **kargs ) -def spu_cmake_external(**attrs): - if "generate_args" not in attrs: - attrs["generate_args"] = ["-GNinja"] - return cmake(**attrs) - -def spu_configure_make(**attrs): - if "args" not in attrs: - attrs["args"] = ["-j 4"] - return configure_make(**attrs) +spu_cmake_external = yacl_cmake_external def _spu_version_file_impl(ctx): out = ctx.actions.declare_file(ctx.attr.filename) diff --git a/libspu/compiler/front_end/hlo_importer.cc b/libspu/compiler/front_end/hlo_importer.cc index 2a9f08e2..296406d7 100644 --- a/libspu/compiler/front_end/hlo_importer.cc +++ b/libspu/compiler/front_end/hlo_importer.cc @@ -185,7 +185,7 @@ HloImporter::parseXlaModuleFromString(const std::string &content) { break; } } - debug_options.set_xla_detailed_logging_and_dumping(true); + debug_options.set_xla_enable_dumping(true); } auto module_config = diff --git a/libspu/kernel/hal/ring.cc b/libspu/kernel/hal/ring.cc index c6eaa9c7..528d8719 100644 --- a/libspu/kernel/hal/ring.cc +++ b/libspu/kernel/hal/ring.cc @@ -346,22 +346,31 @@ Value _mmul(SPUContext* ctx, const Value& x, const Value& y) { const auto& row_blocks = ret_blocks[r]; for (int64_t c = 0; c < static_cast(row_blocks.size()); c++) { const auto& block = row_blocks[c]; - SPU_ENFORCE(block.data().isCompact()); const int64_t block_rows = block.shape()[0]; const int64_t block_cols = block.shape()[1]; - if (n_blocks == 1) { - SPU_ENFORCE(row_blocks.size() == 1); - SPU_ENFORCE(block_cols == n); - char* dst = &ret.data().at({r * m_step, 0}); - const char* src = &block.data().at({0, 0}); - size_t cp_len = block.elsize() * block.numel(); - std::memcpy(dst, src, cp_len); + if (block.data().isCompact()) { + if (n_blocks == 1) { + SPU_ENFORCE(row_blocks.size() == 1); + SPU_ENFORCE(block_cols == n); + char* dst = &ret.data().at({r * m_step, 0}); + const char* src = &block.data().at({0, 0}); + size_t cp_len = block.elsize() * block.numel(); + std::memcpy(dst, src, cp_len); + } else { + for (int64_t i = 0; i < block_rows; i++) { + char* dst = &ret.data().at({r * m_step + i, c * n_step}); + const char* src = &block.data().at({i, 0}); + size_t cp_len = block.elsize() * block_cols; + std::memcpy(dst, src, cp_len); + } + } } else { for (int64_t i = 0; i < block_rows; i++) { - char* dst = &ret.data().at({r * m_step + i, c * n_step}); - const char* src = &block.data().at({i, 0}); - size_t cp_len = block.elsize() * block_cols; - std::memcpy(dst, src, cp_len); + for (int64_t j = 0; j < block_cols; j++) { + char* dst = &ret.data().at({r * m_step + i, c * n_step + j}); + const char* src = &block.data().at({i, j}); + std::memcpy(dst, src, block.elsize()); + } } } } diff --git a/libspu/kernel/hal/sort.cc b/libspu/kernel/hal/sort.cc index 83d2cbf4..9d104f95 100644 --- a/libspu/kernel/hal/sort.cc +++ b/libspu/kernel/hal/sort.cc @@ -21,6 +21,7 @@ Value Permute1D(SPUContext *, const Value &x, const Index &indices) { return Value(x.data().linear_gather(indices), x.dtype()); } +// FIXME: move to mpc layer // Vectorized Prefix Sum // Ref: https://en.algorithmica.org/hpc/algorithms/prefix/ Value PrefixSum(SPUContext *ctx, const Value &x) { @@ -242,10 +243,9 @@ spu::Value GenInvPermByTwoBitVectors(SPUContext *ctx, const spu::Value &x, {reshape(ctx, f0, new_shape), reshape(ctx, f1, new_shape), reshape(ctx, f2, new_shape), reshape(ctx, f3, new_shape)}, 1); - auto s = f.clone(); // calculate prefix sum - auto ps = PrefixSum(ctx, s); + auto ps = PrefixSum(ctx, f); // mul f and s auto fs = _mul(ctx, f, ps); @@ -294,10 +294,9 @@ spu::Value GenInvPermByBitVector(SPUContext *ctx, const spu::Value &x) { Shape new_shape = {1, numel}; auto f = concatenate( ctx, {reshape(ctx, rev_x, new_shape), reshape(ctx, x, new_shape)}, 1); - auto s = f.clone(); // calculate prefix sum - auto ps = PrefixSum(ctx, s); + auto ps = PrefixSum(ctx, f); // mul f and s auto fs = _mul(ctx, f, ps); @@ -339,25 +338,11 @@ std::vector BitDecompose(SPUContext *ctx, const spu::Value &x, ? static_cast(valid_bits) : x_bshare.storage_type().as()->nbits(); rets.reserve(nbits); - std::vector> sub_ctxs; - for (size_t bit = 0; bit < nbits; ++bit) { - sub_ctxs.push_back(ctx->fork()); - } - std::vector> futures; - for (size_t bit = 0; bit < nbits; ++bit) { - auto async_res = std::async( - [&](size_t bit, const spu::Value &x, const spu::Value &k1) { - auto sub_ctx = sub_ctxs[bit].get(); - auto x_bshare_shift = right_shift_logical(sub_ctx, x, bit); - auto lowest_bit = _and(sub_ctx, x_bshare_shift, k1); - return _prefer_a(sub_ctx, lowest_bit); - }, - bit, x_bshare, k1); - futures.push_back(std::move(async_res)); - } for (size_t bit = 0; bit < nbits; ++bit) { - rets.emplace_back(futures[bit].get()); + auto x_bshare_shift = right_shift_logical(ctx, x_bshare, bit); + auto lowest_bit = _and(ctx, x_bshare_shift, k1); + rets.emplace_back(_prefer_a(ctx, lowest_bit)); } return rets; @@ -601,4 +586,4 @@ std::vector simple_sort1d(SPUContext *ctx, } } -} // namespace spu::kernel::hal \ No newline at end of file +} // namespace spu::kernel::hal diff --git a/spu/libspu.cc b/spu/libspu.cc index 69114dbc..ebf6af40 100644 --- a/spu/libspu.cc +++ b/spu/libspu.cc @@ -73,6 +73,7 @@ void BindLink(py::module& m) { using yacl::link::CertInfo; using yacl::link::Context; using yacl::link::ContextDesc; + using yacl::link::RetryOptions; using yacl::link::SSLOptions; using yacl::link::VerifyOptions; @@ -96,6 +97,25 @@ void BindLink(py::module& m) { .def_readwrite("ca_file_path", &VerifyOptions::ca_file_path, "the trusted CA file path"); + py::class_(m, "RetryOptions", + "The options used for channel retry") + .def_readwrite("max_retry", &RetryOptions::max_retry, "max retry count") + .def_readwrite("retry_interval_ms", &RetryOptions::retry_interval_ms, + "first retry interval") + .def_readwrite("retry_interval_incr_ms", + &RetryOptions::retry_interval_incr_ms, + "the amount of time to increase between retries") + .def_readwrite("max_retry_interval_ms", + &RetryOptions::max_retry_interval_ms, + "the max interval between retries") + .def_readwrite("error_codes", &RetryOptions::error_codes, + "retry on these error codes, if empty, retry on all codes") + .def_readwrite( + "http_codes", &RetryOptions::http_codes, + "retry on these http codes, if empty, retry on all http codes") + .def_readwrite("aggressive_retry", &RetryOptions::aggressive_retry, + "do aggressive retry"); + py::class_(m, "SSLOptions", "The options used for ssl") .def_readwrite("cert", &SSLOptions::cert, "certificate used for authentication") @@ -132,12 +152,8 @@ void BindLink(py::module& m) { .def_readwrite("enable_ssl", &ContextDesc::enable_ssl) .def_readwrite("client_ssl_opts", &ContextDesc::client_ssl_opts) .def_readwrite("server_ssl_opts", &ContextDesc::server_ssl_opts) - .def_readwrite("brpc_retry_count", &ContextDesc::brpc_retry_count) - .def_readwrite("brpc_retry_interval_ms", - &ContextDesc::brpc_retry_interval_ms) - .def_readwrite("brpc_aggressive_retry", - &ContextDesc::brpc_aggressive_retry) .def_readwrite("link_type", &ContextDesc::link_type) + .def_readwrite("retry_opts", &ContextDesc::retry_opts) .def( "add_party", [](ContextDesc& desc, std::string id, std::string host) { diff --git a/spu/tests/link_test.py b/spu/tests/link_test.py index 5e1a6be3..2854c706 100644 --- a/spu/tests/link_test.py +++ b/spu/tests/link_test.py @@ -22,17 +22,20 @@ import multiprocess import spu.libspu.link as link +from socket import socket + + +def _rand_port(): + with socket() as s: + s.bind(("localhost", 0)) + return s.getsockname()[1] class UnitTests(unittest.TestCase): def test_link_brpc(self): desc = link.Desc() - desc.add_party("alice", "127.0.0.1:9927") - desc.add_party("bob", "127.0.0.1:9928") - - # Pickle only works properly for top-level functions, so mark proc as global to workaround this limitation - # See https://stackoverflow.com/questions/56533827/pool-apply-async-nested-function-is-not-executed/56534386#56534386 - global proc + desc.add_party("alice", f"127.0.0.1:{_rand_port()}") + desc.add_party("bob", f"127.0.0.1:{_rand_port()}") def proc(rank): data = "hello" if rank == 0 else "world" @@ -87,12 +90,8 @@ def thread(rank): def test_link_send_recv(self): desc = link.Desc() - desc.add_party("alice", "127.0.0.1:9927") - desc.add_party("bob", "127.0.0.1:9928") - - # Pickle only works properly for top-level functions, so mark proc as global to workaround this limitation - # See https://stackoverflow.com/questions/56533827/pool-apply-async-nested-function-is-not-executed/56534386#56534386 - global proc + desc.add_party("alice", f"127.0.0.1:{_rand_port()}") + desc.add_party("bob", f"127.0.0.1:{_rand_port()}") def proc(rank): lctx = link.create_brpc(desc, rank) @@ -104,7 +103,7 @@ def proc(rank): lctx.stop_link() - # launch with multiprocess + # launch with MultiProcessing jobs = [ multiprocess.Process(target=proc, args=(0,)), multiprocess.Process(target=proc, args=(1,)), @@ -117,12 +116,8 @@ def proc(rank): def test_link_send_async(self): desc = link.Desc() - desc.add_party("alice", "127.0.0.1:9927") - desc.add_party("bob", "127.0.0.1:9928") - - # Pickle only works properly for top-level functions, so mark proc as global to workaround this limitation - # See https://stackoverflow.com/questions/56533827/pool-apply-async-nested-function-is-not-executed/56534386#56534386 - global proc + desc.add_party("alice", f"127.0.0.1:{_rand_port()}") + desc.add_party("bob", f"127.0.0.1:{_rand_port()}") def proc(rank): lctx = link.create_brpc(desc, rank) @@ -132,7 +127,7 @@ def proc(rank): lctx.stop_link() - # launch with multiprocess + # launch with MultiProcessing jobs = [ multiprocess.Process(target=proc, args=(0,)), multiprocess.Process(target=proc, args=(1,)), @@ -145,12 +140,8 @@ def proc(rank): def test_link_next_rank(self): desc = link.Desc() - desc.add_party("alice", "127.0.0.1:9927") - desc.add_party("bob", "127.0.0.1:9928") - - # Pickle only works properly for top-level functions, so mark proc as global to workaround this limitation - # See https://stackoverflow.com/questions/56533827/pool-apply-async-nested-function-is-not-executed/56534386#56534386 - global proc + desc.add_party("alice", f"127.0.0.1:{_rand_port()}") + desc.add_party("bob", f"127.0.0.1:{_rand_port()}") def proc(rank): lctx = link.create_brpc(desc, rank) diff --git a/spu/tests/psi_test.py b/spu/tests/psi_test.py index 3f2b028c..58c1248a 100644 --- a/spu/tests/psi_test.py +++ b/spu/tests/psi_test.py @@ -62,8 +62,6 @@ def run_streaming_psi(self, wsize, inputs, outputs, selected_fields, protocol): port = get_free_port() lctx_desc.add_party(f"id_{rank}", f"127.0.0.1:{port}") - global wrap - def wrap(rank, selected_fields, input_path, output_path, type): lctx = link.create_brpc(lctx_desc, rank) @@ -242,8 +240,6 @@ def test_ecdh_oprf_unbalanced(self): precheck_input = False server_cache_path = "server_cache.bin" - global wrap - def wrap( rank, offline_path,