@@ -342,22 +342,28 @@ def load_huggingface_ckpt(model, huggingface_ckpt_path):
342342 if len (files ) == 1 :
343343 tensor0 = f .get_tensor (hf_name [0 ])
344344 tensor1 = f .get_tensor (hf_name [1 ])
345+ target_shape = model .state_dict ()[pd_param ].shape
346+ prepared_tensor = prepare_tensor ([tensor0 , tensor1 ], target_shape )
347+ model .state_dict ()[pd_param ].set_value (prepared_tensor )
345348 else :
346349 if weight_map [hf_name [0 ]] == filename :
347350 tensor0 = f .get_tensor (hf_name [0 ])
348351 with safe_open (
349352 ckpt_pre + weight_map [hf_name [1 ]], framework = "paddle" , device = "cpu"
350353 ) as f_other :
351354 tensor1 = f_other .get_tensor (hf_name [1 ])
355+ target_shape = model .state_dict ()[pd_param ].shape
356+ prepared_tensor = prepare_tensor ([tensor0 , tensor1 ], target_shape )
357+ model .state_dict ()[pd_param ].set_value (prepared_tensor )
352358 else :
353359 with safe_open (
354360 ckpt_pre + weight_map [hf_name [0 ]], framework = "paddle" , device = "cpu"
355361 ) as f_other :
356362 tensor0 = f_other .get_tensor (hf_name [0 ])
357- tensor1 = f .get_tensor (hf_name [1 ])
358- model .state_dict ()[pd_param ].set_value (
359- prepare_tensor ([tensor0 , tensor1 ], model .state_dict ()[pd_param ].shape )
360- )
363+ tensor1 = f .get_tensor (hf_name [1 ])
364+ model .state_dict ()[pd_param ].set_value (
365+ prepare_tensor ([tensor0 , tensor1 ], model .state_dict ()[pd_param ].shape )
366+ )
361367 check_list .append (pd_param )
362368
363369 except Exception as e :
0 commit comments