-
Notifications
You must be signed in to change notification settings - Fork 68
Open
Description
我使用lomo(和zero3)在8张NVIDIA 3090 GPU上微调chatglm2-6b,并使用LOMOTrainer的save_model方法保存。重新加载模型checkpoint后,我发现模型测出来的验证集loss与训练结束时测出来的不一样。我参考deepspeed官方保存模型的代码,重写了save_model(重写的代码如下),发现这个bug解决了。这说明原来版本的save_model有bug,但我还没有找到具体出错原因。
I used LOMO (and zero3) to fine-tune chatglm2-6b on 8 NVIDIA 3090 GPUs and saved it using LOMOTrainer's save_model method. After reloading the model checkpoint, I found that the validation loss measured by the model differed from the validation loss measured at the end of training. I referred to the DeepSpeed official code, rewrote save_model (rewritten code below), and found this bug resolved. This indicates that the original version of save_model has a bug, but I have not yet figured out the specific cause of the error.
def save_model(self, index):
if self.training_args.local_rank in [-1, 0]:
checkpoint_dir = sorted(Path(self.training_args.output_dir).glob("checkpoint-*"))
if len(checkpoint_dir) >= self.training_args.save_total_limit:
shutil.rmtree(checkpoint_dir[0], ignore_errors=True)
torch.distributed.barrier()
if self.training_args.resume_step:
output_dir = os.path.join(self.training_args.output_dir, f"checkpoint-{index+self.training_args.resume_step}")
else:
output_dir = os.path.join(self.training_args.output_dir, f"checkpoint-{index}")
if not os.path.exists(output_dir):
os.makedirs(output_dir, exist_ok=True)
state_dict = OrderedDict() if torch.distributed.get_rank() == 0 else None
shared_params = {}
# Prepare for checkpoint save by ensuring all parameters are partitioned
self.model.optimizer.partition_all_parameters()
with deepspeed.zero.GatheredParameters(list(self.model.module.parameters()), modifier_rank=0):
if torch.distributed.get_rank() == 0:
for name, param in self.model.module.named_parameters():
if param is None:
continue
# can't rely on param.data_ptr() as it will be reused as weights gets
# gathered and reduced, but param.ds_id is unique across all zero weights
# (and shared params will have the same param.ds_id)
if param.ds_id in shared_params:
# shared weights
#print(f"`{key}` is shared with `{shared_params[param.ds_id]}`")
state_dict[name] = state_dict[shared_params[param.ds_id]]
else:
state_dict[name] = param.detach().cpu()
shared_params[param.ds_id] = name
#print(f"param {param.ds_id} {param.shape} {key} ")
# now buffers - not sure if need to take care of potentially shared weights here
for name, buf in self.model.module.named_buffers():
if (buf is not None and name not in self.model.module._non_persistent_buffers_set):
state_dict[name] = buf.detach().cpu()
if len(self.model.optimizer.persistent_parameters) > 0:
self.model.optimizer.persistent_parameters[0].all_gather(self.model.optimizer.persistent_parameters)
if torch.distributed.get_rank() == 0:
torch.save(state_dict, os.path.join(output_dir, 'pytorch_model.bin'))
torch.distributed.barrier()
Metadata
Metadata
Assignees
Labels
No labels