@@ -51,15 +51,53 @@ static bool allowedRemainArith(Operation *op) {
51
51
}
52
52
return false ;
53
53
})
54
+ .Case <mlir::arith::SubIOp, mlir::arith::AddIOp, mlir::arith::MulIOp>(
55
+ [](auto op) {
56
+ // This lambda will be called for any of the matched operation types
57
+ if (auto lhsDefOp = op.getOperand (0 ).getDefiningOp ()) {
58
+ auto lshAllowed = allowedRemainArith (lhsDefOp);
59
+ if (auto rhsDefOp = op.getOperand (1 ).getDefiningOp ()) {
60
+ auto rhsAllowed = allowedRemainArith (rhsDefOp);
61
+ return lshAllowed && rhsAllowed;
62
+ }
63
+ }
64
+ return false ;
65
+ })
54
66
.Default ([](Operation *) {
55
67
// Default case for operations that don't match any of the types
56
68
return false ;
57
69
});
58
70
}
59
71
60
72
static bool hasLWEAnnotation (Operation *op) {
61
- return static_cast <bool >(
62
- op->getAttrOfType <mlir::StringAttr>(" lwe_annotation" ));
73
+ auto check =
74
+ static_cast <bool >(op->getAttrOfType <mlir::StringAttr>(" lwe_annotation" ));
75
+
76
+ if (check) return check;
77
+
78
+ // Check recursively if a defining op has a LWE annotation
79
+ return llvm::TypeSwitch<Operation *, bool >(op)
80
+ .Case <mlir::arith::ExtUIOp, mlir::arith::ExtSIOp, mlir::arith::TruncIOp>(
81
+ [](auto op) {
82
+ if (auto *defOp = op.getIn ().getDefiningOp ()) {
83
+ return hasLWEAnnotation (defOp);
84
+ }
85
+ return static_cast <bool >(
86
+ op->template getAttrOfType <mlir::StringAttr>(" lwe_annotation" ));
87
+ })
88
+ .Case <mlir::arith::SubIOp, mlir::arith::AddIOp, mlir::arith::MulIOp>(
89
+ [](auto op) {
90
+ // This lambda will be called for any of the matched operation types
91
+ if (auto lhsDefOp = op.getOperand (0 ).getDefiningOp ()) {
92
+ auto lshAllowed = hasLWEAnnotation (lhsDefOp);
93
+ if (auto rhsDefOp = op.getOperand (1 ).getDefiningOp ()) {
94
+ auto rhsAllowed = hasLWEAnnotation (rhsDefOp);
95
+ return lshAllowed || rhsAllowed;
96
+ }
97
+ }
98
+ return false ;
99
+ })
100
+ .Default ([](Operation *) { return false ; });
63
101
}
64
102
65
103
static Value materializeTarget (OpBuilder &builder, Type type, ValueRange inputs,
@@ -70,10 +108,18 @@ static Value materializeTarget(OpBuilder &builder, Type type, ValueRange inputs,
70
108
llvm_unreachable (
71
109
" Non-integer types should never be the input to a materializeTarget." );
72
110
73
- auto inValue = inputs.front ().getDefiningOp <mlir::arith::ConstantOp>();
74
- auto intAttr = cast<IntegerAttr>(inValue.getValueAttr ());
111
+ if ( auto inValue = inputs.front ().getDefiningOp <mlir::arith::ConstantOp>()) {
112
+ auto intAttr = cast<IntegerAttr>(inValue.getValueAttr ());
75
113
76
- return builder.create <cggi::CreateTrivialOp>(loc, type, intAttr);
114
+ return builder.create <cggi::CreateTrivialOp>(loc, type, intAttr);
115
+ }
116
+ // Comes from function/loop argument: Trivial encrypt through LWE
117
+ auto encoding = cast<lwe::LWECiphertextType>(type).getEncoding ();
118
+ auto ptxtTy = lwe::LWEPlaintextType::get (builder.getContext (), encoding);
119
+ return builder.create <lwe::TrivialEncryptOp>(
120
+ loc, type,
121
+ builder.create <lwe::EncodeOp>(loc, ptxtTy, inputs[0 ], encoding),
122
+ lwe::LWEParamsAttr ());
77
123
}
78
124
79
125
class ArithToCGGITypeConverter : public TypeConverter {
@@ -156,18 +202,109 @@ struct ConvertExtSIOp : public OpConversionPattern<mlir::arith::ExtSIOp> {
156
202
}
157
203
};
158
204
159
- struct ConvertShRUIOp : public OpConversionPattern <mlir::arith::ShRUIOp > {
160
- ConvertShRUIOp (mlir::MLIRContext *context)
161
- : OpConversionPattern<mlir::arith::ShRUIOp >(context) {}
205
+ struct ConvertCmpOp : public OpConversionPattern <mlir::arith::CmpIOp > {
206
+ ConvertCmpOp (mlir::MLIRContext *context)
207
+ : OpConversionPattern<mlir::arith::CmpIOp >(context) {}
162
208
163
209
using OpConversionPattern::OpConversionPattern;
164
210
165
211
LogicalResult matchAndRewrite (
166
- mlir::arith::ShRUIOp op, OpAdaptor adaptor,
212
+ mlir::arith::CmpIOp op, OpAdaptor adaptor,
167
213
ConversionPatternRewriter &rewriter) const override {
168
214
ImplicitLocOpBuilder b (op.getLoc (), rewriter);
169
215
170
- auto cteShiftSizeOp = op.getRhs ().getDefiningOp <mlir::arith::ConstantOp>();
216
+ auto lweBooleanType = lwe::LWECiphertextType::get (
217
+ op->getContext (),
218
+ lwe::UnspecifiedBitFieldEncodingAttr::get (op->getContext (), 1 ),
219
+ lwe::LWEParamsAttr ());
220
+
221
+ if (auto lhsDefOp = op.getLhs ().getDefiningOp ()) {
222
+ if (!hasLWEAnnotation (lhsDefOp) && allowedRemainArith (lhsDefOp)) {
223
+ auto result = b.create <cggi::CmpOp>(lweBooleanType, op.getPredicate (),
224
+ adaptor.getRhs (), op.getLhs ());
225
+ rewriter.replaceOp (op, result);
226
+ return success ();
227
+ }
228
+ }
229
+
230
+ if (auto rhsDefOp = op.getRhs ().getDefiningOp ()) {
231
+ if (!hasLWEAnnotation (rhsDefOp) && allowedRemainArith (rhsDefOp)) {
232
+ auto result = b.create <cggi::CmpOp>(lweBooleanType, op.getPredicate (),
233
+ adaptor.getLhs (), op.getRhs ());
234
+ rewriter.replaceOp (op, result);
235
+ return success ();
236
+ }
237
+ }
238
+
239
+ auto cmpOp = b.create <cggi::CmpOp>(lweBooleanType, op.getPredicate (),
240
+ adaptor.getLhs (), adaptor.getRhs ());
241
+
242
+ rewriter.replaceOp (op, cmpOp);
243
+ return success ();
244
+ }
245
+ };
246
+
247
+ struct ConvertSubOp : public OpConversionPattern <mlir::arith::SubIOp> {
248
+ ConvertSubOp (mlir::MLIRContext *context)
249
+ : OpConversionPattern<mlir::arith::SubIOp>(context) {}
250
+
251
+ using OpConversionPattern::OpConversionPattern;
252
+
253
+ LogicalResult matchAndRewrite (
254
+ mlir::arith::SubIOp op, OpAdaptor adaptor,
255
+ ConversionPatternRewriter &rewriter) const override {
256
+ ImplicitLocOpBuilder b (op.getLoc (), rewriter);
257
+
258
+ if (auto rhsDefOp = op.getRhs ().getDefiningOp ()) {
259
+ if (!hasLWEAnnotation (rhsDefOp) && allowedRemainArith (rhsDefOp)) {
260
+ auto result = b.create <cggi::SubOp>(adaptor.getLhs ().getType (),
261
+ adaptor.getLhs (), op.getRhs ());
262
+ rewriter.replaceOp (op, result);
263
+ return success ();
264
+ }
265
+ }
266
+
267
+ auto subOp = b.create <cggi::SubOp>(adaptor.getLhs ().getType (),
268
+ adaptor.getLhs (), adaptor.getRhs ());
269
+ rewriter.replaceOp (op, subOp);
270
+ return success ();
271
+ }
272
+ };
273
+
274
+ struct ConvertSelectOp : public OpConversionPattern <mlir::arith::SelectOp> {
275
+ ConvertSelectOp (mlir::MLIRContext *context)
276
+ : OpConversionPattern<mlir::arith::SelectOp>(context) {}
277
+
278
+ using OpConversionPattern::OpConversionPattern;
279
+
280
+ LogicalResult matchAndRewrite (
281
+ mlir::arith::SelectOp op, OpAdaptor adaptor,
282
+ ConversionPatternRewriter &rewriter) const override {
283
+ ImplicitLocOpBuilder b (op.getLoc (), rewriter);
284
+
285
+ auto cmuxOp = b.create <cggi::SelectOp>(
286
+ adaptor.getTrueValue ().getType (), adaptor.getCondition (),
287
+ adaptor.getTrueValue (), adaptor.getFalseValue ());
288
+
289
+ rewriter.replaceOp (op, cmuxOp);
290
+ return success ();
291
+ }
292
+ };
293
+
294
+ template <typename SourceArithShOp, typename TargetCGGIShOp>
295
+ struct ConvertShOp : public OpConversionPattern <SourceArithShOp> {
296
+ ConvertShOp (mlir::MLIRContext *context)
297
+ : OpConversionPattern<SourceArithShOp>(context) {}
298
+
299
+ using OpConversionPattern<SourceArithShOp>::OpConversionPattern;
300
+
301
+ LogicalResult matchAndRewrite (
302
+ SourceArithShOp op, typename SourceArithShOp::Adaptor adaptor,
303
+ ConversionPatternRewriter &rewriter) const override {
304
+ ImplicitLocOpBuilder b (op.getLoc (), rewriter);
305
+
306
+ auto cteShiftSizeOp =
307
+ op.getRhs ().template getDefiningOp <mlir::arith::ConstantOp>();
171
308
172
309
if (cteShiftSizeOp) {
173
310
auto outputType = adaptor.getLhs ().getType ();
@@ -179,14 +316,15 @@ struct ConvertShRUIOp : public OpConversionPattern<mlir::arith::ShRUIOp> {
179
316
auto inputValue =
180
317
mlir::IntegerAttr::get (rewriter.getIndexType (), (int8_t )shiftAmount);
181
318
182
- auto shiftOp = b. create <cggi::ScalarShiftRightOp>(
183
- outputType, adaptor.getLhs (), inputValue);
319
+ auto shiftOp =
320
+ b. create <TargetCGGIShOp>( outputType, adaptor.getLhs (), inputValue);
184
321
rewriter.replaceOp (op, shiftOp);
185
322
186
323
return success ();
187
324
}
188
325
189
- cteShiftSizeOp = op.getLhs ().getDefiningOp <mlir::arith::ConstantOp>();
326
+ cteShiftSizeOp =
327
+ op.getLhs ().template getDefiningOp <mlir::arith::ConstantOp>();
190
328
191
329
auto outputType = adaptor.getRhs ().getType ();
192
330
@@ -196,15 +334,15 @@ struct ConvertShRUIOp : public OpConversionPattern<mlir::arith::ShRUIOp> {
196
334
auto inputValue =
197
335
mlir::IntegerAttr::get (rewriter.getIndexType (), shiftAmount);
198
336
199
- auto shiftOp = b. create <cggi::ScalarShiftRightOp>(
200
- outputType, adaptor.getLhs (), inputValue);
337
+ auto shiftOp =
338
+ b. create <TargetCGGIShOp>( outputType, adaptor.getLhs (), inputValue);
201
339
rewriter.replaceOp (op, shiftOp);
202
340
203
341
return success ();
204
342
}
205
343
};
206
344
207
- template <typename SourceArithOp, typename TargetModArithOp >
345
+ template <typename SourceArithOp, typename TargetCGGIOp >
208
346
struct ConvertArithBinOp : public OpConversionPattern <SourceArithOp> {
209
347
ConvertArithBinOp (mlir::MLIRContext *context)
210
348
: OpConversionPattern<SourceArithOp>(context) {}
@@ -218,24 +356,24 @@ struct ConvertArithBinOp : public OpConversionPattern<SourceArithOp> {
218
356
219
357
if (auto lhsDefOp = op.getLhs ().getDefiningOp ()) {
220
358
if (!hasLWEAnnotation (lhsDefOp) && allowedRemainArith (lhsDefOp)) {
221
- auto result = b.create <TargetModArithOp >(adaptor.getRhs ().getType (),
222
- adaptor.getRhs (), op.getLhs ());
359
+ auto result = b.create <TargetCGGIOp >(adaptor.getRhs ().getType (),
360
+ adaptor.getRhs (), op.getLhs ());
223
361
rewriter.replaceOp (op, result);
224
362
return success ();
225
363
}
226
364
}
227
365
228
366
if (auto rhsDefOp = op.getRhs ().getDefiningOp ()) {
229
367
if (!hasLWEAnnotation (rhsDefOp) && allowedRemainArith (rhsDefOp)) {
230
- auto result = b.create <TargetModArithOp >(adaptor.getLhs ().getType (),
231
- adaptor.getLhs (), op.getRhs ());
368
+ auto result = b.create <TargetCGGIOp >(adaptor.getLhs ().getType (),
369
+ adaptor.getLhs (), op.getRhs ());
232
370
rewriter.replaceOp (op, result);
233
371
return success ();
234
372
}
235
373
}
236
374
237
- auto result = b.create <TargetModArithOp>(
238
- adaptor. getLhs (). getType (), adaptor.getLhs (), adaptor.getRhs ());
375
+ auto result = b.create <TargetCGGIOp>(adaptor. getLhs (). getType (),
376
+ adaptor.getLhs (), adaptor.getRhs ());
239
377
rewriter.replaceOp (op, result);
240
378
return success ();
241
379
}
@@ -277,10 +415,29 @@ struct ArithToCGGI : public impl::ArithToCGGIBase<ArithToCGGI> {
277
415
target.addIllegalDialect <mlir::arith::ArithDialect>();
278
416
target.addLegalOp <mlir::arith::ConstantOp>();
279
417
418
+ target.addDynamicallyLegalOp <mlir::arith::SubIOp, mlir::arith::AddIOp,
419
+ mlir::arith::MulIOp>([&](Operation *op) {
420
+ if (auto *defLhsOp = op->getOperand (0 ).getDefiningOp ()) {
421
+ if (auto *defRhsOp = op->getOperand (1 ).getDefiningOp ()) {
422
+ return !hasLWEAnnotation (defLhsOp) && !hasLWEAnnotation (defRhsOp) &&
423
+ allowedRemainArith (defLhsOp) && allowedRemainArith (defRhsOp);
424
+ }
425
+ }
426
+ return false ;
427
+ });
428
+
280
429
target.addDynamicallyLegalOp <mlir::arith::ExtSIOp>([&](Operation *op) {
281
430
if (auto *defOp =
282
431
cast<mlir::arith::ExtSIOp>(op).getOperand ().getDefiningOp ()) {
283
- return hasLWEAnnotation (defOp) || allowedRemainArith (defOp);
432
+ return !hasLWEAnnotation (defOp) && allowedRemainArith (defOp);
433
+ }
434
+ return false ;
435
+ });
436
+
437
+ target.addDynamicallyLegalOp <mlir::arith::ExtUIOp>([&](Operation *op) {
438
+ if (auto *defOp =
439
+ cast<mlir::arith::ExtUIOp>(op).getOperand ().getDefiningOp ()) {
440
+ return !hasLWEAnnotation (defOp) && allowedRemainArith (defOp);
284
441
}
285
442
return false ;
286
443
});
@@ -298,14 +455,16 @@ struct ArithToCGGI : public impl::ArithToCGGIBase<ArithToCGGI> {
298
455
// accepts Check if there is at least one Store op that is a constants
299
456
auto containsAnyStoreOp = llvm::any_of (op->getUses (), [&](OpOperand &op) {
300
457
if (auto defOp = dyn_cast<memref::StoreOp>(op.getOwner ())) {
301
- return allowedRemainArith (defOp.getValue ().getDefiningOp ());
458
+ return !hasLWEAnnotation (defOp.getValue ().getDefiningOp ()) &&
459
+ allowedRemainArith (defOp.getValue ().getDefiningOp ());
302
460
}
303
461
return false ;
304
462
});
305
463
auto allStoreOpsAreArith =
306
464
llvm::all_of (op->getUses (), [&](OpOperand &op) {
307
465
if (auto defOp = dyn_cast<memref::StoreOp>(op.getOwner ())) {
308
- return allowedRemainArith (defOp.getValue ().getDefiningOp ());
466
+ return !hasLWEAnnotation (defOp.getValue ().getDefiningOp ()) &&
467
+ allowedRemainArith (defOp.getValue ().getDefiningOp ());
309
468
}
310
469
return true ;
311
470
});
@@ -371,10 +530,17 @@ struct ArithToCGGI : public impl::ArithToCGGIBase<ArithToCGGI> {
371
530
});
372
531
373
532
patterns.add <
374
- ConvertTruncIOp, ConvertExtUIOp, ConvertExtSIOp, ConvertShRUIOp,
533
+ ConvertTruncIOp, ConvertExtUIOp, ConvertExtSIOp, ConvertSelectOp,
534
+ ConvertCmpOp, ConvertSubOp,
535
+ ConvertShOp<mlir::arith::ShRSIOp, cggi::ScalarShiftRightOp>,
536
+ ConvertShOp<mlir::arith::ShRUIOp, cggi::ScalarShiftRightOp>,
537
+ ConvertShOp<mlir::arith::ShLIOp, cggi::ScalarShiftLeftOp>,
375
538
ConvertArithBinOp<mlir::arith::AddIOp, cggi::AddOp>,
376
539
ConvertArithBinOp<mlir::arith::MulIOp, cggi::MulOp>,
377
- ConvertArithBinOp<mlir::arith::SubIOp, cggi::SubOp>,
540
+ ConvertArithBinOp<mlir::arith::MaxSIOp, cggi::MaxOp>,
541
+ ConvertArithBinOp<mlir::arith::MinSIOp, cggi::MinOp>,
542
+ ConvertArithBinOp<mlir::arith::MaxUIOp, cggi::MaxOp>,
543
+ ConvertArithBinOp<mlir::arith::MinUIOp, cggi::MinOp>,
378
544
ConvertAny<memref::LoadOp>, ConvertAllocOp,
379
545
ConvertAny<memref::DeallocOp>, ConvertAny<memref::SubViewOp>,
380
546
ConvertAny<memref::CopyOp>, ConvertAny<memref::StoreOp>,
0 commit comments