Skip to content

Commit e9621bb

Browse files
committed
upgrade transformers to 4.54
1 parent a67a971 commit e9621bb

8 files changed

Lines changed: 73 additions & 220 deletions

File tree

.github/requirements-test.txt

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -8,4 +8,4 @@ ruff
88
tensordict
99
torch
1010
torchvision
11-
transformers>=4.51.0,<=4.56.1
11+
transformers>=4.54.0,<=4.56.2

README.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -36,7 +36,7 @@ EasyR1 is efficient and scalable due to the design of **[HybirdEngine](https://a
3636
### Software Requirements
3737

3838
- Python 3.9+
39-
- transformers>=4.51.0
39+
- transformers>=4.54.0
4040
- flash-attn>=2.4.3
4141
- vllm>=0.8.3
4242

requirements.txt

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,6 @@ qwen-vl-utils
1515
ray[default]
1616
tensordict
1717
torchdata
18-
transformers>=4.51.0,<=4.56.1
18+
transformers>=4.54.0,<=4.56.2
1919
vllm>=0.8.0
2020
wandb

verl/models/monkey_patch.py

Lines changed: 31 additions & 37 deletions
Original file line numberDiff line numberDiff line change
@@ -15,48 +15,42 @@
1515

1616
from transformers.modeling_utils import ALL_ATTENTION_FUNCTIONS
1717

18-
from ..utils.py_functional import is_transformers_version_greater_than
1918
from .transformers.flash_attention_utils import flash_attention_forward
20-
from .transformers.qwen2_vl import (
21-
qwen2_vl_attn_forward,
22-
qwen2_vl_base_forward_new,
23-
qwen2_vl_forward_new,
24-
qwen2_vl_forward_old,
19+
from .transformers.qwen2_vl import qwen2_vl_base_forward, qwen2_vl_model_forward
20+
21+
22+
SUPPORTED_MODEL_TYPE = (
23+
"llama",
24+
"gemma",
25+
"gemma2",
26+
"mistral",
27+
"qwen2",
28+
"qwen2_moe",
29+
"qwen3",
30+
"qwen3_moe",
31+
"qwen2_vl",
32+
"qwen2_5_vl",
2533
)
2634

35+
SUPPORTED_VLM_TYPE = ("qwen2_vl", "qwen2_5_vl")
36+
2737

2838
def apply_ulysses_patch(model_type: str) -> None:
29-
if model_type in ("llama", "gemma", "gemma2", "mistral", "qwen2", "qwen3", "qwen3_moe"):
39+
if model_type in SUPPORTED_MODEL_TYPE:
3040
ALL_ATTENTION_FUNCTIONS["flash_attention_2"] = flash_attention_forward
31-
elif model_type in ("qwen2_vl", "qwen2_5_vl"):
32-
if is_transformers_version_greater_than("4.54.0"):
33-
# transformers 4.54.0 does not need special patch: https://github.com/huggingface/transformers/pull/39447
34-
ALL_ATTENTION_FUNCTIONS["flash_attention_2"] = flash_attention_forward
35-
elif is_transformers_version_greater_than("4.53.0"):
36-
raise NotImplementedError("Transformers 4.53.* is not compatible with Qwen2-VL. Use 4.54.0 or later.")
37-
else:
38-
from transformers.models.qwen2_5_vl.modeling_qwen2_5_vl import Qwen2_5_VLFlashAttention2
39-
from transformers.models.qwen2_vl.modeling_qwen2_vl import Qwen2VLFlashAttention2
40-
41-
Qwen2VLFlashAttention2.forward = qwen2_vl_attn_forward
42-
Qwen2_5_VLFlashAttention2.forward = qwen2_vl_attn_forward
43-
44-
if is_transformers_version_greater_than("4.52.0"):
45-
from transformers.models.qwen2_5_vl.modeling_qwen2_5_vl import (
46-
Qwen2_5_VLForConditionalGeneration,
47-
Qwen2_5_VLModel,
48-
)
49-
from transformers.models.qwen2_vl.modeling_qwen2_vl import Qwen2VLForConditionalGeneration, Qwen2VLModel
50-
51-
Qwen2VLModel.forward = qwen2_vl_base_forward_new
52-
Qwen2_5_VLModel.forward = qwen2_vl_base_forward_new
53-
Qwen2VLForConditionalGeneration.forward = qwen2_vl_forward_new
54-
Qwen2_5_VLForConditionalGeneration.forward = qwen2_vl_forward_new
55-
else:
56-
from transformers.models.qwen2_5_vl.modeling_qwen2_5_vl import Qwen2_5_VLForConditionalGeneration
57-
from transformers.models.qwen2_vl.modeling_qwen2_vl import Qwen2VLForConditionalGeneration
58-
59-
Qwen2VLForConditionalGeneration.forward = qwen2_vl_forward_old
60-
Qwen2_5_VLForConditionalGeneration.forward = qwen2_vl_forward_old
6141
else:
6242
raise NotImplementedError(f"Model architecture {model_type} is not supported yet.")
43+
44+
if model_type in SUPPORTED_VLM_TYPE:
45+
from transformers.models.qwen2_5_vl.modeling_qwen2_5_vl import (
46+
Qwen2_5_VLForConditionalGeneration,
47+
Qwen2_5_VLModel,
48+
)
49+
from transformers.models.qwen2_vl.modeling_qwen2_vl import Qwen2VLForConditionalGeneration, Qwen2VLModel
50+
51+
# fix text-image mixed data
52+
Qwen2VLModel.forward = qwen2_vl_base_forward
53+
Qwen2_5_VLModel.forward = qwen2_vl_base_forward
54+
# TODO: add linear cross entropy kernels
55+
Qwen2VLForConditionalGeneration.forward = qwen2_vl_model_forward
56+
Qwen2_5_VLForConditionalGeneration.forward = qwen2_vl_model_forward

verl/models/transformers/flash_attention_utils.py

Lines changed: 13 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -47,11 +47,12 @@ def prepare_fa2_from_position_ids(
4747
query = query.contiguous().view(-1, query.size(-2), query.size(-1))
4848
key = key.contiguous().view(-1, key.size(-2), key.size(-1))
4949
value = value.contiguous().view(-1, value.size(-2), value.size(-1))
50+
tensor_kwargs = {"dtype": torch.int32, "device": position_ids.device}
5051
position_ids = position_ids.view(-1)
5152
cu_seqlens = torch.cat(
5253
(
53-
(position_ids == 0).nonzero().view(-1).to(torch.int32),
54-
torch.tensor(position_ids.size(), device=position_ids.device, dtype=torch.int32),
54+
(position_ids == 0).nonzero().view(-1).to(**tensor_kwargs),
55+
torch.tensor(position_ids.size(), **tensor_kwargs),
5556
)
5657
)
5758
max_length = cu_seqlens.diff().max() # use cu_seqlens to infer max_length for qwen2vl mrope
@@ -90,12 +91,9 @@ def _custom_flash_attention_forward(
9091
query_states, key_states, value_states, target_dtype=torch.bfloat16
9192
)
9293

93-
if position_ids is not None:
94-
assert position_ids.ndim == 2 # (batch_size, seq_length)
95-
9694
sp_size = get_ulysses_sequence_parallel_world_size()
9795
if sp_size > 1:
98-
# qkv: (batch_size, seq_length, num_head, head_size)
96+
# qkv: (batch_size, seq_length / sp_size, num_head, head_size)
9997
query_states = gather_seq_scatter_heads(query_states, seq_dim=1, head_dim=2)
10098
key_states = gather_seq_scatter_heads(key_states, seq_dim=1, head_dim=2)
10199
value_states = gather_seq_scatter_heads(value_states, seq_dim=1, head_dim=2)
@@ -105,19 +103,17 @@ def _custom_flash_attention_forward(
105103

106104
if position_ids is not None and query_length != 1 and not (torch.diff(position_ids, dim=-1) >= 0).all():
107105
batch_size = query_states.size(0)
108-
query_states, key_states, value_states, cu_seq_lens, max_seq_lens = prepare_fa2_from_position_ids(
106+
q, k, v, (cu_seqlens_q, cu_seqlens_k), (max_seqlen_q, max_seqlen_k) = prepare_fa2_from_position_ids(
109107
query_states, key_states, value_states, position_ids
110108
)
111-
cu_seqlens_q, cu_seqlens_k = cu_seq_lens
112-
max_seqlen_in_batch_q, max_seqlen_in_batch_k = max_seq_lens
113109
attn_output = flash_attn_varlen_func(
114-
query_states,
115-
key_states,
116-
value_states,
110+
q,
111+
k,
112+
v,
117113
cu_seqlens_q=cu_seqlens_q,
118114
cu_seqlens_k=cu_seqlens_k,
119-
max_seqlen_q=max_seqlen_in_batch_q,
120-
max_seqlen_k=max_seqlen_in_batch_k,
115+
max_seqlen_q=max_seqlen_q,
116+
max_seqlen_k=max_seqlen_k,
121117
dropout_p=kwargs.pop("dropout", 0.0),
122118
softmax_scale=kwargs.pop("softmax_scale", None),
123119
causal=is_causal,
@@ -132,14 +128,15 @@ def _custom_flash_attention_forward(
132128
attention_mask,
133129
query_length,
134130
is_causal=is_causal,
131+
position_ids=position_ids,
135132
sliding_window=sliding_window,
136133
use_top_left_mask=use_top_left_mask,
137134
deterministic=deterministic,
138135
**kwargs,
139-
) # do not pass position_ids to old flash_attention_forward
136+
)
140137

141138
if sp_size > 1:
142-
# (batch_size, seq_length, num_head, head_size)
139+
# output: (batch_size, seq_length / sp_size, num_head, head_size)
143140
attn_output = gather_heads_scatter_seq(attn_output, head_dim=2, seq_dim=1)
144141

145142
return attn_output

0 commit comments

Comments
 (0)