Skip to content

Commit 61b9384

Browse files
authored
[misc] fix flops counter (#401)
1 parent 601a37c commit 61b9384

3 files changed

Lines changed: 16 additions & 11 deletions

File tree

verl/models/transformers/flash_attention_utils.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -171,7 +171,7 @@ def flash_attention_forward(
171171
value,
172172
attention_mask,
173173
query_length=q_len,
174-
is_causal=True,
174+
is_causal=module.is_causal,
175175
dropout=dropout,
176176
softmax_scale=scaling,
177177
sliding_window=sliding_window,

verl/utils/flops_counter.py

Lines changed: 9 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -21,9 +21,6 @@
2121
from transformers.models.llama.configuration_llama import LlamaConfig
2222

2323

24-
VALID_MODLE_TYPE = {"llama", "qwen2", "qwen2_vl", "qwen2_5_vl", "qwen3"}
25-
26-
2724
def get_device_flops(unit: str = "T") -> float:
2825
def unit_convert(number: float, level: str):
2926
units = ["B", "K", "M", "G", "T", "P"]
@@ -51,6 +48,7 @@ def unit_convert(number: float, level: str):
5148
flops = 148e12
5249
elif "910B" in device_name:
5350
flops = 354e12
51+
5452
flops_unit = unit_convert(flops, unit)
5553
return flops_unit
5654

@@ -65,16 +63,19 @@ class FlopsCounter:
6563
"""
6664

6765
def __init__(self, config: "LlamaConfig"):
68-
if config.model_type not in VALID_MODLE_TYPE:
69-
print(f"Only support {VALID_MODLE_TYPE}, but got {config.model_type}. MFU will always be zero.")
70-
71-
self.estimate_func = {
66+
_ESTIMATE_FUNC = {
7267
"llama": self._estimate_llama_flops,
7368
"qwen2": self._estimate_llama_flops,
7469
"qwen2_vl": self._estimate_llama_flops,
7570
"qwen2_5_vl": self._estimate_llama_flops,
71+
"qwen3": self._estimate_llama_flops,
7672
}
73+
74+
if config.model_type not in _ESTIMATE_FUNC:
75+
print(f"Only support {_ESTIMATE_FUNC.keys()}, but got {config.model_type}. MFU will always be zero.")
76+
7777
self.config = config
78+
self._estimate_flops = _ESTIMATE_FUNC.get(config.model_type, self._estimate_unknown_flops)
7879

7980
def _estimate_unknown_flops(self, tokens_sum: int, batch_seqlens: List[int], delta_time: float) -> float:
8081
return 0
@@ -127,7 +128,6 @@ def estimate_flops(self, batch_seqlens: List[int], delta_time: float) -> Tuple[f
127128
promised_flops (float): The expected FLOPS of the current device.
128129
"""
129130
tokens_sum = sum(batch_seqlens)
130-
func = self.estimate_func.get(self.config.model_type, self._estimate_unknown_flops)
131-
estimated_flops = func(tokens_sum, batch_seqlens, delta_time)
131+
estimated_flops = self._estimate_flops(tokens_sum, batch_seqlens, delta_time)
132132
promised_flops = get_device_flops()
133133
return estimated_flops, promised_flops

verl/workers/actor/dp_actor.py

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -24,7 +24,6 @@
2424
from ray.experimental.tqdm_ray import tqdm
2525
from torch import nn
2626
from torch.distributed.fsdp import FullyShardedDataParallel as FSDP
27-
from transformers.modeling_flash_attention_utils import index_first_axis, pad_input, unpad_input
2827

2928
from ...protocol import DataProto
3029
from ...trainer.core_algos import average_loss, compute_kl, compute_policy_loss
@@ -35,6 +34,12 @@
3534
from .config import ActorConfig
3635

3736

37+
try:
38+
from flash_attn.bert_padding import index_first_axis, pad_input, rearrange, unpad_input
39+
except ImportError:
40+
pass
41+
42+
3843
__all__ = ["DataParallelPPOActor"]
3944

4045

0 commit comments

Comments
 (0)