Skip to content

Commit 8f0bff7

Browse files
Merge pull request #1288 from WoutLegiest:ptxt-ctxt-hl
PiperOrigin-RevId: 721013199
2 parents 77842a6 + 24299f0 commit 8f0bff7

File tree

4 files changed

+37
-50
lines changed

4 files changed

+37
-50
lines changed

lib/Dialect/Arith/Conversions/ArithToCGGI/ArithToCGGI.cpp

Lines changed: 2 additions & 33 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,5 @@
11
#include "lib/Dialect/Arith/Conversions/ArithToCGGI/ArithToCGGI.h"
22

3-
#include <cstdint>
4-
53
#include "lib/Dialect/CGGI/IR/CGGIDialect.h"
64
#include "lib/Dialect/CGGI/IR/CGGIOps.h"
75
#include "lib/Dialect/LWE/IR/LWETypes.h"
@@ -48,35 +46,6 @@ class ArithToCGGITypeConverter : public TypeConverter {
4846
}
4947
};
5048

51-
struct ConvertConstantOp : public OpConversionPattern<mlir::arith::ConstantOp> {
52-
ConvertConstantOp(mlir::MLIRContext *context)
53-
: OpConversionPattern<mlir::arith::ConstantOp>(context) {}
54-
55-
using OpConversionPattern::OpConversionPattern;
56-
57-
LogicalResult matchAndRewrite(
58-
mlir::arith::ConstantOp op, OpAdaptor adaptor,
59-
ConversionPatternRewriter &rewriter) const override {
60-
if (isa<IndexType>(op.getValue().getType())) {
61-
return failure();
62-
}
63-
ImplicitLocOpBuilder b(op.getLoc(), rewriter);
64-
65-
auto intValue = cast<IntegerAttr>(op.getValue()).getValue().getSExtValue();
66-
auto inputValue = mlir::IntegerAttr::get(op.getType(), intValue);
67-
68-
auto encoding = lwe::UnspecifiedBitFieldEncodingAttr::get(
69-
op->getContext(), op.getValue().getType().getIntOrFloatBitWidth());
70-
auto lweType = lwe::LWECiphertextType::get(op->getContext(), encoding,
71-
lwe::LWEParamsAttr());
72-
73-
auto encrypt = b.create<cggi::CreateTrivialOp>(lweType, inputValue);
74-
75-
rewriter.replaceOp(op, encrypt);
76-
return success();
77-
}
78-
};
79-
8049
struct ConvertTruncIOp : public OpConversionPattern<mlir::arith::TruncIOp> {
8150
ConvertTruncIOp(mlir::MLIRContext *context)
8251
: OpConversionPattern<mlir::arith::TruncIOp>(context) {}
@@ -301,8 +270,8 @@ struct ArithToCGGI : public impl::ArithToCGGIBase<ArithToCGGI> {
301270
});
302271

303272
patterns.add<
304-
ConvertConstantOp, ConvertTruncIOp, ConvertExtUIOp, ConvertExtSIOp,
305-
ConvertShRUIOp, ConvertArithBinOp<mlir::arith::AddIOp, cggi::AddOp>,
273+
ConvertTruncIOp, ConvertExtUIOp, ConvertExtSIOp, ConvertShRUIOp,
274+
ConvertArithBinOp<mlir::arith::AddIOp, cggi::AddOp>,
306275
ConvertArithBinOp<mlir::arith::MulIOp, cggi::MulOp>,
307276
ConvertArithBinOp<mlir::arith::SubIOp, cggi::SubOp>,
308277
ConvertAny<memref::LoadOp>, ConvertAny<memref::AllocOp>,

lib/Dialect/CGGI/Conversions/CGGIToTfheRust/CGGIToTfheRust.cpp

Lines changed: 5 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -159,12 +159,10 @@ struct ConvertLut3Op : public OpConversionPattern<cggi::Lut3Op> {
159159
serverKey, adaptor.getC(), b.getIndexAttr(2));
160160
auto shiftedB = b.create<tfhe_rust::ScalarLeftShiftOp>(
161161
serverKey, adaptor.getB(), b.getIndexAttr(1));
162-
auto outputType =
163-
getTypeConverter()->convertType(shiftedB.getResult().getType());
164-
auto summedBC =
165-
b.create<tfhe_rust::AddOp>(outputType, serverKey, shiftedC, shiftedB);
166-
auto summedABC = b.create<tfhe_rust::AddOp>(outputType, serverKey, summedBC,
167-
adaptor.getA());
162+
auto summedBC = b.create<tfhe_rust::AddOp>(adaptor.getB().getType(),
163+
serverKey, shiftedC, shiftedB);
164+
auto summedABC = b.create<tfhe_rust::AddOp>(
165+
adaptor.getB().getType(), serverKey, summedBC, adaptor.getA());
168166

169167
rewriter.replaceOp(
170168
op, b.create<tfhe_rust::ApplyLookupTableOp>(serverKey, summedABC, lut));
@@ -243,6 +241,7 @@ struct ConvertCGGITRBinOp : public OpConversionPattern<BinOp> {
243241
Value serverKey = result.value();
244242
CGGIToTfheRustTypeConverter typeConverter(op->getContext());
245243
auto outputType = typeConverter.convertType(op.getResult().getType());
244+
246245
rewriter.replaceOp(
247246
op, b.create<TfheRustBinOp>(outputType, serverKey, adaptor.getLhs(),
248247
adaptor.getRhs()));

lib/Target/TfheRustHL/TfheRustHLEmitter.cpp

Lines changed: 29 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -308,7 +308,12 @@ LogicalResult TfheRustHLEmitter::printOperation(CreateTrivialOp op) {
308308

309309
os << "FheUint" << getTfheRustBitWidth(op.getResult().getType())
310310
<< "::try_encrypt_trivial("
311-
<< variableNames->getNameForValue(op.getValue()) << ").unwrap();\n";
311+
<< variableNames->getNameForValue(op.getValue());
312+
313+
if (op.getValue().getType().isSigned())
314+
os << " as u" << getTfheRustBitWidth(op.getResult().getType());
315+
316+
os << ").unwrap();\n";
312317
return success();
313318
}
314319

@@ -359,7 +364,7 @@ LogicalResult TfheRustHLEmitter::printOperation(arith::ConstantOp op) {
359364
// By default, it emits an unsigned integer.
360365
emitAssignPrefix(op.getResult());
361366
if (auto intAttr = dyn_cast<IntegerAttr>(valueAttr)) {
362-
os << intAttr.getValue().abs() << "u64;\n";
367+
os << intAttr.getValue().abs() << convertType(op.getType()) << ";\n";
363368
} else {
364369
return op.emitError() << "Unknown constant type " << valueAttr.getType();
365370
}
@@ -383,6 +388,17 @@ LogicalResult TfheRustHLEmitter::printBinaryOp(::mlir::Value result,
383388
std::string_view op) {
384389
emitAssignPrefix(result);
385390

391+
if (auto cteOp = dyn_cast<mlir::arith::ConstantOp>(rhs.getDefiningOp())) {
392+
auto intValue =
393+
cast<IntegerAttr>(cteOp.getValue()).getValue().getZExtValue();
394+
os << checkOrigin(lhs) << variableNames->getNameForValue(lhs) << " " << op
395+
<< " " << intValue << "u" << cteOp.getType().getIntOrFloatBitWidth()
396+
<< ";\n";
397+
return success();
398+
}
399+
400+
// Note: arith.constant op requires signless integer types, but here we
401+
// manually emit an unsigned integer type.
386402
os << checkOrigin(lhs) << variableNames->getNameForValue(lhs) << " " << op
387403
<< " " << checkOrigin(rhs) << variableNames->getNameForValue(rhs) << ";\n";
388404
return success();
@@ -430,8 +446,8 @@ LogicalResult TfheRustHLEmitter::printOperation(memref::AllocOp op) {
430446
if (failed(emitType(op.getMemref().getType().getElementType()))) {
431447
return op.emitOpError() << "Failed to get memref element type";
432448
}
433-
434449
os << "> = BTreeMap::new();\n";
450+
435451
return success();
436452
}
437453

@@ -463,12 +479,11 @@ LogicalResult TfheRustHLEmitter::printOperation(memref::LoadOp op) {
463479
// We assume here that the indices are SSA values (not integer attributes).
464480
if (isa<BlockArgument>(op.getMemref())) {
465481
emitAssignPrefix(op.getResult());
466-
os << "&" << variableNames->getNameForValue(op.getMemRef()) << "["
467-
<< flattenIndexExpression(op.getMemRefType(), op.getIndices(),
468-
[&](Value value) {
469-
return variableNames->getNameForValue(value);
470-
})
471-
<< "];\n";
482+
os << "&" << variableNames->getNameForValue(op.getMemRef());
483+
for (auto value : op.getIndices()) {
484+
os << "[" << variableNames->getNameForValue(value) << "]";
485+
}
486+
os << ";\n";
472487
return success();
473488
}
474489

@@ -586,6 +601,7 @@ LogicalResult TfheRustHLEmitter::printOperation(tensor::InsertOp op) {
586601
return std::string(prefix) + variableNames->getNameForValue(value) +
587602
cloneStr;
588603
}) << "];\n";
604+
589605
return success();
590606
}
591607

@@ -662,9 +678,12 @@ FailureOr<std::string> TfheRustHLEmitter::convertType(Type type) {
662678
}
663679
auto width = getRustIntegerType(type.getWidth());
664680
if (failed(width)) return failure();
665-
return (type.isUnsigned() ? std::string("u") : "") + "i" +
681+
return (type.isSigned() ? std::string("i") : std::string("u")) +
666682
std::to_string(width.value());
667683
})
684+
.Case<IndexType>([&](IndexType type) -> FailureOr<std::string> {
685+
return std::string("usize");
686+
})
668687
.Case<LookupTableType>(
669688
[&](auto type) { return std::string("LookupTableOwned"); })
670689
.Default([&](Type &) { return failure(); });

tests/Examples/tfhe_rust_hl/cpu/README.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,7 @@ if you overrode the default option when installing Cargo.
1717

1818
```bash
1919
bazel query "filter('.mlir.test$', //tests/Examples/tfhe_rust_hl/cpu/...)" \
20-
| xargs bazel test --noincompatible_strict_action_env -test_timeout=180 --sandbox_writable_path=$HOME/.cargo "$@"
20+
| xargs bazel test --noincompatible_strict_action_env --test_timeout=180 --sandbox_writable_path=$HOME/.cargo "$@"
2121
```
2222

2323
The `manual` tag is added to the targets in this directory to ensure that they

0 commit comments

Comments
 (0)