Skip to content

Commit 5616f84

Browse files
authored
add te weight init (#1353)
Adds TE-specific weight initialization to _init_weights Signed-off-by: Peter St. John <pstjohn@nvidia.com>
1 parent 0ea4461 commit 5616f84

File tree

2 files changed

+54
-0
lines changed

2 files changed

+54
-0
lines changed

bionemo-recipes/models/llama3/modeling_llama_te.py

Lines changed: 27 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -53,6 +53,33 @@ class NVLlamaPreTrainedModel(PreTrainedModel):
5353
_no_split_modules = ("TransformerLayer",)
5454
_skip_keys_device_placement = ("past_key_values",)
5555

56+
def _init_weights(self, module):
57+
"""TE-specific weight initialization."""
58+
super()._init_weights(module)
59+
60+
# Copied from transformers.modeling_utils.PreTrainedModel._init_weights
61+
if hasattr(self.config, "initializer_range"):
62+
std = self.config.initializer_range
63+
else:
64+
# 0.02 is the standard default value across the library
65+
std = getattr(self.config.get_text_config(), "initializer_range", 0.02)
66+
67+
if isinstance(
68+
module, (nn.Linear, transformer_engine.pytorch.Linear, transformer_engine.pytorch.LayerNormLinear)
69+
):
70+
module.weight.data.normal_(mean=0.0, std=std)
71+
if module.bias is not None:
72+
module.bias.data.zero_()
73+
if isinstance(module, transformer_engine.pytorch.LayerNorm):
74+
if hasattr(module, "weight") and module.weight is not None:
75+
module.weight.data.fill_(1.0)
76+
if hasattr(module, "bias") and module.bias is not None:
77+
module.bias.data.zero_()
78+
if isinstance(module, transformer_engine.pytorch.LayerNormLinear):
79+
module.layer_norm_weight.data.fill_(1.0)
80+
if module.layer_norm_bias is not None:
81+
module.layer_norm_bias.data.zero_()
82+
5683

5784
class NVLlamaModel(NVLlamaPreTrainedModel):
5885
"""Llama3 model implemented in Transformer Engine."""

bionemo-recipes/recipes/llama3_native_te/example_checkpoint/llama3_nv.py

Lines changed: 27 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -53,6 +53,33 @@ class NVLlamaPreTrainedModel(PreTrainedModel):
5353
_no_split_modules = ("TransformerLayer",)
5454
_skip_keys_device_placement = ("past_key_values",)
5555

56+
def _init_weights(self, module):
57+
"""TE-specific weight initialization."""
58+
super()._init_weights(module)
59+
60+
# Copied from transformers.modeling_utils.PreTrainedModel._init_weights
61+
if hasattr(self.config, "initializer_range"):
62+
std = self.config.initializer_range
63+
else:
64+
# 0.02 is the standard default value across the library
65+
std = getattr(self.config.get_text_config(), "initializer_range", 0.02)
66+
67+
if isinstance(
68+
module, (nn.Linear, transformer_engine.pytorch.Linear, transformer_engine.pytorch.LayerNormLinear)
69+
):
70+
module.weight.data.normal_(mean=0.0, std=std)
71+
if module.bias is not None:
72+
module.bias.data.zero_()
73+
if isinstance(module, transformer_engine.pytorch.LayerNorm):
74+
if hasattr(module, "weight") and module.weight is not None:
75+
module.weight.data.fill_(1.0)
76+
if hasattr(module, "bias") and module.bias is not None:
77+
module.bias.data.zero_()
78+
if isinstance(module, transformer_engine.pytorch.LayerNormLinear):
79+
module.layer_norm_weight.data.fill_(1.0)
80+
if module.layer_norm_bias is not None:
81+
module.layer_norm_bias.data.zero_()
82+
5683

5784
class NVLlamaModel(NVLlamaPreTrainedModel):
5885
"""Llama3 model implemented in Transformer Engine."""

0 commit comments

Comments
 (0)