Skip to content

Commit e69fcea

Browse files
j2kuncopybara-github
authored andcommitted
lower debug.validate to add-debug-port calls
PiperOrigin-RevId: 872500167
1 parent 6d04bee commit e69fcea

32 files changed

Lines changed: 520 additions & 108 deletions

lib/Dialect/Debug/Transforms/BUILD

Lines changed: 41 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,41 @@
1+
load("@heir//lib/Transforms:transforms.bzl", "add_heir_transforms")
2+
load("@rules_cc//cc:cc_library.bzl", "cc_library")
3+
4+
package(
5+
default_applicable_licenses = ["@heir//:license"],
6+
default_visibility = ["//visibility:public"],
7+
)
8+
9+
cc_library(
10+
name = "Transforms",
11+
hdrs = [
12+
"Passes.h",
13+
],
14+
deps = [
15+
":ValidateNames",
16+
":pass_inc_gen",
17+
"@heir//lib/Dialect/Debug/IR:Dialect",
18+
],
19+
)
20+
21+
cc_library(
22+
name = "ValidateNames",
23+
srcs = ["ValidateNames.cpp"],
24+
hdrs = [
25+
"ValidateNames.h",
26+
],
27+
deps = [
28+
":pass_inc_gen",
29+
"@heir//lib/Dialect/Debug/IR:Dialect",
30+
"@llvm-project//llvm:Support",
31+
"@llvm-project//mlir:IR",
32+
"@llvm-project//mlir:Pass",
33+
"@llvm-project//mlir:Support",
34+
],
35+
)
36+
37+
add_heir_transforms(
38+
header_filename = "Passes.h.inc",
39+
pass_name = "Debug",
40+
td_file = "Passes.td",
41+
)
Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,18 @@
1+
#ifndef LIB_DIALECT_DEBUG_TRANSFORMS_PASSES_H_
2+
#define LIB_DIALECT_DEBUG_TRANSFORMS_PASSES_H_
3+
4+
#include "lib/Dialect/Debug/IR/DebugDialect.h"
5+
#include "lib/Dialect/Debug/Transforms/ValidateNames.h"
6+
7+
namespace mlir {
8+
namespace heir {
9+
namespace debug {
10+
11+
#define GEN_PASS_REGISTRATION
12+
#include "lib/Dialect/Debug/Transforms/Passes.h.inc"
13+
14+
} // namespace debug
15+
} // namespace heir
16+
} // namespace mlir
17+
18+
#endif // LIB_DIALECT_DEBUG_TRANSFORMS_PASSES_H_
Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,15 @@
1+
#ifndef LIB_DIALECT_DEBUG_TRANSFORMS_PASSES_TD_
2+
#define LIB_DIALECT_DEBUG_TRANSFORMS_PASSES_TD_
3+
4+
include "mlir/Pass/PassBase.td"
5+
6+
def DebugValidateNames : Pass<"debug-validate-names", "mlir::ModuleOp"> {
7+
let summary = "Validates that debug.validate names are unique";
8+
let description = [{
9+
This pass walks the IR and ensures that each `debug.validate` operation
10+
has a unique `name` attribute. If any duplicates are found, the pass fails.
11+
}];
12+
let dependentDialects = ["mlir::heir::debug::DebugDialect"];
13+
}
14+
15+
#endif // LIB_DIALECT_DEBUG_TRANSFORMS_PASSES_TD_
Lines changed: 37 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,37 @@
1+
#include "lib/Dialect/Debug/Transforms/ValidateNames.h"
2+
3+
#include "lib/Dialect/Debug/IR/DebugOps.h"
4+
#include "llvm/include/llvm/ADT/StringSet.h" // from @llvm-project
5+
#include "mlir/include/mlir/IR/BuiltinOps.h" // from @llvm-project
6+
#include "mlir/include/mlir/Support/LLVM.h" // from @llvm-project
7+
8+
namespace mlir {
9+
namespace heir {
10+
namespace debug {
11+
12+
#define GEN_PASS_DEF_DEBUGVALIDATENAMES
13+
#include "lib/Dialect/Debug/Transforms/Passes.h.inc"
14+
15+
struct DebugValidateNames
16+
: public impl::DebugValidateNamesBase<DebugValidateNames> {
17+
void runOnOperation() override {
18+
ModuleOp module = getOperation();
19+
llvm::StringSet<> names;
20+
WalkResult result = module.walk([&](ValidateOp op) {
21+
StringRef name = op.getName();
22+
if (!names.insert(name).second) {
23+
op.emitError() << "duplicate debug.validate name: " << name;
24+
return WalkResult::interrupt();
25+
}
26+
return WalkResult::advance();
27+
});
28+
29+
if (result.wasInterrupted()) {
30+
signalPassFailure();
31+
}
32+
}
33+
};
34+
35+
} // namespace debug
36+
} // namespace heir
37+
} // namespace mlir
Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,18 @@
1+
#ifndef LIB_DIALECT_DEBUG_TRANSFORMS_VALIDATENAMES_H_
2+
#define LIB_DIALECT_DEBUG_TRANSFORMS_VALIDATENAMES_H_
3+
4+
#include "lib/Dialect/Debug/IR/DebugOps.h"
5+
#include "mlir/include/mlir/Pass/Pass.h" // from @llvm-project
6+
7+
namespace mlir {
8+
namespace heir {
9+
namespace debug {
10+
11+
#define GEN_PASS_DECL_DEBUGVALIDATENAMES
12+
#include "lib/Dialect/Debug/Transforms/Passes.h.inc"
13+
14+
} // namespace debug
15+
} // namespace heir
16+
} // namespace mlir
17+
18+
#endif // LIB_DIALECT_DEBUG_TRANSFORMS_VALIDATENAMES_H_

lib/Dialect/LWE/Transforms/AddDebugPort.cpp

Lines changed: 67 additions & 50 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22

33
#include <string>
44

5+
#include "lib/Dialect/Debug/IR/DebugOps.h"
56
#include "lib/Dialect/LWE/IR/LWETypes.h"
67
#include "lib/Utils/TransformUtils.h"
78
#include "llvm/include/llvm/ADT/STLExtras.h" // from @llvm-project
@@ -11,6 +12,7 @@
1112
#include "mlir/include/mlir/IR/BuiltinOps.h" // from @llvm-project
1213
#include "mlir/include/mlir/IR/BuiltinTypes.h" // from @llvm-project
1314
#include "mlir/include/mlir/IR/ImplicitLocOpBuilder.h" // from @llvm-project
15+
#include "mlir/include/mlir/IR/OpDefinition.h" // from @llvm-project
1416
#include "mlir/include/mlir/IR/Operation.h" // from @llvm-project
1517
#include "mlir/include/mlir/IR/TypeUtilities.h" // from @llvm-project
1618
#include "mlir/include/mlir/IR/Types.h" // from @llvm-project
@@ -66,84 +68,98 @@ func::FuncOp getOrCreateExternalDebugFunc(
6668
return funcOp;
6769
}
6870

69-
LogicalResult insertExternalCall(func::FuncOp op, Type lwePrivateKeyType,
70-
int messageSize) {
71-
auto module = op->getParentOfType<ModuleOp>();
71+
void insertValidationOps(func::FuncOp op) {
72+
int count = 0;
73+
auto insertValidate = [&](Value value, OpBuilder& b) {
74+
Type valueType = value.getType();
75+
if (mlir::isa<LWECiphertextType>(valueType)) {
76+
b.create<debug::ValidateOp>(value.getLoc(), value,
77+
"heir_debug_" + std::to_string(count++),
78+
nullptr);
79+
}
80+
};
7281

73-
// map ciphertext type to unique int
74-
DenseMap<Type, int> typeToInt;
82+
Block& entryBlock = op.getBody().getBlocks().front();
83+
OpBuilder argBuilder(&entryBlock, entryBlock.begin());
84+
for (auto arg : op.getArguments()) {
85+
insertValidate(arg, argBuilder);
86+
}
7587

76-
// implicit assumption the first argument is private key
77-
auto privateKey = op.getArgument(0);
88+
op.walk([&](Operation* walkOp) {
89+
if (walkOp == op.getOperation() ||
90+
walkOp->hasTrait<OpTrait::IsTerminator>())
91+
return;
92+
OpBuilder opBuilder(walkOp->getBlock(), ++walkOp->getIterator());
93+
for (Value result : walkOp->getResults()) {
94+
insertValidate(result, opBuilder);
95+
}
96+
});
97+
}
7898

79-
ImplicitLocOpBuilder b = ImplicitLocOpBuilder::atBlockBegin(
80-
op.getLoc(), &op.getBody().getBlocks().front());
99+
LogicalResult lowerValidationOps(func::FuncOp op, Type lwePrivateKeyType,
100+
int messageSize) {
101+
auto module = op->getParentOfType<ModuleOp>();
102+
DenseMap<Type, int> typeToInt;
103+
Value privateKey = op.getArgument(0);
81104

82-
auto insertCall = [&](Value value) {
105+
auto walkResult = op.walk([&](debug::ValidateOp validateOp) {
106+
Value value = validateOp.getInput();
83107
Type valueType = value.getType();
84-
// NOTE: this won't work for shaped input like tensor<2x!lwe.ciphertext>
85108
if (auto lweCiphertextType = dyn_cast<LWECiphertextType>(valueType)) {
86-
// update typeToInt
87109
if (!typeToInt.count(valueType)) {
88110
typeToInt[valueType] = typeToInt.size();
89111
}
90112

91-
// get attribute associated with value
113+
ImplicitLocOpBuilder b(validateOp.getLoc(), validateOp);
92114
SmallVector<NamedAttribute> attrs;
93-
if (auto blockArg = dyn_cast<BlockArgument>(value)) {
94-
auto* parentOp = blockArg.getOwner()->getParentOp();
95-
auto funcOp = dyn_cast<FunctionOpInterface>(parentOp);
96-
if (funcOp) {
97-
// always dialect attr
98-
for (auto namedAttr : funcOp.getArgAttrs(blockArg.getArgNumber())) {
99-
attrs.push_back(namedAttr);
100-
}
101-
}
102-
} else {
103-
auto* parentOp = value.getDefiningOp();
104-
for (auto namedAttr : parentOp->getDialectAttrs()) {
105-
attrs.push_back(namedAttr);
106-
}
115+
// Transfer metadata from validateOp to CallOp
116+
attrs.push_back(b.getNamedAttr("debug.name", validateOp.getNameAttr()));
117+
if (validateOp.getMetadata()) {
118+
attrs.push_back(
119+
b.getNamedAttr("debug.metadata", validateOp.getMetadataAttr()));
107120
}
108-
109121
attrs.push_back(b.getNamedAttr(
110122
"message.size", b.getStringAttr(std::to_string(messageSize))));
111123

112-
func::CallOp::create(
113-
b,
124+
auto callOp = b.create<func::CallOp>(
114125
getOrCreateExternalDebugFunc(module, lwePrivateKeyType,
115126
lweCiphertextType, typeToInt),
116-
ArrayRef<Value>{privateKey, value})
117-
->setDialectAttrs(attrs);
118-
}
119-
};
127+
ArrayRef<Value>{privateKey, value});
128+
callOp->setDialectAttrs(attrs);
120129

121-
// insert for each argument
122-
for (auto arg : op.getArguments()) {
123-
insertCall(arg);
124-
}
125-
126-
// insert after each HE op
127-
op.walk([&](Operation* op) {
128-
b.setInsertionPointAfter(op);
129-
for (Value result : op->getResults()) {
130-
insertCall(result);
130+
validateOp.erase();
131131
}
132132
return WalkResult::advance();
133133
});
134-
return success();
134+
135+
return walkResult.wasInterrupted() ? failure() : success();
135136
}
136137

137-
LogicalResult convertFunc(func::FuncOp op, int messageSize) {
138+
LogicalResult convertFunc(func::FuncOp op, int messageSize,
139+
bool insertDebugAfterEveryOp) {
140+
if (insertDebugAfterEveryOp) {
141+
insertValidationOps(op);
142+
}
143+
144+
// Check if there are any validation ops to lower before adding private key
145+
bool hasValidationOps = false;
146+
op.walk([&](debug::ValidateOp) {
147+
hasValidationOps = true;
148+
return WalkResult::interrupt();
149+
});
150+
151+
if (!hasValidationOps) return success();
152+
138153
auto type = getPrivateKeyType(op);
139154
if (failed(type)) return op.emitError("failed to get private key type");
140155
auto lwePrivateKeyType = type.value();
141156

142157
if (failed(op.insertArgument(0, lwePrivateKeyType, nullptr, op.getLoc()))) {
143158
return op.emitError("failed to insert private key argument");
144159
}
145-
if (failed(insertExternalCall(op, lwePrivateKeyType, messageSize))) {
146-
return op.emitError("failed to insert external call");
160+
161+
if (failed(lowerValidationOps(op, lwePrivateKeyType, messageSize))) {
162+
return op.emitError("failed to lower validation ops");
147163
}
148164
return success();
149165
}
@@ -154,8 +170,9 @@ struct AddDebugPort : impl::AddDebugPortBase<AddDebugPort> {
154170
void runOnOperation() override {
155171
auto funcOp =
156172
detectEntryFunction(cast<ModuleOp>(getOperation()), entryFunction);
157-
if (funcOp && failed(convertFunc(funcOp, messageSize))) {
158-
funcOp->emitError("Failed to configure the crypto context for func");
173+
if (funcOp &&
174+
failed(convertFunc(funcOp, messageSize, insertDebugAfterEveryOp))) {
175+
funcOp->emitError("Failed to add debug port for func");
159176
signalPassFailure();
160177
}
161178
}

lib/Dialect/LWE/Transforms/BUILD

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -26,8 +26,8 @@ cc_library(
2626
],
2727
deps = [
2828
":pass_inc_gen",
29+
"@heir//lib/Dialect/Debug/IR:Dialect",
2930
"@heir//lib/Dialect/LWE/IR:Dialect",
30-
"@heir//lib/Dialect/TensorExt/IR:Dialect",
3131
"@heir//lib/Utils:TransformUtils",
3232
"@llvm-project//llvm:Support",
3333
"@llvm-project//mlir:FuncDialect",

lib/Dialect/LWE/Transforms/Passes.td

Lines changed: 12 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,12 @@ def AddDebugPort : Pass<"lwe-add-debug-port"> {
1212
function. The debug ports are declarations and user should provide functions with
1313
the same name in their code.
1414

15+
If the option `insert-debug-after-every-op` is set to true, it will insert a `debug.validate`
16+
op after every homomorphic operation.
17+
18+
Regardless of the `insert-debug-after-every-op` option, this pass will lower all
19+
`debug.validate` ops it encounters to function calls.
20+
1521
For example, if the function is called "foo", the secret key is added to its
1622
arguments, and the debug port is called after each homomorphic operation:
1723
```mlir
@@ -29,13 +35,18 @@ def AddDebugPort : Pass<"lwe-add-debug-port"> {
2935
}
3036
```
3137
}];
32-
let dependentDialects = ["mlir::heir::lwe::LWEDialect"];
38+
let dependentDialects = [
39+
"mlir::heir::lwe::LWEDialect",
40+
"mlir::heir::debug::DebugDialect"
41+
];
3342
let options = [
3443
Option<"entryFunction", "entry-function", "std::string",
3544
/*default=*/"", "Default entry function "
3645
"name of entry function.">,
3746
Option<"messageSize", "message-size", "int",
3847
/*default=*/"1", "The size of the message in the ciphertext.">,
48+
Option<"insertDebugAfterEveryOp", "insert-debug-after-every-op", "bool",
49+
/*default=*/"false", "Whether to add debug ports after every op">
3950
];
4051
}
4152

0 commit comments

Comments
 (0)