Skip to content

Commit

Permalink
Merge pull request #1403 from WoutLegiest:signed
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 726985149
  • Loading branch information
copybara-github committed Feb 14, 2025
2 parents b320e5f + a119c47 commit fae5e95
Show file tree
Hide file tree
Showing 8 changed files with 559 additions and 333 deletions.
220 changes: 193 additions & 27 deletions lib/Dialect/Arith/Conversions/ArithToCGGI/ArithToCGGI.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -70,15 +70,53 @@ static bool allowedRemainArith(Operation *op) {
}
return false;
})
.Case<mlir::arith::SubIOp, mlir::arith::AddIOp, mlir::arith::MulIOp>(
[](auto op) {
// This lambda will be called for any of the matched operation types
if (auto lhsDefOp = op.getOperand(0).getDefiningOp()) {
auto lshAllowed = allowedRemainArith(lhsDefOp);
if (auto rhsDefOp = op.getOperand(1).getDefiningOp()) {
auto rhsAllowed = allowedRemainArith(rhsDefOp);
return lshAllowed && rhsAllowed;
}
}
return false;
})
.Default([](Operation *) {
// Default case for operations that don't match any of the types
return false;
});
}

static bool hasLWEAnnotation(Operation *op) {
return static_cast<bool>(
op->getAttrOfType<mlir::StringAttr>("lwe_annotation"));
mlir::StringAttr check =
op->getAttrOfType<mlir::StringAttr>("lwe_annotation");

if (check) return true;

// Check recursively if a defining op has a LWE annotation
return llvm::TypeSwitch<Operation *, bool>(op)
.Case<mlir::arith::ExtUIOp, mlir::arith::ExtSIOp, mlir::arith::TruncIOp>(
[](auto op) {
if (auto *defOp = op.getIn().getDefiningOp()) {
return hasLWEAnnotation(defOp);
}
return op->template getAttrOfType<mlir::StringAttr>(
"lwe_annotation") != nullptr;
})
.Case<mlir::arith::SubIOp, mlir::arith::AddIOp, mlir::arith::MulIOp>(
[](auto op) {
// This lambda will be called for any of the matched operation types
if (auto lhsDefOp = op.getOperand(0).getDefiningOp()) {
auto lshAllowed = hasLWEAnnotation(lhsDefOp);
if (auto rhsDefOp = op.getOperand(1).getDefiningOp()) {
auto rhsAllowed = hasLWEAnnotation(rhsDefOp);
return lshAllowed || rhsAllowed;
}
}
return false;
})
.Default([](Operation *) { return false; });
}

static Value materializeTarget(OpBuilder &builder, Type type, ValueRange inputs,
Expand All @@ -89,10 +127,18 @@ static Value materializeTarget(OpBuilder &builder, Type type, ValueRange inputs,
llvm_unreachable(
"Non-integer types should never be the input to a materializeTarget.");

auto inValue = inputs.front().getDefiningOp<mlir::arith::ConstantOp>();
auto intAttr = cast<IntegerAttr>(inValue.getValueAttr());
if (auto inValue = inputs.front().getDefiningOp<mlir::arith::ConstantOp>()) {
auto intAttr = cast<IntegerAttr>(inValue.getValueAttr());

return builder.create<cggi::CreateTrivialOp>(loc, type, intAttr);
return builder.create<cggi::CreateTrivialOp>(loc, type, intAttr);
}
// Comes from function/loop argument: Trivial encrypt through LWE
auto encoding = cast<lwe::LWECiphertextType>(type).getEncoding();
auto ptxtTy = lwe::LWEPlaintextType::get(builder.getContext(), encoding);
return builder.create<lwe::TrivialEncryptOp>(
loc, type,
builder.create<lwe::EncodeOp>(loc, ptxtTy, inputs[0], encoding),
lwe::LWEParamsAttr());
}

class ArithToCGGITypeConverter : public TypeConverter {
Expand Down Expand Up @@ -175,18 +221,109 @@ struct ConvertExtSIOp : public OpConversionPattern<mlir::arith::ExtSIOp> {
}
};

struct ConvertShRUIOp : public OpConversionPattern<mlir::arith::ShRUIOp> {
ConvertShRUIOp(mlir::MLIRContext *context)
: OpConversionPattern<mlir::arith::ShRUIOp>(context) {}
struct ConvertCmpOp : public OpConversionPattern<mlir::arith::CmpIOp> {
ConvertCmpOp(mlir::MLIRContext *context)
: OpConversionPattern<mlir::arith::CmpIOp>(context) {}

using OpConversionPattern::OpConversionPattern;

LogicalResult matchAndRewrite(
mlir::arith::ShRUIOp op, OpAdaptor adaptor,
mlir::arith::CmpIOp op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
ImplicitLocOpBuilder b(op.getLoc(), rewriter);

auto cteShiftSizeOp = op.getRhs().getDefiningOp<mlir::arith::ConstantOp>();
auto lweBooleanType = lwe::LWECiphertextType::get(
op->getContext(),
lwe::UnspecifiedBitFieldEncodingAttr::get(op->getContext(), 1),
lwe::LWEParamsAttr());

if (auto lhsDefOp = op.getLhs().getDefiningOp()) {
if (!hasLWEAnnotation(lhsDefOp) && allowedRemainArith(lhsDefOp)) {
auto result = b.create<cggi::CmpOp>(lweBooleanType, op.getPredicate(),
adaptor.getRhs(), op.getLhs());
rewriter.replaceOp(op, result);
return success();
}
}

if (auto rhsDefOp = op.getRhs().getDefiningOp()) {
if (!hasLWEAnnotation(rhsDefOp) && allowedRemainArith(rhsDefOp)) {
auto result = b.create<cggi::CmpOp>(lweBooleanType, op.getPredicate(),
adaptor.getLhs(), op.getRhs());
rewriter.replaceOp(op, result);
return success();
}
}

auto cmpOp = b.create<cggi::CmpOp>(lweBooleanType, op.getPredicate(),
adaptor.getLhs(), adaptor.getRhs());

rewriter.replaceOp(op, cmpOp);
return success();
}
};

struct ConvertSubOp : public OpConversionPattern<mlir::arith::SubIOp> {
ConvertSubOp(mlir::MLIRContext *context)
: OpConversionPattern<mlir::arith::SubIOp>(context) {}

using OpConversionPattern::OpConversionPattern;

LogicalResult matchAndRewrite(
mlir::arith::SubIOp op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
ImplicitLocOpBuilder b(op.getLoc(), rewriter);

if (auto rhsDefOp = op.getRhs().getDefiningOp()) {
if (!hasLWEAnnotation(rhsDefOp) && allowedRemainArith(rhsDefOp)) {
auto result = b.create<cggi::SubOp>(adaptor.getLhs().getType(),
adaptor.getLhs(), op.getRhs());
rewriter.replaceOp(op, result);
return success();
}
}

auto subOp = b.create<cggi::SubOp>(adaptor.getLhs().getType(),
adaptor.getLhs(), adaptor.getRhs());
rewriter.replaceOp(op, subOp);
return success();
}
};

struct ConvertSelectOp : public OpConversionPattern<mlir::arith::SelectOp> {
ConvertSelectOp(mlir::MLIRContext *context)
: OpConversionPattern<mlir::arith::SelectOp>(context) {}

using OpConversionPattern::OpConversionPattern;

LogicalResult matchAndRewrite(
mlir::arith::SelectOp op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
ImplicitLocOpBuilder b(op.getLoc(), rewriter);

auto cmuxOp = b.create<cggi::SelectOp>(
adaptor.getTrueValue().getType(), adaptor.getCondition(),
adaptor.getTrueValue(), adaptor.getFalseValue());

rewriter.replaceOp(op, cmuxOp);
return success();
}
};

template <typename SourceArithShOp, typename TargetCGGIShOp>
struct ConvertShOp : public OpConversionPattern<SourceArithShOp> {
ConvertShOp(mlir::MLIRContext *context)
: OpConversionPattern<SourceArithShOp>(context) {}

using OpConversionPattern<SourceArithShOp>::OpConversionPattern;

LogicalResult matchAndRewrite(
SourceArithShOp op, typename SourceArithShOp::Adaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
ImplicitLocOpBuilder b(op.getLoc(), rewriter);

auto cteShiftSizeOp =
op.getRhs().template getDefiningOp<mlir::arith::ConstantOp>();

if (cteShiftSizeOp) {
auto outputType = adaptor.getLhs().getType();
Expand All @@ -198,14 +335,15 @@ struct ConvertShRUIOp : public OpConversionPattern<mlir::arith::ShRUIOp> {
auto inputValue =
mlir::IntegerAttr::get(rewriter.getIndexType(), (int8_t)shiftAmount);

auto shiftOp = b.create<cggi::ScalarShiftRightOp>(
outputType, adaptor.getLhs(), inputValue);
auto shiftOp =
b.create<TargetCGGIShOp>(outputType, adaptor.getLhs(), inputValue);
rewriter.replaceOp(op, shiftOp);

return success();
}

cteShiftSizeOp = op.getLhs().getDefiningOp<mlir::arith::ConstantOp>();
cteShiftSizeOp =
op.getLhs().template getDefiningOp<mlir::arith::ConstantOp>();

auto outputType = adaptor.getRhs().getType();

Expand All @@ -215,15 +353,15 @@ struct ConvertShRUIOp : public OpConversionPattern<mlir::arith::ShRUIOp> {
auto inputValue =
mlir::IntegerAttr::get(rewriter.getIndexType(), shiftAmount);

auto shiftOp = b.create<cggi::ScalarShiftRightOp>(
outputType, adaptor.getLhs(), inputValue);
auto shiftOp =
b.create<TargetCGGIShOp>(outputType, adaptor.getLhs(), inputValue);
rewriter.replaceOp(op, shiftOp);

return success();
}
};

template <typename SourceArithOp, typename TargetModArithOp>
template <typename SourceArithOp, typename TargetCGGIOp>
struct ConvertArithBinOp : public OpConversionPattern<SourceArithOp> {
ConvertArithBinOp(mlir::MLIRContext *context)
: OpConversionPattern<SourceArithOp>(context) {}
Expand All @@ -237,24 +375,24 @@ struct ConvertArithBinOp : public OpConversionPattern<SourceArithOp> {

if (auto lhsDefOp = op.getLhs().getDefiningOp()) {
if (!hasLWEAnnotation(lhsDefOp) && allowedRemainArith(lhsDefOp)) {
auto result = b.create<TargetModArithOp>(adaptor.getRhs().getType(),
adaptor.getRhs(), op.getLhs());
auto result = b.create<TargetCGGIOp>(adaptor.getRhs().getType(),
adaptor.getRhs(), op.getLhs());
rewriter.replaceOp(op, result);
return success();
}
}

if (auto rhsDefOp = op.getRhs().getDefiningOp()) {
if (!hasLWEAnnotation(rhsDefOp) && allowedRemainArith(rhsDefOp)) {
auto result = b.create<TargetModArithOp>(adaptor.getLhs().getType(),
adaptor.getLhs(), op.getRhs());
auto result = b.create<TargetCGGIOp>(adaptor.getLhs().getType(),
adaptor.getLhs(), op.getRhs());
rewriter.replaceOp(op, result);
return success();
}
}

auto result = b.create<TargetModArithOp>(
adaptor.getLhs().getType(), adaptor.getLhs(), adaptor.getRhs());
auto result = b.create<TargetCGGIOp>(adaptor.getLhs().getType(),
adaptor.getLhs(), adaptor.getRhs());
rewriter.replaceOp(op, result);
return success();
}
Expand Down Expand Up @@ -296,10 +434,29 @@ struct ArithToCGGI : public impl::ArithToCGGIBase<ArithToCGGI> {
target.addIllegalDialect<mlir::arith::ArithDialect>();
target.addLegalOp<mlir::arith::ConstantOp>();

target.addDynamicallyLegalOp<mlir::arith::SubIOp, mlir::arith::AddIOp,
mlir::arith::MulIOp>([&](Operation *op) {
if (auto *defLhsOp = op->getOperand(0).getDefiningOp()) {
if (auto *defRhsOp = op->getOperand(1).getDefiningOp()) {
return !hasLWEAnnotation(defLhsOp) && !hasLWEAnnotation(defRhsOp) &&
allowedRemainArith(defLhsOp) && allowedRemainArith(defRhsOp);
}
}
return false;
});

target.addDynamicallyLegalOp<mlir::arith::ExtSIOp>([&](Operation *op) {
if (auto *defOp =
cast<mlir::arith::ExtSIOp>(op).getOperand().getDefiningOp()) {
return hasLWEAnnotation(defOp) || allowedRemainArith(defOp);
return !hasLWEAnnotation(defOp) && allowedRemainArith(defOp);
}
return false;
});

target.addDynamicallyLegalOp<mlir::arith::ExtUIOp>([&](Operation *op) {
if (auto *defOp =
cast<mlir::arith::ExtUIOp>(op).getOperand().getDefiningOp()) {
return !hasLWEAnnotation(defOp) && allowedRemainArith(defOp);
}
return false;
});
Expand All @@ -317,14 +474,16 @@ struct ArithToCGGI : public impl::ArithToCGGIBase<ArithToCGGI> {
// accepts Check if there is at least one Store op that is a constants
auto containsAnyStoreOp = llvm::any_of(op->getUses(), [&](OpOperand &op) {
if (auto defOp = dyn_cast<memref::StoreOp>(op.getOwner())) {
return allowedRemainArith(defOp.getValue().getDefiningOp());
return !hasLWEAnnotation(defOp.getValue().getDefiningOp()) &&
allowedRemainArith(defOp.getValue().getDefiningOp());
}
return false;
});
auto allStoreOpsAreArith =
llvm::all_of(op->getUses(), [&](OpOperand &op) {
if (auto defOp = dyn_cast<memref::StoreOp>(op.getOwner())) {
return allowedRemainArith(defOp.getValue().getDefiningOp());
return !hasLWEAnnotation(defOp.getValue().getDefiningOp()) &&
allowedRemainArith(defOp.getValue().getDefiningOp());
}
return true;
});
Expand Down Expand Up @@ -390,10 +549,17 @@ struct ArithToCGGI : public impl::ArithToCGGIBase<ArithToCGGI> {
});

patterns.add<
ConvertTruncIOp, ConvertExtUIOp, ConvertExtSIOp, ConvertShRUIOp,
ConvertTruncIOp, ConvertExtUIOp, ConvertExtSIOp, ConvertSelectOp,
ConvertCmpOp, ConvertSubOp,
ConvertShOp<mlir::arith::ShRSIOp, cggi::ScalarShiftRightOp>,
ConvertShOp<mlir::arith::ShRUIOp, cggi::ScalarShiftRightOp>,
ConvertShOp<mlir::arith::ShLIOp, cggi::ScalarShiftLeftOp>,
ConvertArithBinOp<mlir::arith::AddIOp, cggi::AddOp>,
ConvertArithBinOp<mlir::arith::MulIOp, cggi::MulOp>,
ConvertArithBinOp<mlir::arith::SubIOp, cggi::SubOp>,
ConvertArithBinOp<mlir::arith::MaxSIOp, cggi::MaxOp>,
ConvertArithBinOp<mlir::arith::MinSIOp, cggi::MinOp>,
ConvertArithBinOp<mlir::arith::MaxUIOp, cggi::MaxOp>,
ConvertArithBinOp<mlir::arith::MinUIOp, cggi::MinOp>,
ConvertAny<memref::LoadOp>, ConvertAllocOp,
ConvertAny<memref::DeallocOp>, ConvertAny<memref::SubViewOp>,
ConvertAny<memref::CopyOp>, ConvertAny<memref::StoreOp>,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -434,17 +434,17 @@ struct ConvertQuartMulI final : OpConversionPattern<mlir::arith::MulIOp> {
auto z01_p1 = b.create<cggi::AddOp>(elemTy, splitLhs[0], splitLhs[1]);
auto z01_p2 = b.create<cggi::AddOp>(elemTy, splitRhs[0], splitRhs[1]);
auto z01_m = b.create<cggi::MulOp>(elemTy, z01_p1, z01_p2);
auto z01_s = b.create<cggi::SubOp>(z01_m, z00);
auto z01 = b.create<cggi::SubOp>(z01_s, z02);
auto z01_s = b.create<cggi::SubOp>(elemTy, z01_m, z00);
auto z01 = b.create<cggi::SubOp>(elemTy, z01_s, z02);

// Second part I of Karatsuba algorithm
auto z1a0 = b.create<cggi::MulOp>(elemTy, splitLhs[0], splitRhs[2]);
auto z1a2 = b.create<cggi::MulOp>(elemTy, splitLhs[1], splitRhs[3]);
auto z1a1_p1 = b.create<cggi::AddOp>(elemTy, splitLhs[0], splitLhs[1]);
auto z1a1_p2 = b.create<cggi::AddOp>(elemTy, splitRhs[2], splitRhs[3]);
auto z1a1_m = b.create<cggi::MulOp>(elemTy, z1a1_p1, z1a1_p2);
auto z1a1_s = b.create<cggi::SubOp>(z1a1_m, z1a0);
auto z1a1 = b.create<cggi::SubOp>(z1a1_s, z1a2);
auto z1a1_s = b.create<cggi::SubOp>(elemTy, z1a1_m, z1a0);
auto z1a1 = b.create<cggi::SubOp>(elemTy, z1a1_s, z1a2);

// Second part II of Karatsuba algorithm
auto z1b0 = b.create<cggi::MulOp>(elemTy, splitLhs[2], splitRhs[0]);
Expand All @@ -453,7 +453,7 @@ struct ConvertQuartMulI final : OpConversionPattern<mlir::arith::MulIOp> {
auto z1b1_p2 = b.create<cggi::AddOp>(elemTy, splitRhs[0], splitRhs[1]);
auto z1b1_m = b.create<cggi::MulOp>(elemTy, z1b1_p1, z1b1_p2);
auto z1b1_s = b.create<cggi::SubOp>(elemTy, z1b1_m, z1b0);
auto z1b1 = b.create<cggi::SubOp>(z1b1_s, z1b2);
auto z1b1 = b.create<cggi::SubOp>(elemTy, z1b1_s, z1b2);

auto out2Kara = b.create<cggi::AddOp>(elemTy, z1a0, z1b0);
auto out2Carry = b.create<cggi::AddOp>(elemTy, out2Kara, z02);
Expand Down
4 changes: 4 additions & 0 deletions lib/Dialect/CGGI/IR/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@ cc_library(
"@heir//lib/Dialect:HEIRInterfaces",
"@heir//lib/Dialect/LWE/IR:Dialect",
"@llvm-project//llvm:Support",
"@llvm-project//mlir:ArithDialect",
"@llvm-project//mlir:IR",
"@llvm-project//mlir:InferTypeOpInterface",
"@llvm-project//mlir:Support",
Expand All @@ -37,13 +38,16 @@ td_library(
srcs = [
"BooleanGates.td",
"CGGIAttributes.td",
"CGGIBinOps.td",
"CGGIDialect.td",
"CGGIEnums.td",
"CGGIOps.td",
"CGGIPBSOps.td",
],
# include from the heir-root to enable fully-qualified include-paths
includes = ["../../../.."],
deps = [
"@llvm-project//mlir:ArithOpsTdFiles",
"@llvm-project//mlir:BuiltinDialectTdFiles",
"@llvm-project//mlir:InferTypeOpInterfaceTdFiles",
"@llvm-project//mlir:OpBaseTdFiles",
Expand Down
Loading

0 comments on commit fae5e95

Please sign in to comment.