Skip to content

Commit d84dc73

Browse files
[Codegen][GPU] Fix f32 attention compilation failure when head_dim=128 (#24138)
Fix attention schedule deduction where QK intrinsic N was incorrectly set to PV intrinsic K, causing K2 tile to be smaller than the QK accumulator inner tile for f32, which later fails the packing in `GPUPackToIntrinsics`. Added regression test for f32 attention with hd=128 Fixes #24135 --------- Signed-off-by: Keshav Vinayak Jha <keshavvinayakjha@gmail.com>
1 parent 445cf9c commit d84dc73

3 files changed

Lines changed: 73 additions & 6 deletions

File tree

compiler/src/iree/compiler/Codegen/Common/GPU/GPUHeuristics.cpp

Lines changed: 20 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1206,10 +1206,22 @@ FailureOr<std::pair<GPUMMASchedule, GPUMMASchedule>> deduceAttentionSchedule(
12061206
int64_t intrinsicAN = intrinsicA.nSizes[0];
12071207
int64_t intrinsicAK = intrinsicA.kSizes[0];
12081208
auto isValidSchedule = [&](const GPUMMASchedule &schedule) -> bool {
1209+
// The output of the QK matmul must be a valid LHS of the PV matmul.
1210+
// The total LHS tile (M x K) of the PV matmul must be a multiple of
1211+
// the output tile (M x N) of the intrinsic used for the QK matmul.
1212+
int64_t pvMTile = schedule.getTotalMTileSize() *
1213+
schedule.getTotalMSize() *
1214+
schedule.getTotalMSubgroupCount();
1215+
int64_t pvKTile = schedule.getTotalKTileSize() * schedule.getTotalKSize();
1216+
if (pvMTile % intrinsicAM != 0 || pvKTile % intrinsicAN != 0) {
1217+
return false;
1218+
}
1219+
12091220
// Create a mma schedule for qkMatmul in attention.
12101221
// qkMatmul.M = pvMatmul.M
12111222
// qkMatmul.N = pvMatmul.K
1212-
// qkMatmul.K = problem.K
1223+
// qkMatmul.K = problem.K1
1224+
int64_t qkNTiles = pvKTile / intrinsicAN;
12131225
SmallVector<int64_t, 2> qkKSizes = qkMatmul.kSizes;
12141226
qkKSizes.back() = qkMatmul.kSizes.back() / intrinsicAK;
12151227
GPUMMASchedule qkSchedule{
@@ -1220,7 +1232,7 @@ FailureOr<std::pair<GPUMMASchedule, GPUMMASchedule>> deduceAttentionSchedule(
12201232
/*mSubgroupCount=*/schedule.mSubgroupCounts,
12211233
/*nSubgroupCount=*/SmallVector<int64_t>(qkMatmul.nSizes.size(), 1),
12221234
schedule.mTileSizes,
1223-
schedule.kTileSizes,
1235+
{qkNTiles},
12241236
qkKSizes};
12251237

12261238
bool isQKAligned =
@@ -1262,18 +1274,21 @@ FailureOr<std::pair<GPUMMASchedule, GPUMMASchedule>> deduceAttentionSchedule(
12621274
// Create a mma schedule for qkMatmul in attention.
12631275
// qkMatmul.M = pvMatmul.M
12641276
// qkMatmul.N = pvMatmul.K
1265-
// qkMatmul.K = problem.K
1277+
// qkMatmul.K = problem.K1
1278+
int64_t pvKTile =
1279+
pvSchedule->getTotalKTileSize() * pvSchedule->getTotalKSize();
1280+
int64_t qkNTiles = pvKTile / intrinsicAN;
12661281
SmallVector<int64_t, 2> qkKSizes = qkMatmul.kSizes;
12671282
qkKSizes.back() = qkMatmul.kSizes.back() / intrinsicAK;
12681283
GPUMMASchedule qkSchedule{
12691284
intrinsicA.mmaKind,
12701285
pvSchedule->mSizes,
1271-
pvSchedule->kSizes,
1286+
{intrinsicAN},
12721287
{intrinsicAK},
12731288
/*mSubgroupCount=*/pvSchedule->mSubgroupCounts,
12741289
/*nSubgroupCount=*/SmallVector<int64_t>(qkMatmul.nSizes.size(), 1),
12751290
pvSchedule->mTileSizes,
1276-
pvSchedule->kTileSizes,
1291+
{qkNTiles},
12771292
qkKSizes};
12781293

12791294
return std::pair(qkSchedule, pvSchedule.value());

compiler/src/iree/compiler/Codegen/LLVMGPU/KernelConfig.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -869,7 +869,7 @@ static LogicalResult setAttentionIntrinsicBasedVectorDistributionConfig(
869869
Type f32Type = b.getF32Type();
870870
GPUMatmulShapeType qkMatmul{
871871
/*m=*/getDimBounds(mDims),
872-
/*n=*/getDimBounds(nDims),
872+
/*n=*/getDimBounds(k2Dims),
873873
/*k=*/getDimBounds(k1Dims),
874874
/*batch=*/getDimBounds(batchDims),
875875
/*a=*/qElementType,

compiler/src/iree/compiler/Codegen/LLVMGPU/test/ROCDL/config_vector_distribute_gfx942.mlir

Lines changed: 52 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -524,3 +524,55 @@ func.func @attention_multi_m_dynamic(%arg0 : tensor<20x8x?x16x64xf16>, %arg1 : t
524524
// CHECK-SAME: #iree_gpu.lowering_config
525525
// CHECK-SAME: reduction = [0, 0, 0, 0, 0, 64, 0]
526526
// CHECK-SAME: workgroup = [1, 4, 1, 16, 0, 0, 64]
527+
528+
// -----
529+
530+
// CHECK: #iree_codegen.translation_info<pipeline = #iree_gpu.pipeline<VectorDistribute>
531+
532+
// CHECK-LABEL: func.func @attention_f32_20x4096x128x4096x128()
533+
534+
#map = affine_map<(d0, d1, d2, d3, d4) -> (d0, d1, d2)>
535+
#map1 = affine_map<(d0, d1, d2, d3, d4) -> (d0, d3, d2)>
536+
#map2 = affine_map<(d0, d1, d2, d3, d4) -> (d0, d3, d4)>
537+
#map3 = affine_map<(d0, d1, d2, d3, d4) -> ()>
538+
#map4 = affine_map<(d0, d1, d2, d3, d4) -> (d0, d1, d4)>
539+
#map5 = affine_map<(d0, d1, d2, d3, d4) -> (d0, d1)>
540+
#pipeline_layout = #hal.pipeline.layout<bindings = [
541+
#hal.pipeline.binding<storage_buffer>,
542+
#hal.pipeline.binding<storage_buffer>,
543+
#hal.pipeline.binding<storage_buffer>,
544+
#hal.pipeline.binding<storage_buffer>
545+
]>
546+
func.func @attention_f32_20x4096x128x4096x128() {
547+
%cst = arith.constant 1.250000e-01 : f32
548+
%c0 = arith.constant 0 : index
549+
%0 = hal.interface.binding.subspan layout(#pipeline_layout) binding(0) alignment(64) offset(%c0) flags(ReadOnly) : !iree_tensor_ext.dispatch.tensor<readonly:tensor<20x4096x128xf32>>
550+
%1 = hal.interface.binding.subspan layout(#pipeline_layout) binding(1) alignment(64) offset(%c0) flags(ReadOnly) : !iree_tensor_ext.dispatch.tensor<readonly:tensor<20x4096x128xf32>>
551+
%2 = hal.interface.binding.subspan layout(#pipeline_layout) binding(2) alignment(64) offset(%c0) flags(ReadOnly) : !iree_tensor_ext.dispatch.tensor<readonly:tensor<20x4096x128xf32>>
552+
%3 = hal.interface.binding.subspan layout(#pipeline_layout) binding(3) alignment(64) offset(%c0) : !iree_tensor_ext.dispatch.tensor<writeonly:tensor<20x4096x128xf32>>
553+
%4 = iree_tensor_ext.dispatch.tensor.load %0, offsets = [0, 0, 0], sizes = [20, 4096, 128], strides = [1, 1, 1] : !iree_tensor_ext.dispatch.tensor<readonly:tensor<20x4096x128xf32>> -> tensor<20x4096x128xf32>
554+
%5 = iree_tensor_ext.dispatch.tensor.load %1, offsets = [0, 0, 0], sizes = [20, 4096, 128], strides = [1, 1, 1] : !iree_tensor_ext.dispatch.tensor<readonly:tensor<20x4096x128xf32>> -> tensor<20x4096x128xf32>
555+
%6 = iree_tensor_ext.dispatch.tensor.load %2, offsets = [0, 0, 0], sizes = [20, 4096, 128], strides = [1, 1, 1] : !iree_tensor_ext.dispatch.tensor<readonly:tensor<20x4096x128xf32>> -> tensor<20x4096x128xf32>
556+
%7 = tensor.empty() : tensor<20x4096x128xf32>
557+
%8 = tensor.empty() : tensor<20x4096xf32>
558+
%cst_0 = arith.constant 0.000000e+00 : f32
559+
%cst_1 = arith.constant -3.40282347E+38 : f32
560+
%cst_2 = arith.constant 0.000000e+00 : f32
561+
%9 = linalg.fill ins(%cst_0 : f32) outs(%7 : tensor<20x4096x128xf32>) -> tensor<20x4096x128xf32>
562+
%10 = linalg.fill ins(%cst_1 : f32) outs(%8 : tensor<20x4096xf32>) -> tensor<20x4096xf32>
563+
%11 = linalg.fill ins(%cst_2 : f32) outs(%8 : tensor<20x4096xf32>) -> tensor<20x4096xf32>
564+
%12:3 = iree_linalg_ext.online_attention {indexing_maps = [#map, #map1, #map2, #map3, #map4, #map5, #map5]} ins(%4, %5, %6, %cst : tensor<20x4096x128xf32>, tensor<20x4096x128xf32>, tensor<20x4096x128xf32>, f32) outs(%9, %10, %11 : tensor<20x4096x128xf32>, tensor<20x4096xf32>, tensor<20x4096xf32>) {
565+
^bb0(%arg0: f32):
566+
iree_linalg_ext.yield %arg0 : f32
567+
} -> tensor<20x4096x128xf32>, tensor<20x4096xf32>, tensor<20x4096xf32>
568+
iree_tensor_ext.dispatch.tensor.store %12#0, %3, offsets = [0, 0, 0], sizes = [20, 4096, 128], strides = [1, 1, 1] : tensor<20x4096x128xf32> -> !iree_tensor_ext.dispatch.tensor<writeonly:tensor<20x4096x128xf32>>
569+
return
570+
}
571+
572+
// CHECK: #iree_gpu.lowering_config
573+
// CHECK-SAME: mma_kind = #iree_gpu.mma_layout<MFMA_F32_16x16x4_F32>
574+
// CHECK-SAME: #iree_gpu.lowering_config
575+
// CHECK-SAME: mma_kind = #iree_gpu.mma_layout<MFMA_F32_16x16x4_F32>
576+
// CHECK-SAME: #iree_gpu.lowering_config
577+
// CHECK-SAME: reduction = [0, 0, 0, 16, 0]
578+
// CHECK-SAME: workgroup = [1, 64, 0, 0, 64]

0 commit comments

Comments
 (0)