Commit cd4072d
authored
[trainer] fix: fallback vision tower to flash_attention_2 for Qwen2.5-VL when u… (#4670)
# Fix: Fallback Vision Tower to Flash Attention 2 for Qwen2.5-VL when
using Flash Attention 3
## Description
This PR adds a patch for Qwen2.5-VL models to fallback the vision
tower's attention implementation to flash_attention_2 when the main
model uses flash_attention_3.
## Motivation
Qwen2.5-VL's vision tower does not support flash_attention_3 properly.
When `attn_implementation` is set to `flash_attention_3`, using FA3 for
the vision tower causes significant performance degradation compared to
flash_attention_2.
## Experimental Validation
We have tested this patch across the entire Qwen2.5-VL series (3B, 7B,
32B, and 72B models) using the Transformers library on an 8×H100 GPU
machine with auto device placement.
Below is the performance comparison for Qwen2.5-VL-7B with input of one
1260×700 image + 150 tokens of text:
```
======================================================================
COMPARISON SUMMARY
======================================================================
Implementation Avg Latency (ms) Throughput (tok/s)
-------------------------------------------------------------
flash_attention_2 102.85 12503.46
flash_attention_3 309.49 4155.19
FA3 vs FA2 Speedup: 0.33x
Memory Difference: +0.00 GB
```
**Test Environment:**
- Hardware: 8×H100 GPUs
- Library: Transformers with auto device placement
- Models tested: Qwen2.5-VL-3B, 7B, 32B, 72B
**Key Findings:**
- Flash Attention 3 is **3x slower** than Flash Attention 2 for the
vision tower
- No memory benefit from using FA3 for vision components
- Consistent behavior observed across all model sizes (3B, 7B, 32B, 72B)
## Changes
- Added a check for `qwen2_5_vl` model type
- When `attn_implementation == "flash_attention_3"`, automatically set
`actor_model_config.vision_config._attn_implementation =
"flash_attention_2"` for the vision tower
- This allows the language model to use FA3 while the vision tower uses
FA2, achieving optimal performance
## Impact
This change ensures that Qwen2.5-VL models can benefit from
flash_attention_3 for text processing while maintaining optimal
performance for vision encoding.
## Technical Details
The patch is applied in `verl/workers/fsdp_workers.py` in the
`_build_model_optimizer` method:
```python
# patch for qwen2.5-vl: when using flash_attention_3, set vision tower to use flash_attention_2
# because the vision tower does not support flash_attention_3
if (
getattr(actor_model_config, "model_type", None) == "qwen2_5_vl"
and attn_implementation == "flash_attention_3"
and hasattr(actor_model_config, "vision_config")
):
actor_model_config.vision_config._attn_implementation = "flash_attention_2"
```
## Testing
Tested on:
- Qwen2.5-VL-3B
- Qwen2.5-VL-7B
- Qwen2.5-VL-32B
- Qwen2.5-VL-72B
All models show consistent performance improvements with this patch when
using flash_attention_3 for the language model.1 parent c790552 commit cd4072d
1 file changed
+9
-0
lines changed| Original file line number | Diff line number | Diff line change | |
|---|---|---|---|
| |||
332 | 332 | | |
333 | 333 | | |
334 | 334 | | |
| 335 | + | |
| 336 | + | |
| 337 | + | |
| 338 | + | |
| 339 | + | |
| 340 | + | |
| 341 | + | |
| 342 | + | |
| 343 | + | |
335 | 344 | | |
336 | 345 | | |
337 | 346 | | |
| |||
0 commit comments