Skip to content

Commit 1efcca7

Browse files
committed
Implement huggingface checkpoint loading
1 parent a1b067e commit 1efcca7

File tree

1 file changed

+189
-14
lines changed

1 file changed

+189
-14
lines changed

examples/pre-training/ernie/pretrain.py

Lines changed: 189 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -35,6 +35,10 @@
3535
PdArgumentParser,
3636
get_last_checkpoint,
3737
)
38+
from paddleformers.trainer.unified_checkpoint import unified_checkpoint
39+
from paddleformers.transformers.model_utils import unwrap_model
40+
41+
from safetensors import safe_open
3842

3943
try:
4044
from paddleformers.utils.downloader import get_static_model_on_pdc
@@ -202,6 +206,185 @@ def _collate_data(data, stack_fn=Stack()):
202206
return train_dataset, valid_dataset, test_dataset, _collate_data
203207

204208

209+
def load_huggingface_checkpoint(model, args):
210+
fused_rms_norm_replace = [
211+
("self_attn.fused_rms_norm_linear.rms_norm_weight", "input_layernorm.weight"),
212+
("self_attn.fused_rms_norm_linear.linear_weight", "self_attn.qkv_proj.weight"),
213+
]
214+
shared_layers_prefix = "shared_layers.embed_weight_share."
215+
unnamed_layers = ["ernie.norm.weight", "lm_head.weight"]
216+
217+
logger.info(f"Loading huggingface checkpoint from {args.model_name_or_path}")
218+
with open(
219+
os.path.join(args.model_name_or_path, "model.safetensors.index.json")
220+
) as f:
221+
weight_map = json.load(f)["weight_map"]
222+
223+
ep_degree = fleet.get_hybrid_communicate_group().get_expert_parallel_world_size()
224+
ep_rank = fleet.get_hybrid_communicate_group().get_expert_parallel_rank()
225+
expert_offset = (model.config.moe_num_experts // ep_degree) * ep_rank
226+
227+
def param_to_weight(name):
228+
# for PP=1, we only need to substitute the fused_rms_norm and expert_id
229+
for src, dst in fused_rms_norm_replace:
230+
name = name.replace(src, dst)
231+
if m := re.search(r"mlp\.experts\.(\d+)", name):
232+
expert_id = expert_offset + int(m.group(1))
233+
s, e = m.span()
234+
name = name[:s] + f"mlp.experts.{expert_id}" + name[e:]
235+
if isinstance(model, ErnieMoEForCausalLM):
236+
return name
237+
238+
# for PP>1, we also need to handle special layers and adjust layer_idx
239+
if name.startswith(shared_layers_prefix):
240+
return "ernie." + name[len(shared_layers_prefix) :]
241+
layer_idx, stem = name.split(".", maxsplit=1)
242+
if stem == "weight":
243+
return unnamed_layers.pop(0)
244+
if stem.startswith("mtp"):
245+
return f"ernie.{stem}"
246+
return f"ernie.layers.{int(layer_idx) - 1}.{stem}"
247+
248+
def try_torch_format(weight_key):
249+
if weight_key.startswith("ernie."):
250+
weight_key = "model." + weight_key[6:]
251+
252+
key_decompose = [weight_key]
253+
if ".up_gate_proj." in weight_key:
254+
key_decompose = [
255+
weight_key.replace(".up_gate_proj.", ".gate_proj."),
256+
weight_key.replace(".up_gate_proj.", ".up_proj."),
257+
]
258+
elif ".qkv_proj." in weight_key:
259+
key_decompose = [
260+
weight_key.replace(".qkv_proj.", ".q_proj."),
261+
weight_key.replace(".qkv_proj.", ".k_proj."),
262+
weight_key.replace(".qkv_proj.", ".v_proj."),
263+
]
264+
265+
tensor_decompose = []
266+
for key in key_decompose:
267+
if not (weight_file := weight_map.get(key)):
268+
return None
269+
with safe_open(
270+
os.path.join(args.model_name_or_path, weight_file),
271+
framework="numpy",
272+
) as f:
273+
tensor = paddle.to_tensor(f.get_tensor(key))
274+
if "_proj." in key or ".gate." in key or "lm_head" in key:
275+
tensor = tensor.T.contiguous()
276+
tensor_decompose.append(tensor)
277+
278+
if len(tensor_decompose) == 1:
279+
return tensor_decompose[0]
280+
else:
281+
return paddle.concat(tensor_decompose, axis=-1)
282+
283+
def auto_fix_shape(param, weight):
284+
assert len(param.shape) == len(weight.shape), "rank not match"
285+
assert all(
286+
p_dim <= w_dim for p_dim, w_dim in zip(param.shape, weight.shape)
287+
), "weight too small"
288+
indices = tuple(slice(0, dim) for dim in param.shape)
289+
return weight[indices].contiguous()
290+
291+
for name, param in model.named_parameters():
292+
weight_key = param_to_weight(name)
293+
if weight_file := weight_map.get(weight_key):
294+
with safe_open(
295+
os.path.join(args.model_name_or_path, weight_file),
296+
framework="numpy",
297+
) as f:
298+
weight = paddle.to_tensor(f.get_tensor(weight_key))
299+
elif (weight := try_torch_format(weight_key)) is None:
300+
logger.warning(
301+
f"param `{name}`'s weight `{weight_key}` not found. "
302+
"Skip initializing."
303+
)
304+
continue
305+
if param.shape != weight.shape:
306+
logger.warning(
307+
f"param `{name}`'s shape doesn't match weight `{weight_key}`: "
308+
f"{param.shape} and {weight.shape}. Auto fixing."
309+
)
310+
weight = auto_fix_shape(param, weight)
311+
param.copy_(weight)
312+
313+
314+
def get_expected_state_dict(model, **kwargs):
315+
fused_rms_norm_replace = [
316+
("self_attn.fused_rms_norm_linear.rms_norm_weight", "input_layernorm.weight"),
317+
("self_attn.fused_rms_norm_linear.linear_weight", "self_attn.qkv_proj.weight"),
318+
]
319+
shared_layers_prefix = "shared_layers.embed_weight_share."
320+
unnamed_layers = ["ernie.norm.weight", "lm_head.weight"]
321+
322+
model = unwrap_model(model)
323+
hcg = fleet.get_hybrid_communicate_group()
324+
ep_degree = hcg.get_expert_parallel_world_size()
325+
ep_rank = hcg.get_expert_parallel_rank()
326+
expert_offset = (model.config.moe_num_experts // ep_degree) * ep_rank
327+
328+
if model.config.head_dim is None:
329+
head_dim = model.config.hidden_size // model.config.num_attention_heads
330+
else:
331+
head_dim = model.config.head_dim
332+
q_dim = head_dim * model.config.num_attention_heads
333+
kv_dim = head_dim * model.config.num_key_value_heads
334+
335+
def copy_attr(out, param):
336+
if hasattr(param, "is_distributed"):
337+
out.is_distributed = param.is_distributed
338+
if hasattr(param, "no_sync"):
339+
out.no_sync = param.no_sync
340+
return out
341+
342+
def param_to_weight(name):
343+
# for PP=1, we only need to substitute the fused_rms_norm and expert_id
344+
for src, dst in fused_rms_norm_replace:
345+
name = name.replace(src, dst)
346+
if m := re.search(r"\.experts\.(\d+)\.", name):
347+
expert_id = expert_offset + int(m.group(1))
348+
s, e = m.span()
349+
name = name[:s] + f".experts.{expert_id}." + name[e:]
350+
if isinstance(model, ErnieMoEForCausalLM):
351+
return name
352+
353+
# for PP>1, we also need to handle special layers and adjust layer_idx
354+
if name.startswith(shared_layers_prefix):
355+
return "ernie." + name[len(shared_layers_prefix) :]
356+
layer_idx, stem = name.split(".", maxsplit=1)
357+
if stem == "weight":
358+
return unnamed_layers.pop(0)
359+
if stem.startswith("mtp"):
360+
return f"ernie.{stem}"
361+
return f"ernie.layers.{int(layer_idx) - 1}.{stem}"
362+
363+
state_dict = {}
364+
for name, param in model.state_dict().items():
365+
name = param_to_weight(name)
366+
if name.startswith("ernie."):
367+
name = "model." + name[6:]
368+
369+
if "_proj." in name or ".gate." in name or "lm_head" in name:
370+
param = copy_attr(param.T, param)
371+
372+
if ".up_gate_proj." in name:
373+
gate, up = param.split(2)
374+
gate, up = copy_attr(gate, param), copy_attr(up, param)
375+
state_dict[name.replace(".up_gate_proj.", ".gate_proj.")] = gate
376+
state_dict[name.replace(".up_gate_proj.", ".up_proj.")] = up
377+
elif ".qkv_proj." in name:
378+
assert q_dim + kv_dim * 2 == param.shape[0]
379+
state_dict[name.replace(".qkv_proj.", ".q_proj.")] = param[:q_dim]
380+
state_dict[name.replace(".qkv_proj.", ".k_proj.")] = param[q_dim:-kv_dim]
381+
state_dict[name.replace(".qkv_proj.", ".v_proj.")] = param[-kv_dim:]
382+
else:
383+
state_dict[name] = param
384+
385+
return state_dict
386+
387+
205388
def main():
206389
if set_affinity is not None:
207390
set_affinity_code = set_affinity()
@@ -520,21 +703,12 @@ def sname_to_tname(pp_model):
520703
cfg.enable_delay_scale_loss = args.enable_delay_scale_loss
521704
register_pp_reshard_information(cfg.num_hidden_layers)
522705

523-
if args.from_scratch:
524-
model = ErnieMoEForCausalLMPipe(cfg)
525-
else:
526-
model = ErnieMoEForCausalLMPipe.from_pretrained(
527-
args.model_name_or_path,
528-
config=cfg,
529-
)
706+
model = ErnieMoEForCausalLMPipe(cfg)
530707
else:
531-
if args.from_scratch:
532-
model = ErnieMoEForCausalLM(cfg)
533-
else:
534-
model = ErnieMoEForCausalLM.from_pretrained(
535-
args.model_name_or_path,
536-
config=cfg,
537-
)
708+
model = ErnieMoEForCausalLM(cfg)
709+
710+
if not args.from_scratch:
711+
load_huggingface_checkpoint(model, args)
538712

539713
cfg = model.config
540714
logger.info(f"using model type:{type(model)}")
@@ -581,6 +755,7 @@ def sname_to_tname(pp_model):
581755
if args.do_train:
582756
train_result = trainer.train(resume_from_checkpoint=checkpoint)
583757
metrics = train_result.metrics
758+
unified_checkpoint.get_expected_state_dict = get_expected_state_dict
584759
trainer.save_model(args.output_dir)
585760
trainer.log_metrics("train", metrics)
586761
trainer.save_metrics("train", metrics)

0 commit comments

Comments
 (0)