Skip to content

Commit 65f0182

Browse files
committed
Implementation of BGV mono bounds for the noise analysis framework
1 parent f8b2a3d commit 65f0182

File tree

13 files changed

+525
-1
lines changed

13 files changed

+525
-1
lines changed

lib/Analysis/NoiseAnalysis/BGV/BUILD

+15
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@ cc_library(
1313
deps = [
1414
":NoiseByBoundCoeffModel",
1515
":NoiseByVarianceCoeffModel",
16+
":NoiseCanEmbModel",
1617
"@heir//lib/Analysis:Utils",
1718
"@heir//lib/Analysis/DimensionAnalysis",
1819
"@heir//lib/Analysis/LevelAnalysis",
@@ -45,6 +46,20 @@ cc_library(
4546
],
4647
)
4748

49+
cc_library(
50+
name = "NoiseCanEmbModel",
51+
srcs = [
52+
"NoiseCanEmbModel.cpp",
53+
],
54+
hdrs = [
55+
"NoiseCanEmbModel.h",
56+
],
57+
deps = [
58+
"@heir//lib/Analysis/NoiseAnalysis:Noise",
59+
"@heir//lib/Parameters/BGV:Params",
60+
],
61+
)
62+
4863
cc_library(
4964
name = "NoiseByVarianceCoeffModel",
5065
srcs = [
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,284 @@
1+
#include "lib/Analysis/NoiseAnalysis/BGV/NoiseCanEmbModel.h"
2+
3+
#include <algorithm>
4+
#include <cassert>
5+
#include <cmath>
6+
#include <iomanip>
7+
#include <ios>
8+
#include <numeric>
9+
#include <sstream>
10+
#include <string>
11+
12+
namespace mlir {
13+
namespace heir {
14+
namespace bgv {
15+
// the formulae below are mainly taken from MMLGA22
16+
// "Finding and Evaluating Parameters for BGV"
17+
// https://eprint.iacr.org/2022/706
18+
19+
template <bool P>
20+
using Model = NoiseCanEmbModel<P>;
21+
22+
template <bool P>
23+
double Model<P>::toLogBound(const LocalParamType &param,
24+
const StateType &noise) {
25+
auto cm = getRingExpansionFactor(param);
26+
// ||a|| <= c_m * ||a||^{can}
27+
return log(cm * noise.getValue()) / log(2);
28+
}
29+
30+
template <bool P>
31+
double Model<P>::toLogBudget(const LocalParamType &param,
32+
const StateType &noise) {
33+
return toLogTotal(param) - toLogBound(param, noise);
34+
}
35+
36+
template <bool P>
37+
double Model<P>::toLogTotal(const LocalParamType &param) {
38+
double total = 0;
39+
auto logqi = param.getSchemeParam()->getLogqi();
40+
for (auto i = 0; i <= param.getCurrentLevel(); ++i) {
41+
total += logqi[i];
42+
}
43+
return total - 1.0;
44+
}
45+
46+
template <bool P>
47+
std::string Model<P>::toLogBoundString(const LocalParamType &param,
48+
const StateType &noise) {
49+
auto logBound = toLogBound(param, noise);
50+
std::stringstream stream;
51+
stream << std::fixed << std::setprecision(2) << logBound;
52+
return stream.str();
53+
}
54+
55+
template <bool P>
56+
std::string Model<P>::toLogBudgetString(const LocalParamType &param,
57+
const StateType &noise) {
58+
auto logBudget = toLogBudget(param, noise);
59+
std::stringstream stream;
60+
stream << std::fixed << std::setprecision(2) << logBudget;
61+
return stream.str();
62+
}
63+
64+
template <bool P>
65+
std::string Model<P>::toLogTotalString(const LocalParamType &param) {
66+
auto logTotal = toLogTotal(param);
67+
std::stringstream stream;
68+
stream << std::fixed << std::setprecision(2) << logTotal;
69+
return stream.str();
70+
}
71+
72+
template <bool P>
73+
double Model<P>::getVarianceErr(const LocalParamType &param) {
74+
auto std0 = param.getSchemeParam()->getStd0();
75+
return std0 * std0;
76+
}
77+
78+
template <bool P>
79+
double Model<P>::getVarianceKey(const LocalParamType &param) {
80+
// assume UNIFORM_TERNARY
81+
return 2.0 / 3.0;
82+
}
83+
84+
template <bool P>
85+
double Model<P>::getRingExpansionFactor(const LocalParamType &param) {
86+
auto N = param.getSchemeParam()->getRingDim();
87+
// Assert that N is a power of 2
88+
assert((N > 0) && ((N & (N - 1)) == 0) && "N must be a power of 2");
89+
// In power-of-two rings c_m = 1
90+
return 1.;
91+
}
92+
93+
template <bool P>
94+
double Model<P>::getAssuranceFactor(const LocalParamType &param) {
95+
// probability that a exceeds its standard deviation by more than a factor of
96+
// D is roughly erfc(D) with erfc(6) = 2^-55, erfc(5) = 2^-40, erfc(4.5) =
97+
// 2^-32
98+
return 6.;
99+
}
100+
101+
template <bool P>
102+
double Model<P>::getBScale(const LocalParamType &param) {
103+
auto varianceKey = getVarianceKey(param);
104+
auto t = param.getSchemeParam()->getPlaintextModulus();
105+
auto d = getAssuranceFactor(param);
106+
auto phi = getPhi(param);
107+
108+
// B_scale = D * t * sqrt(phi(m)/12 * (1 + phi(m) * V_key)
109+
double innerTerm = (phi / 12.) * (1 + phi * varianceKey);
110+
return d * t * sqrt(innerTerm);
111+
}
112+
113+
template <bool P>
114+
double Model<P>::getBKs(const LocalParamType &param) {
115+
auto varianceError = getVarianceErr(param);
116+
auto t = param.getSchemeParam()->getPlaintextModulus();
117+
auto d = getAssuranceFactor(param);
118+
auto phi = getPhi(param);
119+
120+
// B_ks = D * t * phi(m) * sqrt(V_err / 12)
121+
return d * t * phi * sqrt(varianceError / 12.);
122+
}
123+
124+
template <bool P>
125+
double Model<P>::getPhi(const LocalParamType &param) {
126+
return param.getSchemeParam()->getRingDim();
127+
}
128+
129+
template <bool P>
130+
typename Model<P>::StateType Model<P>::evalEncryptPk(
131+
const LocalParamType &param) {
132+
auto varianceError = getVarianceErr(param);
133+
// uniform ternary
134+
auto varianceKey = getVarianceKey(param);
135+
auto t = param.getSchemeParam()->getPlaintextModulus();
136+
auto d = getAssuranceFactor(param);
137+
auto phi = getPhi(param);
138+
139+
// public key (-as + t * e, a)
140+
// public key encryption (-aus + t(u * e + e_0) + m, au + e_1)
141+
// ||m + t * (u * e + e_1 * s + e_0)||
142+
// <= D * t * sqrt(phi(m) * (1/12 + 2 * phi(m) * V_err * V_key + V_err))
143+
double innerTerm =
144+
phi * (1. / 12. + 2. * phi * varianceError * varianceKey + varianceKey);
145+
double fresh = d * t * sqrt(innerTerm);
146+
return StateType::of(fresh);
147+
}
148+
149+
template <bool P>
150+
typename Model<P>::StateType Model<P>::evalEncryptSk(
151+
const LocalParamType &param) {
152+
auto varianceError = getVarianceErr(param);
153+
auto t = param.getSchemeParam()->getPlaintextModulus();
154+
auto d = getAssuranceFactor(param);
155+
auto phi = getPhi(param);
156+
157+
// secret key s
158+
// secret key encryption (-as + m + t * e, a)
159+
// ||m + t * e|| <= D * t * sqrt(phi(m) * (1/12 + V_err))
160+
double innerTerm = phi * (1. / 12. + varianceError);
161+
double fresh = d * t * sqrt(innerTerm);
162+
return StateType::of(fresh);
163+
}
164+
165+
template <bool P>
166+
typename Model<P>::StateType Model<P>::evalEncrypt(
167+
const LocalParamType &param) {
168+
// P stands for public key encryption
169+
if constexpr (P) {
170+
return evalEncryptPk(param);
171+
} else {
172+
return evalEncryptSk(param);
173+
}
174+
}
175+
176+
template <bool P>
177+
typename Model<P>::StateType Model<P>::evalConstant(
178+
const LocalParamType &param) {
179+
auto t = param.getSchemeParam()->getPlaintextModulus();
180+
auto phi = getPhi(param);
181+
182+
// noise part of the plaintext in a pt-ct multiplication
183+
// v_const <= t * sqrt(phi(m) / 12)
184+
return StateType::of(t * sqrt(phi / 12.0));
185+
}
186+
187+
template <bool P>
188+
typename Model<P>::StateType Model<P>::evalAdd(const StateType &lhs,
189+
const StateType &rhs) {
190+
// v_add <= v_0 + v_1
191+
return StateType::of(lhs.getValue() + rhs.getValue());
192+
}
193+
194+
template <bool P>
195+
typename Model<P>::StateType Model<P>::evalMul(
196+
const LocalParamType &resultParam, const StateType &lhs,
197+
const StateType &rhs) {
198+
// v_mul <= v_0 * v_1
199+
return StateType::of(lhs.getValue() * rhs.getValue());
200+
}
201+
202+
template <bool P>
203+
typename Model<P>::StateType Model<P>::evalModReduce(
204+
const LocalParamType &inputParam, const StateType &input) {
205+
auto currentLogqi =
206+
inputParam.getSchemeParam()->getLogqi()[inputParam.getCurrentLevel()];
207+
double modulus = pow(2.0, currentLogqi);
208+
209+
// modulus switching is essentially a scaling operation
210+
// so the original error is scaled by the modulus
211+
// ||v_scaled|| = ||v_input|| / modulus
212+
auto scaled = input.getValue() / modulus;
213+
// in the meantime, it will introduce a rounding error
214+
// (tau_0, tau_1) to (ct_0, ct_1)
215+
// ||tau_0 + tau_1 * s|| <= D * t * sqrt(phi(m)/12 * (1 + phi(m) * V_key) =
216+
// B_scale
217+
// ||v_ms|| <= ||v_scaled|| + B_scale
218+
double bScale = getBScale(inputParam);
219+
return StateType::of(scaled + bScale);
220+
}
221+
222+
template <bool P>
223+
typename Model<P>::StateType Model<P>::evalRelinearizeHYBRID(
224+
const LocalParamType &inputParam, const StateType &input) {
225+
// for v_input, after modup and moddown, it remains the same (with rounding).
226+
// We only need to consider the error from key switching key
227+
// and rounding error during moddown.
228+
// Check section 3.2 of MMLGA22 for more details.
229+
auto dnum = inputParam.getSchemeParam()->getDnum();
230+
231+
auto currentLevel = inputParam.getCurrentLevel();
232+
auto logpi = inputParam.getSchemeParam()->getLogpi();
233+
234+
// TODO: prod of Pi() if Pi() is available instead of logPi()
235+
auto pi = inputParam.getSchemeParam()->getPi();
236+
double prodPi;
237+
double maxPi;
238+
size_t k;
239+
if (pi.size() == 0) {
240+
// values of pi are not set in schemeParam, so we use this
241+
std::vector<double> moduliPi(logpi.size());
242+
std::transform(logpi.begin(), logpi.end(), moduliPi.begin(),
243+
[](double value) { return pow(2.0, value); });
244+
maxPi = *std::max_element(moduliPi.begin(), moduliPi.end());
245+
prodPi = std::accumulate(moduliPi.begin(), moduliPi.end(), 1.,
246+
std::multiplies<double>());
247+
k = moduliPi.size();
248+
} else {
249+
// if real values of pi are set, we use those
250+
maxPi = *std::max_element(pi.begin(), pi.end());
251+
prodPi =
252+
std::accumulate(pi.begin(), pi.end(), 1., std::multiplies<double>());
253+
k = pi.size();
254+
}
255+
256+
// v_ks = v + sqrt(dnum * (currentLevel + 1)) * p_l^(ceil(currentLevel / dnum)
257+
// * B_ks / P + sqrt(k) * B_scale
258+
double bKs = getBKs(inputParam);
259+
auto pPower = ceil(static_cast<double>(currentLevel) / dnum);
260+
auto noiseKs = sqrt(dnum * (currentLevel + 1)) *
261+
pow(static_cast<double>(maxPi), static_cast<double>(pPower)) *
262+
bKs / prodPi;
263+
double bScale = getBScale(inputParam);
264+
auto noiseScale = sqrt(k) * bScale;
265+
266+
return StateType::of(input.getValue() + noiseKs + noiseScale);
267+
}
268+
269+
template <bool P>
270+
typename Model<P>::StateType Model<P>::evalRelinearize(
271+
const LocalParamType &inputParam, const StateType &input) {
272+
// assume HYBRID
273+
// if we further introduce BV to SchemeParam we can have alternative
274+
// implementation.
275+
return evalRelinearizeHYBRID(inputParam, input);
276+
}
277+
278+
// instantiate template class
279+
template class NoiseCanEmbModel<false>;
280+
template class NoiseCanEmbModel<true>;
281+
282+
} // namespace bgv
283+
} // namespace heir
284+
} // namespace mlir
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,74 @@
1+
#ifndef INCLUDE_ANALYSIS_NOISEANALYSIS_BGV_NOISECANEMBMODEL_H_
2+
#define INCLUDE_ANALYSIS_NOISEANALYSIS_BGV_NOISECANEMBMODEL_H_
3+
4+
#include <cassert>
5+
#include <string>
6+
7+
#include "lib/Analysis/NoiseAnalysis/Noise.h"
8+
#include "lib/Parameters/BGV/Params.h"
9+
10+
namespace mlir {
11+
namespace heir {
12+
namespace bgv {
13+
14+
// canonical embedding noise model from MMLGA22
15+
// see https://eprint.iacr.org/2022/706
16+
// use template here just for the sake of code reuse
17+
// P for public key
18+
template <bool P>
19+
class NoiseCanEmbModel {
20+
public:
21+
// for MMLGA22, NoiseState stores the bound ||m + t * e||^{can} for error e.
22+
using StateType = NoiseState;
23+
using SchemeParamType = bgv::SchemeParam;
24+
using LocalParamType = bgv::LocalParam;
25+
26+
private:
27+
static double getVarianceErr(const LocalParamType &param);
28+
static double getVarianceKey(const LocalParamType &param);
29+
static double getBScale(const LocalParamType &param);
30+
static double getBKs(const LocalParamType &param);
31+
static double getAssuranceFactor(const LocalParamType &param);
32+
static double getPhi(const LocalParamType &param);
33+
static double getRingExpansionFactor(const LocalParamType &param);
34+
35+
static StateType evalEncryptPk(const LocalParamType &param);
36+
static StateType evalEncryptSk(const LocalParamType &param);
37+
static StateType evalRelinearizeHYBRID(const LocalParamType &inputParam,
38+
const StateType &input);
39+
40+
public:
41+
static StateType evalEncrypt(const LocalParamType &param);
42+
static StateType evalConstant(const LocalParamType &param);
43+
static StateType evalAdd(const StateType &lhs, const StateType &rhs);
44+
static StateType evalMul(const LocalParamType &resultParam,
45+
const StateType &lhs, const StateType &rhs);
46+
static StateType evalRelinearize(const LocalParamType &inputParam,
47+
const StateType &input);
48+
static StateType evalModReduce(const LocalParamType &inputParam,
49+
const StateType &input);
50+
51+
// logTotal: log(Ql / 2)
52+
// logBound: bound on ||m + t * e|| predicted by the model
53+
// logBudget: logTotal - logBound
54+
// as ||m + t * e|| < Ql / 2 for correct decryption
55+
static double toLogBound(const LocalParamType &param, const StateType &noise);
56+
static std::string toLogBoundString(const LocalParamType &param,
57+
const StateType &noise);
58+
static double toLogBudget(const LocalParamType &param,
59+
const StateType &noise);
60+
static std::string toLogBudgetString(const LocalParamType &param,
61+
const StateType &noise);
62+
static double toLogTotal(const LocalParamType &param);
63+
static std::string toLogTotalString(const LocalParamType &param);
64+
};
65+
66+
// user-facing typedefs
67+
using NoiseCanEmbPkModel = NoiseCanEmbModel<true>;
68+
using NoiseCanEmbSkModel = NoiseCanEmbModel<false>;
69+
70+
} // namespace bgv
71+
} // namespace heir
72+
} // namespace mlir
73+
74+
#endif // INCLUDE_ANALYSIS_NOISEANALYSIS_BGV_NOISECANEMBMODEL_H_

0 commit comments

Comments
 (0)