Skip to content

Commit 4c5bba0

Browse files
committed
Add Grotto DCF
1 parent 0c9688d commit 4c5bba0

File tree

3 files changed

+393
-0
lines changed

3 files changed

+393
-0
lines changed

CMakeLists.txt

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -74,6 +74,10 @@ if(BUILD_TESTING)
7474
add_executable(half_tree_dpf_test src/half_tree_dpf_test.cu)
7575
target_link_libraries(half_tree_dpf_test GTest::gtest_main fss)
7676
gtest_discover_tests(half_tree_dpf_test)
77+
78+
add_executable(grotto_dcf_test src/grotto_dcf_test.cu)
79+
target_link_libraries(grotto_dcf_test GTest::gtest_main fss)
80+
gtest_discover_tests(grotto_dcf_test)
7781
endif()
7882

7983
option(BUILD_BENCH "Build benchmarks" OFF)

include/fss/grotto_dcf.cuh

Lines changed: 252 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,252 @@
1+
// SPDX-License-Identifier: Apache-2.0
2+
/**
3+
* @file grotto_dcf.cuh
4+
* @copyright Apache License, Version 2.0. Copyright (C) 2026 Yulong Ming <i@myl7.org>.
5+
* @author Yulong Ming <i@myl7.org>
6+
*
7+
* @brief 2-party distributed comparison function (DCF) over F2 from standard DPF.
8+
*
9+
* The scheme is from the paper, [_Grotto: Screaming fast (2+1)-PC for Z_{2^n} via
10+
* (2,2)-DPFs_](https://eprint.iacr.org/2023/108) (@ref grotto_dcf "1: the published version").
11+
*
12+
* Key generation is identical to standard BGI DPF. The comparison functionality
13+
* emerges from prefix-parity of the DPF control bits.
14+
*
15+
* Output shares are in F2 (XOR sharing). Each party holds a single bool per query.
16+
* For inputs x: share_0 XOR share_1 = 1[alpha <= x].
17+
*
18+
* ## References
19+
*
20+
* 1. Kyle Storrier, Adithya Vadapalli, Allan Lyons, Ryan Henry: Grotto: Screaming fast (2+1)-PC
21+
* for Z_{2^n} via (2,2)-DPFs. CCS 2023. <https://eprint.iacr.org/2023/108>. @anchor grotto_dcf
22+
*/
23+
24+
#pragma once
25+
#include <cuda_runtime.h>
26+
#include <type_traits>
27+
#include <cstddef>
28+
#include <cassert>
29+
#include <omp.h>
30+
#include <fss/dpf.cuh>
31+
#include <fss/group/bytes.cuh>
32+
#include <fss/prg.cuh>
33+
#include <fss/util.cuh>
34+
35+
namespace fss {
36+
37+
/**
38+
* 2-party DCF scheme over F2 from standard DPF (Grotto construction).
39+
*
40+
* @tparam in_bits Input domain bit size.
41+
* @tparam Prg See Prgable. Must satisfy Prgable<Prg, 2> (same as DPF).
42+
* @tparam In Type for the input domain. From uint8_t to __uint128_t.
43+
* @tparam par_depth -1 is to use ceil(log(num of threads)).
44+
* Only Preprocess() and EvalAll() use it.
45+
*/
46+
template <int in_bits, typename Prg, typename In = uint, int par_depth = -1>
47+
requires((std::is_unsigned_v<In> || std::is_same_v<In, __uint128_t>) &&
48+
in_bits <= sizeof(In) * 8 && Prgable<Prg, 2>)
49+
class GrottoDcf {
50+
using DpfType = Dpf<in_bits, group::Bytes, Prg, In, par_depth>;
51+
52+
public:
53+
using Cw = typename DpfType::Cw;
54+
Prg prg;
55+
56+
/**
57+
* Key generation method. Delegates to Dpf::Gen with beta=0.
58+
*
59+
* @param cws Pre-allocated array of Cw. Size must be in_bits + 1.
60+
* @param s0s 2 initial seeds. Users can randomly sample them.
61+
* @param a The secret comparison threshold.
62+
*
63+
* The key for party i consists of cws + s0s[i].
64+
*/
65+
__host__ __device__ void Gen(Cw cws[], const int4 s0s[2], In a) {
66+
DpfType dpf{prg};
67+
int4 beta = {0, 0, 0, 0};
68+
dpf.Gen(cws, s0s, a, beta);
69+
}
70+
71+
/**
72+
* Parity segment tree over leaf control bits.
73+
*
74+
* p[0..2N-2]: level-order binary tree where N = 2^in_bits.
75+
* Root is p[0]. Leaf x has index p[x + N - 1].
76+
* Internal node j: p[j] = p[2j+1] XOR p[2j+2].
77+
*
78+
* b: party index, needed for reconstructing comparison results.
79+
*/
80+
struct ParityTree {
81+
bool *p;
82+
bool b;
83+
};
84+
85+
/**
86+
* Preprocess: expand DPF tree and build parity segment tree.
87+
*
88+
* Phase 1: O(N) PRG calls to expand the tree and extract leaf control bits.
89+
* Phase 2a: O(N) XOR operations to build the parity segment tree bottom-up.
90+
*
91+
* @param pt ParityTree with p pre-allocated to size 2*N-1 where N = 2^in_bits.
92+
* pt.b must be set to the party index before calling.
93+
* @param s0 Initial seed of the party.
94+
* @param cws Correction words from Gen().
95+
*/
96+
void Preprocess(ParityTree &pt, int4 s0, const Cw cws[]) {
97+
constexpr size_t N = 1ULL << in_bits;
98+
99+
// Phase 1: expand tree, write leaf control bits to pt.p[N-1 .. 2N-2]
100+
ExpandTree(pt.b, s0, cws, pt.p + (N - 1));
101+
102+
// Phase 2a: build parity segment tree bottom-up
103+
for (size_t j = N - 2; j < N - 1; --j) {
104+
pt.p[j] = pt.p[2 * j + 1] ^ pt.p[2 * j + 2];
105+
}
106+
}
107+
108+
/**
109+
* Prefix-parity query on the parity segment tree.
110+
*
111+
* Returns party b's share of 1[alpha <= x].
112+
* Internally queries endpoint e = x + 1, computing prefix-parity of [0, e).
113+
*
114+
* @param pt ParityTree from Preprocess().
115+
* @param x Query point.
116+
* @return bool share such that share_0 XOR share_1 = 1[alpha <= x].
117+
*/
118+
__host__ __device__ static bool Eval(const ParityTree &pt, In x) {
119+
constexpr size_t N = 1ULL << in_bits;
120+
In e = static_cast<In>(x) + 1;
121+
122+
// e == 0 means x + 1 overflowed, i.e., e = N (entire domain)
123+
if (e == 0 || e == N) return pt.p[0];
124+
125+
bool pi = false;
126+
size_t cur = 0;
127+
for (int i = 0; i < in_bits; ++i) {
128+
bool e_bit = (e >> (in_bits - 1 - i)) & 1;
129+
if (e_bit) {
130+
pi ^= pt.p[2 * cur + 1];
131+
cur = 2 * cur + 2;
132+
} else {
133+
cur = 2 * cur + 1;
134+
}
135+
}
136+
return pi;
137+
}
138+
139+
/**
140+
* Full domain evaluation.
141+
*
142+
* Computes party b's share of 1[alpha <= x] for all x in [0, N).
143+
*
144+
* Phase 1: O(N) PRG calls to expand the tree.
145+
* Phase 2b: O(N) prefix-sum (running XOR) over leaf control bits.
146+
*
147+
* @param b Party index.
148+
* @param s0 Initial seed of the party.
149+
* @param cws Correction words from Gen().
150+
* @param ys Pre-allocated output array of size N = 2^in_bits.
151+
* ys[x] = party b's share of 1[alpha <= x].
152+
*/
153+
void EvalAll(bool b, int4 s0, const Cw cws[], bool ys[]) {
154+
constexpr size_t N = 1ULL << in_bits;
155+
156+
// Phase 1: expand tree to get leaf control bits into ys[]
157+
ExpandTree(b, s0, cws, ys);
158+
159+
// Phase 2b: prefix-sum scan (running XOR)
160+
// ys[x] currently holds leaf x's control bit.
161+
// Transform to: ys[x] = XOR of control bits [0..x] = share of 1[alpha <= x].
162+
for (size_t x = 1; x < N; ++x) {
163+
ys[x] = ys[x] ^ ys[x - 1];
164+
}
165+
}
166+
167+
private:
168+
/**
169+
* Expand the DPF tree and write leaf control bits.
170+
*
171+
* @param b Party index.
172+
* @param s0 Initial seed of the party.
173+
* @param cws Correction words from Gen().
174+
* @param t Output array of size N = 2^in_bits for leaf control bits.
175+
*/
176+
void ExpandTree(bool b, int4 s0, const Cw cws[], bool t[]) {
177+
int4 st = s0;
178+
st = util::SetLsb(st, b);
179+
180+
assert(in_bits < sizeof(size_t) * 8);
181+
size_t l = 0;
182+
size_t r = 1ULL << in_bits;
183+
int i = 0;
184+
185+
int par_depth_ = 0;
186+
if constexpr (par_depth == -1) {
187+
int threads = omp_get_max_threads();
188+
while ((1 << par_depth_) < threads) {
189+
par_depth_++;
190+
}
191+
} else {
192+
par_depth_ = par_depth;
193+
}
194+
195+
#pragma omp parallel
196+
#pragma omp single
197+
ExpandTreeRec(st, cws, t, l, r, i, par_depth_);
198+
}
199+
200+
void ExpandTreeRec(
201+
int4 st, const Cw cws[], bool t[], size_t l, size_t r, int i, int par_depth_) {
202+
bool tc = util::GetLsb(st);
203+
int4 s = st;
204+
s = util::SetLsb(s, false);
205+
206+
if (i == in_bits) {
207+
assert(l + 1 == r);
208+
t[l] = tc;
209+
return;
210+
}
211+
212+
Cw cw = cws[i];
213+
int4 s_cw = cw.s;
214+
bool tl_cw = util::GetLsb(s_cw);
215+
s_cw = util::SetLsb(s_cw, false);
216+
bool tr_cw = cw.tr;
217+
218+
auto [sl, sr] = prg.Gen(s);
219+
220+
bool tl = util::GetLsb(sl);
221+
sl = util::SetLsb(sl, false);
222+
bool tr = util::GetLsb(sr);
223+
sr = util::SetLsb(sr, false);
224+
225+
if (tc) {
226+
sl = util::Xor(sl, s_cw);
227+
sr = util::Xor(sr, s_cw);
228+
tl = tl ^ tl_cw;
229+
tr = tr ^ tr_cw;
230+
}
231+
232+
int4 stl = sl;
233+
stl = util::SetLsb(stl, tl);
234+
int4 str = sr;
235+
str = util::SetLsb(str, tr);
236+
237+
size_t mid = (l + r) / 2;
238+
239+
if (i < par_depth_) {
240+
#pragma omp task
241+
ExpandTreeRec(stl, cws, t, l, mid, i + 1, par_depth_);
242+
#pragma omp task
243+
ExpandTreeRec(str, cws, t, mid, r, i + 1, par_depth_);
244+
#pragma omp taskwait
245+
} else {
246+
ExpandTreeRec(stl, cws, t, l, mid, i + 1, par_depth_);
247+
ExpandTreeRec(str, cws, t, mid, r, i + 1, par_depth_);
248+
}
249+
}
250+
};
251+
252+
} // namespace fss

0 commit comments

Comments
 (0)