Skip to content

Commit 6383eb9

Browse files
committed
clean code
1 parent a26e60d commit 6383eb9

8 files changed

+139
-107
lines changed

examples/gc/BUILD.bazel

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,17 @@
1+
# Copyright 2024 Ant Group Co., Ltd.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
115
load("@yacl//bazel:yacl.bzl", "yacl_cc_binary", "yacl_cc_library", "yacl_cc_test")
216

317
package(default_visibility = ["//visibility:public"])

examples/gc/aes_128_evaluator.h

Lines changed: 19 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,17 @@
1+
// Copyright 2024 Ant Group Co., Ltd.
2+
//
3+
// Licensed under the Apache License, Version 2.0 (the "License");
4+
// you may not use this file except in compliance with the License.
5+
// You may obtain a copy of the License at
6+
//
7+
// http://www.apache.org/licenses/LICENSE-2.0
8+
//
9+
// Unless required by applicable law or agreed to in writing, software
10+
// distributed under the License is distributed on an "AS IS" BASIS,
11+
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
// See the License for the specific language governing permissions and
13+
// limitations under the License.
14+
115
#pragma once
216
#include <vector>
317

@@ -38,7 +52,6 @@ class EvaluatorAES {
3852
yacl::io::BFCircuit circ_;
3953
std::shared_ptr<yacl::link::Context> lctx;
4054

41-
// 根据电路改
4255
uint128_t table[36663][2];
4356
uint128_t input;
4457
int num_ot = 128;
@@ -49,7 +62,6 @@ class EvaluatorAES {
4962
OtRecvStore(num_ot, yacl::crypto::OtStoreType::Normal);
5063

5164
void setup() {
52-
// 通信环境初始化
5365
size_t world_size = 2;
5466
yacl::link::ContextDesc ctx_desc;
5567

@@ -69,8 +81,9 @@ class EvaluatorAES {
6981
kernel1.init(lctx);
7082
kernel1.eval_rot(lctx, num_ot, &ot_recv);
7183

84+
// delta, inv_constant, start_point
7285
uint128_t tmp[3];
73-
// delta, inv_constant, start_point 接收
86+
7487
yacl::Buffer r = lctx->Recv(0, "tmp");
7588
const uint128_t* buffer_data = r.data<const uint128_t>();
7689
memcpy(tmp, buffer_data, sizeof(uint128_t) * 3);
@@ -80,7 +93,6 @@ class EvaluatorAES {
8093
inv_constant = tmp[1];
8194
start_point = tmp[2];
8295

83-
// 秘钥生成
8496
mitccrh.setS(start_point);
8597
}
8698

@@ -89,14 +101,12 @@ class EvaluatorAES {
89101
gb_value.resize(circ_.nw);
90102
wires_.resize(circ_.nw);
91103

92-
// 输入位数有关
93104
yacl::dynamic_bitset<uint128_t> bi_val;
94105

95-
// 输入位数有关
96106
input = yacl::crypto::FastRandU128();
97107
std::cout << "input of evaluator:" << input << std::endl;
98-
bi_val.append(input); // 直接转换为二进制 输入线路在前64位
99-
// 接收garbler混淆值
108+
bi_val.append(input);
109+
100110
yacl::Buffer r = lctx->Recv(0, "garbleInput1");
101111

102112
const uint128_t* buffer_data = r.data<const uint128_t>();
@@ -105,19 +115,6 @@ class EvaluatorAES {
105115

106116
std::cout << "recvInput1" << std::endl;
107117

108-
// 对evaluator自己的输入值进行混淆
109-
// r = lctx->Recv(0, "garbleInput2");
110-
// buffer_data = r.data<const uint128_t>();
111-
// for (int i = 0; i < circ_.niw[1]; i++) {
112-
// wires_[i + circ_.niw[0]] =
113-
// buffer_data[i] ^ (select_mask[bi_val[i]] & delta);
114-
115-
// }
116-
// std::cout << "recvInput2" << std::endl;
117-
118-
// onLineOT();
119-
120-
// 输入位数有关
121118
lctx->Send(0, yacl::ByteContainerView(&input, sizeof(uint128_t)), "Input1");
122119

123120
return input;
@@ -136,7 +133,6 @@ class EvaluatorAES {
136133
std::cout << "recvTable" << std::endl;
137134
}
138135

139-
// 未检查
140136
uint128_t EVAND(uint128_t A, uint128_t B, const uint128_t* table_item,
141137
MITCCRH<8>* mitccrh_pointer) {
142138
uint128_t HA, HB, W;
@@ -164,7 +160,7 @@ class EvaluatorAES {
164160
auto gate = circ_.gates[i];
165161
switch (gate.op) {
166162
case yacl::io::BFCircuit::Op::XOR: {
167-
const auto& iw0 = wires_.operator[](gate.iw[0]); // 取到具体值
163+
const auto& iw0 = wires_.operator[](gate.iw[0]);
168164
const auto& iw1 = wires_.operator[](gate.iw[1]);
169165
wires_[gate.ow[0]] = iw0 ^ iw1;
170166
break;

examples/gc/aes_128_garbler.h

Lines changed: 22 additions & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,17 @@
1+
// Copyright 2024 Ant Group Co., Ltd.
2+
//
3+
// Licensed under the Apache License, Version 2.0 (the "License");
4+
// you may not use this file except in compliance with the License.
5+
// You may obtain a copy of the License at
6+
//
7+
// http://www.apache.org/licenses/LICENSE-2.0
8+
//
9+
// Unless required by applicable law or agreed to in writing, software
10+
// distributed under the License is distributed on an "AS IS" BASIS,
11+
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
// See the License for the specific language governing permissions and
13+
// limitations under the License.
14+
115
#pragma once
216

317
#include <vector>
@@ -21,8 +35,6 @@
2135
using namespace std;
2236
using namespace yacl;
2337

24-
25-
2638
inline uint128_t Aes128(uint128_t k, uint128_t m) {
2739
crypto::SymmetricCrypto enc(crypto::SymmetricCrypto::CryptoType::AES128_ECB,
2840
k);
@@ -50,22 +62,19 @@ class GarblerAES {
5062
std::vector<uint128_t> wires_;
5163
std::vector<uint128_t> gb_value;
5264
yacl::io::BFCircuit circ_;
53-
// 根据电路改
65+
5466
uint128_t table[36663][2];
5567

56-
// 输入数据类型需要修改
5768
uint128_t input;
5869
uint128_t input_EV;
5970

60-
// num_ot根据输入改
6171
int num_ot = 128;
6272
uint128_t all_one_uint128_t_ = ~static_cast<__uint128_t>(0);
6373
uint128_t select_mask_[2] = {0, all_one_uint128_t_};
6474
yacl::crypto::OtSendStore ot_send =
6575
OtSendStore(num_ot, yacl::crypto::OtStoreType::Normal);
6676

6777
void setup() {
68-
// 通信环境初始化
6978
size_t world_size = 2;
7079
yacl::link::ContextDesc ctx_desc;
7180

@@ -85,11 +94,10 @@ class GarblerAES {
8594
kernel0.init(lctx);
8695
kernel0.eval_rot(lctx, num_ot, &ot_send);
8796

88-
// delta, inv_constant, start_point 初始化并发送给evaluator
97+
// delta, inv_constant, start_point
8998
uint128_t tmp[3];
9099

91-
// random_uint128_t(tmp, 3);
92-
for(int i = 0; i < 3; i++){
100+
for (int i = 0; i < 3; i++) {
93101
std::random_device rd;
94102
std::mt19937_64 eng(rd());
95103
std::uniform_int_distribution<uint64_t> distr;
@@ -107,32 +115,26 @@ class GarblerAES {
107115
inv_constant = tmp[1] ^ delta;
108116
start_point = tmp[2];
109117

110-
// 秘钥生成
111118
mitccrh.setS(start_point);
112119
}
113120

114-
// 包扩 输入值生成和混淆,garbler混淆值的发送
115121
uint128_t inputProcess(yacl::io::BFCircuit param_circ_) {
116122
circ_ = param_circ_;
117123
gb_value.resize(circ_.nw);
118124
wires_.resize(circ_.nw);
119125

120-
// 输入位数有关
121126
input = yacl::crypto::FastRandU128();
122127
std::cout << "input of garbler:" << input << std::endl;
123128

124-
// 输入位数有关
125129
yacl::dynamic_bitset<uint128_t> bi_val;
126-
bi_val.append(input); // 直接转换为二进制 输入线路在前64位
130+
bi_val.append(input);
127131

128-
// 混淆过程
129132
int num_of_input_wires = 0;
130133
for (size_t i = 0; i < circ_.niv; ++i) {
131134
num_of_input_wires += circ_.niw[i];
132135
}
133136

134-
// random_uint128_t(gb_value.data(), num_of_input_wires);
135-
for(int i = 0; i < num_of_input_wires; i++){
137+
for (int i = 0; i < num_of_input_wires; i++) {
136138
std::random_device rd;
137139
std::mt19937_64 eng(rd());
138140
std::uniform_int_distribution<uint64_t> distr;
@@ -143,7 +145,6 @@ class GarblerAES {
143145
gb_value[i] = MakeUint128(high, low);
144146
}
145147

146-
// 前64位 直接置换 garbler
147148
for (size_t i = 0; i < circ_.niw[0]; i++) {
148149
wires_[i] = gb_value[i] ^ (select_mask_[bi_val[i]] & delta);
149150
}
@@ -158,7 +159,6 @@ class GarblerAES {
158159

159160
yacl::Buffer r = lctx->Recv(1, "Input1");
160161

161-
// 输入位数有关
162162
const uint128_t* buffer_data = r.data<const uint128_t>();
163163
input_EV = *buffer_data;
164164

@@ -204,7 +204,7 @@ class GarblerAES {
204204
auto gate = circ_.gates[i];
205205
switch (gate.op) {
206206
case yacl::io::BFCircuit::Op::XOR: {
207-
const auto& iw0 = gb_value.operator[](gate.iw[0]); // 取到具体值
207+
const auto& iw0 = gb_value.operator[](gate.iw[0]);
208208
const auto& iw1 = gb_value.operator[](gate.iw[1]);
209209
gb_value[gate.ow[0]] = iw0 ^ iw1;
210210
break;
@@ -247,7 +247,6 @@ class GarblerAES {
247247
std::cout << "sendTable" << std::endl;
248248
}
249249
uint128_t decode() {
250-
// 现接收计算结果
251250
size_t index = wires_.size();
252251
int start = index - circ_.now[0];
253252

@@ -258,31 +257,24 @@ class GarblerAES {
258257
std::cout << "recvOutput" << std::endl;
259258

260259
// decode
261-
262-
// 线路有关 输出位数
263260
std::vector<uint128_t> result(1);
264261
finalize(absl::MakeSpan(result));
265262
std::cout << "MPC结果:" << ReverseBytes(result[0]) << std::endl;
266263
std::cout << "明文结果:"
267264
<< Aes128(ReverseBytes(input), ReverseBytes(input_EV))
268-
<< std::endl; // 待修改
265+
<< std::endl;
269266
return result[0];
270267
}
271268

272269
template <typename T>
273270
void finalize(absl::Span<T> outputs) {
274-
// YACL_ENFORCE(outputs.size() >= circ_->nov);
275-
276271
size_t index = wires_.size();
277272

278273
for (size_t i = 0; i < circ_.nov; ++i) {
279274
yacl::dynamic_bitset<T> result(circ_.now[i]);
280275
for (size_t j = 0; j < circ_.now[i]; ++j) {
281276
int wire_index = index - circ_.now[i] + j;
282-
result[j] = getLSB(wires_[wire_index]) ^
283-
getLSB(gb_value[wire_index]); // 得到的是逆序的二进制值
284-
// 对应的混淆电路计算为LSB ^
285-
// d 输出线路在后xx位
277+
result[j] = getLSB(wires_[wire_index]) ^ getLSB(gb_value[wire_index]);
286278
}
287279

288280
outputs[circ_.nov - i - 1] = *(T*)result.data();

examples/gc/gc_test.cc

Lines changed: 17 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,16 @@
1-
1+
// Copyright 2024 Ant Group Co., Ltd.
2+
//
3+
// Licensed under the Apache License, Version 2.0 (the "License");
4+
// you may not use this file except in compliance with the License.
5+
// You may obtain a copy of the License at
6+
//
7+
// http://www.apache.org/licenses/LICENSE-2.0
8+
//
9+
// Unless required by applicable law or agreed to in writing, software
10+
// distributed under the License is distributed on an "AS IS" BASIS,
11+
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
// See the License for the specific language governing permissions and
13+
// limitations under the License.
214

315
#include <vector>
416

@@ -32,7 +44,7 @@ uint128_t ReverseBytes(uint128_t x) {
3244

3345
TEST(GCTest, SHA256Test) {
3446
std::shared_ptr<yacl::io::BFCircuit> circ_;
35-
// 初始化
47+
3648
GarblerSHA256* garbler = new GarblerSHA256();
3749
EvaluatorSHA256* evaluator = new EvaluatorSHA256();
3850

@@ -41,34 +53,27 @@ TEST(GCTest, SHA256Test) {
4153
thread1.get();
4254
thread2.get();
4355

44-
// 电路读取
4556
std::string pth =
4657
fmt::format("{0}/yacl/io/circuit/data/{1}.txt",
4758
std::filesystem::current_path().string(), "sha256");
4859
yacl::io::CircuitReader reader(pth);
4960
reader.ReadMeta();
5061
reader.ReadAllGates();
51-
circ_ = reader.StealCirc(); // 指针
52-
53-
// 输入处理
54-
// garbler->inputProcess(*circ_);
62+
circ_ = reader.StealCirc();
5563

5664
vector<uint8_t> sha256_result;
5765
thread1 = std::async([&] { sha256_result = garbler->inputProcess(*circ_); });
5866
thread2 = std::async([&] { evaluator->inputProcess(*circ_); });
5967
thread1.get();
6068
thread2.get();
6169

62-
// 混淆方对整个电路进行混淆, 并将混淆表发送给evaluator
6370
garbler->GB();
6471
garbler->sendTable();
6572

6673
evaluator->recvTable();
6774

68-
// // 计算方进行计算 按拓扑顺序进行计算
6975
evaluator->EV();
7076

71-
// // // evaluator发送计算结果 garbler进行DE操作
7277
evaluator->sendOutput();
7378

7479
vector<uint8_t> gc_result = garbler->decode();
@@ -80,7 +85,7 @@ TEST(GCTest, SHA256Test) {
8085

8186
TEST(GCTest, AESTest) {
8287
std::shared_ptr<yacl::io::BFCircuit> circ_;
83-
// 初始化
88+
8489
GarblerAES* garbler = new GarblerAES();
8590
EvaluatorAES* evaluator = new EvaluatorAES();
8691

@@ -89,17 +94,13 @@ TEST(GCTest, AESTest) {
8994
thread1.get();
9095
thread2.get();
9196

92-
// 电路读取
9397
std::string pth =
9498
fmt::format("{0}/yacl/io/circuit/data/{1}.txt",
9599
std::filesystem::current_path().string(), "aes_128");
96100
yacl::io::CircuitReader reader(pth);
97101
reader.ReadMeta();
98102
reader.ReadAllGates();
99-
circ_ = reader.StealCirc(); // 指针
100-
101-
// 输入处理
102-
// garbler->inputProcess(*circ_);
103+
circ_ = reader.StealCirc();
103104

104105
uint128_t key;
105106
uint128_t message;
@@ -114,16 +115,13 @@ TEST(GCTest, AESTest) {
114115
thread1.get();
115116
thread2.get();
116117

117-
// 混淆方对整个电路进行混淆, 并将混淆表发送给evaluator
118118
garbler->GB();
119119
garbler->sendTable();
120120

121121
evaluator->recvTable();
122122

123-
// // 计算方进行计算 按拓扑顺序进行计算
124123
evaluator->EV();
125124

126-
// // // evaluator发送计算结果 garbler进行DE操作
127125
evaluator->sendOutput();
128126

129127
uint128_t gc_result = garbler->decode();

0 commit comments

Comments
 (0)