@@ -4657,22 +4657,29 @@ def __init__(
46574657 model : "PreTrainedModel" ,
46584658 model_kwargs : Dict [str , Any ],
46594659 ):
4660- model .__orig_forward = model .forward
4661- # Adopted from https://github.com/huggingface/transformers/blob/v4.49.0/src/transformers/models/got_ocr2/modeling_got_ocr2.py#L835
4662- # Adopted from https://github.com/huggingface/transformers/blob/v4.49.0-Gemma-3/src/transformers/models/gemma3/modeling_gemma3.py#L1321
4663- if (
4664- hasattr (model , "model" )
4665- and hasattr (model .model , "get_image_features" )
4666- and is_transformers_version ("<" , "5" )
4667- ):
4668- model .forward = model .model .get_image_features
4669- else :
4670- model .forward = model .get_image_features
46714660 super ().__init__ (config , model , model_kwargs )
46724661
4673- def __exit__ (self , exc_type , exc_value , traceback ):
4674- super ().__exit__ (exc_type , exc_value , traceback )
4675- self ._model .forward = self ._model .__orig_forward
4662+ @functools .wraps (self .orig_forward )
4663+ def patched_forward (* args , ** kwargs ):
4664+ # Adapted from https://github.com/huggingface/transformers/blob/v4.49.0/src/transformers/models/got_ocr2/modeling_got_ocr2.py#L835
4665+ # Adapted from https://github.com/huggingface/transformers/blob/v4.49.0-Gemma-3/src/transformers/models/gemma3/modeling_gemma3.py#L1321
4666+ if (
4667+ hasattr (self ._model , "model" )
4668+ and hasattr (self ._model .model , "get_image_features" )
4669+ and is_transformers_version ("<" , "5" )
4670+ ):
4671+ get_image_features = self ._model .model .get_image_features
4672+ else :
4673+ get_image_features = self ._model .get_image_features
4674+
4675+ outputs = get_image_features (* args , ** kwargs )
4676+
4677+ if is_transformers_version (">=" , "5" ):
4678+ outputs = BaseModelOutputWithPooling (pooler_output = outputs .pooler_output )
4679+
4680+ return outputs
4681+
4682+ self .patched_forward = patched_forward
46764683
46774684
46784685# Adopted from https://github.com/huggingface/transformers/blob/v4.49.0-Gemma-3/src/transformers/models/gemma3/modeling_gemma3.py#L1147
0 commit comments