Skip to content

Commit

Permalink
Adding support of the Hello World Small example to Arith-to-CGGI
Browse files Browse the repository at this point in the history
  • Loading branch information
WoutLegiest committed Feb 13, 2025
1 parent d7fbf06 commit 9795e2f
Show file tree
Hide file tree
Showing 5 changed files with 303 additions and 82 deletions.
220 changes: 193 additions & 27 deletions lib/Dialect/Arith/Conversions/ArithToCGGI/ArithToCGGI.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -51,15 +51,53 @@ static bool allowedRemainArith(Operation *op) {
}
return false;
})
.Case<mlir::arith::SubIOp, mlir::arith::AddIOp, mlir::arith::MulIOp>(
[](auto op) {
// This lambda will be called for any of the matched operation types
if (auto lhsDefOp = op.getOperand(0).getDefiningOp()) {
auto lshAllowed = allowedRemainArith(lhsDefOp);
if (auto rhsDefOp = op.getOperand(1).getDefiningOp()) {
auto rhsAllowed = allowedRemainArith(rhsDefOp);
return lshAllowed && rhsAllowed;
}
}
return false;
})
.Default([](Operation *) {
// Default case for operations that don't match any of the types
return false;
});
}

static bool hasLWEAnnotation(Operation *op) {
return static_cast<bool>(
op->getAttrOfType<mlir::StringAttr>("lwe_annotation"));
auto check =
static_cast<bool>(op->getAttrOfType<mlir::StringAttr>("lwe_annotation"));

if (check) return check;

// Check recursively if a defining op has a LWE annotation
return llvm::TypeSwitch<Operation *, bool>(op)
.Case<mlir::arith::ExtUIOp, mlir::arith::ExtSIOp, mlir::arith::TruncIOp>(
[](auto op) {
if (auto *defOp = op.getIn().getDefiningOp()) {
return hasLWEAnnotation(defOp);
}
return static_cast<bool>(
op->template getAttrOfType<mlir::StringAttr>("lwe_annotation"));
})
.Case<mlir::arith::SubIOp, mlir::arith::AddIOp, mlir::arith::MulIOp>(
[](auto op) {
// This lambda will be called for any of the matched operation types
if (auto lhsDefOp = op.getOperand(0).getDefiningOp()) {
auto lshAllowed = hasLWEAnnotation(lhsDefOp);
if (auto rhsDefOp = op.getOperand(1).getDefiningOp()) {
auto rhsAllowed = hasLWEAnnotation(rhsDefOp);
return lshAllowed || rhsAllowed;
}
}
return false;
})
.Default([](Operation *) { return false; });
}

static Value materializeTarget(OpBuilder &builder, Type type, ValueRange inputs,
Expand All @@ -70,10 +108,18 @@ static Value materializeTarget(OpBuilder &builder, Type type, ValueRange inputs,
llvm_unreachable(
"Non-integer types should never be the input to a materializeTarget.");

auto inValue = inputs.front().getDefiningOp<mlir::arith::ConstantOp>();
auto intAttr = cast<IntegerAttr>(inValue.getValueAttr());
if (auto inValue = inputs.front().getDefiningOp<mlir::arith::ConstantOp>()) {
auto intAttr = cast<IntegerAttr>(inValue.getValueAttr());

return builder.create<cggi::CreateTrivialOp>(loc, type, intAttr);
return builder.create<cggi::CreateTrivialOp>(loc, type, intAttr);
}
// Comes from function/loop argument: Trivial encrypt through LWE
auto encoding = cast<lwe::LWECiphertextType>(type).getEncoding();
auto ptxtTy = lwe::LWEPlaintextType::get(builder.getContext(), encoding);
return builder.create<lwe::TrivialEncryptOp>(
loc, type,
builder.create<lwe::EncodeOp>(loc, ptxtTy, inputs[0], encoding),
lwe::LWEParamsAttr());
}

class ArithToCGGITypeConverter : public TypeConverter {
Expand Down Expand Up @@ -156,18 +202,109 @@ struct ConvertExtSIOp : public OpConversionPattern<mlir::arith::ExtSIOp> {
}
};

struct ConvertShRUIOp : public OpConversionPattern<mlir::arith::ShRUIOp> {
ConvertShRUIOp(mlir::MLIRContext *context)
: OpConversionPattern<mlir::arith::ShRUIOp>(context) {}
struct ConvertCmpOp : public OpConversionPattern<mlir::arith::CmpIOp> {
ConvertCmpOp(mlir::MLIRContext *context)
: OpConversionPattern<mlir::arith::CmpIOp>(context) {}

using OpConversionPattern::OpConversionPattern;

LogicalResult matchAndRewrite(
mlir::arith::ShRUIOp op, OpAdaptor adaptor,
mlir::arith::CmpIOp op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
ImplicitLocOpBuilder b(op.getLoc(), rewriter);

auto cteShiftSizeOp = op.getRhs().getDefiningOp<mlir::arith::ConstantOp>();
auto lweBooleanType = lwe::LWECiphertextType::get(
op->getContext(),
lwe::UnspecifiedBitFieldEncodingAttr::get(op->getContext(), 1),
lwe::LWEParamsAttr());

if (auto lhsDefOp = op.getLhs().getDefiningOp()) {
if (!hasLWEAnnotation(lhsDefOp) && allowedRemainArith(lhsDefOp)) {
auto result = b.create<cggi::CmpOp>(lweBooleanType, op.getPredicate(),
adaptor.getRhs(), op.getLhs());
rewriter.replaceOp(op, result);
return success();
}
}

if (auto rhsDefOp = op.getRhs().getDefiningOp()) {
if (!hasLWEAnnotation(rhsDefOp) && allowedRemainArith(rhsDefOp)) {
auto result = b.create<cggi::CmpOp>(lweBooleanType, op.getPredicate(),
adaptor.getLhs(), op.getRhs());
rewriter.replaceOp(op, result);
return success();
}
}

auto cmpOp = b.create<cggi::CmpOp>(lweBooleanType, op.getPredicate(),
adaptor.getLhs(), adaptor.getRhs());

rewriter.replaceOp(op, cmpOp);
return success();
}
};

struct ConvertSubOp : public OpConversionPattern<mlir::arith::SubIOp> {
ConvertSubOp(mlir::MLIRContext *context)
: OpConversionPattern<mlir::arith::SubIOp>(context) {}

using OpConversionPattern::OpConversionPattern;

LogicalResult matchAndRewrite(
mlir::arith::SubIOp op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
ImplicitLocOpBuilder b(op.getLoc(), rewriter);

if (auto rhsDefOp = op.getRhs().getDefiningOp()) {
if (!hasLWEAnnotation(rhsDefOp) && allowedRemainArith(rhsDefOp)) {
auto result = b.create<cggi::SubOp>(adaptor.getLhs().getType(),
adaptor.getLhs(), op.getRhs());
rewriter.replaceOp(op, result);
return success();
}
}

auto subOp = b.create<cggi::SubOp>(adaptor.getLhs().getType(),
adaptor.getLhs(), adaptor.getRhs());
rewriter.replaceOp(op, subOp);
return success();
}
};

struct ConvertSelectOp : public OpConversionPattern<mlir::arith::SelectOp> {
ConvertSelectOp(mlir::MLIRContext *context)
: OpConversionPattern<mlir::arith::SelectOp>(context) {}

using OpConversionPattern::OpConversionPattern;

LogicalResult matchAndRewrite(
mlir::arith::SelectOp op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
ImplicitLocOpBuilder b(op.getLoc(), rewriter);

auto cmuxOp = b.create<cggi::SelectOp>(
adaptor.getTrueValue().getType(), adaptor.getCondition(),
adaptor.getTrueValue(), adaptor.getFalseValue());

rewriter.replaceOp(op, cmuxOp);
return success();
}
};

template <typename SourceArithShOp, typename TargetCGGIShOp>
struct ConvertShOp : public OpConversionPattern<SourceArithShOp> {
ConvertShOp(mlir::MLIRContext *context)
: OpConversionPattern<SourceArithShOp>(context) {}

using OpConversionPattern<SourceArithShOp>::OpConversionPattern;

LogicalResult matchAndRewrite(
SourceArithShOp op, typename SourceArithShOp::Adaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
ImplicitLocOpBuilder b(op.getLoc(), rewriter);

auto cteShiftSizeOp =
op.getRhs().template getDefiningOp<mlir::arith::ConstantOp>();

if (cteShiftSizeOp) {
auto outputType = adaptor.getLhs().getType();
Expand All @@ -179,14 +316,15 @@ struct ConvertShRUIOp : public OpConversionPattern<mlir::arith::ShRUIOp> {
auto inputValue =
mlir::IntegerAttr::get(rewriter.getIndexType(), (int8_t)shiftAmount);

auto shiftOp = b.create<cggi::ScalarShiftRightOp>(
outputType, adaptor.getLhs(), inputValue);
auto shiftOp =
b.create<TargetCGGIShOp>(outputType, adaptor.getLhs(), inputValue);
rewriter.replaceOp(op, shiftOp);

return success();
}

cteShiftSizeOp = op.getLhs().getDefiningOp<mlir::arith::ConstantOp>();
cteShiftSizeOp =
op.getLhs().template getDefiningOp<mlir::arith::ConstantOp>();

auto outputType = adaptor.getRhs().getType();

Expand All @@ -196,15 +334,15 @@ struct ConvertShRUIOp : public OpConversionPattern<mlir::arith::ShRUIOp> {
auto inputValue =
mlir::IntegerAttr::get(rewriter.getIndexType(), shiftAmount);

auto shiftOp = b.create<cggi::ScalarShiftRightOp>(
outputType, adaptor.getLhs(), inputValue);
auto shiftOp =
b.create<TargetCGGIShOp>(outputType, adaptor.getLhs(), inputValue);
rewriter.replaceOp(op, shiftOp);

return success();
}
};

template <typename SourceArithOp, typename TargetModArithOp>
template <typename SourceArithOp, typename TargetCGGIOp>
struct ConvertArithBinOp : public OpConversionPattern<SourceArithOp> {
ConvertArithBinOp(mlir::MLIRContext *context)
: OpConversionPattern<SourceArithOp>(context) {}
Expand All @@ -218,24 +356,24 @@ struct ConvertArithBinOp : public OpConversionPattern<SourceArithOp> {

if (auto lhsDefOp = op.getLhs().getDefiningOp()) {
if (!hasLWEAnnotation(lhsDefOp) && allowedRemainArith(lhsDefOp)) {
auto result = b.create<TargetModArithOp>(adaptor.getRhs().getType(),
adaptor.getRhs(), op.getLhs());
auto result = b.create<TargetCGGIOp>(adaptor.getRhs().getType(),
adaptor.getRhs(), op.getLhs());
rewriter.replaceOp(op, result);
return success();
}
}

if (auto rhsDefOp = op.getRhs().getDefiningOp()) {
if (!hasLWEAnnotation(rhsDefOp) && allowedRemainArith(rhsDefOp)) {
auto result = b.create<TargetModArithOp>(adaptor.getLhs().getType(),
adaptor.getLhs(), op.getRhs());
auto result = b.create<TargetCGGIOp>(adaptor.getLhs().getType(),
adaptor.getLhs(), op.getRhs());
rewriter.replaceOp(op, result);
return success();
}
}

auto result = b.create<TargetModArithOp>(
adaptor.getLhs().getType(), adaptor.getLhs(), adaptor.getRhs());
auto result = b.create<TargetCGGIOp>(adaptor.getLhs().getType(),
adaptor.getLhs(), adaptor.getRhs());
rewriter.replaceOp(op, result);
return success();
}
Expand Down Expand Up @@ -277,10 +415,29 @@ struct ArithToCGGI : public impl::ArithToCGGIBase<ArithToCGGI> {
target.addIllegalDialect<mlir::arith::ArithDialect>();
target.addLegalOp<mlir::arith::ConstantOp>();

target.addDynamicallyLegalOp<mlir::arith::SubIOp, mlir::arith::AddIOp,
mlir::arith::MulIOp>([&](Operation *op) {
if (auto *defLhsOp = op->getOperand(0).getDefiningOp()) {
if (auto *defRhsOp = op->getOperand(1).getDefiningOp()) {
return !hasLWEAnnotation(defLhsOp) && !hasLWEAnnotation(defRhsOp) &&
allowedRemainArith(defLhsOp) && allowedRemainArith(defRhsOp);
}
}
return false;
});

target.addDynamicallyLegalOp<mlir::arith::ExtSIOp>([&](Operation *op) {
if (auto *defOp =
cast<mlir::arith::ExtSIOp>(op).getOperand().getDefiningOp()) {
return hasLWEAnnotation(defOp) || allowedRemainArith(defOp);
return !hasLWEAnnotation(defOp) && allowedRemainArith(defOp);
}
return false;
});

target.addDynamicallyLegalOp<mlir::arith::ExtUIOp>([&](Operation *op) {
if (auto *defOp =
cast<mlir::arith::ExtUIOp>(op).getOperand().getDefiningOp()) {
return !hasLWEAnnotation(defOp) && allowedRemainArith(defOp);
}
return false;
});
Expand All @@ -298,14 +455,16 @@ struct ArithToCGGI : public impl::ArithToCGGIBase<ArithToCGGI> {
// accepts Check if there is at least one Store op that is a constants
auto containsAnyStoreOp = llvm::any_of(op->getUses(), [&](OpOperand &op) {
if (auto defOp = dyn_cast<memref::StoreOp>(op.getOwner())) {
return allowedRemainArith(defOp.getValue().getDefiningOp());
return !hasLWEAnnotation(defOp.getValue().getDefiningOp()) &&
allowedRemainArith(defOp.getValue().getDefiningOp());
}
return false;
});
auto allStoreOpsAreArith =
llvm::all_of(op->getUses(), [&](OpOperand &op) {
if (auto defOp = dyn_cast<memref::StoreOp>(op.getOwner())) {
return allowedRemainArith(defOp.getValue().getDefiningOp());
return !hasLWEAnnotation(defOp.getValue().getDefiningOp()) &&
allowedRemainArith(defOp.getValue().getDefiningOp());
}
return true;
});
Expand Down Expand Up @@ -371,10 +530,17 @@ struct ArithToCGGI : public impl::ArithToCGGIBase<ArithToCGGI> {
});

patterns.add<
ConvertTruncIOp, ConvertExtUIOp, ConvertExtSIOp, ConvertShRUIOp,
ConvertTruncIOp, ConvertExtUIOp, ConvertExtSIOp, ConvertSelectOp,
ConvertCmpOp, ConvertSubOp,
ConvertShOp<mlir::arith::ShRSIOp, cggi::ScalarShiftRightOp>,
ConvertShOp<mlir::arith::ShRUIOp, cggi::ScalarShiftRightOp>,
ConvertShOp<mlir::arith::ShLIOp, cggi::ScalarShiftLeftOp>,
ConvertArithBinOp<mlir::arith::AddIOp, cggi::AddOp>,
ConvertArithBinOp<mlir::arith::MulIOp, cggi::MulOp>,
ConvertArithBinOp<mlir::arith::SubIOp, cggi::SubOp>,
ConvertArithBinOp<mlir::arith::MaxSIOp, cggi::MaxOp>,
ConvertArithBinOp<mlir::arith::MinSIOp, cggi::MinOp>,
ConvertArithBinOp<mlir::arith::MaxUIOp, cggi::MaxOp>,
ConvertArithBinOp<mlir::arith::MinUIOp, cggi::MinOp>,
ConvertAny<memref::LoadOp>, ConvertAllocOp,
ConvertAny<memref::DeallocOp>, ConvertAny<memref::SubViewOp>,
ConvertAny<memref::CopyOp>, ConvertAny<memref::StoreOp>,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -419,17 +419,17 @@ struct ConvertQuartMulI final : OpConversionPattern<mlir::arith::MulIOp> {
auto z01_p1 = b.create<cggi::AddOp>(elemTy, splitLhs[0], splitLhs[1]);
auto z01_p2 = b.create<cggi::AddOp>(elemTy, splitRhs[0], splitRhs[1]);
auto z01_m = b.create<cggi::MulOp>(elemTy, z01_p1, z01_p2);
auto z01_s = b.create<cggi::SubOp>(z01_m, z00);
auto z01 = b.create<cggi::SubOp>(z01_s, z02);
auto z01_s = b.create<cggi::SubOp>(elemTy, z01_m, z00);
auto z01 = b.create<cggi::SubOp>(elemTy, z01_s, z02);

// Second part I of Karatsuba algorithm
auto z1a0 = b.create<cggi::MulOp>(elemTy, splitLhs[0], splitRhs[2]);
auto z1a2 = b.create<cggi::MulOp>(elemTy, splitLhs[1], splitRhs[3]);
auto z1a1_p1 = b.create<cggi::AddOp>(elemTy, splitLhs[0], splitLhs[1]);
auto z1a1_p2 = b.create<cggi::AddOp>(elemTy, splitRhs[2], splitRhs[3]);
auto z1a1_m = b.create<cggi::MulOp>(elemTy, z1a1_p1, z1a1_p2);
auto z1a1_s = b.create<cggi::SubOp>(z1a1_m, z1a0);
auto z1a1 = b.create<cggi::SubOp>(z1a1_s, z1a2);
auto z1a1_s = b.create<cggi::SubOp>(elemTy, z1a1_m, z1a0);
auto z1a1 = b.create<cggi::SubOp>(elemTy, z1a1_s, z1a2);

// Second part II of Karatsuba algorithm
auto z1b0 = b.create<cggi::MulOp>(elemTy, splitLhs[2], splitRhs[0]);
Expand All @@ -438,7 +438,7 @@ struct ConvertQuartMulI final : OpConversionPattern<mlir::arith::MulIOp> {
auto z1b1_p2 = b.create<cggi::AddOp>(elemTy, splitRhs[0], splitRhs[1]);
auto z1b1_m = b.create<cggi::MulOp>(elemTy, z1b1_p1, z1b1_p2);
auto z1b1_s = b.create<cggi::SubOp>(elemTy, z1b1_m, z1b0);
auto z1b1 = b.create<cggi::SubOp>(z1b1_s, z1b2);
auto z1b1 = b.create<cggi::SubOp>(elemTy, z1b1_s, z1b2);

auto out2Kara = b.create<cggi::AddOp>(elemTy, z1a0, z1b0);
auto out2Carry = b.create<cggi::AddOp>(elemTy, out2Kara, z02);
Expand Down
1 change: 1 addition & 0 deletions lib/Dialect/CGGI/IR/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,7 @@ td_library(
# include from the heir-root to enable fully-qualified include-paths
includes = ["../../../.."],
deps = [
"@llvm-project//mlir:ArithOpsTdFiles",
"@llvm-project//mlir:BuiltinDialectTdFiles",
"@llvm-project//mlir:InferTypeOpInterfaceTdFiles",
"@llvm-project//mlir:OpBaseTdFiles",
Expand Down
9 changes: 5 additions & 4 deletions lib/Dialect/CGGI/IR/CGGIOps.h
Original file line number Diff line number Diff line change
Expand Up @@ -5,10 +5,11 @@
#include "lib/Dialect/CGGI/IR/CGGIDialect.h"
#include "lib/Dialect/HEIRInterfaces.h"
#include "lib/Dialect/LWE/IR/LWETypes.h"
#include "mlir/include/mlir/IR/BuiltinOps.h" // from @llvm-project
#include "mlir/include/mlir/IR/BuiltinTypes.h" // from @llvm-project
#include "mlir/include/mlir/IR/Dialect.h" // from @llvm-project
#include "mlir/include/mlir/IR/PatternMatch.h" // from @llvm-project
#include "mlir/include/mlir/Dialect/Arith/IR/Arith.h" // from @llvm-project
#include "mlir/include/mlir/IR/BuiltinOps.h" // from @llvm-project
#include "mlir/include/mlir/IR/BuiltinTypes.h" // from @llvm-project
#include "mlir/include/mlir/IR/Dialect.h" // from @llvm-project
#include "mlir/include/mlir/IR/PatternMatch.h" // from @llvm-project
#include "mlir/include/mlir/Interfaces/InferTypeOpInterface.h" // from @llvm-project

#define GET_OP_CLASSES
Expand Down
Loading

0 comments on commit 9795e2f

Please sign in to comment.