Skip to content

Commit 8c7f8c5

Browse files
Alex-Wenggclaude
andcommitted
[GlobalOpt] Reimplement softmax matcher natively (#24466)
Replace the wholesale relocation of the iree-dialects TransformMatchers DSL with a self-contained, native matcher in RaiseSpecialOps.cpp, per the review feedback that this should port only what RaiseSpecialOps needs rather than carry the generic StructuredOpMatcher framework into GlobalOptimization. - Add a local matchSoftmax() plus small helpers that walk the softmax linalg-op graph directly (reduce_max -> sub -> exp -> reduce_add -> mul/reciprocal or div), handling both implicit (projected map) and explicit (pass-through generic) broadcasts. This is behaviorally faithful to makeSoftmaxMatcher, including the same-source invariant. - Delete GlobalOptimization/TransformMatchers.{h,cpp} (~3000 lines) and drop the transform-dialect build deps that only existed for them. The iree-dialects deps stay removed. - Add negative lit cases (wrong max init, mismatched source) alongside the existing softmax raising tests. Co-Authored-By: Claude Opus 4.8 <noreply@anthropic.com> Signed-off-by: Han || Alex <36247722+Alex-Wengg@users.noreply.github.com>
1 parent 7da62a7 commit 8c7f8c5

6 files changed

Lines changed: 391 additions & 3071 deletions

File tree

compiler/src/iree/compiler/GlobalOptimization/BUILD.bazel

Lines changed: 0 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -71,13 +71,11 @@ iree_compiler_cc_library(
7171
"RaiseSpecialOps.cpp",
7272
"RemoveZeroExtentTensors.cpp",
7373
"SimplifyPackUnpack.cpp",
74-
"TransformMatchers.cpp",
7574
"Utils.cpp",
7675
"WarnOnUninitializedValues.cpp",
7776
],
7877
hdrs = [
7978
"Passes.h",
80-
"TransformMatchers.h",
8179
"Utils.h",
8280
],
8381
deps = [
@@ -107,31 +105,25 @@ iree_compiler_cc_library(
107105
"//compiler/src/iree/compiler/Utils",
108106
"@llvm-project//llvm:Support",
109107
"@llvm-project//mlir:AffineDialect",
110-
"@llvm-project//mlir:Analysis",
111108
"@llvm-project//mlir:ArithDialect",
112109
"@llvm-project//mlir:ArithUtils",
113110
"@llvm-project//mlir:ControlFlowDialect",
114111
"@llvm-project//mlir:DialectUtils",
115-
"@llvm-project//mlir:FuncDialect",
116112
"@llvm-project//mlir:FunctionInterfaces",
117113
"@llvm-project//mlir:IR",
118114
"@llvm-project//mlir:LinalgDialect",
119-
"@llvm-project//mlir:LinalgInterfaces",
120115
"@llvm-project//mlir:LinalgTransforms",
121116
"@llvm-project//mlir:LinalgUtils",
122117
"@llvm-project//mlir:MathDialect",
123118
"@llvm-project//mlir:MemRefDialect",
124119
"@llvm-project//mlir:MemRefTransforms",
125120
"@llvm-project//mlir:Pass",
126-
"@llvm-project//mlir:Rewrite",
127121
"@llvm-project//mlir:SCFDialect",
128122
"@llvm-project//mlir:SCFTransforms",
129123
"@llvm-project//mlir:Support",
130124
"@llvm-project//mlir:TensorDialect",
131125
"@llvm-project//mlir:TensorTransforms",
132126
"@llvm-project//mlir:TensorUtils",
133-
"@llvm-project//mlir:TransformDialect",
134-
"@llvm-project//mlir:TransformDialectInterfaces",
135127
"@llvm-project//mlir:TransformUtils",
136128
"@llvm-project//mlir:Transforms",
137129
],

compiler/src/iree/compiler/GlobalOptimization/CMakeLists.txt

Lines changed: 0 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -38,7 +38,6 @@ iree_cc_library(
3838
GlobalOptimization
3939
HDRS
4040
"Passes.h"
41-
"TransformMatchers.h"
4241
"Utils.h"
4342
SRCS
4443
"CleanupNumericNarrowing.cpp"
@@ -63,38 +62,31 @@ iree_cc_library(
6362
"RaiseSpecialOps.cpp"
6463
"RemoveZeroExtentTensors.cpp"
6564
"SimplifyPackUnpack.cpp"
66-
"TransformMatchers.cpp"
6765
"Utils.cpp"
6866
"WarnOnUninitializedValues.cpp"
6967
DEPS
7068
::PassHeaders
7169
::PassesIncGen
7270
LLVMSupport
7371
MLIRAffineDialect
74-
MLIRAnalysis
7572
MLIRArithDialect
7673
MLIRArithUtils
7774
MLIRControlFlowDialect
78-
MLIRFuncDialect
7975
MLIRFunctionInterfaces
8076
MLIRIR
8177
MLIRLinalgDialect
82-
MLIRLinalgInterfacesIncGenLib
8378
MLIRLinalgTransforms
8479
MLIRLinalgUtils
8580
MLIRMathDialect
8681
MLIRMemRefDialect
8782
MLIRMemRefTransforms
8883
MLIRPass
89-
MLIRRewrite
9084
MLIRSCFDialect
9185
MLIRSCFTransforms
9286
MLIRSupport
9387
MLIRTensorDialect
9488
MLIRTensorTransforms
9589
MLIRTensorUtils
96-
MLIRTransformDialect
97-
MLIRTransformDialectInterfaces
9890
MLIRTransformUtils
9991
MLIRTransforms
10092
iree::compiler::Codegen::Common

0 commit comments

Comments
 (0)