Skip to content

Commit e4a3b04

Browse files
kuharclaude
andauthored
[Codgen][ROCm] Fix vector distribution for transposed outputs (#23791)
Layer norm-style dispatches with a multi-output generic that has a transposed output used to crash with `failed to distribute` on a proprietary model. Teach `shouldAttachLoweringConfig` to recognize non-identity output indexing maps so the op gets a `lowering_config` and proper `to_layout` anchors. Co-authored-by: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
1 parent 51f3912 commit e4a3b04

4 files changed

Lines changed: 116 additions & 2 deletions

File tree

compiler/src/iree/compiler/Codegen/Dialect/GPU/TargetUtils/ReductionConfigUtils.cpp

Lines changed: 13 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -428,6 +428,15 @@ populateConfigInfo(const llvm::SetVector<linalg::LinalgOp> &computeOps,
428428
// LinalgOp with only parallel dims. This is needed if the op cannot be fused
429429
// with a reduction or introduces new loop dimensions.
430430
auto shouldAttachLoweringConfig = [&](linalg::LinalgOp linalgOp) -> bool {
431+
// If any output has a non-identity indexing map, the op needs its own
432+
// layout anchors for vector distribution to handle the permuted write.
433+
// Check this first since it takes precedence over fusion preferences.
434+
for (OpOperand &output : linalgOp.getDpsInitsMutable()) {
435+
if (!linalgOp.getMatchingIndexingMap(&output).isIdentity()) {
436+
return true;
437+
}
438+
}
439+
431440
// If the operation has a gather, we want to fuse it with the
432441
// reduction.
433442
if (hasExternalCapture(cast<linalg::GenericOp>(linalgOp))) {
@@ -625,9 +634,11 @@ checkDispatchForVectorDistribution(Operation *parentOp) {
625634
/// attached.
626635
/// 2. `populateConfigInfo` determines to which linalg operations it might
627636
/// attach `lowering_config`. Currently, it attaches `lowering_config` to
628-
/// reduction operations and parallel operations that have new dimensions.
637+
/// reduction operations and parallel operations that have new dimensions or
638+
/// non-identity output indexing maps (e.g., transposed outputs).
629639
/// a. `getVectorDistributeReductionConfig` determines the `lowering_config`
630-
/// for the reduction as well as parallel operations with new dimension.
640+
/// for the reduction as well as parallel operations with new dimensions or
641+
/// non-identity outputs.
631642

632643
/// The workgroup, subgroup, and threadTileSizes are determined by the
633644
/// `setReductionConfig` operation, which are global

compiler/src/iree/compiler/Codegen/LLVMGPU/test/ROCDL/BUILD.bazel

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,7 @@ iree_lit_test_suite(
2323
"buffer_instructions_optimization.mlir",
2424
"config_direct_conv_tile_and_fuse.mlir",
2525
"config_igemm_tile_and_fuse.mlir",
26+
"config_reduction_transposed_output.mlir",
2627
"config_tile_and_fuse.mlir",
2728
"config_tile_and_fuse_gfx1201.mlir",
2829
"config_tile_and_fuse_gfx950.mlir",

compiler/src/iree/compiler/Codegen/LLVMGPU/test/ROCDL/CMakeLists.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@ iree_lit_test_suite(
1818
"buffer_instructions_optimization.mlir"
1919
"config_direct_conv_tile_and_fuse.mlir"
2020
"config_igemm_tile_and_fuse.mlir"
21+
"config_reduction_transposed_output.mlir"
2122
"config_tile_and_fuse.mlir"
2223
"config_tile_and_fuse_gfx1201.mlir"
2324
"config_tile_and_fuse_gfx950.mlir"
Lines changed: 101 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,101 @@
1+
// RUN: iree-opt --mlir-print-local-scope --split-input-file \
2+
// RUN: --iree-gpu-test-target=gfx942 \
3+
// RUN: --iree-codegen-llvmgpu-use-vector-distribution \
4+
// RUN: --pass-pipeline="builtin.module(iree-llvmgpu-select-lowering-strategy)" %s \
5+
// RUN: | FileCheck %s
6+
7+
// Verify that reductions fused with multi-output generics that have transposed
8+
// outputs select LLVMGPUVectorDistribute and attach lowering configs to the
9+
// parallel op with the transposed output.
10+
11+
// 2D case: reduction over dim 1, elementwise with (d0, d1) -> (d1, d0) output.
12+
13+
// CHECK-LABEL: func.func @reduction_2d_transposed_output
14+
// CHECK-SAME: pipeline = LLVMGPUVectorDistribute
15+
// CHECK: linalg.generic {{.*}} iterator_types = ["parallel", "reduction"]
16+
// CHECK-SAME: lowering_config
17+
// CHECK: linalg.generic {{.*}} iterator_types = ["parallel", "parallel"]
18+
// CHECK-SAME: lowering_config
19+
20+
func.func @reduction_2d_transposed_output(
21+
%input: tensor<512x4096xf32>,
22+
%result: !iree_tensor_ext.dispatch.tensor<writeonly:tensor<512x4096xf32>>,
23+
%result_t: !iree_tensor_ext.dispatch.tensor<writeonly:tensor<4096x512xf32>>) {
24+
%cst = arith.constant 0.000000e+00 : f32
25+
%empty_red = tensor.empty() : tensor<512xf32>
26+
%filled = linalg.fill ins(%cst : f32) outs(%empty_red : tensor<512xf32>) -> tensor<512xf32>
27+
%red = linalg.generic {
28+
indexing_maps = [affine_map<(d0, d1) -> (d0, d1)>,
29+
affine_map<(d0, d1) -> (d0)>],
30+
iterator_types = ["parallel", "reduction"]}
31+
ins(%input : tensor<512x4096xf32>) outs(%filled : tensor<512xf32>) {
32+
^bb0(%in: f32, %out: f32):
33+
%sq = arith.mulf %in, %in : f32
34+
%add = arith.addf %sq, %out : f32
35+
linalg.yield %add : f32
36+
} -> tensor<512xf32>
37+
%empty0 = tensor.empty() : tensor<512x4096xf32>
38+
%empty1 = tensor.empty() : tensor<4096x512xf32>
39+
%res:2 = linalg.generic {
40+
indexing_maps = [affine_map<(d0, d1) -> (d0, d1)>,
41+
affine_map<(d0, d1) -> (d0)>,
42+
affine_map<(d0, d1) -> (d0, d1)>,
43+
affine_map<(d0, d1) -> (d1, d0)>],
44+
iterator_types = ["parallel", "parallel"]}
45+
ins(%input, %red : tensor<512x4096xf32>, tensor<512xf32>)
46+
outs(%empty0, %empty1 : tensor<512x4096xf32>, tensor<4096x512xf32>) {
47+
^bb0(%in: f32, %r: f32, %o0: f32, %o1: f32):
48+
%v = arith.mulf %in, %r : f32
49+
linalg.yield %v, %v : f32, f32
50+
} -> (tensor<512x4096xf32>, tensor<4096x512xf32>)
51+
iree_tensor_ext.dispatch.tensor.store %res#0, %result, offsets = [0, 0], sizes = [512, 4096], strides = [1, 1] : tensor<512x4096xf32> -> !iree_tensor_ext.dispatch.tensor<writeonly:tensor<512x4096xf32>>
52+
iree_tensor_ext.dispatch.tensor.store %res#1, %result_t, offsets = [0, 0], sizes = [4096, 512], strides = [1, 1] : tensor<4096x512xf32> -> !iree_tensor_ext.dispatch.tensor<writeonly:tensor<4096x512xf32>>
53+
return
54+
}
55+
56+
// -----
57+
58+
// 3D case: reduction over dim 2, elementwise with (d0, d1, d2) -> (d0, d2, d1) output.
59+
60+
// CHECK-LABEL: func.func @reduction_3d_transposed_output
61+
// CHECK-SAME: pipeline = LLVMGPUVectorDistribute
62+
// CHECK: linalg.generic {{.*}} iterator_types = ["parallel", "parallel", "reduction"]
63+
// CHECK-SAME: lowering_config
64+
// CHECK: linalg.generic {{.*}} iterator_types = ["parallel", "parallel", "parallel"]
65+
// CHECK-SAME: lowering_config
66+
67+
func.func @reduction_3d_transposed_output(
68+
%input: tensor<16x32x4096xf32>,
69+
%result: !iree_tensor_ext.dispatch.tensor<writeonly:tensor<16x32x4096xf32>>,
70+
%result_t: !iree_tensor_ext.dispatch.tensor<writeonly:tensor<16x4096x32xf32>>) {
71+
%cst = arith.constant 0.000000e+00 : f32
72+
%empty_red = tensor.empty() : tensor<16x32xf32>
73+
%filled = linalg.fill ins(%cst : f32) outs(%empty_red : tensor<16x32xf32>) -> tensor<16x32xf32>
74+
%red = linalg.generic {
75+
indexing_maps = [affine_map<(d0, d1, d2) -> (d0, d1, d2)>,
76+
affine_map<(d0, d1, d2) -> (d0, d1)>],
77+
iterator_types = ["parallel", "parallel", "reduction"]}
78+
ins(%input : tensor<16x32x4096xf32>) outs(%filled : tensor<16x32xf32>) {
79+
^bb0(%in: f32, %out: f32):
80+
%sq = arith.mulf %in, %in : f32
81+
%add = arith.addf %sq, %out : f32
82+
linalg.yield %add : f32
83+
} -> tensor<16x32xf32>
84+
%empty0 = tensor.empty() : tensor<16x32x4096xf32>
85+
%empty1 = tensor.empty() : tensor<16x4096x32xf32>
86+
%res:2 = linalg.generic {
87+
indexing_maps = [affine_map<(d0, d1, d2) -> (d0, d1, d2)>,
88+
affine_map<(d0, d1, d2) -> (d0, d1)>,
89+
affine_map<(d0, d1, d2) -> (d0, d1, d2)>,
90+
affine_map<(d0, d1, d2) -> (d0, d2, d1)>],
91+
iterator_types = ["parallel", "parallel", "parallel"]}
92+
ins(%input, %red : tensor<16x32x4096xf32>, tensor<16x32xf32>)
93+
outs(%empty0, %empty1 : tensor<16x32x4096xf32>, tensor<16x4096x32xf32>) {
94+
^bb0(%in: f32, %r: f32, %o0: f32, %o1: f32):
95+
%v = arith.mulf %in, %r : f32
96+
linalg.yield %v, %v : f32, f32
97+
} -> (tensor<16x32x4096xf32>, tensor<16x4096x32xf32>)
98+
iree_tensor_ext.dispatch.tensor.store %res#0, %result, offsets = [0, 0, 0], sizes = [16, 32, 4096], strides = [1, 1, 1] : tensor<16x32x4096xf32> -> !iree_tensor_ext.dispatch.tensor<writeonly:tensor<16x32x4096xf32>>
99+
iree_tensor_ext.dispatch.tensor.store %res#1, %result_t, offsets = [0, 0, 0], sizes = [16, 4096, 32], strides = [1, 1, 1] : tensor<16x4096x32xf32> -> !iree_tensor_ext.dispatch.tensor<writeonly:tensor<16x4096x32xf32>>
100+
return
101+
}

0 commit comments

Comments
 (0)