@@ -53,6 +53,33 @@ class NVLlamaPreTrainedModel(PreTrainedModel):
5353 _no_split_modules = ("TransformerLayer" ,)
5454 _skip_keys_device_placement = ("past_key_values" ,)
5555
56+ def _init_weights (self , module ):
57+ """TE-specific weight initialization."""
58+ super ()._init_weights (module )
59+
60+ # Copied from transformers.modeling_utils.PreTrainedModel._init_weights
61+ if hasattr (self .config , "initializer_range" ):
62+ std = self .config .initializer_range
63+ else :
64+ # 0.02 is the standard default value across the library
65+ std = getattr (self .config .get_text_config (), "initializer_range" , 0.02 )
66+
67+ if isinstance (
68+ module , (nn .Linear , transformer_engine .pytorch .Linear , transformer_engine .pytorch .LayerNormLinear )
69+ ):
70+ module .weight .data .normal_ (mean = 0.0 , std = std )
71+ if module .bias is not None :
72+ module .bias .data .zero_ ()
73+ if isinstance (module , transformer_engine .pytorch .LayerNorm ):
74+ if hasattr (module , "weight" ) and module .weight is not None :
75+ module .weight .data .fill_ (1.0 )
76+ if hasattr (module , "bias" ) and module .bias is not None :
77+ module .bias .data .zero_ ()
78+ if isinstance (module , transformer_engine .pytorch .LayerNormLinear ):
79+ module .layer_norm_weight .data .fill_ (1.0 )
80+ if module .layer_norm_bias is not None :
81+ module .layer_norm_bias .data .zero_ ()
82+
5683
5784class NVLlamaModel (NVLlamaPreTrainedModel ):
5885 """Llama3 model implemented in Transformer Engine."""
0 commit comments