Skip to content

Commit b07e76a

Browse files
authored
[Bench][Blackwell] Fix warp specialization for fp8 x mxfp4 bench (#6537)
This pr-chain brings the performance of the mixed fp8 x mxfp4 MOE kernel on par with fp8 x fp8 kernel: * About 10% slower in the dense benchmarks * About 10% faster in the llama4 benchmarks Applies a bug fix for padded scale loads in fp8 x mxfp4 mode ensuring TMA load requirements are met when using the unpacked fp4 (padded) layout. This only occurs after enabling warp specialization.
1 parent a0e3e78 commit b07e76a

2 files changed

Lines changed: 6 additions & 2 deletions

File tree

bench/triton_bench/matmul_ogs_details/_ptma_matmul_ogs.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -293,7 +293,7 @@ def _ptma_matmul_ogs(
293293
# Enable warp specialization when all loads are TMA loads. Don't enable it
294294
# for mixed-precision yet.
295295
ENABLE_WS: tl.constexpr = True
296-
WARP_SPECIALIZE: tl.constexpr = ((USE_GATHER_TMA or X_USE_LOAD_TMA) and not is_microscaled_format) and ENABLE_WS
296+
WARP_SPECIALIZE: tl.constexpr = (USE_GATHER_TMA or X_USE_LOAD_TMA) and ENABLE_WS
297297

298298
for tile_id in tl.range(tl.program_id(0), num_tiles, NUM_SMS, flatten=True, disallow_acc_multi_buffer=DISALLOW_ACC_MULTI_BUFFER, warp_specialize=WARP_SPECIALIZE):
299299
expt_id, start_z, start_m, eM, off_m, off_n, pid_k = _load_tile_attrs(

lib/Dialect/TritonGPU/Transforms/WarpSpecialization/LoadMMASpecialization.cpp

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@
1212
#include "triton/Dialect/TritonGPU/Transforms/Utility.h"
1313
#include "triton/Dialect/TritonGPU/Transforms/WarpSpecialization.h"
1414
#include "triton/Dialect/TritonNvidiaGPU/IR/Dialect.h"
15+
#include "triton/Dialect/TritonNvidiaGPU/Transforms/TMAUtilities.h"
1516

1617
using namespace mlir;
1718
using namespace triton;
@@ -139,8 +140,11 @@ static void lowerTMACopy(ImplicitLocOpBuilder &b, Partition &partition,
139140
if (auto load = dyn_cast<DescriptorLoadOp>(op)) {
140141
Value tmaPtr = createInPartition<ttng::TensorDescToTMAPtrOp>(
141142
b, partition, load.getDesc());
143+
auto indices = ttng::translateTMAIndices(
144+
b, load.getLoc(), load.getDesc().getType().getBlockType().getEncoding(),
145+
load.getIndices());
142146
createInPartition<ttng::AsyncTMACopyGlobalToLocalOp>(
143-
b, partition, tmaPtr, load.getIndices(), barrier, view, truePred);
147+
b, partition, tmaPtr, indices, barrier, view, truePred);
144148
} else {
145149
auto gather = cast<DescriptorGatherOp>(op);
146150
Value tmaPtr = createInPartition<ttng::TensorDescToTMAPtrOp>(

0 commit comments

Comments
 (0)