Skip to content

Commit 9584708

Browse files
committed
Update: apply liger
1 parent 812b050 commit 9584708

1 file changed

Lines changed: 25 additions & 25 deletions

File tree

src/liger_kernel/transformers/monkey_patch.py

Lines changed: 25 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -667,6 +667,8 @@ def apply_liger_kernel_to_paligemma(
667667
modeling_paligemma.PaliGemmaForConditionalGeneration.forward = lce_forward_deprecated
668668

669669
if model is not None:
670+
text_model_name = model.config.text_config.model_type
671+
text_liger_fn = MODEL_TYPE_TO_APPLY_LIGER_FN.get(text_model_name, None)
670672
# The model instance already exists, so we need to additionally patch the
671673
# instance variables that reference already-instantiated modules
672674

@@ -683,31 +685,29 @@ def apply_liger_kernel_to_paligemma(
683685
_patch_layer_norm_module(layer.layer_norm1)
684686
_patch_layer_norm_module(layer.layer_norm2)
685687

686-
language_model = model.language_model
687-
688-
if isinstance(language_model, GemmaForCausalLM):
689-
apply_liger_kernel_to_gemma(
690-
rope=rope,
691-
cross_entropy=False,
692-
fused_linear_cross_entropy=False,
693-
rms_norm=rms_norm,
694-
geglu=geglu,
695-
model=language_model,
696-
)
697-
698-
elif isinstance(language_model, Gemma2ForCausalLM):
699-
apply_liger_kernel_to_gemma2(
700-
rope=rope,
701-
cross_entropy=False,
702-
fused_linear_cross_entropy=False,
703-
rms_norm=rms_norm,
704-
geglu=geglu,
705-
model=language_model,
706-
)
707-
else:
708-
raise TypeError(
709-
"The language_model of a PaliGemma model must be either GemmaForCausalLM or Gemma2ForCausalLM."
710-
)
688+
kwargs = {
689+
"rope": rope,
690+
"cross_entropy": False,
691+
"fused_linear_cross_entropy": False,
692+
"rms_norm": rms_norm,
693+
"geglu": geglu,
694+
"model": model.language_model,
695+
}
696+
if text_liger_fn:
697+
accept_params = inspect.signature(text_liger_fn).parameters
698+
remain_params = set(kwargs) - (set(accept_params) & set(kwargs))
699+
text_kwargs = {k: v for k, v in kwargs.items() if k not in remain_params}
700+
701+
if remain_params:
702+
logger.warning(
703+
f"These parameters are not supported by {text_model_name}. Enter the remaining {list(text_kwargs.keys())} except for {list(remain_params)}\n"
704+
f"Parameters accepted by {text_model_name}: {list(accept_params.keys())}"
705+
)
706+
text_kwargs["model"] = model.language_model
707+
text_liger_fn(**text_kwargs)
708+
709+
elif text_model_name not in MODEL_TYPE_TO_APPLY_LIGER_FN:
710+
logger.warning(f"{text_model_name} is not supported by Liger kernel.")
711711

712712

713713
def apply_liger_kernel_to_qwen2(

0 commit comments

Comments
 (0)