From 24299f031b592fca3083ad418a81f85a0899b42e Mon Sep 17 00:00:00 2001 From: WoutLegiest Date: Wed, 25 Dec 2024 01:30:54 +0000 Subject: [PATCH] Tfhe-rs Ops with the scalar definition --- .../Conversions/ArithToCGGI/ArithToCGGI.cpp | 36 ++--------------- .../CGGIToTfheRust/CGGIToTfheRust.cpp | 11 +++--- lib/Target/TfheRustHL/TfheRustHLEmitter.cpp | 39 ++++++++++++++----- tests/Examples/tfhe_rust_hl/cpu/README.md | 2 +- 4 files changed, 39 insertions(+), 49 deletions(-) diff --git a/lib/Dialect/Arith/Conversions/ArithToCGGI/ArithToCGGI.cpp b/lib/Dialect/Arith/Conversions/ArithToCGGI/ArithToCGGI.cpp index 63c8aa4f7..93069a849 100644 --- a/lib/Dialect/Arith/Conversions/ArithToCGGI/ArithToCGGI.cpp +++ b/lib/Dialect/Arith/Conversions/ArithToCGGI/ArithToCGGI.cpp @@ -1,6 +1,7 @@ #include "lib/Dialect/Arith/Conversions/ArithToCGGI/ArithToCGGI.h" -#include +#include +#include #include "lib/Dialect/CGGI/IR/CGGIDialect.h" #include "lib/Dialect/CGGI/IR/CGGIOps.h" @@ -48,35 +49,6 @@ class ArithToCGGITypeConverter : public TypeConverter { } }; -struct ConvertConstantOp : public OpConversionPattern { - ConvertConstantOp(mlir::MLIRContext *context) - : OpConversionPattern(context) {} - - using OpConversionPattern::OpConversionPattern; - - LogicalResult matchAndRewrite( - mlir::arith::ConstantOp op, OpAdaptor adaptor, - ConversionPatternRewriter &rewriter) const override { - if (isa(op.getValue().getType())) { - return failure(); - } - ImplicitLocOpBuilder b(op.getLoc(), rewriter); - - auto intValue = cast(op.getValue()).getValue().getSExtValue(); - auto inputValue = mlir::IntegerAttr::get(op.getType(), intValue); - - auto encoding = lwe::UnspecifiedBitFieldEncodingAttr::get( - op->getContext(), op.getValue().getType().getIntOrFloatBitWidth()); - auto lweType = lwe::LWECiphertextType::get(op->getContext(), encoding, - lwe::LWEParamsAttr()); - - auto encrypt = b.create(lweType, inputValue); - - rewriter.replaceOp(op, encrypt); - return success(); - } -}; - struct ConvertTruncIOp : public OpConversionPattern { ConvertTruncIOp(mlir::MLIRContext *context) : OpConversionPattern(context) {} @@ -301,8 +273,8 @@ struct ArithToCGGI : public impl::ArithToCGGIBase { }); patterns.add< - ConvertConstantOp, ConvertTruncIOp, ConvertExtUIOp, ConvertExtSIOp, - ConvertShRUIOp, ConvertArithBinOp, + ConvertTruncIOp, ConvertExtUIOp, ConvertExtSIOp, ConvertShRUIOp, + ConvertArithBinOp, ConvertArithBinOp, ConvertArithBinOp, ConvertAny, ConvertAny, diff --git a/lib/Dialect/CGGI/Conversions/CGGIToTfheRust/CGGIToTfheRust.cpp b/lib/Dialect/CGGI/Conversions/CGGIToTfheRust/CGGIToTfheRust.cpp index d1129e879..4e757db6b 100644 --- a/lib/Dialect/CGGI/Conversions/CGGIToTfheRust/CGGIToTfheRust.cpp +++ b/lib/Dialect/CGGI/Conversions/CGGIToTfheRust/CGGIToTfheRust.cpp @@ -175,12 +175,10 @@ struct ConvertLut3Op : public OpConversionPattern { serverKey, adaptor.getC(), b.getIndexAttr(2)); auto shiftedB = b.create( serverKey, adaptor.getB(), b.getIndexAttr(1)); - auto outputType = - getTypeConverter()->convertType(shiftedB.getResult().getType()); - auto summedBC = - b.create(outputType, serverKey, shiftedC, shiftedB); - auto summedABC = b.create(outputType, serverKey, summedBC, - adaptor.getA()); + auto summedBC = b.create(adaptor.getB().getType(), + serverKey, shiftedC, shiftedB); + auto summedABC = b.create( + adaptor.getB().getType(), serverKey, summedBC, adaptor.getA()); rewriter.replaceOp( op, b.create(serverKey, summedABC, lut)); @@ -259,6 +257,7 @@ struct ConvertCGGITRBinOp : public OpConversionPattern { Value serverKey = result.value(); CGGIToTfheRustTypeConverter typeConverter(op->getContext()); auto outputType = typeConverter.convertType(op.getResult().getType()); + rewriter.replaceOp( op, b.create(outputType, serverKey, adaptor.getLhs(), adaptor.getRhs())); diff --git a/lib/Target/TfheRustHL/TfheRustHLEmitter.cpp b/lib/Target/TfheRustHL/TfheRustHLEmitter.cpp index 1ad45731f..80b29bcf0 100644 --- a/lib/Target/TfheRustHL/TfheRustHLEmitter.cpp +++ b/lib/Target/TfheRustHL/TfheRustHLEmitter.cpp @@ -308,7 +308,12 @@ LogicalResult TfheRustHLEmitter::printOperation(CreateTrivialOp op) { os << "FheUint" << getTfheRustBitWidth(op.getResult().getType()) << "::try_encrypt_trivial(" - << variableNames->getNameForValue(op.getValue()) << ").unwrap();\n"; + << variableNames->getNameForValue(op.getValue()); + + if (op.getValue().getType().isSigned()) + os << " as u" << getTfheRustBitWidth(op.getResult().getType()); + + os << ").unwrap();\n"; return success(); } @@ -359,7 +364,7 @@ LogicalResult TfheRustHLEmitter::printOperation(arith::ConstantOp op) { // By default, it emits an unsigned integer. emitAssignPrefix(op.getResult()); if (auto intAttr = dyn_cast(valueAttr)) { - os << intAttr.getValue().abs() << "u64;\n"; + os << intAttr.getValue().abs() << convertType(op.getType()) << ";\n"; } else { return op.emitError() << "Unknown constant type " << valueAttr.getType(); } @@ -383,6 +388,17 @@ LogicalResult TfheRustHLEmitter::printBinaryOp(::mlir::Value result, std::string_view op) { emitAssignPrefix(result); + if (auto cteOp = dyn_cast(rhs.getDefiningOp())) { + auto intValue = + cast(cteOp.getValue()).getValue().getZExtValue(); + os << checkOrigin(lhs) << variableNames->getNameForValue(lhs) << " " << op + << " " << intValue << "u" << cteOp.getType().getIntOrFloatBitWidth() + << ";\n"; + return success(); + } + + // Note: arith.constant op requires signless integer types, but here we + // manually emit an unsigned integer type. os << checkOrigin(lhs) << variableNames->getNameForValue(lhs) << " " << op << " " << checkOrigin(rhs) << variableNames->getNameForValue(rhs) << ";\n"; return success(); @@ -430,8 +446,8 @@ LogicalResult TfheRustHLEmitter::printOperation(memref::AllocOp op) { if (failed(emitType(op.getMemref().getType().getElementType()))) { return op.emitOpError() << "Failed to get memref element type"; } - os << "> = BTreeMap::new();\n"; + return success(); } @@ -463,12 +479,11 @@ LogicalResult TfheRustHLEmitter::printOperation(memref::LoadOp op) { // We assume here that the indices are SSA values (not integer attributes). if (isa(op.getMemref())) { emitAssignPrefix(op.getResult()); - os << "&" << variableNames->getNameForValue(op.getMemRef()) << "[" - << flattenIndexExpression(op.getMemRefType(), op.getIndices(), - [&](Value value) { - return variableNames->getNameForValue(value); - }) - << "];\n"; + os << "&" << variableNames->getNameForValue(op.getMemRef()); + for (auto value : op.getIndices()) { + os << "[" << variableNames->getNameForValue(value) << "]"; + } + os << ";\n"; return success(); } @@ -586,6 +601,7 @@ LogicalResult TfheRustHLEmitter::printOperation(tensor::InsertOp op) { return std::string(prefix) + variableNames->getNameForValue(value) + cloneStr; }) << "];\n"; + return success(); } @@ -662,9 +678,12 @@ FailureOr TfheRustHLEmitter::convertType(Type type) { } auto width = getRustIntegerType(type.getWidth()); if (failed(width)) return failure(); - return (type.isUnsigned() ? std::string("u") : "") + "i" + + return (type.isSigned() ? std::string("i") : std::string("u")) + std::to_string(width.value()); }) + .Case([&](IndexType type) -> FailureOr { + return std::string("usize"); + }) .Case( [&](auto type) { return std::string("LookupTableOwned"); }) .Default([&](Type &) { return failure(); }); diff --git a/tests/Examples/tfhe_rust_hl/cpu/README.md b/tests/Examples/tfhe_rust_hl/cpu/README.md index ab8e01a8d..9308b340d 100644 --- a/tests/Examples/tfhe_rust_hl/cpu/README.md +++ b/tests/Examples/tfhe_rust_hl/cpu/README.md @@ -17,7 +17,7 @@ if you overrode the default option when installing Cargo. ```bash bazel query "filter('.mlir.test$', //tests/Examples/tfhe_rust_hl/cpu/...)" \ - | xargs bazel test --noincompatible_strict_action_env -test_timeout=180 --sandbox_writable_path=$HOME/.cargo "$@" + | xargs bazel test --noincompatible_strict_action_env --test_timeout=180 --sandbox_writable_path=$HOME/.cargo "$@" ``` The `manual` tag is added to the targets in this directory to ensure that they