Skip to content

Commit

Permalink
Update
Browse files Browse the repository at this point in the history
[ghstack-poisoned]
  • Loading branch information
vkuzo committed Oct 5, 2024
1 parent ca127f0 commit 712fd5d
Showing 1 changed file with 43 additions and 10 deletions.
53 changes: 43 additions & 10 deletions benchmarks/float8/float8_roofline.py
Original file line number Diff line number Diff line change
Expand Up @@ -70,6 +70,7 @@
ScalingType,
CastConfig,
)
from torchao.float8.config import recipe_name_to_linear_config, Float8LinearRecipeName


class LNLinearSigmoid(torch.nn.Module):
Expand Down Expand Up @@ -129,6 +130,8 @@ def get_gemm_times(M, K, N, fast_accum, cache_filename=None):
else:
# cache does not exist yet, create it
cache = dict()
else:
cache = dict()
key = f"{M},{K},{N},{fast_accum}"
if key in cache:
return cache[key]
Expand All @@ -153,13 +156,18 @@ def do_matmul(A, B):
)
f8_time_s = get_gpu_kernel_gemm_time_s(do_matmul, A, B)

scale_a = torch.ones(M, 1, device=device)
scale_b = torch.ones(1, N, device=device)
fast_accum = True # for axiswise
f8_axs_time_s = get_gpu_kernel_gemm_time_s(do_matmul, A, B)

# save to cache if needed
if cache_filename is not None:
cache[key] = [bf16_time_s, f8_time_s]
cache[key] = [bf16_time_s, f8_time_s, f8_axs_time_s]
with open(cache_filename, 'w') as f:
json.dump(cache, f)

return bf16_time_s, f8_time_s
return bf16_time_s, f8_time_s, f8_axs_time_s

def run(
outfile: str,
Expand Down Expand Up @@ -231,13 +239,15 @@ def run(
headers = [
'fwd_M', 'fwd_K', 'fwd_N',
# gemm microbenchmarks
'bf16_gemm_s', 'fp8_gemm_s',
'bf16_gemm_s', 'fp8_gemm_s', 'fp8_axs_gemm_time_s',
# roofline memory overhead estimates
'fp8_oh_dyn_limit', 'fp8_oh_dyn_nolimit',
'fp8_oh_del_limit', 'fp8_oh_del_nolimit',
# actual e2e measurements
'bf16_e2e_s', 'fp8_dyn_e2e_s', 'fp8_del_e2e_s',
'fp8_dyn_speedup', 'fp8_del_speedup',
'bf16_s', 'fp8_dyn_s', 'fp8_del_s', 'fp8_dyn_axs_s',
# 'fp8_lw_s',
'fp8_dyn_sp', 'fp8_del_sp', 'fp8_dyn_axs_sp',
# 'fp8_lw_sp',
]
results = []

Expand All @@ -248,15 +258,18 @@ def run(
break

if gemm_time_strategy == "benchmarks":
bf16_g1, f8_g1 = get_gemm_times(M_val, K_val, N_val, True, gemm_cache_filename)
bf16_g2, f8_g2 = get_gemm_times(M_val, N_val, K_val, False, gemm_cache_filename)
bf16_g3, f8_g3 = get_gemm_times(K_val, M_val, N_val, False, gemm_cache_filename)
bf16_g1, f8_g1, f8_g1_axs = get_gemm_times(M_val, K_val, N_val, True, gemm_cache_filename)
bf16_g2, f8_g2, f8_g2_axs = get_gemm_times(M_val, N_val, K_val, False, gemm_cache_filename)
bf16_g3, f8_g3, f8_g3_axs = get_gemm_times(K_val, M_val, N_val, False, gemm_cache_filename)
bf16_time_val = bf16_g1 + bf16_g2 + bf16_g3
fp8_gemm_time_s = f8_g1 + f8_g2 + f8_g3
fp8_axs_gemm_time_s = f8_g1_axs + f8_g2_axs + f8_g3_axs
else:
assert gemm_time_strategy == "roofline", "unsupported"
bf16_time_val = bf16_gemm_time_sympy.subs(M, M_val).subs(K, K_val).subs(N, N_val)
fp8_gemm_time_s = fp8_gemm_time_sympy.subs(M, M_val).subs(K, K_val).subs(N, N_val)
# for now, assume axiswise gemm is similar to tensorwise
fp8_axs_gemm_time_s = fp8_gemm_time_s

fp8_mem_time_dyn_limit_s = \
fp8_mem_time_sympy_dyn_limit.subs(M, M_val).subs(K, K_val).subs(N, N_val)
Expand Down Expand Up @@ -291,23 +304,43 @@ def run(
cast_config_weight=CastConfig(scaling_type=ScalingType.DELAYED),
cast_config_grad_output=CastConfig(scaling_type=ScalingType.DELAYED),
)
m_fp8_del = convert_to_float8_training(m_orig)
m_fp8_del = convert_to_float8_training(copy.deepcopy(m_orig), config=config)
m_fp8_del = torch.compile(m_fp8_del)
fp8_del_time_actual_s = get_gpu_kernel_time(m_fp8_del, x)

# get the float8 dynamic axiswise scaling gpu kernel time
torch._dynamo.reset()
config = recipe_name_to_linear_config(Float8LinearRecipeName.ALL_AXISWISE)
m_fp8_dyn_axs = convert_to_float8_training(copy.deepcopy(m_orig), config=config)
m_fp8_dyn_axs = torch.compile(m_fp8_dyn_axs)
fp8_dyn_axs_time_actual_s = get_gpu_kernel_time(m_fp8_dyn_axs, x)

# get the lw recipe scaling gpu kernel time
# TODO(future PR): enable below once basic performance issues
# are fixed
# torch._dynamo.reset()
# config = recipe_name_to_linear_config(Float8LinearRecipeName.LW_AXISWISE_WITH_GW_HP)
# m_fp8_lw = convert_to_float8_training(m_orig, config=config)
# m_fp8_lw = torch.compile(m_fp8_lw)
# fp8_lw_time_actual_s = get_gpu_kernel_time(m_fp8_lw, x)

results.append([
M_val, K_val, N_val,
# gemm microbenchmarks
bf16_time_val, fp8_gemm_time_s,
bf16_time_val, fp8_gemm_time_s, fp8_axs_gemm_time_s,
# roofline overhead estimates
fp8_mem_time_dyn_limit_s,
fp8_mem_time_dyn_nolimit_s,
fp8_mem_time_del_limit_s,
fp8_mem_time_del_nolimit_s,
# e2e numbers
bf16_time_actual_s, fp8_dyn_time_actual_s, fp8_del_time_actual_s,
fp8_dyn_axs_time_actual_s,
# fp8_lw_time_actual_s,
bf16_time_actual_s / fp8_dyn_time_actual_s,
bf16_time_actual_s / fp8_del_time_actual_s,
bf16_time_actual_s / fp8_dyn_axs_time_actual_s,
# bf16_time_actual_s / fp8_lw_time_actual_s,
])

df = pd.DataFrame(results, columns=headers)
Expand Down

0 comments on commit 712fd5d

Please sign in to comment.