-
Notifications
You must be signed in to change notification settings - Fork 388
Expand file tree
/
Copy pathmodeling_qwen25_vl.py
More file actions
282 lines (247 loc) · 13.1 KB
/
Copy pathmodeling_qwen25_vl.py
File metadata and controls
282 lines (247 loc) · 13.1 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
# Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import types
from typing import Optional
import torch
import transformers
from megatron.core.inference.contexts import BaseInferenceContext
from megatron.core.packed_seq_params import PackedSeqParams
from megatron.core.tensor_parallel import scatter_to_sequence_parallel_region
from megatron.core.transformer.module import MegatronModule
from packaging.version import Version as PkgVersion
from torch import Tensor
from transformers.models.qwen2_5_vl.modeling_qwen2_5_vl import (
Qwen2_5_VisionTransformerPretrainedModel,
Qwen2_5_VLModel,
)
from megatron.bridge.models.gpt_provider import GPTModelProvider
from megatron.bridge.utils.common_utils import hook_hf_module_setattr_for_tp_grad_sync
def is_transformers_min_version(version):
"""Check if minimum version of transformers is installed."""
try:
transformers_version = PkgVersion(transformers.__version__)
return transformers_version >= PkgVersion(version)
except Exception:
# If version parsing fails, assume false for safety
return False
class Qwen25VLModel(MegatronModule):
"""
Qwen2.5 VL Model. (Based on GPT Transformer language model.)
Args:
config (GPTModelProvider):
language model provider.
transformer_layer_spec (ModuleSpec):
Specifies module to use for transformer layers
vocab_size (int):
Vocabulary size
max_sequence_length (int):
maximum size of sequence. This is used for positional embedding
pre_process (bool, optional):
Include embedding layer (used with pipeline parallelism). Defaults to True.
post_process (bool, optional):
Include an output layer (used with pipeline parallelism). Defaults to True.
fp16_lm_cross_entropy (bool, optional):
Defaults to False.
parallel_output (bool, optional):
Do not gather the outputs, keep them split across tensor
parallel ranks. Defaults to True.
share_embeddings_and_output_weights (bool, optional):
When True, input embeddings and output logit weights are shared. Defaults to False.
position_embedding_type (Literal[learned_absolute,rope], optional):
Position embedding type.. Defaults to 'learned_absolute'.
rotary_percent (float, optional):
Percent of rotary dimension to use for rotary position embeddings.
Ignored unless position_embedding_type is 'rope'. Defaults to 1.0.
rotary_base (int, optional):
Base period for rotary position embeddings. Ignored unless
position_embedding_type is 'rope'.
Defaults to 10000.
rope_scaling (bool, optional): Toggle RoPE scaling.
rope_scaling_factor (float): RoPE scaling factor. Default 8.
scatter_embedding_sequence_parallel (bool, optional):
Whether embeddings should be scattered across sequence parallel
region or not. Defaults to True.
seq_len_interpolation_factor (Optional[float], optional):
scale of linearly interpolating RoPE for longer sequences.
The value must be a float larger than 1.0. Defaults to None.
pg_collection (ProcessGroupCollection): Model communication process groups
"""
def __init__(
self,
config: GPTModelProvider,
pre_process: bool = True,
post_process: bool = True,
vp_stage: Optional[int] = None,
) -> None:
super().__init__(config=config)
self.pre_process = pre_process
self.post_process = post_process
self.vp_stage = vp_stage
if pre_process:
self.visual = Qwen2_5_VisionTransformerPretrainedModel._from_config(config.vision_config)
# Ensure HF visual tower params are marked for TP grad sync and future assignments are hooked.
hook_hf_module_setattr_for_tp_grad_sync(self.visual)
self.language_model = self.config.provide_language_model(
pre_process=pre_process, post_process=post_process, vp_stage=vp_stage
)
# Finalize grad will need these to be bind with module
self.share_embeddings_and_output_weights = config.share_embeddings_and_output_weights
self.shared_embedding_or_output_weight = self.language_model.shared_embedding_or_output_weight
# Bind methods from HF's Qwen2_5_VLModel to this instance
# get_placeholder_mask is only available in transformers 4.55+
if is_transformers_min_version("4.55.0"):
self.get_placeholder_mask = types.MethodType(Qwen2_5_VLModel.get_placeholder_mask, self)
else:
raise RuntimeError(
f"transformers version {transformers.__version__} is not supported. "
f"get_placeholder_mask requires transformers >= 4.55.0. "
f"Please upgrade transformers: pip install 'transformers>=4.55.0'"
)
self.get_image_features = types.MethodType(Qwen2_5_VLModel.get_image_features, self)
self.get_video_features = types.MethodType(Qwen2_5_VLModel.get_video_features, self)
self.get_rope_index = types.MethodType(Qwen2_5_VLModel.get_rope_index, self)
# get_vision_position_ids is only available in transformers 5.3.0+
if is_transformers_min_version("5.3.0"):
self.get_vision_position_ids = types.MethodType(Qwen2_5_VLModel.get_vision_position_ids, self)
@property
def decoder(self):
"""Expose language model decoder for mcore inference compatibility.
mcore's MambaInferenceStateConfig.from_model() calls get_attr_wrapped_model(model, "decoder"),
which only traverses .module wrappers. VLM models store the decoder under language_model.decoder,
so we expose it here to allow the Mamba check to run and correctly return None.
"""
return getattr(self.language_model, "decoder", None)
def set_input_tensor(self, input_tensor) -> None:
"""Set model chunk input tensor."""
self.language_model.set_input_tensor(input_tensor)
def forward(
self,
input_ids: torch.LongTensor = None,
attention_mask: Optional[torch.Tensor] = None,
position_ids: Optional[torch.LongTensor] = None,
inputs_embeds: Optional[torch.FloatTensor] = None,
pixel_values: Optional[torch.Tensor] = None,
pixel_values_videos: Optional[torch.FloatTensor] = None,
image_grid_thw: Optional[torch.LongTensor] = None,
video_grid_thw: Optional[torch.LongTensor] = None,
second_per_grid_ts: Optional[torch.Tensor] = None,
mm_token_type_ids: Optional[torch.IntTensor] = None,
labels: Tensor = None,
inference_context: BaseInferenceContext = None,
packed_seq_params: PackedSeqParams = None,
extra_block_kwargs: dict = None,
runtime_gather_output: Optional[bool] = None,
*,
inference_params: Optional[BaseInferenceContext] = None,
loss_mask: Optional[Tensor] = None,
) -> Tensor:
r"""
image_grid_thw (`torch.LongTensor` of shape `(num_images, 3)`, *optional*):
The temporal, height and width of feature shape of each image in LLM.
video_grid_thw (`torch.LongTensor` of shape `(num_videos, 3)`, *optional*):
The temporal, height and width of feature shape of each video in LLM.
second_per_grid_ts (`torch.Tensor` of shape `(num_videos)`, *optional*):
The time interval (in seconds) for each grid along the temporal dimension in the 3D position IDs.
mm_token_type_ids (`torch.IntTensor` of shape `(batch_size, sequence_length)`, *optional*):
Token type IDs distinguishing text (0) from multimodal (1) tokens. Required by transformers >= 5.3.0.
"""
if self.pre_process:
if inputs_embeds is None:
inputs_embeds = self.language_model.embedding(
input_ids=input_ids, position_ids=None
) # [decoder_seq_len, b, h_language]
inputs_embeds = inputs_embeds.transpose(1, 0).contiguous() # [b, decoder_seq_len, h_language]
if pixel_values is not None:
image_embeds = self.get_image_features(pixel_values, image_grid_thw, return_dict=True).pooler_output
image_embeds = torch.cat(image_embeds, dim=0).to(inputs_embeds.device, inputs_embeds.dtype)
image_mask, _ = self.get_placeholder_mask(
input_ids, inputs_embeds=inputs_embeds, image_features=image_embeds
)
inputs_embeds = inputs_embeds.masked_scatter(image_mask, image_embeds)
if pixel_values_videos is not None:
video_embeds = self.get_video_features(
pixel_values_videos, video_grid_thw, return_dict=True
).pooler_output
video_embeds = torch.cat(video_embeds, dim=0).to(inputs_embeds.device, inputs_embeds.dtype)
_, video_mask = self.get_placeholder_mask(
input_ids, inputs_embeds=inputs_embeds, video_features=video_embeds
)
inputs_embeds = inputs_embeds.masked_scatter(video_mask, video_embeds)
inputs_embeds = inputs_embeds.transpose(
1, 0
) # [b, decoder_seq_len, h_language] -> [decoder_seq_len, b, h_language]
if self.config.sequence_parallel:
tp_group = self.config._pg_collection.tp if self.config._pg_collection is not None else None
inputs_embeds = scatter_to_sequence_parallel_region(inputs_embeds, group=tp_group)
# Compute MRoPE position_ids on ALL pipeline stages
# Each stage has input_ids and visual grid info from the data iterator
# This avoids any broadcasting overhead
hf_attention_mask = None
# Build mm_token_type_ids: 0=text, 1=image, 2=video
mm_token_type_ids = torch.zeros_like(input_ids, dtype=torch.int)
mm_token_type_ids[input_ids == self.config.image_token_id] = 1
mm_token_type_ids[input_ids == self.config.video_token_id] = 2
# In transformers 5.3.0+, get_rope_index requires mm_token_type_ids as the second argument
if is_transformers_min_version("5.3.0"):
position_ids, rope_deltas = self.get_rope_index(
input_ids,
mm_token_type_ids,
image_grid_thw,
video_grid_thw,
second_per_grid_ts=second_per_grid_ts,
attention_mask=hf_attention_mask,
)
else:
position_ids, rope_deltas = self.get_rope_index(
input_ids,
image_grid_thw,
video_grid_thw,
second_per_grid_ts=second_per_grid_ts,
attention_mask=hf_attention_mask,
)
outputs = self.language_model.forward(
input_ids=None,
position_ids=position_ids,
attention_mask=attention_mask,
decoder_input=inputs_embeds,
labels=labels,
loss_mask=loss_mask,
runtime_gather_output=runtime_gather_output,
packed_seq_params=packed_seq_params,
)
return outputs
def freeze(self, freeze_language_model: bool, freeze_vision_model: bool, freeze_vision_projection: bool):
"""Freeze model modules.
Make specific modules non-trainable by setting requires_grad to False.
Args:
freeze_language_model (bool): Freeze the language model module.
freeze_vision_model (bool): Freeze the vision model module (patch_embed and blocks).
freeze_vision_projection (bool): Freeze the vision projection module (merger).
"""
modules = []
if freeze_language_model and hasattr(self, "language_model") and self.language_model is not None:
modules.append(self.language_model)
if freeze_vision_model and hasattr(self, "visual") and self.visual is not None:
# Vision model consists of patch_embed and blocks
if hasattr(self.visual, "patch_embed"):
modules.append(self.visual.patch_embed)
if hasattr(self.visual, "blocks"):
modules.append(self.visual.blocks)
if freeze_vision_projection and hasattr(self, "visual") and self.visual is not None:
# Vision projection is the merger module
if hasattr(self.visual, "merger"):
modules.append(self.visual.merger)
for module in modules:
for param in module.parameters():
param.requires_grad = False