@@ -88,10 +88,12 @@ def _matmul_launch_metadata(grid, kernel, args):
8888 ret = {}
8989 M , N , K = args ["M" ], args ["N" ], args ["K" ]
9090 kernel_name = kernel .name
91- if "ELEM_PER_BYTE " and "VEC_SIZE" in args :
92- if args ["ELEM_PER_BYTE " ] == 1 :
91+ if "ELEM_PER_BYTE_A" and "ELEM_PER_BYTE_B " and "VEC_SIZE" in args :
92+ if args ["ELEM_PER_BYTE_A" ] == 1 and args [ "ELEM_PER_BYTE_B " ] == 1 :
9393 kernel_name += "_mxfp8"
94- elif args ["ELEM_PER_BYTE" ] == 2 :
94+ elif args ["ELEM_PER_BYTE_A" ] == 1 and args ["ELEM_PER_BYTE_B" ] == 2 :
95+ kernel_name += "_mixed"
96+ elif args ["ELEM_PER_BYTE_A" ] == 2 and args ["ELEM_PER_BYTE_B" ] == 2 :
9597 if args ["VEC_SIZE" ] == 16 :
9698 kernel_name += "_nvfp4"
9799 elif args ["VEC_SIZE" ] == 32 :
@@ -104,23 +106,29 @@ def _matmul_launch_metadata(grid, kernel, args):
104106@triton .jit (launch_metadata = _matmul_launch_metadata )
105107def block_scaled_matmul_kernel ( #
106108 a_desc , a_scale , #
107- b_desc , b_scale , #
109+ b_desc_or_tensor , b_scale , #
108110 c_desc , #
109111 M : tl .constexpr , N : tl .constexpr , K : tl .constexpr , #
110112 stride_sk : tl .constexpr , stride_sb : tl .constexpr , stride_sc : tl .constexpr , stride_sd : tl .constexpr ,
111113 output_type : tl .constexpr , #
112- ELEM_PER_BYTE : tl .constexpr , #
114+ ELEM_PER_BYTE_A : tl .constexpr , #
115+ ELEM_PER_BYTE_B : tl .constexpr , #
113116 VEC_SIZE : tl .constexpr , #
114117 BLOCK_M : tl .constexpr , #
115118 BLOCK_N : tl .constexpr , #
116119 BLOCK_K : tl .constexpr , #
117120 NUM_STAGES : tl .constexpr , #
118121 USE_2D_SCALE_LOAD : tl .constexpr ): #
119122
120- if ELEM_PER_BYTE == 1 :
121- dtype = tl .float8e4nv
122- elif ELEM_PER_BYTE == 2 :
123- dtype = tl .dtype ("uint8" )
123+ if ELEM_PER_BYTE_A == 1 :
124+ dtype_a = tl .float8e4nv
125+ elif ELEM_PER_BYTE_A == 2 :
126+ dtype_a = tl .dtype ("uint8" )
127+
128+ if ELEM_PER_BYTE_B == 1 :
129+ dtype_b = tl .float8e4nv
130+ elif ELEM_PER_BYTE_B == 2 :
131+ dtype_b = tl .dtype ("uint8" )
124132
125133 if output_type == 0 :
126134 output_dtype = tl .float32
@@ -129,25 +137,38 @@ def block_scaled_matmul_kernel( #
129137 elif output_type == 2 :
130138 output_dtype = tl .float8e4nv
131139
132- tl .inline_asm_elementwise ("prefetch.tensormap [$1]; // dummy $0" , "=r,l" , [a_desc ], dtype = tl .int32 , is_pure = False ,
133- pack = 1 )
134- tl .inline_asm_elementwise ("prefetch.tensormap [$1]; // dummy $0" , "=r,l" , [b_desc ], dtype = tl .int32 , is_pure = False ,
135- pack = 1 )
136- tl .inline_asm_elementwise ("prefetch.tensormap [$1]; // dummy $0" , "=r,l" , [c_desc ], dtype = tl .int32 , is_pure = False ,
137- pack = 1 )
138-
139140 pid = tl .program_id (axis = 0 )
140141 num_pid_m = tl .cdiv (M , BLOCK_M )
141142 pid_m = pid % num_pid_m
142143 pid_n = pid // num_pid_m
143144 offs_am = pid_m * BLOCK_M
144145 offs_bn = pid_n * BLOCK_N
145- offs_k = 0
146+ offs_k_a = 0
147+ offs_k_b = 0
146148
147149 ## block scale offsets
148150 offs_sm = (pid_m * (BLOCK_M // 128 ) + tl .arange (0 , BLOCK_M // 128 )) % M
149151 offs_sn = (pid_n * (BLOCK_N // 128 ) + tl .arange (0 , BLOCK_N // 128 )) % N
150152
153+ MIXED_PREC : tl .constexpr = ELEM_PER_BYTE_A == 1 and ELEM_PER_BYTE_B == 2
154+
155+ if MIXED_PREC :
156+ b_desc = tl .make_tensor_descriptor (
157+ b_desc_or_tensor ,
158+ shape = [N , K // ELEM_PER_BYTE_B ],
159+ strides = [K // ELEM_PER_BYTE_B , 1 ],
160+ block_shape = [BLOCK_N , BLOCK_K // ELEM_PER_BYTE_B ],
161+ )
162+ else :
163+ b_desc = b_desc_or_tensor
164+ tl .inline_asm_elementwise ("prefetch.tensormap [$1]; // dummy $0" , "=r,l" , [b_desc ], dtype = tl .int32 ,
165+ is_pure = False , pack = 1 )
166+
167+ tl .inline_asm_elementwise ("prefetch.tensormap [$1]; // dummy $0" , "=r,l" , [a_desc ], dtype = tl .int32 , is_pure = False ,
168+ pack = 1 )
169+ tl .inline_asm_elementwise ("prefetch.tensormap [$1]; // dummy $0" , "=r,l" , [c_desc ], dtype = tl .int32 , is_pure = False ,
170+ pack = 1 )
171+
151172 # For now it is recommended to use 2D scale loads for better performance.
152173 # In the future we will bring additional optimizations to either allow 5D loads,
153174 # the use of TMAs for scale factors, or both.
@@ -171,26 +192,39 @@ def block_scaled_matmul_kernel( #
171192
172193 accumulator = tl .zeros ((BLOCK_M , BLOCK_N ), dtype = tl .float32 )
173194 for k in tl .range (0 , tl .cdiv (K , BLOCK_K ), num_stages = NUM_STAGES ):
174- a = tl ._experimental_descriptor_load (a_desc , [offs_am , offs_k ], [BLOCK_M , BLOCK_K // ELEM_PER_BYTE ], dtype )
175- b = tl ._experimental_descriptor_load (b_desc , [offs_bn , offs_k ], [BLOCK_N , BLOCK_K // ELEM_PER_BYTE ], dtype )
195+ a = tl ._experimental_descriptor_load (a_desc , [offs_am , offs_k_a ], [BLOCK_M , BLOCK_K // ELEM_PER_BYTE_A ],
196+ dtype_a )
197+
198+ if MIXED_PREC :
199+ b = b_desc .load ([offs_bn , offs_k_b ])
200+ else :
201+ b = tl ._experimental_descriptor_load (b_desc , [offs_bn , offs_k_b ], [BLOCK_N , BLOCK_K // ELEM_PER_BYTE_B ],
202+ dtype_b )
203+
176204 scale_a = tl .load (a_scale_ptr )
177205 scale_b = tl .load (b_scale_ptr )
178206 if USE_2D_SCALE_LOAD :
179207 scale_a = scale_a .reshape (BLOCK_M // 128 , BLOCK_K // VEC_SIZE // 4 , 32 , 4 , 4 )
180208 scale_b = scale_b .reshape (BLOCK_N // 128 , BLOCK_K // VEC_SIZE // 4 , 32 , 4 , 4 )
181209 scale_a = scale_a .trans (0 , 3 , 2 , 1 , 4 ).reshape (BLOCK_M , BLOCK_K // VEC_SIZE )
182210 scale_b = scale_b .trans (0 , 3 , 2 , 1 , 4 ).reshape (BLOCK_N , BLOCK_K // VEC_SIZE )
183- if ELEM_PER_BYTE == 2 :
211+
212+ if MIXED_PREC :
213+ accumulator = tl .dot_scaled (a , scale_a , "e4m3" , b .T , scale_b , "e2m1" , accumulator )
214+ elif ELEM_PER_BYTE_A == 2 and ELEM_PER_BYTE_B == 2 :
184215 accumulator = tl .dot_scaled (a , scale_a , "e2m1" , b .T , scale_b , "e2m1" , accumulator )
185216 else :
186217 accumulator = tl .dot_scaled (a , scale_a , "e4m3" , b .T , scale_b , "e4m3" , accumulator )
187- offs_k += BLOCK_K // ELEM_PER_BYTE
218+
219+ offs_k_a += BLOCK_K // ELEM_PER_BYTE_A
220+ offs_k_b += BLOCK_K // ELEM_PER_BYTE_B
188221 a_scale_ptr += (BLOCK_K // VEC_SIZE // 4 ) * stride_sb
189222 b_scale_ptr += (BLOCK_K // VEC_SIZE // 4 ) * stride_sb
223+
190224 tl ._experimental_descriptor_store (c_desc , accumulator .to (output_dtype ), [offs_am , offs_bn ])
191225
192226
193- def block_scaled_matmul (a_desc , a_scale , b_desc , b_scale , dtype_dst , M , N , K , configs ):
227+ def block_scaled_matmul (a_desc , a_scale , b_desc_or_tensor , b_scale , dtype_dst , M , N , K , configs ):
194228 output = torch .empty ((M , N ), dtype = dtype_dst , device = "cuda" )
195229 if dtype_dst == torch .float32 :
196230 dtype_dst = 0
@@ -205,11 +239,11 @@ def block_scaled_matmul(a_desc, a_scale, b_desc, b_scale, dtype_dst, M, N, K, co
205239 output .element_size ())
206240
207241 grid = (triton .cdiv (M , configs ["BLOCK_SIZE_M" ]) * triton .cdiv (N , configs ["BLOCK_SIZE_N" ]), 1 )
208- block_scaled_matmul_kernel [grid ](a_desc , a_scale , b_desc , b_scale , c_desc , M , N , K , a_scale .stride (0 ),
242+ block_scaled_matmul_kernel [grid ](a_desc , a_scale , b_desc_or_tensor , b_scale , c_desc , M , N , K , a_scale .stride (0 ),
209243 a_scale .stride (1 ), a_scale .stride (2 ), a_scale .stride (3 ), dtype_dst ,
210- configs ["ELEM_PER_BYTE " ], configs ["VEC_SIZE " ], configs ["BLOCK_SIZE_M " ],
211- configs ["BLOCK_SIZE_N " ], configs ["BLOCK_SIZE_K " ], configs ["num_stages " ],
212- USE_2D_SCALE_LOAD = True )
244+ configs ["ELEM_PER_BYTE_A " ], configs ["ELEM_PER_BYTE_B " ], configs ["VEC_SIZE " ],
245+ configs ["BLOCK_SIZE_M " ], configs ["BLOCK_SIZE_N " ], configs ["BLOCK_SIZE_K " ],
246+ configs [ "num_stages" ], USE_2D_SCALE_LOAD = True )
213247 return output
214248
215249
@@ -218,8 +252,9 @@ def initialize_block_scaled(M, N, K, block_scale_type="nvfp4", compute_reference
218252 BLOCK_N = 256
219253 BLOCK_K = 256 if "fp4" in block_scale_type else 128
220254 VEC_SIZE = 16 if block_scale_type == "nvfp4" else 32
221- assert block_scale_type in ["nvfp4" , "mxfp4" , "mxfp8" ], f"Invalid block scale type: { block_scale_type } "
222- ELEM_PER_BYTE = 2 if "fp4" in block_scale_type else 1
255+ assert block_scale_type in ["nvfp4" , "mxfp4" , "mxfp8" , "mixed" ], f"Invalid block scale type: { block_scale_type } "
256+ ELEM_PER_BYTE_A = 2 if "fp4" in block_scale_type else 1
257+ ELEM_PER_BYTE_B = 1 if block_scale_type == "mxfp8" else 2
223258
224259 device = "cuda"
225260 a_ref = MXFP4Tensor (size = (M , K ), device = device ).random ()
@@ -229,20 +264,32 @@ def initialize_block_scaled(M, N, K, block_scale_type="nvfp4", compute_reference
229264 # the data is generated in col-major layout, packed along K for fp4, and then
230265 # logically transposed. Note that if one operand is of fp8 precision, unlike Hopper,
231266 # Blackwell supports both row-major and col-major layouts for the RHS matrix.
267+ # For the mixed-precision case, the fp4 RHS can be either in row or col-major layout.
268+ # But for performance reason, it is recommended to use col-major layout. If TMA is used
269+ # for the fp4 RHS operand load in mixed-precision dot, as in this tutorial, it must be
270+ # in col-major layout.
232271 b_ref = MXFP4Tensor (size = (N , K ), device = device ).random ()
233- if block_scale_type == "mxfp8" :
272+ if block_scale_type in [ "mxfp8" , "mixed" ] :
234273 a_ref = a_ref .to (torch .float32 )
235- b_ref = b_ref .to (torch .float32 )
236274 a = a_ref .to (torch .float8_e4m3fn )
237- b = b_ref .to (torch .float8_e4m3fn )
238275 else :
239276 # Pack two fp4 elements per byte along K
240277 a = a_ref .to_packed_tensor (dim = 1 )
278+
279+ if block_scale_type == "mxfp8" :
280+ b_ref = b_ref .to (torch .float32 )
281+ b = b_ref .to (torch .float8_e4m3fn )
282+ else :
241283 b = b_ref .to_packed_tensor (dim = 1 )
284+
242285 b_ref = b_ref .to (torch .float32 ).T
243286
244- a_desc = TmaDescKernelParam (a .data_ptr (), a .shape , [BLOCK_M , BLOCK_K // ELEM_PER_BYTE ], 1 )
245- b_desc = TmaDescKernelParam (b .data_ptr (), b .shape , [BLOCK_N , BLOCK_K // ELEM_PER_BYTE ], 1 )
287+ a_desc = TmaDescKernelParam (a .data_ptr (), a .shape , [BLOCK_M , BLOCK_K // ELEM_PER_BYTE_A ], 1 )
288+
289+ if block_scale_type == "mixed" :
290+ b_desc_or_tensor = b
291+ else :
292+ b_desc_or_tensor = TmaDescKernelParam (b .data_ptr (), b .shape , [BLOCK_N , BLOCK_K // ELEM_PER_BYTE_B ], 1 )
246293
247294 epsilon = 1e-8
248295 a_scale = torch .rand ((M // 128 , K // VEC_SIZE // 4 , 32 , 4 , 4 ), device = device ) + epsilon
@@ -252,7 +299,7 @@ def initialize_block_scaled(M, N, K, block_scale_type="nvfp4", compute_reference
252299 b_scale = b_scale .to (torch .float8_e4m3fn )
253300 a_scale_ref = a_scale
254301 b_scale_ref = b_scale
255- elif block_scale_type in ["mxfp4" , "mxfp8" ]:
302+ elif block_scale_type in ["mxfp4" , "mxfp8" , "mixed" ]:
256303 a_scale_ref = MXScaleTensor (a_scale )
257304 b_scale_ref = MXScaleTensor (b_scale )
258305 a_scale = a_scale_ref .data
@@ -276,16 +323,26 @@ def unpack_scale(packed):
276323 "BLOCK_SIZE_N" : BLOCK_N ,
277324 "BLOCK_SIZE_K" : BLOCK_K ,
278325 "num_stages" : 4 ,
279- "ELEM_PER_BYTE" : ELEM_PER_BYTE ,
326+ "ELEM_PER_BYTE_A" : ELEM_PER_BYTE_A ,
327+ "ELEM_PER_BYTE_B" : ELEM_PER_BYTE_B ,
280328 "VEC_SIZE" : VEC_SIZE ,
281329 }
282- return a_desc , a_scale , b_desc , b_scale , configs , reference
330+ return a_desc , a_scale , b_desc_or_tensor , b_scale , configs , reference
283331
284332
285333def validate_block_scaled (M , N , K , block_scale_type = "nvfp4" ):
286- a_desc , a_scale , b_desc , b_scale , configs , reference = initialize_block_scaled (M , N , K , block_scale_type ,
287- compute_reference = True )
288- output = block_scaled_matmul (a_desc , a_scale , b_desc , b_scale , torch .float16 , M , N , K , configs )
334+
335+ def alloc_fn (size : int , align : int , _ ):
336+ return torch .empty (size , dtype = torch .int8 , device = "cuda" )
337+
338+ if block_scale_type == "mixed" :
339+ # This is needed for TMA with the descriptor created on the device.
340+ # TMA load for mixed-precision fp4 is supported only by device TMA.
341+ triton .set_allocator (alloc_fn )
342+
343+ a_desc , a_scale , b_desc_or_tensor , b_scale , configs , reference = initialize_block_scaled (
344+ M , N , K , block_scale_type , compute_reference = True )
345+ output = block_scaled_matmul (a_desc , a_scale , b_desc_or_tensor , b_scale , torch .float16 , M , N , K , configs )
289346 torch .testing .assert_close (reference , output .to (torch .float32 ), atol = 1e-3 , rtol = 1e-3 )
290347 print (f"✅ (pass { block_scale_type } )" )
291348
@@ -296,13 +353,19 @@ def bench_block_scaled(K, block_scale_type="nvfp4", reps=10):
296353 N = 8192
297354 print (f"Problem Shape = { M } x{ N } x{ K } " )
298355
299- a_desc , a_scale , b_desc , b_scale , configs , _ = initialize_block_scaled (M , N , K , block_scale_type ,
300- compute_reference = False )
301- _ = block_scaled_matmul (a_desc , a_scale , b_desc , b_scale , torch .float16 , M , N , K , configs )
356+ def alloc_fn (size : int , align : int , _ ):
357+ return torch .empty (size , dtype = torch .int8 , device = "cuda" )
358+
359+ if block_scale_type == "mixed" :
360+ triton .set_allocator (alloc_fn )
361+
362+ a_desc , a_scale , b_desc_or_tensor , b_scale , configs , _ = initialize_block_scaled (
363+ M , N , K , block_scale_type , compute_reference = False )
364+ _ = block_scaled_matmul (a_desc , a_scale , b_desc_or_tensor , b_scale , torch .float16 , M , N , K , configs )
302365
303366 proton .activate (0 )
304367 for _ in range (reps ):
305- _ = block_scaled_matmul (a_desc , a_scale , b_desc , b_scale , torch .float16 , M , N , K , configs )
368+ _ = block_scaled_matmul (a_desc , a_scale , b_desc_or_tensor , b_scale , torch .float16 , M , N , K , configs )
306369 proton .deactivate (0 )
307370 print ("Done benchmarking" )
308371
@@ -321,7 +384,7 @@ def show_profile(profile_name):
321384 parser .add_argument ("--K_range" , type = int , nargs = 2 )
322385 parser .add_argument ("--K_step" , type = int , default = 512 )
323386 parser .add_argument ("--bench" , action = "store_true" )
324- parser .add_argument ("--format" , type = str , choices = ["mxfp4" , "nvfp4" , "mxfp8" ], default = "nvfp4" )
387+ parser .add_argument ("--format" , type = str , choices = ["mxfp4" , "nvfp4" , "mxfp8" , "mixed" ], default = "nvfp4" )
325388 args = parser .parse_args ()
326389
327390 if not supports_block_scaling ():
0 commit comments