Skip to content

Commit

Permalink
linalg-to-tensor-ext: plumb ciphertext-degree to linalg.matmul rewrit…
Browse files Browse the repository at this point in the history
…es and distribute generic before pass

1. Allow linalg-to-tensor-ext to take in the tilingSize from the ciphertext degree for now
2. linalg-to-tensor-ext currently requires a type converter that upgrades the tensor sizes, and it currently requires (for matching) that matmuls are the single op in the generic. So distribute around matmuls in the rlwe pipelines

TODO: I just want to make sure (2) is actually required...

Part of lowering the MLP demo with loops for #1232

PiperOrigin-RevId: 725272277
  • Loading branch information
asraa authored and copybara-github committed Feb 10, 2025
1 parent 67a2ca8 commit 11ea2e0
Show file tree
Hide file tree
Showing 5 changed files with 53 additions and 15 deletions.
2 changes: 2 additions & 0 deletions lib/Dialect/LinAlg/Conversions/LinalgToTensorExt/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,8 @@ cc_library(
":pass_inc_gen",
"@heir//lib/Analysis/SecretnessAnalysis",
"@heir//lib/Dialect/Secret/IR:Dialect",
"@heir//lib/Dialect/Secret/IR:SecretPatterns",
"@heir//lib/Dialect/Secret/Transforms:DistributeGeneric",
"@heir//lib/Dialect/TensorExt/IR:Dialect",
"@heir//lib/Utils",
"@heir//lib/Utils:ConversionUtils",
Expand Down
33 changes: 26 additions & 7 deletions lib/Pipelines/ArithmeticPipelineRegistration.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@

#include <cstdlib>
#include <string>
#include <vector>

#include "lib/Dialect/BGV/Conversions/BGVToLWE/BGVToLWE.h"
#include "lib/Dialect/CKKS/Conversions/CKKSToLWE/CKKSToLWE.h"
Expand All @@ -15,6 +16,7 @@
#include "lib/Dialect/Secret/Conversions/SecretToBGV/SecretToBGV.h"
#include "lib/Dialect/Secret/Conversions/SecretToCKKS/SecretToCKKS.h"
#include "lib/Dialect/Secret/Transforms/DistributeGeneric.h"
#include "lib/Dialect/Secret/Transforms/MergeAdjacentGenerics.h"
#include "lib/Dialect/TensorExt/Transforms/CollapseInsertionChains.h"
#include "lib/Dialect/TensorExt/Transforms/InsertRotate.h"
#include "lib/Dialect/TensorExt/Transforms/RotateAndReduce.h"
Expand All @@ -26,18 +28,22 @@
#include "lib/Transforms/OptimizeRelinearization/OptimizeRelinearization.h"
#include "lib/Transforms/SecretInsertMgmt/Passes.h"
#include "lib/Transforms/Secretize/Passes.h"
#include "llvm/include/llvm/ADT/SmallVector.h" // from @llvm-project
#include "llvm/include/llvm/Support/raw_ostream.h" // from @llvm-project
#include "mlir/include/mlir/Pass/PassManager.h" // from @llvm-project
#include "mlir/include/mlir/Pass/PassOptions.h" // from @llvm-project
#include "mlir/include/mlir/Transforms/Passes.h" // from @llvm-project

namespace mlir::heir {

void heirSIMDVectorizerPipelineBuilder(OpPassManager &manager) {
void heirSIMDVectorizerPipelineBuilder(OpPassManager &manager,
bool disableLoopUnroll) {
// For now we unroll loops to enable insert-rotate, but we would like to be
// smarter about this and do an affine loop analysis.
// TODO(#589): avoid unrolling loops
manager.addPass(createFullLoopUnroll());
if (!disableLoopUnroll) {
manager.addPass(createFullLoopUnroll());
}

// These two passes are required in this position for a relatively nuanced
// reason. insert-rotate doesn't have general match support. In particular,
Expand Down Expand Up @@ -76,25 +82,38 @@ void heirSIMDVectorizerPipelineBuilder(OpPassManager &manager) {
manager.addPass(createCSEPass());
}

void mlirToSecretArithmeticPipelineBuilder(OpPassManager &pm) {
void mlirToSecretArithmeticPipelineBuilder(
OpPassManager &pm, const MlirToRLWEPipelineOptions &options) {
pm.addPass(createWrapGeneric());
convertToDataObliviousPipelineBuilder(pm);
pm.addPass(createCanonicalizerPass());
pm.addPass(createCSEPass());

// Apply linalg kernels
// Linalg canonicalization
// TODO(#1191): enable dropping unit dims to convert matmul to matvec/vecmat
// pm.addPass(createDropUnitDims());
pm.addPass(createLinalgCanonicalizations());

// Layout assignment and lowering
// TODO(#1191): enable layout propagation after implementing the rest
// of the layout lowering pipeline.
// pm.addPass(createLayoutPropagation());
pm.addPass(heir::linalg::createLinalgToTensorExt());
// Note: LinalgToTensorExt requires that linalg.matmuls are the only operation
// within a secret.generic. This is to ensure that any tensor type conversions
// (padding a rectangular matrix to a square diagonalized matrix) can be
// performed without any type mismatches.
std::vector<std::string> opsToDistribute = {"linalg.matmul"};
auto distributeOpts = secret::SecretDistributeGenericOptions{
.opsToDistribute = llvm::to_vector(opsToDistribute)};
pm.addPass(createSecretDistributeGeneric(distributeOpts));
pm.addPass(createCanonicalizerPass());
auto linalgToTensorExtOptions = linalg::LinalgToTensorExtOptions{};
linalgToTensorExtOptions.tilingSize = options.ciphertextDegree;
pm.addPass(heir::linalg::createLinalgToTensorExt(linalgToTensorExtOptions));
pm.addPass(secret::createSecretMergeAdjacentGenerics());

// Vectorize and optimize rotations
heirSIMDVectorizerPipelineBuilder(pm);
heirSIMDVectorizerPipelineBuilder(pm, options.experimentalDisableLoopUnroll);

// Balance Operations
pm.addPass(createOperationBalancer());
Expand All @@ -103,7 +122,7 @@ void mlirToSecretArithmeticPipelineBuilder(OpPassManager &pm) {
void mlirToRLWEPipeline(OpPassManager &pm,
const MlirToRLWEPipelineOptions &options,
const RLWEScheme scheme) {
mlirToSecretArithmeticPipelineBuilder(pm);
mlirToSecretArithmeticPipelineBuilder(pm, options);

// place mgmt.op and MgmtAttr for BGV
// which is required for secret-to-<scheme> lowering
Expand Down
18 changes: 14 additions & 4 deletions lib/Pipelines/ArithmeticPipelineRegistration.h
Original file line number Diff line number Diff line change
Expand Up @@ -14,10 +14,19 @@ namespace mlir::heir {
// RLWE scheme selector
enum RLWEScheme { ckksScheme, bgvScheme };

void heirSIMDVectorizerPipelineBuilder(OpPassManager &manager);
struct SimdVectorizerOptions
: public PassPipelineOptions<SimdVectorizerOptions> {
PassOptions::Option<bool> experimentalDisableLoopUnroll{
*this, "experimental-disable-loop-unroll",
llvm::cl::desc("Experimental: disable loop unroll, may break analyses "
"(default to false)"),
llvm::cl::init(false)};
};

void heirSIMDVectorizerPipelineBuilder(OpPassManager &manager,
bool disableLoopUnroll);

struct MlirToRLWEPipelineOptions
: public PassPipelineOptions<MlirToRLWEPipelineOptions> {
struct MlirToRLWEPipelineOptions : public SimdVectorizerOptions {
PassOptions::Option<int> ciphertextDegree{
*this, "ciphertext-degree",
llvm::cl::desc("The degree of the polynomials to use for ciphertexts; "
Expand Down Expand Up @@ -57,7 +66,8 @@ void mlirToRLWEPipeline(OpPassManager &pm,
const MlirToRLWEPipelineOptions &options,
RLWEScheme scheme);

void mlirToSecretArithmeticPipelineBuilder(OpPassManager &pm);
void mlirToSecretArithmeticPipelineBuilder(
OpPassManager &pm, const MlirToRLWEPipelineOptions &options);

RLWEPipelineBuilder mlirToRLWEPipelineBuilder(RLWEScheme scheme);

Expand Down
1 change: 1 addition & 0 deletions lib/Pipelines/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -99,6 +99,7 @@ cc_library(
"@heir//lib/Dialect/Secret/Conversions/SecretToCGGI",
"@heir//lib/Dialect/Secret/Conversions/SecretToCKKS",
"@heir//lib/Dialect/Secret/Transforms:DistributeGeneric",
"@heir//lib/Dialect/Secret/Transforms:MergeAdjacentGenerics",
"@heir//lib/Dialect/TOSA/Conversions/TosaToSecretArith",
"@heir//lib/Dialect/TensorExt/Transforms:CollapseInsertionChains",
"@heir//lib/Dialect/TensorExt/Transforms:InsertRotate",
Expand Down
14 changes: 10 additions & 4 deletions tools/heir-opt.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -351,18 +351,24 @@ int main(int argc, char **argv) {
"Lower basic MLIR to LLVM",
::mlir::heir::basicMLIRToLLVMPipelineBuilder);

PassPipelineRegistration<>(
PassPipelineRegistration<SimdVectorizerOptions>(
"heir-simd-vectorizer",
"Run scheme-agnostic passes to convert FHE programs that operate on "
"scalar types to equivalent programs that operate on vectors and use "
"tensor_ext.rotate",
mlir::heir::heirSIMDVectorizerPipelineBuilder);
[](OpPassManager &pm, const SimdVectorizerOptions &options) {
::mlir::heir::heirSIMDVectorizerPipelineBuilder(
pm, options.experimentalDisableLoopUnroll);
});

PassPipelineRegistration<>(
PassPipelineRegistration<mlir::heir::MlirToRLWEPipelineOptions>(
"mlir-to-secret-arithmetic",
"Convert a func using standard MLIR dialects to secret dialect with "
"arithmetic ops",
mlirToSecretArithmeticPipelineBuilder);
[](OpPassManager &pm,
const mlir::heir::MlirToRLWEPipelineOptions &options) {
mlirToSecretArithmeticPipelineBuilder(pm, options);
});

PassPipelineRegistration<mlir::heir::MlirToRLWEPipelineOptions>(
"mlir-to-bgv",
Expand Down

0 comments on commit 11ea2e0

Please sign in to comment.