Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
202 changes: 188 additions & 14 deletions examples/pre-training/ernie/pretrain.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,10 @@
PdArgumentParser,
get_last_checkpoint,
)
from paddleformers.trainer.unified_checkpoint import unified_checkpoint
from paddleformers.transformers.model_utils import unwrap_model

from safetensors import safe_open

try:
from paddleformers.utils.downloader import get_static_model_on_pdc
Expand Down Expand Up @@ -202,6 +206,184 @@ def _collate_data(data, stack_fn=Stack()):
return train_dataset, valid_dataset, test_dataset, _collate_data


def load_huggingface_checkpoint(model, args):
fused_rms_norm_replace = [
("self_attn.fused_rms_norm_linear.rms_norm_weight", "input_layernorm.weight"),
("self_attn.fused_rms_norm_linear.linear_weight", "self_attn.qkv_proj.weight"),
]
shared_layers_prefix = "shared_layers.embed_weight_share."
unnamed_layers = ["ernie.norm.weight", "lm_head.weight"]

logger.info(f"Loading huggingface checkpoint from {args.model_name_or_path}")
with open(
os.path.join(args.model_name_or_path, "model.safetensors.index.json")
) as f:
weight_map = json.load(f)["weight_map"]

ep_degree = fleet.get_hybrid_communicate_group().get_expert_parallel_world_size()
ep_rank = fleet.get_hybrid_communicate_group().get_expert_parallel_rank()
expert_offset = (model.config.moe_num_experts // ep_degree) * ep_rank
use_torch_format = False

def param_to_weight(name):
# for PP=1, we only need to substitute the fused_rms_norm and expert_id
for src, dst in fused_rms_norm_replace:
name = name.replace(src, dst)
if m := re.search(r"mlp\.experts\.(\d+)", name):
expert_id = expert_offset + int(m.group(1))
s, e = m.span()
name = name[:s] + f"mlp.experts.{expert_id}" + name[e:]
if isinstance(model, ErnieMoEForCausalLM):
return name

# for PP>1, we also need to handle special layers and adjust layer_idx
if name.startswith(shared_layers_prefix):
return "ernie." + name[len(shared_layers_prefix) :]
layer_idx, stem = name.split(".", maxsplit=1)
if stem == "weight":
return unnamed_layers.pop(0)
if stem.startswith("mtp"):
return f"ernie.{stem}"
return f"ernie.layers.{int(layer_idx) - 1}.{stem}"

def try_torch_format(weight_key):
if weight_key.startswith("ernie."):
weight_key = "model." + weight_key[6:]

key_decompose = [weight_key]
if ".up_gate_proj." in weight_key:
key_decompose = [
weight_key.replace(".up_gate_proj.", ".gate_proj."),
weight_key.replace(".up_gate_proj.", ".up_proj."),
]
elif ".qkv_proj." in weight_key:
key_decompose = [
weight_key.replace(".qkv_proj.", ".q_proj."),
weight_key.replace(".qkv_proj.", ".k_proj."),
weight_key.replace(".qkv_proj.", ".v_proj."),
]

tensor_decompose = []
for key in key_decompose:
if not (weight_file := weight_map.get(key)):
return None
with safe_open(
os.path.join(args.model_name_or_path, weight_file),
framework="numpy",
) as f:
tensor = paddle.to_tensor(f.get_tensor(key))
if "_proj." in key or ".gate." in key:
tensor = tensor.T.contiguous()
tensor_decompose.append(tensor)

if len(tensor_decompose) == 1:
return tensor_decompose[0]
else:
return paddle.concat(tensor_decompose, axis=-1)

def auto_fix_shape(param, weight):
assert len(param.shape) == len(weight.shape), "rank not match"
assert all(
p_dim <= w_dim for p_dim, w_dim in zip(param.shape, weight.shape)
), "weight too small"
indices = tuple(slice(0, dim) for dim in param.shape)
return weight[indices].contiguous()

for name, param in model.named_parameters():
weight_key = param_to_weight(name)
if weight_file := weight_map.get(weight_key):
with safe_open(
os.path.join(args.model_name_or_path, weight_file),
framework="numpy",
) as f:
weight = paddle.to_tensor(f.get_tensor(weight_key))
elif (weight := try_torch_format(weight_key)) is not None:
use_torch_format = True
else:
logger.warning(
f"param `{name}`'s weight `{weight_key}` not found. "
"Skip initializing."
)
continue
if use_torch_format and "lm_head" in weight_key:
weight = weight.T.contiguous()
if param.shape != weight.shape:
logger.warning(
f"param `{name}`'s shape doesn't match weight `{weight_key}`: "
f"{param.shape} and {weight.shape}. Auto fixing."
)
weight = auto_fix_shape(param, weight)
param.copy_(weight)


def get_expected_state_dict(model, **kwargs):
fused_rms_norm_replace = [
("self_attn.fused_rms_norm_linear.rms_norm_weight", "input_layernorm.weight"),
("self_attn.fused_rms_norm_linear.linear_weight", "self_attn.qkv_proj.weight"),
]
shared_layers_prefix = "embed_share."

model = unwrap_model(model)
hcg = fleet.get_hybrid_communicate_group()
ep_degree = hcg.get_expert_parallel_world_size()
ep_rank = hcg.get_expert_parallel_rank()
expert_offset = (model.config.moe_num_experts // ep_degree) * ep_rank

if model.config.head_dim is None:
head_dim = model.config.hidden_size // model.config.num_attention_heads
else:
head_dim = model.config.head_dim
q_dim = head_dim * model.config.num_attention_heads
kv_dim = head_dim * model.config.num_key_value_heads

def copy_attr(out, param):
if hasattr(param, "is_distributed"):
out.is_distributed = param.is_distributed
if hasattr(param, "no_sync"):
out.no_sync = param.no_sync
return out

def param_to_weight(name):
# for PP=1, we only need to substitute the fused_rms_norm and expert_id
for src, dst in fused_rms_norm_replace:
name = name.replace(src, dst)
if m := re.search(r"\.experts\.(\d+)\.", name):
expert_id = expert_offset + int(m.group(1))
s, e = m.span()
name = name[:s] + f".experts.{expert_id}." + name[e:]
if isinstance(model, ErnieMoEForCausalLM):
return name

# for PP>1, we also need to handle shared layers
if name.startswith(shared_layers_prefix):
return "ernie." + name[len(shared_layers_prefix) :]
return name

state_dict = {}
for name, param in model.state_dict().items():
name = param_to_weight(name)
if name.startswith("ernie."):
name = "model." + name[6:]

if "_proj." in name or ".gate." in name or "lm_head" in name:
param = copy_attr(param.T, param)

if ".up_gate_proj." in name:
gate, up = param.split(2)
gate, up = copy_attr(gate, param), copy_attr(up, param)
state_dict[name.replace(".up_gate_proj.", ".gate_proj.")] = gate
state_dict[name.replace(".up_gate_proj.", ".up_proj.")] = up
elif ".qkv_proj." in name:
assert q_dim + kv_dim * 2 == param.shape[0]
state_dict[name.replace(".qkv_proj.", ".q_proj.")] = param[:q_dim]
state_dict[name.replace(".qkv_proj.", ".k_proj.")] = param[q_dim:-kv_dim]
state_dict[name.replace(".qkv_proj.", ".v_proj.")] = param[-kv_dim:]
else:
state_dict[name] = param

return state_dict


def main():
if set_affinity is not None:
set_affinity_code = set_affinity()
Expand Down Expand Up @@ -520,21 +702,12 @@ def sname_to_tname(pp_model):
cfg.enable_delay_scale_loss = args.enable_delay_scale_loss
register_pp_reshard_information(cfg.num_hidden_layers)

if args.from_scratch:
model = ErnieMoEForCausalLMPipe(cfg)
else:
model = ErnieMoEForCausalLMPipe.from_pretrained(
args.model_name_or_path,
config=cfg,
)
model = ErnieMoEForCausalLMPipe(cfg)
else:
if args.from_scratch:
model = ErnieMoEForCausalLM(cfg)
else:
model = ErnieMoEForCausalLM.from_pretrained(
args.model_name_or_path,
config=cfg,
)
model = ErnieMoEForCausalLM(cfg)

if not args.from_scratch:
load_huggingface_checkpoint(model, args)

cfg = model.config
logger.info(f"using model type:{type(model)}")
Expand Down Expand Up @@ -581,6 +754,7 @@ def sname_to_tname(pp_model):
if args.do_train:
train_result = trainer.train(resume_from_checkpoint=checkpoint)
metrics = train_result.metrics
unified_checkpoint.get_expected_state_dict = get_expected_state_dict
trainer.save_model(args.output_dir)
trainer.log_metrics("train", metrics)
trainer.save_metrics("train", metrics)
Expand Down