Skip to content

Commit 6d04bee

Browse files
asraacopybara-github
authored andcommitted
integrate split-preprocessing into pipelines
This includes a couple of minor changes: * lwe-to-openfhe needs to update function signatures and call ops that have ciphertext data * update split-preprocessing to batch plaintexts into groups with a configurable maximum # of return values PiperOrigin-RevId: 872514667
1 parent 9478d3d commit 6d04bee

18 files changed

Lines changed: 574 additions & 105 deletions

File tree

lib/Dialect/LWE/Conversions/LWEToLattigo/LWEToLattigo.cpp

Lines changed: 16 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -32,6 +32,7 @@
3232
#include "mlir/include/mlir/IR/BuiltinTypeInterfaces.h" // from @llvm-project
3333
#include "mlir/include/mlir/IR/BuiltinTypes.h" // from @llvm-project
3434
#include "mlir/include/mlir/IR/PatternMatch.h" // from @llvm-project
35+
#include "mlir/include/mlir/IR/TypeUtilities.h" // from @llvm-project
3536
#include "mlir/include/mlir/IR/Value.h" // from @llvm-project
3637
#include "mlir/include/mlir/Support/LLVM.h" // from @llvm-project
3738
#include "mlir/include/mlir/Support/LogicalResult.h" // from @llvm-project
@@ -114,10 +115,8 @@ bool containsArgumentOfDialect(Operation* op) {
114115
return false;
115116
}
116117
return llvm::any_of(funcOp.getArgumentTypes(), [&](Type argType) {
117-
if (isa<ShapedType>(argType)) {
118-
argType = cast<ShapedType>(argType).getElementType();
119-
}
120-
return DialectEqual<Dialects...>()(&argType.getDialect());
118+
return DialectEqual<Dialects...>()(
119+
&getElementTypeOrSelf(argType).getDialect());
121120
});
122121
}
123122

@@ -794,7 +793,8 @@ struct LWEToLattigo : public impl::LWEToLattigoBase<LWEToLattigo> {
794793
return typeConverter.isSignatureLegal(op.getFunctionType()) &&
795794
typeConverter.isLegal(&op.getBody()) &&
796795
(!containsArgumentOfDialect<lwe::LWEDialect, bgv::BGVDialect,
797-
ckks::CKKSDialect>(op) ||
796+
ckks::CKKSDialect,
797+
lattigo::LattigoDialect>(op) ||
798798
hasCryptoContextArg);
799799
});
800800

@@ -809,7 +809,17 @@ struct LWEToLattigo : public impl::LWEToLattigoBase<LWEToLattigo> {
809809
!operandTypes.empty() &&
810810
mlir::isa<lattigo::BGVEvaluatorType, lattigo::CKKSEvaluatorType>(
811811
*operandTypes.begin());
812-
return (!containsCryptoArg || hasCryptoContextArg);
812+
// crypto context may need to be added for any function call whose callee
813+
// has crypto ops.
814+
auto containsCryptoOps = false;
815+
FailureOr<func::FuncOp> callee = getCalledFunction(op);
816+
if (succeeded(callee)) {
817+
containsCryptoOps =
818+
containsDialects<lwe::LWEDialect, bgv::BGVDialect,
819+
ckks::CKKSDialect, lattigo::LattigoDialect>(
820+
callee.value());
821+
}
822+
return (!(containsCryptoArg || containsCryptoOps) || hasCryptoContextArg);
813823
});
814824

815825
// All other operations are legal if they have no LWE typed operands or

lib/Dialect/LWE/Conversions/LWEToOpenfhe/LWEToOpenfhe.cpp

Lines changed: 21 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -76,7 +76,8 @@ namespace {
7676
template <typename... Dialects>
7777
bool containsArgumentOfDialect(func::FuncOp funcOp) {
7878
return llvm::any_of(funcOp.getArgumentTypes(), [&](Type argType) {
79-
return DialectEqual<Dialects...>()(&argType.getDialect());
79+
return DialectEqual<Dialects...>()(
80+
&getElementTypeOrSelf(argType).getDialect());
8081
});
8182
}
8283

@@ -109,11 +110,12 @@ struct AddCryptoContextArg : public OpConversionPattern<func::FuncOp> {
109110
}
110111

111112
auto containsCryptoOps =
112-
containsDialects<lwe::LWEDialect, bgv::BGVDialect, ckks::CKKSDialect>(
113-
op);
113+
containsDialects<lwe::LWEDialect, bgv::BGVDialect, ckks::CKKSDialect,
114+
openfhe::OpenfheDialect>(op);
114115
auto containsCryptoArg =
115116
containsArgumentOfDialect<lwe::LWEDialect, bgv::BGVDialect,
116-
ckks::CKKSDialect>(op);
117+
ckks::CKKSDialect, openfhe::OpenfheDialect>(
118+
op);
117119
if (!(containsCryptoOps || containsCryptoArg)) {
118120
return rewriter.notifyMatchFailure(
119121
op, "contains neither ops nor arg types from lwe/bgv/ckks dialects");
@@ -427,11 +429,12 @@ struct LWEToOpenfhe : public impl::LWEToOpenfheBase<LWEToOpenfhe> {
427429
return hasCryptoContextArg;
428430
}
429431
auto containsCryptoOps =
430-
containsDialects<lwe::LWEDialect, bgv::BGVDialect, ckks::CKKSDialect>(
431-
op);
432+
containsDialects<lwe::LWEDialect, bgv::BGVDialect, ckks::CKKSDialect,
433+
openfhe::OpenfheDialect>(op);
432434
auto containsCryptoArg =
433435
containsArgumentOfDialect<lwe::LWEDialect, bgv::BGVDialect,
434-
ckks::CKKSDialect>(op);
436+
ckks::CKKSDialect, openfhe::OpenfheDialect>(
437+
op);
435438
return typeConverter.isSignatureLegal(op.getFunctionType()) &&
436439
typeConverter.isLegal(&op.getBody()) &&
437440
(!(containsCryptoOps || containsCryptoArg) || hasCryptoContextArg);
@@ -447,11 +450,19 @@ struct LWEToOpenfhe : public impl::LWEToOpenfheBase<LWEToOpenfhe> {
447450
return hasCryptoContextArg;
448451
}
449452
auto containsCryptoArg = llvm::any_of(operandTypes, [&](Type argType) {
450-
return DialectEqual<lwe::LWEDialect, bgv::BGVDialect,
451-
ckks::CKKSDialect>()(
453+
return DialectEqual<lwe::LWEDialect, bgv::BGVDialect, ckks::CKKSDialect,
454+
openfhe::OpenfheDialect>()(
452455
&getElementTypeOrSelf(argType).getDialect());
453456
});
454-
return (!containsCryptoArg || hasCryptoContextArg);
457+
auto containsCryptoOps = false;
458+
FailureOr<func::FuncOp> callee = getCalledFunction(op);
459+
if (succeeded(callee)) {
460+
containsCryptoOps =
461+
containsDialects<lwe::LWEDialect, bgv::BGVDialect,
462+
ckks::CKKSDialect, openfhe::OpenfheDialect>(
463+
callee.value());
464+
}
465+
return (!(containsCryptoArg || containsCryptoOps) || hasCryptoContextArg);
455466
});
456467

457468
patterns.add<

lib/Pipelines/ArithmeticPipelineRegistration.cpp

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -186,8 +186,6 @@ void mlirToSecretArithmeticPipelineBuilder(
186186
// Balance Operations
187187
pm.addPass(createOperationBalancer());
188188

189-
// Add a __preprocessed helper for offline pre-packing of plaintexts
190-
pm.addPass(createSplitPreprocessing());
191189
lowerAssignLayout(pm, false);
192190

193191
// Add encrypt/decrypt helper functions for each function argument and return
@@ -406,6 +404,11 @@ void mlirToRLWEPipeline(OpPassManager& pm,
406404
pm.addPass(createCanonicalizerPass());
407405
pm.addPass(createCSEPass());
408406

407+
// Add a __preprocessed helper for offline pre-packing of plaintexts
408+
auto splitPreprocessingOptions = SplitPreprocessingOptions{};
409+
splitPreprocessingOptions.maxReturnValues = options.splitPreprocessing;
410+
pm.addPass(createSplitPreprocessing(splitPreprocessingOptions));
411+
409412
ElementwiseToAffineOptions elementwiseOptions;
410413
elementwiseOptions.convertDialects = {"ckks", "bgv", "lwe"};
411414
pm.addPass(createElementwiseToAffine(elementwiseOptions));
@@ -526,6 +529,7 @@ void torchLinalgToCkksBuilder(OpPassManager& manager,
526529
suboptions.ckksBootstrapWaterline = options.ckksBootstrapWaterline;
527530
suboptions.scalingModBits = options.scalingModBits;
528531
suboptions.firstModBits = options.firstModBits;
532+
suboptions.splitPreprocessing = options.splitPreprocessing;
529533

530534
mlirToRLWEPipelineBuilder(mlir::heir::RLWEScheme::ckksScheme)(manager,
531535
suboptions);

lib/Pipelines/ArithmeticPipelineRegistration.h

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -100,6 +100,11 @@ struct MlirToRLWEPipelineOptions : public SimdVectorizerOptions {
100100
llvm::cl::desc("File name to import execution result from (c.f. --secret-"
101101
"import-execution-result)"),
102102
llvm::cl::init("")};
103+
PassOptions::Option<int> splitPreprocessing{
104+
*this, "split-preprocessing",
105+
llvm::cl::desc("Split preprocessing into separate function with N return "
106+
"values (default to no split)"),
107+
llvm::cl::init(0)};
103108
};
104109

105110
struct PlaintextBackendOptions
@@ -179,6 +184,11 @@ struct TorchLinalgToCkksPipelineOptions
179184
llvm::cl::desc("The number of levels to keep until bootstrapping in CKKS "
180185
"(c.f. --secret-insert-mgmt-ckks)"),
181186
llvm::cl::init(10)};
187+
PassOptions::Option<int> splitPreprocessing{
188+
*this, "split-preprocessing",
189+
llvm::cl::desc("Split preprocessing into separate function with N return "
190+
"values (default to no split)"),
191+
llvm::cl::init(0)};
182192
};
183193
void torchLinalgToCkksBuilder(OpPassManager& manager,
184194
const TorchLinalgToCkksPipelineOptions& options);

lib/Transforms/FoldConstantTensors/FoldConstantTensors.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -210,7 +210,7 @@ class CollapseShapeAfterConstant final
210210
return rewriter.notifyMatchFailure(
211211
collapseOp, "source of collapse must be a constant");
212212

213-
auto sourceAttr = llvm::dyn_cast<ElementsAttr>(constantOp.getValue());
213+
auto sourceAttr = llvm::dyn_cast<DenseElementsAttr>(constantOp.getValue());
214214
if (!sourceAttr)
215215
return rewriter.notifyMatchFailure(
216216
collapseOp, "source of collapse must be an elements attribute");

0 commit comments

Comments
 (0)