Skip to content

Commit 8156d71

Browse files
authored
fix load hf ckpt core dump (#2943)
1 parent 9134817 commit 8156d71

File tree

1 file changed

+10
-4
lines changed

1 file changed

+10
-4
lines changed

examples/experiments/deepseek_v3_pretrain/load_hf_ckpt.py

Lines changed: 10 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -342,22 +342,28 @@ def load_huggingface_ckpt(model, huggingface_ckpt_path):
342342
if len(files) == 1:
343343
tensor0 = f.get_tensor(hf_name[0])
344344
tensor1 = f.get_tensor(hf_name[1])
345+
target_shape = model.state_dict()[pd_param].shape
346+
prepared_tensor = prepare_tensor([tensor0, tensor1], target_shape)
347+
model.state_dict()[pd_param].set_value(prepared_tensor)
345348
else:
346349
if weight_map[hf_name[0]] == filename:
347350
tensor0 = f.get_tensor(hf_name[0])
348351
with safe_open(
349352
ckpt_pre + weight_map[hf_name[1]], framework="paddle", device="cpu"
350353
) as f_other:
351354
tensor1 = f_other.get_tensor(hf_name[1])
355+
target_shape = model.state_dict()[pd_param].shape
356+
prepared_tensor = prepare_tensor([tensor0, tensor1], target_shape)
357+
model.state_dict()[pd_param].set_value(prepared_tensor)
352358
else:
353359
with safe_open(
354360
ckpt_pre + weight_map[hf_name[0]], framework="paddle", device="cpu"
355361
) as f_other:
356362
tensor0 = f_other.get_tensor(hf_name[0])
357-
tensor1 = f.get_tensor(hf_name[1])
358-
model.state_dict()[pd_param].set_value(
359-
prepare_tensor([tensor0, tensor1], model.state_dict()[pd_param].shape)
360-
)
363+
tensor1 = f.get_tensor(hf_name[1])
364+
model.state_dict()[pd_param].set_value(
365+
prepare_tensor([tensor0, tensor1], model.state_dict()[pd_param].shape)
366+
)
361367
check_list.append(pd_param)
362368

363369
except Exception as e:

0 commit comments

Comments
 (0)