Skip to content

Commit ab8bf6f

Browse files
authored
[misc] update flops_counter (#523)
1 parent da0399d commit ab8bf6f

1 file changed

Lines changed: 11 additions & 10 deletions

File tree

verl/utils/flops_counter.py

Lines changed: 11 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -78,21 +78,22 @@ def __init__(self, config: "LlamaConfig"):
7878
if config.model_type not in _ESTIMATE_FUNC:
7979
print(f"Only support {_ESTIMATE_FUNC.keys()}, but got {config.model_type}. MFU will always be zero.")
8080

81-
self.config = config
81+
self.config = getattr(config, "text_config", config)
8282
self._estimate_flops = _ESTIMATE_FUNC.get(config.model_type, self._estimate_unknown_flops)
8383

8484
def _estimate_unknown_flops(self, tokens_sum: int, batch_seqlens: List[int], delta_time: float) -> float:
8585
return 0
8686

8787
def _estimate_llama_flops(self, tokens_sum: int, batch_seqlens: List[int], delta_time: float) -> float:
88-
hidden_size = self.config.hidden_size
89-
vocab_size = self.config.vocab_size
90-
num_hidden_layers = self.config.num_hidden_layers
91-
num_key_value_heads = self.config.num_key_value_heads
92-
num_attention_heads = self.config.num_attention_heads
93-
intermediate_size = self.config.intermediate_size
94-
95-
head_dim = hidden_size // num_attention_heads
88+
config = self.config
89+
hidden_size = config.hidden_size
90+
vocab_size = config.vocab_size
91+
num_hidden_layers = config.num_hidden_layers
92+
num_key_value_heads = config.num_key_value_heads
93+
num_attention_heads = config.num_attention_heads
94+
intermediate_size = config.intermediate_size
95+
96+
head_dim = getattr(config, "head_dim", hidden_size // num_attention_heads)
9697
q_size = num_attention_heads * head_dim
9798
k_size = num_key_value_heads * head_dim
9899
v_size = num_key_value_heads * head_dim
@@ -120,7 +121,7 @@ def _estimate_llama_flops(self, tokens_sum: int, batch_seqlens: List[int], delta
120121
return flops_achieved
121122

122123
def _estimate_qwen2_moe_flops(self, tokens_sum: int, batch_seqlens: List[int], delta_time: float) -> float:
123-
config = self.config.text_config if hasattr(self.config, "text_config") else self.config
124+
config = self.config
124125
hidden_size = config.hidden_size
125126
vocab_size = config.vocab_size
126127
num_hidden_layers = config.num_hidden_layers

0 commit comments

Comments
 (0)