Skip to content

Commit cd4072d

Browse files
authored
[trainer] fix: fallback vision tower to flash_attention_2 for Qwen2.5-VL when u… (#4670)
# Fix: Fallback Vision Tower to Flash Attention 2 for Qwen2.5-VL when using Flash Attention 3 ## Description This PR adds a patch for Qwen2.5-VL models to fallback the vision tower's attention implementation to flash_attention_2 when the main model uses flash_attention_3. ## Motivation Qwen2.5-VL's vision tower does not support flash_attention_3 properly. When `attn_implementation` is set to `flash_attention_3`, using FA3 for the vision tower causes significant performance degradation compared to flash_attention_2. ## Experimental Validation We have tested this patch across the entire Qwen2.5-VL series (3B, 7B, 32B, and 72B models) using the Transformers library on an 8×H100 GPU machine with auto device placement. Below is the performance comparison for Qwen2.5-VL-7B with input of one 1260×700 image + 150 tokens of text: ``` ====================================================================== COMPARISON SUMMARY ====================================================================== Implementation Avg Latency (ms) Throughput (tok/s) ------------------------------------------------------------- flash_attention_2 102.85 12503.46 flash_attention_3 309.49 4155.19 FA3 vs FA2 Speedup: 0.33x Memory Difference: +0.00 GB ``` **Test Environment:** - Hardware: 8×H100 GPUs - Library: Transformers with auto device placement - Models tested: Qwen2.5-VL-3B, 7B, 32B, 72B **Key Findings:** - Flash Attention 3 is **3x slower** than Flash Attention 2 for the vision tower - No memory benefit from using FA3 for vision components - Consistent behavior observed across all model sizes (3B, 7B, 32B, 72B) ## Changes - Added a check for `qwen2_5_vl` model type - When `attn_implementation == "flash_attention_3"`, automatically set `actor_model_config.vision_config._attn_implementation = "flash_attention_2"` for the vision tower - This allows the language model to use FA3 while the vision tower uses FA2, achieving optimal performance ## Impact This change ensures that Qwen2.5-VL models can benefit from flash_attention_3 for text processing while maintaining optimal performance for vision encoding. ## Technical Details The patch is applied in `verl/workers/fsdp_workers.py` in the `_build_model_optimizer` method: ```python # patch for qwen2.5-vl: when using flash_attention_3, set vision tower to use flash_attention_2 # because the vision tower does not support flash_attention_3 if ( getattr(actor_model_config, "model_type", None) == "qwen2_5_vl" and attn_implementation == "flash_attention_3" and hasattr(actor_model_config, "vision_config") ): actor_model_config.vision_config._attn_implementation = "flash_attention_2" ``` ## Testing Tested on: - Qwen2.5-VL-3B - Qwen2.5-VL-7B - Qwen2.5-VL-32B - Qwen2.5-VL-72B All models show consistent performance improvements with this patch when using flash_attention_3 for the language model.
1 parent c790552 commit cd4072d

File tree

1 file changed

+9
-0
lines changed

1 file changed

+9
-0
lines changed

verl/workers/fsdp_workers.py

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -332,6 +332,15 @@ def _build_model_optimizer(
332332
if self.ulysses_sequence_parallel_size > 1 and hasattr(actor_model_config, "vision_config"):
333333
actor_model_config.vision_config._attn_implementation = "eager"
334334

335+
# patch for qwen2.5-vl: when using flash_attention_3, set vision tower to use flash_attention_2
336+
# because the vision tower does not support flash_attention_3
337+
if (
338+
getattr(actor_model_config, "model_type", None) == "qwen2_5_vl"
339+
and attn_implementation == "flash_attention_3"
340+
and hasattr(actor_model_config, "vision_config")
341+
):
342+
actor_model_config.vision_config._attn_implementation = "flash_attention_2"
343+
335344
# patch for kimi-vl
336345
if getattr(actor_model_config, "model_type", None) == "kimi_vl":
337346
actor_model_config.text_config.topk_method = "greedy"

0 commit comments

Comments
 (0)