|
| 1 | +// SPDX-License-Identifier: Apache-2.0 |
| 2 | +/** |
| 3 | + * @file grotto_dcf.cuh |
| 4 | + * @copyright Apache License, Version 2.0. Copyright (C) 2026 Yulong Ming <i@myl7.org>. |
| 5 | + * @author Yulong Ming <i@myl7.org> |
| 6 | + * |
| 7 | + * @brief 2-party distributed comparison function (DCF) over F2 from standard DPF. |
| 8 | + * |
| 9 | + * The scheme is from the paper, [_Grotto: Screaming fast (2+1)-PC for Z_{2^n} via |
| 10 | + * (2,2)-DPFs_](https://eprint.iacr.org/2023/108) (@ref grotto_dcf "1: the published version"). |
| 11 | + * |
| 12 | + * Key generation is identical to standard BGI DPF. The comparison functionality |
| 13 | + * emerges from prefix-parity of the DPF control bits. |
| 14 | + * |
| 15 | + * Output shares are in F2 (XOR sharing). Each party holds a single bool per query. |
| 16 | + * For inputs x: share_0 XOR share_1 = 1[alpha <= x]. |
| 17 | + * |
| 18 | + * ## References |
| 19 | + * |
| 20 | + * 1. Kyle Storrier, Adithya Vadapalli, Allan Lyons, Ryan Henry: Grotto: Screaming fast (2+1)-PC |
| 21 | + * for Z_{2^n} via (2,2)-DPFs. CCS 2023. <https://eprint.iacr.org/2023/108>. @anchor grotto_dcf |
| 22 | + */ |
| 23 | + |
| 24 | +#pragma once |
| 25 | +#include <cuda_runtime.h> |
| 26 | +#include <type_traits> |
| 27 | +#include <cstddef> |
| 28 | +#include <cassert> |
| 29 | +#include <omp.h> |
| 30 | +#include <fss/dpf.cuh> |
| 31 | +#include <fss/group/bytes.cuh> |
| 32 | +#include <fss/prg.cuh> |
| 33 | +#include <fss/util.cuh> |
| 34 | + |
| 35 | +namespace fss { |
| 36 | + |
| 37 | +/** |
| 38 | + * 2-party DCF scheme over F2 from standard DPF (Grotto construction). |
| 39 | + * |
| 40 | + * @tparam in_bits Input domain bit size. |
| 41 | + * @tparam Prg See Prgable. Must satisfy Prgable<Prg, 2> (same as DPF). |
| 42 | + * @tparam In Type for the input domain. From uint8_t to __uint128_t. |
| 43 | + * @tparam par_depth -1 is to use ceil(log(num of threads)). |
| 44 | + * Only Preprocess() and EvalAll() use it. |
| 45 | + */ |
| 46 | +template <int in_bits, typename Prg, typename In = uint, int par_depth = -1> |
| 47 | + requires((std::is_unsigned_v<In> || std::is_same_v<In, __uint128_t>) && |
| 48 | + in_bits <= sizeof(In) * 8 && Prgable<Prg, 2>) |
| 49 | +class GrottoDcf { |
| 50 | + using DpfType = Dpf<in_bits, group::Bytes, Prg, In, par_depth>; |
| 51 | + |
| 52 | +public: |
| 53 | + using Cw = typename DpfType::Cw; |
| 54 | + Prg prg; |
| 55 | + |
| 56 | + /** |
| 57 | + * Key generation method. Delegates to Dpf::Gen with beta=0. |
| 58 | + * |
| 59 | + * @param cws Pre-allocated array of Cw. Size must be in_bits + 1. |
| 60 | + * @param s0s 2 initial seeds. Users can randomly sample them. |
| 61 | + * @param a The secret comparison threshold. |
| 62 | + * |
| 63 | + * The key for party i consists of cws + s0s[i]. |
| 64 | + */ |
| 65 | + __host__ __device__ void Gen(Cw cws[], const int4 s0s[2], In a) { |
| 66 | + DpfType dpf{prg}; |
| 67 | + int4 beta = {0, 0, 0, 0}; |
| 68 | + dpf.Gen(cws, s0s, a, beta); |
| 69 | + } |
| 70 | + |
| 71 | + /** |
| 72 | + * Parity segment tree over leaf control bits. |
| 73 | + * |
| 74 | + * p[0..2N-2]: level-order binary tree where N = 2^in_bits. |
| 75 | + * Root is p[0]. Leaf x has index p[x + N - 1]. |
| 76 | + * Internal node j: p[j] = p[2j+1] XOR p[2j+2]. |
| 77 | + * |
| 78 | + * b: party index, needed for reconstructing comparison results. |
| 79 | + */ |
| 80 | + struct ParityTree { |
| 81 | + bool *p; |
| 82 | + bool b; |
| 83 | + }; |
| 84 | + |
| 85 | + /** |
| 86 | + * Preprocess: expand DPF tree and build parity segment tree. |
| 87 | + * |
| 88 | + * Phase 1: O(N) PRG calls to expand the tree and extract leaf control bits. |
| 89 | + * Phase 2a: O(N) XOR operations to build the parity segment tree bottom-up. |
| 90 | + * |
| 91 | + * @param pt ParityTree with p pre-allocated to size 2*N-1 where N = 2^in_bits. |
| 92 | + * pt.b must be set to the party index before calling. |
| 93 | + * @param s0 Initial seed of the party. |
| 94 | + * @param cws Correction words from Gen(). |
| 95 | + */ |
| 96 | + void Preprocess(ParityTree &pt, int4 s0, const Cw cws[]) { |
| 97 | + constexpr size_t N = 1ULL << in_bits; |
| 98 | + |
| 99 | + // Phase 1: expand tree, write leaf control bits to pt.p[N-1 .. 2N-2] |
| 100 | + ExpandTree(pt.b, s0, cws, pt.p + (N - 1)); |
| 101 | + |
| 102 | + // Phase 2a: build parity segment tree bottom-up |
| 103 | + for (size_t j = N - 2; j < N - 1; --j) { |
| 104 | + pt.p[j] = pt.p[2 * j + 1] ^ pt.p[2 * j + 2]; |
| 105 | + } |
| 106 | + } |
| 107 | + |
| 108 | + /** |
| 109 | + * Prefix-parity query on the parity segment tree. |
| 110 | + * |
| 111 | + * Returns party b's share of 1[alpha <= x]. |
| 112 | + * Internally queries endpoint e = x + 1, computing prefix-parity of [0, e). |
| 113 | + * |
| 114 | + * @param pt ParityTree from Preprocess(). |
| 115 | + * @param x Query point. |
| 116 | + * @return bool share such that share_0 XOR share_1 = 1[alpha <= x]. |
| 117 | + */ |
| 118 | + __host__ __device__ static bool Eval(const ParityTree &pt, In x) { |
| 119 | + constexpr size_t N = 1ULL << in_bits; |
| 120 | + In e = static_cast<In>(x) + 1; |
| 121 | + |
| 122 | + // e == 0 means x + 1 overflowed, i.e., e = N (entire domain) |
| 123 | + if (e == 0 || e == N) return pt.p[0]; |
| 124 | + |
| 125 | + bool pi = false; |
| 126 | + size_t cur = 0; |
| 127 | + for (int i = 0; i < in_bits; ++i) { |
| 128 | + bool e_bit = (e >> (in_bits - 1 - i)) & 1; |
| 129 | + if (e_bit) { |
| 130 | + pi ^= pt.p[2 * cur + 1]; |
| 131 | + cur = 2 * cur + 2; |
| 132 | + } else { |
| 133 | + cur = 2 * cur + 1; |
| 134 | + } |
| 135 | + } |
| 136 | + return pi; |
| 137 | + } |
| 138 | + |
| 139 | + /** |
| 140 | + * Full domain evaluation. |
| 141 | + * |
| 142 | + * Computes party b's share of 1[alpha <= x] for all x in [0, N). |
| 143 | + * |
| 144 | + * Phase 1: O(N) PRG calls to expand the tree. |
| 145 | + * Phase 2b: O(N) prefix-sum (running XOR) over leaf control bits. |
| 146 | + * |
| 147 | + * @param b Party index. |
| 148 | + * @param s0 Initial seed of the party. |
| 149 | + * @param cws Correction words from Gen(). |
| 150 | + * @param ys Pre-allocated output array of size N = 2^in_bits. |
| 151 | + * ys[x] = party b's share of 1[alpha <= x]. |
| 152 | + */ |
| 153 | + void EvalAll(bool b, int4 s0, const Cw cws[], bool ys[]) { |
| 154 | + constexpr size_t N = 1ULL << in_bits; |
| 155 | + |
| 156 | + // Phase 1: expand tree to get leaf control bits into ys[] |
| 157 | + ExpandTree(b, s0, cws, ys); |
| 158 | + |
| 159 | + // Phase 2b: prefix-sum scan (running XOR) |
| 160 | + // ys[x] currently holds leaf x's control bit. |
| 161 | + // Transform to: ys[x] = XOR of control bits [0..x] = share of 1[alpha <= x]. |
| 162 | + for (size_t x = 1; x < N; ++x) { |
| 163 | + ys[x] = ys[x] ^ ys[x - 1]; |
| 164 | + } |
| 165 | + } |
| 166 | + |
| 167 | +private: |
| 168 | + /** |
| 169 | + * Expand the DPF tree and write leaf control bits. |
| 170 | + * |
| 171 | + * @param b Party index. |
| 172 | + * @param s0 Initial seed of the party. |
| 173 | + * @param cws Correction words from Gen(). |
| 174 | + * @param t Output array of size N = 2^in_bits for leaf control bits. |
| 175 | + */ |
| 176 | + void ExpandTree(bool b, int4 s0, const Cw cws[], bool t[]) { |
| 177 | + int4 st = s0; |
| 178 | + st = util::SetLsb(st, b); |
| 179 | + |
| 180 | + assert(in_bits < sizeof(size_t) * 8); |
| 181 | + size_t l = 0; |
| 182 | + size_t r = 1ULL << in_bits; |
| 183 | + int i = 0; |
| 184 | + |
| 185 | + int par_depth_ = 0; |
| 186 | + if constexpr (par_depth == -1) { |
| 187 | + int threads = omp_get_max_threads(); |
| 188 | + while ((1 << par_depth_) < threads) { |
| 189 | + par_depth_++; |
| 190 | + } |
| 191 | + } else { |
| 192 | + par_depth_ = par_depth; |
| 193 | + } |
| 194 | + |
| 195 | +#pragma omp parallel |
| 196 | +#pragma omp single |
| 197 | + ExpandTreeRec(st, cws, t, l, r, i, par_depth_); |
| 198 | + } |
| 199 | + |
| 200 | + void ExpandTreeRec( |
| 201 | + int4 st, const Cw cws[], bool t[], size_t l, size_t r, int i, int par_depth_) { |
| 202 | + bool tc = util::GetLsb(st); |
| 203 | + int4 s = st; |
| 204 | + s = util::SetLsb(s, false); |
| 205 | + |
| 206 | + if (i == in_bits) { |
| 207 | + assert(l + 1 == r); |
| 208 | + t[l] = tc; |
| 209 | + return; |
| 210 | + } |
| 211 | + |
| 212 | + Cw cw = cws[i]; |
| 213 | + int4 s_cw = cw.s; |
| 214 | + bool tl_cw = util::GetLsb(s_cw); |
| 215 | + s_cw = util::SetLsb(s_cw, false); |
| 216 | + bool tr_cw = cw.tr; |
| 217 | + |
| 218 | + auto [sl, sr] = prg.Gen(s); |
| 219 | + |
| 220 | + bool tl = util::GetLsb(sl); |
| 221 | + sl = util::SetLsb(sl, false); |
| 222 | + bool tr = util::GetLsb(sr); |
| 223 | + sr = util::SetLsb(sr, false); |
| 224 | + |
| 225 | + if (tc) { |
| 226 | + sl = util::Xor(sl, s_cw); |
| 227 | + sr = util::Xor(sr, s_cw); |
| 228 | + tl = tl ^ tl_cw; |
| 229 | + tr = tr ^ tr_cw; |
| 230 | + } |
| 231 | + |
| 232 | + int4 stl = sl; |
| 233 | + stl = util::SetLsb(stl, tl); |
| 234 | + int4 str = sr; |
| 235 | + str = util::SetLsb(str, tr); |
| 236 | + |
| 237 | + size_t mid = (l + r) / 2; |
| 238 | + |
| 239 | + if (i < par_depth_) { |
| 240 | +#pragma omp task |
| 241 | + ExpandTreeRec(stl, cws, t, l, mid, i + 1, par_depth_); |
| 242 | +#pragma omp task |
| 243 | + ExpandTreeRec(str, cws, t, mid, r, i + 1, par_depth_); |
| 244 | +#pragma omp taskwait |
| 245 | + } else { |
| 246 | + ExpandTreeRec(stl, cws, t, l, mid, i + 1, par_depth_); |
| 247 | + ExpandTreeRec(str, cws, t, mid, r, i + 1, par_depth_); |
| 248 | + } |
| 249 | + } |
| 250 | +}; |
| 251 | + |
| 252 | +} // namespace fss |
0 commit comments