@@ -78,21 +78,22 @@ def __init__(self, config: "LlamaConfig"):
7878 if config .model_type not in _ESTIMATE_FUNC :
7979 print (f"Only support { _ESTIMATE_FUNC .keys ()} , but got { config .model_type } . MFU will always be zero." )
8080
81- self .config = config
81+ self .config = getattr ( config , "text_config" , config )
8282 self ._estimate_flops = _ESTIMATE_FUNC .get (config .model_type , self ._estimate_unknown_flops )
8383
8484 def _estimate_unknown_flops (self , tokens_sum : int , batch_seqlens : List [int ], delta_time : float ) -> float :
8585 return 0
8686
8787 def _estimate_llama_flops (self , tokens_sum : int , batch_seqlens : List [int ], delta_time : float ) -> float :
88- hidden_size = self .config .hidden_size
89- vocab_size = self .config .vocab_size
90- num_hidden_layers = self .config .num_hidden_layers
91- num_key_value_heads = self .config .num_key_value_heads
92- num_attention_heads = self .config .num_attention_heads
93- intermediate_size = self .config .intermediate_size
94-
95- head_dim = hidden_size // num_attention_heads
88+ config = self .config
89+ hidden_size = config .hidden_size
90+ vocab_size = config .vocab_size
91+ num_hidden_layers = config .num_hidden_layers
92+ num_key_value_heads = config .num_key_value_heads
93+ num_attention_heads = config .num_attention_heads
94+ intermediate_size = config .intermediate_size
95+
96+ head_dim = getattr (config , "head_dim" , hidden_size // num_attention_heads )
9697 q_size = num_attention_heads * head_dim
9798 k_size = num_key_value_heads * head_dim
9899 v_size = num_key_value_heads * head_dim
@@ -120,7 +121,7 @@ def _estimate_llama_flops(self, tokens_sum: int, batch_seqlens: List[int], delta
120121 return flops_achieved
121122
122123 def _estimate_qwen2_moe_flops (self , tokens_sum : int , batch_seqlens : List [int ], delta_time : float ) -> float :
123- config = self .config . text_config if hasattr ( self . config , "text_config" ) else self . config
124+ config = self .config
124125 hidden_size = config .hidden_size
125126 vocab_size = config .vocab_size
126127 num_hidden_layers = config .num_hidden_layers
0 commit comments