forked from NVIDIA/Megatron-LM
-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathgpt_builders.py
More file actions
128 lines (116 loc) · 4.92 KB
/
gpt_builders.py
File metadata and controls
128 lines (116 loc) · 4.92 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
# Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved.
from megatron.core.models.gpt import GPTModel
from megatron.core.models.gpt.gpt_layer_specs import (
get_gpt_decoder_block_spec,
get_gpt_layer_local_spec,
get_gpt_layer_with_transformer_engine_spec,
get_gpt_mtp_block_spec,
)
from megatron.core.models.gpt.heterogeneous.heterogeneous_layer_specs import (
get_gpt_heterogeneous_layer_spec,
)
from megatron.core.transformer.spec_utils import import_module
from megatron.training import get_args, print_rank_0
from megatron.training.arguments import core_transformer_config_from_args
from megatron.training.yaml_arguments import core_transformer_config_from_yaml
import megatron.legacy.model # isort: skip
# NOTE: Loading `megatron.legacy.model` earlier fails due to circular import
def gpt_builder(args, pre_process, post_process, vp_stage=None, config=None):
print_rank_0('building GPT model ...')
if config is None:
if args.yaml_cfg is not None:
config = core_transformer_config_from_yaml(args, "language_model")
else:
config = core_transformer_config_from_args(args)
if args.use_legacy_models:
model = megatron.legacy.model.GPTModel(
config,
num_tokentypes=0,
parallel_output=True,
pre_process=pre_process,
post_process=post_process,
)
else: # using core models
if args.spec is not None:
transformer_layer_spec = import_module(args.spec)
else:
use_te = args.transformer_impl == "transformer_engine"
if args.num_experts:
# Define the decoder block spec
transformer_layer_spec = get_gpt_decoder_block_spec(
config,
use_transformer_engine=use_te,
normalization=args.normalization,
qk_l2_norm=args.qk_l2_norm,
vp_stage=vp_stage,
)
elif args.heterogeneous_layers_config_path is not None:
transformer_layer_spec = get_gpt_heterogeneous_layer_spec(config, use_te)
else:
# Define the decoder layer spec
transformer_layer_spec = _get_transformer_layer_spec(use_te, config)
mtp_block_spec = None
if args.mtp_num_layers is not None:
if (
hasattr(transformer_layer_spec, 'layer_specs')
and len(transformer_layer_spec.layer_specs) == 0
):
# Get the decoder layer spec explicitly if no decoder layer in the last stage,
# Only happens with block spec (TransformerBlockSubmodules) when using MoE.
transformer_layer_spec_for_mtp = _get_transformer_layer_spec(use_te, config)
else:
transformer_layer_spec_for_mtp = transformer_layer_spec
mtp_block_spec = get_gpt_mtp_block_spec(
config,
transformer_layer_spec_for_mtp,
use_transformer_engine=use_te,
vp_stage=vp_stage,
)
model = GPTModel(
config=config,
transformer_layer_spec=transformer_layer_spec,
vocab_size=args.padded_vocab_size,
max_sequence_length=args.max_position_embeddings,
pre_process=pre_process,
post_process=post_process,
fp16_lm_cross_entropy=args.fp16_lm_cross_entropy,
parallel_output=True,
share_embeddings_and_output_weights=not args.untie_embeddings_and_output_weights,
position_embedding_type=args.position_embedding_type,
rotary_percent=args.rotary_percent,
rotary_base=args.rotary_base,
rope_scaling=args.use_rope_scaling,
mtp_block_spec=mtp_block_spec,
vp_stage=vp_stage,
)
return model
def _get_transformer_layer_spec(use_te, config):
"""Get transformer layer specification based on configuration.
Args:
use_te (bool): Whether to use Transformer Engine
args: Training arguments
config: Model configuration
Returns:
transformer_layer_spec: The transformer layer specification
"""
args = get_args()
if use_te:
return get_gpt_layer_with_transformer_engine_spec(
args.num_experts,
args.moe_grouped_gemm,
args.qk_layernorm,
args.multi_latent_attention,
moe_use_legacy_grouped_gemm=args.moe_use_legacy_grouped_gemm,
qk_l2_norm=args.qk_l2_norm,
use_kitchen=config.use_kitchen,
)
else:
return get_gpt_layer_local_spec(
args.num_experts,
args.moe_grouped_gemm,
args.qk_layernorm,
args.multi_latent_attention,
moe_use_legacy_grouped_gemm=args.moe_use_legacy_grouped_gemm,
normalization=args.normalization,
use_kitchen=config.use_kitchen,
)