@@ -667,6 +667,8 @@ def apply_liger_kernel_to_paligemma(
667667 modeling_paligemma .PaliGemmaForConditionalGeneration .forward = lce_forward_deprecated
668668
669669 if model is not None :
670+ text_model_name = model .config .text_config .model_type
671+ text_liger_fn = MODEL_TYPE_TO_APPLY_LIGER_FN .get (text_model_name , None )
670672 # The model instance already exists, so we need to additionally patch the
671673 # instance variables that reference already-instantiated modules
672674
@@ -683,31 +685,29 @@ def apply_liger_kernel_to_paligemma(
683685 _patch_layer_norm_module (layer .layer_norm1 )
684686 _patch_layer_norm_module (layer .layer_norm2 )
685687
686- language_model = model .language_model
687-
688- if isinstance (language_model , GemmaForCausalLM ):
689- apply_liger_kernel_to_gemma (
690- rope = rope ,
691- cross_entropy = False ,
692- fused_linear_cross_entropy = False ,
693- rms_norm = rms_norm ,
694- geglu = geglu ,
695- model = language_model ,
696- )
697-
698- elif isinstance (language_model , Gemma2ForCausalLM ):
699- apply_liger_kernel_to_gemma2 (
700- rope = rope ,
701- cross_entropy = False ,
702- fused_linear_cross_entropy = False ,
703- rms_norm = rms_norm ,
704- geglu = geglu ,
705- model = language_model ,
706- )
707- else :
708- raise TypeError (
709- "The language_model of a PaliGemma model must be either GemmaForCausalLM or Gemma2ForCausalLM."
710- )
688+ kwargs = {
689+ "rope" : rope ,
690+ "cross_entropy" : False ,
691+ "fused_linear_cross_entropy" : False ,
692+ "rms_norm" : rms_norm ,
693+ "geglu" : geglu ,
694+ "model" : model .language_model ,
695+ }
696+ if text_liger_fn :
697+ accept_params = inspect .signature (text_liger_fn ).parameters
698+ remain_params = set (kwargs ) - (set (accept_params ) & set (kwargs ))
699+ text_kwargs = {k : v for k , v in kwargs .items () if k not in remain_params }
700+
701+ if remain_params :
702+ logger .warning (
703+ f"These parameters are not supported by { text_model_name } . Enter the remaining { list (text_kwargs .keys ())} except for { list (remain_params )} \n "
704+ f"Parameters accepted by { text_model_name } : { list (accept_params .keys ())} "
705+ )
706+ text_kwargs ["model" ] = model .language_model
707+ text_liger_fn (** text_kwargs )
708+
709+ elif text_model_name not in MODEL_TYPE_TO_APPLY_LIGER_FN :
710+ logger .warning (f"{ text_model_name } is not supported by Liger kernel." )
711711
712712
713713def apply_liger_kernel_to_qwen2 (
0 commit comments