Skip to content

Commit 17cc629

Browse files
njriasanmeta-codesync[bot]
authored andcommitted
[AutoWS] Fix AutoWS Hopper GEMM partition structure with DP=2 (#1322)
Summary: Fixes the partition structure with DP=2. Previously these were being mapped to the same partition so there was no benefit. Now these are mapped to separate partitions, allowing greater parallelism and the opportunity to use ping pong. This also fixes several structural bugs in the compiler. Pull Request resolved: #1322 Reviewed By: manman-ren Differential Revision: D102372962 Pulled By: njriasan fbshipit-source-id: 33ee700faa08086a798dd21f0d078264d61e3f78
1 parent 495023b commit 17cc629

6 files changed

Lines changed: 327 additions & 70 deletions

File tree

lib/Dialect/TritonGPU/Transforms/WarpSpecialization/OptimizePartitionWarps.cpp

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -164,6 +164,7 @@ static LogicalResult optimizePartitionNumWarps(ModuleAxisInfoAnalysis &axisInfo,
164164
SmallVector<unsigned> maxTensorRegs;
165165
for (Region *partition : wsOp.getPartitionRegions()) {
166166
unsigned &tensorRegs = maxTensorRegs.emplace_back(0);
167+
167168
partition->walk([&](Operation *op) {
168169
for (Type type :
169170
llvm::concat<Type>(op->getOperandTypes(), op->getResultTypes())) {
@@ -207,7 +208,9 @@ static LogicalResult optimizePartitionNumWarps(ModuleAxisInfoAnalysis &axisInfo,
207208
if (isa<ttng::AsyncTMAGatherOp, ttng::AsyncTMAScatterOp>(op))
208209
*minWarps = 2;
209210
// TMEM ops require at least 4 warps to be able to read all lanes.
210-
else if (isa<ttng::TMEMLoadOp, ttng::TMEMStoreOp, ttng::TMEMAllocOp>(op))
211+
// WarpGroupDotOp requires a full warp group (4 warps).
212+
else if (isa<ttng::TMEMLoadOp, ttng::TMEMStoreOp, ttng::TMEMAllocOp,
213+
ttng::WarpGroupDotOp>(op))
211214
*minWarps = 4;
212215
});
213216
}
@@ -306,7 +309,6 @@ static LogicalResult optimizePartitionNumWarps(ModuleAxisInfoAnalysis &axisInfo,
306309
for (auto [partition, newNumWarps, prevNumWarps, tensorRegs, estRegs] :
307310
llvm::zip(wsOp.getPartitionRegions(), partitionNumWarps,
308311
wsOp.getPartitionNumWarps(), maxTensorRegs, estRegUsage)) {
309-
310312
// "Guess" the register usage for each partition.
311313
estRegs = tensorRegs ? maxRegAutoWS : minRegAutoWS;
312314

python/test/unit/language/test_tutorial09_warp_specialization.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1110,6 +1110,7 @@ def test_hopper_matmul_tma_warp_specialize(
11101110
"""Test matmul_kernel_tma with warp_specialize=True on Hopper (K-loop based)."""
11111111
if DATA_PARTITION_FACTOR != 1 and BLOCK_SIZE_M != 128:
11121112
pytest.skip("DATA_PARTITION_FACTOR != 1 requires BLOCK_SIZE_M == 128")
1113+
11131114
if BLOCK_SIZE_N == 256 and BLOCK_SIZE_K == 128 and not (BLOCK_SIZE_M == 64 and num_stages == 2):
11141115
pytest.skip("OOM: shared memory exceeds H100 limit")
11151116

@@ -1169,6 +1170,7 @@ def alloc_fn(size, align, stream):
11691170
num_warps=num_warps,
11701171
early_tma_store_lowering=use_early_tma_store_lowering,
11711172
pingpongAutoWS=enable_pingpong,
1173+
maxRegAutoWS=208 if DATA_PARTITION_FACTOR > 1 else 252,
11721174
)
11731175

11741176
ttgir = kernel.asm["ttgir"]
@@ -1221,6 +1223,7 @@ def test_hopper_matmul_tma_persistent_warp_specialize(
12211223
"""
12221224
if DATA_PARTITION_FACTOR != 1 and BLOCK_SIZE_M != 128:
12231225
pytest.skip("DATA_PARTITION_FACTOR != 1 requires BLOCK_SIZE_M == 128")
1226+
12241227
if BLOCK_SIZE_N == 256 and BLOCK_SIZE_K == 128 and not (BLOCK_SIZE_M == 64 and num_stages == 2):
12251228
pytest.skip("OOM: shared memory exceeds H100 limit")
12261229

@@ -1295,6 +1298,7 @@ def alloc_fn(size, align, stream):
12951298
num_warps=num_warps,
12961299
early_tma_store_lowering=use_early_tma_store_lowering,
12971300
pingpongAutoWS=enable_pingpong,
1301+
maxRegAutoWS=208 if DATA_PARTITION_FACTOR > 1 else 252,
12981302
)
12991303

13001304
ttgir = kernel.asm["ttgir"]
@@ -1348,6 +1352,7 @@ def test_hopper_matmul_descriptor_persistent_warp_specialize(
13481352
"""
13491353
if DATA_PARTITION_FACTOR != 1 and BLOCK_SIZE_M != 128:
13501354
pytest.skip("DATA_PARTITION_FACTOR != 1 requires BLOCK_SIZE_M == 128")
1355+
13511356
if BLOCK_SIZE_N == 256 and BLOCK_SIZE_K == 128 and not (BLOCK_SIZE_M == 64 and num_stages == 2):
13521357
pytest.skip("OOM: shared memory exceeds H100 limit")
13531358

@@ -1407,6 +1412,7 @@ def alloc_fn(size, align, stream):
14071412
num_warps=num_warps,
14081413
early_tma_store_lowering=use_early_tma_store_lowering,
14091414
pingpongAutoWS=enable_pingpong,
1415+
maxRegAutoWS=208 if DATA_PARTITION_FACTOR > 1 else 252,
14101416
)
14111417

14121418
ttgir = kernel.asm["ttgir"]
Lines changed: 134 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,134 @@
1+
// RUN: triton-opt %s --nvgpu-partition-scheduling-meta --verify-each=false | FileCheck %s
2+
3+
// Tests that on Hopper (cuda:90) with DATA_PARTITION_FACTOR=2 and
4+
// WarpGroupDotOp, the partition scheduler correctly creates per-dpId
5+
// computation partitions using the WarpGroupDotOp fallback (since
6+
// WSDataPartition already split the dots, leaving no DataPartition-
7+
// categorized ops in backward slices). Epilogue is merged into
8+
// computation partitions so each MMA's truncf + TMA store lives
9+
// alongside it.
10+
11+
#blocked = #ttg.blocked<{sizePerThread = [1, 8], threadsPerWarp = [4, 8], warpsPerCTA = [4, 1], order = [1, 0]}>
12+
#blocked1 = #ttg.blocked<{sizePerThread = [1, 8], threadsPerWarp = [2, 16], warpsPerCTA = [4, 1], order = [1, 0]}>
13+
#mma = #ttg.nvidia_mma<{versionMajor = 3, versionMinor = 0, warpsPerCTA = [4, 1], instrShape = [16, 128, 16]}>
14+
#shared = #ttg.nvmma_shared<{swizzlingByteWidth = 128, transposed = false, elementBitWidth = 16}>
15+
#shared1 = #ttg.nvmma_shared<{swizzlingByteWidth = 128, transposed = true, elementBitWidth = 16}>
16+
#smem = #ttg.shared_memory
17+
18+
module attributes {"ttg.num-warps" = 4 : i32, ttg.target = "cuda:90", "ttg.threads-per-warp" = 32 : i32} {
19+
20+
// CHECK-LABEL: hopper_data_partitioned_gemm
21+
//
22+
// --- Inner k-loop: descriptor_loads and local_allocs → load partition ---
23+
// CHECK: descriptor_load{{.*}}ttg.partition = array<i32: [[LOAD:[0-9]+]]>
24+
// CHECK: descriptor_load{{.*}}ttg.partition = array<i32: [[LOAD]]>
25+
// CHECK: descriptor_load{{.*}}ttg.partition = array<i32: [[LOAD]]>
26+
// CHECK: local_alloc{{.*}}ttg.partition = array<i32: [[LOAD]]>
27+
// CHECK: local_alloc{{.*}}ttg.partition = array<i32: [[LOAD]]>
28+
// CHECK: local_alloc{{.*}}ttg.partition = array<i32: [[LOAD]]>
29+
//
30+
// --- Inner k-loop: each warp_group_dot in its own computation partition ---
31+
// CHECK: warp_group_dot{{.*}}ttg.partition = array<i32: [[COMP_A:[0-9]+]]>
32+
// CHECK: warp_group_dot{{.*}}ttg.partition = array<i32: [[COMP_B:[0-9]+]]>
33+
//
34+
// --- Epilogue: each half's truncf + TMA store in same partition as its MMA ---
35+
// CHECK: truncf{{.*}}ttg.partition = array<i32: [[COMP_A]]>
36+
// CHECK: truncf{{.*}}ttg.partition = array<i32: [[COMP_B]]>
37+
// CHECK: async_tma_copy_local_to_global{{.*}}ttg.partition = array<i32: [[COMP_A]]>
38+
// CHECK: async_tma_copy_local_to_global{{.*}}ttg.partition = array<i32: [[COMP_B]]>
39+
//
40+
// --- Partition types: computation partitions before load ---
41+
// CHECK: partition.types = ["computation", "computation", "load"
42+
tt.func public @hopper_data_partitioned_gemm(
43+
%a_desc: !tt.tensordesc<tensor<64x64xf16, #shared>>,
44+
%b_desc: !tt.tensordesc<tensor<128x64xf16, #shared>>,
45+
%c_desc: !tt.tensordesc<tensor<64x128xf16, #shared>>,
46+
%M: i32 {tt.divisibility = 16 : i32},
47+
%N: i32 {tt.divisibility = 16 : i32},
48+
%K: i32 {tt.divisibility = 16 : i32}
49+
) {
50+
%c132_i32 = arith.constant 132 : i32
51+
%c8_i32 = arith.constant 8 : i32
52+
%c128_i32 = arith.constant 128 : i32
53+
%c64_i32 = arith.constant 64 : i32
54+
%c0_i32 = arith.constant 0 : i32
55+
%c1_i32 = arith.constant 1 : i32
56+
%c127_i32 = arith.constant 127 : i32
57+
%cst = arith.constant dense<0.000000e+00> : tensor<64x128xf32, #mma>
58+
59+
%start_pid = tt.get_program_id x : i32
60+
%num_pid_m = arith.addi %M, %c127_i32 : i32
61+
%num_pid_m_div = arith.divsi %num_pid_m, %c128_i32 : i32
62+
%num_pid_n = arith.addi %N, %c127_i32 : i32
63+
%num_pid_n_div = arith.divsi %num_pid_n, %c128_i32 : i32
64+
%k_tiles = arith.addi %K, %c64_i32 : i32
65+
%k_tiles_div = arith.divsi %k_tiles, %c64_i32 : i32
66+
%num_tiles = arith.muli %num_pid_m_div, %num_pid_n_div : i32
67+
%tile_id_c_init = arith.subi %start_pid, %c132_i32 : i32
68+
%num_pid_in_group = arith.muli %num_pid_n_div, %c8_i32 : i32
69+
70+
%tile_id_c_out = scf.for %tile_id = %start_pid to %num_tiles step %c132_i32
71+
iter_args(%tile_id_c = %tile_id_c_init) -> (i32) : i32 {
72+
%group_id = arith.divsi %tile_id, %num_pid_in_group : i32
73+
%first_pid_m = arith.muli %group_id, %c8_i32 : i32
74+
%group_size_m = arith.subi %num_pid_m_div, %first_pid_m : i32
75+
%group_size_m_clamped = arith.minsi %group_size_m, %c8_i32 : i32
76+
%pid_m = arith.remsi %tile_id, %group_size_m_clamped : i32
77+
%pid_m_final = arith.addi %first_pid_m, %pid_m : i32
78+
%pid_n_tmp = arith.remsi %tile_id, %num_pid_in_group : i32
79+
%pid_n = arith.divsi %pid_n_tmp, %group_size_m_clamped : i32
80+
%offs_am = arith.muli %pid_m_final, %c128_i32 : i32
81+
%offs_am_1 = arith.addi %offs_am, %c64_i32 : i32
82+
%offs_bn = arith.muli %pid_n, %c128_i32 : i32
83+
84+
// Inner k-loop with two WarpGroupDotOps (data-partitioned)
85+
%acc:2 = scf.for %ki = %c0_i32 to %k_tiles_div step %c1_i32
86+
iter_args(%acc0 = %cst, %acc1 = %cst) -> (tensor<64x128xf32, #mma>, tensor<64x128xf32, #mma>) : i32 {
87+
%offs_k = arith.muli %ki, %c64_i32 {loop.cluster = 1 : i32, loop.stage = 0 : i32} : i32
88+
89+
%a0 = tt.descriptor_load %a_desc[%offs_am, %offs_k] {loop.cluster = 1 : i32, loop.stage = 0 : i32} : !tt.tensordesc<tensor<64x64xf16, #shared>> -> tensor<64x64xf16, #blocked>
90+
%a1 = tt.descriptor_load %a_desc[%offs_am_1, %offs_k] {loop.cluster = 1 : i32, loop.stage = 0 : i32} : !tt.tensordesc<tensor<64x64xf16, #shared>> -> tensor<64x64xf16, #blocked>
91+
%b = tt.descriptor_load %b_desc[%offs_bn, %offs_k] {loop.cluster = 1 : i32, loop.stage = 0 : i32} : !tt.tensordesc<tensor<128x64xf16, #shared>> -> tensor<128x64xf16, #blocked>
92+
93+
%a0_smem = ttg.local_alloc %a0 {loop.cluster = 0 : i32, loop.stage = 1 : i32} : (tensor<64x64xf16, #blocked>) -> !ttg.memdesc<64x64xf16, #shared, #smem>
94+
%a1_smem = ttg.local_alloc %a1 {loop.cluster = 0 : i32, loop.stage = 1 : i32} : (tensor<64x64xf16, #blocked>) -> !ttg.memdesc<64x64xf16, #shared, #smem>
95+
%b_smem = ttg.local_alloc %b {loop.cluster = 0 : i32, loop.stage = 1 : i32} : (tensor<128x64xf16, #blocked>) -> !ttg.memdesc<128x64xf16, #shared, #smem>
96+
%b_trans = ttg.memdesc_trans %b_smem {loop.cluster = 0 : i32, loop.stage = 1 : i32, order = array<i32: 1, 0>} : !ttg.memdesc<128x64xf16, #shared, #smem> -> !ttg.memdesc<64x128xf16, #shared1, #smem>
97+
98+
%dot0 = ttng.warp_group_dot %a0_smem, %b_trans, %acc0 {inputPrecision = 0 : i32, loop.cluster = 0 : i32, loop.stage = 1 : i32} : !ttg.memdesc<64x64xf16, #shared, #smem> * !ttg.memdesc<64x128xf16, #shared1, #smem> -> tensor<64x128xf32, #mma>
99+
%dot1 = ttng.warp_group_dot %a1_smem, %b_trans, %acc1 {inputPrecision = 0 : i32, loop.cluster = 0 : i32, loop.stage = 1 : i32} : !ttg.memdesc<64x64xf16, #shared, #smem> * !ttg.memdesc<64x128xf16, #shared1, #smem> -> tensor<64x128xf32, #mma>
100+
101+
scf.yield %dot0, %dot1 : tensor<64x128xf32, #mma>, tensor<64x128xf32, #mma>
102+
} {tt.scheduled_max_stage = 1 : i32}
103+
104+
// Epilogue
105+
%tile_id_c_next = arith.addi %tile_id_c, %c132_i32 : i32
106+
%group_id_c = arith.divsi %tile_id_c_next, %num_pid_in_group : i32
107+
%first_pid_m_c = arith.muli %group_id_c, %c8_i32 : i32
108+
%group_size_m_c = arith.subi %num_pid_m_div, %first_pid_m_c : i32
109+
%group_size_m_c_clamped = arith.minsi %group_size_m_c, %c8_i32 : i32
110+
%pid_m_c = arith.remsi %tile_id_c_next, %group_size_m_c_clamped : i32
111+
%pid_m_c_final = arith.addi %first_pid_m_c, %pid_m_c : i32
112+
%pid_n_c_tmp = arith.remsi %tile_id_c_next, %num_pid_in_group : i32
113+
%pid_n_c = arith.divsi %pid_n_c_tmp, %group_size_m_c_clamped : i32
114+
%offs_am_c = arith.muli %pid_m_c_final, %c128_i32 : i32
115+
%offs_am_c_1 = arith.addi %offs_am_c, %c64_i32 : i32
116+
%offs_bn_c = arith.muli %pid_n_c, %c128_i32 : i32
117+
118+
%c0_f16 = arith.truncf %acc#0 : tensor<64x128xf32, #mma> to tensor<64x128xf16, #mma>
119+
%c1_f16 = arith.truncf %acc#1 : tensor<64x128xf32, #mma> to tensor<64x128xf16, #mma>
120+
%c0_cvt = ttg.convert_layout %c0_f16 : tensor<64x128xf16, #mma> -> tensor<64x128xf16, #blocked1>
121+
%c1_cvt = ttg.convert_layout %c1_f16 : tensor<64x128xf16, #mma> -> tensor<64x128xf16, #blocked1>
122+
%c0_smem = ttg.local_alloc %c0_cvt : (tensor<64x128xf16, #blocked1>) -> !ttg.memdesc<64x128xf16, #shared, #smem, mutable>
123+
%store_tok0 = ttng.async_tma_copy_local_to_global %c_desc[%offs_am_c, %offs_bn_c] %c0_smem : !tt.tensordesc<tensor<64x128xf16, #shared>>, !ttg.memdesc<64x128xf16, #shared, #smem, mutable> -> !ttg.async.token
124+
ttng.async_tma_store_token_wait %store_tok0 : !ttg.async.token
125+
%c1_smem = ttg.local_alloc %c1_cvt : (tensor<64x128xf16, #blocked1>) -> !ttg.memdesc<64x128xf16, #shared, #smem, mutable>
126+
%store_tok1 = ttng.async_tma_copy_local_to_global %c_desc[%offs_am_c_1, %offs_bn_c] %c1_smem : !tt.tensordesc<tensor<64x128xf16, #shared>>, !ttg.memdesc<64x128xf16, #shared, #smem, mutable> -> !ttg.async.token
127+
ttng.async_tma_store_token_wait %store_tok1 : !ttg.async.token
128+
129+
scf.yield %tile_id_c_next : i32
130+
} {tt.data_partition_factor = 2 : i32, tt.smem_alloc_algo = 0 : i32, tt.warp_specialize}
131+
tt.return
132+
}
133+
134+
} // module

third_party/nvidia/backend/compiler.py

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -393,8 +393,6 @@ def make_ttgir(mod, metadata, opt, capability):
393393
passes.ttgpuir.add_assign_latencies(pm, opt.num_stages, use_meta_swp_schedule)
394394
passes.ttgpuir.add_schedule_loops(pm, opt.num_stages, use_meta_swp_schedule)
395395
passes.ttgpuir.add_pipeline(pm, opt.num_stages, dump_enabled)
396-
if knobs.nvidia.use_meta_ws:
397-
passes.ttgpuir.add_optimize_partition_warps(pm)
398396
elif capability // 10 >= 10:
399397
if not knobs.nvidia.use_modulo_schedule:
400398
passes.ttgpuir.add_fuse_nested_loops(pm)
@@ -464,7 +462,7 @@ def make_ttgir(mod, metadata, opt, capability):
464462
passes.common.add_symbol_dce(pm)
465463
# Optimize the number of warps and registers after TMA lowering, so
466464
# that any local loads eliminated by TMA lowering do not inflate them.
467-
if capability // 10 >= 10 and knobs.nvidia.use_meta_ws:
465+
if capability // 10 >= 9 and knobs.nvidia.use_meta_ws:
468466
passes.ttgpuir.add_optimize_partition_warps(pm)
469467
nvidia.passes.ttnvgpuir.add_fence_insertion(pm, capability)
470468
nvidia.passes.ttnvgpuir.add_lower_mma(pm)

third_party/nvidia/hopper/lib/Transforms/WarpSpecialization/CodePartitionUtility.cpp

Lines changed: 19 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -2717,8 +2717,6 @@ static void createChannelPost(Operation *allocOp, mlir::DominanceInfo &dom,
27172717
if (!producerOp)
27182718
return;
27192719
auto producerTaskIds = getAsyncTaskIds(producerOp);
2720-
assert(producerTaskIds.size() == 1);
2721-
auto producerTaskId = producerTaskIds.front();
27222720
// Collect consumer task IDs from all consumers. With data partitioning,
27232721
// different consumers may have different task IDs (e.g., K/V buffers
27242722
// consumed by multiple computation partitions).
@@ -2730,6 +2728,25 @@ static void createChannelPost(Operation *allocOp, mlir::DominanceInfo &dom,
27302728
consumerTaskIds.push_back(id);
27312729
}
27322730
}
2731+
2732+
// When a producer has multiple task IDs (e.g., a shared local_alloc
2733+
// consumed by data-partitioned computation groups), no channel is needed
2734+
// for any producer that is co-located with a consumer. It is unclear if
2735+
// is sufficient when there are multiple consumers.
2736+
AsyncTaskId producerTaskId = -1;
2737+
if (producerTaskIds.size() > 1 && consumerTaskIds.size() == 1) {
2738+
auto consumerTaskId = consumerTaskIds.front();
2739+
for (auto id : producerTaskIds) {
2740+
if (id != consumerTaskId) {
2741+
assert(producerTaskId == -1 &&
2742+
"Multiple producers encountered for 1 consumer");
2743+
producerTaskId = id;
2744+
}
2745+
}
2746+
} else {
2747+
assert(producerTaskIds.size() == 1);
2748+
producerTaskId = producerTaskIds.front();
2749+
}
27332750
// Remove producer task id from consumerTaskIds.
27342751
auto iter = std::remove(consumerTaskIds.begin(), consumerTaskIds.end(),
27352752
producerTaskId);

0 commit comments

Comments
 (0)