Skip to content

Commit 9a8b04c

Browse files
committed
upd lint
1 parent 12793cf commit 9a8b04c

1 file changed

Lines changed: 40 additions & 10 deletions

File tree

sgl-kernel/benchmark/bench_fp4_gemm.py

Lines changed: 40 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -234,8 +234,14 @@ def benchmark(batch_size, provider, N, K, dtype, correctness, csv_file):
234234
elif provider == "cutlass":
235235
with autotune():
236236
_run_mm_fp4(
237-
a_fp4, b_fp4_T, a_scale_interleaved, b_sf_T,
238-
alpha, dtype, res_fi, backend="cutlass",
237+
a_fp4,
238+
b_fp4_T,
239+
a_scale_interleaved,
240+
b_sf_T,
241+
alpha,
242+
dtype,
243+
res_fi,
244+
backend="cutlass",
239245
)
240246
times_ms = bench_gpu_time(
241247
fn=partial(_run_mm_fp4, backend="cutlass"),
@@ -253,8 +259,14 @@ def benchmark(batch_size, provider, N, K, dtype, correctness, csv_file):
253259
elif provider == "cudnn":
254260
with autotune():
255261
_run_mm_fp4(
256-
a_fp4, b_fp4_T, a_scale_interleaved, b_sf_T,
257-
alpha, dtype, res_fi, backend="cudnn",
262+
a_fp4,
263+
b_fp4_T,
264+
a_scale_interleaved,
265+
b_sf_T,
266+
alpha,
267+
dtype,
268+
res_fi,
269+
backend="cudnn",
258270
)
259271
times_ms = bench_gpu_time(
260272
fn=partial(_run_mm_fp4, backend="cudnn"),
@@ -274,8 +286,14 @@ def benchmark(batch_size, provider, N, K, dtype, correctness, csv_file):
274286
b_sf_u8_T = b_sf_T.to(torch.uint8)
275287
with autotune():
276288
_run_mm_fp4(
277-
a_fp4, b_fp4_T, a_sf_u8, b_sf_u8_T,
278-
alpha, dtype, res_fi, backend="trtllm",
289+
a_fp4,
290+
b_fp4_T,
291+
a_sf_u8,
292+
b_sf_u8_T,
293+
alpha,
294+
dtype,
295+
res_fi,
296+
backend="trtllm",
279297
)
280298
times_ms = bench_gpu_time(
281299
fn=partial(_run_mm_fp4, backend="trtllm"),
@@ -285,8 +303,14 @@ def benchmark(batch_size, provider, N, K, dtype, correctness, csv_file):
285303
elif provider == "cute-dsl":
286304
with autotune():
287305
_run_mm_fp4(
288-
a_fp4, b_fp4_T, a_scale_interleaved, b_sf_T,
289-
alpha, dtype, res_fi, backend="cute-dsl",
306+
a_fp4,
307+
b_fp4_T,
308+
a_scale_interleaved,
309+
b_sf_T,
310+
alpha,
311+
dtype,
312+
res_fi,
313+
backend="cute-dsl",
290314
)
291315
times_ms = bench_gpu_time(
292316
fn=partial(_run_mm_fp4, backend="cute-dsl"),
@@ -304,8 +328,14 @@ def benchmark(batch_size, provider, N, K, dtype, correctness, csv_file):
304328
elif provider == "auto":
305329
with autotune():
306330
_run_mm_fp4(
307-
a_fp4, b_fp4_T, a_scale_interleaved, b_sf_T,
308-
alpha, dtype, res_fi, backend="auto",
331+
a_fp4,
332+
b_fp4_T,
333+
a_scale_interleaved,
334+
b_sf_T,
335+
alpha,
336+
dtype,
337+
res_fi,
338+
backend="auto",
309339
)
310340
times_ms = bench_gpu_time(
311341
fn=partial(_run_mm_fp4, backend="auto"),

0 commit comments

Comments
 (0)