Skip to content

Commit 8fb6035

Browse files
adelesunfacebook-github-bot
authored andcommitted
Implement the max function for integer type and its unit tests (#179)
Summary: Pull Request resolved: #179 Implement the max function for integer type and its unit tests Reviewed By: RuiyuZhu Differential Revision: D35390074 fbshipit-source-id: 36aa50fce720314c2d645c1e91fafc9eeee46c43
1 parent 6d623e0 commit 8fb6035

3 files changed

Lines changed: 168 additions & 0 deletions

File tree

fbpcf/frontend/Int.h

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -279,6 +279,17 @@ using Integer =
279279
typename IntTypeHelper<typename T::type, IsSecret<T>::value, schedulerId>::
280280
type;
281281

282+
template <
283+
bool isSigned,
284+
int8_t width,
285+
bool isSecret1,
286+
bool isSecret2,
287+
int schedulerId,
288+
bool usingBatch>
289+
Int<isSigned, width, isSecret1 || isSecret2, schedulerId, usingBatch> max(
290+
const Int<isSigned, width, isSecret1, schedulerId, usingBatch>& left,
291+
const Int<isSigned, width, isSecret2, schedulerId, usingBatch>& right);
292+
282293
} // namespace fbpcf::frontend
283294

284295
#include "fbpcf/frontend/Int_impl.h"

fbpcf/frontend/Int_impl.h

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -604,4 +604,17 @@ Int<isSigned, width, isSecret, schedulerId, usingBatch>::unbatching(
604604
return rst;
605605
}
606606

607+
template <
608+
bool isSigned,
609+
int8_t width,
610+
bool isSecret1,
611+
bool isSecret2,
612+
int schedulerId,
613+
bool usingBatch>
614+
Int<isSigned, width, isSecret1 || isSecret2, schedulerId, usingBatch> max(
615+
const Int<isSigned, width, isSecret1, schedulerId, usingBatch>& left,
616+
const Int<isSigned, width, isSecret2, schedulerId, usingBatch>& right) {
617+
return left.mux(left < right, right);
618+
}
619+
607620
} // namespace fbpcf::frontend

fbpcf/frontend/test/IntTest.cpp

Lines changed: 144 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1434,4 +1434,148 @@ TEST(IntTest, testRebatch) {
14341434
testVectorEq(int123.at(2).openToParty(partyId).getValue(), v3);
14351435
}
14361436

1437+
TEST(IntTest, testMax) {
1438+
const int8_t width = 64;
1439+
1440+
int64_t largestSigned = std::numeric_limits<int64_t>().max();
1441+
int64_t smallestSigned = std::numeric_limits<int64_t>().min();
1442+
uint64_t largestUnsigned = std::numeric_limits<uint64_t>().max();
1443+
1444+
scheduler::SchedulerKeeper<0>::setScheduler(
1445+
std::make_unique<scheduler::PlaintextScheduler>(
1446+
scheduler::WireKeeper::createWithUnorderedMap()));
1447+
using secSignedInt = Integer<Secret<Signed<width>>, 0>;
1448+
using pubSignedInt = Integer<Public<Signed<width>>, 0>;
1449+
using secUnsignedInt = Integer<Secret<Unsigned<width>>, 0>;
1450+
using pubUnsignedInt = Integer<Public<Unsigned<width>>, 0>;
1451+
1452+
int partyId = 2;
1453+
1454+
std::random_device rd;
1455+
std::mt19937_64 e(rd());
1456+
std::uniform_int_distribution<int64_t> dist1(smallestSigned, largestSigned);
1457+
1458+
std::uniform_int_distribution<uint64_t> dist2(0, largestUnsigned);
1459+
1460+
for (int i = 0; i < 100; i++) {
1461+
int64_t v1 = dist1(e);
1462+
int64_t v2 = dist1(e);
1463+
1464+
secSignedInt int1(v1, partyId);
1465+
secSignedInt int2(v2, partyId);
1466+
pubSignedInt int3(v1);
1467+
pubSignedInt int4(v2);
1468+
1469+
auto r1 = max(int1, int2);
1470+
auto r2 = max(int1, int4);
1471+
auto r3 = max(int3, int4);
1472+
1473+
auto expectedValue = v1 < v2 ? v2 : v1;
1474+
1475+
EXPECT_EQ(r1.openToParty(partyId).getValue(), expectedValue);
1476+
EXPECT_EQ(r2.openToParty(partyId).getValue(), expectedValue);
1477+
EXPECT_EQ(r3.getValue(), expectedValue);
1478+
}
1479+
1480+
for (int i = 0; i < 100; i++) {
1481+
uint64_t v1 = dist2(e);
1482+
uint64_t v2 = dist2(e);
1483+
1484+
secUnsignedInt int1(v1, partyId);
1485+
secUnsignedInt int2(v2, partyId);
1486+
pubUnsignedInt int3(v1);
1487+
pubUnsignedInt int4(v2);
1488+
1489+
auto r1 = max(int1, int2);
1490+
auto r2 = max(int1, int4);
1491+
auto r3 = max(int3, int4);
1492+
1493+
auto expectedValue = v1 < v2 ? v2 : v1;
1494+
1495+
EXPECT_EQ(r1.openToParty(partyId).getValue(), expectedValue);
1496+
EXPECT_EQ(r2.openToParty(partyId).getValue(), expectedValue);
1497+
EXPECT_EQ(r3.getValue(), expectedValue);
1498+
}
1499+
}
1500+
1501+
TEST(IntTest, testMaxBatch) {
1502+
const int8_t width = 64;
1503+
1504+
int64_t largestSigned = std::numeric_limits<int64_t>().max();
1505+
int64_t smallestSigned = std::numeric_limits<int64_t>().min();
1506+
uint64_t largestUnsigned = std::numeric_limits<uint64_t>().max();
1507+
1508+
scheduler::SchedulerKeeper<0>::setScheduler(
1509+
std::make_unique<scheduler::PlaintextScheduler>(
1510+
scheduler::WireKeeper::createWithUnorderedMap()));
1511+
using secSignedIntBatch = Integer<Secret<Batch<Signed<width>>>, 0>;
1512+
using pubSignedIntBatch = Integer<Public<Batch<Signed<width>>>, 0>;
1513+
using secUnsignedIntBatch = Integer<Secret<Batch<Unsigned<width>>>, 0>;
1514+
using pubUnsignedIntBatch = Integer<Public<Batch<Unsigned<width>>>, 0>;
1515+
1516+
size_t batchSize = 9;
1517+
1518+
int partyId = 2;
1519+
1520+
std::random_device rd;
1521+
std::mt19937_64 e(rd());
1522+
std::uniform_int_distribution<int64_t> dist1(smallestSigned, largestSigned);
1523+
1524+
std::uniform_int_distribution<uint64_t> dist2(0, largestUnsigned);
1525+
1526+
for (int i = 0; i < 100; i++) {
1527+
std::vector<int64_t> v1(batchSize);
1528+
std::vector<int64_t> v2(batchSize);
1529+
for (size_t j = 0; j < batchSize; j++) {
1530+
v1[j] = dist1(e);
1531+
v2[j] = dist1(e);
1532+
}
1533+
1534+
secSignedIntBatch int1(v1, partyId);
1535+
secSignedIntBatch int2(v2, partyId);
1536+
pubSignedIntBatch int3(v1);
1537+
pubSignedIntBatch int4(v2);
1538+
1539+
auto r1 = max(int1, int2);
1540+
auto r2 = max(int1, int4);
1541+
auto r3 = max(int3, int4);
1542+
1543+
std::vector<int64_t> expectedValue(batchSize);
1544+
for (size_t j = 0; j < batchSize; j++) {
1545+
expectedValue[j] = v1[j] < v2[j] ? v2[j] : v1[j];
1546+
}
1547+
1548+
testVectorEq(r1.openToParty(partyId).getValue(), expectedValue);
1549+
testVectorEq(r2.openToParty(partyId).getValue(), expectedValue);
1550+
testVectorEq(r3.getValue(), expectedValue);
1551+
}
1552+
1553+
for (int i = 0; i < 100; i++) {
1554+
std::vector<uint64_t> v1(batchSize);
1555+
std::vector<uint64_t> v2(batchSize);
1556+
for (size_t j = 0; j < batchSize; j++) {
1557+
v1[j] = dist2(e);
1558+
v2[j] = dist2(e);
1559+
}
1560+
1561+
secUnsignedIntBatch int1(v1, partyId);
1562+
secUnsignedIntBatch int2(v2, partyId);
1563+
pubUnsignedIntBatch int3(v1);
1564+
pubUnsignedIntBatch int4(v2);
1565+
1566+
auto r1 = max(int1, int2);
1567+
auto r2 = max(int1, int4);
1568+
auto r3 = max(int3, int4);
1569+
1570+
std::vector<uint64_t> expectedValue(batchSize);
1571+
for (size_t j = 0; j < batchSize; j++) {
1572+
expectedValue[j] = v1[j] < v2[j] ? v2[j] : v1[j];
1573+
}
1574+
1575+
testVectorEq(r1.openToParty(partyId).getValue(), expectedValue);
1576+
testVectorEq(r2.openToParty(partyId).getValue(), expectedValue);
1577+
testVectorEq(r3.getValue(), expectedValue);
1578+
}
1579+
}
1580+
14371581
} // namespace fbpcf::frontend

0 commit comments

Comments
 (0)