@@ -70,15 +70,53 @@ static bool allowedRemainArith(Operation *op) {
70
70
}
71
71
return false ;
72
72
})
73
+ .Case <mlir::arith::SubIOp, mlir::arith::AddIOp, mlir::arith::MulIOp>(
74
+ [](auto op) {
75
+ // This lambda will be called for any of the matched operation types
76
+ if (auto lhsDefOp = op.getOperand (0 ).getDefiningOp ()) {
77
+ auto lshAllowed = allowedRemainArith (lhsDefOp);
78
+ if (auto rhsDefOp = op.getOperand (1 ).getDefiningOp ()) {
79
+ auto rhsAllowed = allowedRemainArith (rhsDefOp);
80
+ return lshAllowed && rhsAllowed;
81
+ }
82
+ }
83
+ return false ;
84
+ })
73
85
.Default ([](Operation *) {
74
86
// Default case for operations that don't match any of the types
75
87
return false ;
76
88
});
77
89
}
78
90
79
91
static bool hasLWEAnnotation (Operation *op) {
80
- return static_cast <bool >(
81
- op->getAttrOfType <mlir::StringAttr>(" lwe_annotation" ));
92
+ mlir::StringAttr check =
93
+ op->getAttrOfType <mlir::StringAttr>(" lwe_annotation" );
94
+
95
+ if (check) return true ;
96
+
97
+ // Check recursively if a defining op has a LWE annotation
98
+ return llvm::TypeSwitch<Operation *, bool >(op)
99
+ .Case <mlir::arith::ExtUIOp, mlir::arith::ExtSIOp, mlir::arith::TruncIOp>(
100
+ [](auto op) {
101
+ if (auto *defOp = op.getIn ().getDefiningOp ()) {
102
+ return hasLWEAnnotation (defOp);
103
+ }
104
+ return op->template getAttrOfType <mlir::StringAttr>(
105
+ " lwe_annotation" ) != nullptr ;
106
+ })
107
+ .Case <mlir::arith::SubIOp, mlir::arith::AddIOp, mlir::arith::MulIOp>(
108
+ [](auto op) {
109
+ // This lambda will be called for any of the matched operation types
110
+ if (auto lhsDefOp = op.getOperand (0 ).getDefiningOp ()) {
111
+ auto lshAllowed = hasLWEAnnotation (lhsDefOp);
112
+ if (auto rhsDefOp = op.getOperand (1 ).getDefiningOp ()) {
113
+ auto rhsAllowed = hasLWEAnnotation (rhsDefOp);
114
+ return lshAllowed || rhsAllowed;
115
+ }
116
+ }
117
+ return false ;
118
+ })
119
+ .Default ([](Operation *) { return false ; });
82
120
}
83
121
84
122
static Value materializeTarget (OpBuilder &builder, Type type, ValueRange inputs,
@@ -89,10 +127,18 @@ static Value materializeTarget(OpBuilder &builder, Type type, ValueRange inputs,
89
127
llvm_unreachable (
90
128
" Non-integer types should never be the input to a materializeTarget." );
91
129
92
- auto inValue = inputs.front ().getDefiningOp <mlir::arith::ConstantOp>();
93
- auto intAttr = cast<IntegerAttr>(inValue.getValueAttr ());
130
+ if ( auto inValue = inputs.front ().getDefiningOp <mlir::arith::ConstantOp>()) {
131
+ auto intAttr = cast<IntegerAttr>(inValue.getValueAttr ());
94
132
95
- return builder.create <cggi::CreateTrivialOp>(loc, type, intAttr);
133
+ return builder.create <cggi::CreateTrivialOp>(loc, type, intAttr);
134
+ }
135
+ // Comes from function/loop argument: Trivial encrypt through LWE
136
+ auto encoding = cast<lwe::LWECiphertextType>(type).getEncoding ();
137
+ auto ptxtTy = lwe::LWEPlaintextType::get (builder.getContext (), encoding);
138
+ return builder.create <lwe::TrivialEncryptOp>(
139
+ loc, type,
140
+ builder.create <lwe::EncodeOp>(loc, ptxtTy, inputs[0 ], encoding),
141
+ lwe::LWEParamsAttr ());
96
142
}
97
143
98
144
class ArithToCGGITypeConverter : public TypeConverter {
@@ -175,18 +221,109 @@ struct ConvertExtSIOp : public OpConversionPattern<mlir::arith::ExtSIOp> {
175
221
}
176
222
};
177
223
178
- struct ConvertShRUIOp : public OpConversionPattern <mlir::arith::ShRUIOp > {
179
- ConvertShRUIOp (mlir::MLIRContext *context)
180
- : OpConversionPattern<mlir::arith::ShRUIOp >(context) {}
224
+ struct ConvertCmpOp : public OpConversionPattern <mlir::arith::CmpIOp > {
225
+ ConvertCmpOp (mlir::MLIRContext *context)
226
+ : OpConversionPattern<mlir::arith::CmpIOp >(context) {}
181
227
182
228
using OpConversionPattern::OpConversionPattern;
183
229
184
230
LogicalResult matchAndRewrite (
185
- mlir::arith::ShRUIOp op, OpAdaptor adaptor,
231
+ mlir::arith::CmpIOp op, OpAdaptor adaptor,
186
232
ConversionPatternRewriter &rewriter) const override {
187
233
ImplicitLocOpBuilder b (op.getLoc (), rewriter);
188
234
189
- auto cteShiftSizeOp = op.getRhs ().getDefiningOp <mlir::arith::ConstantOp>();
235
+ auto lweBooleanType = lwe::LWECiphertextType::get (
236
+ op->getContext (),
237
+ lwe::UnspecifiedBitFieldEncodingAttr::get (op->getContext (), 1 ),
238
+ lwe::LWEParamsAttr ());
239
+
240
+ if (auto lhsDefOp = op.getLhs ().getDefiningOp ()) {
241
+ if (!hasLWEAnnotation (lhsDefOp) && allowedRemainArith (lhsDefOp)) {
242
+ auto result = b.create <cggi::CmpOp>(lweBooleanType, op.getPredicate (),
243
+ adaptor.getRhs (), op.getLhs ());
244
+ rewriter.replaceOp (op, result);
245
+ return success ();
246
+ }
247
+ }
248
+
249
+ if (auto rhsDefOp = op.getRhs ().getDefiningOp ()) {
250
+ if (!hasLWEAnnotation (rhsDefOp) && allowedRemainArith (rhsDefOp)) {
251
+ auto result = b.create <cggi::CmpOp>(lweBooleanType, op.getPredicate (),
252
+ adaptor.getLhs (), op.getRhs ());
253
+ rewriter.replaceOp (op, result);
254
+ return success ();
255
+ }
256
+ }
257
+
258
+ auto cmpOp = b.create <cggi::CmpOp>(lweBooleanType, op.getPredicate (),
259
+ adaptor.getLhs (), adaptor.getRhs ());
260
+
261
+ rewriter.replaceOp (op, cmpOp);
262
+ return success ();
263
+ }
264
+ };
265
+
266
+ struct ConvertSubOp : public OpConversionPattern <mlir::arith::SubIOp> {
267
+ ConvertSubOp (mlir::MLIRContext *context)
268
+ : OpConversionPattern<mlir::arith::SubIOp>(context) {}
269
+
270
+ using OpConversionPattern::OpConversionPattern;
271
+
272
+ LogicalResult matchAndRewrite (
273
+ mlir::arith::SubIOp op, OpAdaptor adaptor,
274
+ ConversionPatternRewriter &rewriter) const override {
275
+ ImplicitLocOpBuilder b (op.getLoc (), rewriter);
276
+
277
+ if (auto rhsDefOp = op.getRhs ().getDefiningOp ()) {
278
+ if (!hasLWEAnnotation (rhsDefOp) && allowedRemainArith (rhsDefOp)) {
279
+ auto result = b.create <cggi::SubOp>(adaptor.getLhs ().getType (),
280
+ adaptor.getLhs (), op.getRhs ());
281
+ rewriter.replaceOp (op, result);
282
+ return success ();
283
+ }
284
+ }
285
+
286
+ auto subOp = b.create <cggi::SubOp>(adaptor.getLhs ().getType (),
287
+ adaptor.getLhs (), adaptor.getRhs ());
288
+ rewriter.replaceOp (op, subOp);
289
+ return success ();
290
+ }
291
+ };
292
+
293
+ struct ConvertSelectOp : public OpConversionPattern <mlir::arith::SelectOp> {
294
+ ConvertSelectOp (mlir::MLIRContext *context)
295
+ : OpConversionPattern<mlir::arith::SelectOp>(context) {}
296
+
297
+ using OpConversionPattern::OpConversionPattern;
298
+
299
+ LogicalResult matchAndRewrite (
300
+ mlir::arith::SelectOp op, OpAdaptor adaptor,
301
+ ConversionPatternRewriter &rewriter) const override {
302
+ ImplicitLocOpBuilder b (op.getLoc (), rewriter);
303
+
304
+ auto cmuxOp = b.create <cggi::SelectOp>(
305
+ adaptor.getTrueValue ().getType (), adaptor.getCondition (),
306
+ adaptor.getTrueValue (), adaptor.getFalseValue ());
307
+
308
+ rewriter.replaceOp (op, cmuxOp);
309
+ return success ();
310
+ }
311
+ };
312
+
313
+ template <typename SourceArithShOp, typename TargetCGGIShOp>
314
+ struct ConvertShOp : public OpConversionPattern <SourceArithShOp> {
315
+ ConvertShOp (mlir::MLIRContext *context)
316
+ : OpConversionPattern<SourceArithShOp>(context) {}
317
+
318
+ using OpConversionPattern<SourceArithShOp>::OpConversionPattern;
319
+
320
+ LogicalResult matchAndRewrite (
321
+ SourceArithShOp op, typename SourceArithShOp::Adaptor adaptor,
322
+ ConversionPatternRewriter &rewriter) const override {
323
+ ImplicitLocOpBuilder b (op.getLoc (), rewriter);
324
+
325
+ auto cteShiftSizeOp =
326
+ op.getRhs ().template getDefiningOp <mlir::arith::ConstantOp>();
190
327
191
328
if (cteShiftSizeOp) {
192
329
auto outputType = adaptor.getLhs ().getType ();
@@ -198,14 +335,15 @@ struct ConvertShRUIOp : public OpConversionPattern<mlir::arith::ShRUIOp> {
198
335
auto inputValue =
199
336
mlir::IntegerAttr::get (rewriter.getIndexType (), (int8_t )shiftAmount);
200
337
201
- auto shiftOp = b. create <cggi::ScalarShiftRightOp>(
202
- outputType, adaptor.getLhs (), inputValue);
338
+ auto shiftOp =
339
+ b. create <TargetCGGIShOp>( outputType, adaptor.getLhs (), inputValue);
203
340
rewriter.replaceOp (op, shiftOp);
204
341
205
342
return success ();
206
343
}
207
344
208
- cteShiftSizeOp = op.getLhs ().getDefiningOp <mlir::arith::ConstantOp>();
345
+ cteShiftSizeOp =
346
+ op.getLhs ().template getDefiningOp <mlir::arith::ConstantOp>();
209
347
210
348
auto outputType = adaptor.getRhs ().getType ();
211
349
@@ -215,15 +353,15 @@ struct ConvertShRUIOp : public OpConversionPattern<mlir::arith::ShRUIOp> {
215
353
auto inputValue =
216
354
mlir::IntegerAttr::get (rewriter.getIndexType (), shiftAmount);
217
355
218
- auto shiftOp = b. create <cggi::ScalarShiftRightOp>(
219
- outputType, adaptor.getLhs (), inputValue);
356
+ auto shiftOp =
357
+ b. create <TargetCGGIShOp>( outputType, adaptor.getLhs (), inputValue);
220
358
rewriter.replaceOp (op, shiftOp);
221
359
222
360
return success ();
223
361
}
224
362
};
225
363
226
- template <typename SourceArithOp, typename TargetModArithOp >
364
+ template <typename SourceArithOp, typename TargetCGGIOp >
227
365
struct ConvertArithBinOp : public OpConversionPattern <SourceArithOp> {
228
366
ConvertArithBinOp (mlir::MLIRContext *context)
229
367
: OpConversionPattern<SourceArithOp>(context) {}
@@ -237,24 +375,24 @@ struct ConvertArithBinOp : public OpConversionPattern<SourceArithOp> {
237
375
238
376
if (auto lhsDefOp = op.getLhs ().getDefiningOp ()) {
239
377
if (!hasLWEAnnotation (lhsDefOp) && allowedRemainArith (lhsDefOp)) {
240
- auto result = b.create <TargetModArithOp >(adaptor.getRhs ().getType (),
241
- adaptor.getRhs (), op.getLhs ());
378
+ auto result = b.create <TargetCGGIOp >(adaptor.getRhs ().getType (),
379
+ adaptor.getRhs (), op.getLhs ());
242
380
rewriter.replaceOp (op, result);
243
381
return success ();
244
382
}
245
383
}
246
384
247
385
if (auto rhsDefOp = op.getRhs ().getDefiningOp ()) {
248
386
if (!hasLWEAnnotation (rhsDefOp) && allowedRemainArith (rhsDefOp)) {
249
- auto result = b.create <TargetModArithOp >(adaptor.getLhs ().getType (),
250
- adaptor.getLhs (), op.getRhs ());
387
+ auto result = b.create <TargetCGGIOp >(adaptor.getLhs ().getType (),
388
+ adaptor.getLhs (), op.getRhs ());
251
389
rewriter.replaceOp (op, result);
252
390
return success ();
253
391
}
254
392
}
255
393
256
- auto result = b.create <TargetModArithOp>(
257
- adaptor. getLhs (). getType (), adaptor.getLhs (), adaptor.getRhs ());
394
+ auto result = b.create <TargetCGGIOp>(adaptor. getLhs (). getType (),
395
+ adaptor.getLhs (), adaptor.getRhs ());
258
396
rewriter.replaceOp (op, result);
259
397
return success ();
260
398
}
@@ -296,10 +434,29 @@ struct ArithToCGGI : public impl::ArithToCGGIBase<ArithToCGGI> {
296
434
target.addIllegalDialect <mlir::arith::ArithDialect>();
297
435
target.addLegalOp <mlir::arith::ConstantOp>();
298
436
437
+ target.addDynamicallyLegalOp <mlir::arith::SubIOp, mlir::arith::AddIOp,
438
+ mlir::arith::MulIOp>([&](Operation *op) {
439
+ if (auto *defLhsOp = op->getOperand (0 ).getDefiningOp ()) {
440
+ if (auto *defRhsOp = op->getOperand (1 ).getDefiningOp ()) {
441
+ return !hasLWEAnnotation (defLhsOp) && !hasLWEAnnotation (defRhsOp) &&
442
+ allowedRemainArith (defLhsOp) && allowedRemainArith (defRhsOp);
443
+ }
444
+ }
445
+ return false ;
446
+ });
447
+
299
448
target.addDynamicallyLegalOp <mlir::arith::ExtSIOp>([&](Operation *op) {
300
449
if (auto *defOp =
301
450
cast<mlir::arith::ExtSIOp>(op).getOperand ().getDefiningOp ()) {
302
- return hasLWEAnnotation (defOp) || allowedRemainArith (defOp);
451
+ return !hasLWEAnnotation (defOp) && allowedRemainArith (defOp);
452
+ }
453
+ return false ;
454
+ });
455
+
456
+ target.addDynamicallyLegalOp <mlir::arith::ExtUIOp>([&](Operation *op) {
457
+ if (auto *defOp =
458
+ cast<mlir::arith::ExtUIOp>(op).getOperand ().getDefiningOp ()) {
459
+ return !hasLWEAnnotation (defOp) && allowedRemainArith (defOp);
303
460
}
304
461
return false ;
305
462
});
@@ -317,14 +474,16 @@ struct ArithToCGGI : public impl::ArithToCGGIBase<ArithToCGGI> {
317
474
// accepts Check if there is at least one Store op that is a constants
318
475
auto containsAnyStoreOp = llvm::any_of (op->getUses (), [&](OpOperand &op) {
319
476
if (auto defOp = dyn_cast<memref::StoreOp>(op.getOwner ())) {
320
- return allowedRemainArith (defOp.getValue ().getDefiningOp ());
477
+ return !hasLWEAnnotation (defOp.getValue ().getDefiningOp ()) &&
478
+ allowedRemainArith (defOp.getValue ().getDefiningOp ());
321
479
}
322
480
return false ;
323
481
});
324
482
auto allStoreOpsAreArith =
325
483
llvm::all_of (op->getUses (), [&](OpOperand &op) {
326
484
if (auto defOp = dyn_cast<memref::StoreOp>(op.getOwner ())) {
327
- return allowedRemainArith (defOp.getValue ().getDefiningOp ());
485
+ return !hasLWEAnnotation (defOp.getValue ().getDefiningOp ()) &&
486
+ allowedRemainArith (defOp.getValue ().getDefiningOp ());
328
487
}
329
488
return true ;
330
489
});
@@ -390,10 +549,17 @@ struct ArithToCGGI : public impl::ArithToCGGIBase<ArithToCGGI> {
390
549
});
391
550
392
551
patterns.add <
393
- ConvertTruncIOp, ConvertExtUIOp, ConvertExtSIOp, ConvertShRUIOp,
552
+ ConvertTruncIOp, ConvertExtUIOp, ConvertExtSIOp, ConvertSelectOp,
553
+ ConvertCmpOp, ConvertSubOp,
554
+ ConvertShOp<mlir::arith::ShRSIOp, cggi::ScalarShiftRightOp>,
555
+ ConvertShOp<mlir::arith::ShRUIOp, cggi::ScalarShiftRightOp>,
556
+ ConvertShOp<mlir::arith::ShLIOp, cggi::ScalarShiftLeftOp>,
394
557
ConvertArithBinOp<mlir::arith::AddIOp, cggi::AddOp>,
395
558
ConvertArithBinOp<mlir::arith::MulIOp, cggi::MulOp>,
396
- ConvertArithBinOp<mlir::arith::SubIOp, cggi::SubOp>,
559
+ ConvertArithBinOp<mlir::arith::MaxSIOp, cggi::MaxOp>,
560
+ ConvertArithBinOp<mlir::arith::MinSIOp, cggi::MinOp>,
561
+ ConvertArithBinOp<mlir::arith::MaxUIOp, cggi::MaxOp>,
562
+ ConvertArithBinOp<mlir::arith::MinUIOp, cggi::MinOp>,
397
563
ConvertAny<memref::LoadOp>, ConvertAllocOp,
398
564
ConvertAny<memref::DeallocOp>, ConvertAny<memref::SubViewOp>,
399
565
ConvertAny<memref::CopyOp>, ConvertAny<memref::StoreOp>,
0 commit comments