Skip to content

Commit 11ea2e0

Browse files
asraacopybara-github
authored andcommitted
linalg-to-tensor-ext: plumb ciphertext-degree to linalg.matmul rewrites and distribute generic before pass
1. Allow linalg-to-tensor-ext to take in the tilingSize from the ciphertext degree for now 2. linalg-to-tensor-ext currently requires a type converter that upgrades the tensor sizes, and it currently requires (for matching) that matmuls are the single op in the generic. So distribute around matmuls in the rlwe pipelines TODO: I just want to make sure (2) is actually required... Part of lowering the MLP demo with loops for #1232 PiperOrigin-RevId: 725272277
1 parent 67a2ca8 commit 11ea2e0

File tree

5 files changed

+53
-15
lines changed

5 files changed

+53
-15
lines changed

lib/Dialect/LinAlg/Conversions/LinalgToTensorExt/BUILD

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,8 @@ cc_library(
1515
":pass_inc_gen",
1616
"@heir//lib/Analysis/SecretnessAnalysis",
1717
"@heir//lib/Dialect/Secret/IR:Dialect",
18+
"@heir//lib/Dialect/Secret/IR:SecretPatterns",
19+
"@heir//lib/Dialect/Secret/Transforms:DistributeGeneric",
1820
"@heir//lib/Dialect/TensorExt/IR:Dialect",
1921
"@heir//lib/Utils",
2022
"@heir//lib/Utils:ConversionUtils",

lib/Pipelines/ArithmeticPipelineRegistration.cpp

Lines changed: 26 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22

33
#include <cstdlib>
44
#include <string>
5+
#include <vector>
56

67
#include "lib/Dialect/BGV/Conversions/BGVToLWE/BGVToLWE.h"
78
#include "lib/Dialect/CKKS/Conversions/CKKSToLWE/CKKSToLWE.h"
@@ -15,6 +16,7 @@
1516
#include "lib/Dialect/Secret/Conversions/SecretToBGV/SecretToBGV.h"
1617
#include "lib/Dialect/Secret/Conversions/SecretToCKKS/SecretToCKKS.h"
1718
#include "lib/Dialect/Secret/Transforms/DistributeGeneric.h"
19+
#include "lib/Dialect/Secret/Transforms/MergeAdjacentGenerics.h"
1820
#include "lib/Dialect/TensorExt/Transforms/CollapseInsertionChains.h"
1921
#include "lib/Dialect/TensorExt/Transforms/InsertRotate.h"
2022
#include "lib/Dialect/TensorExt/Transforms/RotateAndReduce.h"
@@ -26,18 +28,22 @@
2628
#include "lib/Transforms/OptimizeRelinearization/OptimizeRelinearization.h"
2729
#include "lib/Transforms/SecretInsertMgmt/Passes.h"
2830
#include "lib/Transforms/Secretize/Passes.h"
31+
#include "llvm/include/llvm/ADT/SmallVector.h" // from @llvm-project
2932
#include "llvm/include/llvm/Support/raw_ostream.h" // from @llvm-project
3033
#include "mlir/include/mlir/Pass/PassManager.h" // from @llvm-project
3134
#include "mlir/include/mlir/Pass/PassOptions.h" // from @llvm-project
3235
#include "mlir/include/mlir/Transforms/Passes.h" // from @llvm-project
3336

3437
namespace mlir::heir {
3538

36-
void heirSIMDVectorizerPipelineBuilder(OpPassManager &manager) {
39+
void heirSIMDVectorizerPipelineBuilder(OpPassManager &manager,
40+
bool disableLoopUnroll) {
3741
// For now we unroll loops to enable insert-rotate, but we would like to be
3842
// smarter about this and do an affine loop analysis.
3943
// TODO(#589): avoid unrolling loops
40-
manager.addPass(createFullLoopUnroll());
44+
if (!disableLoopUnroll) {
45+
manager.addPass(createFullLoopUnroll());
46+
}
4147

4248
// These two passes are required in this position for a relatively nuanced
4349
// reason. insert-rotate doesn't have general match support. In particular,
@@ -76,25 +82,38 @@ void heirSIMDVectorizerPipelineBuilder(OpPassManager &manager) {
7682
manager.addPass(createCSEPass());
7783
}
7884

79-
void mlirToSecretArithmeticPipelineBuilder(OpPassManager &pm) {
85+
void mlirToSecretArithmeticPipelineBuilder(
86+
OpPassManager &pm, const MlirToRLWEPipelineOptions &options) {
8087
pm.addPass(createWrapGeneric());
8188
convertToDataObliviousPipelineBuilder(pm);
8289
pm.addPass(createCanonicalizerPass());
8390
pm.addPass(createCSEPass());
8491

92+
// Apply linalg kernels
8593
// Linalg canonicalization
8694
// TODO(#1191): enable dropping unit dims to convert matmul to matvec/vecmat
8795
// pm.addPass(createDropUnitDims());
8896
pm.addPass(createLinalgCanonicalizations());
89-
9097
// Layout assignment and lowering
9198
// TODO(#1191): enable layout propagation after implementing the rest
9299
// of the layout lowering pipeline.
93100
// pm.addPass(createLayoutPropagation());
94-
pm.addPass(heir::linalg::createLinalgToTensorExt());
101+
// Note: LinalgToTensorExt requires that linalg.matmuls are the only operation
102+
// within a secret.generic. This is to ensure that any tensor type conversions
103+
// (padding a rectangular matrix to a square diagonalized matrix) can be
104+
// performed without any type mismatches.
105+
std::vector<std::string> opsToDistribute = {"linalg.matmul"};
106+
auto distributeOpts = secret::SecretDistributeGenericOptions{
107+
.opsToDistribute = llvm::to_vector(opsToDistribute)};
108+
pm.addPass(createSecretDistributeGeneric(distributeOpts));
109+
pm.addPass(createCanonicalizerPass());
110+
auto linalgToTensorExtOptions = linalg::LinalgToTensorExtOptions{};
111+
linalgToTensorExtOptions.tilingSize = options.ciphertextDegree;
112+
pm.addPass(heir::linalg::createLinalgToTensorExt(linalgToTensorExtOptions));
113+
pm.addPass(secret::createSecretMergeAdjacentGenerics());
95114

96115
// Vectorize and optimize rotations
97-
heirSIMDVectorizerPipelineBuilder(pm);
116+
heirSIMDVectorizerPipelineBuilder(pm, options.experimentalDisableLoopUnroll);
98117

99118
// Balance Operations
100119
pm.addPass(createOperationBalancer());
@@ -103,7 +122,7 @@ void mlirToSecretArithmeticPipelineBuilder(OpPassManager &pm) {
103122
void mlirToRLWEPipeline(OpPassManager &pm,
104123
const MlirToRLWEPipelineOptions &options,
105124
const RLWEScheme scheme) {
106-
mlirToSecretArithmeticPipelineBuilder(pm);
125+
mlirToSecretArithmeticPipelineBuilder(pm, options);
107126

108127
// place mgmt.op and MgmtAttr for BGV
109128
// which is required for secret-to-<scheme> lowering

lib/Pipelines/ArithmeticPipelineRegistration.h

Lines changed: 14 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -14,10 +14,19 @@ namespace mlir::heir {
1414
// RLWE scheme selector
1515
enum RLWEScheme { ckksScheme, bgvScheme };
1616

17-
void heirSIMDVectorizerPipelineBuilder(OpPassManager &manager);
17+
struct SimdVectorizerOptions
18+
: public PassPipelineOptions<SimdVectorizerOptions> {
19+
PassOptions::Option<bool> experimentalDisableLoopUnroll{
20+
*this, "experimental-disable-loop-unroll",
21+
llvm::cl::desc("Experimental: disable loop unroll, may break analyses "
22+
"(default to false)"),
23+
llvm::cl::init(false)};
24+
};
25+
26+
void heirSIMDVectorizerPipelineBuilder(OpPassManager &manager,
27+
bool disableLoopUnroll);
1828

19-
struct MlirToRLWEPipelineOptions
20-
: public PassPipelineOptions<MlirToRLWEPipelineOptions> {
29+
struct MlirToRLWEPipelineOptions : public SimdVectorizerOptions {
2130
PassOptions::Option<int> ciphertextDegree{
2231
*this, "ciphertext-degree",
2332
llvm::cl::desc("The degree of the polynomials to use for ciphertexts; "
@@ -57,7 +66,8 @@ void mlirToRLWEPipeline(OpPassManager &pm,
5766
const MlirToRLWEPipelineOptions &options,
5867
RLWEScheme scheme);
5968

60-
void mlirToSecretArithmeticPipelineBuilder(OpPassManager &pm);
69+
void mlirToSecretArithmeticPipelineBuilder(
70+
OpPassManager &pm, const MlirToRLWEPipelineOptions &options);
6171

6272
RLWEPipelineBuilder mlirToRLWEPipelineBuilder(RLWEScheme scheme);
6373

lib/Pipelines/BUILD

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -99,6 +99,7 @@ cc_library(
9999
"@heir//lib/Dialect/Secret/Conversions/SecretToCGGI",
100100
"@heir//lib/Dialect/Secret/Conversions/SecretToCKKS",
101101
"@heir//lib/Dialect/Secret/Transforms:DistributeGeneric",
102+
"@heir//lib/Dialect/Secret/Transforms:MergeAdjacentGenerics",
102103
"@heir//lib/Dialect/TOSA/Conversions/TosaToSecretArith",
103104
"@heir//lib/Dialect/TensorExt/Transforms:CollapseInsertionChains",
104105
"@heir//lib/Dialect/TensorExt/Transforms:InsertRotate",

tools/heir-opt.cpp

Lines changed: 10 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -351,18 +351,24 @@ int main(int argc, char **argv) {
351351
"Lower basic MLIR to LLVM",
352352
::mlir::heir::basicMLIRToLLVMPipelineBuilder);
353353

354-
PassPipelineRegistration<>(
354+
PassPipelineRegistration<SimdVectorizerOptions>(
355355
"heir-simd-vectorizer",
356356
"Run scheme-agnostic passes to convert FHE programs that operate on "
357357
"scalar types to equivalent programs that operate on vectors and use "
358358
"tensor_ext.rotate",
359-
mlir::heir::heirSIMDVectorizerPipelineBuilder);
359+
[](OpPassManager &pm, const SimdVectorizerOptions &options) {
360+
::mlir::heir::heirSIMDVectorizerPipelineBuilder(
361+
pm, options.experimentalDisableLoopUnroll);
362+
});
360363

361-
PassPipelineRegistration<>(
364+
PassPipelineRegistration<mlir::heir::MlirToRLWEPipelineOptions>(
362365
"mlir-to-secret-arithmetic",
363366
"Convert a func using standard MLIR dialects to secret dialect with "
364367
"arithmetic ops",
365-
mlirToSecretArithmeticPipelineBuilder);
368+
[](OpPassManager &pm,
369+
const mlir::heir::MlirToRLWEPipelineOptions &options) {
370+
mlirToSecretArithmeticPipelineBuilder(pm, options);
371+
});
366372

367373
PassPipelineRegistration<mlir::heir::MlirToRLWEPipelineOptions>(
368374
"mlir-to-bgv",

0 commit comments

Comments
 (0)