@@ -277,6 +277,8 @@ def load_model(
277277 freeze_model (model , model_args )
278278 else :
279279 model = setup_lora_training (config , model , model_args , is_trainable )
280+ if not model_args .disable_gradient_checkpointing and hasattr (model , "enable_input_require_grads" ):
281+ model .enable_input_require_grads ()
280282
281283 if add_valuehead :
282284 from trl import AutoModelForCausalLMWithValueHead
@@ -710,8 +712,6 @@ def get_extra_data_provider(model_name_or_path: str, processor=None):
710712 if isinstance (model_type , str ) and (("qwen2" in model_type ) or (model_type in ("qwen3_vl" , "qwen3_vl_moe" ))):
711713 import types
712714
713- from transformers import BatchFeature # help define a object to accesss attr
714-
715715 def _call_get_rope_index (fn , input_ids : torch .LongTensor , ** candidate_kwargs ):
716716 sig = inspect .signature (fn )
717717 params = sig .parameters
@@ -745,17 +745,13 @@ def _call_get_rope_index(fn, input_ids: torch.LongTensor, **candidate_kwargs):
745745 "<|vision_start|>"
746746 )
747747
748- dummy_self = BatchFeature (
749- {
750- "config" : BatchFeature (
751- {
752- "vision_config" : BatchFeature (vc ),
753- "image_token_id" : image_token_id ,
754- "video_token_id" : video_token_id ,
755- "vision_start_token_id" : vision_start_token_id ,
756- }
757- )
758- }
748+ dummy_self = types .SimpleNamespace (
749+ config = types .SimpleNamespace (
750+ vision_config = types .SimpleNamespace (** vc ),
751+ image_token_id = image_token_id ,
752+ video_token_id = video_token_id ,
753+ vision_start_token_id = vision_start_token_id ,
754+ )
759755 )
760756
761757 is_tf_ge_4_52 = is_transformers_version_greater_than ("4.52.0" )
@@ -771,6 +767,9 @@ def _call_get_rope_index(fn, input_ids: torch.LongTensor, **candidate_kwargs):
771767 elif model_type in ("qwen3_vl" , "qwen3_vl_moe" ):
772768 from transformers .models .qwen3_vl .modeling_qwen3_vl import Qwen3VLModel
773769
770+ dummy_self .get_vision_position_ids = types .MethodType (
771+ Qwen3VLModel .get_vision_position_ids , dummy_self
772+ )
774773 get_rope_index = types .MethodType (Qwen3VLModel .get_rope_index , dummy_self )
775774 else :
776775 if is_tf_ge_4_52 :
@@ -787,8 +786,15 @@ def extra_data_provider(
787786 image_grid_thw : Optional [torch .LongTensor ] = None ,
788787 video_grid_thw : Optional [torch .LongTensor ] = None ,
789788 attention_mask : Optional [torch .Tensor ] = None ,
789+ mm_token_type_ids : Optional [torch .Tensor ] = None ,
790790 second_per_grid_ts : Optional [torch .Tensor ] = None ,
791791 ):
792+ if model_type in ("qwen3_vl" , "qwen3_vl_moe" ) and mm_token_type_ids is None :
793+ mm_token_type_ids = torch .zeros_like (input_ids )
794+ if image_token_id is not None :
795+ mm_token_type_ids = torch .where (input_ids == image_token_id , 1 , mm_token_type_ids )
796+ if video_token_id is not None :
797+ mm_token_type_ids = torch .where (input_ids == video_token_id , 2 , mm_token_type_ids )
792798 # Keep kwargs to be resilient to HF signature changes between versions/models.
793799 out = _call_get_rope_index (
794800 get_rope_index ,
@@ -797,6 +803,7 @@ def extra_data_provider(
797803 video_grid_thw = video_grid_thw ,
798804 second_per_grid_ts = second_per_grid_ts ,
799805 attention_mask = attention_mask ,
806+ mm_token_type_ids = mm_token_type_ids ,
800807 )
801808 rope_index = out [0 ]
802809 # PumpkinComment:
0 commit comments