Skip to content

Commit a02e85f

Browse files
authored
[CPU] Propagate the reduction tile sizes to producers because of fusion. (#23660)
The codegen pipeline is designed to fuse produers into reduction loops for less memory footprint. Thus, the tile sizes should be propagated to producers. Previously, it triggered the vector input sizes from lowering config, which leads to numeric issues. Fixes #23638 ci-extra: linux_arm64_clang Signed-off-by: hanhanW <hanhan0912@gmail.com>
1 parent 098465b commit a02e85f

7 files changed

Lines changed: 96 additions & 6 deletions

File tree

compiler/src/iree/compiler/Codegen/LLVMCPU/KernelDispatch.cpp

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3797,7 +3797,14 @@ void MultiLoweringConfigGenerator::setNewTilingConfigs() {
37973797
// level is `VectorReductionTiles`, skip it.
37983798
if ((iterType == utils::IteratorType::reduction) ^
37993799
(level == IREE::CPU::TilingLevel::VectorReductionTiles)) {
3800-
continue;
3800+
// Producer ops are fused during reduction tiling, so their
3801+
// parallel dims that correspond to root reduction dims need the
3802+
// reduction tile sizes in their config.
3803+
if (!(isProducerOfRootOp(op, rootOperation) &&
3804+
level == IREE::CPU::TilingLevel::VectorReductionTiles &&
3805+
iterType == utils::IteratorType::parallel)) {
3806+
continue;
3807+
}
38013808
}
38023809
tileSizes[pos] = globalTileSizes[level][globalDimIdx];
38033810
scalableFlags[pos] = globalScalableTileFlags[level][globalDimIdx];

compiler/src/iree/compiler/Codegen/LLVMCPU/LLVMCPUSelectLoweringStrategy.cpp

Lines changed: 21 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,8 @@
1212
#include "iree/compiler/Codegen/LLVMCPU/KernelDispatch.h"
1313
#include "iree/compiler/Codegen/LLVMCPU/Passes.h"
1414
#include "iree/compiler/Codegen/LLVMCPU/Utils.h"
15+
#include "iree/compiler/Codegen/Utils/CPUUtils.h"
16+
#include "iree/compiler/Codegen/Utils/Utils.h"
1517
#include "mlir/Pass/Pass.h"
1618
#include "mlir/Pass/PassManager.h"
1719
#include "mlir/Pass/PassRegistry.h"
@@ -46,8 +48,12 @@ static bool isValidInterchange(ArrayRef<int64_t> interchange, int numLoops) {
4648
}
4749

4850
/// Verifies if the tile sizes from `loweringConfig` are valid for each level.
51+
/// `rootOp` is the root compute op in the dispatch; producer ops (before
52+
/// the root) may have parallel dims set at reduction tiling levels because
53+
/// they are fused during reduction tiling.
4954
static LogicalResult verifyMultiTilingExpertPassPipelineConfig(
50-
Operation *op, IREE::CPU::LoweringConfigAttr loweringConfig) {
55+
Operation *op, IREE::CPU::LoweringConfigAttr loweringConfig,
56+
Operation *rootOp) {
5157

5258
auto interfaceOp = dyn_cast_if_present<TilingInterface>(op);
5359
if (!interfaceOp) {
@@ -89,6 +95,12 @@ static LogicalResult verifyMultiTilingExpertPassPipelineConfig(
8995
}
9096
case IREE::CPU::TilingLevel::CacheReductionTiles:
9197
case IREE::CPU::TilingLevel::VectorReductionTiles: {
98+
// Producer ops (before the root) are fused during reduction tiling,
99+
// so their parallel dims may carry reduction tile sizes inherited
100+
// from the root op. Skip this check for producers.
101+
if (isProducerOfRootOp(op, rootOp)) {
102+
break;
103+
}
92104
for (auto [index, tileSize] :
93105
llvm::enumerate(tilingLevelAttr.getSizes())) {
94106
if (tileSize != 0 && pLoopsSet.contains(index)) {
@@ -122,7 +134,8 @@ static LogicalResult verifyMultiTilingExpertPassPipelineConfig(
122134
/// lower dim ops. It requires {Distribution, VectorCommonParallel,
123135
/// VectorReduction} tiling levels.
124136
static LogicalResult verifyConvTileAndDecomposeExpertConfig(
125-
Operation *op, IREE::CPU::LoweringConfigAttr loweringConfig) {
137+
Operation *op, IREE::CPU::LoweringConfigAttr loweringConfig,
138+
Operation * /*rootOp*/) {
126139
if (!isa<linalg::ConvolutionOpInterface>(op)) {
127140
return success();
128141
}
@@ -218,6 +231,11 @@ static LogicalResult verifyConvTileAndDecomposeExpertConfig(
218231
template <typename F>
219232
static LogicalResult verifyLoweringConfiguration(FunctionOpInterface funcOp,
220233
F verificationFn) {
234+
// Find the root op for producer/consumer distinction in verification.
235+
SmallVector<Operation *> computeOps = getComputeOps(funcOp);
236+
FailureOr<Operation *> rootOp = getRootOperation(computeOps);
237+
Operation *root = succeeded(rootOp) ? rootOp.value() : nullptr;
238+
221239
auto walkResult = funcOp.walk([&](Operation *op) -> WalkResult {
222240
if (isa<IREE::LinalgExt::CustomOp>(op)) {
223241
return WalkResult::advance();
@@ -226,7 +244,7 @@ static LogicalResult verifyLoweringConfiguration(FunctionOpInterface funcOp,
226244
if (!loweringConfig) {
227245
return WalkResult::advance();
228246
}
229-
return verificationFn(op, loweringConfig);
247+
return verificationFn(op, loweringConfig, root);
230248
});
231249
return failure(walkResult.wasInterrupted());
232250
}

compiler/src/iree/compiler/Codegen/LLVMCPU/test/select_aarch64_lowering_strategy.mlir

Lines changed: 49 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -213,3 +213,52 @@ func.func @mmt4d_384x384x512_4x1x4_dispatch_0(%3: tensor<96x384x4x1xf32>, %4: te
213213
// CHECK: func.func @mmt4d_384x384x512_4x1x4_dispatch_0(
214214
// CHECK: linalg.mmt4d
215215
// CHECK-SAME: lowering_config = #[[CONFIG]]
216+
217+
// -----
218+
219+
// Verify that gather producers of attention get vector_reduction tile sizes
220+
// for dims that map to attention's reduction dims. Without this, the gather
221+
// would be vectorized with incorrect tile sizes (0 -> 1 replacement) causing
222+
// wrong numerical results.
223+
224+
#executable_target = #hal.executable.target<"llvm-cpu", "embedded-elf-arm_64", {cpu = "generic", cpu_features = "+reserve-x18", data_layout = "e-m:e-p270:32:32-p271:32:32-p272:64:64-i8:8:32-i16:16:32-i64:64-i128:128-n32:64-S128-Fn32", native_vector_size = 16 : i64, target_triple = "aarch64-unknown-unknown-eabi-elf"}>
225+
func.func @gather_attention(
226+
%key_table: tensor<?x4x16x32xf16>, %indices: tensor<32x?xi64>,
227+
%value_table: tensor<?x4x16x32xf16>, %query: tensor<32x4x2x32xf16>,
228+
%mask: tensor<32x4x2x?x16xf16>, %dim0: index, %dim1: index,
229+
%dim2: index, %dim3: index) -> tensor<32x4x2x32xf16>
230+
attributes {hal.executable.target = #executable_target} {
231+
%cst = arith.constant 1.767580e-01 : f16
232+
%empty = tensor.empty(%dim1) : tensor<32x?x4x16x32xf16>
233+
%k_gather = iree_linalg_ext.gather dimension_map = [0]
234+
ins(%key_table, %indices : tensor<?x4x16x32xf16>, tensor<32x?xi64>)
235+
outs(%empty : tensor<32x?x4x16x32xf16>) -> tensor<32x?x4x16x32xf16>
236+
%v_gather = iree_linalg_ext.gather dimension_map = [0]
237+
ins(%value_table, %indices : tensor<?x4x16x32xf16>, tensor<32x?xi64>)
238+
outs(%empty : tensor<32x?x4x16x32xf16>) -> tensor<32x?x4x16x32xf16>
239+
%out = tensor.empty() : tensor<32x4x2x32xf16>
240+
%result = iree_linalg_ext.attention {
241+
indexing_maps = [affine_map<(d0, d1, d2, d3, d4, d5, d6) -> (d0, d1, d2, d4)>,
242+
affine_map<(d0, d1, d2, d3, d4, d5, d6) -> (d0, d5, d1, d6, d4)>,
243+
affine_map<(d0, d1, d2, d3, d4, d5, d6) -> (d0, d5, d1, d6, d3)>,
244+
affine_map<(d0, d1, d2, d3, d4, d5, d6) -> ()>,
245+
affine_map<(d0, d1, d2, d3, d4, d5, d6) -> (d0, d1, d2, d5, d6)>,
246+
affine_map<(d0, d1, d2, d3, d4, d5, d6) -> (d0, d1, d2, d3)>]}
247+
ins(%query, %k_gather, %v_gather, %cst, %mask
248+
: tensor<32x4x2x32xf16>, tensor<32x?x4x16x32xf16>,
249+
tensor<32x?x4x16x32xf16>, f16, tensor<32x4x2x?x16xf16>)
250+
outs(%out : tensor<32x4x2x32xf16>) {
251+
^bb0(%arg0: f32):
252+
iree_linalg_ext.yield %arg0 : f32
253+
} -> tensor<32x4x2x32xf16>
254+
return %result : tensor<32x4x2x32xf16>
255+
}
256+
// Gather ops should have vector_reduction set for dims mapping to attention
257+
// reduction dims (d5, d6). This is critical for correct vectorization.
258+
// CHECK-DAG: #[[GATHER_CONFIG:.+]] = #iree_cpu.lowering_config<vector_common_parallel = [1, 0, 1, 0, {{[0-9]+}}], vector_reduction = [0, 1, 0, 4, 0]>
259+
// CHECK-DAG: #[[ATTN_CONFIG:.+]] = #iree_cpu.lowering_config<distribution = [1, 1, 2, 32, 0, 0, 0], vector_common_parallel = [1, 1, 1, 2, 0, 0, 0], vector_reduction = [0, 0, 0, 0, 0, 1, 4]>
260+
// CHECK: func.func @gather_attention(
261+
// CHECK: iree_linalg_ext.gather
262+
// CHECK-SAME: lowering_config = #[[GATHER_CONFIG]]
263+
// CHECK: iree_linalg_ext.attention
264+
// CHECK-SAME: lowering_config = #[[ATTN_CONFIG]]

compiler/src/iree/compiler/Codegen/Utils/CPUUtils.cpp

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -111,4 +111,18 @@ unsigned getUserVscaleValue() {
111111
return clVscaleFromUser;
112112
}
113113

114+
bool isProducerOfRootOp(Operation *op, Operation *rootOp) {
115+
if (!rootOp || op == rootOp) {
116+
return false;
117+
}
118+
for (Value result : op->getResults()) {
119+
for (Operation *user : result.getUsers()) {
120+
if (user == rootOp) {
121+
return true;
122+
}
123+
}
124+
}
125+
return false;
126+
}
127+
114128
} // namespace mlir::iree_compiler

compiler/src/iree/compiler/Codegen/Utils/CPUUtils.h

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -45,6 +45,10 @@ bool isScalableVectorizationEnabled();
4545
/// is resolved.
4646
unsigned getUserVscaleValue();
4747

48+
/// Returns true if `op` is a direct producer of `rootOp`, i.e., at least one
49+
/// of `op`'s results is used as an operand of `rootOp`.
50+
bool isProducerOfRootOp(Operation *op, Operation *rootOp);
51+
4852
} // namespace mlir::iree_compiler
4953

5054
#endif // IREE_COMPILER_CODEGEN_UTILS_CPUUTILS_H_

tests/e2e/regression/BUILD.bazel

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -86,7 +86,6 @@ iree_check_single_backend_test_suite(
8686
compiler_flags = ["--iree-llvmcpu-target-cpu=generic"],
8787
driver = "local-task",
8888
tags = [
89-
"noaarch64",
9089
"noriscv",
9190
],
9291
target_backend = "llvm-cpu",

tests/e2e/regression/CMakeLists.txt

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -81,7 +81,6 @@ iree_check_single_backend_test_suite(
8181
COMPILER_FLAGS
8282
"--iree-llvmcpu-target-cpu=generic"
8383
LABELS
84-
"noaarch64"
8584
"noriscv"
8685
)
8786

0 commit comments

Comments
 (0)