forked from NVIDIA/Megatron-LM
-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathgpt_builders.py
More file actions
143 lines (131 loc) · 5.74 KB
/
gpt_builders.py
File metadata and controls
143 lines (131 loc) · 5.74 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
# Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
from megatron.core.models.gpt import GPTModel
from megatron.core.models.gpt.gpt_layer_specs import (
get_gpt_decoder_block_spec,
get_gpt_layer_local_spec,
get_gpt_layer_with_transformer_engine_spec,
get_gpt_layer_with_inference_spec,
get_gpt_mtp_block_spec,
)
from megatron.core.models.gpt.heterogeneous.heterogeneous_layer_specs import (
get_gpt_heterogeneous_layer_spec,
)
from megatron.core.transformer.spec_utils import import_module
from megatron.training import get_args, print_rank_0
from megatron.training.arguments import core_transformer_config_from_args
from megatron.training.yaml_arguments import core_transformer_config_from_yaml
import megatron.legacy.model # isort: skip
# NOTE: Loading `megatron.legacy.model` earlier fails due to circular import
def gpt_builder(args, pre_process, post_process, vp_stage=None, config=None, pg_collection=None):
print_rank_0('building GPT model ...')
if config is None:
if args.yaml_cfg is not None:
config = core_transformer_config_from_yaml(args, "language_model")
else:
config = core_transformer_config_from_args(args)
if args.use_legacy_models:
model = megatron.legacy.model.GPTModel(
config,
num_tokentypes=0,
parallel_output=True,
pre_process=pre_process,
post_process=post_process,
)
else: # using core models
if args.spec is not None:
transformer_layer_spec = import_module(args.spec)
else:
use_te = args.transformer_impl == "transformer_engine"
if args.num_experts:
assert not (config.transformer_impl == "inference_optimized")
# Define the decoder block spec
transformer_layer_spec = get_gpt_decoder_block_spec(
config,
use_transformer_engine=use_te,
normalization=args.normalization,
qk_l2_norm=args.qk_l2_norm,
vp_stage=vp_stage,
)
elif args.heterogeneous_layers_config_path is not None:
assert not (config.transformer_impl == "inference_optimized")
transformer_layer_spec = get_gpt_heterogeneous_layer_spec(config, use_te)
else:
# Define the decoder layer spec
transformer_layer_spec = _get_transformer_layer_spec(use_te, config)
mtp_block_spec = None
if args.mtp_num_layers is not None:
assert not (config.transformer_impl == "inference_optimized")
if (
hasattr(transformer_layer_spec, 'layer_specs')
and len(transformer_layer_spec.layer_specs) == 0
):
# Get the decoder layer spec explicitly if no decoder layer in the last stage,
# Only happens with block spec (TransformerBlockSubmodules) when using MoE.
transformer_layer_spec_for_mtp = _get_transformer_layer_spec(use_te, config)
else:
transformer_layer_spec_for_mtp = transformer_layer_spec
mtp_block_spec = get_gpt_mtp_block_spec(
config,
transformer_layer_spec_for_mtp,
use_transformer_engine=use_te,
vp_stage=vp_stage,
)
model = GPTModel(
config=config,
transformer_layer_spec=transformer_layer_spec,
vocab_size=args.padded_vocab_size,
max_sequence_length=args.max_position_embeddings,
pre_process=pre_process,
post_process=post_process,
fp16_lm_cross_entropy=args.fp16_lm_cross_entropy,
parallel_output=True,
share_embeddings_and_output_weights=not args.untie_embeddings_and_output_weights,
position_embedding_type=args.position_embedding_type,
rotary_percent=args.rotary_percent,
rotary_base=args.rotary_base,
rope_scaling=args.use_rope_scaling,
mtp_block_spec=mtp_block_spec,
vp_stage=vp_stage,
pg_collection=pg_collection,
)
return model
def _get_transformer_layer_spec(use_te, config):
"""Get transformer layer specification based on configuration.
Args:
use_te (bool): Whether to use Transformer Engine
args: Training arguments
config: Model configuration
Returns:
transformer_layer_spec: The transformer layer specification
"""
args = get_args()
if use_te:
return get_gpt_layer_with_transformer_engine_spec(
args.num_experts,
args.moe_grouped_gemm,
args.qk_layernorm,
args.multi_latent_attention,
moe_use_legacy_grouped_gemm=args.moe_use_legacy_grouped_gemm,
qk_l2_norm=args.qk_l2_norm,
use_kitchen=config.use_kitchen,
use_kitchen_attention=config.use_kitchen_attention,
kitchen_attention_backend=config.kitchen_attention_backend,
)
elif config.transformer_impl == "inference_optimized":
return get_gpt_layer_with_inference_spec(
args.qk_layernorm,
args.multi_latent_attention,
qk_l2_norm=args.qk_l2_norm,
)
else:
return get_gpt_layer_local_spec(
args.num_experts,
args.moe_grouped_gemm,
args.qk_layernorm,
args.multi_latent_attention,
moe_use_legacy_grouped_gemm=args.moe_use_legacy_grouped_gemm,
normalization=args.normalization,
use_kitchen=config.use_kitchen,
use_kitchen_attention=config.use_kitchen_attention,
kitchen_attention_backend=config.kitchen_attention_backend,
)