-
Notifications
You must be signed in to change notification settings - Fork 1.1k
/
Copy pathmodel_utils.py
77 lines (67 loc) · 2.78 KB
/
model_utils.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
# Copyright (c) Microsoft Corporation.
# SPDX-License-Identifier: Apache-2.0
# DeepSpeed Team
import os
import math
import torch
from transformers import (
AutoConfig,
AutoModel,
)
from huggingface_hub import snapshot_download
from transformers.deepspeed import HfDeepSpeedConfig
from .reward_model import RewardModel
def create_hf_model(model_class,
model_name_or_path,
tokenizer,
ds_config=None,
rlhf_training=False,
disable_dropout=False):
model_config = AutoConfig.from_pretrained(model_name_or_path)
if disable_dropout:
model_config.dropout = 0.0
# Note: dschf is defined in function scope to avoid global effects
# https://huggingface.co/docs/transformers/main_classes/deepspeed#nontrainer-deepspeed-integration
if ds_config is not None and ds_config["zero_optimization"]["stage"] == 3:
dschf = HfDeepSpeedConfig(ds_config)
else:
dschf = None
if rlhf_training:
# the weight loading is handled by create critic model
model = model_class.from_config(model_config)
else:
model = model_class.from_pretrained(
model_name_or_path,
from_tf=bool(".ckpt" in model_name_or_path),
config=model_config, trust_remote_code=True)
model.config.end_token_id = tokenizer.eos_token_id
model.config.pad_token_id = model.config.eos_token_id
model.resize_token_embeddings(int(
8 *
math.ceil(len(tokenizer) / 8.0))) # make the vocab size multiple of 8
return model
def create_critic_model(model_name_or_path,
tokenizer,
ds_config,
num_padding_at_beginning=0,
rlhf_training=False,
disable_dropout=False):
# OPT model family always put a padding token at the beginning of the sequence,
# we did not see this in other models but not sure if it is a general rule
critic_model = create_hf_model(AutoModel, model_name_or_path, tokenizer,
ds_config, rlhf_training, disable_dropout)
critic_model = RewardModel(
critic_model,
tokenizer,
num_padding_at_beginning=num_padding_at_beginning)
if rlhf_training:
if not os.path.isdir(model_name_or_path):
model_name_or_path = snapshot_download(model_name_or_path)
# critic model needs to load the weight here
model_ckpt_path = os.path.join(model_name_or_path, 'pytorch_model.bin')
assert os.path.exists(
model_ckpt_path
), f"Cannot find model checkpoint at {model_ckpt_path}"
critic_model.load_state_dict(
torch.load(model_ckpt_path, map_location='cpu'))
return critic_model