@@ -234,8 +234,14 @@ def benchmark(batch_size, provider, N, K, dtype, correctness, csv_file):
234234 elif provider == "cutlass" :
235235 with autotune ():
236236 _run_mm_fp4 (
237- a_fp4 , b_fp4_T , a_scale_interleaved , b_sf_T ,
238- alpha , dtype , res_fi , backend = "cutlass" ,
237+ a_fp4 ,
238+ b_fp4_T ,
239+ a_scale_interleaved ,
240+ b_sf_T ,
241+ alpha ,
242+ dtype ,
243+ res_fi ,
244+ backend = "cutlass" ,
239245 )
240246 times_ms = bench_gpu_time (
241247 fn = partial (_run_mm_fp4 , backend = "cutlass" ),
@@ -253,8 +259,14 @@ def benchmark(batch_size, provider, N, K, dtype, correctness, csv_file):
253259 elif provider == "cudnn" :
254260 with autotune ():
255261 _run_mm_fp4 (
256- a_fp4 , b_fp4_T , a_scale_interleaved , b_sf_T ,
257- alpha , dtype , res_fi , backend = "cudnn" ,
262+ a_fp4 ,
263+ b_fp4_T ,
264+ a_scale_interleaved ,
265+ b_sf_T ,
266+ alpha ,
267+ dtype ,
268+ res_fi ,
269+ backend = "cudnn" ,
258270 )
259271 times_ms = bench_gpu_time (
260272 fn = partial (_run_mm_fp4 , backend = "cudnn" ),
@@ -274,8 +286,14 @@ def benchmark(batch_size, provider, N, K, dtype, correctness, csv_file):
274286 b_sf_u8_T = b_sf_T .to (torch .uint8 )
275287 with autotune ():
276288 _run_mm_fp4 (
277- a_fp4 , b_fp4_T , a_sf_u8 , b_sf_u8_T ,
278- alpha , dtype , res_fi , backend = "trtllm" ,
289+ a_fp4 ,
290+ b_fp4_T ,
291+ a_sf_u8 ,
292+ b_sf_u8_T ,
293+ alpha ,
294+ dtype ,
295+ res_fi ,
296+ backend = "trtllm" ,
279297 )
280298 times_ms = bench_gpu_time (
281299 fn = partial (_run_mm_fp4 , backend = "trtllm" ),
@@ -285,8 +303,14 @@ def benchmark(batch_size, provider, N, K, dtype, correctness, csv_file):
285303 elif provider == "cute-dsl" :
286304 with autotune ():
287305 _run_mm_fp4 (
288- a_fp4 , b_fp4_T , a_scale_interleaved , b_sf_T ,
289- alpha , dtype , res_fi , backend = "cute-dsl" ,
306+ a_fp4 ,
307+ b_fp4_T ,
308+ a_scale_interleaved ,
309+ b_sf_T ,
310+ alpha ,
311+ dtype ,
312+ res_fi ,
313+ backend = "cute-dsl" ,
290314 )
291315 times_ms = bench_gpu_time (
292316 fn = partial (_run_mm_fp4 , backend = "cute-dsl" ),
@@ -304,8 +328,14 @@ def benchmark(batch_size, provider, N, K, dtype, correctness, csv_file):
304328 elif provider == "auto" :
305329 with autotune ():
306330 _run_mm_fp4 (
307- a_fp4 , b_fp4_T , a_scale_interleaved , b_sf_T ,
308- alpha , dtype , res_fi , backend = "auto" ,
331+ a_fp4 ,
332+ b_fp4_T ,
333+ a_scale_interleaved ,
334+ b_sf_T ,
335+ alpha ,
336+ dtype ,
337+ res_fi ,
338+ backend = "auto" ,
309339 )
310340 times_ms = bench_gpu_time (
311341 fn = partial (_run_mm_fp4 , backend = "auto" ),
0 commit comments