Skip to content

Commit

Permalink
Merge pull request #1343 from ZenithalHourlyRate:bgv-worst-noise-anal…
Browse files Browse the repository at this point in the history
…ysis

PiperOrigin-RevId: 725783991
  • Loading branch information
copybara-github committed Feb 11, 2025
2 parents c751044 + e834ba9 commit d7fbf06
Show file tree
Hide file tree
Showing 27 changed files with 1,387 additions and 6 deletions.
1 change: 1 addition & 0 deletions lib/Analysis/DimensionAnalysis/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@ cc_library(
"@llvm-project//llvm:Support",
"@llvm-project//mlir:Analysis",
"@llvm-project//mlir:ArithDialect",
"@llvm-project//mlir:CallOpInterfaces",
"@llvm-project//mlir:FunctionInterfaces",
"@llvm-project//mlir:IR",
"@llvm-project//mlir:Support",
Expand Down
27 changes: 27 additions & 0 deletions lib/Analysis/DimensionAnalysis/DimensionAnalysis.cpp
Original file line number Diff line number Diff line change
@@ -1,18 +1,25 @@
#include "lib/Analysis/DimensionAnalysis/DimensionAnalysis.h"

#include <algorithm>
#include <cassert>
#include <functional>

#include "lib/Analysis/SecretnessAnalysis/SecretnessAnalysis.h"
#include "lib/Analysis/Utils.h"
#include "lib/Dialect/Mgmt/IR/MgmtAttributes.h"
#include "lib/Dialect/Mgmt/IR/MgmtDialect.h"
#include "lib/Dialect/Mgmt/IR/MgmtOps.h"
#include "lib/Dialect/Secret/IR/SecretOps.h"
#include "llvm/include/llvm/ADT/TypeSwitch.h" // from @llvm-project
#include "mlir/include/mlir/Analysis/DataFlowFramework.h" // from @llvm-project
#include "mlir/include/mlir/Dialect/Arith/IR/Arith.h" // from @llvm-project
#include "mlir/include/mlir/IR/Attributes.h" // from @llvm-project
#include "mlir/include/mlir/IR/BuiltinAttributes.h" // from @llvm-project
#include "mlir/include/mlir/IR/BuiltinTypes.h" // from @llvm-project
#include "mlir/include/mlir/IR/Operation.h" // from @llvm-project
#include "mlir/include/mlir/IR/Value.h" // from @llvm-project
#include "mlir/include/mlir/IR/Visitors.h" // from @llvm-project
#include "mlir/include/mlir/Interfaces/CallInterfaces.h" // from @llvm-project
#include "mlir/include/mlir/Interfaces/FunctionInterfaces.h" // from @llvm-project
#include "mlir/include/mlir/Support/LLVM.h" // from @llvm-project

Expand Down Expand Up @@ -106,6 +113,26 @@ int getDimension(Value value, DataFlowSolver *solver) {
return lattice->getValue().getDimension();
}

int getDimensionFromMgmtAttr(Value value) {
Attribute attr;
if (auto blockArg = dyn_cast<BlockArgument>(value)) {
auto *parentOp = blockArg.getOwner()->getParentOp();
auto genericOp = dyn_cast<secret::GenericOp>(parentOp);
if (genericOp) {
attr = genericOp.getArgAttr(blockArg.getArgNumber(),
mgmt::MgmtDialect::kArgMgmtAttrName);
}
} else {
auto *parentOp = value.getDefiningOp();
attr = parentOp->getAttr(mgmt::MgmtDialect::kArgMgmtAttrName);
}
if (!mlir::isa<mgmt::MgmtAttr>(attr)) {
assert(false && "MgmtAttr not found");
}
auto mgmtAttr = mlir::cast<mgmt::MgmtAttr>(attr);
return mgmtAttr.getDimension();
}

void annotateDimension(Operation *top, DataFlowSolver *solver) {
auto getIntegerAttr = [&](int dimension) {
return IntegerAttr::get(IntegerType::get(top->getContext(), 64), dimension);
Expand Down
2 changes: 2 additions & 0 deletions lib/Analysis/DimensionAnalysis/DimensionAnalysis.h
Original file line number Diff line number Diff line change
Expand Up @@ -93,6 +93,8 @@ class DimensionAnalysis
// initialized
DimensionState::DimensionType getDimension(Value value, DataFlowSolver *solver);

DimensionState::DimensionType getDimensionFromMgmtAttr(Value value);

void annotateDimension(Operation *top, DataFlowSolver *solver);

} // namespace heir
Expand Down
3 changes: 1 addition & 2 deletions lib/Analysis/LevelAnalysis/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -15,9 +15,8 @@ cc_library(
"@llvm-project//llvm:Support",
"@llvm-project//mlir:Analysis",
"@llvm-project//mlir:ArithDialect",
"@llvm-project//mlir:FunctionInterfaces",
"@llvm-project//mlir:CallOpInterfaces",
"@llvm-project//mlir:IR",
"@llvm-project//mlir:Support",
"@llvm-project//mlir:TensorDialect",
],
)
30 changes: 26 additions & 4 deletions lib/Analysis/LevelAnalysis/LevelAnalysis.cpp
Original file line number Diff line number Diff line change
@@ -1,23 +1,25 @@
#include "lib/Analysis/LevelAnalysis/LevelAnalysis.h"

#include <algorithm>
#include <cassert>
#include <functional>

#include "lib/Analysis/SecretnessAnalysis/SecretnessAnalysis.h"
#include "lib/Analysis/Utils.h"
#include "lib/Dialect/Mgmt/IR/MgmtAttributes.h"
#include "lib/Dialect/Mgmt/IR/MgmtDialect.h"
#include "lib/Dialect/Mgmt/IR/MgmtOps.h"
#include "lib/Dialect/Secret/IR/SecretOps.h"
#include "llvm/include/llvm/ADT/TypeSwitch.h" // from @llvm-project
#include "mlir/include/mlir/Analysis/DataFlowFramework.h" // from @llvm-project
#include "mlir/include/mlir/Dialect/Arith/IR/Arith.h" // from @llvm-project
#include "mlir/include/mlir/Dialect/Tensor/IR/Tensor.h" // from @llvm-project
#include "mlir/include/mlir/IR/Attributes.h" // from @llvm-project
#include "mlir/include/mlir/IR/BuiltinAttributes.h" // from @llvm-project
#include "mlir/include/mlir/IR/BuiltinTypes.h" // from @llvm-project
#include "mlir/include/mlir/IR/Operation.h" // from @llvm-project
#include "mlir/include/mlir/IR/Value.h" // from @llvm-project
#include "mlir/include/mlir/IR/Visitors.h" // from @llvm-project
#include "mlir/include/mlir/Interfaces/FunctionInterfaces.h" // from @llvm-project
#include "mlir/include/mlir/Support/LLVM.h" // from @llvm-project
#include "mlir/include/mlir/Interfaces/CallInterfaces.h" // from @llvm-project
#include "mlir/include/mlir/Support/LLVM.h" // from @llvm-project

namespace mlir {
namespace heir {
Expand Down Expand Up @@ -142,5 +144,25 @@ void annotateLevel(Operation *top, DataFlowSolver *solver) {
});
}

LevelState::LevelType getLevelFromMgmtAttr(Value value) {
Attribute attr;
if (auto blockArg = dyn_cast<BlockArgument>(value)) {
auto *parentOp = blockArg.getOwner()->getParentOp();
auto genericOp = dyn_cast<secret::GenericOp>(parentOp);
if (genericOp) {
attr = genericOp.getArgAttr(blockArg.getArgNumber(),
mgmt::MgmtDialect::kArgMgmtAttrName);
}
} else {
auto *parentOp = value.getDefiningOp();
attr = parentOp->getAttr(mgmt::MgmtDialect::kArgMgmtAttrName);
}
if (!mlir::isa<mgmt::MgmtAttr>(attr)) {
assert(false && "MgmtAttr not found");
}
auto mgmtAttr = mlir::cast<mgmt::MgmtAttr>(attr);
return mgmtAttr.getLevel();
}

} // namespace heir
} // namespace mlir
2 changes: 2 additions & 0 deletions lib/Analysis/LevelAnalysis/LevelAnalysis.h
Original file line number Diff line number Diff line change
Expand Up @@ -93,6 +93,8 @@ class LevelAnalysis
}
};

LevelState::LevelType getLevelFromMgmtAttr(Value value);

void annotateLevel(Operation *top, DataFlowSolver *solver);

} // namespace heir
Expand Down
44 changes: 44 additions & 0 deletions lib/Analysis/NoiseAnalysis/BGV/BUILD
Original file line number Diff line number Diff line change
@@ -0,0 +1,44 @@
package(
default_applicable_licenses = ["@heir//:license"],
default_visibility = ["//visibility:public"],
)

cc_library(
name = "NoiseByBoundCoeffModel",
srcs = [
"NoiseByBoundCoeffModel.cpp",
"NoiseByBoundCoeffModelAnalysis.cpp",
],
hdrs = [
"NoiseByBoundCoeffModel.h",
],
deps = [
":Noise",
"@heir//lib/Analysis:Utils",
"@heir//lib/Analysis/DimensionAnalysis",
"@heir//lib/Analysis/LevelAnalysis",
"@heir//lib/Analysis/NoiseAnalysis",
"@heir//lib/Dialect/Mgmt/IR:Dialect",
"@heir//lib/Dialect/Secret/IR:Dialect",
"@heir//lib/Dialect/TensorExt/IR:Dialect",
"@heir//lib/Parameters/BGV:Params",
"@llvm-project//llvm:Support",
"@llvm-project//mlir:ArithDialect",
"@llvm-project//mlir:CallOpInterfaces",
"@llvm-project//mlir:IR",
"@llvm-project//mlir:Support",
"@llvm-project//mlir:TensorDialect",
],
)

cc_library(
name = "Noise",
srcs = ["Noise.cpp"],
hdrs = [
"Noise.h",
],
deps = [
"@llvm-project//llvm:Support",
"@llvm-project//mlir:IR",
],
)
32 changes: 32 additions & 0 deletions lib/Analysis/NoiseAnalysis/BGV/Noise.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,32 @@
#include "lib/Analysis/NoiseAnalysis/BGV/Noise.h"

#include <cmath>
#include <string>

#include "llvm/include/llvm/Support/raw_ostream.h" // from @llvm-project
#include "mlir/include/mlir/IR/Diagnostics.h" // from @llvm-project

namespace mlir {
namespace heir {
namespace bgv {

std::string NoiseState::toString() const {
switch (noiseType) {
case (NoiseType::UNINITIALIZED):
return "NoiseState(uninitialized)";
case (NoiseType::SET):
return "NoiseState(" + std::to_string(log(getValue()) / log(2)) + ") ";
}
}

llvm::raw_ostream &operator<<(llvm::raw_ostream &os, const NoiseState &noise) {
return os << noise.toString();
}

Diagnostic &operator<<(Diagnostic &diagnostic, const NoiseState &noise) {
return diagnostic << noise.toString();
}

} // namespace bgv
} // namespace heir
} // namespace mlir
88 changes: 88 additions & 0 deletions lib/Analysis/NoiseAnalysis/BGV/Noise.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,88 @@
#ifndef INCLUDE_ANALYSIS_NOISEANALYSIS_BGV_NOISE_H_
#define INCLUDE_ANALYSIS_NOISEANALYSIS_BGV_NOISE_H_

#include <optional>

#include "llvm/include/llvm/Support/raw_ostream.h" // from @llvm-project
#include "mlir/include/mlir/IR/Diagnostics.h" // from @llvm-project

namespace mlir {
namespace heir {
namespace bgv {

// This class could be shared among all noise models that tracks the noise by a
// single value. Noise model could have different interpretation of the value.
// In BGV world, most noise model just use a single value, either as bound or
// as variance.
class NoiseState {
public:
enum NoiseType {
// A min value for the lattice, discarable when joined with anything else.
UNINITIALIZED,
// A known value for the lattice, when noise can be inferred.
SET,
};

static NoiseState uninitialized() {
return NoiseState(NoiseType::UNINITIALIZED, std::nullopt);
}
static NoiseState of(double value) {
return NoiseState(NoiseType::SET, value);
}

/// Create an integer value range lattice value.
/// The default constructor must be equivalent to the "entry state" of the
/// lattice, i.e., an uninitialized noise.
NoiseState(NoiseType noiseType = NoiseType::UNINITIALIZED,
std::optional<double> value = std::nullopt)
: noiseType(noiseType), value(value) {}

bool isKnown() const { return noiseType == NoiseType::SET; }

bool isInitialized() const { return noiseType != NoiseType::UNINITIALIZED; }

const double &getValue() const {
assert(isKnown());
return *value;
}

bool operator==(const NoiseState &rhs) const {
return noiseType == rhs.noiseType && value == rhs.value;
}

static NoiseState join(const NoiseState &lhs, const NoiseState &rhs) {
// Uninitialized noises correspond to values that are not secret,
// which may be the inputs to an encryption operation.
if (lhs.noiseType == NoiseType::UNINITIALIZED) {
return rhs;
}
if (rhs.noiseType == NoiseType::UNINITIALIZED) {
return lhs;
}

assert(lhs.noiseType == NoiseType::SET && rhs.noiseType == NoiseType::SET);
return NoiseState::of(std::max(lhs.getValue(), rhs.getValue()));
}

void print(llvm::raw_ostream &os) const { os << value; }

std::string toString() const;

friend llvm::raw_ostream &operator<<(llvm::raw_ostream &os,
const NoiseState &noise);

friend Diagnostic &operator<<(Diagnostic &diagnostic,
const NoiseState &noise);

private:
NoiseType noiseType;
// notice that when level becomes large (e.g. 17), the max Q could be like
// 3523 bits and could not be represented in double.
std::optional<double> value;
};

} // namespace bgv
} // namespace heir
} // namespace mlir

#endif // INCLUDE_ANALYSIS_NOISEANALYSIS_BGV_NOISE_H_
Loading

0 comments on commit d7fbf06

Please sign in to comment.