Skip to content

Commit 712fd5d

Browse files
committed
Update
[ghstack-poisoned]
1 parent ca127f0 commit 712fd5d

File tree

1 file changed

+43
-10
lines changed

1 file changed

+43
-10
lines changed

benchmarks/float8/float8_roofline.py

Lines changed: 43 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -70,6 +70,7 @@
7070
ScalingType,
7171
CastConfig,
7272
)
73+
from torchao.float8.config import recipe_name_to_linear_config, Float8LinearRecipeName
7374

7475

7576
class LNLinearSigmoid(torch.nn.Module):
@@ -129,6 +130,8 @@ def get_gemm_times(M, K, N, fast_accum, cache_filename=None):
129130
else:
130131
# cache does not exist yet, create it
131132
cache = dict()
133+
else:
134+
cache = dict()
132135
key = f"{M},{K},{N},{fast_accum}"
133136
if key in cache:
134137
return cache[key]
@@ -153,13 +156,18 @@ def do_matmul(A, B):
153156
)
154157
f8_time_s = get_gpu_kernel_gemm_time_s(do_matmul, A, B)
155158

159+
scale_a = torch.ones(M, 1, device=device)
160+
scale_b = torch.ones(1, N, device=device)
161+
fast_accum = True # for axiswise
162+
f8_axs_time_s = get_gpu_kernel_gemm_time_s(do_matmul, A, B)
163+
156164
# save to cache if needed
157165
if cache_filename is not None:
158-
cache[key] = [bf16_time_s, f8_time_s]
166+
cache[key] = [bf16_time_s, f8_time_s, f8_axs_time_s]
159167
with open(cache_filename, 'w') as f:
160168
json.dump(cache, f)
161169

162-
return bf16_time_s, f8_time_s
170+
return bf16_time_s, f8_time_s, f8_axs_time_s
163171

164172
def run(
165173
outfile: str,
@@ -231,13 +239,15 @@ def run(
231239
headers = [
232240
'fwd_M', 'fwd_K', 'fwd_N',
233241
# gemm microbenchmarks
234-
'bf16_gemm_s', 'fp8_gemm_s',
242+
'bf16_gemm_s', 'fp8_gemm_s', 'fp8_axs_gemm_time_s',
235243
# roofline memory overhead estimates
236244
'fp8_oh_dyn_limit', 'fp8_oh_dyn_nolimit',
237245
'fp8_oh_del_limit', 'fp8_oh_del_nolimit',
238246
# actual e2e measurements
239-
'bf16_e2e_s', 'fp8_dyn_e2e_s', 'fp8_del_e2e_s',
240-
'fp8_dyn_speedup', 'fp8_del_speedup',
247+
'bf16_s', 'fp8_dyn_s', 'fp8_del_s', 'fp8_dyn_axs_s',
248+
# 'fp8_lw_s',
249+
'fp8_dyn_sp', 'fp8_del_sp', 'fp8_dyn_axs_sp',
250+
# 'fp8_lw_sp',
241251
]
242252
results = []
243253

@@ -248,15 +258,18 @@ def run(
248258
break
249259

250260
if gemm_time_strategy == "benchmarks":
251-
bf16_g1, f8_g1 = get_gemm_times(M_val, K_val, N_val, True, gemm_cache_filename)
252-
bf16_g2, f8_g2 = get_gemm_times(M_val, N_val, K_val, False, gemm_cache_filename)
253-
bf16_g3, f8_g3 = get_gemm_times(K_val, M_val, N_val, False, gemm_cache_filename)
261+
bf16_g1, f8_g1, f8_g1_axs = get_gemm_times(M_val, K_val, N_val, True, gemm_cache_filename)
262+
bf16_g2, f8_g2, f8_g2_axs = get_gemm_times(M_val, N_val, K_val, False, gemm_cache_filename)
263+
bf16_g3, f8_g3, f8_g3_axs = get_gemm_times(K_val, M_val, N_val, False, gemm_cache_filename)
254264
bf16_time_val = bf16_g1 + bf16_g2 + bf16_g3
255265
fp8_gemm_time_s = f8_g1 + f8_g2 + f8_g3
266+
fp8_axs_gemm_time_s = f8_g1_axs + f8_g2_axs + f8_g3_axs
256267
else:
257268
assert gemm_time_strategy == "roofline", "unsupported"
258269
bf16_time_val = bf16_gemm_time_sympy.subs(M, M_val).subs(K, K_val).subs(N, N_val)
259270
fp8_gemm_time_s = fp8_gemm_time_sympy.subs(M, M_val).subs(K, K_val).subs(N, N_val)
271+
# for now, assume axiswise gemm is similar to tensorwise
272+
fp8_axs_gemm_time_s = fp8_gemm_time_s
260273

261274
fp8_mem_time_dyn_limit_s = \
262275
fp8_mem_time_sympy_dyn_limit.subs(M, M_val).subs(K, K_val).subs(N, N_val)
@@ -291,23 +304,43 @@ def run(
291304
cast_config_weight=CastConfig(scaling_type=ScalingType.DELAYED),
292305
cast_config_grad_output=CastConfig(scaling_type=ScalingType.DELAYED),
293306
)
294-
m_fp8_del = convert_to_float8_training(m_orig)
307+
m_fp8_del = convert_to_float8_training(copy.deepcopy(m_orig), config=config)
295308
m_fp8_del = torch.compile(m_fp8_del)
296309
fp8_del_time_actual_s = get_gpu_kernel_time(m_fp8_del, x)
297310

311+
# get the float8 dynamic axiswise scaling gpu kernel time
312+
torch._dynamo.reset()
313+
config = recipe_name_to_linear_config(Float8LinearRecipeName.ALL_AXISWISE)
314+
m_fp8_dyn_axs = convert_to_float8_training(copy.deepcopy(m_orig), config=config)
315+
m_fp8_dyn_axs = torch.compile(m_fp8_dyn_axs)
316+
fp8_dyn_axs_time_actual_s = get_gpu_kernel_time(m_fp8_dyn_axs, x)
317+
318+
# get the lw recipe scaling gpu kernel time
319+
# TODO(future PR): enable below once basic performance issues
320+
# are fixed
321+
# torch._dynamo.reset()
322+
# config = recipe_name_to_linear_config(Float8LinearRecipeName.LW_AXISWISE_WITH_GW_HP)
323+
# m_fp8_lw = convert_to_float8_training(m_orig, config=config)
324+
# m_fp8_lw = torch.compile(m_fp8_lw)
325+
# fp8_lw_time_actual_s = get_gpu_kernel_time(m_fp8_lw, x)
326+
298327
results.append([
299328
M_val, K_val, N_val,
300329
# gemm microbenchmarks
301-
bf16_time_val, fp8_gemm_time_s,
330+
bf16_time_val, fp8_gemm_time_s, fp8_axs_gemm_time_s,
302331
# roofline overhead estimates
303332
fp8_mem_time_dyn_limit_s,
304333
fp8_mem_time_dyn_nolimit_s,
305334
fp8_mem_time_del_limit_s,
306335
fp8_mem_time_del_nolimit_s,
307336
# e2e numbers
308337
bf16_time_actual_s, fp8_dyn_time_actual_s, fp8_del_time_actual_s,
338+
fp8_dyn_axs_time_actual_s,
339+
# fp8_lw_time_actual_s,
309340
bf16_time_actual_s / fp8_dyn_time_actual_s,
310341
bf16_time_actual_s / fp8_del_time_actual_s,
342+
bf16_time_actual_s / fp8_dyn_axs_time_actual_s,
343+
# bf16_time_actual_s / fp8_lw_time_actual_s,
311344
])
312345

313346
df = pd.DataFrame(results, columns=headers)

0 commit comments

Comments
 (0)