Skip to content

Commit d5234a7

Browse files
committed
[AMD][TDM] Minor formatting cleanup
1 parent 75b8c66 commit d5234a7

3 files changed

Lines changed: 12 additions & 9 deletions

File tree

python/src/gluon_ir.cc

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1023,10 +1023,10 @@ void init_gluon_ir(py::module &&m) {
10231023
[](GluonOpBuilder &self, Value descPtr, std::vector<Value> &indices,
10241024
Value result, Value pred, Value barrier,
10251025
std::vector<int64_t> warpBases) {
1026-
auto warpBasesAttr = warpBases.empty()
1027-
? DenseI64ArrayAttr()
1028-
: DenseI64ArrayAttr::get(
1029-
self.getContext(), warpBases);
1026+
auto warpBasesAttr =
1027+
warpBases.empty()
1028+
? DenseI64ArrayAttr()
1029+
: DenseI64ArrayAttr::get(self.getContext(), warpBases);
10301030
self.create<ttag::AsyncTDMCopyGlobalToLocalOp>(
10311031
descPtr, indices, result, pred, barrier, warpBasesAttr);
10321032
})

third_party/amd/lib/Dialect/TritonAMDGPU/IR/Dialect.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -39,9 +39,9 @@
3939
#include "Dialect/TritonAMDGPU/IR/Dialect.cpp.inc"
4040
// clang-format on
4141

42+
#include "third_party/amd/backend/include/TDMCommon.h"
4243
#include "third_party/amd/include/Dialect/TritonAMDGPU/Utility/CommonUtils.h"
4344
#include "third_party/amd/lib/TritonAMDGPUToLLVM/TDMUtility.h"
44-
#include "third_party/amd/backend/include/TDMCommon.h"
4545

4646
using namespace mlir;
4747
using namespace mlir::triton::amdgpu;

third_party/amd/python/examples/gluon/f16_gemm_warp_pipeline_gfx1250.py

Lines changed: 7 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -104,10 +104,11 @@ def gemm_tdm_pipelined_warp_pipelined_kernel(a_ptr, b_ptr, c_ptr, #
104104
# Duplicate warps get pred=0 (hardware no-op), freeing TDM bandwidth.
105105
# ---------------------------------------------------------------------------
106106

107+
107108
@gluon.jit
108109
def issue_loads_specialized(producer, a_desc, b_desc, off_am, off_bn, a_buffer, b_buffer, BLOCK_K: ttgl.constexpr,
109-
NUM_BUFFERS: ttgl.constexpr, TRANSPOSE_B: ttgl.constexpr,
110-
TDM_WARP_BASES: ttgl.constexpr, pred=1):
110+
NUM_BUFFERS: ttgl.constexpr, TRANSPOSE_B: ttgl.constexpr, TDM_WARP_BASES: ttgl.constexpr,
111+
pred=1):
111112
pred_i32 = pred.to(ttgl.int32) if hasattr(pred, 'to') else pred
112113
ttgl.amd.gfx1250.tdm.async_load(a_desc, [off_am, producer * BLOCK_K], a_buffer.index(producer % NUM_BUFFERS),
113114
pred=pred_i32, warp_bases=TDM_WARP_BASES)
@@ -171,8 +172,8 @@ def gemm_tdm_specialized_pipelined_warp_pipelined_kernel(a_ptr, b_ptr, c_ptr, #
171172
with ttgl.amd.warp_pipeline_stage("stage0", priority=1):
172173
consumer, a, b = lds_load(consumer, a_buffer, OPERAND_LAYOUT_A, b_buffer, OPERAND_LAYOUT_B, NUM_BUFFERS,
173174
TRANSPOSE_B)
174-
producer = issue_loads_specialized(producer, a_desc, b_desc, 0, 0, a_buffer, b_buffer, BLOCK_K,
175-
NUM_BUFFERS, TRANSPOSE_B, TDM_WARP_BASES)
175+
producer = issue_loads_specialized(producer, a_desc, b_desc, 0, 0, a_buffer, b_buffer, BLOCK_K, NUM_BUFFERS,
176+
TRANSPOSE_B, TDM_WARP_BASES)
176177
with ttgl.amd.warp_pipeline_stage("stage1", priority=0):
177178
accumulator = issue_wmma_compute(a, b, accumulator)
178179
ttgl.amd.gfx1250.tdm.async_wait(2)
@@ -193,6 +194,7 @@ def gemm_tdm_specialized_pipelined_warp_pipelined_kernel(a_ptr, b_ptr, c_ptr, #
193194
# Helper
194195
# ---------------------------------------------------------------------------
195196

197+
196198
def _compute_tdm_warp_bases(block_shape, num_warps, active_warps):
197199
"""Compute warp_bases for partial TDM copy with the given active warp count.
198200
@@ -231,6 +233,7 @@ def _compute_tdm_warp_bases(block_shape, num_warps, active_warps):
231233
# Tests
232234
# ---------------------------------------------------------------------------
233235

236+
234237
@pytest.mark.parametrize("BLOCK_M,BLOCK_N,BLOCK_K", [(256, 256, 64)])
235238
@pytest.mark.parametrize("NUM_BUFFERS", [3])
236239
@pytest.mark.parametrize("TRANSPOSE_B", [True])

0 commit comments

Comments
 (0)