Skip to content

Commit 9795e2f

Browse files
committed
Adding support of the Hello World Small example to Arith-to-CGGI
1 parent d7fbf06 commit 9795e2f

File tree

5 files changed

+303
-82
lines changed

5 files changed

+303
-82
lines changed

lib/Dialect/Arith/Conversions/ArithToCGGI/ArithToCGGI.cpp

Lines changed: 193 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -51,15 +51,53 @@ static bool allowedRemainArith(Operation *op) {
5151
}
5252
return false;
5353
})
54+
.Case<mlir::arith::SubIOp, mlir::arith::AddIOp, mlir::arith::MulIOp>(
55+
[](auto op) {
56+
// This lambda will be called for any of the matched operation types
57+
if (auto lhsDefOp = op.getOperand(0).getDefiningOp()) {
58+
auto lshAllowed = allowedRemainArith(lhsDefOp);
59+
if (auto rhsDefOp = op.getOperand(1).getDefiningOp()) {
60+
auto rhsAllowed = allowedRemainArith(rhsDefOp);
61+
return lshAllowed && rhsAllowed;
62+
}
63+
}
64+
return false;
65+
})
5466
.Default([](Operation *) {
5567
// Default case for operations that don't match any of the types
5668
return false;
5769
});
5870
}
5971

6072
static bool hasLWEAnnotation(Operation *op) {
61-
return static_cast<bool>(
62-
op->getAttrOfType<mlir::StringAttr>("lwe_annotation"));
73+
auto check =
74+
static_cast<bool>(op->getAttrOfType<mlir::StringAttr>("lwe_annotation"));
75+
76+
if (check) return check;
77+
78+
// Check recursively if a defining op has a LWE annotation
79+
return llvm::TypeSwitch<Operation *, bool>(op)
80+
.Case<mlir::arith::ExtUIOp, mlir::arith::ExtSIOp, mlir::arith::TruncIOp>(
81+
[](auto op) {
82+
if (auto *defOp = op.getIn().getDefiningOp()) {
83+
return hasLWEAnnotation(defOp);
84+
}
85+
return static_cast<bool>(
86+
op->template getAttrOfType<mlir::StringAttr>("lwe_annotation"));
87+
})
88+
.Case<mlir::arith::SubIOp, mlir::arith::AddIOp, mlir::arith::MulIOp>(
89+
[](auto op) {
90+
// This lambda will be called for any of the matched operation types
91+
if (auto lhsDefOp = op.getOperand(0).getDefiningOp()) {
92+
auto lshAllowed = hasLWEAnnotation(lhsDefOp);
93+
if (auto rhsDefOp = op.getOperand(1).getDefiningOp()) {
94+
auto rhsAllowed = hasLWEAnnotation(rhsDefOp);
95+
return lshAllowed || rhsAllowed;
96+
}
97+
}
98+
return false;
99+
})
100+
.Default([](Operation *) { return false; });
63101
}
64102

65103
static Value materializeTarget(OpBuilder &builder, Type type, ValueRange inputs,
@@ -70,10 +108,18 @@ static Value materializeTarget(OpBuilder &builder, Type type, ValueRange inputs,
70108
llvm_unreachable(
71109
"Non-integer types should never be the input to a materializeTarget.");
72110

73-
auto inValue = inputs.front().getDefiningOp<mlir::arith::ConstantOp>();
74-
auto intAttr = cast<IntegerAttr>(inValue.getValueAttr());
111+
if (auto inValue = inputs.front().getDefiningOp<mlir::arith::ConstantOp>()) {
112+
auto intAttr = cast<IntegerAttr>(inValue.getValueAttr());
75113

76-
return builder.create<cggi::CreateTrivialOp>(loc, type, intAttr);
114+
return builder.create<cggi::CreateTrivialOp>(loc, type, intAttr);
115+
}
116+
// Comes from function/loop argument: Trivial encrypt through LWE
117+
auto encoding = cast<lwe::LWECiphertextType>(type).getEncoding();
118+
auto ptxtTy = lwe::LWEPlaintextType::get(builder.getContext(), encoding);
119+
return builder.create<lwe::TrivialEncryptOp>(
120+
loc, type,
121+
builder.create<lwe::EncodeOp>(loc, ptxtTy, inputs[0], encoding),
122+
lwe::LWEParamsAttr());
77123
}
78124

79125
class ArithToCGGITypeConverter : public TypeConverter {
@@ -156,18 +202,109 @@ struct ConvertExtSIOp : public OpConversionPattern<mlir::arith::ExtSIOp> {
156202
}
157203
};
158204

159-
struct ConvertShRUIOp : public OpConversionPattern<mlir::arith::ShRUIOp> {
160-
ConvertShRUIOp(mlir::MLIRContext *context)
161-
: OpConversionPattern<mlir::arith::ShRUIOp>(context) {}
205+
struct ConvertCmpOp : public OpConversionPattern<mlir::arith::CmpIOp> {
206+
ConvertCmpOp(mlir::MLIRContext *context)
207+
: OpConversionPattern<mlir::arith::CmpIOp>(context) {}
162208

163209
using OpConversionPattern::OpConversionPattern;
164210

165211
LogicalResult matchAndRewrite(
166-
mlir::arith::ShRUIOp op, OpAdaptor adaptor,
212+
mlir::arith::CmpIOp op, OpAdaptor adaptor,
167213
ConversionPatternRewriter &rewriter) const override {
168214
ImplicitLocOpBuilder b(op.getLoc(), rewriter);
169215

170-
auto cteShiftSizeOp = op.getRhs().getDefiningOp<mlir::arith::ConstantOp>();
216+
auto lweBooleanType = lwe::LWECiphertextType::get(
217+
op->getContext(),
218+
lwe::UnspecifiedBitFieldEncodingAttr::get(op->getContext(), 1),
219+
lwe::LWEParamsAttr());
220+
221+
if (auto lhsDefOp = op.getLhs().getDefiningOp()) {
222+
if (!hasLWEAnnotation(lhsDefOp) && allowedRemainArith(lhsDefOp)) {
223+
auto result = b.create<cggi::CmpOp>(lweBooleanType, op.getPredicate(),
224+
adaptor.getRhs(), op.getLhs());
225+
rewriter.replaceOp(op, result);
226+
return success();
227+
}
228+
}
229+
230+
if (auto rhsDefOp = op.getRhs().getDefiningOp()) {
231+
if (!hasLWEAnnotation(rhsDefOp) && allowedRemainArith(rhsDefOp)) {
232+
auto result = b.create<cggi::CmpOp>(lweBooleanType, op.getPredicate(),
233+
adaptor.getLhs(), op.getRhs());
234+
rewriter.replaceOp(op, result);
235+
return success();
236+
}
237+
}
238+
239+
auto cmpOp = b.create<cggi::CmpOp>(lweBooleanType, op.getPredicate(),
240+
adaptor.getLhs(), adaptor.getRhs());
241+
242+
rewriter.replaceOp(op, cmpOp);
243+
return success();
244+
}
245+
};
246+
247+
struct ConvertSubOp : public OpConversionPattern<mlir::arith::SubIOp> {
248+
ConvertSubOp(mlir::MLIRContext *context)
249+
: OpConversionPattern<mlir::arith::SubIOp>(context) {}
250+
251+
using OpConversionPattern::OpConversionPattern;
252+
253+
LogicalResult matchAndRewrite(
254+
mlir::arith::SubIOp op, OpAdaptor adaptor,
255+
ConversionPatternRewriter &rewriter) const override {
256+
ImplicitLocOpBuilder b(op.getLoc(), rewriter);
257+
258+
if (auto rhsDefOp = op.getRhs().getDefiningOp()) {
259+
if (!hasLWEAnnotation(rhsDefOp) && allowedRemainArith(rhsDefOp)) {
260+
auto result = b.create<cggi::SubOp>(adaptor.getLhs().getType(),
261+
adaptor.getLhs(), op.getRhs());
262+
rewriter.replaceOp(op, result);
263+
return success();
264+
}
265+
}
266+
267+
auto subOp = b.create<cggi::SubOp>(adaptor.getLhs().getType(),
268+
adaptor.getLhs(), adaptor.getRhs());
269+
rewriter.replaceOp(op, subOp);
270+
return success();
271+
}
272+
};
273+
274+
struct ConvertSelectOp : public OpConversionPattern<mlir::arith::SelectOp> {
275+
ConvertSelectOp(mlir::MLIRContext *context)
276+
: OpConversionPattern<mlir::arith::SelectOp>(context) {}
277+
278+
using OpConversionPattern::OpConversionPattern;
279+
280+
LogicalResult matchAndRewrite(
281+
mlir::arith::SelectOp op, OpAdaptor adaptor,
282+
ConversionPatternRewriter &rewriter) const override {
283+
ImplicitLocOpBuilder b(op.getLoc(), rewriter);
284+
285+
auto cmuxOp = b.create<cggi::SelectOp>(
286+
adaptor.getTrueValue().getType(), adaptor.getCondition(),
287+
adaptor.getTrueValue(), adaptor.getFalseValue());
288+
289+
rewriter.replaceOp(op, cmuxOp);
290+
return success();
291+
}
292+
};
293+
294+
template <typename SourceArithShOp, typename TargetCGGIShOp>
295+
struct ConvertShOp : public OpConversionPattern<SourceArithShOp> {
296+
ConvertShOp(mlir::MLIRContext *context)
297+
: OpConversionPattern<SourceArithShOp>(context) {}
298+
299+
using OpConversionPattern<SourceArithShOp>::OpConversionPattern;
300+
301+
LogicalResult matchAndRewrite(
302+
SourceArithShOp op, typename SourceArithShOp::Adaptor adaptor,
303+
ConversionPatternRewriter &rewriter) const override {
304+
ImplicitLocOpBuilder b(op.getLoc(), rewriter);
305+
306+
auto cteShiftSizeOp =
307+
op.getRhs().template getDefiningOp<mlir::arith::ConstantOp>();
171308

172309
if (cteShiftSizeOp) {
173310
auto outputType = adaptor.getLhs().getType();
@@ -179,14 +316,15 @@ struct ConvertShRUIOp : public OpConversionPattern<mlir::arith::ShRUIOp> {
179316
auto inputValue =
180317
mlir::IntegerAttr::get(rewriter.getIndexType(), (int8_t)shiftAmount);
181318

182-
auto shiftOp = b.create<cggi::ScalarShiftRightOp>(
183-
outputType, adaptor.getLhs(), inputValue);
319+
auto shiftOp =
320+
b.create<TargetCGGIShOp>(outputType, adaptor.getLhs(), inputValue);
184321
rewriter.replaceOp(op, shiftOp);
185322

186323
return success();
187324
}
188325

189-
cteShiftSizeOp = op.getLhs().getDefiningOp<mlir::arith::ConstantOp>();
326+
cteShiftSizeOp =
327+
op.getLhs().template getDefiningOp<mlir::arith::ConstantOp>();
190328

191329
auto outputType = adaptor.getRhs().getType();
192330

@@ -196,15 +334,15 @@ struct ConvertShRUIOp : public OpConversionPattern<mlir::arith::ShRUIOp> {
196334
auto inputValue =
197335
mlir::IntegerAttr::get(rewriter.getIndexType(), shiftAmount);
198336

199-
auto shiftOp = b.create<cggi::ScalarShiftRightOp>(
200-
outputType, adaptor.getLhs(), inputValue);
337+
auto shiftOp =
338+
b.create<TargetCGGIShOp>(outputType, adaptor.getLhs(), inputValue);
201339
rewriter.replaceOp(op, shiftOp);
202340

203341
return success();
204342
}
205343
};
206344

207-
template <typename SourceArithOp, typename TargetModArithOp>
345+
template <typename SourceArithOp, typename TargetCGGIOp>
208346
struct ConvertArithBinOp : public OpConversionPattern<SourceArithOp> {
209347
ConvertArithBinOp(mlir::MLIRContext *context)
210348
: OpConversionPattern<SourceArithOp>(context) {}
@@ -218,24 +356,24 @@ struct ConvertArithBinOp : public OpConversionPattern<SourceArithOp> {
218356

219357
if (auto lhsDefOp = op.getLhs().getDefiningOp()) {
220358
if (!hasLWEAnnotation(lhsDefOp) && allowedRemainArith(lhsDefOp)) {
221-
auto result = b.create<TargetModArithOp>(adaptor.getRhs().getType(),
222-
adaptor.getRhs(), op.getLhs());
359+
auto result = b.create<TargetCGGIOp>(adaptor.getRhs().getType(),
360+
adaptor.getRhs(), op.getLhs());
223361
rewriter.replaceOp(op, result);
224362
return success();
225363
}
226364
}
227365

228366
if (auto rhsDefOp = op.getRhs().getDefiningOp()) {
229367
if (!hasLWEAnnotation(rhsDefOp) && allowedRemainArith(rhsDefOp)) {
230-
auto result = b.create<TargetModArithOp>(adaptor.getLhs().getType(),
231-
adaptor.getLhs(), op.getRhs());
368+
auto result = b.create<TargetCGGIOp>(adaptor.getLhs().getType(),
369+
adaptor.getLhs(), op.getRhs());
232370
rewriter.replaceOp(op, result);
233371
return success();
234372
}
235373
}
236374

237-
auto result = b.create<TargetModArithOp>(
238-
adaptor.getLhs().getType(), adaptor.getLhs(), adaptor.getRhs());
375+
auto result = b.create<TargetCGGIOp>(adaptor.getLhs().getType(),
376+
adaptor.getLhs(), adaptor.getRhs());
239377
rewriter.replaceOp(op, result);
240378
return success();
241379
}
@@ -277,10 +415,29 @@ struct ArithToCGGI : public impl::ArithToCGGIBase<ArithToCGGI> {
277415
target.addIllegalDialect<mlir::arith::ArithDialect>();
278416
target.addLegalOp<mlir::arith::ConstantOp>();
279417

418+
target.addDynamicallyLegalOp<mlir::arith::SubIOp, mlir::arith::AddIOp,
419+
mlir::arith::MulIOp>([&](Operation *op) {
420+
if (auto *defLhsOp = op->getOperand(0).getDefiningOp()) {
421+
if (auto *defRhsOp = op->getOperand(1).getDefiningOp()) {
422+
return !hasLWEAnnotation(defLhsOp) && !hasLWEAnnotation(defRhsOp) &&
423+
allowedRemainArith(defLhsOp) && allowedRemainArith(defRhsOp);
424+
}
425+
}
426+
return false;
427+
});
428+
280429
target.addDynamicallyLegalOp<mlir::arith::ExtSIOp>([&](Operation *op) {
281430
if (auto *defOp =
282431
cast<mlir::arith::ExtSIOp>(op).getOperand().getDefiningOp()) {
283-
return hasLWEAnnotation(defOp) || allowedRemainArith(defOp);
432+
return !hasLWEAnnotation(defOp) && allowedRemainArith(defOp);
433+
}
434+
return false;
435+
});
436+
437+
target.addDynamicallyLegalOp<mlir::arith::ExtUIOp>([&](Operation *op) {
438+
if (auto *defOp =
439+
cast<mlir::arith::ExtUIOp>(op).getOperand().getDefiningOp()) {
440+
return !hasLWEAnnotation(defOp) && allowedRemainArith(defOp);
284441
}
285442
return false;
286443
});
@@ -298,14 +455,16 @@ struct ArithToCGGI : public impl::ArithToCGGIBase<ArithToCGGI> {
298455
// accepts Check if there is at least one Store op that is a constants
299456
auto containsAnyStoreOp = llvm::any_of(op->getUses(), [&](OpOperand &op) {
300457
if (auto defOp = dyn_cast<memref::StoreOp>(op.getOwner())) {
301-
return allowedRemainArith(defOp.getValue().getDefiningOp());
458+
return !hasLWEAnnotation(defOp.getValue().getDefiningOp()) &&
459+
allowedRemainArith(defOp.getValue().getDefiningOp());
302460
}
303461
return false;
304462
});
305463
auto allStoreOpsAreArith =
306464
llvm::all_of(op->getUses(), [&](OpOperand &op) {
307465
if (auto defOp = dyn_cast<memref::StoreOp>(op.getOwner())) {
308-
return allowedRemainArith(defOp.getValue().getDefiningOp());
466+
return !hasLWEAnnotation(defOp.getValue().getDefiningOp()) &&
467+
allowedRemainArith(defOp.getValue().getDefiningOp());
309468
}
310469
return true;
311470
});
@@ -371,10 +530,17 @@ struct ArithToCGGI : public impl::ArithToCGGIBase<ArithToCGGI> {
371530
});
372531

373532
patterns.add<
374-
ConvertTruncIOp, ConvertExtUIOp, ConvertExtSIOp, ConvertShRUIOp,
533+
ConvertTruncIOp, ConvertExtUIOp, ConvertExtSIOp, ConvertSelectOp,
534+
ConvertCmpOp, ConvertSubOp,
535+
ConvertShOp<mlir::arith::ShRSIOp, cggi::ScalarShiftRightOp>,
536+
ConvertShOp<mlir::arith::ShRUIOp, cggi::ScalarShiftRightOp>,
537+
ConvertShOp<mlir::arith::ShLIOp, cggi::ScalarShiftLeftOp>,
375538
ConvertArithBinOp<mlir::arith::AddIOp, cggi::AddOp>,
376539
ConvertArithBinOp<mlir::arith::MulIOp, cggi::MulOp>,
377-
ConvertArithBinOp<mlir::arith::SubIOp, cggi::SubOp>,
540+
ConvertArithBinOp<mlir::arith::MaxSIOp, cggi::MaxOp>,
541+
ConvertArithBinOp<mlir::arith::MinSIOp, cggi::MinOp>,
542+
ConvertArithBinOp<mlir::arith::MaxUIOp, cggi::MaxOp>,
543+
ConvertArithBinOp<mlir::arith::MinUIOp, cggi::MinOp>,
378544
ConvertAny<memref::LoadOp>, ConvertAllocOp,
379545
ConvertAny<memref::DeallocOp>, ConvertAny<memref::SubViewOp>,
380546
ConvertAny<memref::CopyOp>, ConvertAny<memref::StoreOp>,

lib/Dialect/Arith/Conversions/ArithToCGGIQuart/ArithToCGGIQuart.cpp

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -419,17 +419,17 @@ struct ConvertQuartMulI final : OpConversionPattern<mlir::arith::MulIOp> {
419419
auto z01_p1 = b.create<cggi::AddOp>(elemTy, splitLhs[0], splitLhs[1]);
420420
auto z01_p2 = b.create<cggi::AddOp>(elemTy, splitRhs[0], splitRhs[1]);
421421
auto z01_m = b.create<cggi::MulOp>(elemTy, z01_p1, z01_p2);
422-
auto z01_s = b.create<cggi::SubOp>(z01_m, z00);
423-
auto z01 = b.create<cggi::SubOp>(z01_s, z02);
422+
auto z01_s = b.create<cggi::SubOp>(elemTy, z01_m, z00);
423+
auto z01 = b.create<cggi::SubOp>(elemTy, z01_s, z02);
424424

425425
// Second part I of Karatsuba algorithm
426426
auto z1a0 = b.create<cggi::MulOp>(elemTy, splitLhs[0], splitRhs[2]);
427427
auto z1a2 = b.create<cggi::MulOp>(elemTy, splitLhs[1], splitRhs[3]);
428428
auto z1a1_p1 = b.create<cggi::AddOp>(elemTy, splitLhs[0], splitLhs[1]);
429429
auto z1a1_p2 = b.create<cggi::AddOp>(elemTy, splitRhs[2], splitRhs[3]);
430430
auto z1a1_m = b.create<cggi::MulOp>(elemTy, z1a1_p1, z1a1_p2);
431-
auto z1a1_s = b.create<cggi::SubOp>(z1a1_m, z1a0);
432-
auto z1a1 = b.create<cggi::SubOp>(z1a1_s, z1a2);
431+
auto z1a1_s = b.create<cggi::SubOp>(elemTy, z1a1_m, z1a0);
432+
auto z1a1 = b.create<cggi::SubOp>(elemTy, z1a1_s, z1a2);
433433

434434
// Second part II of Karatsuba algorithm
435435
auto z1b0 = b.create<cggi::MulOp>(elemTy, splitLhs[2], splitRhs[0]);
@@ -438,7 +438,7 @@ struct ConvertQuartMulI final : OpConversionPattern<mlir::arith::MulIOp> {
438438
auto z1b1_p2 = b.create<cggi::AddOp>(elemTy, splitRhs[0], splitRhs[1]);
439439
auto z1b1_m = b.create<cggi::MulOp>(elemTy, z1b1_p1, z1b1_p2);
440440
auto z1b1_s = b.create<cggi::SubOp>(elemTy, z1b1_m, z1b0);
441-
auto z1b1 = b.create<cggi::SubOp>(z1b1_s, z1b2);
441+
auto z1b1 = b.create<cggi::SubOp>(elemTy, z1b1_s, z1b2);
442442

443443
auto out2Kara = b.create<cggi::AddOp>(elemTy, z1a0, z1b0);
444444
auto out2Carry = b.create<cggi::AddOp>(elemTy, out2Kara, z02);

lib/Dialect/CGGI/IR/BUILD

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -44,6 +44,7 @@ td_library(
4444
# include from the heir-root to enable fully-qualified include-paths
4545
includes = ["../../../.."],
4646
deps = [
47+
"@llvm-project//mlir:ArithOpsTdFiles",
4748
"@llvm-project//mlir:BuiltinDialectTdFiles",
4849
"@llvm-project//mlir:InferTypeOpInterfaceTdFiles",
4950
"@llvm-project//mlir:OpBaseTdFiles",

lib/Dialect/CGGI/IR/CGGIOps.h

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -5,10 +5,11 @@
55
#include "lib/Dialect/CGGI/IR/CGGIDialect.h"
66
#include "lib/Dialect/HEIRInterfaces.h"
77
#include "lib/Dialect/LWE/IR/LWETypes.h"
8-
#include "mlir/include/mlir/IR/BuiltinOps.h" // from @llvm-project
9-
#include "mlir/include/mlir/IR/BuiltinTypes.h" // from @llvm-project
10-
#include "mlir/include/mlir/IR/Dialect.h" // from @llvm-project
11-
#include "mlir/include/mlir/IR/PatternMatch.h" // from @llvm-project
8+
#include "mlir/include/mlir/Dialect/Arith/IR/Arith.h" // from @llvm-project
9+
#include "mlir/include/mlir/IR/BuiltinOps.h" // from @llvm-project
10+
#include "mlir/include/mlir/IR/BuiltinTypes.h" // from @llvm-project
11+
#include "mlir/include/mlir/IR/Dialect.h" // from @llvm-project
12+
#include "mlir/include/mlir/IR/PatternMatch.h" // from @llvm-project
1213
#include "mlir/include/mlir/Interfaces/InferTypeOpInterface.h" // from @llvm-project
1314

1415
#define GET_OP_CLASSES

0 commit comments

Comments
 (0)