22
33#include < string>
44
5+ #include " lib/Dialect/Debug/IR/DebugOps.h"
56#include " lib/Dialect/LWE/IR/LWETypes.h"
67#include " lib/Utils/TransformUtils.h"
78#include " llvm/include/llvm/ADT/STLExtras.h" // from @llvm-project
1112#include " mlir/include/mlir/IR/BuiltinOps.h" // from @llvm-project
1213#include " mlir/include/mlir/IR/BuiltinTypes.h" // from @llvm-project
1314#include " mlir/include/mlir/IR/ImplicitLocOpBuilder.h" // from @llvm-project
15+ #include " mlir/include/mlir/IR/OpDefinition.h" // from @llvm-project
1416#include " mlir/include/mlir/IR/Operation.h" // from @llvm-project
1517#include " mlir/include/mlir/IR/TypeUtilities.h" // from @llvm-project
1618#include " mlir/include/mlir/IR/Types.h" // from @llvm-project
@@ -66,84 +68,98 @@ func::FuncOp getOrCreateExternalDebugFunc(
6668 return funcOp;
6769}
6870
69- LogicalResult insertExternalCall (func::FuncOp op, Type lwePrivateKeyType,
70- int messageSize) {
71- auto module = op->getParentOfType <ModuleOp>();
71+ void insertValidationOps (func::FuncOp op) {
72+ int count = 0 ;
73+ auto insertValidate = [&](Value value, OpBuilder& b) {
74+ Type valueType = value.getType ();
75+ if (mlir::isa<LWECiphertextType>(valueType)) {
76+ b.create <debug::ValidateOp>(value.getLoc (), value,
77+ " heir_debug_" + std::to_string (count++),
78+ nullptr );
79+ }
80+ };
7281
73- // map ciphertext type to unique int
74- DenseMap<Type, int > typeToInt;
82+ Block& entryBlock = op.getBody ().getBlocks ().front ();
83+ OpBuilder argBuilder (&entryBlock, entryBlock.begin ());
84+ for (auto arg : op.getArguments ()) {
85+ insertValidate (arg, argBuilder);
86+ }
7587
76- // implicit assumption the first argument is private key
77- auto privateKey = op.getArgument (0 );
88+ op.walk ([&](Operation* walkOp) {
89+ if (walkOp == op.getOperation () ||
90+ walkOp->hasTrait <OpTrait::IsTerminator>())
91+ return ;
92+ OpBuilder opBuilder (walkOp->getBlock (), ++walkOp->getIterator ());
93+ for (Value result : walkOp->getResults ()) {
94+ insertValidate (result, opBuilder);
95+ }
96+ });
97+ }
7898
79- ImplicitLocOpBuilder b = ImplicitLocOpBuilder::atBlockBegin (
80- op.getLoc (), &op.getBody ().getBlocks ().front ());
99+ LogicalResult lowerValidationOps (func::FuncOp op, Type lwePrivateKeyType,
100+ int messageSize) {
101+ auto module = op->getParentOfType <ModuleOp>();
102+ DenseMap<Type, int > typeToInt;
103+ Value privateKey = op.getArgument (0 );
81104
82- auto insertCall = [&](Value value) {
105+ auto walkResult = op.walk ([&](debug::ValidateOp validateOp) {
106+ Value value = validateOp.getInput ();
83107 Type valueType = value.getType ();
84- // NOTE: this won't work for shaped input like tensor<2x!lwe.ciphertext>
85108 if (auto lweCiphertextType = dyn_cast<LWECiphertextType>(valueType)) {
86- // update typeToInt
87109 if (!typeToInt.count (valueType)) {
88110 typeToInt[valueType] = typeToInt.size ();
89111 }
90112
91- // get attribute associated with value
113+ ImplicitLocOpBuilder b (validateOp. getLoc (), validateOp);
92114 SmallVector<NamedAttribute> attrs;
93- if (auto blockArg = dyn_cast<BlockArgument>(value)) {
94- auto * parentOp = blockArg.getOwner ()->getParentOp ();
95- auto funcOp = dyn_cast<FunctionOpInterface>(parentOp);
96- if (funcOp) {
97- // always dialect attr
98- for (auto namedAttr : funcOp.getArgAttrs (blockArg.getArgNumber ())) {
99- attrs.push_back (namedAttr);
100- }
101- }
102- } else {
103- auto * parentOp = value.getDefiningOp ();
104- for (auto namedAttr : parentOp->getDialectAttrs ()) {
105- attrs.push_back (namedAttr);
106- }
115+ // Transfer metadata from validateOp to CallOp
116+ attrs.push_back (b.getNamedAttr (" debug.name" , validateOp.getNameAttr ()));
117+ if (validateOp.getMetadata ()) {
118+ attrs.push_back (
119+ b.getNamedAttr (" debug.metadata" , validateOp.getMetadataAttr ()));
107120 }
108-
109121 attrs.push_back (b.getNamedAttr (
110122 " message.size" , b.getStringAttr (std::to_string (messageSize))));
111123
112- func::CallOp::create (
113- b,
124+ auto callOp = b.create <func::CallOp>(
114125 getOrCreateExternalDebugFunc (module , lwePrivateKeyType,
115126 lweCiphertextType, typeToInt),
116- ArrayRef<Value>{privateKey, value})
117- ->setDialectAttrs (attrs);
118- }
119- };
127+ ArrayRef<Value>{privateKey, value});
128+ callOp->setDialectAttrs (attrs);
120129
121- // insert for each argument
122- for (auto arg : op.getArguments ()) {
123- insertCall (arg);
124- }
125-
126- // insert after each HE op
127- op.walk ([&](Operation* op) {
128- b.setInsertionPointAfter (op);
129- for (Value result : op->getResults ()) {
130- insertCall (result);
130+ validateOp.erase ();
131131 }
132132 return WalkResult::advance ();
133133 });
134- return success ();
134+
135+ return walkResult.wasInterrupted () ? failure () : success ();
135136}
136137
137- LogicalResult convertFunc (func::FuncOp op, int messageSize) {
138+ LogicalResult convertFunc (func::FuncOp op, int messageSize,
139+ bool insertDebugAfterEveryOp) {
140+ if (insertDebugAfterEveryOp) {
141+ insertValidationOps (op);
142+ }
143+
144+ // Check if there are any validation ops to lower before adding private key
145+ bool hasValidationOps = false ;
146+ op.walk ([&](debug::ValidateOp) {
147+ hasValidationOps = true ;
148+ return WalkResult::interrupt ();
149+ });
150+
151+ if (!hasValidationOps) return success ();
152+
138153 auto type = getPrivateKeyType (op);
139154 if (failed (type)) return op.emitError (" failed to get private key type" );
140155 auto lwePrivateKeyType = type.value ();
141156
142157 if (failed (op.insertArgument (0 , lwePrivateKeyType, nullptr , op.getLoc ()))) {
143158 return op.emitError (" failed to insert private key argument" );
144159 }
145- if (failed (insertExternalCall (op, lwePrivateKeyType, messageSize))) {
146- return op.emitError (" failed to insert external call" );
160+
161+ if (failed (lowerValidationOps (op, lwePrivateKeyType, messageSize))) {
162+ return op.emitError (" failed to lower validation ops" );
147163 }
148164 return success ();
149165}
@@ -154,8 +170,9 @@ struct AddDebugPort : impl::AddDebugPortBase<AddDebugPort> {
154170 void runOnOperation () override {
155171 auto funcOp =
156172 detectEntryFunction (cast<ModuleOp>(getOperation ()), entryFunction);
157- if (funcOp && failed (convertFunc (funcOp, messageSize))) {
158- funcOp->emitError (" Failed to configure the crypto context for func" );
173+ if (funcOp &&
174+ failed (convertFunc (funcOp, messageSize, insertDebugAfterEveryOp))) {
175+ funcOp->emitError (" Failed to add debug port for func" );
159176 signalPassFailure ();
160177 }
161178 }
0 commit comments