Skip to content

Commit dec9235

Browse files
author
Wangxiaoxiaoa
committed
feat: support mm_token_type_ids and 3D RoPE alignment for Qwen2/3-VL (#436)
1 parent 4bb7d74 commit dec9235

2 files changed

Lines changed: 40 additions & 13 deletions

File tree

roll/datasets/collator.py

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -138,6 +138,7 @@ def __call__(self, features: List[Dict[str, Any]]) -> Dict[str, Any]:
138138
# model_inputs for hf/deepspeed: input_id, attention_mask, pixel_values, image_grid_thw
139139
padded_features = defaultdict(list)
140140
un_padded_features = defaultdict(list)
141+
mm_token_type_id_features = []
141142
mm_feature_keys = set()
142143
for feature in features:
143144
# cannot process as batch directly though processor output as batch
@@ -165,6 +166,8 @@ def __call__(self, features: List[Dict[str, Any]]) -> Dict[str, Any]:
165166
model_inputs.pop(key)
166167
for key in filter(lambda k: k in model_inputs, self.padded_keys):
167168
padded_features[key].append(model_inputs.pop(key)[0])
169+
if "mm_token_type_ids" in model_inputs:
170+
mm_token_type_id_features.append(torch.as_tensor(model_inputs.pop("mm_token_type_ids")[0]))
168171
# mm feature fileds can be different because of mixed data
169172
mm_feature_keys = mm_feature_keys.union(model_inputs.keys())
170173
# to tensors except padded_keys which would be converted after padding
@@ -208,6 +211,22 @@ def __call__(self, features: List[Dict[str, Any]]) -> Dict[str, Any]:
208211
return_tensors=self.return_tensors,
209212
)
210213
batch.update(un_padded_features)
214+
if mm_token_type_id_features:
215+
target_len = batch["input_ids"].shape[-1]
216+
padded_mm_token_type_ids = []
217+
for token_type_ids in mm_token_type_id_features:
218+
pad_len = target_len - token_type_ids.shape[-1]
219+
if pad_len < 0:
220+
raise ValueError(
221+
f"mm_token_type_ids length {token_type_ids.shape[-1]} exceeds padded input length {target_len}"
222+
)
223+
pad = torch.zeros(pad_len, dtype=token_type_ids.dtype, device=token_type_ids.device)
224+
if self.tokenizer.padding_side == "left":
225+
token_type_ids = torch.cat([pad, token_type_ids], dim=-1)
226+
else:
227+
token_type_ids = torch.cat([token_type_ids, pad], dim=-1)
228+
padded_mm_token_type_ids.append(token_type_ids)
229+
batch["mm_token_type_ids"] = torch.stack(padded_mm_token_type_ids, dim=0)
211230

212231
# other custom data fields: mainly for specific position_ids currently
213232
# position_ids for qwen2-vl is optional and make sure it is a 3D tensor
@@ -226,6 +245,7 @@ def __call__(self, features: List[Dict[str, Any]]) -> Dict[str, Any]:
226245
kwargs[key] = fun_params[key].default
227246
extra_data = self.extra_data_provider(**kwargs)
228247
batch.update(extra_data)
248+
batch.pop("mm_token_type_ids", None)
229249

230250
# each field should be a tensor or np.array(val=list_data, dtype=object)
231251
# to be stored in DataProto

roll/models/model_providers.py

Lines changed: 20 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -277,6 +277,8 @@ def load_model(
277277
freeze_model(model, model_args)
278278
else:
279279
model = setup_lora_training(config, model, model_args, is_trainable)
280+
if not model_args.disable_gradient_checkpointing and hasattr(model, "enable_input_require_grads"):
281+
model.enable_input_require_grads()
280282

281283
if add_valuehead:
282284
from trl import AutoModelForCausalLMWithValueHead
@@ -710,8 +712,6 @@ def get_extra_data_provider(model_name_or_path: str, processor=None):
710712
if isinstance(model_type, str) and (("qwen2" in model_type) or (model_type in ("qwen3_vl", "qwen3_vl_moe"))):
711713
import types
712714

713-
from transformers import BatchFeature # help define a object to accesss attr
714-
715715
def _call_get_rope_index(fn, input_ids: torch.LongTensor, **candidate_kwargs):
716716
sig = inspect.signature(fn)
717717
params = sig.parameters
@@ -745,17 +745,13 @@ def _call_get_rope_index(fn, input_ids: torch.LongTensor, **candidate_kwargs):
745745
"<|vision_start|>"
746746
)
747747

748-
dummy_self = BatchFeature(
749-
{
750-
"config": BatchFeature(
751-
{
752-
"vision_config": BatchFeature(vc),
753-
"image_token_id": image_token_id,
754-
"video_token_id": video_token_id,
755-
"vision_start_token_id": vision_start_token_id,
756-
}
757-
)
758-
}
748+
dummy_self = types.SimpleNamespace(
749+
config=types.SimpleNamespace(
750+
vision_config=types.SimpleNamespace(**vc),
751+
image_token_id=image_token_id,
752+
video_token_id=video_token_id,
753+
vision_start_token_id=vision_start_token_id,
754+
)
759755
)
760756

761757
is_tf_ge_4_52 = is_transformers_version_greater_than("4.52.0")
@@ -771,6 +767,9 @@ def _call_get_rope_index(fn, input_ids: torch.LongTensor, **candidate_kwargs):
771767
elif model_type in ("qwen3_vl", "qwen3_vl_moe"):
772768
from transformers.models.qwen3_vl.modeling_qwen3_vl import Qwen3VLModel
773769

770+
dummy_self.get_vision_position_ids = types.MethodType(
771+
Qwen3VLModel.get_vision_position_ids, dummy_self
772+
)
774773
get_rope_index = types.MethodType(Qwen3VLModel.get_rope_index, dummy_self)
775774
else:
776775
if is_tf_ge_4_52:
@@ -787,8 +786,15 @@ def extra_data_provider(
787786
image_grid_thw: Optional[torch.LongTensor] = None,
788787
video_grid_thw: Optional[torch.LongTensor] = None,
789788
attention_mask: Optional[torch.Tensor] = None,
789+
mm_token_type_ids: Optional[torch.Tensor] = None,
790790
second_per_grid_ts: Optional[torch.Tensor] = None,
791791
):
792+
if model_type in ("qwen3_vl", "qwen3_vl_moe") and mm_token_type_ids is None:
793+
mm_token_type_ids = torch.zeros_like(input_ids)
794+
if image_token_id is not None:
795+
mm_token_type_ids = torch.where(input_ids == image_token_id, 1, mm_token_type_ids)
796+
if video_token_id is not None:
797+
mm_token_type_ids = torch.where(input_ids == video_token_id, 2, mm_token_type_ids)
792798
# Keep kwargs to be resilient to HF signature changes between versions/models.
793799
out = _call_get_rope_index(
794800
get_rope_index,
@@ -797,6 +803,7 @@ def extra_data_provider(
797803
video_grid_thw=video_grid_thw,
798804
second_per_grid_ts=second_per_grid_ts,
799805
attention_mask=attention_mask,
806+
mm_token_type_ids=mm_token_type_ids,
800807
)
801808
rope_index = out[0]
802809
# PumpkinComment:

0 commit comments

Comments
 (0)