diff --git a/benchmarks/float8/float8_roofline.py b/benchmarks/float8/float8_roofline.py index 2f04b8ee8..19c6cc21b 100644 --- a/benchmarks/float8/float8_roofline.py +++ b/benchmarks/float8/float8_roofline.py @@ -70,6 +70,7 @@ ScalingType, CastConfig, ) +from torchao.float8.config import recipe_name_to_linear_config, Float8LinearRecipeName class LNLinearSigmoid(torch.nn.Module): @@ -129,6 +130,8 @@ def get_gemm_times(M, K, N, fast_accum, cache_filename=None): else: # cache does not exist yet, create it cache = dict() + else: + cache = dict() key = f"{M},{K},{N},{fast_accum}" if key in cache: return cache[key] @@ -153,13 +156,18 @@ def do_matmul(A, B): ) f8_time_s = get_gpu_kernel_gemm_time_s(do_matmul, A, B) + scale_a = torch.ones(M, 1, device=device) + scale_b = torch.ones(1, N, device=device) + fast_accum = True # for axiswise + f8_axs_time_s = get_gpu_kernel_gemm_time_s(do_matmul, A, B) + # save to cache if needed if cache_filename is not None: - cache[key] = [bf16_time_s, f8_time_s] + cache[key] = [bf16_time_s, f8_time_s, f8_axs_time_s] with open(cache_filename, 'w') as f: json.dump(cache, f) - return bf16_time_s, f8_time_s + return bf16_time_s, f8_time_s, f8_axs_time_s def run( outfile: str, @@ -231,13 +239,15 @@ def run( headers = [ 'fwd_M', 'fwd_K', 'fwd_N', # gemm microbenchmarks - 'bf16_gemm_s', 'fp8_gemm_s', + 'bf16_gemm_s', 'fp8_gemm_s', 'fp8_axs_gemm_time_s', # roofline memory overhead estimates 'fp8_oh_dyn_limit', 'fp8_oh_dyn_nolimit', 'fp8_oh_del_limit', 'fp8_oh_del_nolimit', # actual e2e measurements - 'bf16_e2e_s', 'fp8_dyn_e2e_s', 'fp8_del_e2e_s', - 'fp8_dyn_speedup', 'fp8_del_speedup', + 'bf16_s', 'fp8_dyn_s', 'fp8_del_s', 'fp8_dyn_axs_s', + # 'fp8_lw_s', + 'fp8_dyn_sp', 'fp8_del_sp', 'fp8_dyn_axs_sp', + # 'fp8_lw_sp', ] results = [] @@ -248,15 +258,18 @@ def run( break if gemm_time_strategy == "benchmarks": - bf16_g1, f8_g1 = get_gemm_times(M_val, K_val, N_val, True, gemm_cache_filename) - bf16_g2, f8_g2 = get_gemm_times(M_val, N_val, K_val, False, gemm_cache_filename) - bf16_g3, f8_g3 = get_gemm_times(K_val, M_val, N_val, False, gemm_cache_filename) + bf16_g1, f8_g1, f8_g1_axs = get_gemm_times(M_val, K_val, N_val, True, gemm_cache_filename) + bf16_g2, f8_g2, f8_g2_axs = get_gemm_times(M_val, N_val, K_val, False, gemm_cache_filename) + bf16_g3, f8_g3, f8_g3_axs = get_gemm_times(K_val, M_val, N_val, False, gemm_cache_filename) bf16_time_val = bf16_g1 + bf16_g2 + bf16_g3 fp8_gemm_time_s = f8_g1 + f8_g2 + f8_g3 + fp8_axs_gemm_time_s = f8_g1_axs + f8_g2_axs + f8_g3_axs else: assert gemm_time_strategy == "roofline", "unsupported" bf16_time_val = bf16_gemm_time_sympy.subs(M, M_val).subs(K, K_val).subs(N, N_val) fp8_gemm_time_s = fp8_gemm_time_sympy.subs(M, M_val).subs(K, K_val).subs(N, N_val) + # for now, assume axiswise gemm is similar to tensorwise + fp8_axs_gemm_time_s = fp8_gemm_time_s fp8_mem_time_dyn_limit_s = \ fp8_mem_time_sympy_dyn_limit.subs(M, M_val).subs(K, K_val).subs(N, N_val) @@ -291,14 +304,30 @@ def run( cast_config_weight=CastConfig(scaling_type=ScalingType.DELAYED), cast_config_grad_output=CastConfig(scaling_type=ScalingType.DELAYED), ) - m_fp8_del = convert_to_float8_training(m_orig) + m_fp8_del = convert_to_float8_training(copy.deepcopy(m_orig), config=config) m_fp8_del = torch.compile(m_fp8_del) fp8_del_time_actual_s = get_gpu_kernel_time(m_fp8_del, x) + # get the float8 dynamic axiswise scaling gpu kernel time + torch._dynamo.reset() + config = recipe_name_to_linear_config(Float8LinearRecipeName.ALL_AXISWISE) + m_fp8_dyn_axs = convert_to_float8_training(copy.deepcopy(m_orig), config=config) + m_fp8_dyn_axs = torch.compile(m_fp8_dyn_axs) + fp8_dyn_axs_time_actual_s = get_gpu_kernel_time(m_fp8_dyn_axs, x) + + # get the lw recipe scaling gpu kernel time + # TODO(future PR): enable below once basic performance issues + # are fixed + # torch._dynamo.reset() + # config = recipe_name_to_linear_config(Float8LinearRecipeName.LW_AXISWISE_WITH_GW_HP) + # m_fp8_lw = convert_to_float8_training(m_orig, config=config) + # m_fp8_lw = torch.compile(m_fp8_lw) + # fp8_lw_time_actual_s = get_gpu_kernel_time(m_fp8_lw, x) + results.append([ M_val, K_val, N_val, # gemm microbenchmarks - bf16_time_val, fp8_gemm_time_s, + bf16_time_val, fp8_gemm_time_s, fp8_axs_gemm_time_s, # roofline overhead estimates fp8_mem_time_dyn_limit_s, fp8_mem_time_dyn_nolimit_s, @@ -306,8 +335,12 @@ def run( fp8_mem_time_del_nolimit_s, # e2e numbers bf16_time_actual_s, fp8_dyn_time_actual_s, fp8_del_time_actual_s, + fp8_dyn_axs_time_actual_s, + # fp8_lw_time_actual_s, bf16_time_actual_s / fp8_dyn_time_actual_s, bf16_time_actual_s / fp8_del_time_actual_s, + bf16_time_actual_s / fp8_dyn_axs_time_actual_s, + # bf16_time_actual_s / fp8_lw_time_actual_s, ]) df = pd.DataFrame(results, columns=headers)