Skip to content

Commit fae5e95

Browse files
Merge pull request #1403 from WoutLegiest:signed
PiperOrigin-RevId: 726985149
2 parents b320e5f + a119c47 commit fae5e95

File tree

8 files changed

+559
-333
lines changed

8 files changed

+559
-333
lines changed

lib/Dialect/Arith/Conversions/ArithToCGGI/ArithToCGGI.cpp

Lines changed: 193 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -70,15 +70,53 @@ static bool allowedRemainArith(Operation *op) {
7070
}
7171
return false;
7272
})
73+
.Case<mlir::arith::SubIOp, mlir::arith::AddIOp, mlir::arith::MulIOp>(
74+
[](auto op) {
75+
// This lambda will be called for any of the matched operation types
76+
if (auto lhsDefOp = op.getOperand(0).getDefiningOp()) {
77+
auto lshAllowed = allowedRemainArith(lhsDefOp);
78+
if (auto rhsDefOp = op.getOperand(1).getDefiningOp()) {
79+
auto rhsAllowed = allowedRemainArith(rhsDefOp);
80+
return lshAllowed && rhsAllowed;
81+
}
82+
}
83+
return false;
84+
})
7385
.Default([](Operation *) {
7486
// Default case for operations that don't match any of the types
7587
return false;
7688
});
7789
}
7890

7991
static bool hasLWEAnnotation(Operation *op) {
80-
return static_cast<bool>(
81-
op->getAttrOfType<mlir::StringAttr>("lwe_annotation"));
92+
mlir::StringAttr check =
93+
op->getAttrOfType<mlir::StringAttr>("lwe_annotation");
94+
95+
if (check) return true;
96+
97+
// Check recursively if a defining op has a LWE annotation
98+
return llvm::TypeSwitch<Operation *, bool>(op)
99+
.Case<mlir::arith::ExtUIOp, mlir::arith::ExtSIOp, mlir::arith::TruncIOp>(
100+
[](auto op) {
101+
if (auto *defOp = op.getIn().getDefiningOp()) {
102+
return hasLWEAnnotation(defOp);
103+
}
104+
return op->template getAttrOfType<mlir::StringAttr>(
105+
"lwe_annotation") != nullptr;
106+
})
107+
.Case<mlir::arith::SubIOp, mlir::arith::AddIOp, mlir::arith::MulIOp>(
108+
[](auto op) {
109+
// This lambda will be called for any of the matched operation types
110+
if (auto lhsDefOp = op.getOperand(0).getDefiningOp()) {
111+
auto lshAllowed = hasLWEAnnotation(lhsDefOp);
112+
if (auto rhsDefOp = op.getOperand(1).getDefiningOp()) {
113+
auto rhsAllowed = hasLWEAnnotation(rhsDefOp);
114+
return lshAllowed || rhsAllowed;
115+
}
116+
}
117+
return false;
118+
})
119+
.Default([](Operation *) { return false; });
82120
}
83121

84122
static Value materializeTarget(OpBuilder &builder, Type type, ValueRange inputs,
@@ -89,10 +127,18 @@ static Value materializeTarget(OpBuilder &builder, Type type, ValueRange inputs,
89127
llvm_unreachable(
90128
"Non-integer types should never be the input to a materializeTarget.");
91129

92-
auto inValue = inputs.front().getDefiningOp<mlir::arith::ConstantOp>();
93-
auto intAttr = cast<IntegerAttr>(inValue.getValueAttr());
130+
if (auto inValue = inputs.front().getDefiningOp<mlir::arith::ConstantOp>()) {
131+
auto intAttr = cast<IntegerAttr>(inValue.getValueAttr());
94132

95-
return builder.create<cggi::CreateTrivialOp>(loc, type, intAttr);
133+
return builder.create<cggi::CreateTrivialOp>(loc, type, intAttr);
134+
}
135+
// Comes from function/loop argument: Trivial encrypt through LWE
136+
auto encoding = cast<lwe::LWECiphertextType>(type).getEncoding();
137+
auto ptxtTy = lwe::LWEPlaintextType::get(builder.getContext(), encoding);
138+
return builder.create<lwe::TrivialEncryptOp>(
139+
loc, type,
140+
builder.create<lwe::EncodeOp>(loc, ptxtTy, inputs[0], encoding),
141+
lwe::LWEParamsAttr());
96142
}
97143

98144
class ArithToCGGITypeConverter : public TypeConverter {
@@ -175,18 +221,109 @@ struct ConvertExtSIOp : public OpConversionPattern<mlir::arith::ExtSIOp> {
175221
}
176222
};
177223

178-
struct ConvertShRUIOp : public OpConversionPattern<mlir::arith::ShRUIOp> {
179-
ConvertShRUIOp(mlir::MLIRContext *context)
180-
: OpConversionPattern<mlir::arith::ShRUIOp>(context) {}
224+
struct ConvertCmpOp : public OpConversionPattern<mlir::arith::CmpIOp> {
225+
ConvertCmpOp(mlir::MLIRContext *context)
226+
: OpConversionPattern<mlir::arith::CmpIOp>(context) {}
181227

182228
using OpConversionPattern::OpConversionPattern;
183229

184230
LogicalResult matchAndRewrite(
185-
mlir::arith::ShRUIOp op, OpAdaptor adaptor,
231+
mlir::arith::CmpIOp op, OpAdaptor adaptor,
186232
ConversionPatternRewriter &rewriter) const override {
187233
ImplicitLocOpBuilder b(op.getLoc(), rewriter);
188234

189-
auto cteShiftSizeOp = op.getRhs().getDefiningOp<mlir::arith::ConstantOp>();
235+
auto lweBooleanType = lwe::LWECiphertextType::get(
236+
op->getContext(),
237+
lwe::UnspecifiedBitFieldEncodingAttr::get(op->getContext(), 1),
238+
lwe::LWEParamsAttr());
239+
240+
if (auto lhsDefOp = op.getLhs().getDefiningOp()) {
241+
if (!hasLWEAnnotation(lhsDefOp) && allowedRemainArith(lhsDefOp)) {
242+
auto result = b.create<cggi::CmpOp>(lweBooleanType, op.getPredicate(),
243+
adaptor.getRhs(), op.getLhs());
244+
rewriter.replaceOp(op, result);
245+
return success();
246+
}
247+
}
248+
249+
if (auto rhsDefOp = op.getRhs().getDefiningOp()) {
250+
if (!hasLWEAnnotation(rhsDefOp) && allowedRemainArith(rhsDefOp)) {
251+
auto result = b.create<cggi::CmpOp>(lweBooleanType, op.getPredicate(),
252+
adaptor.getLhs(), op.getRhs());
253+
rewriter.replaceOp(op, result);
254+
return success();
255+
}
256+
}
257+
258+
auto cmpOp = b.create<cggi::CmpOp>(lweBooleanType, op.getPredicate(),
259+
adaptor.getLhs(), adaptor.getRhs());
260+
261+
rewriter.replaceOp(op, cmpOp);
262+
return success();
263+
}
264+
};
265+
266+
struct ConvertSubOp : public OpConversionPattern<mlir::arith::SubIOp> {
267+
ConvertSubOp(mlir::MLIRContext *context)
268+
: OpConversionPattern<mlir::arith::SubIOp>(context) {}
269+
270+
using OpConversionPattern::OpConversionPattern;
271+
272+
LogicalResult matchAndRewrite(
273+
mlir::arith::SubIOp op, OpAdaptor adaptor,
274+
ConversionPatternRewriter &rewriter) const override {
275+
ImplicitLocOpBuilder b(op.getLoc(), rewriter);
276+
277+
if (auto rhsDefOp = op.getRhs().getDefiningOp()) {
278+
if (!hasLWEAnnotation(rhsDefOp) && allowedRemainArith(rhsDefOp)) {
279+
auto result = b.create<cggi::SubOp>(adaptor.getLhs().getType(),
280+
adaptor.getLhs(), op.getRhs());
281+
rewriter.replaceOp(op, result);
282+
return success();
283+
}
284+
}
285+
286+
auto subOp = b.create<cggi::SubOp>(adaptor.getLhs().getType(),
287+
adaptor.getLhs(), adaptor.getRhs());
288+
rewriter.replaceOp(op, subOp);
289+
return success();
290+
}
291+
};
292+
293+
struct ConvertSelectOp : public OpConversionPattern<mlir::arith::SelectOp> {
294+
ConvertSelectOp(mlir::MLIRContext *context)
295+
: OpConversionPattern<mlir::arith::SelectOp>(context) {}
296+
297+
using OpConversionPattern::OpConversionPattern;
298+
299+
LogicalResult matchAndRewrite(
300+
mlir::arith::SelectOp op, OpAdaptor adaptor,
301+
ConversionPatternRewriter &rewriter) const override {
302+
ImplicitLocOpBuilder b(op.getLoc(), rewriter);
303+
304+
auto cmuxOp = b.create<cggi::SelectOp>(
305+
adaptor.getTrueValue().getType(), adaptor.getCondition(),
306+
adaptor.getTrueValue(), adaptor.getFalseValue());
307+
308+
rewriter.replaceOp(op, cmuxOp);
309+
return success();
310+
}
311+
};
312+
313+
template <typename SourceArithShOp, typename TargetCGGIShOp>
314+
struct ConvertShOp : public OpConversionPattern<SourceArithShOp> {
315+
ConvertShOp(mlir::MLIRContext *context)
316+
: OpConversionPattern<SourceArithShOp>(context) {}
317+
318+
using OpConversionPattern<SourceArithShOp>::OpConversionPattern;
319+
320+
LogicalResult matchAndRewrite(
321+
SourceArithShOp op, typename SourceArithShOp::Adaptor adaptor,
322+
ConversionPatternRewriter &rewriter) const override {
323+
ImplicitLocOpBuilder b(op.getLoc(), rewriter);
324+
325+
auto cteShiftSizeOp =
326+
op.getRhs().template getDefiningOp<mlir::arith::ConstantOp>();
190327

191328
if (cteShiftSizeOp) {
192329
auto outputType = adaptor.getLhs().getType();
@@ -198,14 +335,15 @@ struct ConvertShRUIOp : public OpConversionPattern<mlir::arith::ShRUIOp> {
198335
auto inputValue =
199336
mlir::IntegerAttr::get(rewriter.getIndexType(), (int8_t)shiftAmount);
200337

201-
auto shiftOp = b.create<cggi::ScalarShiftRightOp>(
202-
outputType, adaptor.getLhs(), inputValue);
338+
auto shiftOp =
339+
b.create<TargetCGGIShOp>(outputType, adaptor.getLhs(), inputValue);
203340
rewriter.replaceOp(op, shiftOp);
204341

205342
return success();
206343
}
207344

208-
cteShiftSizeOp = op.getLhs().getDefiningOp<mlir::arith::ConstantOp>();
345+
cteShiftSizeOp =
346+
op.getLhs().template getDefiningOp<mlir::arith::ConstantOp>();
209347

210348
auto outputType = adaptor.getRhs().getType();
211349

@@ -215,15 +353,15 @@ struct ConvertShRUIOp : public OpConversionPattern<mlir::arith::ShRUIOp> {
215353
auto inputValue =
216354
mlir::IntegerAttr::get(rewriter.getIndexType(), shiftAmount);
217355

218-
auto shiftOp = b.create<cggi::ScalarShiftRightOp>(
219-
outputType, adaptor.getLhs(), inputValue);
356+
auto shiftOp =
357+
b.create<TargetCGGIShOp>(outputType, adaptor.getLhs(), inputValue);
220358
rewriter.replaceOp(op, shiftOp);
221359

222360
return success();
223361
}
224362
};
225363

226-
template <typename SourceArithOp, typename TargetModArithOp>
364+
template <typename SourceArithOp, typename TargetCGGIOp>
227365
struct ConvertArithBinOp : public OpConversionPattern<SourceArithOp> {
228366
ConvertArithBinOp(mlir::MLIRContext *context)
229367
: OpConversionPattern<SourceArithOp>(context) {}
@@ -237,24 +375,24 @@ struct ConvertArithBinOp : public OpConversionPattern<SourceArithOp> {
237375

238376
if (auto lhsDefOp = op.getLhs().getDefiningOp()) {
239377
if (!hasLWEAnnotation(lhsDefOp) && allowedRemainArith(lhsDefOp)) {
240-
auto result = b.create<TargetModArithOp>(adaptor.getRhs().getType(),
241-
adaptor.getRhs(), op.getLhs());
378+
auto result = b.create<TargetCGGIOp>(adaptor.getRhs().getType(),
379+
adaptor.getRhs(), op.getLhs());
242380
rewriter.replaceOp(op, result);
243381
return success();
244382
}
245383
}
246384

247385
if (auto rhsDefOp = op.getRhs().getDefiningOp()) {
248386
if (!hasLWEAnnotation(rhsDefOp) && allowedRemainArith(rhsDefOp)) {
249-
auto result = b.create<TargetModArithOp>(adaptor.getLhs().getType(),
250-
adaptor.getLhs(), op.getRhs());
387+
auto result = b.create<TargetCGGIOp>(adaptor.getLhs().getType(),
388+
adaptor.getLhs(), op.getRhs());
251389
rewriter.replaceOp(op, result);
252390
return success();
253391
}
254392
}
255393

256-
auto result = b.create<TargetModArithOp>(
257-
adaptor.getLhs().getType(), adaptor.getLhs(), adaptor.getRhs());
394+
auto result = b.create<TargetCGGIOp>(adaptor.getLhs().getType(),
395+
adaptor.getLhs(), adaptor.getRhs());
258396
rewriter.replaceOp(op, result);
259397
return success();
260398
}
@@ -296,10 +434,29 @@ struct ArithToCGGI : public impl::ArithToCGGIBase<ArithToCGGI> {
296434
target.addIllegalDialect<mlir::arith::ArithDialect>();
297435
target.addLegalOp<mlir::arith::ConstantOp>();
298436

437+
target.addDynamicallyLegalOp<mlir::arith::SubIOp, mlir::arith::AddIOp,
438+
mlir::arith::MulIOp>([&](Operation *op) {
439+
if (auto *defLhsOp = op->getOperand(0).getDefiningOp()) {
440+
if (auto *defRhsOp = op->getOperand(1).getDefiningOp()) {
441+
return !hasLWEAnnotation(defLhsOp) && !hasLWEAnnotation(defRhsOp) &&
442+
allowedRemainArith(defLhsOp) && allowedRemainArith(defRhsOp);
443+
}
444+
}
445+
return false;
446+
});
447+
299448
target.addDynamicallyLegalOp<mlir::arith::ExtSIOp>([&](Operation *op) {
300449
if (auto *defOp =
301450
cast<mlir::arith::ExtSIOp>(op).getOperand().getDefiningOp()) {
302-
return hasLWEAnnotation(defOp) || allowedRemainArith(defOp);
451+
return !hasLWEAnnotation(defOp) && allowedRemainArith(defOp);
452+
}
453+
return false;
454+
});
455+
456+
target.addDynamicallyLegalOp<mlir::arith::ExtUIOp>([&](Operation *op) {
457+
if (auto *defOp =
458+
cast<mlir::arith::ExtUIOp>(op).getOperand().getDefiningOp()) {
459+
return !hasLWEAnnotation(defOp) && allowedRemainArith(defOp);
303460
}
304461
return false;
305462
});
@@ -317,14 +474,16 @@ struct ArithToCGGI : public impl::ArithToCGGIBase<ArithToCGGI> {
317474
// accepts Check if there is at least one Store op that is a constants
318475
auto containsAnyStoreOp = llvm::any_of(op->getUses(), [&](OpOperand &op) {
319476
if (auto defOp = dyn_cast<memref::StoreOp>(op.getOwner())) {
320-
return allowedRemainArith(defOp.getValue().getDefiningOp());
477+
return !hasLWEAnnotation(defOp.getValue().getDefiningOp()) &&
478+
allowedRemainArith(defOp.getValue().getDefiningOp());
321479
}
322480
return false;
323481
});
324482
auto allStoreOpsAreArith =
325483
llvm::all_of(op->getUses(), [&](OpOperand &op) {
326484
if (auto defOp = dyn_cast<memref::StoreOp>(op.getOwner())) {
327-
return allowedRemainArith(defOp.getValue().getDefiningOp());
485+
return !hasLWEAnnotation(defOp.getValue().getDefiningOp()) &&
486+
allowedRemainArith(defOp.getValue().getDefiningOp());
328487
}
329488
return true;
330489
});
@@ -390,10 +549,17 @@ struct ArithToCGGI : public impl::ArithToCGGIBase<ArithToCGGI> {
390549
});
391550

392551
patterns.add<
393-
ConvertTruncIOp, ConvertExtUIOp, ConvertExtSIOp, ConvertShRUIOp,
552+
ConvertTruncIOp, ConvertExtUIOp, ConvertExtSIOp, ConvertSelectOp,
553+
ConvertCmpOp, ConvertSubOp,
554+
ConvertShOp<mlir::arith::ShRSIOp, cggi::ScalarShiftRightOp>,
555+
ConvertShOp<mlir::arith::ShRUIOp, cggi::ScalarShiftRightOp>,
556+
ConvertShOp<mlir::arith::ShLIOp, cggi::ScalarShiftLeftOp>,
394557
ConvertArithBinOp<mlir::arith::AddIOp, cggi::AddOp>,
395558
ConvertArithBinOp<mlir::arith::MulIOp, cggi::MulOp>,
396-
ConvertArithBinOp<mlir::arith::SubIOp, cggi::SubOp>,
559+
ConvertArithBinOp<mlir::arith::MaxSIOp, cggi::MaxOp>,
560+
ConvertArithBinOp<mlir::arith::MinSIOp, cggi::MinOp>,
561+
ConvertArithBinOp<mlir::arith::MaxUIOp, cggi::MaxOp>,
562+
ConvertArithBinOp<mlir::arith::MinUIOp, cggi::MinOp>,
397563
ConvertAny<memref::LoadOp>, ConvertAllocOp,
398564
ConvertAny<memref::DeallocOp>, ConvertAny<memref::SubViewOp>,
399565
ConvertAny<memref::CopyOp>, ConvertAny<memref::StoreOp>,

lib/Dialect/Arith/Conversions/ArithToCGGIQuart/ArithToCGGIQuart.cpp

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -434,17 +434,17 @@ struct ConvertQuartMulI final : OpConversionPattern<mlir::arith::MulIOp> {
434434
auto z01_p1 = b.create<cggi::AddOp>(elemTy, splitLhs[0], splitLhs[1]);
435435
auto z01_p2 = b.create<cggi::AddOp>(elemTy, splitRhs[0], splitRhs[1]);
436436
auto z01_m = b.create<cggi::MulOp>(elemTy, z01_p1, z01_p2);
437-
auto z01_s = b.create<cggi::SubOp>(z01_m, z00);
438-
auto z01 = b.create<cggi::SubOp>(z01_s, z02);
437+
auto z01_s = b.create<cggi::SubOp>(elemTy, z01_m, z00);
438+
auto z01 = b.create<cggi::SubOp>(elemTy, z01_s, z02);
439439

440440
// Second part I of Karatsuba algorithm
441441
auto z1a0 = b.create<cggi::MulOp>(elemTy, splitLhs[0], splitRhs[2]);
442442
auto z1a2 = b.create<cggi::MulOp>(elemTy, splitLhs[1], splitRhs[3]);
443443
auto z1a1_p1 = b.create<cggi::AddOp>(elemTy, splitLhs[0], splitLhs[1]);
444444
auto z1a1_p2 = b.create<cggi::AddOp>(elemTy, splitRhs[2], splitRhs[3]);
445445
auto z1a1_m = b.create<cggi::MulOp>(elemTy, z1a1_p1, z1a1_p2);
446-
auto z1a1_s = b.create<cggi::SubOp>(z1a1_m, z1a0);
447-
auto z1a1 = b.create<cggi::SubOp>(z1a1_s, z1a2);
446+
auto z1a1_s = b.create<cggi::SubOp>(elemTy, z1a1_m, z1a0);
447+
auto z1a1 = b.create<cggi::SubOp>(elemTy, z1a1_s, z1a2);
448448

449449
// Second part II of Karatsuba algorithm
450450
auto z1b0 = b.create<cggi::MulOp>(elemTy, splitLhs[2], splitRhs[0]);
@@ -453,7 +453,7 @@ struct ConvertQuartMulI final : OpConversionPattern<mlir::arith::MulIOp> {
453453
auto z1b1_p2 = b.create<cggi::AddOp>(elemTy, splitRhs[0], splitRhs[1]);
454454
auto z1b1_m = b.create<cggi::MulOp>(elemTy, z1b1_p1, z1b1_p2);
455455
auto z1b1_s = b.create<cggi::SubOp>(elemTy, z1b1_m, z1b0);
456-
auto z1b1 = b.create<cggi::SubOp>(z1b1_s, z1b2);
456+
auto z1b1 = b.create<cggi::SubOp>(elemTy, z1b1_s, z1b2);
457457

458458
auto out2Kara = b.create<cggi::AddOp>(elemTy, z1a0, z1b0);
459459
auto out2Carry = b.create<cggi::AddOp>(elemTy, out2Kara, z02);

lib/Dialect/CGGI/IR/BUILD

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,7 @@ cc_library(
2626
"@heir//lib/Dialect:HEIRInterfaces",
2727
"@heir//lib/Dialect/LWE/IR:Dialect",
2828
"@llvm-project//llvm:Support",
29+
"@llvm-project//mlir:ArithDialect",
2930
"@llvm-project//mlir:IR",
3031
"@llvm-project//mlir:InferTypeOpInterface",
3132
"@llvm-project//mlir:Support",
@@ -37,13 +38,16 @@ td_library(
3738
srcs = [
3839
"BooleanGates.td",
3940
"CGGIAttributes.td",
41+
"CGGIBinOps.td",
4042
"CGGIDialect.td",
4143
"CGGIEnums.td",
4244
"CGGIOps.td",
45+
"CGGIPBSOps.td",
4346
],
4447
# include from the heir-root to enable fully-qualified include-paths
4548
includes = ["../../../.."],
4649
deps = [
50+
"@llvm-project//mlir:ArithOpsTdFiles",
4751
"@llvm-project//mlir:BuiltinDialectTdFiles",
4852
"@llvm-project//mlir:InferTypeOpInterfaceTdFiles",
4953
"@llvm-project//mlir:OpBaseTdFiles",

0 commit comments

Comments
 (0)