Skip to content

Commit 350f616

Browse files
Ruiyu Zhufacebook-github-bot
authored andcommitted
Add rebatching API for int (#116)
Summary: Pull Request resolved: #116 For UDP protocol, we need to add a new type of gate. Rebatching gate. This type of gate allows to break a batch of values into smaller batches or combine several batches into a larger one. This diff adds the APIs for int type to batching/unbatching Reviewed By: elliottlawrence Differential Revision: D34914176 fbshipit-source-id: e8038f51cc8b7f5e25de9bb9198e931ef43c35e0
1 parent bfd98a7 commit 350f616

3 files changed

Lines changed: 100 additions & 0 deletions

File tree

fbpcf/frontend/Int.h

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -204,6 +204,14 @@ class Int {
204204
typename Int<isSigned, width, true, schedulerId, usingBatch>::ExtractedInt
205205
extractIntShare() const;
206206

207+
Int<isSigned, width, isSecret, schedulerId, usingBatch> batchingWith(
208+
const std::vector<
209+
Int<isSigned, width, isSecret, schedulerId, usingBatch>>& others)
210+
const;
211+
212+
std::vector<Int<isSigned, width, isSecret, schedulerId, usingBatch>>
213+
unbatching(std::shared_ptr<std::vector<uint32_t>> unbatchingStrategy) const;
214+
207215
private:
208216
template <typename T>
209217
std::vector<UnitIntType> convertTo64BitIntVector(

fbpcf/frontend/Int_impl.h

Lines changed: 44 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -560,4 +560,48 @@ Int<isSigned, width, isSecret, schedulerId, usingBatch>::convertBitsToInt(
560560
}
561561
}
562562

563+
template <
564+
bool isSigned,
565+
int8_t width,
566+
bool isSecret,
567+
int schedulerId,
568+
bool usingBatch>
569+
Int<isSigned, width, isSecret, schedulerId, usingBatch>
570+
Int<isSigned, width, isSecret, schedulerId, usingBatch>::batchingWith(
571+
const std::vector<Int<isSigned, width, isSecret, schedulerId, usingBatch>>&
572+
others) const {
573+
static_assert(usingBatch, "Only batch values needs to rebatch!");
574+
Int<isSigned, width, isSecret, schedulerId, usingBatch> rst;
575+
size_t batchSize = others.size();
576+
std::vector<Bit<true, schedulerId, usingBatch>> bits(batchSize);
577+
for (size_t i = 0; i < width; i++) {
578+
for (size_t j = 0; j < batchSize; j++) {
579+
bits[j] = others.at(j).data_.at(i);
580+
}
581+
rst.data_[i] = data_.at(i).batchingWith(bits);
582+
}
583+
return rst;
584+
}
585+
586+
template <
587+
bool isSigned,
588+
int8_t width,
589+
bool isSecret,
590+
int schedulerId,
591+
bool usingBatch>
592+
std::vector<Int<isSigned, width, isSecret, schedulerId, usingBatch>>
593+
Int<isSigned, width, isSecret, schedulerId, usingBatch>::unbatching(
594+
std::shared_ptr<std::vector<uint32_t>> unbatchingStrategy) const {
595+
static_assert(usingBatch, "Only batch values needs to rebatch!");
596+
std::vector<Int<isSigned, width, isSecret, schedulerId, usingBatch>> rst(
597+
unbatchingStrategy->size());
598+
for (size_t i = 0; i < width; i++) {
599+
auto bitVec = data_.at(i).unbatching(unbatchingStrategy);
600+
for (size_t j = 0; j < unbatchingStrategy->size(); j++) {
601+
rst.at(j).data_.at(i) = bitVec.at(j);
602+
}
603+
}
604+
return rst;
605+
}
606+
563607
} // namespace fbpcf::frontend

fbpcf/frontend/test/IntTest.cpp

Lines changed: 48 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1386,4 +1386,52 @@ TEST(IntTest, testSubscriptBatch) {
13861386
}
13871387
}
13881388

1389+
TEST(IntTest, testRebatch) {
1390+
const int8_t width = 15;
1391+
scheduler::SchedulerKeeper<0>::setScheduler(
1392+
std::make_unique<scheduler::PlaintextScheduler>(
1393+
scheduler::WireKeeper::createWithUnorderedMap()));
1394+
using secSignedIntBatch = Integer<Secret<Batch<Signed<width>>>, 0>;
1395+
1396+
int partyId = 2;
1397+
1398+
int size1 = 5;
1399+
int size2 = 2;
1400+
int size3 = 3;
1401+
1402+
std::vector<int64_t> v1(size1, (int64_t(1) << (width - 1)) - 1);
1403+
std::vector<int64_t> v2(size2, -1 - v1[0]);
1404+
std::vector<int64_t> v3(size3, 3);
1405+
1406+
secSignedIntBatch int1(v1, partyId);
1407+
secSignedIntBatch int2(v2, partyId);
1408+
secSignedIntBatch int3(v3, partyId);
1409+
1410+
std::vector<int64_t> expectedV(size1 + size2 + size3);
1411+
for (size_t i = 0; i < size1; i++) {
1412+
expectedV[i] = v1.at(i);
1413+
}
1414+
for (size_t i = 0; i < size2; i++) {
1415+
expectedV[size1 + i] = v2.at(i);
1416+
}
1417+
1418+
for (size_t i = 0; i < size3; i++) {
1419+
expectedV[size1 + size2 + i] = v3.at(i);
1420+
}
1421+
1422+
auto int4 = int1.batchingWith({int2, int3});
1423+
1424+
testVectorEq(int4.openToParty(partyId).getValue(), expectedV);
1425+
1426+
auto int123 = int4.unbatching(
1427+
std::make_shared<std::vector<uint32_t>>(std::vector<uint32_t>(
1428+
{static_cast<unsigned int>(size1),
1429+
static_cast<unsigned int>(size2),
1430+
static_cast<unsigned int>(size3)})));
1431+
1432+
testVectorEq(int123.at(0).openToParty(partyId).getValue(), v1);
1433+
testVectorEq(int123.at(1).openToParty(partyId).getValue(), v2);
1434+
testVectorEq(int123.at(2).openToParty(partyId).getValue(), v3);
1435+
}
1436+
13891437
} // namespace fbpcf::frontend

0 commit comments

Comments
 (0)