Skip to content

Commit

Permalink
Merge pull request #1288 from WoutLegiest:ptxt-ctxt-hl
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 721013199
  • Loading branch information
copybara-github committed Jan 29, 2025
2 parents 77842a6 + 24299f0 commit 8f0bff7
Show file tree
Hide file tree
Showing 4 changed files with 37 additions and 50 deletions.
35 changes: 2 additions & 33 deletions lib/Dialect/Arith/Conversions/ArithToCGGI/ArithToCGGI.cpp
Original file line number Diff line number Diff line change
@@ -1,7 +1,5 @@
#include "lib/Dialect/Arith/Conversions/ArithToCGGI/ArithToCGGI.h"

#include <cstdint>

#include "lib/Dialect/CGGI/IR/CGGIDialect.h"
#include "lib/Dialect/CGGI/IR/CGGIOps.h"
#include "lib/Dialect/LWE/IR/LWETypes.h"
Expand Down Expand Up @@ -48,35 +46,6 @@ class ArithToCGGITypeConverter : public TypeConverter {
}
};

struct ConvertConstantOp : public OpConversionPattern<mlir::arith::ConstantOp> {
ConvertConstantOp(mlir::MLIRContext *context)
: OpConversionPattern<mlir::arith::ConstantOp>(context) {}

using OpConversionPattern::OpConversionPattern;

LogicalResult matchAndRewrite(
mlir::arith::ConstantOp op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
if (isa<IndexType>(op.getValue().getType())) {
return failure();
}
ImplicitLocOpBuilder b(op.getLoc(), rewriter);

auto intValue = cast<IntegerAttr>(op.getValue()).getValue().getSExtValue();
auto inputValue = mlir::IntegerAttr::get(op.getType(), intValue);

auto encoding = lwe::UnspecifiedBitFieldEncodingAttr::get(
op->getContext(), op.getValue().getType().getIntOrFloatBitWidth());
auto lweType = lwe::LWECiphertextType::get(op->getContext(), encoding,
lwe::LWEParamsAttr());

auto encrypt = b.create<cggi::CreateTrivialOp>(lweType, inputValue);

rewriter.replaceOp(op, encrypt);
return success();
}
};

struct ConvertTruncIOp : public OpConversionPattern<mlir::arith::TruncIOp> {
ConvertTruncIOp(mlir::MLIRContext *context)
: OpConversionPattern<mlir::arith::TruncIOp>(context) {}
Expand Down Expand Up @@ -301,8 +270,8 @@ struct ArithToCGGI : public impl::ArithToCGGIBase<ArithToCGGI> {
});

patterns.add<
ConvertConstantOp, ConvertTruncIOp, ConvertExtUIOp, ConvertExtSIOp,
ConvertShRUIOp, ConvertArithBinOp<mlir::arith::AddIOp, cggi::AddOp>,
ConvertTruncIOp, ConvertExtUIOp, ConvertExtSIOp, ConvertShRUIOp,
ConvertArithBinOp<mlir::arith::AddIOp, cggi::AddOp>,
ConvertArithBinOp<mlir::arith::MulIOp, cggi::MulOp>,
ConvertArithBinOp<mlir::arith::SubIOp, cggi::SubOp>,
ConvertAny<memref::LoadOp>, ConvertAny<memref::AllocOp>,
Expand Down
11 changes: 5 additions & 6 deletions lib/Dialect/CGGI/Conversions/CGGIToTfheRust/CGGIToTfheRust.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -159,12 +159,10 @@ struct ConvertLut3Op : public OpConversionPattern<cggi::Lut3Op> {
serverKey, adaptor.getC(), b.getIndexAttr(2));
auto shiftedB = b.create<tfhe_rust::ScalarLeftShiftOp>(
serverKey, adaptor.getB(), b.getIndexAttr(1));
auto outputType =
getTypeConverter()->convertType(shiftedB.getResult().getType());
auto summedBC =
b.create<tfhe_rust::AddOp>(outputType, serverKey, shiftedC, shiftedB);
auto summedABC = b.create<tfhe_rust::AddOp>(outputType, serverKey, summedBC,
adaptor.getA());
auto summedBC = b.create<tfhe_rust::AddOp>(adaptor.getB().getType(),
serverKey, shiftedC, shiftedB);
auto summedABC = b.create<tfhe_rust::AddOp>(
adaptor.getB().getType(), serverKey, summedBC, adaptor.getA());

rewriter.replaceOp(
op, b.create<tfhe_rust::ApplyLookupTableOp>(serverKey, summedABC, lut));
Expand Down Expand Up @@ -243,6 +241,7 @@ struct ConvertCGGITRBinOp : public OpConversionPattern<BinOp> {
Value serverKey = result.value();
CGGIToTfheRustTypeConverter typeConverter(op->getContext());
auto outputType = typeConverter.convertType(op.getResult().getType());

rewriter.replaceOp(
op, b.create<TfheRustBinOp>(outputType, serverKey, adaptor.getLhs(),
adaptor.getRhs()));
Expand Down
39 changes: 29 additions & 10 deletions lib/Target/TfheRustHL/TfheRustHLEmitter.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -308,7 +308,12 @@ LogicalResult TfheRustHLEmitter::printOperation(CreateTrivialOp op) {

os << "FheUint" << getTfheRustBitWidth(op.getResult().getType())
<< "::try_encrypt_trivial("
<< variableNames->getNameForValue(op.getValue()) << ").unwrap();\n";
<< variableNames->getNameForValue(op.getValue());

if (op.getValue().getType().isSigned())
os << " as u" << getTfheRustBitWidth(op.getResult().getType());

os << ").unwrap();\n";
return success();
}

Expand Down Expand Up @@ -359,7 +364,7 @@ LogicalResult TfheRustHLEmitter::printOperation(arith::ConstantOp op) {
// By default, it emits an unsigned integer.
emitAssignPrefix(op.getResult());
if (auto intAttr = dyn_cast<IntegerAttr>(valueAttr)) {
os << intAttr.getValue().abs() << "u64;\n";
os << intAttr.getValue().abs() << convertType(op.getType()) << ";\n";
} else {
return op.emitError() << "Unknown constant type " << valueAttr.getType();
}
Expand All @@ -383,6 +388,17 @@ LogicalResult TfheRustHLEmitter::printBinaryOp(::mlir::Value result,
std::string_view op) {
emitAssignPrefix(result);

if (auto cteOp = dyn_cast<mlir::arith::ConstantOp>(rhs.getDefiningOp())) {
auto intValue =
cast<IntegerAttr>(cteOp.getValue()).getValue().getZExtValue();
os << checkOrigin(lhs) << variableNames->getNameForValue(lhs) << " " << op
<< " " << intValue << "u" << cteOp.getType().getIntOrFloatBitWidth()
<< ";\n";
return success();
}

// Note: arith.constant op requires signless integer types, but here we
// manually emit an unsigned integer type.
os << checkOrigin(lhs) << variableNames->getNameForValue(lhs) << " " << op
<< " " << checkOrigin(rhs) << variableNames->getNameForValue(rhs) << ";\n";
return success();
Expand Down Expand Up @@ -430,8 +446,8 @@ LogicalResult TfheRustHLEmitter::printOperation(memref::AllocOp op) {
if (failed(emitType(op.getMemref().getType().getElementType()))) {
return op.emitOpError() << "Failed to get memref element type";
}

os << "> = BTreeMap::new();\n";

return success();
}

Expand Down Expand Up @@ -463,12 +479,11 @@ LogicalResult TfheRustHLEmitter::printOperation(memref::LoadOp op) {
// We assume here that the indices are SSA values (not integer attributes).
if (isa<BlockArgument>(op.getMemref())) {
emitAssignPrefix(op.getResult());
os << "&" << variableNames->getNameForValue(op.getMemRef()) << "["
<< flattenIndexExpression(op.getMemRefType(), op.getIndices(),
[&](Value value) {
return variableNames->getNameForValue(value);
})
<< "];\n";
os << "&" << variableNames->getNameForValue(op.getMemRef());
for (auto value : op.getIndices()) {
os << "[" << variableNames->getNameForValue(value) << "]";
}
os << ";\n";
return success();
}

Expand Down Expand Up @@ -586,6 +601,7 @@ LogicalResult TfheRustHLEmitter::printOperation(tensor::InsertOp op) {
return std::string(prefix) + variableNames->getNameForValue(value) +
cloneStr;
}) << "];\n";

return success();
}

Expand Down Expand Up @@ -662,9 +678,12 @@ FailureOr<std::string> TfheRustHLEmitter::convertType(Type type) {
}
auto width = getRustIntegerType(type.getWidth());
if (failed(width)) return failure();
return (type.isUnsigned() ? std::string("u") : "") + "i" +
return (type.isSigned() ? std::string("i") : std::string("u")) +
std::to_string(width.value());
})
.Case<IndexType>([&](IndexType type) -> FailureOr<std::string> {
return std::string("usize");
})
.Case<LookupTableType>(
[&](auto type) { return std::string("LookupTableOwned"); })
.Default([&](Type &) { return failure(); });
Expand Down
2 changes: 1 addition & 1 deletion tests/Examples/tfhe_rust_hl/cpu/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@ if you overrode the default option when installing Cargo.

```bash
bazel query "filter('.mlir.test$', //tests/Examples/tfhe_rust_hl/cpu/...)" \
| xargs bazel test --noincompatible_strict_action_env -test_timeout=180 --sandbox_writable_path=$HOME/.cargo "$@"
| xargs bazel test --noincompatible_strict_action_env --test_timeout=180 --sandbox_writable_path=$HOME/.cargo "$@"
```

The `manual` tag is added to the targets in this directory to ensure that they
Expand Down

0 comments on commit 8f0bff7

Please sign in to comment.