Skip to content

Commit 0f006af

Browse files
Ruiyu Zhufacebook-github-bot
authored andcommitted
Add rebatching API for bitstring (#117)
Summary: Pull Request resolved: #117 For UDP protocol, we need to add a new type of gate. Rebatching gate. This type of gate allows to break a batch of values into smaller batches or combine several batches into a larger one. This diff adds the APIs for bit string type to batching/unbatching Reviewed By: elliottlawrence Differential Revision: D34914178 fbshipit-source-id: 411e1a60f2afdcfd72bb89a41e809ea9712781e2
1 parent 350f616 commit 0f006af

3 files changed

Lines changed: 123 additions & 0 deletions

File tree

fbpcf/frontend/BitString.h

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -121,6 +121,13 @@ class BitString : public scheduler::SchedulerKeeper<schedulerId> {
121121
mux(const Bit<isSecretChoice, schedulerId, usingBatch>& choice,
122122
const BitString<isSecretOther, schedulerId, usingBatch>& other) const;
123123

124+
BitString<isSecret, schedulerId, usingBatch> batchingWith(
125+
const std::vector<BitString<isSecret, schedulerId, usingBatch>>& others)
126+
const;
127+
128+
std::vector<BitString<isSecret, schedulerId, usingBatch>> unbatching(
129+
std::shared_ptr<std::vector<uint32_t>> unbatchingStrategy) const;
130+
124131
private:
125132
std::vector<Bit<isSecret, schedulerId, usingBatch>> data_;
126133

fbpcf/frontend/BitString_impl.h

Lines changed: 46 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -160,4 +160,50 @@ BitString<isSecret, schedulerId, usingBatch>::mux(
160160
return rst;
161161
}
162162

163+
template <bool isSecret, int schedulerId, bool usingBatch>
164+
BitString<isSecret, schedulerId, usingBatch>
165+
BitString<isSecret, schedulerId, usingBatch>::batchingWith(
166+
const std::vector<BitString<isSecret, schedulerId, usingBatch>>& others)
167+
const {
168+
static_assert(usingBatch, "Only batch values needs to rebatch!");
169+
170+
for (auto& item : others) {
171+
if (item.data_.size() != data_.size()) {
172+
throw std::runtime_error(
173+
"The BitStrings need to have the same length to batch together.");
174+
}
175+
}
176+
177+
BitString<isSecret, schedulerId, usingBatch> rst(data_.size());
178+
size_t batchSize = others.size();
179+
std::vector<Bit<true, schedulerId, usingBatch>> bits(batchSize);
180+
for (size_t i = 0; i < data_.size(); i++) {
181+
for (size_t j = 0; j < batchSize; j++) {
182+
bits[j] = others.at(j).data_.at(i);
183+
}
184+
rst.data_[i] = data_.at(i).batchingWith(bits);
185+
}
186+
return rst;
187+
}
188+
189+
template <bool isSecret, int schedulerId, bool usingBatch>
190+
std::vector<BitString<isSecret, schedulerId, usingBatch>>
191+
BitString<isSecret, schedulerId, usingBatch>::unbatching(
192+
std::shared_ptr<std::vector<uint32_t>> unbatchingStrategy) const {
193+
static_assert(usingBatch, "Only batch values needs to rebatch!");
194+
std::vector<BitString<isSecret, schedulerId, usingBatch>> rst(
195+
unbatchingStrategy->size());
196+
for (auto& item : rst) {
197+
item.resize(data_.size());
198+
}
199+
200+
for (size_t i = 0; i < data_.size(); i++) {
201+
auto bitVec = data_.at(i).unbatching(unbatchingStrategy);
202+
for (size_t j = 0; j < unbatchingStrategy->size(); j++) {
203+
rst.at(j).data_.at(i) = bitVec.at(j);
204+
}
205+
}
206+
return rst;
207+
}
208+
163209
} // namespace fbpcf::frontend

fbpcf/frontend/test/BitStringTest.cpp

Lines changed: 70 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@
1616
#include "fbpcf/test/TestHelper.h"
1717

1818
namespace fbpcf::frontend {
19+
1920
TEST(StringTest, testInputAndOutput) {
2021
std::random_device rd;
2122
std::mt19937_64 e(rd());
@@ -453,4 +454,73 @@ TEST(StringTest, testResizeWithAND) {
453454
}
454455
}
455456

457+
TEST(StringTest, testRebatching) {
458+
std::random_device rd;
459+
std::mt19937_64 e(rd());
460+
std::uniform_int_distribution<uint32_t> dSize(1, 1024);
461+
462+
std::uniform_int_distribution<uint8_t> dBool(0, 1);
463+
464+
scheduler::SchedulerKeeper<0>::setScheduler(
465+
std::make_unique<scheduler::PlaintextScheduler>(
466+
scheduler::WireKeeper::createWithUnorderedMap()));
467+
468+
using SecBatchString = BitString<true, 0, true>;
469+
470+
std::vector<bool> testValue(dSize(e));
471+
for (size_t i = 0; i < testValue.size(); i++) {
472+
testValue[i] = dBool(e);
473+
}
474+
uint32_t length = dSize(e);
475+
uint32_t batchSize1 = dSize(e);
476+
uint32_t batchSize2 = dSize(e);
477+
uint32_t batchSize3 = dSize(e);
478+
std::vector<std::vector<bool>> testBatchValue1(
479+
length, std::vector<bool>(batchSize1));
480+
std::vector<std::vector<bool>> testBatchValue2(
481+
length, std::vector<bool>(batchSize2));
482+
std::vector<std::vector<bool>> testBatchValue3(
483+
length, std::vector<bool>(batchSize3));
484+
485+
for (size_t i = 0; i < length; i++) {
486+
for (size_t j = 0; j < batchSize1; j++) {
487+
testBatchValue1[i][j] = dBool(e);
488+
}
489+
for (size_t j = 0; j < batchSize2; j++) {
490+
testBatchValue2[i][j] = dBool(e);
491+
}
492+
for (size_t j = 0; j < batchSize3; j++) {
493+
testBatchValue3[i][j] = dBool(e);
494+
}
495+
}
496+
497+
SecBatchString v1(testBatchValue1, 0);
498+
SecBatchString v2(testBatchValue2, 0);
499+
SecBatchString v3(testBatchValue3, 0);
500+
501+
auto v4 = v1.batchingWith({v2, v3});
502+
auto v123 = v4.unbatching(std::make_shared<std::vector<uint32_t>>(
503+
std::vector<uint32_t>({batchSize1, batchSize2, batchSize3})));
504+
505+
auto t4 = v4.openToParty(0).getValue();
506+
auto t5 = v123.at(0).openToParty(0).getValue();
507+
auto t6 = v123.at(1).openToParty(0).getValue();
508+
auto t7 = v123.at(2).openToParty(0).getValue();
509+
510+
for (size_t i = 0; i < length; i++) {
511+
testVectorEq(t5.at(i), testBatchValue1.at(i));
512+
testVectorEq(t6.at(i), testBatchValue2.at(i));
513+
testVectorEq(t7.at(i), testBatchValue3.at(i));
514+
testBatchValue1[i].insert(
515+
testBatchValue1[i].end(),
516+
testBatchValue2[i].begin(),
517+
testBatchValue2[i].end());
518+
testBatchValue1[i].insert(
519+
testBatchValue1[i].end(),
520+
testBatchValue3[i].begin(),
521+
testBatchValue3[i].end());
522+
testVectorEq(t4.at(i), testBatchValue1.at(i));
523+
}
524+
}
525+
456526
} // namespace fbpcf::frontend

0 commit comments

Comments
 (0)