2
2
3
3
#include < cstdlib>
4
4
#include < string>
5
+ #include < vector>
5
6
6
7
#include " lib/Dialect/BGV/Conversions/BGVToLWE/BGVToLWE.h"
7
8
#include " lib/Dialect/CKKS/Conversions/CKKSToLWE/CKKSToLWE.h"
15
16
#include " lib/Dialect/Secret/Conversions/SecretToBGV/SecretToBGV.h"
16
17
#include " lib/Dialect/Secret/Conversions/SecretToCKKS/SecretToCKKS.h"
17
18
#include " lib/Dialect/Secret/Transforms/DistributeGeneric.h"
19
+ #include " lib/Dialect/Secret/Transforms/MergeAdjacentGenerics.h"
18
20
#include " lib/Dialect/TensorExt/Transforms/CollapseInsertionChains.h"
19
21
#include " lib/Dialect/TensorExt/Transforms/InsertRotate.h"
20
22
#include " lib/Dialect/TensorExt/Transforms/RotateAndReduce.h"
26
28
#include " lib/Transforms/OptimizeRelinearization/OptimizeRelinearization.h"
27
29
#include " lib/Transforms/SecretInsertMgmt/Passes.h"
28
30
#include " lib/Transforms/Secretize/Passes.h"
31
+ #include " llvm/include/llvm/ADT/SmallVector.h" // from @llvm-project
29
32
#include " llvm/include/llvm/Support/raw_ostream.h" // from @llvm-project
30
33
#include " mlir/include/mlir/Pass/PassManager.h" // from @llvm-project
31
34
#include " mlir/include/mlir/Pass/PassOptions.h" // from @llvm-project
32
35
#include " mlir/include/mlir/Transforms/Passes.h" // from @llvm-project
33
36
34
37
namespace mlir ::heir {
35
38
36
- void heirSIMDVectorizerPipelineBuilder (OpPassManager &manager) {
39
+ void heirSIMDVectorizerPipelineBuilder (OpPassManager &manager,
40
+ bool disableLoopUnroll) {
37
41
// For now we unroll loops to enable insert-rotate, but we would like to be
38
42
// smarter about this and do an affine loop analysis.
39
43
// TODO(#589): avoid unrolling loops
40
- manager.addPass (createFullLoopUnroll ());
44
+ if (!disableLoopUnroll) {
45
+ manager.addPass (createFullLoopUnroll ());
46
+ }
41
47
42
48
// These two passes are required in this position for a relatively nuanced
43
49
// reason. insert-rotate doesn't have general match support. In particular,
@@ -76,25 +82,38 @@ void heirSIMDVectorizerPipelineBuilder(OpPassManager &manager) {
76
82
manager.addPass (createCSEPass ());
77
83
}
78
84
79
- void mlirToSecretArithmeticPipelineBuilder (OpPassManager &pm) {
85
+ void mlirToSecretArithmeticPipelineBuilder (
86
+ OpPassManager &pm, const MlirToRLWEPipelineOptions &options) {
80
87
pm.addPass (createWrapGeneric ());
81
88
convertToDataObliviousPipelineBuilder (pm);
82
89
pm.addPass (createCanonicalizerPass ());
83
90
pm.addPass (createCSEPass ());
84
91
92
+ // Apply linalg kernels
85
93
// Linalg canonicalization
86
94
// TODO(#1191): enable dropping unit dims to convert matmul to matvec/vecmat
87
95
// pm.addPass(createDropUnitDims());
88
96
pm.addPass (createLinalgCanonicalizations ());
89
-
90
97
// Layout assignment and lowering
91
98
// TODO(#1191): enable layout propagation after implementing the rest
92
99
// of the layout lowering pipeline.
93
100
// pm.addPass(createLayoutPropagation());
94
- pm.addPass (heir::linalg::createLinalgToTensorExt ());
101
+ // Note: LinalgToTensorExt requires that linalg.matmuls are the only operation
102
+ // within a secret.generic. This is to ensure that any tensor type conversions
103
+ // (padding a rectangular matrix to a square diagonalized matrix) can be
104
+ // performed without any type mismatches.
105
+ std::vector<std::string> opsToDistribute = {" linalg.matmul" };
106
+ auto distributeOpts = secret::SecretDistributeGenericOptions{
107
+ .opsToDistribute = llvm::to_vector (opsToDistribute)};
108
+ pm.addPass (createSecretDistributeGeneric (distributeOpts));
109
+ pm.addPass (createCanonicalizerPass ());
110
+ auto linalgToTensorExtOptions = linalg::LinalgToTensorExtOptions{};
111
+ linalgToTensorExtOptions.tilingSize = options.ciphertextDegree ;
112
+ pm.addPass (heir::linalg::createLinalgToTensorExt (linalgToTensorExtOptions));
113
+ pm.addPass (secret::createSecretMergeAdjacentGenerics ());
95
114
96
115
// Vectorize and optimize rotations
97
- heirSIMDVectorizerPipelineBuilder (pm);
116
+ heirSIMDVectorizerPipelineBuilder (pm, options. experimentalDisableLoopUnroll );
98
117
99
118
// Balance Operations
100
119
pm.addPass (createOperationBalancer ());
@@ -103,7 +122,7 @@ void mlirToSecretArithmeticPipelineBuilder(OpPassManager &pm) {
103
122
void mlirToRLWEPipeline (OpPassManager &pm,
104
123
const MlirToRLWEPipelineOptions &options,
105
124
const RLWEScheme scheme) {
106
- mlirToSecretArithmeticPipelineBuilder (pm);
125
+ mlirToSecretArithmeticPipelineBuilder (pm, options );
107
126
108
127
// place mgmt.op and MgmtAttr for BGV
109
128
// which is required for secret-to-<scheme> lowering
0 commit comments