-
Notifications
You must be signed in to change notification settings - Fork 434
Expand file tree
/
Copy pathregistry.py
More file actions
287 lines (249 loc) · 10.8 KB
/
registry.py
File metadata and controls
287 lines (249 loc) · 10.8 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
import dataclasses
from typing import Any
import torch
from mbridge.core.bridge import Bridge
from megatron.core import parallel_state as mpu
from megatron.core import tensor_parallel
from megatron.core.distributed import DistributedDataParallel as DDP
from megatron.core.distributed import DistributedDataParallelConfig as MCoreDDPConfig
from megatron.core.models.gpt.gpt_model import GPTModel
from megatron.core.transformer import TransformerConfig
from transformers import AutoConfig, PretrainedConfig
from areal.api.cli_args import MegatronEngineConfig
from areal.models.mcore.qwen3 import (
hf_to_mcore_config_qwen3_dense,
make_mcore_layer_specs_qwen3_dense,
)
from areal.utils import logging
logger = logging.getLogger("MCoreRegistry")
class ValueHead(torch.nn.Linear):
def __init__(
self,
input_size: int,
output_size: int = 1,
*,
config: TransformerConfig,
bias: bool = False,
) -> None:
super().__init__(in_features=input_size, out_features=output_size, bias=bias)
self.sequence_parallel = config.sequence_parallel
if self.sequence_parallel:
self.weight.sequence_parallel = True
self.weight.data.normal_(mean=0.0, std=0.02)
if bias:
self.bias.data.zero_()
def forward(
self,
input_: torch.Tensor,
weight: torch.Tensor | None = None,
runtime_gather_output: bool | None = None,
) -> tuple[torch.Tensor, None]:
logits = super().forward(input_)
logits = logits.float()
if self.sequence_parallel:
logits = tensor_parallel.gather_from_sequence_parallel_region(
logits, tensor_parallel_output_grad=False
)
return logits, None
def _replace_output_layer_with_value_head(
model: GPTModel,
tf_config: TransformerConfig,
) -> None:
"""Replace model's output_layer with ValueHead.
This function can be used on any GPTModel instance, whether created
via mbridge or directly. After replacement:
- model.output_layer becomes a ValueHead instance
- model.vocab_size is set to 1
Args:
model: The GPTModel instance to modify
tf_config: Transformer configuration containing hidden_size and SP settings
"""
if not hasattr(model, "output_layer"):
raise ValueError(
"Model does not have output_layer. Ensure post_process=True when creating GPTModel."
)
dtype = tf_config.params_dtype
model.output_layer = ValueHead(
input_size=tf_config.hidden_size,
output_size=1,
config=tf_config,
bias=False,
).to(dtype=dtype)
model.vocab_size = 1
def unwrap_to_gpt_model(model: torch.nn.Module) -> GPTModel:
"""Unwraps a model to the underlying GPTModel instance."""
_model = model
while not isinstance(_model, GPTModel) and hasattr(_model, "module"):
_model = _model.module
if not isinstance(_model, GPTModel):
raise TypeError(f"Model could not be unwrapped to GPTModel. Got {type(_model)}")
return _model
# Model registry for different architectures
def make_hf_and_mcore_config(
hf_path: str,
dtype: torch.dtype,
bridge=None,
bridge_type: str = "mbridge",
) -> tuple[PretrainedConfig, TransformerConfig]:
if bridge is not None and bridge_type == "mbridge":
hf_config = bridge.hf_config
hf_config._name_or_path = hf_path
return hf_config, bridge.config
elif bridge is not None and bridge_type == "megatron-bridge":
hf_config = getattr(bridge.hf_pretrained, "config", bridge.hf_pretrained)
if hasattr(hf_config, "_name_or_path"):
hf_config._name_or_path = hf_path
return hf_config, bridge.transformer_config
else:
hf_config: PretrainedConfig = AutoConfig.from_pretrained(
pretrained_model_name_or_path=hf_path,
trust_remote_code=True,
)
assert len(hf_config.architectures) == 1
architecture = hf_config.architectures[0]
if architecture == "Qwen3ForCausalLM":
return hf_config, hf_to_mcore_config_qwen3_dense(hf_config, dtype)
else:
raise ValueError(
f"Architecture not registered for config conversion: {architecture}."
)
def make_mcore_layer_specs(hf_config: PretrainedConfig, tf_config: TransformerConfig):
assert len(hf_config.architectures) == 1
architecture = hf_config.architectures[0]
if architecture == "Qwen3ForCausalLM":
return make_mcore_layer_specs_qwen3_dense(tf_config, use_te=True)
else:
raise ValueError(
f"Architecture not registered for config conversion: {architecture}."
)
def make_mcore_model(
hf_config: PretrainedConfig,
tf_config: TransformerConfig,
mcore_config: MegatronEngineConfig | None = None,
bridge: Bridge | Any | None = None,
bridge_type: str = "mbridge",
is_critic: bool = False,
) -> list[GPTModel | DDP]:
if bridge is not None and bridge_type == "mbridge":
models = bridge.get_model(
# TODO: Add DDP options when supporting training
wrap_with_ddp=mcore_config.wrap_with_ddp,
ddp_config=dataclasses.asdict(mcore_config.ddp),
use_torch_fsdp2=mcore_config.use_torch_fsdp2,
use_custom_fsdp=mcore_config.use_custom_fsdp,
fp16=tf_config.fp16,
bf16=tf_config.bf16,
use_precision_aware_optimizer=mcore_config.use_precision_aware_optimizer,
overlap_param_gather_with_optimizer_step=mcore_config.overlap_param_gather_with_optimizer_step,
)
models = list(models)
# Replace output_layer with ValueHead for critic models
if is_critic:
for model in models:
_model = unwrap_to_gpt_model(model)
_replace_output_layer_with_value_head(_model, tf_config)
return models
if bridge is not None and bridge_type == "megatron-bridge":
provider = bridge.to_megatron_provider(load_weights=False)
vpp_size = mcore_config.virtual_pipeline_parallel_size or 0
provider.tensor_model_parallel_size = mpu.get_tensor_model_parallel_world_size()
provider.pipeline_model_parallel_size = (
mpu.get_pipeline_model_parallel_world_size()
)
provider.virtual_pipeline_model_parallel_size = (
vpp_size if vpp_size > 1 else None
)
provider.context_parallel_size = mpu.get_context_parallel_world_size()
provider.expert_model_parallel_size = mpu.get_expert_model_parallel_world_size()
provider.expert_tensor_parallel_size = (
mpu.get_expert_tensor_parallel_world_size()
)
provider.sequence_parallel = mpu.get_tensor_model_parallel_world_size() > 1
provider.pipeline_dtype = tf_config.params_dtype
provider.recompute_granularity = mcore_config.recompute_granularity
provider.recompute_method = mcore_config.recompute_method
provider.recompute_num_layers = mcore_config.recompute_num_layers
provider.distribute_saved_activations = (
mcore_config.distribute_saved_activations
)
provider.recompute_modules = mcore_config.recompute_modules
provider.account_for_embedding_in_pipeline_split = False
provider.account_for_loss_in_pipeline_split = False
# Keep these four flags aligned with mbridge base defaults.
provider.variable_seq_lengths = True
logger.warning(
"Ignoring mcore_config.moe_token_dispatcher_type=%s for bridge_type='megatron-bridge'; "
"using 'alltoall' and variable_seq_lengths=True.",
mcore_config.moe_token_dispatcher_type,
)
provider.moe_token_dispatcher_type = "alltoall"
provider.batch_p2p_comm = False
provider.overlap_p2p_comm = (
vpp_size > 1 and provider.pipeline_model_parallel_size > 1
)
# Aligning tf config settings with provider for consistency.
tf_config.variable_seq_lengths = provider.variable_seq_lengths
tf_config.moe_token_dispatcher_type = provider.moe_token_dispatcher_type
tf_config.batch_p2p_comm = provider.batch_p2p_comm
tf_config.overlap_p2p_comm = provider.overlap_p2p_comm
provider.finalize()
models = provider.provide_distributed_model(
ddp_config=MCoreDDPConfig(**dataclasses.asdict(mcore_config.ddp)),
fp16=tf_config.fp16,
bf16=tf_config.bf16,
use_megatron_fsdp=mcore_config.use_custom_fsdp,
use_torch_fsdp2=mcore_config.use_torch_fsdp2,
wrap_with_ddp=mcore_config.wrap_with_ddp,
overlap_param_gather_with_optimizer_step=mcore_config.overlap_param_gather_with_optimizer_step,
)
models = list(models)
if is_critic:
for model in models:
_model = unwrap_to_gpt_model(model)
_replace_output_layer_with_value_head(_model, tf_config)
return models
else:
if (
mcore_config is not None
and mcore_config.virtual_pipeline_parallel_size is not None
and mcore_config.virtual_pipeline_parallel_size > 1
):
raise NotImplementedError(
"Virtual pipeline parallelism requires mbridge-backed models."
)
transformer_layer_spec = make_mcore_layer_specs(hf_config, tf_config)
rope_scaling_args = {}
if hf_config.rope_scaling is not None:
if hf_config.rope_scaling["type"] != "linear":
raise NotImplementedError(
f"Rope scaling type {hf_config.rope_scaling['type']} not supported yet."
)
rope_scaling_args["seq_len_interpolation_factor"] = hf_config.rope_scaling[
"factor"
]
model = GPTModel(
config=tf_config,
transformer_layer_spec=transformer_layer_spec,
vocab_size=hf_config.vocab_size,
max_sequence_length=hf_config.max_position_embeddings,
pre_process=True, # TODO: pipeline parallel
post_process=True, # TODO: pipeline parallel
share_embeddings_and_output_weights=False, # TODO: implement share output weights
position_embedding_type="rope",
rotary_base=hf_config.rope_theta,
**rope_scaling_args,
# vp_stage=None TODO: virtual pipeline parallel
)
# Replace output_layer with ValueHead for critic models
if is_critic:
_replace_output_layer_with_value_head(model, tf_config)
if mcore_config.wrap_with_ddp:
ddp_config = MCoreDDPConfig(**dataclasses.asdict(mcore_config.ddp))
wrapped = DDP(
config=tf_config,
ddp_config=ddp_config,
module=model,
disable_bucketing=False,
)
return [wrapped]
return [model]