Skip to content

Commit 7f40dff

Browse files
authored
[AMD][gfx1250] Add Assumptions and Fix Predicate in MXGEMM Kernel (#9285)
This PR - added assumptions to loop boundary - fix predicate to eliminate readfirstlane instrs - moved `static_profile` and `composition` to a shared location
1 parent efd3e32 commit 7f40dff

3 files changed

Lines changed: 58 additions & 64 deletions

File tree

Lines changed: 37 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,37 @@
1+
import re
2+
3+
4+
def composition(cls):
5+
""" A decorator lets aggregate type to directly access attributes from its aggregate member. """
6+
7+
def __getattr__(self, name):
8+
if name in self.__dict__:
9+
return object.__getattribute__(self, name)
10+
for member in self.__dict__.values():
11+
if getattr(member, "__triton_aggregate__", False) and not hasattr(member, name):
12+
continue
13+
return getattr(member, name)
14+
raise AttributeError(f"{type(self).__name__} object has no attribute '{name}'")
15+
16+
cls.__getattr__ = __getattr__
17+
return cls
18+
19+
20+
def static_profile(kernel):
21+
amdgcn = kernel.asm['amdgcn']
22+
23+
sgpr_count = int(re.search(r'\.sgpr_count:\s+(\d+)', amdgcn).group(1))
24+
sgpr_spill_count = int(re.search(r'\.sgpr_spill_count:\s+(\d+)', amdgcn).group(1))
25+
vgpr_count = int(re.search(r'\.vgpr_count:\s+(\d+)', amdgcn).group(1))
26+
vgpr_spill_count = int(re.search(r'\.vgpr_spill_count:\s+(\d+)', amdgcn).group(1))
27+
scratch_size = int(re.search(r';\s+ScratchSize:\s+(\d+)', amdgcn).group(1))
28+
code_len_in_byte = int(re.search(r';\s+codeLenInByte\s+=\s+(\d+)', amdgcn).group(1))
29+
occupancy = int(re.search(r';\s+Occupancy:\s+(\d+)', amdgcn).group(1))
30+
31+
print(f"- sgpr_count: {sgpr_count}\n"
32+
f"- sgpr_spill_count: {sgpr_spill_count}\n"
33+
f"- vgpr_count: {vgpr_count}\n"
34+
f"- vgpr_spill_count: {vgpr_spill_count}\n"
35+
f"- scratch_size: {scratch_size}\n"
36+
f"- code_len_in_byte: {code_len_in_byte}\n"
37+
f"- occupancy: {occupancy}\n")

third_party/amd/python/examples/gluon/mxfp_fa_gfx1250.py

Lines changed: 6 additions & 36 deletions
Original file line numberDiff line numberDiff line change
@@ -27,27 +27,17 @@
2727
from triton.experimental.gluon.language.amd.gfx1250 import buffer_load, buffer_store
2828
from triton.experimental.gluon.language.amd.gfx1250 import async_copy as cp
2929

30+
# Handle imports for both pytest (module context) and direct execution
31+
try:
32+
from .gfx1250_utils import static_profile, composition
33+
except ImportError:
34+
from gfx1250_utils import static_profile, composition
35+
3036
# ===-----------------------------------------------------------------------===#
3137
# Kernel Utilities
3238
# ===-----------------------------------------------------------------------===#
3339

3440

35-
def composition(cls):
36-
""" A decorator lets aggregate type to directly access attributes from its aggregate member. """
37-
38-
def __getattr__(self, name):
39-
if name in self.__dict__:
40-
return object.__getattribute__(self, name)
41-
for member in self.__dict__.values():
42-
if getattr(member, "__triton_aggregate__", False) and not hasattr(member, name):
43-
continue
44-
return getattr(member, name)
45-
raise AttributeError(f"{type(self).__name__} object has no attribute '{name}'")
46-
47-
cls.__getattr__ = __getattr__
48-
return cls
49-
50-
5141
@gluon.constexpr_function
5242
def get_padded_shared_layout(shape, transposed=False):
5343
""" Get a padded shared layout without back conflict for a given tensor shape. """
@@ -1840,26 +1830,6 @@ def create_global_scale(dtype: str):
18401830
return scale, scale_ref
18411831

18421832

1843-
def static_profile(kernel):
1844-
amdgcn = kernel.asm['amdgcn']
1845-
1846-
sgpr_count = int(re.search(r'\.sgpr_count:\s+(\d+)', amdgcn).group(1))
1847-
sgpr_spill_count = int(re.search(r'\.sgpr_spill_count:\s+(\d+)', amdgcn).group(1))
1848-
vgpr_count = int(re.search(r'\.vgpr_count:\s+(\d+)', amdgcn).group(1))
1849-
vgpr_spill_count = int(re.search(r'\.vgpr_spill_count:\s+(\d+)', amdgcn).group(1))
1850-
scratch_size = int(re.search(r';\s+ScratchSize:\s+(\d+)', amdgcn).group(1))
1851-
code_len_in_byte = int(re.search(r';\s+codeLenInByte\s+=\s+(\d+)', amdgcn).group(1))
1852-
occupancy = int(re.search(r';\s+Occupancy:\s+(\d+)', amdgcn).group(1))
1853-
1854-
print(f"- sgpr_count: {sgpr_count}\n"
1855-
f"- sgpr_spill_count: {sgpr_spill_count}\n"
1856-
f"- vgpr_count: {vgpr_count}\n"
1857-
f"- vgpr_spill_count: {vgpr_spill_count}\n"
1858-
f"- scratch_size: {scratch_size}\n"
1859-
f"- code_len_in_byte: {code_len_in_byte}\n"
1860-
f"- occupancy: {occupancy}\n")
1861-
1862-
18631833
def get_source_mapping(block_scaling, subtile, pipelined, amdgcn):
18641834
"""
18651835
Create a mapping from amdgcn assembly to source code lines:

third_party/amd/python/examples/gluon/mxfp_gemm_gfx1250.py

Lines changed: 15 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,6 @@
44
# Needed for internal dev flow for now; will remove later
55
hip.hip.hipInit(0)
66

7-
import re
87
import torch
98
import pytest
109
import triton
@@ -15,25 +14,11 @@
1514
from triton.language.core import _aggregate as aggregate
1615
from triton.tools.mxfp import MXFP4Tensor, MXScaleTensor
1716

18-
19-
def static_profile(kernel):
20-
amdgcn = kernel.asm['amdgcn']
21-
22-
sgpr_count = int(re.search(r'\.sgpr_count:\s+(\d+)', amdgcn).group(1))
23-
sgpr_spill_count = int(re.search(r'\.sgpr_spill_count:\s+(\d+)', amdgcn).group(1))
24-
vgpr_count = int(re.search(r'\.vgpr_count:\s+(\d+)', amdgcn).group(1))
25-
vgpr_spill_count = int(re.search(r'\.vgpr_spill_count:\s+(\d+)', amdgcn).group(1))
26-
scratch_size = int(re.search(r';\s+ScratchSize:\s+(\d+)', amdgcn).group(1))
27-
code_len_in_byte = int(re.search(r';\s+codeLenInByte\s+=\s+(\d+)', amdgcn).group(1))
28-
occupancy = int(re.search(r';\s+Occupancy:\s+(\d+)', amdgcn).group(1))
29-
30-
print(f"- sgpr_count: {sgpr_count}\n"
31-
f"- sgpr_spill_count: {sgpr_spill_count}\n"
32-
f"- vgpr_count: {vgpr_count}\n"
33-
f"- vgpr_spill_count: {vgpr_spill_count}\n"
34-
f"- scratch_size: {scratch_size}\n"
35-
f"- code_len_in_byte: {code_len_in_byte}\n"
36-
f"- occupancy: {occupancy}\n")
17+
# Handle imports for both pytest (module context) and direct execution
18+
try:
19+
from .gfx1250_utils import static_profile
20+
except ImportError:
21+
from gfx1250_utils import static_profile
3722

3823

3924
@gluon.constexpr_function
@@ -78,7 +63,6 @@ class MXFPGEMMConfig:
7863
BLOCK_M_PRESHUFFLED: gl.constexpr
7964
BLOCK_N_PRESHUFFLED: gl.constexpr
8065
BLOCK_K_SCALE_PRESHUFFLED: gl.constexpr
81-
tiles_per_warp: gl.constexpr
8266
SCALE_BLOCK: gl.constexpr
8367
ASYNC_COPY_SCALE: gl.constexpr
8468

@@ -116,8 +100,6 @@ def __init__(self, BLOCK_M, BLOCK_N, BLOCK_K, DTYPE_A, DTYPE_B, SCALE_BLOCK, NUM
116100
reg_bases: gl.constexpr = []
117101
warp_bases: gl.constexpr = [[0, 1], [1, 0]]
118102

119-
self.tiles_per_warp = gl.constexpr([2, 2] if SCALE_PRESHUFFLE else [1, 1])
120-
121103
self.BLOCK_M_PRESHUFFLED = gl.constexpr(BLOCK_M // self.PRESHUFFLE_FACTOR)
122104
self.BLOCK_N_PRESHUFFLED = gl.constexpr(BLOCK_N // self.PRESHUFFLE_FACTOR)
123105
self.BLOCK_K_SCALE_PRESHUFFLED = gl.constexpr(BLOCK_K_SCALE * self.PRESHUFFLE_FACTOR)
@@ -280,7 +262,7 @@ def initialize(cfg: MXFPGEMMConfig, a_desc, b_desc, a_scale_desc, b_scale_desc,
280262
a_scale_desc, b_scale_desc, c_ptr, c_offs, c_mask)
281263

282264
@gluon.jit
283-
def issue_loads(self, load_idx, pred=True):
265+
def issue_loads(self, load_idx, pred=1):
284266
cfg = self.cfg
285267
NUM_SUBTILES_K = cfg.NUM_SUBTILES[2]
286268
BLOCK_K_PACKED_A: gl.constexpr = cfg.BLOCK_K // cfg.DIV_FACTOR_A // NUM_SUBTILES_K
@@ -359,9 +341,12 @@ def pipeline(self, K):
359341

360342
accumulator = gl.zeros((cfg.BLOCK_M, cfg.BLOCK_N), dtype=gl.float32, layout=self.cfg.acc_layout)
361343
loop_ub = gl.cdiv(K, cfg.BLOCK_K)
344+
gl.assume(loop_ub > 0)
362345
epilogue_lb = loop_ub - (cfg.NUM_BUFFERS - 1)
363346
for i in range(0, loop_ub):
364-
load_idx = self.issue_loads(load_idx, pred=(i < epilogue_lb))
347+
pred = i - epilogue_lb
348+
pred = (pred >> 31) & 1
349+
load_idx = self.issue_loads(load_idx, pred=pred)
365350

366351
gl.amd.gfx1250.tdm.async_wait((cfg.NUM_BUFFERS - 1) * self.cfg.NUM_LOADS_IN_BATCH)
367352

@@ -554,7 +539,7 @@ def issue_local_load_b(self, wmma_idx, b_buffer, b_scale_buffer):
554539
return b, scale_b
555540

556541
@gluon.jit
557-
def issue_load_a(self, load_idx, a_buffer, a_scale_buffer, pred=True):
542+
def issue_load_a(self, load_idx, a_buffer, a_scale_buffer, pred=1):
558543
cfg = self.cfg
559544
NUM_SUBTILES_K: gl.constexpr = cfg.NUM_SUBTILES[2]
560545
BLOCK_K: gl.constexpr = cfg.BLOCK_K // cfg.DIV_FACTOR_A // NUM_SUBTILES_K
@@ -574,7 +559,7 @@ def issue_load_a(self, load_idx, a_buffer, a_scale_buffer, pred=True):
574559
return load_idx + 1
575560

576561
@gluon.jit
577-
def issue_load_b(self, load_idx, b_buffer, b_scale_buffer, pred=True):
562+
def issue_load_b(self, load_idx, b_buffer, b_scale_buffer, pred=1):
578563
cfg = self.cfg
579564
NUM_SUBTILES_N: gl.constexpr = cfg.NUM_SUBTILES[1]
580565
NUM_SUBTILES_K: gl.constexpr = cfg.NUM_SUBTILES[2]
@@ -642,9 +627,11 @@ def pipeline(self, K):
642627
layout=cfg.acc_layout)
643628

644629
loop_ub = gl.cdiv(K, cfg.BLOCK_K)
630+
gl.assume(loop_ub > 0)
645631
epilogue_lb = loop_ub - (cfg.NUM_BUFFERS - 1)
646632
for i in range(0, loop_ub):
647-
pred = (i < epilogue_lb)
633+
pred = i - epilogue_lb
634+
pred = (pred >> 31) & 1
648635

649636
# iter i + 1
650637
load_a_idx = self.issue_load_a(load_a_idx, self.a_buffer0, self.a_scale_buffer0, pred=pred)

0 commit comments

Comments
 (0)