Skip to content

Commit 506f868

Browse files
committed
fix llava rlhf
Former-commit-id: b3e33c7
1 parent 3b42f1a commit 506f868

File tree

5 files changed

+79
-43
lines changed

5 files changed

+79
-43
lines changed

src/llmtuner/model/__init__.py

+2-1
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
from .loader import load_config, load_model, load_tokenizer
2-
from .utils.misc import find_all_linear_modules, load_valuehead_params
2+
from .utils.misc import find_all_linear_modules
3+
from .utils.valuehead import load_valuehead_params
34

45

56
__all__ = [

src/llmtuner/model/loader.py

+4-3
Original file line numberDiff line numberDiff line change
@@ -7,9 +7,10 @@
77
from ..extras.misc import count_parameters, try_download_model_from_ms
88
from .adapter import init_adapter
99
from .patcher import patch_config, patch_model, patch_tokenizer, patch_valuehead_model
10-
from .utils.misc import load_valuehead_params, register_autoclass
10+
from .utils.misc import register_autoclass
1111
from .utils.mod import convert_pretrained_model_to_mod, load_mod_pretrained_model
1212
from .utils.unsloth import load_unsloth_pretrained_model
13+
from .utils.valuehead import load_valuehead_params
1314

1415

1516
if TYPE_CHECKING:
@@ -105,7 +106,7 @@ def load_model(
105106
"""
106107
init_kwargs = _get_init_kwargs(model_args)
107108
config = load_config(model_args)
108-
patch_config(config, tokenizer, model_args, init_kwargs, is_trainable)
109+
patch_config(config, tokenizer, model_args, init_kwargs, is_trainable, add_valuehead)
109110

110111
model = None
111112
lazy_load = False
@@ -130,7 +131,7 @@ def load_model(
130131
model = convert_pretrained_model_to_mod(model, config, model_args)
131132

132133
if not lazy_load:
133-
patch_model(model, tokenizer, model_args, is_trainable)
134+
patch_model(model, tokenizer, model_args, is_trainable, add_valuehead)
134135
register_autoclass(config, model, tokenizer)
135136

136137
model = init_adapter(config, model, model_args, finetuning_args, is_trainable)

src/llmtuner/model/patcher.py

+12-4
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@
1515
from .utils.moe import add_z3_leaf_module, configure_moe
1616
from .utils.quantization import configure_quantization
1717
from .utils.rope import configure_rope
18+
from .utils.valuehead import configure_valuehead, prepare_valuehead_model
1819
from .utils.visual import autocast_projector_dtype
1920

2021

@@ -39,6 +40,7 @@ def patch_config(
3940
model_args: "ModelArguments",
4041
init_kwargs: Dict[str, Any],
4142
is_trainable: bool,
43+
add_valuehead: bool,
4244
) -> None:
4345
if model_args.compute_dtype is None: # priority: bf16 > fp16 > fp32
4446
model_args.compute_dtype = infer_optim_dtype(model_dtype=getattr(config, "torch_dtype", None))
@@ -49,6 +51,9 @@ def patch_config(
4951
configure_quantization(config, tokenizer, model_args, init_kwargs)
5052
configure_moe(config, model_args, is_trainable)
5153

54+
if add_valuehead:
55+
configure_valuehead(config)
56+
5257
if model_args.use_cache and not is_trainable:
5358
setattr(config, "use_cache", True)
5459
logger.info("Using KV cache for faster generation.")
@@ -73,7 +78,11 @@ def patch_config(
7378

7479

7580
def patch_model(
76-
model: "PreTrainedModel", tokenizer: "PreTrainedTokenizer", model_args: "ModelArguments", is_trainable: bool
81+
model: "PreTrainedModel",
82+
tokenizer: "PreTrainedTokenizer",
83+
model_args: "ModelArguments",
84+
is_trainable: bool,
85+
add_valuehead: bool,
7786
) -> None:
7887
gen_config = model.generation_config # check and fix generation config
7988
if not gen_config.do_sample and (
@@ -86,9 +95,8 @@ def patch_model(
8695
if "GenerationMixin" not in str(model.generate.__func__):
8796
model.generate = MethodType(PreTrainedModel.generate, model)
8897

89-
if is_trainable and getattr(model.config, "model_type", None) == "chatglm":
90-
setattr(model, "lm_head", model.transformer.output_layer)
91-
setattr(model, "_keys_to_ignore_on_save", ["lm_head.weight"])
98+
if add_valuehead:
99+
prepare_valuehead_model(model)
92100

93101
if model_args.resize_vocab:
94102
resize_embedding_layer(model, tokenizer)

src/llmtuner/model/utils/misc.py

+2-35
Original file line numberDiff line numberDiff line change
@@ -1,18 +1,13 @@
1-
from typing import TYPE_CHECKING, Dict, List
1+
from typing import TYPE_CHECKING, List
22

33
import torch
4-
from transformers import PreTrainedModel
5-
from transformers.utils import cached_file
64

7-
from ...extras.constants import V_HEAD_SAFE_WEIGHTS_NAME, V_HEAD_WEIGHTS_NAME
85
from ...extras.logging import get_logger
96
from .quantization import QuantizationMethod
107

118

129
if TYPE_CHECKING:
13-
from transformers import PretrainedConfig, PreTrainedTokenizer
14-
15-
from ...hparams import ModelArguments
10+
from transformers import PretrainedConfig, PreTrainedModel, PreTrainedTokenizer
1611

1712

1813
logger = get_logger(__name__)
@@ -74,34 +69,6 @@ def find_expanded_modules(model: "PreTrainedModel", target_modules: List[str], n
7469
return module_names
7570

7671

77-
def load_valuehead_params(path_or_repo_id: str, model_args: "ModelArguments") -> Dict[str, torch.Tensor]:
78-
r"""
79-
Loads value head parameters from Hugging Face Hub or local disk.
80-
81-
Returns: dict with keys `v_head.summary.weight` and `v_head.summary.bias`.
82-
"""
83-
kwargs = {"path_or_repo_id": path_or_repo_id, "cache_dir": model_args.cache_dir, "token": model_args.hf_hub_token}
84-
85-
try:
86-
from safetensors import safe_open
87-
88-
vhead_file = cached_file(filename=V_HEAD_SAFE_WEIGHTS_NAME, **kwargs)
89-
with safe_open(vhead_file, framework="pt", device="cpu") as f:
90-
return {key: f.get_tensor(key) for key in f.keys()}
91-
except Exception as err:
92-
logger.info("Failed to load {}: {}".format(V_HEAD_SAFE_WEIGHTS_NAME, str(err)))
93-
94-
try:
95-
vhead_file = cached_file(filename=V_HEAD_WEIGHTS_NAME, **kwargs)
96-
return torch.load(vhead_file, map_location="cpu")
97-
except Exception as err:
98-
logger.info("Failed to load {}: {}".format(V_HEAD_WEIGHTS_NAME, str(err)))
99-
100-
logger.info("Provided path ({}) does not contain value head weights.".format(path_or_repo_id))
101-
logger.info("Ignore these messages if you are not resuming the training of a value head model.")
102-
return None
103-
104-
10572
def register_autoclass(config: "PretrainedConfig", model: "PreTrainedModel", tokenizer: "PreTrainedTokenizer"):
10673
if "AutoConfig" in getattr(config, "auto_map", {}):
10774
config.__class__.register_for_auto_class()

src/llmtuner/model/utils/valuehead.py

+59
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,59 @@
1+
from typing import TYPE_CHECKING, Dict
2+
3+
import torch
4+
from transformers.utils import cached_file
5+
6+
from ...extras.constants import V_HEAD_SAFE_WEIGHTS_NAME, V_HEAD_WEIGHTS_NAME
7+
from ...extras.logging import get_logger
8+
9+
10+
if TYPE_CHECKING:
11+
from transformers import PretrainedConfig, PreTrainedModel
12+
13+
from ...hparams import ModelArguments
14+
15+
16+
logger = get_logger(__name__)
17+
18+
19+
def configure_valuehead(config: "PretrainedConfig") -> None:
20+
if getattr(config, "model_type", None) == "llava":
21+
setattr(config, "hidden_size", getattr(config.vision_config, "intermediate_size", None))
22+
23+
24+
def load_valuehead_params(path_or_repo_id: str, model_args: "ModelArguments") -> Dict[str, torch.Tensor]:
25+
r"""
26+
Loads value head parameters from Hugging Face Hub or local disk.
27+
28+
Returns: dict with keys `v_head.summary.weight` and `v_head.summary.bias`.
29+
"""
30+
kwargs = {"path_or_repo_id": path_or_repo_id, "cache_dir": model_args.cache_dir, "token": model_args.hf_hub_token}
31+
32+
try:
33+
from safetensors import safe_open
34+
35+
vhead_file = cached_file(filename=V_HEAD_SAFE_WEIGHTS_NAME, **kwargs)
36+
with safe_open(vhead_file, framework="pt", device="cpu") as f:
37+
return {key: f.get_tensor(key) for key in f.keys()}
38+
except Exception as err:
39+
logger.info("Failed to load {}: {}".format(V_HEAD_SAFE_WEIGHTS_NAME, str(err)))
40+
41+
try:
42+
vhead_file = cached_file(filename=V_HEAD_WEIGHTS_NAME, **kwargs)
43+
return torch.load(vhead_file, map_location="cpu")
44+
except Exception as err:
45+
logger.info("Failed to load {}: {}".format(V_HEAD_WEIGHTS_NAME, str(err)))
46+
47+
logger.info("Provided path ({}) does not contain value head weights.".format(path_or_repo_id))
48+
logger.info("Ignore these messages if you are not resuming the training of a value head model.")
49+
return None
50+
51+
52+
def prepare_valuehead_model(model: "PreTrainedModel") -> None:
53+
if getattr(model.config, "model_type", None) == "llava":
54+
setattr(model, "lm_head", model.language_model.get_output_embeddings())
55+
setattr(model, "_keys_to_ignore_on_save", ["lm_head.weight"])
56+
57+
if getattr(model.config, "model_type", None) == "chatglm":
58+
setattr(model, "lm_head", model.transformer.output_layer)
59+
setattr(model, "_keys_to_ignore_on_save", ["lm_head.weight"])

0 commit comments

Comments
 (0)