Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Adding support of the Hello World Small example to Arith-to-CGGI #1403

Merged
merged 1 commit into from
Feb 14, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
220 changes: 193 additions & 27 deletions lib/Dialect/Arith/Conversions/ArithToCGGI/ArithToCGGI.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -51,15 +51,53 @@ static bool allowedRemainArith(Operation *op) {
}
return false;
})
.Case<mlir::arith::SubIOp, mlir::arith::AddIOp, mlir::arith::MulIOp>(
[](auto op) {
// This lambda will be called for any of the matched operation types
if (auto lhsDefOp = op.getOperand(0).getDefiningOp()) {
auto lshAllowed = allowedRemainArith(lhsDefOp);
if (auto rhsDefOp = op.getOperand(1).getDefiningOp()) {
auto rhsAllowed = allowedRemainArith(rhsDefOp);
return lshAllowed && rhsAllowed;
}
}
return false;
})
.Default([](Operation *) {
// Default case for operations that don't match any of the types
return false;
});
}

static bool hasLWEAnnotation(Operation *op) {
return static_cast<bool>(
op->getAttrOfType<mlir::StringAttr>("lwe_annotation"));
mlir::StringAttr check =
op->getAttrOfType<mlir::StringAttr>("lwe_annotation");

if (check) return true;

// Check recursively if a defining op has a LWE annotation
return llvm::TypeSwitch<Operation *, bool>(op)
.Case<mlir::arith::ExtUIOp, mlir::arith::ExtSIOp, mlir::arith::TruncIOp>(
[](auto op) {
if (auto *defOp = op.getIn().getDefiningOp()) {
return hasLWEAnnotation(defOp);
}
return op->template getAttrOfType<mlir::StringAttr>(
"lwe_annotation") != nullptr;
})
.Case<mlir::arith::SubIOp, mlir::arith::AddIOp, mlir::arith::MulIOp>(
[](auto op) {
// This lambda will be called for any of the matched operation types
if (auto lhsDefOp = op.getOperand(0).getDefiningOp()) {
auto lshAllowed = hasLWEAnnotation(lhsDefOp);
if (auto rhsDefOp = op.getOperand(1).getDefiningOp()) {
auto rhsAllowed = hasLWEAnnotation(rhsDefOp);
return lshAllowed || rhsAllowed;
}
}
return false;
})
.Default([](Operation *) { return false; });
}

static Value materializeTarget(OpBuilder &builder, Type type, ValueRange inputs,
Expand All @@ -70,10 +108,18 @@ static Value materializeTarget(OpBuilder &builder, Type type, ValueRange inputs,
llvm_unreachable(
"Non-integer types should never be the input to a materializeTarget.");

auto inValue = inputs.front().getDefiningOp<mlir::arith::ConstantOp>();
auto intAttr = cast<IntegerAttr>(inValue.getValueAttr());
if (auto inValue = inputs.front().getDefiningOp<mlir::arith::ConstantOp>()) {
auto intAttr = cast<IntegerAttr>(inValue.getValueAttr());

return builder.create<cggi::CreateTrivialOp>(loc, type, intAttr);
return builder.create<cggi::CreateTrivialOp>(loc, type, intAttr);
}
// Comes from function/loop argument: Trivial encrypt through LWE
auto encoding = cast<lwe::LWECiphertextType>(type).getEncoding();
auto ptxtTy = lwe::LWEPlaintextType::get(builder.getContext(), encoding);
return builder.create<lwe::TrivialEncryptOp>(
loc, type,
builder.create<lwe::EncodeOp>(loc, ptxtTy, inputs[0], encoding),
lwe::LWEParamsAttr());
}

class ArithToCGGITypeConverter : public TypeConverter {
Expand Down Expand Up @@ -156,18 +202,109 @@ struct ConvertExtSIOp : public OpConversionPattern<mlir::arith::ExtSIOp> {
}
};

struct ConvertShRUIOp : public OpConversionPattern<mlir::arith::ShRUIOp> {
ConvertShRUIOp(mlir::MLIRContext *context)
: OpConversionPattern<mlir::arith::ShRUIOp>(context) {}
struct ConvertCmpOp : public OpConversionPattern<mlir::arith::CmpIOp> {
ConvertCmpOp(mlir::MLIRContext *context)
: OpConversionPattern<mlir::arith::CmpIOp>(context) {}

using OpConversionPattern::OpConversionPattern;

LogicalResult matchAndRewrite(
mlir::arith::ShRUIOp op, OpAdaptor adaptor,
mlir::arith::CmpIOp op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
ImplicitLocOpBuilder b(op.getLoc(), rewriter);

auto cteShiftSizeOp = op.getRhs().getDefiningOp<mlir::arith::ConstantOp>();
auto lweBooleanType = lwe::LWECiphertextType::get(
op->getContext(),
lwe::UnspecifiedBitFieldEncodingAttr::get(op->getContext(), 1),
lwe::LWEParamsAttr());

if (auto lhsDefOp = op.getLhs().getDefiningOp()) {
if (!hasLWEAnnotation(lhsDefOp) && allowedRemainArith(lhsDefOp)) {
auto result = b.create<cggi::CmpOp>(lweBooleanType, op.getPredicate(),
adaptor.getRhs(), op.getLhs());
rewriter.replaceOp(op, result);
return success();
}
}

if (auto rhsDefOp = op.getRhs().getDefiningOp()) {
if (!hasLWEAnnotation(rhsDefOp) && allowedRemainArith(rhsDefOp)) {
auto result = b.create<cggi::CmpOp>(lweBooleanType, op.getPredicate(),
adaptor.getLhs(), op.getRhs());
rewriter.replaceOp(op, result);
return success();
}
}

auto cmpOp = b.create<cggi::CmpOp>(lweBooleanType, op.getPredicate(),
adaptor.getLhs(), adaptor.getRhs());

rewriter.replaceOp(op, cmpOp);
return success();
}
};

struct ConvertSubOp : public OpConversionPattern<mlir::arith::SubIOp> {
ConvertSubOp(mlir::MLIRContext *context)
: OpConversionPattern<mlir::arith::SubIOp>(context) {}

using OpConversionPattern::OpConversionPattern;

LogicalResult matchAndRewrite(
mlir::arith::SubIOp op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
ImplicitLocOpBuilder b(op.getLoc(), rewriter);

if (auto rhsDefOp = op.getRhs().getDefiningOp()) {
if (!hasLWEAnnotation(rhsDefOp) && allowedRemainArith(rhsDefOp)) {
auto result = b.create<cggi::SubOp>(adaptor.getLhs().getType(),
adaptor.getLhs(), op.getRhs());
rewriter.replaceOp(op, result);
return success();
}
}

auto subOp = b.create<cggi::SubOp>(adaptor.getLhs().getType(),
adaptor.getLhs(), adaptor.getRhs());
rewriter.replaceOp(op, subOp);
return success();
}
};

struct ConvertSelectOp : public OpConversionPattern<mlir::arith::SelectOp> {
ConvertSelectOp(mlir::MLIRContext *context)
: OpConversionPattern<mlir::arith::SelectOp>(context) {}

using OpConversionPattern::OpConversionPattern;

LogicalResult matchAndRewrite(
mlir::arith::SelectOp op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
ImplicitLocOpBuilder b(op.getLoc(), rewriter);

auto cmuxOp = b.create<cggi::SelectOp>(
adaptor.getTrueValue().getType(), adaptor.getCondition(),
adaptor.getTrueValue(), adaptor.getFalseValue());

rewriter.replaceOp(op, cmuxOp);
return success();
}
};

template <typename SourceArithShOp, typename TargetCGGIShOp>
struct ConvertShOp : public OpConversionPattern<SourceArithShOp> {
ConvertShOp(mlir::MLIRContext *context)
: OpConversionPattern<SourceArithShOp>(context) {}

using OpConversionPattern<SourceArithShOp>::OpConversionPattern;

LogicalResult matchAndRewrite(
SourceArithShOp op, typename SourceArithShOp::Adaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
ImplicitLocOpBuilder b(op.getLoc(), rewriter);

auto cteShiftSizeOp =
op.getRhs().template getDefiningOp<mlir::arith::ConstantOp>();

if (cteShiftSizeOp) {
auto outputType = adaptor.getLhs().getType();
Expand All @@ -179,14 +316,15 @@ struct ConvertShRUIOp : public OpConversionPattern<mlir::arith::ShRUIOp> {
auto inputValue =
mlir::IntegerAttr::get(rewriter.getIndexType(), (int8_t)shiftAmount);

auto shiftOp = b.create<cggi::ScalarShiftRightOp>(
outputType, adaptor.getLhs(), inputValue);
auto shiftOp =
b.create<TargetCGGIShOp>(outputType, adaptor.getLhs(), inputValue);
rewriter.replaceOp(op, shiftOp);

return success();
}

cteShiftSizeOp = op.getLhs().getDefiningOp<mlir::arith::ConstantOp>();
cteShiftSizeOp =
op.getLhs().template getDefiningOp<mlir::arith::ConstantOp>();

auto outputType = adaptor.getRhs().getType();

Expand All @@ -196,15 +334,15 @@ struct ConvertShRUIOp : public OpConversionPattern<mlir::arith::ShRUIOp> {
auto inputValue =
mlir::IntegerAttr::get(rewriter.getIndexType(), shiftAmount);

auto shiftOp = b.create<cggi::ScalarShiftRightOp>(
outputType, adaptor.getLhs(), inputValue);
auto shiftOp =
b.create<TargetCGGIShOp>(outputType, adaptor.getLhs(), inputValue);
rewriter.replaceOp(op, shiftOp);

return success();
}
};

template <typename SourceArithOp, typename TargetModArithOp>
template <typename SourceArithOp, typename TargetCGGIOp>
struct ConvertArithBinOp : public OpConversionPattern<SourceArithOp> {
ConvertArithBinOp(mlir::MLIRContext *context)
: OpConversionPattern<SourceArithOp>(context) {}
Expand All @@ -218,24 +356,24 @@ struct ConvertArithBinOp : public OpConversionPattern<SourceArithOp> {

if (auto lhsDefOp = op.getLhs().getDefiningOp()) {
if (!hasLWEAnnotation(lhsDefOp) && allowedRemainArith(lhsDefOp)) {
auto result = b.create<TargetModArithOp>(adaptor.getRhs().getType(),
adaptor.getRhs(), op.getLhs());
auto result = b.create<TargetCGGIOp>(adaptor.getRhs().getType(),
adaptor.getRhs(), op.getLhs());
rewriter.replaceOp(op, result);
return success();
}
}

if (auto rhsDefOp = op.getRhs().getDefiningOp()) {
if (!hasLWEAnnotation(rhsDefOp) && allowedRemainArith(rhsDefOp)) {
auto result = b.create<TargetModArithOp>(adaptor.getLhs().getType(),
adaptor.getLhs(), op.getRhs());
auto result = b.create<TargetCGGIOp>(adaptor.getLhs().getType(),
adaptor.getLhs(), op.getRhs());
rewriter.replaceOp(op, result);
return success();
}
}

auto result = b.create<TargetModArithOp>(
adaptor.getLhs().getType(), adaptor.getLhs(), adaptor.getRhs());
auto result = b.create<TargetCGGIOp>(adaptor.getLhs().getType(),
adaptor.getLhs(), adaptor.getRhs());
rewriter.replaceOp(op, result);
return success();
}
Expand Down Expand Up @@ -277,10 +415,29 @@ struct ArithToCGGI : public impl::ArithToCGGIBase<ArithToCGGI> {
target.addIllegalDialect<mlir::arith::ArithDialect>();
target.addLegalOp<mlir::arith::ConstantOp>();

target.addDynamicallyLegalOp<mlir::arith::SubIOp, mlir::arith::AddIOp,
mlir::arith::MulIOp>([&](Operation *op) {
if (auto *defLhsOp = op->getOperand(0).getDefiningOp()) {
if (auto *defRhsOp = op->getOperand(1).getDefiningOp()) {
return !hasLWEAnnotation(defLhsOp) && !hasLWEAnnotation(defRhsOp) &&
allowedRemainArith(defLhsOp) && allowedRemainArith(defRhsOp);
}
}
return false;
});

target.addDynamicallyLegalOp<mlir::arith::ExtSIOp>([&](Operation *op) {
if (auto *defOp =
cast<mlir::arith::ExtSIOp>(op).getOperand().getDefiningOp()) {
return hasLWEAnnotation(defOp) || allowedRemainArith(defOp);
return !hasLWEAnnotation(defOp) && allowedRemainArith(defOp);
}
return false;
});

target.addDynamicallyLegalOp<mlir::arith::ExtUIOp>([&](Operation *op) {
if (auto *defOp =
cast<mlir::arith::ExtUIOp>(op).getOperand().getDefiningOp()) {
return !hasLWEAnnotation(defOp) && allowedRemainArith(defOp);
}
return false;
});
Expand All @@ -298,14 +455,16 @@ struct ArithToCGGI : public impl::ArithToCGGIBase<ArithToCGGI> {
// accepts Check if there is at least one Store op that is a constants
auto containsAnyStoreOp = llvm::any_of(op->getUses(), [&](OpOperand &op) {
if (auto defOp = dyn_cast<memref::StoreOp>(op.getOwner())) {
return allowedRemainArith(defOp.getValue().getDefiningOp());
return !hasLWEAnnotation(defOp.getValue().getDefiningOp()) &&
allowedRemainArith(defOp.getValue().getDefiningOp());
}
return false;
});
auto allStoreOpsAreArith =
llvm::all_of(op->getUses(), [&](OpOperand &op) {
if (auto defOp = dyn_cast<memref::StoreOp>(op.getOwner())) {
return allowedRemainArith(defOp.getValue().getDefiningOp());
return !hasLWEAnnotation(defOp.getValue().getDefiningOp()) &&
allowedRemainArith(defOp.getValue().getDefiningOp());
}
return true;
});
Expand Down Expand Up @@ -371,10 +530,17 @@ struct ArithToCGGI : public impl::ArithToCGGIBase<ArithToCGGI> {
});

patterns.add<
ConvertTruncIOp, ConvertExtUIOp, ConvertExtSIOp, ConvertShRUIOp,
ConvertTruncIOp, ConvertExtUIOp, ConvertExtSIOp, ConvertSelectOp,
ConvertCmpOp, ConvertSubOp,
ConvertShOp<mlir::arith::ShRSIOp, cggi::ScalarShiftRightOp>,
ConvertShOp<mlir::arith::ShRUIOp, cggi::ScalarShiftRightOp>,
ConvertShOp<mlir::arith::ShLIOp, cggi::ScalarShiftLeftOp>,
ConvertArithBinOp<mlir::arith::AddIOp, cggi::AddOp>,
ConvertArithBinOp<mlir::arith::MulIOp, cggi::MulOp>,
ConvertArithBinOp<mlir::arith::SubIOp, cggi::SubOp>,
ConvertArithBinOp<mlir::arith::MaxSIOp, cggi::MaxOp>,
ConvertArithBinOp<mlir::arith::MinSIOp, cggi::MinOp>,
ConvertArithBinOp<mlir::arith::MaxUIOp, cggi::MaxOp>,
ConvertArithBinOp<mlir::arith::MinUIOp, cggi::MinOp>,
ConvertAny<memref::LoadOp>, ConvertAllocOp,
ConvertAny<memref::DeallocOp>, ConvertAny<memref::SubViewOp>,
ConvertAny<memref::CopyOp>, ConvertAny<memref::StoreOp>,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -419,17 +419,17 @@ struct ConvertQuartMulI final : OpConversionPattern<mlir::arith::MulIOp> {
auto z01_p1 = b.create<cggi::AddOp>(elemTy, splitLhs[0], splitLhs[1]);
auto z01_p2 = b.create<cggi::AddOp>(elemTy, splitRhs[0], splitRhs[1]);
auto z01_m = b.create<cggi::MulOp>(elemTy, z01_p1, z01_p2);
auto z01_s = b.create<cggi::SubOp>(z01_m, z00);
auto z01 = b.create<cggi::SubOp>(z01_s, z02);
auto z01_s = b.create<cggi::SubOp>(elemTy, z01_m, z00);
auto z01 = b.create<cggi::SubOp>(elemTy, z01_s, z02);

// Second part I of Karatsuba algorithm
auto z1a0 = b.create<cggi::MulOp>(elemTy, splitLhs[0], splitRhs[2]);
auto z1a2 = b.create<cggi::MulOp>(elemTy, splitLhs[1], splitRhs[3]);
auto z1a1_p1 = b.create<cggi::AddOp>(elemTy, splitLhs[0], splitLhs[1]);
auto z1a1_p2 = b.create<cggi::AddOp>(elemTy, splitRhs[2], splitRhs[3]);
auto z1a1_m = b.create<cggi::MulOp>(elemTy, z1a1_p1, z1a1_p2);
auto z1a1_s = b.create<cggi::SubOp>(z1a1_m, z1a0);
auto z1a1 = b.create<cggi::SubOp>(z1a1_s, z1a2);
auto z1a1_s = b.create<cggi::SubOp>(elemTy, z1a1_m, z1a0);
auto z1a1 = b.create<cggi::SubOp>(elemTy, z1a1_s, z1a2);

// Second part II of Karatsuba algorithm
auto z1b0 = b.create<cggi::MulOp>(elemTy, splitLhs[2], splitRhs[0]);
Expand All @@ -438,7 +438,7 @@ struct ConvertQuartMulI final : OpConversionPattern<mlir::arith::MulIOp> {
auto z1b1_p2 = b.create<cggi::AddOp>(elemTy, splitRhs[0], splitRhs[1]);
auto z1b1_m = b.create<cggi::MulOp>(elemTy, z1b1_p1, z1b1_p2);
auto z1b1_s = b.create<cggi::SubOp>(elemTy, z1b1_m, z1b0);
auto z1b1 = b.create<cggi::SubOp>(z1b1_s, z1b2);
auto z1b1 = b.create<cggi::SubOp>(elemTy, z1b1_s, z1b2);

auto out2Kara = b.create<cggi::AddOp>(elemTy, z1a0, z1b0);
auto out2Carry = b.create<cggi::AddOp>(elemTy, out2Kara, z02);
Expand Down
3 changes: 3 additions & 0 deletions lib/Dialect/CGGI/IR/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -37,13 +37,16 @@ td_library(
srcs = [
"BooleanGates.td",
"CGGIAttributes.td",
"CGGIBinOps.td",
"CGGIDialect.td",
"CGGIEnums.td",
"CGGIOps.td",
"CGGIPBSOps.td",
],
# include from the heir-root to enable fully-qualified include-paths
includes = ["../../../.."],
deps = [
"@llvm-project//mlir:ArithOpsTdFiles",
"@llvm-project//mlir:BuiltinDialectTdFiles",
"@llvm-project//mlir:InferTypeOpInterfaceTdFiles",
"@llvm-project//mlir:OpBaseTdFiles",
Expand Down
Loading
Loading