@@ -483,6 +483,59 @@ def make_base_actions():
483483 mappings = make_base_actions ()
484484 return mappings
485485
486+ @classmethod
487+ def _gen_aoa_config (cls , config : Gemma3TextConfig ):
488+ model_prefix = "" if cls == cls .base_model_prefix else "model."
489+ aoa_config = {
490+ "aoa_statements" : [
491+ # load tied weight
492+ "model.embed_tokens.weight -> lm_head.weight" ,
493+ # others
494+ f"model.embed_tokens.weight -> { model_prefix } embed_tokens.weight" ,
495+ f"model.norm.weight -> { model_prefix } norm.weight" ,
496+ f"model.layers.$LAYER_ID.input_layernorm.weight -> { model_prefix } layers.$LAYER_ID.input_layernorm.weight" ,
497+ f"model.layers.$LAYER_ID.post_attention_layernorm.weight -> { model_prefix } layers.$LAYER_ID.post_attention_layernorm.weight" ,
498+ f"model.layers.$LAYER_ID.pre_feedforward_layernorm.weight -> { model_prefix } layers.$LAYER_ID.pre_feedforward_layernorm.weight" ,
499+ f"model.layers.$LAYER_ID.post_feedforward_layernorm.weight -> { model_prefix } layers.$LAYER_ID.post_feedforward_layernorm.weight" ,
500+ # do transpose
501+ f"model.layers.$LAYER_ID.mlp.gate_proj.weight^T -> { model_prefix } layers.$LAYER_ID.mlp.gate_proj.weight" ,
502+ f"model.layers.$LAYER_ID.mlp.up_proj.weight^T -> { model_prefix } layers.$LAYER_ID.mlp.up_proj.weight" ,
503+ f"model.layers.$LAYER_ID.mlp.down_proj.weight^T -> { model_prefix } layers.$LAYER_ID.mlp.down_proj.weight" ,
504+ f"model.layers.$LAYER_ID.self_attn.q_proj.weight^T -> { model_prefix } layers.$LAYER_ID.self_attn.q_proj.weight" ,
505+ f"model.layers.$LAYER_ID.self_attn.k_proj.weight^T -> { model_prefix } layers.$LAYER_ID.self_attn.k_proj.weight" ,
506+ f"model.layers.$LAYER_ID.self_attn.v_proj.weight^T -> { model_prefix } layers.$LAYER_ID.self_attn.v_proj.weight" ,
507+ f"model.layers.$LAYER_ID.self_attn.o_proj.weight^T -> { model_prefix } layers.$LAYER_ID.self_attn.o_proj.weight" ,
508+ ]
509+ }
510+ return aoa_config
511+
512+ # NOTE: These aoa_config items will be removed later. The subsequent AOA parsing module will automatically generate the reverse AOA based on the forward (from_pretrained) AOA.
513+ @classmethod
514+ def _gen_inv_aoa_config (cls , config : Gemma3TextConfig ):
515+ model_prefix = "" if cls == cls .base_model_prefix else "model."
516+ aoa_statements = [
517+ # ignore tied weights
518+ "lm_head.weight -> _" ,
519+ # do transpose
520+ f"{ model_prefix } layers.$LAYER_ID.mlp.gate_proj.weight^T -> model.layers.$LAYER_ID.mlp.gate_proj.weight" ,
521+ f"{ model_prefix } layers.$LAYER_ID.mlp.up_proj.weight^T -> model.layers.$LAYER_ID.mlp.up_proj.weight" ,
522+ f"{ model_prefix } layers.$LAYER_ID.mlp.down_proj.weight^T -> model.layers.$LAYER_ID.mlp.down_proj.weight" ,
523+ f"{ model_prefix } layers.$LAYER_ID.self_attn.q_proj.weight^T -> model.layers.$LAYER_ID.self_attn.q_proj.weight" ,
524+ f"{ model_prefix } layers.$LAYER_ID.self_attn.k_proj.weight^T -> model.layers.$LAYER_ID.self_attn.k_proj.weight" ,
525+ f"{ model_prefix } layers.$LAYER_ID.self_attn.v_proj.weight^T -> model.layers.$LAYER_ID.self_attn.v_proj.weight" ,
526+ f"{ model_prefix } layers.$LAYER_ID.self_attn.o_proj.weight^T -> model.layers.$LAYER_ID.self_attn.o_proj.weight" ,
527+ # others
528+ f"{ model_prefix } embed_tokens.weight -> model.embed_tokens.weight" ,
529+ f"{ model_prefix } norm.weight -> model.norm.weight" ,
530+ f"{ model_prefix } layers.$LAYER_ID.input_layernorm.weight -> model.layers.$LAYER_ID.input_layernorm.weight" ,
531+ f"{ model_prefix } layers.$LAYER_ID.post_attention_layernorm.weight -> model.layers.$LAYER_ID.post_attention_layernorm.weight" ,
532+ f"{ model_prefix } layers.$LAYER_ID.pre_feedforward_layernorm.weight -> model.layers.$LAYER_ID.pre_feedforward_layernorm.weight" ,
533+ f"{ model_prefix } layers.$LAYER_ID.post_feedforward_layernorm.weight -> model.layers.$LAYER_ID.post_feedforward_layernorm.weight" ,
534+ ]
535+
536+ aoa_config = {"aoa_statements" : aoa_statements }
537+ return aoa_config
538+
486539
487540class Gemma3TextModel (Gemma3PreTrainedModel ):
488541 config_class = Gemma3TextConfig
@@ -869,6 +922,8 @@ class Gemma3ForCausalLMPipe(GeneralModelForCausalLMPipe):
869922 _keep_in_fp32_modules = Gemma3TextModel ._keep_in_fp32_modules
870923 _tied_weights_keys = ["lm_head.weight" ]
871924 transpose_weight_keys = Gemma3TextModel .transpose_weight_keys
925+ _gen_aoa_config = Gemma3ForCausalLM ._gen_aoa_config
926+ _gen_inv_aoa_config = Gemma3ForCausalLM ._gen_inv_aoa_config
872927
873928
874929__all__ = [
0 commit comments