Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
133 changes: 119 additions & 14 deletions examples/pre-training/ernie/pretrain.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,6 +63,8 @@

from config import get_config

from safetensors import safe_open

try:
from paddleformers.trainer.trainer_utils import log_trainer_start
except ImportError:
Expand Down Expand Up @@ -202,6 +204,118 @@ def _collate_data(data, stack_fn=Stack()):
return train_dataset, valid_dataset, test_dataset, _collate_data


def load_huggingface_checkpoint(model, args):
fused_rms_norm_replace = [
("self_attn.fused_rms_norm_linear.rms_norm_weight", "input_layernorm.weight"),
("self_attn.fused_rms_norm_linear.linear_weight", "self_attn.qkv_proj.weight"),
]
shared_layers_prefix = "shared_layers.embed_weight_share."
unnamed_layers = ["ernie.norm.weight", "lm_head.weight"]

logger.info(f"Loading huggingface checkpoint from {args.model_name_or_path}")
with open(
os.path.join(args.model_name_or_path, "model.safetensors.index.json")
) as f:
weight_map = json.load(f)["weight_map"]

ep_degree = fleet.get_hybrid_communicate_group().get_expert_parallel_world_size()
ep_rank = fleet.get_hybrid_communicate_group().get_expert_parallel_rank()
expert_offset = (model.config.moe_num_experts // ep_degree) * ep_rank

def param_to_weight(name):
# for PP=1, we only need to substitute the fused_rms_norm and expert_id
for src, dst in fused_rms_norm_replace:
name = name.replace(src, dst)
if m := re.search(r"mlp\.experts\.(\d+)", name):
expert_id = expert_offset + int(m.group(1))
s, e = m.span()
name = name[:s] + f"mlp.experts.{expert_id}" + name[e:]
if isinstance(model, ErnieMoEForCausalLM):
return name

# for PP>1, we also need to handle special layers and adjust layer_idx
if name.startswith(shared_layers_prefix):
return "ernie." + name[len(shared_layers_prefix) :]
layer_idx, stem = name.split(".", maxsplit=1)
if stem == "weight":
return unnamed_layers.pop(0)
if stem.startswith("mtp"):
return f"ernie.{stem}"
return f"ernie.layers.{int(layer_idx) - 1}.{stem}"

def try_torch_format(weight_key):
if weight_key.startswith("ernie."):
weight_key = "model." + weight_key[6:]

key_decompose = [weight_key]
if ".up_gate_proj." in weight_key:
key_decompose = [
weight_key.replace(".up_gate_proj.", ".gate_proj."),
weight_key.replace(".up_gate_proj.", ".up_proj."),
]
elif ".qkv_proj." in weight_key:
key_decompose = [
weight_key.replace(".qkv_proj.", ".q_proj."),
weight_key.replace(".qkv_proj.", ".k_proj."),
weight_key.replace(".qkv_proj.", ".v_proj."),
]

tensor_decompose = []
for key in key_decompose:
if not (weight_file := weight_map.get(key)):
return None
with safe_open(
os.path.join(args.model_name_or_path, weight_file),
framework="paddle",
device="cpu",
) as f:
tensor = f.get_tensor(key)
if "_proj." in key or ".gate." in key:
tensor = tensor.T.contiguous()
tensor_decompose.append(tensor)

if len(tensor_decompose) == 1:
return tensor_decompose[0]
else:
return paddle.concat(tensor_decompose, axis=-1)

def auto_fix_shape(param, weight):
assert len(param.shape) == len(weight.shape), "rank not match"
if (
len(param.shape) == 2
and param.shape[0] == weight.shape[1]
and param.shape[1] == weight.shape[0]
):
return weight.T.contiguous()
assert all(
p_dim <= w_dim for p_dim, w_dim in zip(param.shape, weight.shape)
), "weight too small"
indices = tuple(slice(0, dim) for dim in param.shape)
return weight[indices].contiguous()

for name, param in model.named_parameters():
weight_key = param_to_weight(name)
if weight_file := weight_map.get(weight_key):
with safe_open(
os.path.join(args.model_name_or_path, weight_file),
framework="paddle",
) as f:
weight = f.get_tensor(weight_key)
elif (weight := try_torch_format(weight_key)) is None:
logger.warning(
f"param `{name}`'s weight `{weight_key}` not found. "
"Skip initializing."
)
continue
if param.shape != weight.shape:
logger.warning(
f"param `{name}`'s shape doesn't match weight `{weight_key}`: "
f"{param.shape} and {weight.shape}. Auto fixing."
)
weight = auto_fix_shape(param, weight)
param.copy_(weight)


def main():
if set_affinity is not None:
set_affinity_code = set_affinity()
Expand Down Expand Up @@ -520,21 +634,12 @@ def sname_to_tname(pp_model):
cfg.enable_delay_scale_loss = args.enable_delay_scale_loss
register_pp_reshard_information(cfg.num_hidden_layers)

if args.from_scratch:
model = ErnieMoEForCausalLMPipe(cfg)
else:
model = ErnieMoEForCausalLMPipe.from_pretrained(
args.model_name_or_path,
config=cfg,
)
model = ErnieMoEForCausalLMPipe(cfg)
else:
if args.from_scratch:
model = ErnieMoEForCausalLM(cfg)
else:
model = ErnieMoEForCausalLM.from_pretrained(
args.model_name_or_path,
config=cfg,
)
model = ErnieMoEForCausalLM(cfg)

if not args.from_scratch:
load_huggingface_checkpoint(model, args)

cfg = model.config
logger.info(f"using model type:{type(model)}")
Expand Down