Skip to content

Commit b978da4

Browse files
committed
【Gemma3】apply flexcheckpoint
1 parent bd6586f commit b978da4

File tree

1 file changed

+55
-0
lines changed

1 file changed

+55
-0
lines changed

paddleformers/transformers/gemma3_text/modeling.py

Lines changed: 55 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -483,6 +483,59 @@ def make_base_actions():
483483
mappings = make_base_actions()
484484
return mappings
485485

486+
@classmethod
487+
def _gen_aoa_config(cls, config: Gemma3TextConfig):
488+
model_prefix = "" if cls == cls.base_model_prefix else "model."
489+
aoa_config = {
490+
"aoa_statements": [
491+
# load tied weight
492+
"model.embed_tokens.weight -> lm_head.weight",
493+
# others
494+
f"model.embed_tokens.weight -> {model_prefix}embed_tokens.weight",
495+
f"model.norm.weight -> {model_prefix}norm.weight",
496+
f"model.layers.$LAYER_ID.input_layernorm.weight -> {model_prefix}layers.$LAYER_ID.input_layernorm.weight",
497+
f"model.layers.$LAYER_ID.post_attention_layernorm.weight -> {model_prefix}layers.$LAYER_ID.post_attention_layernorm.weight",
498+
f"model.layers.$LAYER_ID.pre_feedforward_layernorm.weight -> {model_prefix}layers.$LAYER_ID.pre_feedforward_layernorm.weight",
499+
f"model.layers.$LAYER_ID.post_feedforward_layernorm.weight -> {model_prefix}layers.$LAYER_ID.post_feedforward_layernorm.weight",
500+
# do transpose
501+
f"model.layers.$LAYER_ID.mlp.gate_proj.weight^T -> {model_prefix}layers.$LAYER_ID.mlp.gate_proj.weight",
502+
f"model.layers.$LAYER_ID.mlp.up_proj.weight^T -> {model_prefix}layers.$LAYER_ID.mlp.up_proj.weight",
503+
f"model.layers.$LAYER_ID.mlp.down_proj.weight^T -> {model_prefix}layers.$LAYER_ID.mlp.down_proj.weight",
504+
f"model.layers.$LAYER_ID.self_attn.q_proj.weight^T -> {model_prefix}layers.$LAYER_ID.self_attn.q_proj.weight",
505+
f"model.layers.$LAYER_ID.self_attn.k_proj.weight^T -> {model_prefix}layers.$LAYER_ID.self_attn.k_proj.weight",
506+
f"model.layers.$LAYER_ID.self_attn.v_proj.weight^T -> {model_prefix}layers.$LAYER_ID.self_attn.v_proj.weight",
507+
f"model.layers.$LAYER_ID.self_attn.o_proj.weight^T -> {model_prefix}layers.$LAYER_ID.self_attn.o_proj.weight",
508+
]
509+
}
510+
return aoa_config
511+
512+
# NOTE: These aoa_config items will be removed later. The subsequent AOA parsing module will automatically generate the reverse AOA based on the forward (from_pretrained) AOA.
513+
@classmethod
514+
def _gen_inv_aoa_config(cls, config: Gemma3TextConfig):
515+
model_prefix = "" if cls == cls.base_model_prefix else "model."
516+
aoa_statements = [
517+
# ignore tied weights
518+
"lm_head.weight -> _",
519+
# do transpose
520+
f"{model_prefix}layers.$LAYER_ID.mlp.gate_proj.weight^T -> model.layers.$LAYER_ID.mlp.gate_proj.weight",
521+
f"{model_prefix}layers.$LAYER_ID.mlp.up_proj.weight^T -> model.layers.$LAYER_ID.mlp.up_proj.weight",
522+
f"{model_prefix}layers.$LAYER_ID.mlp.down_proj.weight^T -> model.layers.$LAYER_ID.mlp.down_proj.weight",
523+
f"{model_prefix}layers.$LAYER_ID.self_attn.q_proj.weight^T -> model.layers.$LAYER_ID.self_attn.q_proj.weight",
524+
f"{model_prefix}layers.$LAYER_ID.self_attn.k_proj.weight^T -> model.layers.$LAYER_ID.self_attn.k_proj.weight",
525+
f"{model_prefix}layers.$LAYER_ID.self_attn.v_proj.weight^T -> model.layers.$LAYER_ID.self_attn.v_proj.weight",
526+
f"{model_prefix}layers.$LAYER_ID.self_attn.o_proj.weight^T -> model.layers.$LAYER_ID.self_attn.o_proj.weight",
527+
# others
528+
f"{model_prefix}embed_tokens.weight -> model.embed_tokens.weight",
529+
f"{model_prefix}norm.weight -> model.norm.weight",
530+
f"{model_prefix}layers.$LAYER_ID.input_layernorm.weight -> model.layers.$LAYER_ID.input_layernorm.weight",
531+
f"{model_prefix}layers.$LAYER_ID.post_attention_layernorm.weight -> model.layers.$LAYER_ID.post_attention_layernorm.weight",
532+
f"{model_prefix}layers.$LAYER_ID.pre_feedforward_layernorm.weight -> model.layers.$LAYER_ID.pre_feedforward_layernorm.weight",
533+
f"{model_prefix}layers.$LAYER_ID.post_feedforward_layernorm.weight -> model.layers.$LAYER_ID.post_feedforward_layernorm.weight",
534+
]
535+
536+
aoa_config = {"aoa_statements": aoa_statements}
537+
return aoa_config
538+
486539

487540
class Gemma3TextModel(Gemma3PreTrainedModel):
488541
config_class = Gemma3TextConfig
@@ -869,6 +922,8 @@ class Gemma3ForCausalLMPipe(GeneralModelForCausalLMPipe):
869922
_keep_in_fp32_modules = Gemma3TextModel._keep_in_fp32_modules
870923
_tied_weights_keys = ["lm_head.weight"]
871924
transpose_weight_keys = Gemma3TextModel.transpose_weight_keys
925+
_gen_aoa_config = Gemma3ForCausalLM._gen_aoa_config
926+
_gen_inv_aoa_config = Gemma3ForCausalLM._gen_inv_aoa_config
872927

873928

874929
__all__ = [

0 commit comments

Comments
 (0)