diff --git a/.bazelignore b/.bazelignore index 1e107f52..8b137891 100644 --- a/.bazelignore +++ b/.bazelignore @@ -1 +1 @@ -examples + diff --git a/docs/update_po.sh b/docs/update_po.sh old mode 100755 new mode 100644 index c6b21d18..eb8f4548 --- a/docs/update_po.sh +++ b/docs/update_po.sh @@ -17,4 +17,4 @@ mkdir -p _build/gettext && make gettext && sphinx-intl update -p _build/gettext -l zh_CN && -echo "po files has been updated. Please update po files in locales folder." +echo "po files has been updated. Please update po files in locales folder." \ No newline at end of file diff --git a/examples/MODULE.bazel.lock b/examples/MODULE.bazel.lock index b9b80d4d..2f4b509c 100644 --- a/examples/MODULE.bazel.lock +++ b/examples/MODULE.bazel.lock @@ -1,5 +1,5 @@ { - "lockFileVersion": 11, + "lockFileVersion": 13, "registryFileHashes": { "https://bcr.bazel.build/bazel_registry.json": "8a28e4aff06ee60aed2a8c281907fb8bcbf3b753c91fb5a5c57da3215d5b3497", "https://bcr.bazel.build/modules/abseil-cpp/20210324.2/MODULE.bazel": "7cd0312e064fde87c8d1cd79ba06c876bd23630c83466e9500321be55c96ace2", diff --git a/examples/gc/BUILD.bazel b/examples/gc/BUILD.bazel new file mode 100644 index 00000000..4f6d369a --- /dev/null +++ b/examples/gc/BUILD.bazel @@ -0,0 +1,138 @@ +# Copyright 2024 Ant Group Co., Ltd. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +load("@yacl//bazel:yacl.bzl", "yacl_cc_library", "yacl_cc_test", "yacl_cc_binary",) + +package(default_visibility = ["//visibility:public"]) + +yacl_cc_library( + name = "mitccrh", + hdrs = [ + "mitccrh.h", + ], + deps = [ + ":utils", + "//yacl/crypto/aes:aes_opt", + ], +) + +yacl_cc_library( + name = "utils", + hdrs = ["utils.h"], + deps = [ + "//yacl/base:byte_container_view", + "//yacl/base:int128", + ], +) + +yacl_cc_library( + name = "aes_128_garbler", + hdrs = [ + "aes_128_garbler.h", + ], + deps = [ + ":mitccrh", + "//yacl/base:byte_container_view", + "//yacl/io/circuit:bristol_fashion", + "//yacl/kernel:ot_kernel", + ], +) + +yacl_cc_library( + name = "sha256_garbler", + hdrs = [ + "sha256_garbler.h", + ], + deps = [ + ":mitccrh", + "//yacl/base:byte_container_view", + "//yacl/io/circuit:bristol_fashion", + ], +) + +yacl_cc_library( + name = "sha256_evaluator", + hdrs = [ + "sha256_evaluator.h", + ], + deps = [ + ":mitccrh", + "//yacl/base:byte_container_view", + "//yacl/io/circuit:bristol_fashion", + ], +) + +yacl_cc_library( + name = "aes_128_evaluator", + hdrs = [ + "aes_128_evaluator.h", + ], + deps = [ + ":mitccrh", + "//yacl/base:byte_container_view", + "//yacl/io/circuit:bristol_fashion", + "//yacl/kernel:ot_kernel", + ], +) + +yacl_cc_test( + name = "gc_test", + srcs = ["gc_test.cc"], + copts = [ + "-mavx", + "-maes", + "-mpclmul", + ], + data = ["//yacl/io/circuit:circuit_data"], + deps = [ + ":aes_128_evaluator", + ":aes_128_garbler", + ":sha256_evaluator", + ":sha256_garbler", + ], +) + +cc_binary( + name = "sha_run", + srcs = ["sha_run.cc"], + copts = [ + "-mavx", + "-maes", + "-mpclmul", + ], + data = ["//yacl/io/circuit:circuit_data"], + deps = [ + ":aes_128_evaluator", + ":aes_128_garbler", + ":sha256_garbler", + ":sha256_evaluator", + "//yacl/crypto/block_cipher:symmetric_crypto", + ], +) + +cc_binary( + name = "aes_run", + srcs = ["aes_run.cc"], + copts = [ + "-mavx", + "-maes", + "-mpclmul", + ], + data = ["//yacl/io/circuit:circuit_data"], + deps = [ + ":aes_128_evaluator", + ":aes_128_garbler", + "//yacl/crypto/block_cipher:symmetric_crypto", + ], +) \ No newline at end of file diff --git a/examples/gc/aes_128_evaluator.h b/examples/gc/aes_128_evaluator.h new file mode 100644 index 00000000..61373ffe --- /dev/null +++ b/examples/gc/aes_128_evaluator.h @@ -0,0 +1,223 @@ +// Copyright 2024 Ant Group Co., Ltd. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once +#include + +#include "examples/gc/mitccrh.h" +#include "fmt/format.h" + +#include "yacl/base/byte_container_view.h" +#include "yacl/base/dynamic_bitset.h" +#include "yacl/base/int128.h" +#include "yacl/crypto/rand/rand.h" +#include "yacl/io/circuit/bristol_fashion.h" +#include "yacl/kernel/ot_kernel.h" +#include "yacl/link/context.h" +#include "yacl/link/factory.h" + +using namespace std; +using namespace yacl; +using namespace yacl::crypto; + +namespace { + +using OtMsg = uint128_t; +using OtMsgPair = std::array; +using OtChoices = dynamic_bitset; + +} // namespace + +class EvaluatorAES { + public: + uint128_t delta; + uint128_t inv_constant; + uint128_t start_point; + MITCCRH<8> mitccrh; + + std::vector wires_; + std::vector gb_value; + yacl::io::BFCircuit circ_; + std::shared_ptr lctx; + + // The number of and gate is 6400 + uint128_t table[6400][2]; + uint128_t input; + int num_ot = 128; // input bit of evaluator + int send_bytes = 0; + uint128_t all_one_uint128_t = ~static_cast(0); + uint128_t select_mask[2] = {0, all_one_uint128_t}; + + yacl::crypto::OtRecvStore ot_recv = + OtRecvStore(num_ot, yacl::crypto::OtStoreType::Normal); + + void setup() { + size_t world_size = 2; + yacl::link::ContextDesc ctx_desc; + + for (size_t rank = 0; rank < world_size; rank++) { + const auto id = fmt::format("id-{}", rank); + const auto host = fmt::format("127.0.0.1:{}", 10010 + rank); + ctx_desc.parties.push_back({id, host}); + } + + lctx = yacl::link::FactoryBrpc().CreateContext(ctx_desc, 1); + lctx->ConnectToMesh(); + + // OT off-line + const auto ext_algorithm = yacl::crypto::OtKernel::ExtAlgorithm::SoftSpoken; + yacl::crypto::OtKernel kernel1(yacl::crypto::OtKernel::Role::Receiver, + ext_algorithm); + kernel1.init(lctx); + kernel1.eval_rot(lctx, num_ot, &ot_recv); + + // delta, inv_constant, start_point + uint128_t tmp[3]; + + yacl::Buffer r = lctx->Recv(0, "tmp"); + const uint128_t* buffer_data = r.data(); + memcpy(tmp, buffer_data, sizeof(uint128_t) * 3); + + delta = tmp[0]; + inv_constant = tmp[1]; + start_point = tmp[2]; + + mitccrh.setS(start_point); + } + + uint128_t inputProcess(yacl::io::BFCircuit param_circ_) { + circ_ = param_circ_; + gb_value.resize(circ_.nw); + wires_.resize(circ_.nw); + + yacl::dynamic_bitset bi_val; + + input = yacl::crypto::FastRandU128(); + + bi_val.append(input); + + yacl::Buffer r = lctx->Recv(0, "garbleInput1"); + + const uint128_t* buffer_data = r.data(); + + memcpy(wires_.data(), buffer_data, sizeof(uint128_t) * num_ot); + + return input; + } + void recvTable() { + yacl::Buffer r = lctx->Recv(0, "table"); + const uint128_t* buffer_data = r.data(); + int k = 0; + for (size_t i = 0; i < 6400; i++) { + for (int j = 0; j < 2; j++) { + table[i][j] = buffer_data[k]; + k++; + } + } + } + + uint128_t EVAND(uint128_t A, uint128_t B, const uint128_t* table_item, + MITCCRH<8>* mitccrh_pointer) { + uint128_t HA, HB, W; + int sa, sb; + + sa = getLSB(A); + sb = getLSB(B); + + uint128_t H[2]; + H[0] = A; + H[1] = B; + mitccrh_pointer->hash<2, 1>(H); + HA = H[0]; + HB = H[1]; + + W = HA ^ HB; + W = W ^ (select_mask[sa] & table_item[0]); + W = W ^ (select_mask[sb] & table_item[1]); + W = W ^ (select_mask[sb] & A); + return W; + } + + void EV() { + int and_num = 0; + for (size_t i = 0; i < circ_.gates.size(); i++) { + auto gate = circ_.gates[i]; + switch (gate.op) { + case yacl::io::BFCircuit::Op::XOR: { + const auto& iw0 = wires_.operator[](gate.iw[0]); + const auto& iw1 = wires_.operator[](gate.iw[1]); + wires_[gate.ow[0]] = iw0 ^ iw1; + break; + } + case yacl::io::BFCircuit::Op::AND: { + const auto& iw0 = wires_.operator[](gate.iw[0]); + const auto& iw1 = wires_.operator[](gate.iw[1]); + wires_[gate.ow[0]] = EVAND(iw0, iw1, table[and_num], &mitccrh); + and_num++; + break; + } + case yacl::io::BFCircuit::Op::INV: { + const auto& iw0 = wires_.operator[](gate.iw[0]); + wires_[gate.ow[0]] = iw0 ^ inv_constant; + break; + } + case yacl::io::BFCircuit::Op::EQ: { + wires_[gate.ow[0]] = gate.iw[0]; + break; + } + case yacl::io::BFCircuit::Op::EQW: { + const auto& iw0 = wires_.operator[](gate.iw[0]); + wires_[gate.ow[0]] = iw0; + break; + } + case yacl::io::BFCircuit::Op::MAND: { /* multiple ANDs */ + YACL_THROW("Unimplemented MAND gate"); + break; + } + default: + YACL_THROW("Unknown Gate Type: {}", (int)gate.op); + } + } + } + void sendOutput() { + size_t index = wires_.size(); + int start = index - circ_.now[0]; + lctx->Send(0, + yacl::ByteContainerView(wires_.data() + start, + sizeof(uint128_t) * num_ot), + "output"); + send_bytes += sizeof(uint128_t) * num_ot; + } + void onLineOT() { + yacl::dynamic_bitset choices; + choices.append(input); + + yacl::dynamic_bitset ot = ot_recv.CopyBitBuf(); + ot.resize(choices.size()); + + yacl::dynamic_bitset masked_choices = ot ^ choices; + lctx->Send( + 0, yacl::ByteContainerView(masked_choices.data(), sizeof(uint128_t)), + "masked_choice"); + send_bytes += sizeof(uint128_t); + + auto buf = lctx->Recv(lctx->NextRank(), ""); + std::vector batch_recv(num_ot); + std::memcpy(batch_recv.data(), buf.data(), buf.size()); + for (int j = 0; j < num_ot; ++j) { + auto idx = num_ot + j; + wires_[idx] = batch_recv[j][choices[j]] ^ ot_recv.GetBlock(j); + } + } +}; \ No newline at end of file diff --git a/examples/gc/aes_128_garbler.h b/examples/gc/aes_128_garbler.h new file mode 100644 index 00000000..46cefc1c --- /dev/null +++ b/examples/gc/aes_128_garbler.h @@ -0,0 +1,271 @@ +// Copyright 2024 Ant Group Co., Ltd. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +#include + +#include "absl/types/span.h" +#include "examples/gc/mitccrh.h" +#include "fmt/format.h" + +#include "yacl/base/byte_container_view.h" +#include "yacl/base/dynamic_bitset.h" +#include "yacl/base/int128.h" +#include "yacl/crypto/rand/rand.h" +#include "yacl/io/circuit/bristol_fashion.h" +#include "yacl/kernel/ot_kernel.h" +#include "yacl/link/context.h" +#include "yacl/link/factory.h" + +using namespace std; +using namespace yacl; + +inline uint128_t Aes128(uint128_t k, uint128_t m) { + crypto::SymmetricCrypto enc(crypto::SymmetricCrypto::CryptoType::AES128_ECB, + k); + return enc.Encrypt(m); +} + +class GarblerAES { + public: + std::shared_ptr lctx; + uint128_t delta; + uint128_t inv_constant; + uint128_t start_point; + MITCCRH<8> mitccrh; + + std::vector wires_; + std::vector gb_value; + yacl::io::BFCircuit circ_; + + // The number of and gate is 6400 + uint128_t table[6400][2]; + + uint128_t input; + uint128_t input_EV; + int send_bytes = 0; + int num_ot = 128; // input bit of evaluator + uint128_t all_one_uint128_t_ = ~static_cast<__uint128_t>(0); + uint128_t select_mask_[2] = {0, all_one_uint128_t_}; + yacl::crypto::OtSendStore ot_send = + OtSendStore(num_ot, yacl::crypto::OtStoreType::Normal); + + void setup() { + size_t world_size = 2; + yacl::link::ContextDesc ctx_desc; + + for (size_t rank = 0; rank < world_size; rank++) { + const auto id = fmt::format("id-{}", rank); + const auto host = fmt::format("127.0.0.1:{}", 10010 + rank); + ctx_desc.parties.push_back({id, host}); + } + + lctx = yacl::link::FactoryBrpc().CreateContext(ctx_desc, 0); + lctx->ConnectToMesh(); + + // OT off-line + const auto ext_algorithm = yacl::crypto::OtKernel::ExtAlgorithm::SoftSpoken; + yacl::crypto::OtKernel kernel0(yacl::crypto::OtKernel::Role::Sender, + ext_algorithm); + kernel0.init(lctx); + kernel0.eval_rot(lctx, num_ot, &ot_send); + + // delta, inv_constant, start_point + auto tmp = yacl::crypto::SecureRandVec(3); + tmp[0] = tmp[0] | 1; + lctx->Send(1, + yacl::ByteContainerView(static_cast(tmp.data()), + sizeof(uint128_t) * 3), + "tmp"); + send_bytes += sizeof(uint128_t) * 3; + + delta = tmp[0]; + inv_constant = tmp[1] ^ delta; + start_point = tmp[2]; + + mitccrh.setS(start_point); + } + + uint128_t inputProcess(yacl::io::BFCircuit param_circ_) { + circ_ = param_circ_; + gb_value.resize(circ_.nw); + wires_.resize(circ_.nw); + + input = yacl::crypto::FastRandU128(); + + yacl::dynamic_bitset bi_val; + bi_val.append(input); + + int num_of_input_wires = 0; + for (size_t i = 0; i < circ_.niv; ++i) { + num_of_input_wires += circ_.niw[i]; + } + + auto rands = yacl::crypto::SecureRandVec(num_of_input_wires); + for (int i = 0; i < num_of_input_wires; i++) { + gb_value[i] = rands[i]; + } + + for (size_t i = 0; i < circ_.niw[0]; i++) { + wires_[i] = gb_value[i] ^ (select_mask_[bi_val[i]] & delta); + } + + lctx->Send( + 1, yacl::ByteContainerView(wires_.data(), sizeof(uint128_t) * num_ot), + "garbleInput1"); + send_bytes += sizeof(uint128_t) * num_ot; + + return input; + } + + uint128_t GBAND(uint128_t LA0, uint128_t A1, uint128_t LB0, uint128_t B1, + uint128_t* table_item, MITCCRH<8>* mitccrh_pointer) { + bool pa = getLSB(LA0); + bool pb = getLSB(LB0); + + uint128_t HLA0, HA1, HLB0, HB1; + uint128_t tmp, W0; + uint128_t H[4]; + + H[0] = LA0; + H[1] = A1; + H[2] = LB0; + H[3] = B1; + + mitccrh_pointer->hash<2, 2>(H); + + HLA0 = H[0]; + HA1 = H[1]; + HLB0 = H[2]; + HB1 = H[3]; + + table_item[0] = HLA0 ^ HA1; + table_item[0] = table_item[0] ^ (select_mask_[pb] & delta); + + W0 = HLA0; + W0 = W0 ^ (select_mask_[pa] & table_item[0]); + + tmp = HLB0 ^ HB1; + table_item[1] = tmp ^ LA0; + + W0 = W0 ^ HLB0; + W0 = W0 ^ (select_mask_[pb] & tmp); + return W0; + } + void GB() { + int and_num = 0; + for (size_t i = 0; i < circ_.gates.size(); i++) { + auto gate = circ_.gates[i]; + switch (gate.op) { + case yacl::io::BFCircuit::Op::XOR: { + const auto& iw0 = gb_value.operator[](gate.iw[0]); + const auto& iw1 = gb_value.operator[](gate.iw[1]); + gb_value[gate.ow[0]] = iw0 ^ iw1; + break; + } + case yacl::io::BFCircuit::Op::AND: { + const auto& iw0 = gb_value.operator[](gate.iw[0]); + const auto& iw1 = gb_value.operator[](gate.iw[1]); + gb_value[gate.ow[0]] = GBAND(iw0, iw0 ^ delta, iw1, iw1 ^ delta, + table[and_num], &mitccrh); + and_num++; + break; + } + case yacl::io::BFCircuit::Op::INV: { + const auto& iw0 = gb_value.operator[](gate.iw[0]); + gb_value[gate.ow[0]] = iw0 ^ inv_constant; + break; + } + case yacl::io::BFCircuit::Op::EQ: { + gb_value[gate.ow[0]] = gate.iw[0]; + break; + } + case yacl::io::BFCircuit::Op::EQW: { + const auto& iw0 = gb_value.operator[](gate.iw[0]); + gb_value[gate.ow[0]] = iw0; + break; + } + case yacl::io::BFCircuit::Op::MAND: { /* multiple ANDs */ + YACL_THROW("Unimplemented MAND gate"); + break; + } + default: + YACL_THROW("Unknown Gate Type: {}", (int)gate.op); + } + } + } + + void sendTable() { + lctx->Send(1, yacl::ByteContainerView(table, sizeof(uint128_t) * 2 * 6400), + "table"); + send_bytes += sizeof(uint128_t) * 2 * 6400; + } + uint128_t decode() { + size_t index = wires_.size(); + int start = index - circ_.now[0]; + + yacl::Buffer r = lctx->Recv(1, "output"); + const uint128_t* buffer_data = r.data(); + + memcpy(wires_.data() + start, buffer_data, sizeof(uint128_t) * num_ot); + + // decode + std::vector result(1); + finalize(absl::MakeSpan(result)); + + return result[0]; + } + + template + void finalize(absl::Span outputs) { + size_t index = wires_.size(); + + for (size_t i = 0; i < circ_.nov; ++i) { + yacl::dynamic_bitset result(circ_.now[i]); + for (size_t j = 0; j < circ_.now[i]; ++j) { + int wire_index = index - circ_.now[i] + j; + result[j] = getLSB(wires_[wire_index]) ^ getLSB(gb_value[wire_index]); + } + + outputs[circ_.nov - i - 1] = *(T*)result.data(); + index -= circ_.now[i]; + } + } + void onlineOT() { + auto buf = lctx->Recv(1, "masked_choice"); + + dynamic_bitset masked_choices(num_ot); + std::memcpy(masked_choices.data(), buf.data(), buf.size()); + + std::vector batch_send(num_ot); + + for (int j = 0; j < num_ot; ++j) { + auto idx = num_ot + j; + if (!masked_choices[j]) { + batch_send[j][0] = ot_send.GetBlock(j, 0) ^ gb_value[idx]; + batch_send[j][1] = ot_send.GetBlock(j, 1) ^ gb_value[idx] ^ delta; + } else { + batch_send[j][0] = ot_send.GetBlock(j, 1) ^ gb_value[idx]; + batch_send[j][1] = ot_send.GetBlock(j, 0) ^ gb_value[idx] ^ delta; + } + } + + lctx->SendAsync( + lctx->NextRank(), + ByteContainerView(batch_send.data(), sizeof(uint128_t) * num_ot * 2), + ""); + send_bytes += sizeof(uint128_t) * num_ot * 2; + } +}; \ No newline at end of file diff --git a/examples/gc/aes_run.cc b/examples/gc/aes_run.cc new file mode 100644 index 00000000..75f5e2a2 --- /dev/null +++ b/examples/gc/aes_run.cc @@ -0,0 +1,90 @@ +// Copyright 2024 Ant Group Co., Ltd. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include +#include +#include +#include + +#include "examples/gc/aes_128_evaluator.h" +#include "examples/gc/aes_128_garbler.h" +#include "fmt/format.h" + +#include "yacl/crypto/block_cipher/symmetric_crypto.h" + +using namespace std; + +int aes_garbler_send_bytes = 0; +int aes_evaluator_send_bytes = 0; +int aes_compute_time = 0; + +void aes_performance() { + std::shared_ptr circ_; + + GarblerAES* garbler = new GarblerAES(); + EvaluatorAES* evaluator = new EvaluatorAES(); + + std::future thread1 = std::async([&] { garbler->setup(); }); + std::future thread2 = std::async([&] { evaluator->setup(); }); + thread1.get(); + thread2.get(); + + std::string pth = fmt::format("yacl/io/circuit/data/{0}.txt", "aes_128"); + yacl::io::CircuitReader reader(pth); + reader.ReadMeta(); + reader.ReadAllGates(); + circ_ = reader.StealCirc(); + + for (int i = 0; i < 1; i++) { + auto start1 = clock_start(); + uint128_t key; + uint128_t message; + thread1 = std::async([&] { key = garbler->inputProcess(*circ_); }); + thread2 = std::async([&] { message = evaluator->inputProcess(*circ_); }); + thread1.get(); + thread2.get(); + + // OT + thread1 = std::async([&] { evaluator->onLineOT(); }); + thread2 = std::async([&] { garbler->onlineOT(); }); + thread1.get(); + thread2.get(); + + garbler->GB(); + garbler->sendTable(); + + evaluator->recvTable(); + + evaluator->EV(); + + evaluator->sendOutput(); + + uint128_t gc_result = garbler->decode(); + aes_compute_time += time_from(start1); + aes_garbler_send_bytes += garbler->send_bytes; + aes_evaluator_send_bytes += evaluator->send_bytes; + } + delete garbler; + delete evaluator; +} + +int main() { + aes_performance(); + cout << "AES_performance:" << endl; + std::cout << "Garbler send: " << aes_garbler_send_bytes << " bytes" << " " + << endl; + std::cout << "Evaluator send: " << aes_evaluator_send_bytes << " bytes" + << " " << endl; + cout << "Time for Computing: " << aes_compute_time << "us" << endl; +} \ No newline at end of file diff --git a/examples/gc/emp_benchmark/ABY_aes_test.patch b/examples/gc/emp_benchmark/ABY_aes_test.patch new file mode 100644 index 00000000..c2e0e465 --- /dev/null +++ b/examples/gc/emp_benchmark/ABY_aes_test.patch @@ -0,0 +1,182 @@ +diff --git a/run_aes.sh b/run_aes.sh +new file mode 100644 +index 0000000..da61da4 +--- /dev/null ++++ b/run_aes.sh +@@ -0,0 +1,72 @@ ++#!/bin/bash ++ ++cd build ++# 输出文件路径 ++output_file="aes_test_output.txt" ++ ++# 清空输出文件 ++echo "Running AES Test..." > $output_file ++echo "----------------------------------------" >> $output_file ++ ++# 统计总和的变量 ++total_computing_time_0=0 ++total_computing_time_1=0 ++total_received_data_0=0 ++total_received_data_1=0 ++total_sent_data_0=0 ++total_sent_data_1=0 ++ ++echo "--------------- AES Test Batch 1 ------------------------" ++# 运行 100 次,并添加进度条 ++for i in {1..100} ++do ++ # 显示进度条 ++ progress=$((i * 100 / 100)) # 计算进度 ++ printf "\rProgress: [" ++ for ((j=0; j> $output_file ++ echo " Party 0: Computing Time = $(echo "$computing_time_0" | bc) ms, Received Data = $rev_data_0 bytes, Sent Data = $send_data_0 bytes" >> $output_file ++ echo " Party 1: Computing Time = $(echo "$computing_time_1" | bc) ms, Received Data = $rev_data_1 bytes, Sent Data = $send_data_1 bytes" >> $output_file ++ echo "----------------------------------------" >> $output_file ++ ++ # 延时以模拟每次测试的时间,方便查看进度条(可以根据实际运行时间调整) ++ # sleep 0.1 ++done ++ ++# 输出总的统计结果到文件,时间单位转化为毫秒 ++echo "----------------------------------------" >> $output_file ++echo "Total Results:" >> $output_file ++echo "Total Computing Time for Party 0: $(echo "$total_computing_time_0" | bc) ms" >> $output_file ++echo "Total Computing Time for Party 1: $(echo "$total_computing_time_1" | bc) ms" >> $output_file ++echo "Total Received Data for Party 0: $total_received_data_0 bytes" >> $output_file ++echo "Total Received Data for Party 1: $total_received_data_1 bytes" >> $output_file ++echo "Total Sent Data for Party 0: $total_sent_data_0 bytes" >> $output_file ++echo "Total Sent Data for Party 1: $total_sent_data_1 bytes" >> $output_file ++ ++# 提示完成 ++echo "Test completed. Results saved to $output_file" +diff --git a/src/examples/aes/aes_test.cpp b/src/examples/aes/aes_test.cpp +index 63acd6f..4ebdd6d 100644 +--- a/src/examples/aes/aes_test.cpp ++++ b/src/examples/aes/aes_test.cpp +@@ -69,7 +69,7 @@ int32_t read_test_options(int32_t* argcp, char*** argvp, e_role* role, uint32_t* + + int main(int argc, char** argv) { + e_role role; +- uint32_t bitlen = 32, nvals = 1, secparam = 128, nthreads = 1; ++ uint32_t bitlen = 128, nvals = 8, secparam = 128, nthreads = 1; + uint16_t port = 7766; + std::string address = "127.0.0.1"; + bool verbose = false; +diff --git a/src/examples/aes/common/aescircuit.cpp b/src/examples/aes/common/aescircuit.cpp +index 7c89fd8..2a12240 100644 +--- a/src/examples/aes/common/aescircuit.cpp ++++ b/src/examples/aes/common/aescircuit.cpp +@@ -15,6 +15,7 @@ + along with this program. If not, see . + \brief Implementation of AESCiruit + */ ++#include + #include "aescircuit.h" + #include "../../../abycore/circuit/booleancircuits.h" + #include "../../../abycore/sharing/sharing.h" +@@ -26,14 +27,13 @@ static uint32_t* pos_odd; + + int32_t test_aes_circuit(e_role role, const std::string& address, uint16_t port, seclvl seclvl, uint32_t nvals, uint32_t nthreads, + e_mt_gen_alg mt_alg, e_sharing sharing, [[maybe_unused]] bool verbose, bool use_vec_ands, bool expand_in_sfe, bool client_only) { +- uint32_t bitlen = 32; ++ uint32_t bitlen = 128; + uint32_t aes_key_bits; + ABYParty* party = new ABYParty(role, address, port, seclvl, bitlen, nthreads, mt_alg, 4000000); + std::vector& sharings = party->GetSharings(); + + crypto* crypt = new crypto(seclvl.symbits, (uint8_t*) const_seed); + CBitVector input, key, verify; +- + //ids that are required for the vector_and optimization + if(use_vec_ands) { + pos_even = (uint32_t*) malloc(sizeof(uint32_t) * nvals); +@@ -45,14 +45,16 @@ int32_t test_aes_circuit(e_role role, const std::string& address, uint16_t port, + } + + aes_key_bits = crypt->get_aes_key_bytes() * 8; ++ // std::cout << aes_key_bits << std::endl; + input.Create(AES_BITS * nvals, crypt); ++ std::cout << input.GetSize() << std::endl; + verify.Create(AES_BITS * nvals); + key.CreateBytes(AES_EXP_KEY_BYTES); + +- uint8_t aes_test_key[AES_KEY_BYTES]; +- srand(7438); +- for(uint32_t i = 0; i < AES_KEY_BYTES; i++) { +- aes_test_key[i] = (uint8_t) (rand() % 256); ++ uint8_t aes_test_key[AES_KEY_BYTES * 8]; ++ // srand(7438); ++ for(uint32_t i = 0; i < AES_KEY_BYTES * 8; i++) { ++ aes_test_key[i] = (uint8_t) (rand() % 2); + } + uint8_t expanded_key[AES_EXP_KEY_BYTES]; + ExpandKey(expanded_key, aes_test_key); +@@ -90,7 +92,6 @@ int32_t test_aes_circuit(e_role role, const std::string& address, uint16_t port, + if(nyao_rev_circs > 0) { + s_ciphertext_yao_rev = BuildAESCircuit(s_in_yao_rev, s_key_yao_rev, (BooleanCircuit*) yao_rev_circ, use_vec_ands); + } +- + party->ExecCircuit(); + + output = s_ciphertext_yao->get_clear_value_ptr(); +@@ -131,7 +132,6 @@ int32_t test_aes_circuit(e_role role, const std::string& address, uint16_t port, + + out.SetBytes(output, 0L, (uint64_t) AES_BYTES * nvals); + } +- + verify_AES_encryption(input.GetArr(), key.GetArr(), nvals, verify.GetArr(), crypt); + + #ifndef BATCH +@@ -155,13 +155,18 @@ int32_t test_aes_circuit(e_role role, const std::string& address, uint16_t port, + #ifndef BATCH + std::cout << "all tests succeeded" << std::endl; + #else +- std::cout << party->GetTiming(P_SETUP) << "\t" << party->GetTiming(P_GARBLE) << "\t" << party->GetTiming(P_ONLINE) << "\t" << party->GetTiming(P_TOTAL) << +- "\t" << party->GetSentData(P_TOTAL) + party->GetReceivedData(P_TOTAL) << "\t"; +- if(sharing == S_YAO_REV) { +- std::cout << sharings[S_YAO]->GetNumNonLinearOperations() +sharings[S_YAO_REV]->GetNumNonLinearOperations() << "\t" << sharings[S_YAO]->GetMaxCommunicationRounds()<< std::endl; +- } else { +- std::cout << sharings[sharing]->GetNumNonLinearOperations() << "\t" << sharings[sharing]->GetMaxCommunicationRounds()<< std::endl; +- } ++ std::cout << "role: " << role << std::endl; ++ // std::cout << party->GetTiming(P_SETUP) << "\t" << party->GetTiming(P_GARBLE) << "\t" << party->GetTiming(P_ONLINE) << "\t" << party->GetTiming(P_TOTAL) << ++ // "\t" << party->GetSentData(P_TOTAL) + party->GetReceivedData(P_TOTAL) << std::endl; ++ std::cout << role << " computing time: " << party->GetTiming(P_GARBLE) + party->GetTiming(P_ONLINE) << std::endl; ++ std::cout << role << " rev data: " << party->GetReceivedData(P_TOTAL) << std::endl; ++ std::cout << role << " send data: " << party->GetSentData(P_TOTAL) << std::endl; ++ ++ // if(sharing == S_YAO_REV) { ++ // std::cout << sharings[S_YAO]->GetNumNonLinearOperations() +sharings[S_YAO_REV]->GetNumNonLinearOperations() << "\t" << sharings[S_YAO]->GetMaxCommunicationRounds()<< std::endl; ++ // } else { ++ // std::cout << sharings[sharing]->GetNumNonLinearOperations() << "\t" << sharings[sharing]->GetMaxCommunicationRounds()<< std::endl; ++ // } + #endif + delete crypt; + delete party; diff --git a/examples/gc/emp_benchmark/batchDualEx_test.patch b/examples/gc/emp_benchmark/batchDualEx_test.patch new file mode 100644 index 00000000..c727ccd5 --- /dev/null +++ b/examples/gc/emp_benchmark/batchDualEx_test.patch @@ -0,0 +1,455 @@ +diff --git a/FrontEnd/Main.cpp b/FrontEnd/Main.cpp +index ac09818..2dbb8d0 100644 +--- a/FrontEnd/Main.cpp ++++ b/FrontEnd/Main.cpp +@@ -37,11 +37,12 @@ using namespace osuCrypto; + void Eval(Circuit& cir, u64 numExe, u64 bucketSize, u64 numOpened, u64 numConcurrentSetups, u64 numConcurrentEvals, u64 numThreadsPerEval, + bool v, + bool timefiles, +- Timer& timer); ++ Timer& timer, int role); + + + void pingTest(Endpoint& netMgr, Role role) +-{ ++{ ++ std::cout << "PINGTEST" << std::endl; + u64 count = 100; + std::array oneMB; + +@@ -174,7 +175,7 @@ void commandLineMain(int argc, const char** argv) + ); + + opt.add( +- "128", ++ "1", + 0, + 1, + 0, +@@ -184,7 +185,7 @@ void commandLineMain(int argc, const char** argv) + ); + + opt.add( +- "4", ++ "1", + 0, + 1, + 0, +@@ -206,7 +207,7 @@ void commandLineMain(int argc, const char** argv) + + + opt.add( +- "./circuits/AES-non-expanded.txt", ++ "./circuits/AES-expanded.txt", + 0, + 1, + 0, +@@ -224,7 +225,7 @@ void commandLineMain(int argc, const char** argv) + "--ping"); + + opt.add( +- "4", ++ "2", + 0, + 1, + 0, +@@ -242,7 +243,7 @@ void commandLineMain(int argc, const char** argv) + "--evalConcurrently"); + + opt.add( +- "", ++ "1", + 0, + 1, + 0, +@@ -296,7 +297,7 @@ void commandLineMain(int argc, const char** argv) + opt.get("-f")->getString(file); + verbose = opt.get("-v")->isSet; + timefiles = opt.get("-l")->isSet; +- ++ // std::cout << file << std::endl; + if (opt.get("-c")->isSet) + { + opt.get("-c")->getInt(temp); numThreadsPerEval = static_cast(temp); +@@ -373,7 +374,8 @@ void commandLineMain(int argc, const char** argv) + numThreadsPerEval, + verbose, + timefiles, +- timer); ++ timer, ++ (int)role); + + return; + } +@@ -402,7 +404,7 @@ void commandLineMain(int argc, const char** argv) + + PRNG prng(_mm_set_epi64x(0, role)); + +- std::cout << "Initializing..." << std::endl; ++ // std::cout << "Initializing..." << std::endl; + + auto initStart = timer.setTimePoint("Init Start"); + +@@ -485,12 +487,13 @@ void commandLineMain(int argc, const char** argv) + actor.printTimes("./timeFile"); + + //std::cout << "initTime " << std::chrono::duration_cast(initFinish - initStart).count() << " ms" << std::endl; +- std::cout << "Done. " << std::endl << std::endl; +- std::cout << "total offline = " << offlineTotal / 1000.0 << " ms" << std::endl; +- std::cout << "total online = " << onlineTotal / 1000.0 << " ms" << std::endl; +- std::cout << "offline / eval = " << offlineTotal / numExec / 1000.0 << " ms" << std::endl; +- std::cout << "online / eval = " << onlineTotal / numExec / 1000.0 << " ms" << std::endl; +- ++ // std::cout << "Done. " << std::endl << std::endl; ++ std::cout << "party " << role << " total offline = " << offlineTotal / 1000.0 << " ms" << std::endl; ++ std::cout << "party " << role << " total online = " << onlineTotal / 1000.0 << " ms" << std::endl; ++ std::cout << "party " << role << " offline / eval = " << offlineTotal / numExec / 1000.0 << " ms" << std::endl; ++ std::cout << "party " << role << " online / eval = " << onlineTotal / numExec / 1000.0 << " ms" << std::endl; ++ std::cout << "party " << role << " send = " << actor.mTotalBytesSent << " bytes" << std::endl; ++ std::cout << "party " << role << " rev = " << actor.mTotalBytesRecv << " bytes" << std::endl; + if (verbose) + { + +@@ -538,7 +541,8 @@ void Eval( + u64 numThreadsPerEval, + bool verbose, + bool timefiles, +- Timer& timer) ++ Timer& timer, ++ int role) + { + u64 psiSecParam = 40; + +@@ -643,10 +647,10 @@ void Eval( + thrd.join(); + + std::cout << "Done. " << std::endl << std::endl; +- std::cout << "total offline = " << std::chrono::duration_cast(initFinish - initStart).count() / 1000.0 << " ms" << std::endl; +- std::cout << "total online = " << std::chrono::duration_cast(finished - initFinish).count() / 1000.0 << " ms" << std::endl; +- std::cout << "time/eval = " << std::chrono::duration_cast(finished - initFinish).count() / numExe / 1000.0 << " ms" << std::endl; +- std::cout << "min eval time = " << std::chrono::duration_cast(min).count() / 1000.0 << " ms" << std::endl << std::endl << std::endl; ++ std::cout << role << "total offline = " << std::chrono::duration_cast(initFinish - initStart).count() / 1000.0 << " ms" << std::endl; ++ std::cout << role << "total online = " << std::chrono::duration_cast(finished - initFinish).count() / 1000.0 << " ms" << std::endl; ++ std::cout << role << "time/eval = " << std::chrono::duration_cast(finished - initFinish).count() / numExe / 1000.0 << " ms" << std::endl; ++ std::cout << role << "min eval time = " << std::chrono::duration_cast(min).count() / 1000.0 << " ms" << std::endl << std::endl << std::endl; + + if (verbose) + { +diff --git a/batchDualEx_aes.sh b/batchDualEx_aes.sh +new file mode 100644 +index 0000000..5c3546b +--- /dev/null ++++ b/batchDualEx_aes.sh +@@ -0,0 +1,66 @@ ++#!/bin/bash ++ ++EXECUTABLE="./bin/frontend.exe" ++RUNS=100 # 改成你想要运行的次数 ++ ++# 初始化总计变量 ++total_send_0=0 ++total_recv_0=0 ++total_time_0=0 ++ ++total_send_1=0 ++total_recv_1=0 ++total_time_1=0 ++ ++echo -e "=== AES TEST $RUNS Runs Batch 1 ===" ++ ++# 输出进度条函数 ++print_progress() { ++ local progress=$1 ++ local total=$2 ++ local width=50 ++ local filled=$((progress * width / total)) ++ local empty=$((width - filled)) ++ printf "\rProgress: [" ++ for ((i = 0; i < filled; i++)); do printf "#"; done ++ for ((i = 0; i < empty; i++)); do printf " "; done ++ printf "] %d/%d" "$progress" "$total" ++} ++ ++for i in $(seq 1 $RUNS); do ++ print_progress "$i" "$RUNS" ++ OUTPUT=$( ( $EXECUTABLE -r 0 & $EXECUTABLE -r 1; wait ) 2>&1 ) ++ ++ # 提取 party 0 数据 ++ recv_0=$(echo "$OUTPUT" | grep "party 0 rev" | cut -d '=' -f2 | awk '{print $1}') ++ send_0=$(echo "$OUTPUT" | grep "party 0 send" | cut -d '=' -f2 | awk '{print $1}') ++ time_0=$(echo "$OUTPUT" | grep "party 0 total offline" | awk '{print $(NF-1)}') ++ time_on_0=$(echo "$OUTPUT" | grep "party 0 total online" | awk '{print $(NF-1)}') ++ ++ # 提取 party 1 数据 ++ recv_1=$(echo "$OUTPUT" | grep "party 1 rev" | cut -d '=' -f2 | awk '{print $1}') ++ send_1=$(echo "$OUTPUT" | grep "party 1 send" | cut -d '=' -f2 | awk '{print $1}') ++ time_1=$(echo "$OUTPUT" | grep "party 1 total offline" | awk '{print $(NF-1)}') ++ time_on_1=$(echo "$OUTPUT" | grep "party 0 total online" | awk '{print $(NF-1)}') ++ ++ # 累加 ++ total_send_0=$((total_send_0 + send_0)) ++ total_recv_0=$((total_recv_0 + recv_0)) ++ total_time_0=$(echo "$total_time_0 + $time_0 + $time_on_0" | bc) ++ ++ total_send_1=$((total_send_1 + send_1)) ++ total_recv_1=$((total_recv_1 + recv_1)) ++ total_time_1=$(echo "$total_time_1 + $time_1 + $time_on_1" | bc) ++done ++ ++echo -e "\n\n=== AES TEST After $RUNS Runs ===" ++echo "Party 0: Total Send = ${total_send_0} bytes, Total Recv = ${total_recv_0} bytes, Total Time = ${total_time_0} ms" ++echo "Party 1: Total Send = ${total_recv_0} bytes, Total Recv = ${total_send_0} bytes, Total Time = ${total_time_1} ms" ++ ++{ ++ echo -e "=== AES TEST $RUNS Runs Batch 1 ===" ++ echo "Party 0: Total Send = ${total_send_0} bytes, Total Recv = ${total_recv_0} bytes, Total Time = ${total_time_0} ms" ++ echo "Party 1: Total Send = ${total_recv_0} bytes, Total Recv = ${total_send_0} bytes, Total Time = ${total_time_1} ms" ++} > ./bin/aes_result.log ++ ++echo "Result is saved in ./bin/aes_result.log" +\ No newline at end of file +diff --git a/batchDualEx_sha256.sh b/batchDualEx_sha256.sh +new file mode 100644 +index 0000000..ce5d822 +--- /dev/null ++++ b/batchDualEx_sha256.sh +@@ -0,0 +1,66 @@ ++#!/bin/bash ++ ++EXECUTABLE="./bin/frontend.exe" ++RUNS=100 # 改成你想要运行的次数 ++ ++# 初始化总计变量 ++total_send_0=0 ++total_recv_0=0 ++total_time_0=0 ++ ++total_send_1=0 ++total_recv_1=0 ++total_time_1=0 ++ ++echo -e "\n=== SHA256 TEST $RUNS Runs Batch 1 ===" ++ ++# 输出进度条函数 ++print_progress() { ++ local progress=$1 ++ local total=$2 ++ local width=50 ++ local filled=$((progress * width / total)) ++ local empty=$((width - filled)) ++ printf "\rProgress: [" ++ for ((i = 0; i < filled; i++)); do printf "#"; done ++ for ((i = 0; i < empty; i++)); do printf " "; done ++ printf "] %d/%d" "$progress" "$total" ++} ++ ++for i in $(seq 1 $RUNS); do ++ print_progress "$i" "$RUNS" ++ OUTPUT=$( ( $EXECUTABLE -r 0 & $EXECUTABLE -r 1; wait ) 2>&1 ) ++ ++ # 提取 party 0 数据 ++ recv_0=$(echo "$OUTPUT" | grep "party 0 rev" | cut -d '=' -f2 | awk '{print $1}') ++ send_0=$(echo "$OUTPUT" | grep "party 0 send" | cut -d '=' -f2 | awk '{print $1}') ++ time_0=$(echo "$OUTPUT" | grep "party 0 total offline" | awk '{print $(NF-1)}') ++ time_on_0=$(echo "$OUTPUT" | grep "party 0 total online" | awk '{print $(NF-1)}') ++ ++ # 提取 party 1 数据 ++ recv_1=$(echo "$OUTPUT" | grep "party 1 rev" | cut -d '=' -f2 | awk '{print $1}') ++ send_1=$(echo "$OUTPUT" | grep "party 1 send" | cut -d '=' -f2 | awk '{print $1}') ++ time_1=$(echo "$OUTPUT" | grep "party 1 total offline" | awk '{print $(NF-1)}') ++ time_on_1=$(echo "$OUTPUT" | grep "party 0 total online" | awk '{print $(NF-1)}') ++ ++ # 累加 ++ total_send_0=$((total_send_0 + send_0)) ++ total_recv_0=$((total_recv_0 + recv_0)) ++ total_time_0=$(echo "$total_time_0 + $time_0 + $time_on_0" | bc) ++ ++ total_send_1=$((total_send_1 + send_1)) ++ total_recv_1=$((total_recv_1 + recv_1)) ++ total_time_1=$(echo "$total_time_1 + $time_1 + $time_on_1" | bc) ++done ++ ++echo -e "\n\n=== SHA256 TEST After $RUNS Runs ===" ++echo "Party 0: Total Send = ${total_send_0} bytes, Total Recv = ${total_recv_0} bytes, Total Time = ${total_time_0} ms" ++echo "Party 1: Total Send = ${total_recv_0} bytes, Total Recv = ${total_send_0} bytes, Total Time = ${total_time_1} ms" ++ ++{ ++ echo -e "\n=== SHA256 TEST $RUNS Runs Batch 1 ===" ++ echo "Party 0: Total Send = ${total_send_0} bytes, Total Recv = ${total_recv_0} bytes, Total Time = ${total_time_0} ms" ++ echo "Party 1: Total Send = ${total_recv_0} bytes, Total Recv = ${total_send_0} bytes, Total Time = ${total_time_1} ms" ++} > ./bin/sha256_result.log ++ ++echo "Result is saved in ./bin/sha256_result.log" +\ No newline at end of file +diff --git a/libBDX/DualEx/Bucket.cpp b/libBDX/DualEx/Bucket.cpp +index 8758056..065ebd4 100644 +--- a/libBDX/DualEx/Bucket.cpp ++++ b/libBDX/DualEx/Bucket.cpp +@@ -70,7 +70,11 @@ namespace osuCrypto + + mOutputs[i].resize(cir.Outputs().size()); + } +- std::shuffle(mPSIInputPermutes.begin(), mPSIInputPermutes.end(), prng); ++ // std::shuffle(mPSIInputPermutes.begin(), mPSIInputPermutes.end(), prng); ++ for (u64 i = mPSIInputPermutes.size() - 1; i > 0; --i) { ++ u64 j = prng.get() % (i + 1); ++ std::swap(mPSIInputPermutes[i], mPSIInputPermutes[j]); ++ } + + mTheirPermutes.resize(bucketSize - 1); + BitVector delta(theirKProbe.encodingSize()); +diff --git a/libBDX/DualEx/DualExActor.cpp b/libBDX/DualEx/DualExActor.cpp +index c71b5a8..8f0801f 100644 +--- a/libBDX/DualEx/DualExActor.cpp ++++ b/libBDX/DualEx/DualExActor.cpp +@@ -115,6 +115,8 @@ namespace osuCrypto + PRNG prng2(seed); + Channel chl = mNetMgr.addChannel("OTRecv" + ToString(t), "OTSend" + ToString(t)); + mOTRecv[t].Extend(bases[t], numOTExtPer, prng2, chl, mOTRecvDoneIdx[t]); ++ mTotalBytesSent += chl.getTotalDataSent(); ++ mTotalBytesRecv += chl.getTotalDataRecv(); + chl.close(); + }); + } +@@ -125,7 +127,8 @@ namespace osuCrypto + bases[numInit - 1][i][1] = extenders[i][1].get(); + } + mOTRecv[numInit - 1].Extend(bases[numInit - 1], numOTExtPer, prng, chl, mOTRecvDoneIdx[numInit - 1]); +- ++ mTotalBytesSent += chl.getTotalDataSent(); ++ mTotalBytesRecv += chl.getTotalDataRecv(); + chl.close(); + + for (auto& thrd : thrds) +@@ -169,6 +172,8 @@ namespace osuCrypto + PRNG prng2(seed); + Channel chl = mNetMgr.addChannel("OTSend" + ToString(t), "OTRecv" + ToString(t)); + mOTSend[t].Extend(bases[t], baseOTsReceiver_inputs, numOTExtPer, prng2, chl, mOTSendDoneIdx[t]); ++ mTotalBytesSent += chl.getTotalDataSent(); ++ mTotalBytesRecv += chl.getTotalDataRecv(); + chl.close(); + }); + } +@@ -178,6 +183,8 @@ namespace osuCrypto + bases[numInit - 1][i] = extenders[i].get(); + } + mOTSend[numInit - 1].Extend(bases[numInit - 1], baseOTsReceiver_inputs, numOTExtPer, prng, chl, mOTSendDoneIdx[numInit - 1]); ++ mTotalBytesSent += chl.getTotalDataSent(); ++ mTotalBytesRecv += chl.getTotalDataRecv(); + chl.close(); + + for (auto& thrd : thrds) +@@ -392,7 +399,8 @@ namespace osuCrypto + #endif + + mOnlineProm.set_value(); +- ++ mTotalBytesSent += chl.getTotalDataSent(); ++ mTotalBytesRecv += chl.getTotalDataRecv(); + chl.close(); + + } +@@ -455,6 +463,8 @@ namespace osuCrypto + evalThreadLoop(i, numThreadsPerEval, j, numParallelEval, prng, chl);// , mPSIRecvChls[j], mPsiChannelLocks[0][j].get()); + + // all done :) ++ mTotalBytesSent += chl.getTotalDataSent(); ++ mTotalBytesRecv += chl.getTotalDataRecv(); + chl.close(); + }); + +@@ -470,6 +480,8 @@ namespace osuCrypto + sendCircuitInputLoop(i, numThreadsPerEval, j, numParallelEval, chl); + + // all done :) ++ mTotalBytesSent += chl.getTotalDataSent(); ++ mTotalBytesRecv += chl.getTotalDataRecv(); + chl.close(); + }); + } +@@ -656,6 +668,8 @@ namespace osuCrypto + //std::cout << "receive " + ToString(initThrdIdx) + " kprobe Done" << std::endl; + + // all done :) ++ mTotalBytesSent += chl.getTotalDataSent(); ++ mTotalBytesRecv += chl.getTotalDataRecv(); + chl.close(); + + timer.setTimePoint("bucketing_done"); +@@ -791,7 +805,8 @@ namespace osuCrypto + { + mBuckets[bcktIdx].initKProbeInputSend(mCircuit, chl, mTheirKProbe, prng, mRole, mIndexArray); + } +- ++ mTotalBytesSent += chl.getTotalDataSent(); ++ mTotalBytesRecv += chl.getTotalDataRecv(); + chl.close(); + } + +diff --git a/libBDX/DualEx/DualExActor.h b/libBDX/DualEx/DualExActor.h +index 39ef744..7e623d9 100644 +--- a/libBDX/DualEx/DualExActor.h ++++ b/libBDX/DualEx/DualExActor.h +@@ -45,6 +45,8 @@ namespace osuCrypto + const u64 mNumExe, mBucketSize, mNumOpened,mNumCircuits, mPsiSecParam; + std::atomic mCnCCommitRecvDone; + std::vector mBuckets; ++ u64 mTotalBytesSent = 0; ++ u64 mTotalBytesRecv = 0; + //std::atomic mEvalIdx; + + //PRNG mPrng; +diff --git a/libBDXTests/AsyncPSI_Tests.cpp b/libBDXTests/AsyncPSI_Tests.cpp +index f4e0bb8..3dafb13 100644 +--- a/libBDXTests/AsyncPSI_Tests.cpp ++++ b/libBDXTests/AsyncPSI_Tests.cpp +@@ -156,7 +156,11 @@ void AsyncPsi_FullSet_Test_Impl() + sendSet[i] = recvSet[i] = prng.get(); + } + +- std::shuffle(sendSet.begin(), sendSet.end(), prng); ++ // std::shuffle(sendSet.begin(), sendSet.end(), prng); ++ for (u64 i = sendSet.size() - 1; i > 0; --i) { ++ u64 j = prng.get() % (i + 1); ++ std::swap(sendSet[i], sendSet[j]); ++ } + + + std::string name("psi"); +diff --git a/libBDXTests/PSI_Tests.cpp b/libBDXTests/PSI_Tests.cpp +index d0f506f..eb1dcd9 100644 +--- a/libBDXTests/PSI_Tests.cpp ++++ b/libBDXTests/PSI_Tests.cpp +@@ -159,8 +159,11 @@ void Psi_FullSet_Test_Impl() + sendSet[i] = recvSet[i] = prng.get(); + } + +- std::shuffle(sendSet.begin(), sendSet.end(), prng); +- ++ // std::shuffle(sendSet.begin(), sendSet.end(), prng); ++ for (u64 i = sendSet.size() - 1; i > 0; --i) { ++ u64 j = prng.get() % (i + 1); ++ std::swap(sendSet[i], sendSet[j]); ++ } + + std::string name("psi"); + //NetworkManager netMgr0("localhost", 1212, 4, true); +diff --git a/run_batchDualEx.sh b/run_batchDualEx.sh +new file mode 100644 +index 0000000..b5fb23f +--- /dev/null ++++ b/run_batchDualEx.sh +@@ -0,0 +1,13 @@ ++#!/bin/bash ++ ++# 要修改的 C++ 文件 ++# ./FrontEnd/Main.cpp ++bash batchDualEx_aes.sh ++target_file="./FrontEnd/Main.cpp" # 替换为你的目标文件名 ++sed -i 's|"\./circuits/AES-expanded\.txt"|"./circuits/sha-256.txt"|g' "$target_file" ++echo "替换完成:./circuits/AES-expanded.txt -> sha-256.txt" ++make > tmp.log ++rm tmp.log ++bash batchDualEx_sha256.sh ++ ++ diff --git a/examples/gc/emp_benchmark/communication_cost_sh2pc.patch b/examples/gc/emp_benchmark/communication_cost_sh2pc.patch new file mode 100644 index 00000000..b7011481 --- /dev/null +++ b/examples/gc/emp_benchmark/communication_cost_sh2pc.patch @@ -0,0 +1,471 @@ +diff --git a/aes_run.sh b/aes_run.sh +new file mode 100644 +index 0000000..f3f672d +--- /dev/null ++++ b/aes_run.sh +@@ -0,0 +1,126 @@ ++#!/bin/bash ++cd build ++ ++# 初始化变量 ++total_send_rounds_1=0 ++total_recv_rounds_1=0 ++total_send_bytes_1=0 ++total_recv_bytes_1=0 ++total_send_rounds_2=0 ++total_recv_rounds_2=0 ++total_send_bytes_2=0 ++total_recv_bytes_2=0 ++total_computation_time_1=0 ++total_reading_time_1=0 ++total_computation_time_2=0 ++total_reading_time_2=0 ++ ++ ++# 运行 100 次循环 ++total=100 ++PROGRESS_WIDTH=20 ++echo "==================================AES BEGIN=========================================" ++echo "Total Iteration: $total, Batch Size 1" ++for ((i=1; i<=total; i++)); do ++ # 计算进度条 ++ percent=$((i * 100 / total)) ++ progress=$(( (i * PROGRESS_WIDTH) / total )) ++ remaining=$((PROGRESS_WIDTH - progress)) ++ bar=$(printf "%-${PROGRESS_WIDTH}s" "$(printf '█%.0s' $(seq 1 $progress))") ++ echo -ne "\r[ ${bar// / } ] $percent% ($i/$total) Running iteration $i... " ++ ++ # 启动进程 ++ port=$((1234 + i)) ++ # echo -ne "\r[$bar] $percent% ($i/$total) Running iteration $i..." ++ ./bin/test_circuit_file_aes 1 1234 > output1.log& ++ ./bin/test_circuit_file_aes 2 1234 > output2.log 2>&1 ++ wait ++ ++ # 解析 party 1 数据(清理分号) ++ computation_time_1=$(grep "Time for Computation:" output1.log | awk '{print $4}' | tr -d ';') ++ reading_time_1=$(grep "Time for Reading File and Creating Circuits:" output1.log | awk '{print $8}' | tr -d ';') ++ send_rounds_1=$(grep "party 1: send rounds:" output1.log | awk '{print $5}' | tr -d ';') ++ recv_rounds_1=$(grep "recv rounds:" output1.log | awk '{print $8}' | tr -d ';') ++ send_bytes_1=$(grep "party 1: send bytes:" output1.log | awk '{print $5}' | tr -d ';') ++ recv_bytes_1=$(grep "recv bytes:" output1.log | awk '{print $8}' | tr -d ';') ++ ++ # 解析 party 2 数据(清理分号) ++ send_rounds_2=$(grep "party 2: send rounds:" output2.log | awk '{print $5}' | tr -d ';') ++ recv_rounds_2=$(grep "recv rounds:" output2.log | awk '{print $8}' | tr -d ';') ++ send_bytes_2=$(grep "party 2: send bytes:" output2.log | awk '{print $5}' | tr -d ';') ++ recv_bytes_2=$(grep "recv bytes:" output2.log | awk '{print $8}' | tr -d ';') ++ computation_time_2=$(grep "Time for Computation:" output2.log | awk '{print $4}' | tr -d ';') ++ reading_time_2=$(grep "Time for Reading File and Creating Circuits:" output2.log | awk '{print $8}' | tr -d ';') ++ ++ # 累加数据 ++ total_send_rounds_1=$((total_send_rounds_1 + send_rounds_1)) ++ total_recv_rounds_1=$((total_recv_rounds_1 + recv_rounds_1)) ++ total_send_bytes_1=$((total_send_bytes_1 + send_bytes_1)) ++ total_recv_bytes_1=$((total_recv_bytes_1 + recv_bytes_1)) ++ total_computation_time_1=$((total_computation_time_1 + computation_time_1)) ++ total_reading_time_1=$((total_reading_time_1 + reading_time_1)) ++ ++ total_send_rounds_2=$((total_send_rounds_2 + send_rounds_2)) ++ total_recv_rounds_2=$((total_recv_rounds_2 + recv_rounds_2)) ++ total_send_bytes_2=$((total_send_bytes_2 + send_bytes_2)) ++ total_recv_bytes_2=$((total_recv_bytes_2 + recv_bytes_2)) ++ total_computation_time_2=$((total_computation_time_2 + computation_time_2)) ++ total_reading_time_2=$((total_reading_time_2 + reading_time_2)) ++done ++ ++convert_to_kb() { ++ local bytes=$1 ++ if [ "$bytes" -gt 1024 ]; then ++ echo "$((bytes / 1024)) KB" ++ else ++ echo "$bytes B" ++ fi ++} ++ ++convert_to_ms() { ++ local times=$1 ++ if [ "$times" -gt 100000000 ]; then ++ echo "$((times / 1000000)) ms" ++ else ++ echo "$times ns" ++ fi ++} ++ ++ ++echo "Emp-tool AES Total Results after 100 Iterations:" ++echo "Party 1:" ++echo " Send Rounds: $total_send_rounds_1" ++echo " Recv Rounds: $total_recv_rounds_1" ++echo " Send Bytes: $(convert_to_kb $total_send_bytes_1)" ++echo " Recv Bytes: $(convert_to_kb $total_recv_bytes_1)" ++echo " Time for Reading File and Creating Circuits: $(convert_to_ms $total_reading_time_1)" ++echo " Time for Computation: $(convert_to_ms $total_computation_time_1)" ++echo "Party 2:" ++echo " Send Rounds: $total_send_rounds_2" ++echo " Recv Rounds: $total_recv_rounds_2" ++echo " Send Bytes: $(convert_to_kb $total_send_bytes_2)" ++echo " Recv Bytes: $(convert_to_kb $total_recv_bytes_2)" ++echo " Time for Reading File and Creating Circuits: $(convert_to_ms $total_reading_time_2)" ++echo " Time for Computation: $(convert_to_ms $total_computation_time_2)" ++ ++{ ++ echo "Emp-tool AES Total Results after 100 Iterations:" ++ echo "Party 1:" ++ echo " Send Rounds: $total_send_rounds_1" ++ echo " Recv Rounds: $total_recv_rounds_1" ++ echo " Send Bytes: $(convert_to_kb $total_send_bytes_1)" ++ echo " Recv Bytes: $(convert_to_kb $total_recv_bytes_1)" ++ echo " Time for Reading File and Creating Circuits: $(convert_to_ms $total_reading_time_1)" ++ echo " Time for Computation: $(convert_to_ms $total_computation_time_1)" ++ echo "Party 2:" ++ echo " Send Rounds: $total_send_rounds_2" ++ echo " Recv Rounds: $total_recv_rounds_2" ++ echo " Send Bytes: $(convert_to_kb $total_send_bytes_2)" ++ echo " Recv Bytes: $(convert_to_kb $total_recv_bytes_2)" ++ echo " Time for Reading File and Creating Circuits: $(convert_to_ms $total_reading_time_2)" ++ echo " Time for Computation: $(convert_to_ms $total_computation_time_2)" ++} > aes_result.log ++ ++rm output1.log ++rm output2.log ++echo "Results saved to ./build/aes_result.log" +diff --git a/sha256_run.sh b/sha256_run.sh +new file mode 100644 +index 0000000..b262352 +--- /dev/null ++++ b/sha256_run.sh +@@ -0,0 +1,127 @@ ++#!/bin/bash ++cd build ++ ++# 初始化变量 ++total_send_rounds_1=0 ++total_recv_rounds_1=0 ++total_send_bytes_1=0 ++total_recv_bytes_1=0 ++total_send_rounds_2=0 ++total_recv_rounds_2=0 ++total_send_bytes_2=0 ++total_recv_bytes_2=0 ++total_computation_time_1=0 ++total_reading_time_1=0 ++total_computation_time_2=0 ++total_reading_time_2=0 ++ ++ ++# 运行 100 次循环 ++total=100 ++PROGRESS_WIDTH=20 ++echo "==================================SHA256 BEGIN=========================================" ++echo "Total Iteration: $total, Batch Size 1" ++for ((i=1; i<=total; i++)); do ++ # 计算进度条 ++ percent=$((i * 100 / total)) ++ progress=$(( (i * PROGRESS_WIDTH) / total )) ++ remaining=$((PROGRESS_WIDTH - progress)) ++ bar=$(printf "%-${PROGRESS_WIDTH}s" "$(printf '█%.0s' $(seq 1 $progress))") ++ echo -ne "\r[ ${bar// / } ] $percent% ($i/$total) Running iteration $i... " ++ ++ # 启动进程 ++ port=$((1234 + i)) ++ # echo -ne "\r[$bar] $percent% ($i/$total) Running iteration $i..." ++ ./bin/test_circuit_file_sha256 1 1234 > output1.log& ++ ./bin/test_circuit_file_sha256 2 1234 > output2.log 2>&1 ++ wait ++ ++ # 解析 party 1 数据(清理分号) ++ computation_time_1=$(grep "Time for Computation:" output1.log | awk '{print $4}' | tr -d ';') ++ reading_time_1=$(grep "Time for Reading File and Creating Circuits:" output1.log | awk '{print $8}' | tr -d ';') ++ send_rounds_1=$(grep "party 1: send rounds:" output1.log | awk '{print $5}' | tr -d ';') ++ recv_rounds_1=$(grep "recv rounds:" output1.log | awk '{print $8}' | tr -d ';') ++ send_bytes_1=$(grep "party 1: send bytes:" output1.log | awk '{print $5}' | tr -d ';') ++ recv_bytes_1=$(grep "recv bytes:" output1.log | awk '{print $8}' | tr -d ';') ++ ++ # 解析 party 2 数据(清理分号) ++ send_rounds_2=$(grep "party 2: send rounds:" output2.log | awk '{print $5}' | tr -d ';') ++ recv_rounds_2=$(grep "recv rounds:" output2.log | awk '{print $8}' | tr -d ';') ++ send_bytes_2=$(grep "party 2: send bytes:" output2.log | awk '{print $5}' | tr -d ';') ++ recv_bytes_2=$(grep "recv bytes:" output2.log | awk '{print $8}' | tr -d ';') ++ computation_time_2=$(grep "Time for Computation:" output2.log | awk '{print $4}' | tr -d ';') ++ reading_time_2=$(grep "Time for Reading File and Creating Circuits:" output2.log | awk '{print $8}' | tr -d ';') ++ ++ # 累加数据 ++ total_send_rounds_1=$((total_send_rounds_1 + send_rounds_1)) ++ total_recv_rounds_1=$((total_recv_rounds_1 + recv_rounds_1)) ++ total_send_bytes_1=$((total_send_bytes_1 + send_bytes_1)) ++ total_recv_bytes_1=$((total_recv_bytes_1 + recv_bytes_1)) ++ total_computation_time_1=$((total_computation_time_1 + computation_time_1)) ++ total_reading_time_1=$((total_reading_time_1 + reading_time_1)) ++ ++ total_send_rounds_2=$((total_send_rounds_2 + send_rounds_2)) ++ total_recv_rounds_2=$((total_recv_rounds_2 + recv_rounds_2)) ++ total_send_bytes_2=$((total_send_bytes_2 + send_bytes_2)) ++ total_recv_bytes_2=$((total_recv_bytes_2 + recv_bytes_2)) ++ total_computation_time_2=$((total_computation_time_2 + computation_time_2)) ++ total_reading_time_2=$((total_reading_time_2 + reading_time_2)) ++done ++ ++# 转换 bytes 为 KB(如果超过 1024) ++convert_to_kb() { ++ local bytes=$1 ++ if [ "$bytes" -gt 1024 ]; then ++ echo "$((bytes / 1024)) KB" ++ else ++ echo "$bytes B" ++ fi ++} ++ ++convert_to_ms() { ++ local times=$1 ++ if [ "$times" -gt 100000000 ]; then ++ echo "$((times / 1000000)) ms" ++ else ++ echo "$times ns" ++ fi ++} ++ ++ ++echo "Emp-tool SHA256 Total Results after 100 Iterations:" ++echo "Party 1:" ++echo " Send Rounds: $total_send_rounds_1" ++echo " Recv Rounds: $total_recv_rounds_1" ++echo " Send Bytes: $(convert_to_kb $total_send_bytes_1)" ++echo " Recv Bytes: $(convert_to_kb $total_recv_bytes_1)" ++echo " Time for Reading File and Creating Circuits: $(convert_to_ms $total_reading_time_1)" ++echo " Time for Computation: $(convert_to_ms $total_computation_time_1)" ++echo "Party 2:" ++echo " Send Rounds: $total_send_rounds_2" ++echo " Recv Rounds: $total_recv_rounds_2" ++echo " Send Bytes: $(convert_to_kb $total_send_bytes_2)" ++echo " Recv Bytes: $(convert_to_kb $total_recv_bytes_2)" ++echo " Time for Reading File and Creating Circuits: $(convert_to_ms $total_reading_time_2)" ++echo " Time for Computation: $(convert_to_ms $total_computation_time_2)" ++ ++{ ++ echo "Emp-tool AES Total Results after 100 Iterations:" ++ echo "Party 1:" ++ echo " Send Rounds: $total_send_rounds_1" ++ echo " Recv Rounds: $total_recv_rounds_1" ++ echo " Send Bytes: $(convert_to_kb $total_send_bytes_1)" ++ echo " Recv Bytes: $(convert_to_kb $total_recv_bytes_1)" ++ echo " Time for Reading File and Creating Circuits: $(convert_to_ms $total_reading_time_1)" ++ echo " Time for Computation: $(convert_to_ms $total_computation_time_1)" ++ echo "Party 2:" ++ echo " Send Rounds: $total_send_rounds_2" ++ echo " Recv Rounds: $total_recv_rounds_2" ++ echo " Send Bytes: $(convert_to_kb $total_send_bytes_2)" ++ echo " Recv Bytes: $(convert_to_kb $total_recv_bytes_2)" ++ echo " Time for Reading File and Creating Circuits: $(convert_to_ms $total_reading_time_2)" ++ echo " Time for Computation: $(convert_to_ms $total_computation_time_2)" ++} > sha256_result.log ++ ++rm output1.log ++rm output2.log ++echo "Results saved to ./build/sha256_result.log" +\ No newline at end of file +diff --git a/test/CMakeLists.txt b/test/CMakeLists.txt +index d4d81ac..213a101 100644 +--- a/test/CMakeLists.txt ++++ b/test/CMakeLists.txt +@@ -22,5 +22,7 @@ IF(${ENABLE_FLOAT}) + add_test_case_with_run(float) + ENDIF(${ENABLE_FLOAT}) + add_test_case_with_run(circuit_file) ++add_test_case_with_run(circuit_file_aes) ++add_test_case_with_run(circuit_file_sha256) + add_test_case_with_run(example) + add_test_case_with_run(repeat) +diff --git a/test/circuit_file_aes.cpp b/test/circuit_file_aes.cpp +new file mode 100644 +index 0000000..7c44ad3 +--- /dev/null ++++ b/test/circuit_file_aes.cpp +@@ -0,0 +1,74 @@ ++#include "emp-sh2pc/emp-sh2pc.h" ++using namespace emp; ++using namespace std; ++const string circuit_file_location = macro_xstr(EMP_CIRCUIT_PATH); ++ ++int port, party; ++string file = circuit_file_location + "/bristol_fashion/aes_128.txt"; ++ ++vector cat_vector(vector key, vector plaintext) ++{ ++ vector result = key; ++ result.insert(result.end(), plaintext.begin(), plaintext.end()); ++ return result; ++} ++ ++string bits2string(vector bits) ++{ ++ string result; ++ for (const Bit &b : bits) ++ { ++ result += b.reveal() ? '1' : '0'; ++ } ++ return result; ++} ++ ++string reverse_string(string str) ++{ ++ reverse(str.begin(), str.end()); ++ return str; ++} ++ ++void test() ++{ ++ std::srand(static_cast( ++ std::chrono::system_clock::now().time_since_epoch().count())); ++ auto start1 = clock_start(); ++ BristolFashion cf(file.c_str()); ++ cout << "Time for Reading File and Creating Circuits: " << time_from(start1) << endl; ++ ++ vector key(128); ++ for (auto &bit : key) ++ { ++ bit = (rand() % 2) == 1; ++ } ++ ++ vector plaintext(128); ++ for (auto &bit : plaintext) ++ { ++ bit = (rand() % 2) == 1; ++ } ++ cout << "key : " << bits2string(key) << endl; ++ cout << "plaintext: " << bits2string(plaintext) << endl; ++ ++ vector bit_vec = cat_vector(key, plaintext); ++ Integer a(bit_vec); ++ ++ Integer c(128, 1, PUBLIC); ++ auto start2 = clock_start(); ++ cf.compute((block *)c.bits.data(), (block *)a.bits.data()); ++ cout << "ciphertext: " << reverse_string(c.reveal()) << endl; ++ cout << "Time for Computation: " << time_from(start2) << endl; ++} ++int main(int argc, char **argv) ++{ ++ parse_party_and_port(argv, &party, &port); ++ NetIO *io = new NetIO(party == ALICE ? nullptr : "127.0.0.1", port); ++ ++ setup_semi_honest(io, party); ++ test(); ++ cout << "party " << party << ": send rounds: " << io->send_rounds << "; recv rounds: " << io->recv_rounds << endl; ++ cout << "party " << party << ": send bytes: " << io->send_bytes << "; recv bytes: " << io->recv_bytes << endl; ++ finalize_semi_honest(); ++ delete io; ++} +diff --git a/test/circuit_file_sha256.cpp b/test/circuit_file_sha256.cpp +new file mode 100644 +index 0000000..7925096 +--- /dev/null ++++ b/test/circuit_file_sha256.cpp +@@ -0,0 +1,107 @@ ++#include ++#include ++#include "emp-sh2pc/emp-sh2pc.h" ++using namespace emp; ++using namespace std; ++const string circuit_file_location = macro_xstr(EMP_CIRCUIT_PATH); ++ ++int port, party; ++string file = circuit_file_location + "/bristol_fashion/sha256.txt"; ++ ++vector cat_vector(vector key, vector plaintext) ++{ ++ vector result = key; ++ result.insert(result.end(), plaintext.begin(), plaintext.end()); ++ return result; ++} ++ ++string bits2string(vector bits) ++{ ++ string result; ++ for (const Bit &b : bits) ++ { ++ result += b.reveal() ? '1' : '0'; ++ } ++ return result; ++} ++ ++string reverse_string(string str) ++{ ++ reverse(str.begin(), str.end()); ++ return str; ++} ++ ++string bits2hexString(vector bits) ++{ ++ stringstream ss; ++ for (size_t i = 0; i < bits.size(); i += 4) ++ { ++ int hex_value = 0; ++ for (size_t j = 0; j < 4 && i + j < bits.size(); ++j) ++ { ++ hex_value = (hex_value << 1) | bits[i + j].reveal(); ++ } ++ ss << hex << hex_value; ++ } ++ return ss.str(); ++} ++ ++string biString2hexString(string str) ++{ ++ stringstream ss; ++ int len = str.length(); ++ while (len % 4 != 0) ++ { ++ str = '0' + str; ++ len++; ++ } ++ for (int i = 0; i < len; i += 4) ++ { ++ string byte_str = str.substr(i, 4); ++ int byte_val = stoi(byte_str, nullptr, 2); ++ ss << hex << setw(1) << setfill('0') << byte_val; ++ } ++ ++ return ss.str(); ++} ++ ++void test() ++{ ++ std::srand(static_cast( ++ std::chrono::system_clock::now().time_since_epoch().count())); ++ auto start1 = clock_start(); ++ BristolFashion cf(file.c_str()); ++ cout << "Time for Reading File and Creating Circuits: " << time_from(start1) << endl; ++ vector message_block(512); ++ vector hash_state(256); ++ for (auto &bit : message_block) ++ { ++ bit = (rand() % 2) == 1; ++ } ++ for (auto &bit : hash_state) ++ { ++ bit = (rand() % 2) == 1; ++ } ++ ++ cout << "message block: 0x" << bits2hexString(message_block) << endl; ++ cout << "hash state : 0x" << bits2hexString(hash_state) << endl; ++ Integer a(cat_vector(message_block, hash_state)); ++ ++ Integer c(256, 0, PUBLIC); ++ auto start2 = clock_start(); ++ cf.compute((block *)c.bits.data(), (block *)a.bits.data()); ++ cout << "ciphertext : 0x" << biString2hexString(c.reveal()) << endl; ++ cout << "Time for Computation: " << time_from(start2) << endl; ++} ++int main(int argc, char **argv) ++{ ++ parse_party_and_port(argv, &party, &port); ++ NetIO *io = new NetIO(party == ALICE ? nullptr : "127.0.0.1", port); ++ ++ setup_semi_honest(io, party); ++ test(); ++ cout << "party " << party << ": send rounds: " << io->send_rounds << "; recv rounds: " << io->recv_rounds << endl; ++ cout << "party " << party << ": send bytes: " << io->send_bytes << "; recv bytes: " << io->recv_bytes << endl; ++ finalize_semi_honest(); ++ delete io; ++} diff --git a/examples/gc/emp_benchmark/communication_cost_tool.patch b/examples/gc/emp_benchmark/communication_cost_tool.patch new file mode 100644 index 00000000..252707c2 --- /dev/null +++ b/examples/gc/emp_benchmark/communication_cost_tool.patch @@ -0,0 +1,33 @@ +diff --git a/emp-tool/io/net_io_channel.h b/emp-tool/io/net_io_channel.h +index 6566564..c049d16 100644 +--- a/emp-tool/io/net_io_channel.h ++++ b/emp-tool/io/net_io_channel.h +@@ -28,6 +28,10 @@ class NetIO: public IOChannel { public: + bool has_sent = false; + string addr; + int port; ++ int send_rounds = 0; ++ int recv_rounds = 0; ++ int send_bytes = 0; ++ int recv_bytes = 0; + NetIO(const char * address, int port, bool quiet = false) { + if (port <0 || port > 65535) { + throw std::runtime_error("Invalid port number!"); +@@ -128,6 +132,8 @@ class NetIO: public IOChannel { public: + error("net_send_data\n"); + } + has_sent = true; ++ send_rounds += 1; ++ send_bytes += len; + } + + void recv_data_internal(void * data, size_t len) { +@@ -142,6 +148,8 @@ class NetIO: public IOChannel { public: + else + error("net_recv_data\n"); + } ++ recv_rounds += 1; ++ recv_bytes += len; + } + }; + diff --git a/examples/gc/emp_benchmark/emp-readme-install.patch b/examples/gc/emp_benchmark/emp-readme-install.patch new file mode 100644 index 00000000..69f1dd92 --- /dev/null +++ b/examples/gc/emp_benchmark/emp-readme-install.patch @@ -0,0 +1,77 @@ +diff --git a/scripts/install.py b/scripts/install.py +index bca7d4d..901a34a 100644 +--- a/scripts/install.py ++++ b/scripts/install.py +@@ -1,6 +1,7 @@ + #!/usr/python + import subprocess +-install_packages = ''' ++ ++install_packages = """ + if [ "$(uname)" == "Darwin" ]; then + brew list openssl || brew install openssl + brew list pkg-config || brew install pkg-config +@@ -16,34 +17,44 @@ else + echo "System not supported yet!" + fi + fi +-''' ++""" + +-install_template = ''' +-git clone https://github.com/emp-toolkit/X.git --branch Y ++install_template = """ ++git clone https://github.com/emp-toolkit/X.git + cd X ++git checkout Y + cmake . + make -j4 + sudo make install + cd .. +-''' ++""" + + import argparse ++ + parser = argparse.ArgumentParser() +-parser.add_argument('-install', '--install', action='store_true') +-parser.add_argument('-deps', '--deps', action='store_true') +-parser.add_argument('--tool', nargs='?', const='master') +-parser.add_argument('--ot', nargs='?', const='master') +-parser.add_argument('--sh2pc', nargs='?', const='master') +-parser.add_argument('--ag2pc', nargs='?', const='master') +-parser.add_argument('--agmpc', nargs='?', const='master') +-parser.add_argument('--zk', nargs='?', const='master') ++parser.add_argument("-install", "--install", action="store_true") ++parser.add_argument("-deps", "--deps", action="store_true") ++parser.add_argument( ++ "--tool", nargs="?", const="8052d95ddf56b519a671b774865bb13157b3b4e0" ++) ++parser.add_argument("--ot", nargs="?", const="0342af547fa80477e866c56b5e2632315ae51721") ++parser.add_argument( ++ "--sh2pc", nargs="?", const="61589f52111a26015b2bb8ab359dc457f8a246eb" ++) ++parser.add_argument( ++ "--ag2pc", nargs="?", const="61589f52111a26015b2bb8ab359dc457f8a246eb" ++) ++parser.add_argument( ++ "--agmpc", nargs="?", const="0add81ed517ac5b83d3a6576572b8daa0d236303" ++) ++parser.add_argument("--zk", nargs="?", const="4a0d717f5e3d18b408db422b845ccb18e24a853b") + args = parser.parse_args() + +-if vars(args)['install'] or vars(args)['deps']: +- subprocess.call(["bash", "-c", install_packages]) ++print(vars(args)) ++ + +-for k in ['tool', 'ot', 'zk', 'sh2pc', 'ag2pc', 'agmpc']: +- if vars(args)[k]: +- template = install_template.replace("X", "emp-"+k).replace("Y", vars(args)[k]) +- print(template) +- subprocess.call(["bash", "-c", template]) ++for k in ["tool", "ot", "zk", "sh2pc", "ag2pc", "agmpc"]: ++ if vars(args)[k]: ++ template = install_template.replace("X", "emp-" + k).replace("Y", vars(args)[k]) ++ print(template) ++ subprocess.call(["bash", "-c", template]) diff --git a/examples/gc/emp_benchmark/libOte_cryptoTools.patch b/examples/gc/emp_benchmark/libOte_cryptoTools.patch new file mode 100644 index 00000000..bea5c35e --- /dev/null +++ b/examples/gc/emp_benchmark/libOte_cryptoTools.patch @@ -0,0 +1,18 @@ +diff --git a/thirdparty/linux/boost.get b/thirdparty/linux/boost.get +index 3e86ed8..bcd2f12 100644 +--- a/thirdparty/linux/boost.get ++++ b/thirdparty/linux/boost.get +@@ -3,10 +3,10 @@ + set -e + + if [ ! -d boost ]; then +- wget -c 'http://sourceforge.net/projects/boost/files/boost/1.59.0/boost_1_59_0.tar.bz2/download' -O ./boost_1_59_0.tar.bz2 +- tar xfj boost_1_59_0.tar.bz2 ++ wget -c 'http://sourceforge.net/projects/boost/files/boost/1.59.0/boost_1_59_0.tar.gz/download' -O ./boost_1_59_0.tar.gz ++ tar -xf boost_1_59_0.tar.gz + mv boost_1_59_0 boost +- rm boost_1_59_0.tar.bz2 ++ rm boost_1_59_0.tar.gz + fi + + cd ./boost diff --git a/examples/gc/emp_benchmark/run.sh b/examples/gc/emp_benchmark/run.sh new file mode 100644 index 00000000..025f4378 --- /dev/null +++ b/examples/gc/emp_benchmark/run.sh @@ -0,0 +1,103 @@ +#!/bin/bash +set -e + +echo -e "Please run \033[32mconda deactivate\033[0m to deactivate the conda environment and prevent Boost from being overridden." +sleep 1 + +OS_TYPE="$(uname)" + +if [ "$OS_TYPE" == "Darwin" ]; then + # macOS 系统 + brew list openssl || brew install openssl + brew list pkg-config || brew install pkg-config + brew list cmake || brew install cmake + brew list boost || brew install boost + brew list gmp || brew install gmp + +elif [ "$OS_TYPE" == "Linux" ]; then + if command -v apt-get >/dev/null; then + # Ubuntu/Debian 系统 + sudo apt-get update + sudo apt-get install -y software-properties-common + sudo apt-get install -y \ + cmake \ + git \ + build-essential \ + libssl-dev \ + pkg-config \ + libgmp-dev \ + libboost-all-dev + + elif command -v yum >/dev/null; then + # RHEL / CentOS / Fedora + sudo yum install -y \ + python3 \ + gcc \ + make \ + git \ + cmake \ + gcc-c++ \ + openssl-devel \ + gmp-devel \ + boost-devel + + else + echo "当前 Linux 发行版不受支持,请手动安装 cmake、git、libssl-dev、libgmp-dev 和 libboost" + exit 1 + fi +else + echo "当前系统 ($OS_TYPE) 不支持!" + exit 1 +fi + +mkdir emp_toolkit +cd emp_toolkit +git clone https://github.com/emp-toolkit/emp-tool.git +cp ../communication_cost_tool.patch emp-tool/ +cd emp-tool/ +git checkout 8052d95ddf56b519a671b774865bb13157b3b4e0 +git apply communication_cost_tool.patch +cmake . +make -j4 +sudo make install +cd .. + +git clone https://github.com/emp-toolkit/emp-readme.git +cp ../emp-readme-install.patch emp-readme/ +cd emp-readme/ +git checkout 28ed3ab07be2edda6d7841692be2c552d22d7cf5 +git apply emp-readme-install.patch +cp scripts/install.py ../ +cd .. + +python install.py --ot --sh2pc +cd .. + +git clone https://github.com/emp-toolkit/emp-sh2pc.git +cp ./communication_cost_sh2pc.patch emp-sh2pc +cd emp-sh2pc +git checkout 61589f52111a26015b2bb8ab359dc457f8a246eb +git apply --reject --whitespace=fix communication_cost_sh2pc.patch + +mkdir build +cd build +cmake .. +make -j4 +cd .. + +bash aes_run.sh +bash sha256_run.sh + +echo "ABY Test" + +git clone https://github.com/encryptogroup/ABY.git +cd ABY +git checkout d8e69414d091cafc007e65a03ef30768ebaf723d +cp ../ABY_aes_test.patch ./ +git apply ABY_aes_test.patch +mkdir build +cd build +cmake .. -DABY_BUILD_EXE=On +make +cd .. +bash run_aes.sh \ No newline at end of file diff --git a/examples/gc/emp_benchmark/run_batchEualEx.sh b/examples/gc/emp_benchmark/run_batchEualEx.sh new file mode 100644 index 00000000..7fe5e793 --- /dev/null +++ b/examples/gc/emp_benchmark/run_batchEualEx.sh @@ -0,0 +1,39 @@ +#!/bin/bash +set -e +echo -e "To ensure all dependencies are installed correctly, please run \033[32mbash ./run.sh\033[0m first." +echo -e "Please run \033[32mconda deactivate\033[0m to deactivate the conda environment and prevent Boost from being overridden." + +mkdir BatchDualEx +cd BatchDualEx + + +git clone --recursive https://github.com/osu-crypto/libOTe.git +cd libOTe +git reset --hard e0727fe6dcfdd4 +git submodule update --recursive +cp ../../libOte_cryptoTools.patch ./cryptoTools/ +cd cryptoTools +git apply libOte_cryptoTools.patch +cd thirdparty/linux +bash all.get +cd ../../.. + +cmake -G "Unix Makefiles" +make + +cd ../.. +# pwd +git clone https://github.com/osu-crypto/batchDualEx.git +cd ./batchDualEx +git checkout ffb7508342fc6d3e9288d6a79a74afbda0bd51d2 +cp ../../batchDualEx_test.patch ./ +git apply batchDualEx_test.patch +cd ./thirdparty/linux +bash ./ntl.get +cd ../.. + +cmake -G "Unix Makefiles" -DBOOST_ROOT=../libOTe/cryptoTools/thirdparty/linux/boost -DBoost_NO_SYSTEM_PATHS=ON -DBoost_USE_STATIC_LIBS=ON +make + +bash run_batchDualEx.sh + diff --git a/examples/gc/gc_test.cc b/examples/gc/gc_test.cc new file mode 100644 index 00000000..1771fc8c --- /dev/null +++ b/examples/gc/gc_test.cc @@ -0,0 +1,123 @@ +// Copyright 2024 Ant Group Co., Ltd. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include +#include +#include + +#include "examples/gc/aes_128_evaluator.h" +#include "examples/gc/aes_128_garbler.h" +#include "examples/gc/sha256_evaluator.h" +#include "examples/gc/sha256_garbler.h" +#include "fmt/format.h" +#include "gtest/gtest.h" + +#include "yacl/crypto/block_cipher/symmetric_crypto.h" + +namespace examples::gc { + +inline uint128_t Aes128(uint128_t k, uint128_t m) { + crypto::SymmetricCrypto enc(crypto::SymmetricCrypto::CryptoType::AES128_ECB, + k); + return enc.Encrypt(m); +} + +TEST(GCTest, SHA256Test) { + std::shared_ptr circ_; + + GarblerSHA256* garbler = new GarblerSHA256(); + EvaluatorSHA256* evaluator = new EvaluatorSHA256(); + + std::future thread1 = std::async([&] { garbler->setup(); }); + std::future thread2 = std::async([&] { evaluator->setup(); }); + thread1.get(); + thread2.get(); + + std::string pth = fmt::format("yacl/io/circuit/data/{0}.txt", "sha256"); + yacl::io::CircuitReader reader(pth); + reader.ReadMeta(); + reader.ReadAllGates(); + circ_ = reader.StealCirc(); + + vector sha256_result; + thread1 = std::async([&] { sha256_result = garbler->inputProcess(*circ_); }); + thread2 = std::async([&] { evaluator->inputProcess(*circ_); }); + thread1.get(); + thread2.get(); + + garbler->GB(); + garbler->sendTable(); + + evaluator->recvTable(); + + evaluator->EV(); + + evaluator->sendOutput(); + + vector gc_result = garbler->decode(); + + EXPECT_EQ(sha256_result.size(), gc_result.size()); + EXPECT_TRUE( + std::equal(gc_result.begin(), gc_result.end(), sha256_result.begin())); + delete garbler; + delete evaluator; +} + +TEST(GCTest, AESTest) { + std::shared_ptr circ_; + + GarblerAES* garbler = new GarblerAES(); + EvaluatorAES* evaluator = new EvaluatorAES(); + + std::future thread1 = std::async([&] { garbler->setup(); }); + std::future thread2 = std::async([&] { evaluator->setup(); }); + thread1.get(); + thread2.get(); + + std::string pth = fmt::format("yacl/io/circuit/data/{0}.txt", "aes_128"); + yacl::io::CircuitReader reader(pth); + reader.ReadMeta(); + reader.ReadAllGates(); + circ_ = reader.StealCirc(); + + uint128_t key; + uint128_t message; + thread1 = std::async([&] { key = garbler->inputProcess(*circ_); }); + thread2 = std::async([&] { message = evaluator->inputProcess(*circ_); }); + thread1.get(); + thread2.get(); + + // OT + thread1 = std::async([&] { evaluator->onLineOT(); }); + thread2 = std::async([&] { garbler->onlineOT(); }); + thread1.get(); + thread2.get(); + + garbler->GB(); + garbler->sendTable(); + + evaluator->recvTable(); + + evaluator->EV(); + + evaluator->sendOutput(); + + uint128_t gc_result = garbler->decode(); + auto aes = Aes128(ReverseBytes(key), ReverseBytes(message)); + EXPECT_EQ(ReverseBytes(gc_result), aes); + delete garbler; + delete evaluator; +} + +} // namespace examples::gc diff --git a/examples/gc/mitccrh.h b/examples/gc/mitccrh.h new file mode 100644 index 00000000..28f591a0 --- /dev/null +++ b/examples/gc/mitccrh.h @@ -0,0 +1,78 @@ +// Copyright 2024 Ant Group Co., Ltd. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +#include "utils.h" + +#include "yacl/crypto/aes/aes_opt.h" + +/* + * [REF] Implementation of "Better Concrete Security for Half-Gates Garbling (in + * the Multi-Instance Setting)" https://eprint.iacr.org/2019/1168.pdf + */ + +using block = __uint128_t; + +inline uint128_t Sigma(uint128_t x) { + auto _x = _mm_loadu_si128(reinterpret_cast<__m128i*>(&x)); + auto exchange = _mm_shuffle_epi32(_x, 0b01001110); + auto left = _mm_unpackhi_epi64(_x, _mm_setzero_si128()); + return reinterpret_cast(_mm_xor_si128(exchange, left)); +} + +template +class MITCCRH { + public: + yacl::crypto::AES_KEY scheduled_key[BatchSize]; + block keys[BatchSize]; + int key_used = BatchSize; + block start_point; + uint64_t gid = 0; + + void setS(block sin) { this->start_point = sin; } + + void renew_ks(uint64_t gid) { + this->gid = gid; + renew_ks(); + } + + void renew_ks() { + for (int i = 0; i < BatchSize; ++i) + keys[i] = start_point ^ yacl::MakeUint128(gid++, (uint64_t)0); + yacl::crypto::AES_opt_key_schedule(keys, scheduled_key); + key_used = 0; + } + + template + void hash_cir(block* blks) { + for (int i = 0; i < K * H; ++i) blks[i] = Sigma(blks[i]); + hash(blks); + } + + template + void hash(block* blks, bool used = false) { + assert(K <= BatchSize); + assert(BatchSize % K == 0); + if (key_used == BatchSize) renew_ks(); + + block tmp[K * H]; + for (int i = 0; i < K * H; ++i) tmp[i] = blks[i]; + + yacl::crypto::ParaEnc(tmp, scheduled_key + key_used); + if (used) key_used += K; + + for (int i = 0; i < K * H; ++i) blks[i] = blks[i] ^ tmp[i]; + } +}; diff --git a/examples/gc/run_gc_examples.sh b/examples/gc/run_gc_examples.sh new file mode 100644 index 00000000..50817ce1 --- /dev/null +++ b/examples/gc/run_gc_examples.sh @@ -0,0 +1,3 @@ +bash ycal_run_gc_aes.sh +pwd +bash ycal_run_gc_sha.sh \ No newline at end of file diff --git a/examples/gc/sha256_evaluator.h b/examples/gc/sha256_evaluator.h new file mode 100644 index 00000000..4fd91610 --- /dev/null +++ b/examples/gc/sha256_evaluator.h @@ -0,0 +1,174 @@ +// Copyright 2024 Ant Group Co., Ltd. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +#include + +#include "examples/gc/mitccrh.h" +#include "fmt/format.h" + +#include "yacl/base/byte_container_view.h" +#include "yacl/base/int128.h" +#include "yacl/io/circuit/bristol_fashion.h" +#include "yacl/link/context.h" +#include "yacl/link/factory.h" + +using namespace std; +using namespace yacl; +using namespace yacl::crypto; + +class EvaluatorSHA256 { + public: + uint128_t delta; + uint128_t inv_constant; + uint128_t start_point; + MITCCRH<8> mitccrh; + + std::vector wires_; + std::vector gb_value; + yacl::io::BFCircuit circ_; + std::shared_ptr lctx; + + // The number of and gate is 22573 + uint128_t table[22573][2]; + uint128_t input; + int num_ot = 768; // input bit + int send_bytes = 0; + + uint128_t all_one_uint128_t = ~static_cast<__uint128_t>(0); + uint128_t select_mask[2] = {0, all_one_uint128_t}; + void setup() { + size_t world_size = 2; + yacl::link::ContextDesc ctx_desc; + + for (size_t rank = 0; rank < world_size; rank++) { + const auto id = fmt::format("id-{}", rank); + const auto host = fmt::format("127.0.0.1:{}", 10086 + rank); + ctx_desc.parties.push_back({id, host}); + } + + lctx = yacl::link::FactoryBrpc().CreateContext(ctx_desc, 1); + lctx->ConnectToMesh(); + + // delta, inv_constant, start_point + uint128_t tmp[3]; + yacl::Buffer r = lctx->Recv(0, "tmp"); + const uint128_t* buffer_data = r.data(); + memcpy(tmp, buffer_data, sizeof(uint128_t) * 3); + + delta = tmp[0]; + inv_constant = tmp[1]; + start_point = tmp[2]; + + mitccrh.setS(start_point); + } + + void inputProcess(yacl::io::BFCircuit param_circ_) { + circ_ = param_circ_; + gb_value.resize(circ_.nw); + wires_.resize(circ_.nw); + + yacl::Buffer r = lctx->Recv(0, "garbleInput1"); + + const uint128_t* buffer_data = r.data(); + + memcpy(wires_.data(), buffer_data, sizeof(uint128_t) * num_ot); + } + void recvTable() { + yacl::Buffer r = lctx->Recv(0, "table"); + const uint128_t* buffer_data = r.data(); + int k = 0; + for (size_t i = 0; i < 22573; i++) { + for (int j = 0; j < 2; j++) { + table[i][j] = buffer_data[k]; + k++; + } + } + } + + uint128_t EVAND(uint128_t A, uint128_t B, const uint128_t* table_item, + MITCCRH<8>* mitccrh_pointer) { + uint128_t HA, HB, W; + int sa, sb; + + sa = getLSB(A); + sb = getLSB(B); + + uint128_t H[2]; + H[0] = A; + H[1] = B; + mitccrh_pointer->hash<2, 1>(H); + HA = H[0]; + HB = H[1]; + + W = HA ^ HB; + W = W ^ (select_mask[sa] & table_item[0]); + W = W ^ (select_mask[sb] & table_item[1]); + W = W ^ (select_mask[sb] & A); + return W; + } + + void EV() { + int table_cursor = 0; + for (size_t i = 0; i < circ_.gates.size(); i++) { + auto gate = circ_.gates[i]; + switch (gate.op) { + case yacl::io::BFCircuit::Op::XOR: { + const auto& iw0 = wires_.operator[](gate.iw[0]); + const auto& iw1 = wires_.operator[](gate.iw[1]); + wires_[gate.ow[0]] = iw0 ^ iw1; + break; + } + case yacl::io::BFCircuit::Op::AND: { + const auto& iw0 = wires_.operator[](gate.iw[0]); + const auto& iw1 = wires_.operator[](gate.iw[1]); + wires_[gate.ow[0]] = EVAND(iw0, iw1, table[table_cursor], &mitccrh); + table_cursor++; + break; + } + case yacl::io::BFCircuit::Op::INV: { + const auto& iw0 = wires_.operator[](gate.iw[0]); + wires_[gate.ow[0]] = iw0 ^ inv_constant; + break; + } + case yacl::io::BFCircuit::Op::EQ: { + wires_[gate.ow[0]] = gate.iw[0]; + break; + } + case yacl::io::BFCircuit::Op::EQW: { + const auto& iw0 = wires_.operator[](gate.iw[0]); + wires_[gate.ow[0]] = iw0; + break; + } + case yacl::io::BFCircuit::Op::MAND: { /* multiple ANDs */ + YACL_THROW("Unimplemented MAND gate"); + break; + } + default: + YACL_THROW("Unknown Gate Type: {}", (int)gate.op); + } + } + } + void sendOutput() { + size_t index = wires_.size(); + int start = index - circ_.now[0]; + lctx->Send( + 0, + yacl::ByteContainerView(wires_.data() + start, sizeof(uint128_t) * 256), + "output"); + + send_bytes = sizeof(uint128_t) * 256; + } +}; \ No newline at end of file diff --git a/examples/gc/sha256_garbler.h b/examples/gc/sha256_garbler.h new file mode 100644 index 00000000..7a01318c --- /dev/null +++ b/examples/gc/sha256_garbler.h @@ -0,0 +1,249 @@ +// Copyright 2024 Ant Group Co., Ltd. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +#include + +#include "absl/types/span.h" +#include "examples/gc/mitccrh.h" +#include "fmt/format.h" + +#include "yacl/base/byte_container_view.h" +#include "yacl/base/dynamic_bitset.h" +#include "yacl/base/int128.h" +#include "yacl/crypto/hash/ssl_hash.h" +#include "yacl/crypto/rand/rand.h" +#include "yacl/io/circuit/bristol_fashion.h" +#include "yacl/link/context.h" +#include "yacl/link/factory.h" + +using namespace std; +using namespace yacl; + +class GarblerSHA256 { + public: + std::shared_ptr lctx; + uint128_t delta; + uint128_t inv_constant; + uint128_t start_point; + MITCCRH<8> mitccrh; + + std::vector wires_; + std::vector gb_value; + yacl::io::BFCircuit circ_; + + // The number of and gate is 22573 + uint128_t table[22573][2]; + + uint128_t input; + uint128_t input_EV; + vector message; + + int num_ot = 768; // input bit + int send_bytes = 0; + uint128_t all_one_uint128_t_ = ~static_cast<__uint128_t>(0); + uint128_t select_mask_[2] = {0, all_one_uint128_t_}; + + yacl::crypto::OtSendStore ot_send = + yacl::crypto::OtSendStore(num_ot, yacl::crypto::OtStoreType::Normal); + + void setup() { + size_t world_size = 2; + yacl::link::ContextDesc ctx_desc; + + for (size_t rank = 0; rank < world_size; rank++) { + const auto id = fmt::format("id-{}", rank); + const auto host = fmt::format("127.0.0.1:{}", 10086 + rank); + ctx_desc.parties.push_back({id, host}); + } + + lctx = yacl::link::FactoryBrpc().CreateContext(ctx_desc, 0); + lctx->ConnectToMesh(); + + // delta, inv_constant, start_point + auto tmp = yacl::crypto::SecureRandVec(3); + tmp[0] = tmp[0] | 1; + lctx->Send(1, + yacl::ByteContainerView(static_cast(tmp.data()), + sizeof(uint128_t) * 3), + "tmp"); + send_bytes += sizeof(uint128_t) * 3; + + delta = tmp[0]; + inv_constant = tmp[1] ^ delta; + start_point = tmp[2]; + + mitccrh.setS(start_point); + } + + vector inputProcess(yacl::io::BFCircuit param_circ_) { + circ_ = param_circ_; + gb_value.resize(circ_.nw); + wires_.resize(circ_.nw); + + message = crypto::FastRandBytes(crypto::RandLtN(32)); + auto in_buf = io::BuiltinBFCircuit::PrepareSha256Input(message); + auto sha256_result = crypto::Sha256Hash().Update(message).CumulativeHash(); + + dynamic_bitset bi_val; + bi_val.resize(circ_.nw); + std::memcpy(bi_val.data(), in_buf.data(), in_buf.size()); + + int num_of_input_wires = 0; + for (size_t i = 0; i < circ_.niv; ++i) { + num_of_input_wires += circ_.niw[i]; + } + + auto rands = yacl::crypto::SecureRandVec(num_of_input_wires); + for (int i = 0; i < num_of_input_wires; i++) { + gb_value[i] = rands[i]; + } + + for (int i = 0; i < 768; i++) { + wires_[i] = gb_value[i] ^ (select_mask_[bi_val[i]] & delta); + } + + lctx->Send( + 1, yacl::ByteContainerView(wires_.data(), sizeof(uint128_t) * num_ot), + "garbleInput1"); + send_bytes += sizeof(uint128_t) * num_ot; + + return sha256_result; + } + + uint128_t GBAND(uint128_t LA0, uint128_t A1, uint128_t LB0, uint128_t B1, + uint128_t* table_item, MITCCRH<8>* mitccrh_pointer) { + bool pa = getLSB(LA0); + bool pb = getLSB(LB0); + + uint128_t HLA0, HA1, HLB0, HB1; + uint128_t tmp, W0; + uint128_t H[4]; + + H[0] = LA0; + H[1] = A1; + H[2] = LB0; + H[3] = B1; + + mitccrh_pointer->hash<2, 2>(H); + + HLA0 = H[0]; + HA1 = H[1]; + HLB0 = H[2]; + HB1 = H[3]; + + table_item[0] = HLA0 ^ HA1; + table_item[0] = table_item[0] ^ (select_mask_[pb] & delta); + + W0 = HLA0; + W0 = W0 ^ (select_mask_[pa] & table_item[0]); + + tmp = HLB0 ^ HB1; + table_item[1] = tmp ^ LA0; + + W0 = W0 ^ HLB0; + W0 = W0 ^ (select_mask_[pb] & tmp); + return W0; + } + void GB() { + int table_cursor = 0; + for (size_t i = 0; i < circ_.gates.size(); i++) { + auto gate = circ_.gates[i]; + switch (gate.op) { + case yacl::io::BFCircuit::Op::XOR: { + const auto& iw0 = gb_value.operator[](gate.iw[0]); + const auto& iw1 = gb_value.operator[](gate.iw[1]); + gb_value[gate.ow[0]] = iw0 ^ iw1; + break; + } + case yacl::io::BFCircuit::Op::AND: { + const auto& iw0 = gb_value.operator[](gate.iw[0]); + const auto& iw1 = gb_value.operator[](gate.iw[1]); + gb_value[gate.ow[0]] = GBAND(iw0, iw0 ^ delta, iw1, iw1 ^ delta, + table[table_cursor], &mitccrh); + table_cursor++; + break; + } + case yacl::io::BFCircuit::Op::INV: { + const auto& iw0 = gb_value.operator[](gate.iw[0]); + gb_value[gate.ow[0]] = iw0 ^ inv_constant; + break; + } + case yacl::io::BFCircuit::Op::EQ: { + gb_value[gate.ow[0]] = gate.iw[0]; + break; + } + case yacl::io::BFCircuit::Op::EQW: { + const auto& iw0 = gb_value.operator[](gate.iw[0]); + gb_value[gate.ow[0]] = iw0; + break; + } + case yacl::io::BFCircuit::Op::MAND: { /* multiple ANDs */ + YACL_THROW("Unimplemented MAND gate"); + break; + } + default: + YACL_THROW("Unknown Gate Type: {}", (int)gate.op); + } + } + } + + void sendTable() { + lctx->Send(1, yacl::ByteContainerView(table, sizeof(uint128_t) * 2 * 22573), + "table"); + send_bytes += sizeof(uint128_t) * 2 * 22573; + } + + vector decode() { + size_t index = wires_.size(); + int start = index - circ_.now[0]; + + yacl::Buffer r = lctx->Recv(1, "output"); + + memcpy(wires_.data() + start, r.data(), sizeof(uint128_t) * 256); + + const auto out_size = 32; + std::vector out(out_size); + + for (size_t i = 0; i < out_size; ++i) { + dynamic_bitset result(8); + for (size_t j = 0; j < 8; ++j) { + result[j] = + getLSB(wires_[index - 8 + j]) ^ getLSB(gb_value[index - 8 + j]); + } + out[out_size - i - 1] = *(static_cast(result.data())); + index -= 8; + } + std::reverse(out.begin(), out.end()); + + return out; + } + + template + void finalize(absl::Span outputs) { + size_t index = wires_.size(); + + for (size_t i = 0; i < circ_.nov; ++i) { + yacl::dynamic_bitset result(circ_.now[i]); + for (size_t j = 0; j < circ_.now[i]; ++j) { + int wire_index = index - circ_.now[i] + j; + result[j] = getLSB(wires_[wire_index]) ^ getLSB(gb_value[wire_index]); + } + + outputs[circ_.nov - i - 1] = *(T*)result.data(); + index -= circ_.now[i]; + } + } +}; \ No newline at end of file diff --git a/examples/gc/sha_run.cc b/examples/gc/sha_run.cc new file mode 100644 index 00000000..da7f052e --- /dev/null +++ b/examples/gc/sha_run.cc @@ -0,0 +1,89 @@ +// Copyright 2024 Ant Group Co., Ltd. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include +#include +#include +#include + +#include "examples/gc/aes_128_evaluator.h" +#include "examples/gc/aes_128_garbler.h" +#include "examples/gc/sha256_evaluator.h" +#include "examples/gc/sha256_garbler.h" +#include "fmt/format.h" + +#include "yacl/crypto/block_cipher/symmetric_crypto.h" + +using namespace std; + +int sha_garbler_send_bytes = 0; +int sha_evaluator_send_bytes = 0; + +int sha_compute_time = 0; + +void sha_performance() { + std::shared_ptr circ_; + + GarblerSHA256* garbler = new GarblerSHA256(); + EvaluatorSHA256* evaluator = new EvaluatorSHA256(); + + std::future thread1 = std::async([&] { garbler->setup(); }); + std::future thread2 = std::async([&] { evaluator->setup(); }); + thread1.get(); + thread2.get(); + + std::string pth = fmt::format("yacl/io/circuit/data/{0}.txt", "sha256"); + + yacl::io::CircuitReader reader(pth); + reader.ReadMeta(); + reader.ReadAllGates(); + circ_ = reader.StealCirc(); + + for (int i = 0; i < 1; i++) { + vector sha256_result; + + auto start1 = clock_start(); + sha256_result = garbler->inputProcess(*circ_); + evaluator->inputProcess(*circ_); + + garbler->GB(); + garbler->sendTable(); + + evaluator->recvTable(); + + evaluator->EV(); + + evaluator->sendOutput(); + + vector gc_result = garbler->decode(); + sha_compute_time += time_from(start1); + sha_garbler_send_bytes += garbler->send_bytes; + sha_evaluator_send_bytes += evaluator->send_bytes; + } + + delete garbler; + delete evaluator; +} + +int main() { + sha_performance(); + + cout << "SHA_performance:" << endl; + std::cout << "Garbler send: " << sha_garbler_send_bytes << " bytes" << " " + << endl; + std::cout << "Evaluator send: " << sha_evaluator_send_bytes << " bytes" + << " " << endl; + cout << "Time for Computing: " << sha_compute_time << "us" << endl; + cout << endl; +} diff --git a/examples/gc/utils.h b/examples/gc/utils.h new file mode 100644 index 00000000..fd16da0c --- /dev/null +++ b/examples/gc/utils.h @@ -0,0 +1,44 @@ +// Copyright 2024 Ant Group Co., Ltd. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +#include + +#include "yacl/base/byte_container_view.h" +#include "yacl/base/int128.h" + +using std::chrono::high_resolution_clock; +using std::chrono::time_point; + +// get the Least Significant Bit of uint128_t +inline bool getLSB(const uint128_t& x) { return (x & 1) == 1; } + +uint128_t ReverseBytes(uint128_t x) { + auto byte_view = yacl::ByteContainerView(&x, sizeof(x)); + uint128_t ret = 0; + auto buf = std::vector(sizeof(ret)); + for (size_t i = 0; i < byte_view.size(); ++i) { + buf[byte_view.size() - i - 1] = byte_view[i]; + } + std::memcpy(&ret, buf.data(), buf.size()); + return ret; +} + +inline time_point clock_start() { + return high_resolution_clock::now(); +} + +inline double time_from(const time_point& s) { + return std::chrono::duration_cast( + high_resolution_clock::now() - s) + .count(); +} \ No newline at end of file diff --git a/examples/gc/ycal_run_gc_aes.sh b/examples/gc/ycal_run_gc_aes.sh new file mode 100644 index 00000000..0a54ee76 --- /dev/null +++ b/examples/gc/ycal_run_gc_aes.sh @@ -0,0 +1,52 @@ +cd ../.. +bazel build examples/gc:aes_run +bazel-bin/examples/gc/aes_run > examples/gc/outputs_tmp.txt +total_garbler_send=0 +total_evaluator_send=0 +total_time=0 + +# 总次数 +total=100 + +# 清空旧输出 +> examples/gc/outputs_tmp.txt +echo "=========================== AES Epoch 100 Batch 1 ==================================" +# 运行 100 次 +for ((i=1; i<=total; i++)); do + output=$(bazel-bin/examples/gc/aes_run) + echo "$output" >> examples/gc/outputs_tmp.txt + + # 提取信息 + garbler_bytes=$(echo "$output" | grep "Garbler send" | awk '{print $3}') + evaluator_bytes=$(echo "$output" | grep "Evaluator send" | awk '{print $3}') + time_us=$(echo "$output" | grep "Time for Computing" | awk '{print $4}' | sed 's/us//') + + # 累加 + total_garbler_send=$((total_garbler_send + garbler_bytes)) + total_evaluator_send=$((total_evaluator_send + evaluator_bytes)) + total_time=$((total_time + time_us)) + + # 进度条 + progress=$((i * 100 / total)) + echo -ne "Running [$i/$total] [" + for ((j=0; j examples/gc/aes_summary.txt +echo "Total Evaluator send: $((total_evaluator_send / 1024)) KB" >> examples/gc/aes_summary.txt +echo "Total Time for Computing: ${total_time} us" >> examples/gc/aes_summary.txt + + + +rm examples/gc/outputs_tmp.txt \ No newline at end of file diff --git a/examples/gc/ycal_run_gc_sha.sh b/examples/gc/ycal_run_gc_sha.sh new file mode 100644 index 00000000..4686b451 --- /dev/null +++ b/examples/gc/ycal_run_gc_sha.sh @@ -0,0 +1,51 @@ +cd ../.. +bazel build examples/gc:sha_run +bazel-bin/examples/gc/sha_run > examples/gc/outputs_tmp.txt +total_garbler_send=0 +total_evaluator_send=0 +total_time=0 + +# 总次数 +total=100 + +# 清空旧输出 +> examples/gc/outputs_tmp.txt +echo "=========================== SHA Epoch 100 Batch 1 ==================================" +# 运行 100 次 +for ((i=1; i<=total; i++)); do + output=$(bazel-bin/examples/gc/sha_run) + echo "$output" >> examples/gc/outputs_tmp.txt + + # 提取信息 + garbler_bytes=$(echo "$output" | grep "Garbler send" | awk '{print $3}') + evaluator_bytes=$(echo "$output" | grep "Evaluator send" | awk '{print $3}') + time_us=$(echo "$output" | grep "Time for Computing" | awk '{print $4}' | sed 's/us//') + + # 累加 + total_garbler_send=$((total_garbler_send + garbler_bytes)) + total_evaluator_send=$((total_evaluator_send + evaluator_bytes)) + total_time=$((total_time + time_us)) + + # 进度条 + progress=$((i * 100 / total)) + echo -ne "Running [$i/$total] [" + for ((j=0; j examples/gc/sha_summary.txt +echo "Total Evaluator send: $((total_evaluator_send / 1024)) KB" >> examples/gc/sha_summary.txt +echo "Total Time for Computing: ${total_time} us" >> examples/gc/sha_summary.txt + + +rm examples/gc/outputs_tmp.txt \ No newline at end of file