44# Needed for internal dev flow for now; will remove later
55hip .hip .hipInit (0 )
66
7- import re
87import torch
98import pytest
109import triton
1514from triton .language .core import _aggregate as aggregate
1615from triton .tools .mxfp import MXFP4Tensor , MXScaleTensor
1716
18-
19- def static_profile (kernel ):
20- amdgcn = kernel .asm ['amdgcn' ]
21-
22- sgpr_count = int (re .search (r'\.sgpr_count:\s+(\d+)' , amdgcn ).group (1 ))
23- sgpr_spill_count = int (re .search (r'\.sgpr_spill_count:\s+(\d+)' , amdgcn ).group (1 ))
24- vgpr_count = int (re .search (r'\.vgpr_count:\s+(\d+)' , amdgcn ).group (1 ))
25- vgpr_spill_count = int (re .search (r'\.vgpr_spill_count:\s+(\d+)' , amdgcn ).group (1 ))
26- scratch_size = int (re .search (r';\s+ScratchSize:\s+(\d+)' , amdgcn ).group (1 ))
27- code_len_in_byte = int (re .search (r';\s+codeLenInByte\s+=\s+(\d+)' , amdgcn ).group (1 ))
28- occupancy = int (re .search (r';\s+Occupancy:\s+(\d+)' , amdgcn ).group (1 ))
29-
30- print (f"- sgpr_count: { sgpr_count } \n "
31- f"- sgpr_spill_count: { sgpr_spill_count } \n "
32- f"- vgpr_count: { vgpr_count } \n "
33- f"- vgpr_spill_count: { vgpr_spill_count } \n "
34- f"- scratch_size: { scratch_size } \n "
35- f"- code_len_in_byte: { code_len_in_byte } \n "
36- f"- occupancy: { occupancy } \n " )
17+ # Handle imports for both pytest (module context) and direct execution
18+ try :
19+ from .gfx1250_utils import static_profile
20+ except ImportError :
21+ from gfx1250_utils import static_profile
3722
3823
3924@gluon .constexpr_function
@@ -78,7 +63,6 @@ class MXFPGEMMConfig:
7863 BLOCK_M_PRESHUFFLED : gl .constexpr
7964 BLOCK_N_PRESHUFFLED : gl .constexpr
8065 BLOCK_K_SCALE_PRESHUFFLED : gl .constexpr
81- tiles_per_warp : gl .constexpr
8266 SCALE_BLOCK : gl .constexpr
8367 ASYNC_COPY_SCALE : gl .constexpr
8468
@@ -116,8 +100,6 @@ def __init__(self, BLOCK_M, BLOCK_N, BLOCK_K, DTYPE_A, DTYPE_B, SCALE_BLOCK, NUM
116100 reg_bases : gl .constexpr = []
117101 warp_bases : gl .constexpr = [[0 , 1 ], [1 , 0 ]]
118102
119- self .tiles_per_warp = gl .constexpr ([2 , 2 ] if SCALE_PRESHUFFLE else [1 , 1 ])
120-
121103 self .BLOCK_M_PRESHUFFLED = gl .constexpr (BLOCK_M // self .PRESHUFFLE_FACTOR )
122104 self .BLOCK_N_PRESHUFFLED = gl .constexpr (BLOCK_N // self .PRESHUFFLE_FACTOR )
123105 self .BLOCK_K_SCALE_PRESHUFFLED = gl .constexpr (BLOCK_K_SCALE * self .PRESHUFFLE_FACTOR )
@@ -280,7 +262,7 @@ def initialize(cfg: MXFPGEMMConfig, a_desc, b_desc, a_scale_desc, b_scale_desc,
280262 a_scale_desc , b_scale_desc , c_ptr , c_offs , c_mask )
281263
282264 @gluon .jit
283- def issue_loads (self , load_idx , pred = True ):
265+ def issue_loads (self , load_idx , pred = 1 ):
284266 cfg = self .cfg
285267 NUM_SUBTILES_K = cfg .NUM_SUBTILES [2 ]
286268 BLOCK_K_PACKED_A : gl .constexpr = cfg .BLOCK_K // cfg .DIV_FACTOR_A // NUM_SUBTILES_K
@@ -359,9 +341,12 @@ def pipeline(self, K):
359341
360342 accumulator = gl .zeros ((cfg .BLOCK_M , cfg .BLOCK_N ), dtype = gl .float32 , layout = self .cfg .acc_layout )
361343 loop_ub = gl .cdiv (K , cfg .BLOCK_K )
344+ gl .assume (loop_ub > 0 )
362345 epilogue_lb = loop_ub - (cfg .NUM_BUFFERS - 1 )
363346 for i in range (0 , loop_ub ):
364- load_idx = self .issue_loads (load_idx , pred = (i < epilogue_lb ))
347+ pred = i - epilogue_lb
348+ pred = (pred >> 31 ) & 1
349+ load_idx = self .issue_loads (load_idx , pred = pred )
365350
366351 gl .amd .gfx1250 .tdm .async_wait ((cfg .NUM_BUFFERS - 1 ) * self .cfg .NUM_LOADS_IN_BATCH )
367352
@@ -554,7 +539,7 @@ def issue_local_load_b(self, wmma_idx, b_buffer, b_scale_buffer):
554539 return b , scale_b
555540
556541 @gluon .jit
557- def issue_load_a (self , load_idx , a_buffer , a_scale_buffer , pred = True ):
542+ def issue_load_a (self , load_idx , a_buffer , a_scale_buffer , pred = 1 ):
558543 cfg = self .cfg
559544 NUM_SUBTILES_K : gl .constexpr = cfg .NUM_SUBTILES [2 ]
560545 BLOCK_K : gl .constexpr = cfg .BLOCK_K // cfg .DIV_FACTOR_A // NUM_SUBTILES_K
@@ -574,7 +559,7 @@ def issue_load_a(self, load_idx, a_buffer, a_scale_buffer, pred=True):
574559 return load_idx + 1
575560
576561 @gluon .jit
577- def issue_load_b (self , load_idx , b_buffer , b_scale_buffer , pred = True ):
562+ def issue_load_b (self , load_idx , b_buffer , b_scale_buffer , pred = 1 ):
578563 cfg = self .cfg
579564 NUM_SUBTILES_N : gl .constexpr = cfg .NUM_SUBTILES [1 ]
580565 NUM_SUBTILES_K : gl .constexpr = cfg .NUM_SUBTILES [2 ]
@@ -642,9 +627,11 @@ def pipeline(self, K):
642627 layout = cfg .acc_layout )
643628
644629 loop_ub = gl .cdiv (K , cfg .BLOCK_K )
630+ gl .assume (loop_ub > 0 )
645631 epilogue_lb = loop_ub - (cfg .NUM_BUFFERS - 1 )
646632 for i in range (0 , loop_ub ):
647- pred = (i < epilogue_lb )
633+ pred = i - epilogue_lb
634+ pred = (pred >> 31 ) & 1
648635
649636 # iter i + 1
650637 load_a_idx = self .issue_load_a (load_a_idx , self .a_buffer0 , self .a_scale_buffer0 , pred = pred )
0 commit comments