Skip to content

Commit d7fbf06

Browse files
Merge pull request #1343 from ZenithalHourlyRate:bgv-worst-noise-analysis
PiperOrigin-RevId: 725783991
2 parents c751044 + e834ba9 commit d7fbf06

27 files changed

+1387
-6
lines changed

lib/Analysis/DimensionAnalysis/BUILD

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@ cc_library(
1515
"@llvm-project//llvm:Support",
1616
"@llvm-project//mlir:Analysis",
1717
"@llvm-project//mlir:ArithDialect",
18+
"@llvm-project//mlir:CallOpInterfaces",
1819
"@llvm-project//mlir:FunctionInterfaces",
1920
"@llvm-project//mlir:IR",
2021
"@llvm-project//mlir:Support",

lib/Analysis/DimensionAnalysis/DimensionAnalysis.cpp

Lines changed: 27 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,18 +1,25 @@
11
#include "lib/Analysis/DimensionAnalysis/DimensionAnalysis.h"
22

33
#include <algorithm>
4+
#include <cassert>
45
#include <functional>
56

67
#include "lib/Analysis/SecretnessAnalysis/SecretnessAnalysis.h"
78
#include "lib/Analysis/Utils.h"
9+
#include "lib/Dialect/Mgmt/IR/MgmtAttributes.h"
10+
#include "lib/Dialect/Mgmt/IR/MgmtDialect.h"
811
#include "lib/Dialect/Mgmt/IR/MgmtOps.h"
912
#include "lib/Dialect/Secret/IR/SecretOps.h"
1013
#include "llvm/include/llvm/ADT/TypeSwitch.h" // from @llvm-project
1114
#include "mlir/include/mlir/Analysis/DataFlowFramework.h" // from @llvm-project
1215
#include "mlir/include/mlir/Dialect/Arith/IR/Arith.h" // from @llvm-project
16+
#include "mlir/include/mlir/IR/Attributes.h" // from @llvm-project
17+
#include "mlir/include/mlir/IR/BuiltinAttributes.h" // from @llvm-project
18+
#include "mlir/include/mlir/IR/BuiltinTypes.h" // from @llvm-project
1319
#include "mlir/include/mlir/IR/Operation.h" // from @llvm-project
1420
#include "mlir/include/mlir/IR/Value.h" // from @llvm-project
1521
#include "mlir/include/mlir/IR/Visitors.h" // from @llvm-project
22+
#include "mlir/include/mlir/Interfaces/CallInterfaces.h" // from @llvm-project
1623
#include "mlir/include/mlir/Interfaces/FunctionInterfaces.h" // from @llvm-project
1724
#include "mlir/include/mlir/Support/LLVM.h" // from @llvm-project
1825

@@ -106,6 +113,26 @@ int getDimension(Value value, DataFlowSolver *solver) {
106113
return lattice->getValue().getDimension();
107114
}
108115

116+
int getDimensionFromMgmtAttr(Value value) {
117+
Attribute attr;
118+
if (auto blockArg = dyn_cast<BlockArgument>(value)) {
119+
auto *parentOp = blockArg.getOwner()->getParentOp();
120+
auto genericOp = dyn_cast<secret::GenericOp>(parentOp);
121+
if (genericOp) {
122+
attr = genericOp.getArgAttr(blockArg.getArgNumber(),
123+
mgmt::MgmtDialect::kArgMgmtAttrName);
124+
}
125+
} else {
126+
auto *parentOp = value.getDefiningOp();
127+
attr = parentOp->getAttr(mgmt::MgmtDialect::kArgMgmtAttrName);
128+
}
129+
if (!mlir::isa<mgmt::MgmtAttr>(attr)) {
130+
assert(false && "MgmtAttr not found");
131+
}
132+
auto mgmtAttr = mlir::cast<mgmt::MgmtAttr>(attr);
133+
return mgmtAttr.getDimension();
134+
}
135+
109136
void annotateDimension(Operation *top, DataFlowSolver *solver) {
110137
auto getIntegerAttr = [&](int dimension) {
111138
return IntegerAttr::get(IntegerType::get(top->getContext(), 64), dimension);

lib/Analysis/DimensionAnalysis/DimensionAnalysis.h

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -93,6 +93,8 @@ class DimensionAnalysis
9393
// initialized
9494
DimensionState::DimensionType getDimension(Value value, DataFlowSolver *solver);
9595

96+
DimensionState::DimensionType getDimensionFromMgmtAttr(Value value);
97+
9698
void annotateDimension(Operation *top, DataFlowSolver *solver);
9799

98100
} // namespace heir

lib/Analysis/LevelAnalysis/BUILD

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -15,9 +15,8 @@ cc_library(
1515
"@llvm-project//llvm:Support",
1616
"@llvm-project//mlir:Analysis",
1717
"@llvm-project//mlir:ArithDialect",
18-
"@llvm-project//mlir:FunctionInterfaces",
18+
"@llvm-project//mlir:CallOpInterfaces",
1919
"@llvm-project//mlir:IR",
2020
"@llvm-project//mlir:Support",
21-
"@llvm-project//mlir:TensorDialect",
2221
],
2322
)

lib/Analysis/LevelAnalysis/LevelAnalysis.cpp

Lines changed: 26 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,23 +1,25 @@
11
#include "lib/Analysis/LevelAnalysis/LevelAnalysis.h"
22

33
#include <algorithm>
4+
#include <cassert>
45
#include <functional>
56

67
#include "lib/Analysis/SecretnessAnalysis/SecretnessAnalysis.h"
78
#include "lib/Analysis/Utils.h"
9+
#include "lib/Dialect/Mgmt/IR/MgmtAttributes.h"
10+
#include "lib/Dialect/Mgmt/IR/MgmtDialect.h"
811
#include "lib/Dialect/Mgmt/IR/MgmtOps.h"
912
#include "lib/Dialect/Secret/IR/SecretOps.h"
1013
#include "llvm/include/llvm/ADT/TypeSwitch.h" // from @llvm-project
1114
#include "mlir/include/mlir/Analysis/DataFlowFramework.h" // from @llvm-project
12-
#include "mlir/include/mlir/Dialect/Arith/IR/Arith.h" // from @llvm-project
13-
#include "mlir/include/mlir/Dialect/Tensor/IR/Tensor.h" // from @llvm-project
15+
#include "mlir/include/mlir/IR/Attributes.h" // from @llvm-project
1416
#include "mlir/include/mlir/IR/BuiltinAttributes.h" // from @llvm-project
1517
#include "mlir/include/mlir/IR/BuiltinTypes.h" // from @llvm-project
1618
#include "mlir/include/mlir/IR/Operation.h" // from @llvm-project
1719
#include "mlir/include/mlir/IR/Value.h" // from @llvm-project
1820
#include "mlir/include/mlir/IR/Visitors.h" // from @llvm-project
19-
#include "mlir/include/mlir/Interfaces/FunctionInterfaces.h" // from @llvm-project
20-
#include "mlir/include/mlir/Support/LLVM.h" // from @llvm-project
21+
#include "mlir/include/mlir/Interfaces/CallInterfaces.h" // from @llvm-project
22+
#include "mlir/include/mlir/Support/LLVM.h" // from @llvm-project
2123

2224
namespace mlir {
2325
namespace heir {
@@ -142,5 +144,25 @@ void annotateLevel(Operation *top, DataFlowSolver *solver) {
142144
});
143145
}
144146

147+
LevelState::LevelType getLevelFromMgmtAttr(Value value) {
148+
Attribute attr;
149+
if (auto blockArg = dyn_cast<BlockArgument>(value)) {
150+
auto *parentOp = blockArg.getOwner()->getParentOp();
151+
auto genericOp = dyn_cast<secret::GenericOp>(parentOp);
152+
if (genericOp) {
153+
attr = genericOp.getArgAttr(blockArg.getArgNumber(),
154+
mgmt::MgmtDialect::kArgMgmtAttrName);
155+
}
156+
} else {
157+
auto *parentOp = value.getDefiningOp();
158+
attr = parentOp->getAttr(mgmt::MgmtDialect::kArgMgmtAttrName);
159+
}
160+
if (!mlir::isa<mgmt::MgmtAttr>(attr)) {
161+
assert(false && "MgmtAttr not found");
162+
}
163+
auto mgmtAttr = mlir::cast<mgmt::MgmtAttr>(attr);
164+
return mgmtAttr.getLevel();
165+
}
166+
145167
} // namespace heir
146168
} // namespace mlir

lib/Analysis/LevelAnalysis/LevelAnalysis.h

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -93,6 +93,8 @@ class LevelAnalysis
9393
}
9494
};
9595

96+
LevelState::LevelType getLevelFromMgmtAttr(Value value);
97+
9698
void annotateLevel(Operation *top, DataFlowSolver *solver);
9799

98100
} // namespace heir

lib/Analysis/NoiseAnalysis/BGV/BUILD

Lines changed: 44 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,44 @@
1+
package(
2+
default_applicable_licenses = ["@heir//:license"],
3+
default_visibility = ["//visibility:public"],
4+
)
5+
6+
cc_library(
7+
name = "NoiseByBoundCoeffModel",
8+
srcs = [
9+
"NoiseByBoundCoeffModel.cpp",
10+
"NoiseByBoundCoeffModelAnalysis.cpp",
11+
],
12+
hdrs = [
13+
"NoiseByBoundCoeffModel.h",
14+
],
15+
deps = [
16+
":Noise",
17+
"@heir//lib/Analysis:Utils",
18+
"@heir//lib/Analysis/DimensionAnalysis",
19+
"@heir//lib/Analysis/LevelAnalysis",
20+
"@heir//lib/Analysis/NoiseAnalysis",
21+
"@heir//lib/Dialect/Mgmt/IR:Dialect",
22+
"@heir//lib/Dialect/Secret/IR:Dialect",
23+
"@heir//lib/Dialect/TensorExt/IR:Dialect",
24+
"@heir//lib/Parameters/BGV:Params",
25+
"@llvm-project//llvm:Support",
26+
"@llvm-project//mlir:ArithDialect",
27+
"@llvm-project//mlir:CallOpInterfaces",
28+
"@llvm-project//mlir:IR",
29+
"@llvm-project//mlir:Support",
30+
"@llvm-project//mlir:TensorDialect",
31+
],
32+
)
33+
34+
cc_library(
35+
name = "Noise",
36+
srcs = ["Noise.cpp"],
37+
hdrs = [
38+
"Noise.h",
39+
],
40+
deps = [
41+
"@llvm-project//llvm:Support",
42+
"@llvm-project//mlir:IR",
43+
],
44+
)
Lines changed: 32 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,32 @@
1+
#include "lib/Analysis/NoiseAnalysis/BGV/Noise.h"
2+
3+
#include <cmath>
4+
#include <string>
5+
6+
#include "llvm/include/llvm/Support/raw_ostream.h" // from @llvm-project
7+
#include "mlir/include/mlir/IR/Diagnostics.h" // from @llvm-project
8+
9+
namespace mlir {
10+
namespace heir {
11+
namespace bgv {
12+
13+
std::string NoiseState::toString() const {
14+
switch (noiseType) {
15+
case (NoiseType::UNINITIALIZED):
16+
return "NoiseState(uninitialized)";
17+
case (NoiseType::SET):
18+
return "NoiseState(" + std::to_string(log(getValue()) / log(2)) + ") ";
19+
}
20+
}
21+
22+
llvm::raw_ostream &operator<<(llvm::raw_ostream &os, const NoiseState &noise) {
23+
return os << noise.toString();
24+
}
25+
26+
Diagnostic &operator<<(Diagnostic &diagnostic, const NoiseState &noise) {
27+
return diagnostic << noise.toString();
28+
}
29+
30+
} // namespace bgv
31+
} // namespace heir
32+
} // namespace mlir
Lines changed: 88 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,88 @@
1+
#ifndef INCLUDE_ANALYSIS_NOISEANALYSIS_BGV_NOISE_H_
2+
#define INCLUDE_ANALYSIS_NOISEANALYSIS_BGV_NOISE_H_
3+
4+
#include <optional>
5+
6+
#include "llvm/include/llvm/Support/raw_ostream.h" // from @llvm-project
7+
#include "mlir/include/mlir/IR/Diagnostics.h" // from @llvm-project
8+
9+
namespace mlir {
10+
namespace heir {
11+
namespace bgv {
12+
13+
// This class could be shared among all noise models that tracks the noise by a
14+
// single value. Noise model could have different interpretation of the value.
15+
// In BGV world, most noise model just use a single value, either as bound or
16+
// as variance.
17+
class NoiseState {
18+
public:
19+
enum NoiseType {
20+
// A min value for the lattice, discarable when joined with anything else.
21+
UNINITIALIZED,
22+
// A known value for the lattice, when noise can be inferred.
23+
SET,
24+
};
25+
26+
static NoiseState uninitialized() {
27+
return NoiseState(NoiseType::UNINITIALIZED, std::nullopt);
28+
}
29+
static NoiseState of(double value) {
30+
return NoiseState(NoiseType::SET, value);
31+
}
32+
33+
/// Create an integer value range lattice value.
34+
/// The default constructor must be equivalent to the "entry state" of the
35+
/// lattice, i.e., an uninitialized noise.
36+
NoiseState(NoiseType noiseType = NoiseType::UNINITIALIZED,
37+
std::optional<double> value = std::nullopt)
38+
: noiseType(noiseType), value(value) {}
39+
40+
bool isKnown() const { return noiseType == NoiseType::SET; }
41+
42+
bool isInitialized() const { return noiseType != NoiseType::UNINITIALIZED; }
43+
44+
const double &getValue() const {
45+
assert(isKnown());
46+
return *value;
47+
}
48+
49+
bool operator==(const NoiseState &rhs) const {
50+
return noiseType == rhs.noiseType && value == rhs.value;
51+
}
52+
53+
static NoiseState join(const NoiseState &lhs, const NoiseState &rhs) {
54+
// Uninitialized noises correspond to values that are not secret,
55+
// which may be the inputs to an encryption operation.
56+
if (lhs.noiseType == NoiseType::UNINITIALIZED) {
57+
return rhs;
58+
}
59+
if (rhs.noiseType == NoiseType::UNINITIALIZED) {
60+
return lhs;
61+
}
62+
63+
assert(lhs.noiseType == NoiseType::SET && rhs.noiseType == NoiseType::SET);
64+
return NoiseState::of(std::max(lhs.getValue(), rhs.getValue()));
65+
}
66+
67+
void print(llvm::raw_ostream &os) const { os << value; }
68+
69+
std::string toString() const;
70+
71+
friend llvm::raw_ostream &operator<<(llvm::raw_ostream &os,
72+
const NoiseState &noise);
73+
74+
friend Diagnostic &operator<<(Diagnostic &diagnostic,
75+
const NoiseState &noise);
76+
77+
private:
78+
NoiseType noiseType;
79+
// notice that when level becomes large (e.g. 17), the max Q could be like
80+
// 3523 bits and could not be represented in double.
81+
std::optional<double> value;
82+
};
83+
84+
} // namespace bgv
85+
} // namespace heir
86+
} // namespace mlir
87+
88+
#endif // INCLUDE_ANALYSIS_NOISEANALYSIS_BGV_NOISE_H_

0 commit comments

Comments
 (0)