Skip to content

Commit 6e6267f

Browse files
committed
fix #3694
Former-commit-id: 2a67ab3
1 parent a84f155 commit 6e6267f

File tree

7 files changed

+132
-76
lines changed

7 files changed

+132
-76
lines changed

examples/extras/llama_pro/llama3_freeze_sft.yaml

+2-2
Original file line numberDiff line numberDiff line change
@@ -5,8 +5,8 @@ model_name_or_path: models/llama3-8b-instruct-pro
55
stage: sft
66
do_train: true
77
finetuning_type: freeze
8-
name_module_trainable: all
9-
num_layer_trainable: 8
8+
freeze_trainable_layers: 8
9+
freeze_trainable_modules: all
1010
use_llama_pro: true
1111

1212
# dataset

scripts/llama_pro.py

+2-3
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
# coding=utf-8
2-
# Performs block expansion for LLaMA, Mistral or Qwen1.5 models.
2+
# Performs block expansion for LLaMA, Mistral, Qwen1.5 or Yi models.
33
# Usage: python llama_pro.py --model_name_or_path meta-llama/Llama-2-7b-hf --output_dir llama2_pro --num_expand 8
44
# Inspired by: https://github.com/TencentARC/LLaMA-Pro/blob/main/scripts/block_expansion.py
55

@@ -106,8 +106,7 @@ def block_expansion(
106106
print("Fine-tune this model with:")
107107
print(" --model_name_or_path {} \\".format(output_dir))
108108
print(" --finetuning_type freeze \\")
109-
print(" --name_module_trainable all \\")
110-
print(" --num_layer_trainable {} \\".format(num_expand))
109+
print(" --freeze_trainable_layers {} \\".format(num_expand))
111110
print(" --use_llama_pro")
112111

113112

src/llmtuner/hparams/finetuning_args.py

+62-46
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,4 @@
1-
import json
2-
from dataclasses import asdict, dataclass, field
1+
from dataclasses import dataclass, field
32
from typing import Literal, Optional
43

54

@@ -9,22 +8,40 @@ class FreezeArguments:
98
Arguments pertaining to the freeze (partial-parameter) training.
109
"""
1110

12-
name_module_trainable: str = field(
11+
freeze_trainable_layers: int = field(
12+
default=2,
13+
metadata={
14+
"help": (
15+
"The number of trainable layers for freeze (partial-parameter) fine-tuning. "
16+
"Positive numbers mean the last n layers are set as trainable, "
17+
"negative numbers mean the first n layers are set as trainable."
18+
)
19+
},
20+
)
21+
freeze_trainable_modules: str = field(
1322
default="all",
1423
metadata={
15-
"help": """Name of trainable modules for partial-parameter (freeze) fine-tuning. \
16-
Use commas to separate multiple modules. \
17-
Use "all" to specify all the available modules. \
18-
LLaMA choices: ["mlp", "self_attn"], \
19-
BLOOM & Falcon & ChatGLM choices: ["mlp", "self_attention"], \
20-
Qwen choices: ["mlp", "attn"], \
21-
InternLM2 choices: ["feed_forward", "attention"], \
22-
Others choices: the same as LLaMA."""
24+
"help": (
25+
"Name(s) of trainable modules for freeze (partial-parameter) fine-tuning. "
26+
"Use commas to separate multiple modules. "
27+
"Use `all` to specify all the available modules. "
28+
"LLaMA choices: [`mlp`, `self_attn`], "
29+
"BLOOM & Falcon & ChatGLM choices: [`mlp`, `self_attention`], "
30+
"Qwen choices: [`mlp`, `attn`], "
31+
"InternLM2 choices: [`feed_forward`, `attention`], "
32+
"Others choices: the same as LLaMA."
33+
)
2334
},
2435
)
25-
num_layer_trainable: int = field(
26-
default=2,
27-
metadata={"help": "The number of trainable layers for partial-parameter (freeze) fine-tuning."},
36+
freeze_extra_modules: Optional[str] = field(
37+
default=None,
38+
metadata={
39+
"help": (
40+
"Name(s) of modules apart from hidden layers to be set as trainable "
41+
"for freeze (partial-parameter) fine-tuning. "
42+
"Use commas to separate multiple modules."
43+
)
44+
},
2845
)
2946

3047

@@ -37,7 +54,11 @@ class LoraArguments:
3754
additional_target: Optional[str] = field(
3855
default=None,
3956
metadata={
40-
"help": "Name(s) of modules apart from LoRA layers to be set as trainable and saved in the final checkpoint."
57+
"help": (
58+
"Name(s) of modules apart from LoRA layers to be set as trainable "
59+
"and saved in the final checkpoint. "
60+
"Use commas to separate multiple modules."
61+
)
4162
},
4263
)
4364
lora_alpha: Optional[int] = field(
@@ -55,15 +76,17 @@ class LoraArguments:
5576
lora_target: str = field(
5677
default="all",
5778
metadata={
58-
"help": """Name(s) of target modules to apply LoRA. \
59-
Use commas to separate multiple modules. \
60-
Use "all" to specify all the linear modules. \
61-
LLaMA choices: ["q_proj", "k_proj", "v_proj", "o_proj", "gate_proj", "up_proj", "down_proj"], \
62-
BLOOM & Falcon & ChatGLM choices: ["query_key_value", "dense", "dense_h_to_4h", "dense_4h_to_h"], \
63-
Baichuan choices: ["W_pack", "o_proj", "gate_proj", "up_proj", "down_proj"], \
64-
Qwen choices: ["c_attn", "attn.c_proj", "w1", "w2", "mlp.c_proj"], \
65-
InternLM2 choices: ["wqkv", "wo", "w1", "w2", "w3"], \
66-
Others choices: the same as LLaMA."""
79+
"help": (
80+
"Name(s) of target modules to apply LoRA. "
81+
"Use commas to separate multiple modules. "
82+
"Use `all` to specify all the linear modules. "
83+
"LLaMA choices: [`q_proj`, `k_proj`, `v_proj`, `o_proj`, `gate_proj`, `up_proj`, `down_proj`], "
84+
"BLOOM & Falcon & ChatGLM choices: [`query_key_value`, `dense`, `dense_h_to_4h`, `dense_4h_to_h`], "
85+
"Baichuan choices: [`W_pack`, `o_proj`, `gate_proj`, `up_proj`, `down_proj`], "
86+
"Qwen choices: [`c_attn`, `attn.c_proj`, `w1`, `w2`, `mlp.c_proj`], "
87+
"InternLM2 choices: [`wqkv`, `wo`, `w1`, `w2`, `w3`], "
88+
"Others choices: the same as LLaMA."
89+
)
6790
},
6891
)
6992
loraplus_lr_ratio: Optional[float] = field(
@@ -177,8 +200,10 @@ class GaloreArguments:
177200
galore_target: str = field(
178201
default="all",
179202
metadata={
180-
"help": """Name(s) of modules to apply GaLore. Use commas to separate multiple modules. \
181-
Use "all" to specify all the linear modules."""
203+
"help": (
204+
"Name(s) of modules to apply GaLore. Use commas to separate multiple modules. "
205+
"Use `all` to specify all the linear modules."
206+
)
182207
},
183208
)
184209
galore_rank: int = field(
@@ -238,16 +263,20 @@ class BAdamArgument:
238263
badam_mask_mode: Literal["adjacent", "scatter"] = field(
239264
default="adjacent",
240265
metadata={
241-
"help": """The mode of the mask for BAdam optimizer. \
242-
`adjacent` means that the trainable parameters are adjacent to each other, \
243-
`scatter` means that trainable parameters are randomly choosed from the weight."""
266+
"help": (
267+
"The mode of the mask for BAdam optimizer. "
268+
"`adjacent` means that the trainable parameters are adjacent to each other, "
269+
"`scatter` means that trainable parameters are randomly choosed from the weight."
270+
)
244271
},
245272
)
246273
badam_verbose: int = field(
247274
default=0,
248275
metadata={
249-
"help": """The verbosity level of BAdam optimizer. \
250-
0 for no print, 1 for print the block prefix, 2 for print trainable parameters"""
276+
"help": (
277+
"The verbosity level of BAdam optimizer. "
278+
"0 for no print, 1 for print the block prefix, 2 for print trainable parameters."
279+
)
251280
},
252281
)
253282

@@ -285,7 +314,8 @@ def split_arg(arg):
285314
return [item.strip() for item in arg.split(",")]
286315
return arg
287316

288-
self.name_module_trainable = split_arg(self.name_module_trainable)
317+
self.freeze_trainable_modules = split_arg(self.freeze_trainable_modules)
318+
self.freeze_extra_modules = split_arg(self.freeze_extra_modules)
289319
self.lora_alpha = self.lora_alpha or self.lora_rank * 2
290320
self.lora_target = split_arg(self.lora_target)
291321
self.additional_target = split_arg(self.additional_target)
@@ -315,17 +345,3 @@ def split_arg(arg):
315345

316346
if self.loraplus_lr_ratio is not None and self.finetuning_type != "lora":
317347
raise ValueError("`loraplus_lr_ratio` is only valid for the LoRA training.")
318-
319-
def save_to_json(self, json_path: str):
320-
r"""Saves the content of this instance in JSON format inside `json_path`."""
321-
json_string = json.dumps(asdict(self), indent=2, sort_keys=True) + "\n"
322-
with open(json_path, "w", encoding="utf-8") as f:
323-
f.write(json_string)
324-
325-
@classmethod
326-
def load_from_json(cls, json_path: str):
327-
r"""Creates an instance from the content of `json_path`."""
328-
with open(json_path, "r", encoding="utf-8") as f:
329-
text = f.read()
330-
331-
return cls(**json.loads(text))

src/llmtuner/model/adapter.py

+29-13
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
1+
import re
12
from typing import TYPE_CHECKING
23

34
import torch
@@ -68,37 +69,52 @@ def init_adapter(
6869
raise ValueError("Current model does not support freeze tuning.")
6970

7071
if finetuning_args.use_llama_pro:
71-
if num_layers % finetuning_args.num_layer_trainable != 0:
72+
if num_layers % finetuning_args.freeze_trainable_layers != 0:
7273
raise ValueError(
7374
"`num_layers` {} should be divisible by `num_layer_trainable` {}.".format(
74-
num_layers, finetuning_args.num_layer_trainable
75+
num_layers, finetuning_args.freeze_trainable_layers
7576
)
7677
)
7778

78-
stride = num_layers // finetuning_args.num_layer_trainable
79+
stride = num_layers // finetuning_args.freeze_trainable_layers
7980
trainable_layer_ids = range(stride - 1, num_layers + stride - 1, stride)
80-
elif finetuning_args.num_layer_trainable > 0: # fine-tuning the last n layers if num_layer_trainable > 0
81-
trainable_layer_ids = range(num_layers - finetuning_args.num_layer_trainable, num_layers)
81+
elif finetuning_args.freeze_trainable_layers > 0: # fine-tuning the last n layers if num_layer_trainable > 0
82+
trainable_layer_ids = range(max(0, num_layers - finetuning_args.freeze_trainable_layers), num_layers)
8283
else: # fine-tuning the first n layers if num_layer_trainable < 0
83-
trainable_layer_ids = range(-finetuning_args.num_layer_trainable)
84+
trainable_layer_ids = range(min(-finetuning_args.freeze_trainable_layers, num_layers))
8485

85-
freeze_modules = {"all"}
86-
for name, _ in model.named_modules():
86+
hidden_modules = set()
87+
non_hidden_modules = set()
88+
for name, _ in model.named_parameters():
8789
if ".0." in name:
88-
freeze_modules.add(name.split(".0.")[-1].split(".")[0])
90+
hidden_modules.add(name.split(".0.")[-1].split(".")[0])
8991
elif ".1." in name: # MoD starts from layer 1
90-
freeze_modules.add(name.split(".1.")[-1].split(".")[0])
92+
hidden_modules.add(name.split(".1.")[-1].split(".")[0])
93+
94+
if re.search(r"\.\d+\.", name) is None:
95+
non_hidden_modules.add(name.split(".")[-2])
9196

9297
trainable_layers = []
93-
for module_name in finetuning_args.name_module_trainable:
94-
if module_name not in freeze_modules:
98+
for module_name in finetuning_args.freeze_trainable_modules:
99+
if module_name != "all" and module_name not in hidden_modules:
95100
raise ValueError(
96-
"Module {} is not found, please choose from {}".format(module_name, ", ".join(freeze_modules))
101+
"Module {} is not found, please choose from {}".format(module_name, ", ".join(hidden_modules))
97102
)
98103

99104
for idx in trainable_layer_ids:
100105
trainable_layers.append(".{:d}.{}".format(idx, module_name if module_name != "all" else ""))
101106

107+
if finetuning_args.freeze_extra_modules:
108+
for module_name in finetuning_args.freeze_extra_modules:
109+
if module_name not in non_hidden_modules:
110+
raise ValueError(
111+
"Module {} is not found, please choose from {}".format(
112+
module_name, ", ".join(non_hidden_modules)
113+
)
114+
)
115+
116+
trainable_layers.append(module_name)
117+
102118
for name, param in model.named_parameters():
103119
if any(trainable_layer in name for trainable_layer in trainable_layers):
104120
if cast_trainable_params_to_fp32:

src/llmtuner/webui/components/train.py

+8-4
Original file line numberDiff line numberDiff line change
@@ -124,13 +124,17 @@ def create_train_tab(engine: "Engine") -> Dict[str, "Component"]:
124124

125125
with gr.Accordion(open=False) as freeze_tab:
126126
with gr.Row():
127-
num_layer_trainable = gr.Slider(minimum=1, maximum=128, value=2, step=1)
128-
name_module_trainable = gr.Textbox(value="all")
127+
freeze_trainable_layers = gr.Slider(minimum=-128, maximum=128, value=2, step=1)
128+
freeze_trainable_modules = gr.Textbox(value="all")
129+
freeze_extra_modules = gr.Textbox()
129130

130-
input_elems.update({num_layer_trainable, name_module_trainable})
131+
input_elems.update({freeze_trainable_layers, freeze_trainable_modules, freeze_extra_modules})
131132
elem_dict.update(
132133
dict(
133-
freeze_tab=freeze_tab, num_layer_trainable=num_layer_trainable, name_module_trainable=name_module_trainable
134+
freeze_tab=freeze_tab,
135+
freeze_trainable_layers=freeze_trainable_layers,
136+
freeze_trainable_modules=freeze_trainable_modules,
137+
freeze_extra_modules=freeze_extra_modules,
134138
)
135139
)
136140

src/llmtuner/webui/locales.py

+26-6
Original file line numberDiff line numberDiff line change
@@ -572,24 +572,24 @@
572572
"label": "部分参数微调设置",
573573
},
574574
},
575-
"num_layer_trainable": {
575+
"freeze_trainable_layers": {
576576
"en": {
577577
"label": "Trainable layers",
578-
"info": "The number of trainable layers.",
578+
"info": "Number of the last(+)/first(-) hidden layers to be set as trainable.",
579579
},
580580
"ru": {
581581
"label": "Обучаемые слои",
582-
"info": "Количество обучаемых слоев.",
582+
"info": "Количество последних (+)/первых (-) скрытых слоев, которые будут установлены как обучаемые.",
583583
},
584584
"zh": {
585585
"label": "可训练层数",
586-
"info": "可训练模型层的数量。",
586+
"info": "最末尾(+)/最前端(-)可训练隐藏层的数量。",
587587
},
588588
},
589-
"name_module_trainable": {
589+
"freeze_trainable_modules": {
590590
"en": {
591591
"label": "Trainable modules",
592-
"info": "The name of trainable modules. Use commas to separate multiple modules.",
592+
"info": "Name(s) of trainable modules. Use commas to separate multiple modules.",
593593
},
594594
"ru": {
595595
"label": "Обучаемые модули",
@@ -600,6 +600,26 @@
600600
"info": "可训练模块的名称。使用英文逗号分隔多个名称。",
601601
},
602602
},
603+
"freeze_extra_modules": {
604+
"en": {
605+
"label": "Extra modules (optional)",
606+
"info": (
607+
"Name(s) of modules apart from hidden layers to be set as trainable. "
608+
"Use commas to separate multiple modules."
609+
),
610+
},
611+
"ru": {
612+
"label": "Дополнительные модули (опционально)",
613+
"info": (
614+
"Имена модулей, кроме скрытых слоев, которые следует установить в качестве обучаемых. "
615+
"Используйте запятые для разделения нескольких модулей."
616+
),
617+
},
618+
"zh": {
619+
"label": "额外模块(非必填)",
620+
"info": "除隐藏层以外的可训练模块名称。使用英文逗号分隔多个名称。",
621+
},
622+
},
603623
"lora_tab": {
604624
"en": {
605625
"label": "LoRA configurations",

src/llmtuner/webui/runner.py

+3-2
Original file line numberDiff line numberDiff line change
@@ -146,8 +146,9 @@ def _parse_train_args(self, data: Dict["Component", Any]) -> Dict[str, Any]:
146146
)
147147

148148
if args["finetuning_type"] == "freeze":
149-
args["num_layer_trainable"] = get("train.num_layer_trainable")
150-
args["name_module_trainable"] = get("train.name_module_trainable")
149+
args["freeze_trainable_layers"] = get("train.freeze_trainable_layers")
150+
args["freeze_trainable_modules"] = get("train.freeze_trainable_modules")
151+
args["freeze_extra_modules"] = get("train.freeze_extra_modules") or None
151152
elif args["finetuning_type"] == "lora":
152153
args["lora_rank"] = get("train.lora_rank")
153154
args["lora_alpha"] = get("train.lora_alpha")

0 commit comments

Comments
 (0)