Skip to content

Commit e5651d6

Browse files
Refactor: Move StableHLO matchers to dialect utils and reorganize namespaces (#730)
- Move StablehloMatchers from Transforms to Dialect/StablehloExt/Utils - Simplify matcher header to focus on softmax pattern matching - Move pass implementations from mlir to mtrt namespace - Rename LinalgExtElementwiseFusionPass to MTRTLinalgElementwiseFusionPass - Reorganize test infrastructure under test/lib/Dialect/StablehloExt - Update build system and pass registrations accordingly GitOrigin-RevId: 7ecb025b071c65fcc6ccdc9543cd6f9f3ec389eb
1 parent 677a743 commit e5651d6

29 files changed

Lines changed: 586 additions & 317 deletions

File tree

Lines changed: 60 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,60 @@
1+
//===- StablehloMatchers.h -----------------------------------------------===//
2+
//
3+
// SPDX-FileCopyrightText: Copyright 2024 NVIDIA CORPORATION & AFFILIATES.
4+
// All rights reserved.
5+
// SPDX-License-Identifier: Apache-2.0
6+
//
7+
// Licensed under the Apache License, Version 2.0 (the "License");
8+
// you may not use this file except in compliance with the License.
9+
// You may obtain a copy of the License at
10+
//
11+
// http://www.apache.org/licenses/LICENSE-2.0
12+
//
13+
// Unless required by applicable law or agreed to in writing, software
14+
// distributed under the License is distributed on an "AS IS" BASIS,
15+
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
16+
// See the License for the specific language governing permissions and
17+
// limitations under the License.
18+
//
19+
//===----------------------------------------------------------------------===//
20+
///
21+
/// This file defines different matchers for StableHLO.
22+
///
23+
//===----------------------------------------------------------------------===//
24+
#ifndef MLIR_TENSORRT_DIALECT_STABLEHLOEXT_UTILS_STABLEHLOMATCHERS
25+
#define MLIR_TENSORRT_DIALECT_STABLEHLOEXT_UTILS_STABLEHLOMATCHERS
26+
27+
#include "stablehlo/dialect/StablehloOps.h"
28+
29+
namespace mlir::stablehlo {
30+
31+
namespace detail {
32+
/// Matcher for stablehlo
33+
/// reduce(max)->subtract->exponential->reduce(add)->divide The Softmax Matcher
34+
/// is rooted at stablehlo::DivOp operation.
35+
template <typename BcastInDimOp, typename ReduceOp, typename SubOp,
36+
typename ExpnOp, typename DivideOp, typename MaxOp, typename AddOp>
37+
struct HLOToSoftmaxMatcher {
38+
HLOToSoftmaxMatcher(Value &softmaxInputDeduced, int64_t &reductionDim)
39+
: softmaxInputOperand(softmaxInputDeduced), softmaxAxis(reductionDim) {}
40+
Value &softmaxInputOperand;
41+
int64_t &softmaxAxis;
42+
bool match(Operation *op);
43+
};
44+
45+
} // namespace detail
46+
47+
/// Match a pattern rooted at stablhehlo.divide to a softmax
48+
/// operation.
49+
inline auto m_StableHLOSoftmaxMatcher(Value &softmaxInputOperand,
50+
int64_t &softmax_axis) {
51+
return detail::HLOToSoftmaxMatcher<
52+
mlir::stablehlo::BroadcastInDimOp, mlir::stablehlo::ReduceOp,
53+
mlir::stablehlo::SubtractOp, mlir::stablehlo::ExpOp,
54+
mlir::stablehlo::DivOp, mlir::stablehlo::MaxOp, mlir::stablehlo::AddOp>(
55+
softmaxInputOperand, softmax_axis);
56+
}
57+
58+
} // namespace mlir::stablehlo
59+
60+
#endif // MLIR_TENSORRT_DIALECT_STABLEHLOEXT_UTILS_STABLEHLOMATCHERS

mlir-tensorrt/compiler/include/mlir-tensorrt/Transforms/Passes.h

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -32,11 +32,12 @@ namespace mlir {
3232
namespace tensor {
3333
class TensorDialect;
3434
}
35+
} // namespace mlir
3536

37+
namespace mtrt {
3638
#define GEN_PASS_DECL
3739
#define GEN_PASS_REGISTRATION
3840
#include "mlir-tensorrt/Transforms/Passes.h.inc"
39-
40-
} // namespace mlir
41+
} // namespace mtrt
4142

4243
#endif // MLIR_TENSORRT_TRANSFORMS_PASSES_H

mlir-tensorrt/compiler/include/mlir-tensorrt/Transforms/Passes.td

Lines changed: 19 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -22,20 +22,6 @@
2222

2323
include "mlir/Pass/PassBase.td"
2424

25-
#ifdef MLIR_TENSORRT_ENABLE_HLO
26-
//===----------------------------------------------------------------------===//
27-
// TestStablehloMatchersPass to test different matchers for stablehlo
28-
//===----------------------------------------------------------------------===//
29-
def TestStablehloMatchersPass : Pass<"test-stablehlo-matchers"> {
30-
let summary = "tests pattern matching utilities for StableHLO.";
31-
let description =[{
32-
StableHLOMatchers.h defines matchers to raise to different patterns for
33-
dot-product attention, softmax, etc. This pass tests a matcher by looking
34-
for the `__matched__` attribute added by a matcher.
35-
}];
36-
}
37-
#endif // MLIR_TENSORRT_ENABLE_HLO
38-
3925
//===----------------------------------------------------------------------===//
4026
// DropNestedModulesPass
4127
//===----------------------------------------------------------------------===//
@@ -44,6 +30,25 @@ def DropNestedModulesPass : Pass<"drop-nested-modules", "::mlir::ModuleOp"> {
4430
" nested within the top-level Module";
4531
}
4632

33+
//===----------------------------------------------------------------------===//
34+
// LinalgElementwiseFusionPass
35+
//===----------------------------------------------------------------------===//
36+
37+
def LinalgElementwiseFusionPass : Pass<"mtrt-linalg-elementwise-fusion"> {
38+
let summary =
39+
"Performs elementwise fusion on linalg operations specific to MLIR-TensorRT";
40+
41+
let description = [{
42+
Runs linalg elementwise fusion with a control function that is specific
43+
to MLIR-TensorRT.
44+
}];
45+
46+
let dependentDialects = [
47+
"::mlir::tensor::TensorDialect",
48+
"::mlir::linalg::LinalgDialect"
49+
];
50+
}
51+
4752
//===----------------------------------------------------------------------===//
4853
// MemRefCastEliminationPass
4954
//===----------------------------------------------------------------------===//

mlir-tensorrt/compiler/include/mlir-tensorrt/Transforms/StablehloMatchers/StablehloMatchers.h

Lines changed: 0 additions & 150 deletions
This file was deleted.

mlir-tensorrt/compiler/lib/Compiler/InitAllPasses.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -67,7 +67,7 @@ void mtrt::compiler::registerAllPasses() {
6767
mlir::registerLowerAffinePass();
6868
mlir::registerMLIRTensorRTCommonConversionPasses();
6969
mlir::registerMLIRTensorRTConversionPasses();
70-
mlir::registerMLIRTensorRTGenericTransformsPasses();
70+
mtrt::registerMLIRTensorRTGenericTransformsPasses();
7171
mlir::registerTransformsPasses();
7272
mlir::tensorrt::registerTensorRTPasses();
7373
mtrt::compiler::registerTensorRTToExecutablePasses();

mlir-tensorrt/compiler/lib/Compiler/StablehloToExecutable/CMakeLists.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,7 @@ add_mlir_tensorrt_library(MLIRTensorRTCompilerStableHloToExecutable
2626
MLIRTensorRTHostBackend
2727
MLIRTensorRTHostToEmitC
2828
MLIRTensorRTHostToLLVM
29+
MLIRTensorRTLinalgElementwiseFusion
2930
MLIRTensorRTMemRefToCUDA
3031
MLIRTensorRTPlanToExecutor
3132
MLIRTensorRTPlanTransforms

mlir-tensorrt/compiler/lib/Compiler/StablehloToExecutable/Passes.cpp

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -97,12 +97,14 @@ class ProcessHostClustersPass
9797
ProcessHostClustersPass() {
9898
dynamicPM = OpPassManager("func.func");
9999
dynamicPM.addPass(mlir::createStablehloToLinalgPass());
100+
dynamicPM.addPass(mlir::createLinalgGeneralizeNamedOpsPass());
101+
dynamicPM.addPass(mtrt::createLinalgElementwiseFusionPass());
100102
dynamicPM.addPass(mlir::createLinalgDetensorizePass(
101103
mlir::LinalgDetensorizePassOptions{/*aggressiveMode=*/true}));
102104
dynamicPM.addPass(mlir::createConvertToLoops());
103105
dynamicPM.addPass(mlir::createCSEPass());
104106
dynamicPM.addPass(mlir::createCanonicalizerPass());
105-
dynamicPM.addPass(mlir::createSCFDetensorizeLoopsPass());
107+
dynamicPM.addPass(createSCFDetensorizeLoopsPass());
106108
}
107109

108110
void runOnOperation() override {

mlir-tensorrt/compiler/lib/Compiler/StablehloToExecutable/StableHloInputPipelines.cpp

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -93,8 +93,8 @@ void mtrt::compiler::buildStablehloPreProcessingPipeline(
9393
// `convert-stablehlo-to-scf`:
9494
if (opts.legalizeControlFlowToSCF) {
9595
pm.addNestedPass<func::FuncOp>(mlir::createConvertStablehloToScfPass());
96-
pm.addNestedPass<func::FuncOp>(mlir::createUnrollForLoopsPass(
97-
mlir::UnrollForLoopsPassOptions{opts.unrollThreshold}));
96+
pm.addNestedPass<func::FuncOp>(createUnrollForLoopsPass(
97+
mtrt::UnrollForLoopsPassOptions{opts.unrollThreshold}));
9898
}
9999

100100
// `stablehlo-ext-constant-folding`:

mlir-tensorrt/compiler/lib/Conversion/StablehloToTensorRT/CMakeLists.txt

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,6 @@ add_mlir_tensorrt_library(MLIRTensorRTStablehloToTensorRT
1313
MLIRTensorRTConvertToTensorRTCommon
1414
MLIRTensorRTTensorRTDialect
1515
MLIRTensorRTStableHloExtUtils
16-
MLIRTensorRTStablehloMatchers
1716
MLIRTensorRTTensorRTUtils
1817
MLIRTensorRTUtilsShapeInfo
1918
StablehloOps

mlir-tensorrt/compiler/lib/Conversion/StablehloToTensorRT/StablehloToTensorRT.cpp

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -31,7 +31,7 @@
3131
#include "mlir-tensorrt/Conversion/Patterns.h"
3232
#include "mlir-tensorrt/Conversion/TensorRTCommon/ConvertToTensorRTCommon.h"
3333
#include "mlir-tensorrt/Dialect/StablehloExt/Utils/GatherScatterUtils.h"
34-
#include "mlir-tensorrt/Transforms/StablehloMatchers/StablehloMatchers.h"
34+
#include "mlir-tensorrt/Dialect/StablehloExt/Utils/StablehloMatchers.h"
3535
#include "mlir/Dialect/Arith/IR/Arith.h"
3636
#include "mlir/Dialect/Func/IR/FuncOps.h"
3737
#include "mlir/Dialect/Func/Transforms/FuncConversions.h"
@@ -3713,7 +3713,7 @@ struct ConvertHLOSoftmax : public OpRewritePattern<stablehlo::DivOp> {
37133713
mlir::Value softmaxInputOperand;
37143714
int64_t deducedAxis = -1;
37153715
if (!matchPattern(op.getOperation(),
3716-
mlir::matchers::m_StableHLOSoftmaxMatcher(
3716+
mlir::stablehlo::m_StableHLOSoftmaxMatcher(
37173717
softmaxInputOperand, deducedAxis)))
37183718
return failure();
37193719

0 commit comments

Comments
 (0)