Skip to content

Commit efb0d50

Browse files
Merge pull request #1545 from Fraunhofer-AISEC:mono-noise-analysis
PiperOrigin-RevId: 736319230
2 parents ac31b6f + 65f0182 commit efb0d50

File tree

13 files changed

+528
-4
lines changed

13 files changed

+528
-4
lines changed

lib/Analysis/NoiseAnalysis/BGV/BUILD

+15-2
Original file line numberDiff line numberDiff line change
@@ -13,15 +13,14 @@ cc_library(
1313
deps = [
1414
":NoiseByBoundCoeffModel",
1515
":NoiseByVarianceCoeffModel",
16+
":NoiseCanEmbModel",
1617
"@heir//lib/Analysis:Utils",
1718
"@heir//lib/Analysis/DimensionAnalysis",
1819
"@heir//lib/Analysis/LevelAnalysis",
1920
"@heir//lib/Analysis/NoiseAnalysis",
20-
"@heir//lib/Analysis/NoiseAnalysis:Noise",
2121
"@heir//lib/Dialect/Mgmt/IR:Dialect",
2222
"@heir//lib/Dialect/Secret/IR:Dialect",
2323
"@heir//lib/Dialect/TensorExt/IR:Dialect",
24-
"@heir//lib/Parameters/BGV:Params",
2524
"@llvm-project//llvm:Support",
2625
"@llvm-project//mlir:ArithDialect",
2726
"@llvm-project//mlir:CallOpInterfaces",
@@ -45,6 +44,20 @@ cc_library(
4544
],
4645
)
4746

47+
cc_library(
48+
name = "NoiseCanEmbModel",
49+
srcs = [
50+
"NoiseCanEmbModel.cpp",
51+
],
52+
hdrs = [
53+
"NoiseCanEmbModel.h",
54+
],
55+
deps = [
56+
"@heir//lib/Analysis/NoiseAnalysis:Noise",
57+
"@heir//lib/Parameters/BGV:Params",
58+
],
59+
)
60+
4861
cc_library(
4962
name = "NoiseByVarianceCoeffModel",
5063
srcs = [
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,287 @@
1+
#include "lib/Analysis/NoiseAnalysis/BGV/NoiseCanEmbModel.h"
2+
3+
#include <algorithm>
4+
#include <cassert>
5+
#include <cmath>
6+
#include <cstddef>
7+
#include <functional>
8+
#include <iomanip>
9+
#include <ios>
10+
#include <numeric>
11+
#include <sstream>
12+
#include <string>
13+
#include <vector>
14+
15+
namespace mlir {
16+
namespace heir {
17+
namespace bgv {
18+
// the formulae below are mainly taken from MMLGA22
19+
// "Finding and Evaluating Parameters for BGV"
20+
// https://eprint.iacr.org/2022/706
21+
22+
template <bool P>
23+
using Model = NoiseCanEmbModel<P>;
24+
25+
template <bool P>
26+
double Model<P>::toLogBound(const LocalParamType &param,
27+
const StateType &noise) {
28+
auto cm = getRingExpansionFactor(param);
29+
// ||a|| <= c_m * ||a||^{can}
30+
return log(cm * noise.getValue()) / log(2);
31+
}
32+
33+
template <bool P>
34+
double Model<P>::toLogBudget(const LocalParamType &param,
35+
const StateType &noise) {
36+
return toLogTotal(param) - toLogBound(param, noise);
37+
}
38+
39+
template <bool P>
40+
double Model<P>::toLogTotal(const LocalParamType &param) {
41+
double total = 0;
42+
auto logqi = param.getSchemeParam()->getLogqi();
43+
for (auto i = 0; i <= param.getCurrentLevel(); ++i) {
44+
total += logqi[i];
45+
}
46+
return total - 1.0;
47+
}
48+
49+
template <bool P>
50+
std::string Model<P>::toLogBoundString(const LocalParamType &param,
51+
const StateType &noise) {
52+
auto logBound = toLogBound(param, noise);
53+
std::stringstream stream;
54+
stream << std::fixed << std::setprecision(2) << logBound;
55+
return stream.str();
56+
}
57+
58+
template <bool P>
59+
std::string Model<P>::toLogBudgetString(const LocalParamType &param,
60+
const StateType &noise) {
61+
auto logBudget = toLogBudget(param, noise);
62+
std::stringstream stream;
63+
stream << std::fixed << std::setprecision(2) << logBudget;
64+
return stream.str();
65+
}
66+
67+
template <bool P>
68+
std::string Model<P>::toLogTotalString(const LocalParamType &param) {
69+
auto logTotal = toLogTotal(param);
70+
std::stringstream stream;
71+
stream << std::fixed << std::setprecision(2) << logTotal;
72+
return stream.str();
73+
}
74+
75+
template <bool P>
76+
double Model<P>::getVarianceErr(const LocalParamType &param) {
77+
auto std0 = param.getSchemeParam()->getStd0();
78+
return std0 * std0;
79+
}
80+
81+
template <bool P>
82+
double Model<P>::getVarianceKey(const LocalParamType &param) {
83+
// assume UNIFORM_TERNARY
84+
return 2.0 / 3.0;
85+
}
86+
87+
template <bool P>
88+
double Model<P>::getRingExpansionFactor(const LocalParamType &param) {
89+
[[maybe_unused]] auto N = param.getSchemeParam()->getRingDim();
90+
// Assert that N is a power of 2
91+
assert((N > 0) && ((N & (N - 1)) == 0) && "N must be a power of 2");
92+
// In power-of-two rings c_m = 1
93+
return 1.;
94+
}
95+
96+
template <bool P>
97+
double Model<P>::getAssuranceFactor(const LocalParamType &param) {
98+
// probability that a exceeds its standard deviation by more than a factor of
99+
// D is roughly erfc(D) with erfc(6) = 2^-55, erfc(5) = 2^-40, erfc(4.5) =
100+
// 2^-32
101+
return 6.;
102+
}
103+
104+
template <bool P>
105+
double Model<P>::getBScale(const LocalParamType &param) {
106+
auto varianceKey = getVarianceKey(param);
107+
auto t = param.getSchemeParam()->getPlaintextModulus();
108+
auto d = getAssuranceFactor(param);
109+
auto phi = getPhi(param);
110+
111+
// B_scale = D * t * sqrt(phi(m)/12 * (1 + phi(m) * V_key)
112+
double innerTerm = (phi / 12.) * (1 + phi * varianceKey);
113+
return d * t * sqrt(innerTerm);
114+
}
115+
116+
template <bool P>
117+
double Model<P>::getBKs(const LocalParamType &param) {
118+
auto varianceError = getVarianceErr(param);
119+
auto t = param.getSchemeParam()->getPlaintextModulus();
120+
auto d = getAssuranceFactor(param);
121+
auto phi = getPhi(param);
122+
123+
// B_ks = D * t * phi(m) * sqrt(V_err / 12)
124+
return d * t * phi * sqrt(varianceError / 12.);
125+
}
126+
127+
template <bool P>
128+
double Model<P>::getPhi(const LocalParamType &param) {
129+
return param.getSchemeParam()->getRingDim();
130+
}
131+
132+
template <bool P>
133+
typename Model<P>::StateType Model<P>::evalEncryptPk(
134+
const LocalParamType &param) {
135+
auto varianceError = getVarianceErr(param);
136+
// uniform ternary
137+
auto varianceKey = getVarianceKey(param);
138+
auto t = param.getSchemeParam()->getPlaintextModulus();
139+
auto d = getAssuranceFactor(param);
140+
auto phi = getPhi(param);
141+
142+
// public key (-as + t * e, a)
143+
// public key encryption (-aus + t(u * e + e_0) + m, au + e_1)
144+
// ||m + t * (u * e + e_1 * s + e_0)||
145+
// <= D * t * sqrt(phi(m) * (1/12 + 2 * phi(m) * V_err * V_key + V_err))
146+
double innerTerm =
147+
phi * (1. / 12. + 2. * phi * varianceError * varianceKey + varianceKey);
148+
double fresh = d * t * sqrt(innerTerm);
149+
return StateType::of(fresh);
150+
}
151+
152+
template <bool P>
153+
typename Model<P>::StateType Model<P>::evalEncryptSk(
154+
const LocalParamType &param) {
155+
auto varianceError = getVarianceErr(param);
156+
auto t = param.getSchemeParam()->getPlaintextModulus();
157+
auto d = getAssuranceFactor(param);
158+
auto phi = getPhi(param);
159+
160+
// secret key s
161+
// secret key encryption (-as + m + t * e, a)
162+
// ||m + t * e|| <= D * t * sqrt(phi(m) * (1/12 + V_err))
163+
double innerTerm = phi * (1. / 12. + varianceError);
164+
double fresh = d * t * sqrt(innerTerm);
165+
return StateType::of(fresh);
166+
}
167+
168+
template <bool P>
169+
typename Model<P>::StateType Model<P>::evalEncrypt(
170+
const LocalParamType &param) {
171+
// P stands for public key encryption
172+
if constexpr (P) {
173+
return evalEncryptPk(param);
174+
} else {
175+
return evalEncryptSk(param);
176+
}
177+
}
178+
179+
template <bool P>
180+
typename Model<P>::StateType Model<P>::evalConstant(
181+
const LocalParamType &param) {
182+
auto t = param.getSchemeParam()->getPlaintextModulus();
183+
auto phi = getPhi(param);
184+
185+
// noise part of the plaintext in a pt-ct multiplication
186+
// v_const <= t * sqrt(phi(m) / 12)
187+
return StateType::of(t * sqrt(phi / 12.0));
188+
}
189+
190+
template <bool P>
191+
typename Model<P>::StateType Model<P>::evalAdd(const StateType &lhs,
192+
const StateType &rhs) {
193+
// v_add <= v_0 + v_1
194+
return StateType::of(lhs.getValue() + rhs.getValue());
195+
}
196+
197+
template <bool P>
198+
typename Model<P>::StateType Model<P>::evalMul(
199+
const LocalParamType &resultParam, const StateType &lhs,
200+
const StateType &rhs) {
201+
// v_mul <= v_0 * v_1
202+
return StateType::of(lhs.getValue() * rhs.getValue());
203+
}
204+
205+
template <bool P>
206+
typename Model<P>::StateType Model<P>::evalModReduce(
207+
const LocalParamType &inputParam, const StateType &input) {
208+
auto currentLogqi =
209+
inputParam.getSchemeParam()->getLogqi()[inputParam.getCurrentLevel()];
210+
double modulus = pow(2.0, currentLogqi);
211+
212+
// modulus switching is essentially a scaling operation
213+
// so the original error is scaled by the modulus
214+
// ||v_scaled|| = ||v_input|| / modulus
215+
auto scaled = input.getValue() / modulus;
216+
// in the meantime, it will introduce a rounding error
217+
// (tau_0, tau_1) to (ct_0, ct_1)
218+
// ||tau_0 + tau_1 * s|| <= D * t * sqrt(phi(m)/12 * (1 + phi(m) * V_key) =
219+
// B_scale
220+
// ||v_ms|| <= ||v_scaled|| + B_scale
221+
double bScale = getBScale(inputParam);
222+
return StateType::of(scaled + bScale);
223+
}
224+
225+
template <bool P>
226+
typename Model<P>::StateType Model<P>::evalRelinearizeHYBRID(
227+
const LocalParamType &inputParam, const StateType &input) {
228+
// for v_input, after modup and moddown, it remains the same (with rounding).
229+
// We only need to consider the error from key switching key
230+
// and rounding error during moddown.
231+
// Check section 3.2 of MMLGA22 for more details.
232+
auto dnum = inputParam.getSchemeParam()->getDnum();
233+
234+
auto currentLevel = inputParam.getCurrentLevel();
235+
auto logpi = inputParam.getSchemeParam()->getLogpi();
236+
237+
// TODO: prod of Pi() if Pi() is available instead of logPi()
238+
auto pi = inputParam.getSchemeParam()->getPi();
239+
double prodPi;
240+
double maxPi;
241+
size_t k;
242+
if (pi.empty()) {
243+
// values of pi are not set in schemeParam, so we use this
244+
std::vector<double> moduliPi(logpi.size());
245+
std::transform(logpi.begin(), logpi.end(), moduliPi.begin(),
246+
[](double value) { return pow(2.0, value); });
247+
maxPi = *std::max_element(moduliPi.begin(), moduliPi.end());
248+
prodPi = std::accumulate(moduliPi.begin(), moduliPi.end(), 1.,
249+
std::multiplies<double>());
250+
k = moduliPi.size();
251+
} else {
252+
// if real values of pi are set, we use those
253+
maxPi = *std::max_element(pi.begin(), pi.end());
254+
prodPi =
255+
std::accumulate(pi.begin(), pi.end(), 1., std::multiplies<double>());
256+
k = pi.size();
257+
}
258+
259+
// v_ks = v + sqrt(dnum * (currentLevel + 1)) * p_l^(ceil(currentLevel / dnum)
260+
// * B_ks / P + sqrt(k) * B_scale
261+
double bKs = getBKs(inputParam);
262+
auto pPower = ceil(static_cast<double>(currentLevel) / dnum);
263+
auto noiseKs = sqrt(dnum * (currentLevel + 1)) *
264+
pow(static_cast<double>(maxPi), static_cast<double>(pPower)) *
265+
bKs / prodPi;
266+
double bScale = getBScale(inputParam);
267+
auto noiseScale = sqrt(k) * bScale;
268+
269+
return StateType::of(input.getValue() + noiseKs + noiseScale);
270+
}
271+
272+
template <bool P>
273+
typename Model<P>::StateType Model<P>::evalRelinearize(
274+
const LocalParamType &inputParam, const StateType &input) {
275+
// assume HYBRID
276+
// if we further introduce BV to SchemeParam we can have alternative
277+
// implementation.
278+
return evalRelinearizeHYBRID(inputParam, input);
279+
}
280+
281+
// instantiate template class
282+
template class NoiseCanEmbModel<false>;
283+
template class NoiseCanEmbModel<true>;
284+
285+
} // namespace bgv
286+
} // namespace heir
287+
} // namespace mlir

0 commit comments

Comments
 (0)