2121 from transformers .models .llama .configuration_llama import LlamaConfig
2222
2323
24- VALID_MODLE_TYPE = {"llama" , "qwen2" , "qwen2_vl" , "qwen2_5_vl" , "qwen3" }
25-
26-
2724def get_device_flops (unit : str = "T" ) -> float :
2825 def unit_convert (number : float , level : str ):
2926 units = ["B" , "K" , "M" , "G" , "T" , "P" ]
@@ -51,6 +48,7 @@ def unit_convert(number: float, level: str):
5148 flops = 148e12
5249 elif "910B" in device_name :
5350 flops = 354e12
51+
5452 flops_unit = unit_convert (flops , unit )
5553 return flops_unit
5654
@@ -65,16 +63,19 @@ class FlopsCounter:
6563 """
6664
6765 def __init__ (self , config : "LlamaConfig" ):
68- if config .model_type not in VALID_MODLE_TYPE :
69- print (f"Only support { VALID_MODLE_TYPE } , but got { config .model_type } . MFU will always be zero." )
70-
71- self .estimate_func = {
66+ _ESTIMATE_FUNC = {
7267 "llama" : self ._estimate_llama_flops ,
7368 "qwen2" : self ._estimate_llama_flops ,
7469 "qwen2_vl" : self ._estimate_llama_flops ,
7570 "qwen2_5_vl" : self ._estimate_llama_flops ,
71+ "qwen3" : self ._estimate_llama_flops ,
7672 }
73+
74+ if config .model_type not in _ESTIMATE_FUNC :
75+ print (f"Only support { _ESTIMATE_FUNC .keys ()} , but got { config .model_type } . MFU will always be zero." )
76+
7777 self .config = config
78+ self ._estimate_flops = _ESTIMATE_FUNC .get (config .model_type , self ._estimate_unknown_flops )
7879
7980 def _estimate_unknown_flops (self , tokens_sum : int , batch_seqlens : List [int ], delta_time : float ) -> float :
8081 return 0
@@ -127,7 +128,6 @@ def estimate_flops(self, batch_seqlens: List[int], delta_time: float) -> Tuple[f
127128 promised_flops (float): The expected FLOPS of the current device.
128129 """
129130 tokens_sum = sum (batch_seqlens )
130- func = self .estimate_func .get (self .config .model_type , self ._estimate_unknown_flops )
131- estimated_flops = func (tokens_sum , batch_seqlens , delta_time )
131+ estimated_flops = self ._estimate_flops (tokens_sum , batch_seqlens , delta_time )
132132 promised_flops = get_device_flops ()
133133 return estimated_flops , promised_flops
0 commit comments