Skip to content

Commit 47d41ab

Browse files
yzhang93claude
andauthored
[GPUHeuristics] Improve large GEMM intrinsic selection on CDNA4 (#24115)
Extend the compute-throughput-first intrinsic preference to LargeGemm shapes, preferring MFMA_F32_32x32x16_F16 over MFMA_F32_16x16x32_F16 (4x more output per instruction). Add VGPR pressure cap to prevent spilling when MNT boost sets high tile counts with 32x32 intrinsics. Top GEMM improvements on MI355X: ``` 4096x1024x150000: 2112us -> 1538us (1.37x) 2268x4096x150000: 11359us -> 8529us (1.33x) 1024x4096x150000: 1982us -> 1573us (1.26x) 4096x2048x150000: 4015us -> 3307us (1.21x) 2048x8192x4096: 183us -> 154us (1.19x) ``` Top conv improvements on MI355X (NHWC, fp16): ``` n32 c256 H100xW100 k2376 3x3 wgrad: 7983us -> 6634us (1.20x) n32 c256 H25xW25 k2376 3x3 wgrad: 777us -> 664us (1.17x) n32 c256 H100xW100 k2376 3x3 fwd: 7042us -> 6122us (1.15x) n32 c256 H25xW25 k2376 3x3 fwd: 452us -> 405us (1.12x) n32 c256 H50xW50 k2376 3x3 fwd: 1711us -> 1541us (1.11x) ``` Overall GEMM benchmark: **+6.3%** geomean speedup. Overall Proxy conv benchmark: **+2.5%** geomean speedup. Some regressions exist in K-dominated wgrad shapes due to larger workgroup tiles, but overall improvements outweigh regressions across all benchmarks. --------- Signed-off-by: yzhang93 <zhyuhang88@gmail.com> Co-authored-by: Claude <noreply@anthropic.com>
1 parent 89b536e commit 47d41ab

4 files changed

Lines changed: 73 additions & 26 deletions

File tree

compiler/src/iree/compiler/Codegen/Common/GPU/GPUHeuristics.cpp

Lines changed: 48 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -697,7 +697,8 @@ static double computeMNUtilization(const GPUMatmulShapeType &problem,
697697
/// returns true if the lhs is ordered before rhs.
698698
static bool compareIntrinsics(const GPUMatmulShapeType &problem,
699699
const GPUIntrinsicType &lhs,
700-
const GPUIntrinsicType &rhs) {
700+
const GPUIntrinsicType &rhs,
701+
bool preferHighComputeIntrinsic = false) {
701702
// When both M and N need padding, prefer the intrinsic with better M*N
702703
// utilization. This targets grouped convolutions where per-group channels
703704
// are small (e.g., 8x8 problem: 16x16 at 25% util >> 32x32 at 6.25%).
@@ -775,7 +776,7 @@ static bool compareIntrinsics(const GPUMatmulShapeType &problem,
775776
// (compute=8192, area=512) because throughput matters more. Among
776777
// 16x16x32 and 32x32x16 (both area=1024), prefer smaller K (16 vs 32)
777778
// for less operand staging pressure.
778-
if (problem.gemmSize == GemmSizeKind::VeryLargeGemm) {
779+
if (preferHighComputeIntrinsic) {
779780
int64_t lhsCompute = intrinsicCompute(lhs);
780781
int64_t rhsCompute = intrinsicCompute(rhs);
781782
if (lhsCompute != rhsCompute) {
@@ -806,11 +807,12 @@ static bool compareIntrinsics(const GPUMatmulShapeType &problem,
806807

807808
static SmallVector<GPUIntrinsicType>
808809
sortMMAIntrinsics(GPUMatmulShapeType problem,
809-
ArrayRef<GPUIntrinsicType> intrinsics) {
810+
ArrayRef<GPUIntrinsicType> intrinsics,
811+
bool preferHighComputeIntrinsic = false) {
810812
SmallVector<GPUIntrinsicType> sortedIntrinsics(intrinsics);
811813
llvm::stable_sort(sortedIntrinsics, [&](const GPUIntrinsicType &lhs,
812814
const GPUIntrinsicType &rhs) {
813-
return compareIntrinsics(problem, lhs, rhs);
815+
return compareIntrinsics(problem, lhs, rhs, preferHighComputeIntrinsic);
814816
});
815817
return sortedIntrinsics;
816818
}
@@ -834,14 +836,16 @@ static int64_t computeEstimatedWorkgroupCount(const GPUMMAHeuristicSeeds &seeds,
834836
}
835837

836838
/// Adjust M*N tile-count (bestMNTileCountPerSubgroup) seeds based on target
837-
/// hardware and problem characteristics. Three independent adjustments, applied
839+
/// hardware and problem characteristics. Four independent adjustments, applied
838840
/// in order:
839841
/// 1. Baseline (all targets): reduces bestMNTileCountPerSubgroup until the
840842
/// estimated workgroup count fills all CUs.
841843
/// 2. Tile-count boost (when boostMNTileCountPerSubgroup is set): for GEMMs
842844
/// with balanced K, boosts tile count to the architecture-specific target.
843845
/// 3. Utilization guard (when minUtilizationThreshold is set): halves tile
844846
/// count until GPU utilization meets the threshold.
847+
/// 4. VGPR pressure cap: limits MN tile count based on per-thread output
848+
/// register pressure from the selected intrinsic, preventing spilling.
845849
static void adjustSeedsForTarget(GPUMMAHeuristicSeeds &seeds,
846850
const GPUMatmulShapeType &problem,
847851
const GPUIntrinsicType &intrinsic,
@@ -898,6 +902,12 @@ static void adjustSeedsForTarget(GPUMMAHeuristicSeeds &seeds,
898902
std::max(seeds.bestMNTileCountPerSubgroup, boostMNT);
899903
LDBG() << "Boosting MNT to " << seeds.bestMNTileCountPerSubgroup
900904
<< " for balanced large gemm";
905+
// Halve subgroup count to offset the MNT boost, keeping the total
906+
// workgroup resource footprint (threads, LDS) in check for occupancy.
907+
seeds.bestSubgroupCountPerWorkgroup =
908+
std::max<int64_t>(1, seeds.bestSubgroupCountPerWorkgroup / 2);
909+
LDBG() << "Halving subgroup count to "
910+
<< seeds.bestSubgroupCountPerWorkgroup << " to offset MNT boost";
901911
}
902912
}
903913

@@ -928,6 +938,27 @@ static void adjustSeedsForTarget(GPUMMAHeuristicSeeds &seeds,
928938
<< seeds.bestMNTileCountPerSubgroup;
929939
}
930940
}
941+
942+
// Cap per-subgroup MN tile count based on output VGPR pressure from the
943+
// selected intrinsic. Only applies when the MNT boost (step 2) is
944+
// configured, since the boost can push MN tile counts high enough to
945+
// cause spilling with large-output intrinsics (32x32). Capping at 128
946+
// output VGPRs per thread (8 MN tiles for 32x32, 32 for 16x16) prevents
947+
// spilling while preserving the boost for intrinsics that can handle
948+
// higher tile counts.
949+
if (seeds.maxOutputVGPRsPerThread) {
950+
int64_t subgroupSize = target.getPreferredSubgroupSize();
951+
int64_t outputVGPRsPerTile =
952+
(intrinsic.mSizes[0] * intrinsic.nSizes[0]) / subgroupSize;
953+
int64_t maxMNTiles = *seeds.maxOutputVGPRsPerThread / outputVGPRsPerTile;
954+
if (seeds.bestMNTileCountPerSubgroup > maxMNTiles) {
955+
LDBG() << "VGPR cap: reducing bestMNTileCountPerSubgroup from "
956+
<< seeds.bestMNTileCountPerSubgroup << " to " << maxMNTiles
957+
<< " (intrinsic " << intrinsic.mSizes[0] << "x"
958+
<< intrinsic.nSizes[0] << ")";
959+
seeds.bestMNTileCountPerSubgroup = maxMNTiles;
960+
}
961+
}
931962
}
932963

933964
FailureOr<GPUMMASchedule> deduceMMASchedule(
@@ -938,8 +969,19 @@ FailureOr<GPUMMASchedule> deduceMMASchedule(
938969
bool useDirectLoad, int64_t prefetchNumStages, bool mustBeAligned,
939970
bool doCPromotion, int64_t splitReductionTripCnt) {
940971

972+
// Prefer higher-compute intrinsics (e.g., 32x32x16 over 16x16x32) for:
973+
// - VeryLargeGemm: always compute-bound, higher throughput wins.
974+
// - LargeGemm on architectures with MNT boost (e.g., CDNA4): the boost
975+
// indicates the target benefits from larger output tiles. Gated by
976+
// !doCPromotion to avoid regressing addmm shapes that need accumulator
977+
// promotion to shared memory.
978+
bool isLargeGemmWithBoost = problem.gemmSize == GemmSizeKind::LargeGemm &&
979+
seeds.boostMNTileCountPerSubgroup.has_value() &&
980+
!doCPromotion;
981+
bool preferHighComputeIntrinsic =
982+
problem.gemmSize == GemmSizeKind::VeryLargeGemm || isLargeGemmWithBoost;
941983
SmallVector<GPUIntrinsicType> sortedIntrinsics =
942-
sortMMAIntrinsics(problem, intrinsics);
984+
sortMMAIntrinsics(problem, intrinsics, preferHighComputeIntrinsic);
943985

944986
// Compute product of M and N problem sizes to decide if block intrinsics
945987
// should be considered. If both M and N products exceed the threshold, skip

compiler/src/iree/compiler/Codegen/Common/GPU/GPUHeuristics.h

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -102,6 +102,10 @@ struct GPUMMAHeuristicSeeds {
102102
// per workgroup), which can improve performance when the GPU has enough work
103103
// to stay saturated.
104104
std::optional<int64_t> boostMNTileCountPerSubgroup = std::nullopt;
105+
// Maximum output VGPRs per thread for the VGPR pressure cap. When set,
106+
// adjustSeedsForTarget will reduce bestMNTileCountPerSubgroup to keep
107+
// per-thread output register pressure within this limit.
108+
std::optional<int64_t> maxOutputVGPRsPerThread = std::nullopt;
105109
};
106110

107111
struct GPUMMASchedule {

compiler/src/iree/compiler/Codegen/Dialect/GPU/TargetUtils/KnownTargets.cpp

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1351,9 +1351,10 @@ static constexpr ArchSeedSet kCDNA4Seeds = {
13511351
/*gemm=*/{
13521352
/*SmallGemm=*/ {2, 2, 4, 2 * kCacheLineSizeBits},
13531353
/*MediumGemm=*/ {4, 8, 4, 2 * kCacheLineSizeBits},
1354-
/*LargeGemm=*/ {4, 16, 2, kCacheLineSizeBits / 2,
1354+
/*LargeGemm=*/ {8, 16, 2, kCacheLineSizeBits / 2,
13551355
/*minUtilizationThreshold=*/0.50,
1356-
/*boostMNTileCountPerSubgroup=*/32},
1356+
/*boostMNTileCountPerSubgroup=*/32,
1357+
/*maxOutputVGPRsPerThread=*/128},
13571358
/*VeryLargeGemm=*/ {4, 16, 2, kCacheLineSizeBits / 2},
13581359
},
13591360
/*scaledGemm=*/{

compiler/src/iree/compiler/Codegen/LLVMGPU/test/ROCDL/config_tile_and_fuse_gfx950.mlir

Lines changed: 18 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -412,13 +412,13 @@ func.func @matmul_large_symmetric_f16(
412412

413413
// MI355X-LABEL: func.func @matmul_large_symmetric_f16
414414
// MI355X-SAME: #iree_codegen.translation_info<pipeline = #iree_gpu.pipeline<TileAndFuse>
415-
// MI355X-SAME: workgroup_size = [256, 1, 1] subgroup_size = 64
415+
// MI355X-SAME: workgroup_size = [512, 1, 1] subgroup_size = 64
416416
// MI355X: linalg.matmul {{.*}}lowering_config = #iree_gpu.lowering_config
417-
// MI355X-SAME: mma_kind = #iree_gpu.mma_layout<MFMA_F32_16x16x32_F16>
417+
// MI355X-SAME: mma_kind = #iree_gpu.mma_layout<MFMA_F32_32x32x16_F16>
418418
// MI355X-SAME: promote_operands = [0, 1]
419-
// MI355X-SAME: reduction = [0, 0, 1]
420-
// MI355X-SAME: subgroup = [4, 8, 0]
421-
// MI355X-SAME: workgroup = [128, 256, 0]
419+
// MI355X-SAME: reduction = [0, 0, 2]
420+
// MI355X-SAME: subgroup = [2, 4, 0]
421+
// MI355X-SAME: workgroup = [256, 256, 0]
422422

423423
// -----
424424

@@ -437,13 +437,13 @@ func.func @matmul_large_tall_m_f16(
437437

438438
// MI355X-LABEL: func.func @matmul_large_tall_m_f16
439439
// MI355X-SAME: #iree_codegen.translation_info<pipeline = #iree_gpu.pipeline<TileAndFuse>
440-
// MI355X-SAME: workgroup_size = [256, 1, 1] subgroup_size = 64
440+
// MI355X-SAME: workgroup_size = [512, 1, 1] subgroup_size = 64
441441
// MI355X: linalg.matmul {{.*}}lowering_config = #iree_gpu.lowering_config
442-
// MI355X-SAME: mma_kind = #iree_gpu.mma_layout<MFMA_F32_16x16x32_F16>
442+
// MI355X-SAME: mma_kind = #iree_gpu.mma_layout<MFMA_F32_32x32x16_F16>
443443
// MI355X-SAME: promote_operands = [0, 1]
444-
// MI355X-SAME: reduction = [0, 0, 1]
445-
// MI355X-SAME: subgroup = [4, 8, 0]
446-
// MI355X-SAME: workgroup = [128, 256, 0]
444+
// MI355X-SAME: reduction = [0, 0, 2]
445+
// MI355X-SAME: subgroup = [2, 4, 0]
446+
// MI355X-SAME: workgroup = [256, 256, 0]
447447

448448
// -----
449449

@@ -462,13 +462,13 @@ func.func @matmul_large_wide_n_f16(
462462

463463
// MI355X-LABEL: func.func @matmul_large_wide_n_f16
464464
// MI355X-SAME: #iree_codegen.translation_info<pipeline = #iree_gpu.pipeline<TileAndFuse>
465-
// MI355X-SAME: workgroup_size = [256, 1, 1] subgroup_size = 64
465+
// MI355X-SAME: workgroup_size = [512, 1, 1] subgroup_size = 64
466466
// MI355X: linalg.matmul {{.*}}lowering_config = #iree_gpu.lowering_config
467-
// MI355X-SAME: mma_kind = #iree_gpu.mma_layout<MFMA_F32_16x16x32_F16>
467+
// MI355X-SAME: mma_kind = #iree_gpu.mma_layout<MFMA_F32_32x32x16_F16>
468468
// MI355X-SAME: promote_operands = [0, 1]
469-
// MI355X-SAME: reduction = [0, 0, 1]
470-
// MI355X-SAME: subgroup = [4, 8, 0]
471-
// MI355X-SAME: workgroup = [128, 256, 0]
469+
// MI355X-SAME: reduction = [0, 0, 2]
470+
// MI355X-SAME: subgroup = [2, 4, 0]
471+
// MI355X-SAME: workgroup = [256, 256, 0]
472472

473473
// -----
474474

@@ -490,11 +490,11 @@ func.func @matmul_large_very_tall_m_f16(
490490
// MI355X-SAME: #iree_codegen.translation_info<pipeline = #iree_gpu.pipeline<TileAndFuse>
491491
// MI355X-SAME: workgroup_size = [256, 1, 1] subgroup_size = 64
492492
// MI355X: linalg.matmul {{.*}}lowering_config = #iree_gpu.lowering_config
493-
// MI355X-SAME: mma_kind = #iree_gpu.mma_layout<MFMA_F32_16x16x32_F16>
493+
// MI355X-SAME: mma_kind = #iree_gpu.mma_layout<MFMA_F32_32x32x16_F16>
494494
// MI355X-SAME: padding = [128, 256, 32]
495495
// MI355X-SAME: promote_operands = [0, 1]
496-
// MI355X-SAME: reduction = [0, 0, 1]
497-
// MI355X-SAME: subgroup = [4, 8, 0]
496+
// MI355X-SAME: reduction = [0, 0, 2]
497+
// MI355X-SAME: subgroup = [2, 4, 0]
498498
// MI355X-SAME: workgroup = [128, 256, 0]
499499

500500
// -----

0 commit comments

Comments
 (0)