Skip to content

notebook 示例中使用 save_pretrained 保存 model/checkpoints 的方法有问题 #104

@lalalabox

Description

@lalalabox

问题描述

在若干个 notebook 示例中,使用 save_pretrained 方法保存模型是不对的,可以重新加载这些模型进行检查,发现权重没有变化。

例如ChatGLM2_LoRA注释版.ipynb

# 仅仅保存lora可训练参数
# 覆盖了KerasModel中的load_ckpt和save_ckpt方法
def save_ckpt(self, ckpt_path='checkpoint', accelerator = None):
    unwrap_net = accelerator.unwrap_model(self.net)
    unwrap_net.save_pretrained(ckpt_path)

解决方法

👉根据accelerate官方文档中,保存model/checkpoints的方法👈,应该修正为

# 仅仅保存lora可训练参数
# 覆盖了KerasModel中的load_ckpt和save_ckpt方法
def save_ckpt(self, ckpt_path="checkpoint", accelerator=None):
    unwrap_net = accelerator.unwrap_model(self.net)
    unwrap_net.save_pretrained(
        ckpt_path,
        is_main_process=accelerator.is_main_process,
        save_function=accelerator.save,
        state_dict=accelerator.get_state_dict(self.net),
    )

希望后来者少走弯路,我检查了好久才找到这个问题😂

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions