diff --git a/lib/Dialect/LinAlg/Conversions/LinalgToTensorExt/BUILD b/lib/Dialect/LinAlg/Conversions/LinalgToTensorExt/BUILD index e6220aa7f..84a3b0f99 100644 --- a/lib/Dialect/LinAlg/Conversions/LinalgToTensorExt/BUILD +++ b/lib/Dialect/LinAlg/Conversions/LinalgToTensorExt/BUILD @@ -15,6 +15,8 @@ cc_library( ":pass_inc_gen", "@heir//lib/Analysis/SecretnessAnalysis", "@heir//lib/Dialect/Secret/IR:Dialect", + "@heir//lib/Dialect/Secret/IR:SecretPatterns", + "@heir//lib/Dialect/Secret/Transforms:DistributeGeneric", "@heir//lib/Dialect/TensorExt/IR:Dialect", "@heir//lib/Utils", "@heir//lib/Utils:ConversionUtils", diff --git a/lib/Pipelines/ArithmeticPipelineRegistration.cpp b/lib/Pipelines/ArithmeticPipelineRegistration.cpp index 2d44a4419..cef79a35b 100644 --- a/lib/Pipelines/ArithmeticPipelineRegistration.cpp +++ b/lib/Pipelines/ArithmeticPipelineRegistration.cpp @@ -2,6 +2,7 @@ #include #include +#include #include "lib/Dialect/BGV/Conversions/BGVToLWE/BGVToLWE.h" #include "lib/Dialect/CKKS/Conversions/CKKSToLWE/CKKSToLWE.h" @@ -15,6 +16,7 @@ #include "lib/Dialect/Secret/Conversions/SecretToBGV/SecretToBGV.h" #include "lib/Dialect/Secret/Conversions/SecretToCKKS/SecretToCKKS.h" #include "lib/Dialect/Secret/Transforms/DistributeGeneric.h" +#include "lib/Dialect/Secret/Transforms/MergeAdjacentGenerics.h" #include "lib/Dialect/TensorExt/Transforms/CollapseInsertionChains.h" #include "lib/Dialect/TensorExt/Transforms/InsertRotate.h" #include "lib/Dialect/TensorExt/Transforms/RotateAndReduce.h" @@ -26,6 +28,7 @@ #include "lib/Transforms/OptimizeRelinearization/OptimizeRelinearization.h" #include "lib/Transforms/SecretInsertMgmt/Passes.h" #include "lib/Transforms/Secretize/Passes.h" +#include "llvm/include/llvm/ADT/SmallVector.h" // from @llvm-project #include "llvm/include/llvm/Support/raw_ostream.h" // from @llvm-project #include "mlir/include/mlir/Pass/PassManager.h" // from @llvm-project #include "mlir/include/mlir/Pass/PassOptions.h" // from @llvm-project @@ -33,11 +36,14 @@ namespace mlir::heir { -void heirSIMDVectorizerPipelineBuilder(OpPassManager &manager) { +void heirSIMDVectorizerPipelineBuilder(OpPassManager &manager, + bool disableLoopUnroll) { // For now we unroll loops to enable insert-rotate, but we would like to be // smarter about this and do an affine loop analysis. // TODO(#589): avoid unrolling loops - manager.addPass(createFullLoopUnroll()); + if (!disableLoopUnroll) { + manager.addPass(createFullLoopUnroll()); + } // These two passes are required in this position for a relatively nuanced // reason. insert-rotate doesn't have general match support. In particular, @@ -76,25 +82,38 @@ void heirSIMDVectorizerPipelineBuilder(OpPassManager &manager) { manager.addPass(createCSEPass()); } -void mlirToSecretArithmeticPipelineBuilder(OpPassManager &pm) { +void mlirToSecretArithmeticPipelineBuilder( + OpPassManager &pm, const MlirToRLWEPipelineOptions &options) { pm.addPass(createWrapGeneric()); convertToDataObliviousPipelineBuilder(pm); pm.addPass(createCanonicalizerPass()); pm.addPass(createCSEPass()); + // Apply linalg kernels // Linalg canonicalization // TODO(#1191): enable dropping unit dims to convert matmul to matvec/vecmat // pm.addPass(createDropUnitDims()); pm.addPass(createLinalgCanonicalizations()); - // Layout assignment and lowering // TODO(#1191): enable layout propagation after implementing the rest // of the layout lowering pipeline. // pm.addPass(createLayoutPropagation()); - pm.addPass(heir::linalg::createLinalgToTensorExt()); + // Note: LinalgToTensorExt requires that linalg.matmuls are the only operation + // within a secret.generic. This is to ensure that any tensor type conversions + // (padding a rectangular matrix to a square diagonalized matrix) can be + // performed without any type mismatches. + std::vector opsToDistribute = {"linalg.matmul"}; + auto distributeOpts = secret::SecretDistributeGenericOptions{ + .opsToDistribute = llvm::to_vector(opsToDistribute)}; + pm.addPass(createSecretDistributeGeneric(distributeOpts)); + pm.addPass(createCanonicalizerPass()); + auto linalgToTensorExtOptions = linalg::LinalgToTensorExtOptions{}; + linalgToTensorExtOptions.tilingSize = options.ciphertextDegree; + pm.addPass(heir::linalg::createLinalgToTensorExt(linalgToTensorExtOptions)); + pm.addPass(secret::createSecretMergeAdjacentGenerics()); // Vectorize and optimize rotations - heirSIMDVectorizerPipelineBuilder(pm); + heirSIMDVectorizerPipelineBuilder(pm, options.experimentalDisableLoopUnroll); // Balance Operations pm.addPass(createOperationBalancer()); @@ -103,7 +122,7 @@ void mlirToSecretArithmeticPipelineBuilder(OpPassManager &pm) { void mlirToRLWEPipeline(OpPassManager &pm, const MlirToRLWEPipelineOptions &options, const RLWEScheme scheme) { - mlirToSecretArithmeticPipelineBuilder(pm); + mlirToSecretArithmeticPipelineBuilder(pm, options); // place mgmt.op and MgmtAttr for BGV // which is required for secret-to- lowering diff --git a/lib/Pipelines/ArithmeticPipelineRegistration.h b/lib/Pipelines/ArithmeticPipelineRegistration.h index eb47cb3e7..914833a1b 100644 --- a/lib/Pipelines/ArithmeticPipelineRegistration.h +++ b/lib/Pipelines/ArithmeticPipelineRegistration.h @@ -14,10 +14,19 @@ namespace mlir::heir { // RLWE scheme selector enum RLWEScheme { ckksScheme, bgvScheme }; -void heirSIMDVectorizerPipelineBuilder(OpPassManager &manager); +struct SimdVectorizerOptions + : public PassPipelineOptions { + PassOptions::Option experimentalDisableLoopUnroll{ + *this, "experimental-disable-loop-unroll", + llvm::cl::desc("Experimental: disable loop unroll, may break analyses " + "(default to false)"), + llvm::cl::init(false)}; +}; + +void heirSIMDVectorizerPipelineBuilder(OpPassManager &manager, + bool disableLoopUnroll); -struct MlirToRLWEPipelineOptions - : public PassPipelineOptions { +struct MlirToRLWEPipelineOptions : public SimdVectorizerOptions { PassOptions::Option ciphertextDegree{ *this, "ciphertext-degree", llvm::cl::desc("The degree of the polynomials to use for ciphertexts; " @@ -57,7 +66,8 @@ void mlirToRLWEPipeline(OpPassManager &pm, const MlirToRLWEPipelineOptions &options, RLWEScheme scheme); -void mlirToSecretArithmeticPipelineBuilder(OpPassManager &pm); +void mlirToSecretArithmeticPipelineBuilder( + OpPassManager &pm, const MlirToRLWEPipelineOptions &options); RLWEPipelineBuilder mlirToRLWEPipelineBuilder(RLWEScheme scheme); diff --git a/lib/Pipelines/BUILD b/lib/Pipelines/BUILD index d17f9f11f..45d78aac4 100644 --- a/lib/Pipelines/BUILD +++ b/lib/Pipelines/BUILD @@ -99,6 +99,7 @@ cc_library( "@heir//lib/Dialect/Secret/Conversions/SecretToCGGI", "@heir//lib/Dialect/Secret/Conversions/SecretToCKKS", "@heir//lib/Dialect/Secret/Transforms:DistributeGeneric", + "@heir//lib/Dialect/Secret/Transforms:MergeAdjacentGenerics", "@heir//lib/Dialect/TOSA/Conversions/TosaToSecretArith", "@heir//lib/Dialect/TensorExt/Transforms:CollapseInsertionChains", "@heir//lib/Dialect/TensorExt/Transforms:InsertRotate", diff --git a/tools/heir-opt.cpp b/tools/heir-opt.cpp index feaa7a213..03102ccf6 100644 --- a/tools/heir-opt.cpp +++ b/tools/heir-opt.cpp @@ -351,18 +351,24 @@ int main(int argc, char **argv) { "Lower basic MLIR to LLVM", ::mlir::heir::basicMLIRToLLVMPipelineBuilder); - PassPipelineRegistration<>( + PassPipelineRegistration( "heir-simd-vectorizer", "Run scheme-agnostic passes to convert FHE programs that operate on " "scalar types to equivalent programs that operate on vectors and use " "tensor_ext.rotate", - mlir::heir::heirSIMDVectorizerPipelineBuilder); + [](OpPassManager &pm, const SimdVectorizerOptions &options) { + ::mlir::heir::heirSIMDVectorizerPipelineBuilder( + pm, options.experimentalDisableLoopUnroll); + }); - PassPipelineRegistration<>( + PassPipelineRegistration( "mlir-to-secret-arithmetic", "Convert a func using standard MLIR dialects to secret dialect with " "arithmetic ops", - mlirToSecretArithmeticPipelineBuilder); + [](OpPassManager &pm, + const mlir::heir::MlirToRLWEPipelineOptions &options) { + mlirToSecretArithmeticPipelineBuilder(pm, options); + }); PassPipelineRegistration( "mlir-to-bgv",