Skip to content

Commit

Permalink
Repo sync (#149)
Browse files Browse the repository at this point in the history
  • Loading branch information
anakinxc authored Apr 6, 2023
1 parent f84c843 commit 2b29087
Show file tree
Hide file tree
Showing 46 changed files with 2,101 additions and 682 deletions.
4 changes: 3 additions & 1 deletion CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
## staging
> please add your unreleased change here.
## 20230328
## 20230406
- [SPU] 0.3.2 release
- [Feature] Add TrustedThirdParty beaver provider for semi2k
- [Feature] Expose ssl/tls options
Expand All @@ -19,10 +19,12 @@
- [Feature] Improve shift performance
- [Feature] Support shift by secret number of bits
- [Feature] Support secret indexing
- [Feature] Add PIR python binding
- [bugfix] Fix boolean ConstantOp
- [bugfix] Fix jnp.median
- [bugfix] Fix jnp.sort on floating point inputs
- [bugfix] Fix secret sort with public payloads
- [bugfix] Reenable secret GatherOp support
- [3p] Relax TensorFlow version in requirements.txt
- [3p] Move to OpenXLA
- [API] Move C++ API from spu to libspu
Expand Down
18 changes: 3 additions & 15 deletions bazel/repositories.bzl
Original file line number Diff line number Diff line change
Expand Up @@ -18,10 +18,9 @@ load("@bazel_tools//tools/build_defs/repo:git.bzl", "git_repository")

SECRETFLOW_GIT = "https://github.com/secretflow"

YACL_COMMIT_ID = "8108513d1203cc2d7e5b1d2e4429af2fee5d1d4d"
YACL_COMMIT_ID = "8aef7c9076ee75e5467a921fad1c4a286eb2dc3a"

def spu_deps():
_rule_python()
_bazel_platform()
_upb()
_com_github_xtensor_xtensor()
Expand Down Expand Up @@ -63,17 +62,6 @@ def spu_deps():
path = "/opt/homebrew/opt/libomp/",
)

def _rule_python():
maybe(
http_archive,
name = "rules_python",
sha256 = "29a801171f7ca190c543406f9894abf2d483c206e14d6acbd695623662320097",
strip_prefix = "rules_python-0.18.1",
urls = [
"https://github.com/bazelbuild/rules_python/releases/download/0.18.1/rules_python-0.18.1.tar.gz",
],
)

def _bazel_platform():
http_archive(
name = "platforms",
Expand Down Expand Up @@ -151,8 +139,8 @@ def _com_github_xtensor_xtl():
)

def _com_github_openxla_xla():
OPENXLA_COMMIT = "6aee72ed08290623ff68742e146750ea0e7ddf8c"
OPENXLA_SHA256 = "dc13148b1e27d8fbc5a60bf4055595f3f7708bb339400a0ba19d9923b6f642fc"
OPENXLA_COMMIT = "da6b60c1a1f31bf1194bcdfb138841902e413704"
OPENXLA_SHA256 = "cec02e7c0af001fd08ce89a47f21bbedb1908abf1144070f1228f525398d280b"

SKYLIB_VERSION = "1.3.0"

Expand Down
71 changes: 61 additions & 10 deletions docs/development/pir.rst
Original file line number Diff line number Diff line change
Expand Up @@ -22,19 +22,23 @@ First build pir examples.

.. code-block:: bash
bazel build //examples/cpp/pir -c opt
bazel build //examples/cpp/pir/... -c opt
setup phase
>>>>>>>>>>>

Start server's terminal.
Generate test usage oprf_key.bin

.. code-block:: bash
dd if=/dev/urandom of=oprf_key.bin bs=32 count=1
Start server's terminal.

.. code-block:: bash
./keyword_pir_setup -in_path psi_server_data.csv -oprfkey_path secret_key.bin \\
-key_columns id -label_columns label -data_per_query 256 -label_max_len 40 \\
-out_path pir_setup_dir -params_path psi_params.bin
./bazel-bin/examples/cpp/pir/keyword_pir_setup -in_path examples/data/pir_server_data.csv \\
-key_columns id -label_columns label -count_per_query 256 -max_label_length 40 \\
-oprf_key_path oprf_key.bin -setup_path pir_setup_dir
query phase
>>>>>>>>>>>
Expand All @@ -45,17 +49,64 @@ In the server's terminal.

.. code-block:: bash
keyword_pir_server -rank 0 -setup_path pir_setup_dir \\
-oprfkey_path secret_key.bin -data_per_query 256 -label_max_len 40 \\
-params_path psi_params.bin -label_columns label
./bazel-bin/examples/cpp/pir/keyword_pir_server -rank 0 -setup_path pir_setup_dir \\
-oprf_key_path oprf_key.bin
In the client's terminal.

.. code-block:: bash
./keyword_pir_client -rank 1 -in_path psi_client_data.csv.csv \\
-key_columns id -data_per_query 256 -out_path pir_out.csv
./bazel-bin/examples/cpp/pir/keyword_pir_client -rank 1 \\
-in_path examples/data/pir_client_data.csv.csv \\
-key_columns id -out_path pir_out.csv
PIR query results write to pir_out.csv.
Run examples on two host, Please add '-parties ip1:port1,ip2:port2'.

Run keyword PIR python example
---------------------------

First build spu python whl package or install from network.

.. code-block:: bash
bash build_wheel_entrypoint.sh
install dist/spu-*.whl
setup phase
>>>>>>>>>>>

Start server's terminal.


.. code-block:: bash
python examples/python/pir/pir_setup.py --in_path examples/data/pir_server_data.csv \\
--oprf_key_path oprf_key.bin --key_columns id --label_columns label \\
--count_per_query 256 --max_label_length 40 \\
--setup_path pir_setup_dir
query phase
>>>>>>>>>>>

Start two terminals.

In the server's terminal.

.. code-block:: bash
python examples/python/pir/pir_server.py --rank 0 --setup_path pir_setup_dir \\
--oprf_key_path oprf_key.bin
In the client's terminal.

.. code-block:: bash
python examples/python/pir/pir_client.py -rank 1 \\
-in_path examples/data/pir_client_data.csv.csv \\
-key_columns id -out_path pir_out.csv
PIR query results write to pir_out.csv.
Run examples on two host, Please add '--party_ips ip1:port1,ip2:port2'.
2 changes: 1 addition & 1 deletion docs/requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -13,4 +13,4 @@ pytablewriter==0.64.2
linkify-it-py==2.0.0
mdutils==1.4.0
spu>=0.3.1b0
sf-pydata-sphinx-theme
pydata-sphinx-theme
17 changes: 3 additions & 14 deletions examples/cpp/pir/BUILD.bazel
Original file line number Diff line number Diff line change
Expand Up @@ -16,26 +16,15 @@ load("//bazel:spu.bzl", "spu_cc_binary", "spu_cc_library")

package(default_visibility = ["//visibility:public"])

spu_cc_library(
name = "keyword_pir_utils",
srcs = ["keyword_pir_utils.cc"],
hdrs = ["keyword_pir_utils.h"],
deps = [
"//libspu/psi/core/labeled_psi:labeled_psi",
"//libspu/psi/io:io",
"//libspu/psi/utils:batch_provider",
],
)

spu_cc_binary(
name = "keyword_pir_setup",
srcs = ["keyword_pir_setup.cc"],
data = [
"//examples/data",
],
deps = [
":keyword_pir_utils",
"//examples/cpp:utils",
"//libspu/pir:pir",
"//libspu/psi/core/labeled_psi:labeled_psi",
"//libspu/psi/utils:cipher_store",
"@com_google_absl//absl/strings",
Expand All @@ -53,8 +42,8 @@ spu_cc_binary(
"//examples/data",
],
deps = [
":keyword_pir_utils",
"//examples/cpp:utils",
"//libspu/pir:pir",
"//libspu/psi/core/labeled_psi:labeled_psi",
"//libspu/psi/utils:cipher_store",
"//libspu/psi/utils:serialize",
Expand All @@ -74,8 +63,8 @@ spu_cc_binary(
"//examples/data",
],
deps = [
":keyword_pir_utils",
"//examples/cpp:utils",
"//libspu/pir:pir",
"//libspu/psi/core/labeled_psi:labeled_psi",
"//libspu/psi/utils:cipher_store",
"//libspu/psi/utils:serialize",
Expand Down
2 changes: 1 addition & 1 deletion examples/cpp/pir/generate_pir_data.cc
Original file line number Diff line number Diff line change
Expand Up @@ -78,7 +78,7 @@ int main(int argc, char **argv) {
psi2_out_file << "id,id1" << '\r' << std::endl;

for (size_t idx = 0; idx < alice_item_size; idx++) {
std::string a_item = fmt::format("{:010d}", idx);
std::string a_item = fmt::format("{:010d}{:08d}", idx, idx + 900000000);
std::string b_item;
if (dist1(rand)) {
psi2_out_file << a_item << "," << id1_data << '\r' << std::endl;
Expand Down
127 changes: 11 additions & 116 deletions examples/cpp/pir/keyword_pir_client.cc
Original file line number Diff line number Diff line change
Expand Up @@ -18,23 +18,25 @@
//
// To run the example, start terminals:
// > ./keyword_pir_client -rank 1 -in_path ../../data/psi_client_data.csv.csv
// -key_columns id -data_per_query 256 -out_path pir_out.csv
// > -key_columns id -out_path pir_out.csv
// clang-format on

#include <chrono>
#include <filesystem>
#include <string>

#include "examples/cpp/pir/keyword_pir_utils.h"
#include "examples/cpp/utils.h"
#include "yacl/io/rw/csv_writer.h"

#include "libspu/pir/pir.h"
#include "libspu/psi/core/labeled_psi/psi_params.h"
#include "libspu/psi/core/labeled_psi/receiver.h"
#include "libspu/psi/utils/batch_provider.h"
#include "libspu/psi/utils/serialize.h"
#include "libspu/psi/utils/utils.h"

#include "libspu/pir/pir.pb.h"

using DurationMillis = std::chrono::duration<double, std::milli>;

llvm::cl::opt<std::string> InPathOpt("in_path", llvm::cl::init("data.csv"),
Expand All @@ -43,9 +45,6 @@ llvm::cl::opt<std::string> InPathOpt("in_path", llvm::cl::init("data.csv"),
llvm::cl::opt<std::string> KeyColumnsOpt("key_columns", llvm::cl::init("id"),
llvm::cl::desc("key columns"));

llvm::cl::opt<int> DataPerQueryOpt("data_per_query", llvm::cl::init(256),
llvm::cl::desc("data count per query"));

llvm::cl::opt<std::string> OutPathOpt(
"out_path", llvm::cl::init("."),
llvm::cl::desc("[out] pir query output path for db setup data"));
Expand All @@ -58,121 +57,17 @@ int main(int argc, char **argv) {

std::vector<std::string> ids = absl::StrSplit(KeyColumnsOpt.getValue(), ',');

// recv label columns
yacl::Buffer label_columns_buffer = link_ctx->Recv(
link_ctx->NextRank(), fmt::format("recv label columns name"));
std::vector<std::string> label_columns_name;
spu::psi::utils::DeserializeStrItems(label_columns_buffer,
&label_columns_name);

std::shared_ptr<spu::psi::IBatchProvider> query_batch_provider =
std::make_shared<spu::psi::CsvBatchProvider>(InPathOpt.getValue(), ids);

yacl::io::Schema s;
for (size_t i = 0; i < ids.size(); ++i) {
s.feature_types.push_back(yacl::io::Schema::STRING);
}
for (size_t i = 0; i < label_columns_name.size(); ++i) {
s.feature_types.push_back(yacl::io::Schema::STRING);
}

s.feature_names = ids;
s.feature_names.insert(s.feature_names.end(), label_columns_name.begin(),
label_columns_name.end());

yacl::io::WriterOptions w_op;
w_op.file_schema = s;

auto out = spu::psi::io::BuildOutputStream(
spu::psi::io::FileIoOptions(OutPathOpt.getValue()));
yacl::io::CsvWriter writer(w_op, std::move(out));
writer.Init();

size_t nr = DataPerQueryOpt.getValue();

// recv psi params
yacl::Buffer params_buffer =
link_ctx->Recv(link_ctx->NextRank(), fmt::format("recv psi params"));

apsi::PSIParams psi_params = spu::psi::ParsePsiParamsProto(params_buffer);

spu::psi::LabelPsiReceiver receiver(psi_params, true);

const auto total_query_start = std::chrono::system_clock::now();

size_t query_count = 0;

while (true) {
auto query_batch_items = query_batch_provider->ReadNextBatch(nr);

spu::psi::AllGatherItemsSize(link_ctx, query_batch_items.size());

if (query_batch_items.empty()) {
break;
}

const auto oprf_start = std::chrono::system_clock::now();
std::pair<std::vector<apsi::HashedItem>, std::vector<apsi::LabelKey>>
items_oprf = receiver.RequestOPRF(query_batch_items, link_ctx);

const auto oprf_end = std::chrono::system_clock::now();
const DurationMillis oprf_duration = oprf_end - oprf_start;
SPDLOG_INFO("*** server oprf duration:{}", oprf_duration.count());

const auto query_start = std::chrono::system_clock::now();
std::pair<std::vector<size_t>, std::vector<std::string>> query_result =
receiver.RequestQuery(items_oprf.first, items_oprf.second, link_ctx);

const auto query_end = std::chrono::system_clock::now();
const DurationMillis query_duration = query_end - query_start;
SPDLOG_INFO("*** server query duration:{}", query_duration.count());

SPDLOG_INFO("query_result size:{}", query_result.first.size());

yacl::io::ColumnVectorBatch batch;

std::vector<std::vector<std::string>> query_id_results(ids.size());
std::vector<std::vector<std::string>> query_label_results(
label_columns_name.size());

for (size_t i = 0; i < query_result.first.size(); ++i) {
std::vector<std::string> result_ids =
absl::StrSplit(query_batch_items[query_result.first[i]], ',');

SPU_ENFORCE(result_ids.size() == ids.size());

std::vector<std::string> result_labels =
absl::StrSplit(query_result.second[i], ',');
SPU_ENFORCE(result_labels.size() == label_columns_name.size());

for (size_t j = 0; j < result_ids.size(); ++j) {
query_id_results[j].push_back(result_ids[j]);
}
for (size_t j = 0; j < result_labels.size(); ++j) {
query_label_results[j].push_back(result_labels[j]);
}
}

for (size_t i = 0; i < ids.size(); ++i) {
batch.AppendCol(query_id_results[i]);
}
for (size_t i = 0; i < label_columns_name.size(); ++i) {
batch.AppendCol(query_label_results[i]);
}

writer.Add(batch);
spu::pir::PirClientConfig config;

query_count++;
}
config.set_pir_protocol(spu::pir::PirProtocol::KEYWORD_PIR_LABELED_PSI);

writer.Close();
config.set_input_path(InPathOpt.getValue());
config.mutable_key_columns()->Add(ids.begin(), ids.end());
config.set_output_path(OutPathOpt.getValue());

SPDLOG_INFO("query_count:{}", query_count);
spu::pir::PirResultReport report = spu::pir::PirClient(link_ctx, config);

const auto total_query_end = std::chrono::system_clock::now();
const DurationMillis total_query_duration =
total_query_end - total_query_start;
SPDLOG_INFO("*** total query duration:{}", total_query_duration.count());
SPDLOG_INFO("data count:{}", report.data_count());

return 0;
}
Loading

0 comments on commit 2b29087

Please sign in to comment.