|
| 1 | +#include "lib/Analysis/NoiseAnalysis/BGV/NoiseCanEmbModel.h" |
| 2 | + |
| 3 | +#include <algorithm> |
| 4 | +#include <cassert> |
| 5 | +#include <cmath> |
| 6 | +#include <cstddef> |
| 7 | +#include <functional> |
| 8 | +#include <iomanip> |
| 9 | +#include <ios> |
| 10 | +#include <numeric> |
| 11 | +#include <sstream> |
| 12 | +#include <string> |
| 13 | +#include <vector> |
| 14 | + |
| 15 | +namespace mlir { |
| 16 | +namespace heir { |
| 17 | +namespace bgv { |
| 18 | +// the formulae below are mainly taken from MMLGA22 |
| 19 | +// "Finding and Evaluating Parameters for BGV" |
| 20 | +// https://eprint.iacr.org/2022/706 |
| 21 | + |
| 22 | +template <bool P> |
| 23 | +using Model = NoiseCanEmbModel<P>; |
| 24 | + |
| 25 | +template <bool P> |
| 26 | +double Model<P>::toLogBound(const LocalParamType ¶m, |
| 27 | + const StateType &noise) { |
| 28 | + auto cm = getRingExpansionFactor(param); |
| 29 | + // ||a|| <= c_m * ||a||^{can} |
| 30 | + return log(cm * noise.getValue()) / log(2); |
| 31 | +} |
| 32 | + |
| 33 | +template <bool P> |
| 34 | +double Model<P>::toLogBudget(const LocalParamType ¶m, |
| 35 | + const StateType &noise) { |
| 36 | + return toLogTotal(param) - toLogBound(param, noise); |
| 37 | +} |
| 38 | + |
| 39 | +template <bool P> |
| 40 | +double Model<P>::toLogTotal(const LocalParamType ¶m) { |
| 41 | + double total = 0; |
| 42 | + auto logqi = param.getSchemeParam()->getLogqi(); |
| 43 | + for (auto i = 0; i <= param.getCurrentLevel(); ++i) { |
| 44 | + total += logqi[i]; |
| 45 | + } |
| 46 | + return total - 1.0; |
| 47 | +} |
| 48 | + |
| 49 | +template <bool P> |
| 50 | +std::string Model<P>::toLogBoundString(const LocalParamType ¶m, |
| 51 | + const StateType &noise) { |
| 52 | + auto logBound = toLogBound(param, noise); |
| 53 | + std::stringstream stream; |
| 54 | + stream << std::fixed << std::setprecision(2) << logBound; |
| 55 | + return stream.str(); |
| 56 | +} |
| 57 | + |
| 58 | +template <bool P> |
| 59 | +std::string Model<P>::toLogBudgetString(const LocalParamType ¶m, |
| 60 | + const StateType &noise) { |
| 61 | + auto logBudget = toLogBudget(param, noise); |
| 62 | + std::stringstream stream; |
| 63 | + stream << std::fixed << std::setprecision(2) << logBudget; |
| 64 | + return stream.str(); |
| 65 | +} |
| 66 | + |
| 67 | +template <bool P> |
| 68 | +std::string Model<P>::toLogTotalString(const LocalParamType ¶m) { |
| 69 | + auto logTotal = toLogTotal(param); |
| 70 | + std::stringstream stream; |
| 71 | + stream << std::fixed << std::setprecision(2) << logTotal; |
| 72 | + return stream.str(); |
| 73 | +} |
| 74 | + |
| 75 | +template <bool P> |
| 76 | +double Model<P>::getVarianceErr(const LocalParamType ¶m) { |
| 77 | + auto std0 = param.getSchemeParam()->getStd0(); |
| 78 | + return std0 * std0; |
| 79 | +} |
| 80 | + |
| 81 | +template <bool P> |
| 82 | +double Model<P>::getVarianceKey(const LocalParamType ¶m) { |
| 83 | + // assume UNIFORM_TERNARY |
| 84 | + return 2.0 / 3.0; |
| 85 | +} |
| 86 | + |
| 87 | +template <bool P> |
| 88 | +double Model<P>::getRingExpansionFactor(const LocalParamType ¶m) { |
| 89 | + [[maybe_unused]] auto N = param.getSchemeParam()->getRingDim(); |
| 90 | + // Assert that N is a power of 2 |
| 91 | + assert((N > 0) && ((N & (N - 1)) == 0) && "N must be a power of 2"); |
| 92 | + // In power-of-two rings c_m = 1 |
| 93 | + return 1.; |
| 94 | +} |
| 95 | + |
| 96 | +template <bool P> |
| 97 | +double Model<P>::getAssuranceFactor(const LocalParamType ¶m) { |
| 98 | + // probability that a exceeds its standard deviation by more than a factor of |
| 99 | + // D is roughly erfc(D) with erfc(6) = 2^-55, erfc(5) = 2^-40, erfc(4.5) = |
| 100 | + // 2^-32 |
| 101 | + return 6.; |
| 102 | +} |
| 103 | + |
| 104 | +template <bool P> |
| 105 | +double Model<P>::getBScale(const LocalParamType ¶m) { |
| 106 | + auto varianceKey = getVarianceKey(param); |
| 107 | + auto t = param.getSchemeParam()->getPlaintextModulus(); |
| 108 | + auto d = getAssuranceFactor(param); |
| 109 | + auto phi = getPhi(param); |
| 110 | + |
| 111 | + // B_scale = D * t * sqrt(phi(m)/12 * (1 + phi(m) * V_key) |
| 112 | + double innerTerm = (phi / 12.) * (1 + phi * varianceKey); |
| 113 | + return d * t * sqrt(innerTerm); |
| 114 | +} |
| 115 | + |
| 116 | +template <bool P> |
| 117 | +double Model<P>::getBKs(const LocalParamType ¶m) { |
| 118 | + auto varianceError = getVarianceErr(param); |
| 119 | + auto t = param.getSchemeParam()->getPlaintextModulus(); |
| 120 | + auto d = getAssuranceFactor(param); |
| 121 | + auto phi = getPhi(param); |
| 122 | + |
| 123 | + // B_ks = D * t * phi(m) * sqrt(V_err / 12) |
| 124 | + return d * t * phi * sqrt(varianceError / 12.); |
| 125 | +} |
| 126 | + |
| 127 | +template <bool P> |
| 128 | +double Model<P>::getPhi(const LocalParamType ¶m) { |
| 129 | + return param.getSchemeParam()->getRingDim(); |
| 130 | +} |
| 131 | + |
| 132 | +template <bool P> |
| 133 | +typename Model<P>::StateType Model<P>::evalEncryptPk( |
| 134 | + const LocalParamType ¶m) { |
| 135 | + auto varianceError = getVarianceErr(param); |
| 136 | + // uniform ternary |
| 137 | + auto varianceKey = getVarianceKey(param); |
| 138 | + auto t = param.getSchemeParam()->getPlaintextModulus(); |
| 139 | + auto d = getAssuranceFactor(param); |
| 140 | + auto phi = getPhi(param); |
| 141 | + |
| 142 | + // public key (-as + t * e, a) |
| 143 | + // public key encryption (-aus + t(u * e + e_0) + m, au + e_1) |
| 144 | + // ||m + t * (u * e + e_1 * s + e_0)|| |
| 145 | + // <= D * t * sqrt(phi(m) * (1/12 + 2 * phi(m) * V_err * V_key + V_err)) |
| 146 | + double innerTerm = |
| 147 | + phi * (1. / 12. + 2. * phi * varianceError * varianceKey + varianceKey); |
| 148 | + double fresh = d * t * sqrt(innerTerm); |
| 149 | + return StateType::of(fresh); |
| 150 | +} |
| 151 | + |
| 152 | +template <bool P> |
| 153 | +typename Model<P>::StateType Model<P>::evalEncryptSk( |
| 154 | + const LocalParamType ¶m) { |
| 155 | + auto varianceError = getVarianceErr(param); |
| 156 | + auto t = param.getSchemeParam()->getPlaintextModulus(); |
| 157 | + auto d = getAssuranceFactor(param); |
| 158 | + auto phi = getPhi(param); |
| 159 | + |
| 160 | + // secret key s |
| 161 | + // secret key encryption (-as + m + t * e, a) |
| 162 | + // ||m + t * e|| <= D * t * sqrt(phi(m) * (1/12 + V_err)) |
| 163 | + double innerTerm = phi * (1. / 12. + varianceError); |
| 164 | + double fresh = d * t * sqrt(innerTerm); |
| 165 | + return StateType::of(fresh); |
| 166 | +} |
| 167 | + |
| 168 | +template <bool P> |
| 169 | +typename Model<P>::StateType Model<P>::evalEncrypt( |
| 170 | + const LocalParamType ¶m) { |
| 171 | + // P stands for public key encryption |
| 172 | + if constexpr (P) { |
| 173 | + return evalEncryptPk(param); |
| 174 | + } else { |
| 175 | + return evalEncryptSk(param); |
| 176 | + } |
| 177 | +} |
| 178 | + |
| 179 | +template <bool P> |
| 180 | +typename Model<P>::StateType Model<P>::evalConstant( |
| 181 | + const LocalParamType ¶m) { |
| 182 | + auto t = param.getSchemeParam()->getPlaintextModulus(); |
| 183 | + auto phi = getPhi(param); |
| 184 | + |
| 185 | + // noise part of the plaintext in a pt-ct multiplication |
| 186 | + // v_const <= t * sqrt(phi(m) / 12) |
| 187 | + return StateType::of(t * sqrt(phi / 12.0)); |
| 188 | +} |
| 189 | + |
| 190 | +template <bool P> |
| 191 | +typename Model<P>::StateType Model<P>::evalAdd(const StateType &lhs, |
| 192 | + const StateType &rhs) { |
| 193 | + // v_add <= v_0 + v_1 |
| 194 | + return StateType::of(lhs.getValue() + rhs.getValue()); |
| 195 | +} |
| 196 | + |
| 197 | +template <bool P> |
| 198 | +typename Model<P>::StateType Model<P>::evalMul( |
| 199 | + const LocalParamType &resultParam, const StateType &lhs, |
| 200 | + const StateType &rhs) { |
| 201 | + // v_mul <= v_0 * v_1 |
| 202 | + return StateType::of(lhs.getValue() * rhs.getValue()); |
| 203 | +} |
| 204 | + |
| 205 | +template <bool P> |
| 206 | +typename Model<P>::StateType Model<P>::evalModReduce( |
| 207 | + const LocalParamType &inputParam, const StateType &input) { |
| 208 | + auto currentLogqi = |
| 209 | + inputParam.getSchemeParam()->getLogqi()[inputParam.getCurrentLevel()]; |
| 210 | + double modulus = pow(2.0, currentLogqi); |
| 211 | + |
| 212 | + // modulus switching is essentially a scaling operation |
| 213 | + // so the original error is scaled by the modulus |
| 214 | + // ||v_scaled|| = ||v_input|| / modulus |
| 215 | + auto scaled = input.getValue() / modulus; |
| 216 | + // in the meantime, it will introduce a rounding error |
| 217 | + // (tau_0, tau_1) to (ct_0, ct_1) |
| 218 | + // ||tau_0 + tau_1 * s|| <= D * t * sqrt(phi(m)/12 * (1 + phi(m) * V_key) = |
| 219 | + // B_scale |
| 220 | + // ||v_ms|| <= ||v_scaled|| + B_scale |
| 221 | + double bScale = getBScale(inputParam); |
| 222 | + return StateType::of(scaled + bScale); |
| 223 | +} |
| 224 | + |
| 225 | +template <bool P> |
| 226 | +typename Model<P>::StateType Model<P>::evalRelinearizeHYBRID( |
| 227 | + const LocalParamType &inputParam, const StateType &input) { |
| 228 | + // for v_input, after modup and moddown, it remains the same (with rounding). |
| 229 | + // We only need to consider the error from key switching key |
| 230 | + // and rounding error during moddown. |
| 231 | + // Check section 3.2 of MMLGA22 for more details. |
| 232 | + auto dnum = inputParam.getSchemeParam()->getDnum(); |
| 233 | + |
| 234 | + auto currentLevel = inputParam.getCurrentLevel(); |
| 235 | + auto logpi = inputParam.getSchemeParam()->getLogpi(); |
| 236 | + |
| 237 | + // TODO: prod of Pi() if Pi() is available instead of logPi() |
| 238 | + auto pi = inputParam.getSchemeParam()->getPi(); |
| 239 | + double prodPi; |
| 240 | + double maxPi; |
| 241 | + size_t k; |
| 242 | + if (pi.empty()) { |
| 243 | + // values of pi are not set in schemeParam, so we use this |
| 244 | + std::vector<double> moduliPi(logpi.size()); |
| 245 | + std::transform(logpi.begin(), logpi.end(), moduliPi.begin(), |
| 246 | + [](double value) { return pow(2.0, value); }); |
| 247 | + maxPi = *std::max_element(moduliPi.begin(), moduliPi.end()); |
| 248 | + prodPi = std::accumulate(moduliPi.begin(), moduliPi.end(), 1., |
| 249 | + std::multiplies<double>()); |
| 250 | + k = moduliPi.size(); |
| 251 | + } else { |
| 252 | + // if real values of pi are set, we use those |
| 253 | + maxPi = *std::max_element(pi.begin(), pi.end()); |
| 254 | + prodPi = |
| 255 | + std::accumulate(pi.begin(), pi.end(), 1., std::multiplies<double>()); |
| 256 | + k = pi.size(); |
| 257 | + } |
| 258 | + |
| 259 | + // v_ks = v + sqrt(dnum * (currentLevel + 1)) * p_l^(ceil(currentLevel / dnum) |
| 260 | + // * B_ks / P + sqrt(k) * B_scale |
| 261 | + double bKs = getBKs(inputParam); |
| 262 | + auto pPower = ceil(static_cast<double>(currentLevel) / dnum); |
| 263 | + auto noiseKs = sqrt(dnum * (currentLevel + 1)) * |
| 264 | + pow(static_cast<double>(maxPi), static_cast<double>(pPower)) * |
| 265 | + bKs / prodPi; |
| 266 | + double bScale = getBScale(inputParam); |
| 267 | + auto noiseScale = sqrt(k) * bScale; |
| 268 | + |
| 269 | + return StateType::of(input.getValue() + noiseKs + noiseScale); |
| 270 | +} |
| 271 | + |
| 272 | +template <bool P> |
| 273 | +typename Model<P>::StateType Model<P>::evalRelinearize( |
| 274 | + const LocalParamType &inputParam, const StateType &input) { |
| 275 | + // assume HYBRID |
| 276 | + // if we further introduce BV to SchemeParam we can have alternative |
| 277 | + // implementation. |
| 278 | + return evalRelinearizeHYBRID(inputParam, input); |
| 279 | +} |
| 280 | + |
| 281 | +// instantiate template class |
| 282 | +template class NoiseCanEmbModel<false>; |
| 283 | +template class NoiseCanEmbModel<true>; |
| 284 | + |
| 285 | +} // namespace bgv |
| 286 | +} // namespace heir |
| 287 | +} // namespace mlir |
0 commit comments