1
- import json
2
- from dataclasses import asdict , dataclass , field
1
+ from dataclasses import dataclass , field
3
2
from typing import Literal , Optional
4
3
5
4
@@ -9,22 +8,40 @@ class FreezeArguments:
9
8
Arguments pertaining to the freeze (partial-parameter) training.
10
9
"""
11
10
12
- name_module_trainable : str = field (
11
+ freeze_trainable_layers : int = field (
12
+ default = 2 ,
13
+ metadata = {
14
+ "help" : (
15
+ "The number of trainable layers for freeze (partial-parameter) fine-tuning. "
16
+ "Positive numbers mean the last n layers are set as trainable, "
17
+ "negative numbers mean the first n layers are set as trainable."
18
+ )
19
+ },
20
+ )
21
+ freeze_trainable_modules : str = field (
13
22
default = "all" ,
14
23
metadata = {
15
- "help" : """Name of trainable modules for partial-parameter (freeze) fine-tuning. \
16
- Use commas to separate multiple modules. \
17
- Use "all" to specify all the available modules. \
18
- LLaMA choices: ["mlp", "self_attn"], \
19
- BLOOM & Falcon & ChatGLM choices: ["mlp", "self_attention"], \
20
- Qwen choices: ["mlp", "attn"], \
21
- InternLM2 choices: ["feed_forward", "attention"], \
22
- Others choices: the same as LLaMA."""
24
+ "help" : (
25
+ "Name(s) of trainable modules for freeze (partial-parameter) fine-tuning. "
26
+ "Use commas to separate multiple modules. "
27
+ "Use `all` to specify all the available modules. "
28
+ "LLaMA choices: [`mlp`, `self_attn`], "
29
+ "BLOOM & Falcon & ChatGLM choices: [`mlp`, `self_attention`], "
30
+ "Qwen choices: [`mlp`, `attn`], "
31
+ "InternLM2 choices: [`feed_forward`, `attention`], "
32
+ "Others choices: the same as LLaMA."
33
+ )
23
34
},
24
35
)
25
- num_layer_trainable : int = field (
26
- default = 2 ,
27
- metadata = {"help" : "The number of trainable layers for partial-parameter (freeze) fine-tuning." },
36
+ freeze_extra_modules : Optional [str ] = field (
37
+ default = None ,
38
+ metadata = {
39
+ "help" : (
40
+ "Name(s) of modules apart from hidden layers to be set as trainable "
41
+ "for freeze (partial-parameter) fine-tuning. "
42
+ "Use commas to separate multiple modules."
43
+ )
44
+ },
28
45
)
29
46
30
47
@@ -37,7 +54,11 @@ class LoraArguments:
37
54
additional_target : Optional [str ] = field (
38
55
default = None ,
39
56
metadata = {
40
- "help" : "Name(s) of modules apart from LoRA layers to be set as trainable and saved in the final checkpoint."
57
+ "help" : (
58
+ "Name(s) of modules apart from LoRA layers to be set as trainable "
59
+ "and saved in the final checkpoint. "
60
+ "Use commas to separate multiple modules."
61
+ )
41
62
},
42
63
)
43
64
lora_alpha : Optional [int ] = field (
@@ -55,15 +76,17 @@ class LoraArguments:
55
76
lora_target : str = field (
56
77
default = "all" ,
57
78
metadata = {
58
- "help" : """Name(s) of target modules to apply LoRA. \
59
- Use commas to separate multiple modules. \
60
- Use "all" to specify all the linear modules. \
61
- LLaMA choices: ["q_proj", "k_proj", "v_proj", "o_proj", "gate_proj", "up_proj", "down_proj"], \
62
- BLOOM & Falcon & ChatGLM choices: ["query_key_value", "dense", "dense_h_to_4h", "dense_4h_to_h"], \
63
- Baichuan choices: ["W_pack", "o_proj", "gate_proj", "up_proj", "down_proj"], \
64
- Qwen choices: ["c_attn", "attn.c_proj", "w1", "w2", "mlp.c_proj"], \
65
- InternLM2 choices: ["wqkv", "wo", "w1", "w2", "w3"], \
66
- Others choices: the same as LLaMA."""
79
+ "help" : (
80
+ "Name(s) of target modules to apply LoRA. "
81
+ "Use commas to separate multiple modules. "
82
+ "Use `all` to specify all the linear modules. "
83
+ "LLaMA choices: [`q_proj`, `k_proj`, `v_proj`, `o_proj`, `gate_proj`, `up_proj`, `down_proj`], "
84
+ "BLOOM & Falcon & ChatGLM choices: [`query_key_value`, `dense`, `dense_h_to_4h`, `dense_4h_to_h`], "
85
+ "Baichuan choices: [`W_pack`, `o_proj`, `gate_proj`, `up_proj`, `down_proj`], "
86
+ "Qwen choices: [`c_attn`, `attn.c_proj`, `w1`, `w2`, `mlp.c_proj`], "
87
+ "InternLM2 choices: [`wqkv`, `wo`, `w1`, `w2`, `w3`], "
88
+ "Others choices: the same as LLaMA."
89
+ )
67
90
},
68
91
)
69
92
loraplus_lr_ratio : Optional [float ] = field (
@@ -177,8 +200,10 @@ class GaloreArguments:
177
200
galore_target : str = field (
178
201
default = "all" ,
179
202
metadata = {
180
- "help" : """Name(s) of modules to apply GaLore. Use commas to separate multiple modules. \
181
- Use "all" to specify all the linear modules."""
203
+ "help" : (
204
+ "Name(s) of modules to apply GaLore. Use commas to separate multiple modules. "
205
+ "Use `all` to specify all the linear modules."
206
+ )
182
207
},
183
208
)
184
209
galore_rank : int = field (
@@ -238,16 +263,20 @@ class BAdamArgument:
238
263
badam_mask_mode : Literal ["adjacent" , "scatter" ] = field (
239
264
default = "adjacent" ,
240
265
metadata = {
241
- "help" : """The mode of the mask for BAdam optimizer. \
242
- `adjacent` means that the trainable parameters are adjacent to each other, \
243
- `scatter` means that trainable parameters are randomly choosed from the weight."""
266
+ "help" : (
267
+ "The mode of the mask for BAdam optimizer. "
268
+ "`adjacent` means that the trainable parameters are adjacent to each other, "
269
+ "`scatter` means that trainable parameters are randomly choosed from the weight."
270
+ )
244
271
},
245
272
)
246
273
badam_verbose : int = field (
247
274
default = 0 ,
248
275
metadata = {
249
- "help" : """The verbosity level of BAdam optimizer. \
250
- 0 for no print, 1 for print the block prefix, 2 for print trainable parameters"""
276
+ "help" : (
277
+ "The verbosity level of BAdam optimizer. "
278
+ "0 for no print, 1 for print the block prefix, 2 for print trainable parameters."
279
+ )
251
280
},
252
281
)
253
282
@@ -285,7 +314,8 @@ def split_arg(arg):
285
314
return [item .strip () for item in arg .split ("," )]
286
315
return arg
287
316
288
- self .name_module_trainable = split_arg (self .name_module_trainable )
317
+ self .freeze_trainable_modules = split_arg (self .freeze_trainable_modules )
318
+ self .freeze_extra_modules = split_arg (self .freeze_extra_modules )
289
319
self .lora_alpha = self .lora_alpha or self .lora_rank * 2
290
320
self .lora_target = split_arg (self .lora_target )
291
321
self .additional_target = split_arg (self .additional_target )
@@ -315,17 +345,3 @@ def split_arg(arg):
315
345
316
346
if self .loraplus_lr_ratio is not None and self .finetuning_type != "lora" :
317
347
raise ValueError ("`loraplus_lr_ratio` is only valid for the LoRA training." )
318
-
319
- def save_to_json (self , json_path : str ):
320
- r"""Saves the content of this instance in JSON format inside `json_path`."""
321
- json_string = json .dumps (asdict (self ), indent = 2 , sort_keys = True ) + "\n "
322
- with open (json_path , "w" , encoding = "utf-8" ) as f :
323
- f .write (json_string )
324
-
325
- @classmethod
326
- def load_from_json (cls , json_path : str ):
327
- r"""Creates an instance from the content of `json_path`."""
328
- with open (json_path , "r" , encoding = "utf-8" ) as f :
329
- text = f .read ()
330
-
331
- return cls (** json .loads (text ))
0 commit comments