24
24
]
25
25
26
26
import inspect
27
- import torch
28
27
import os
29
- from typing import Any , Callable , Dict , List , Optional , Set , Tuple , Type , TypeVar , Union
30
- from collections import OrderedDict
31
28
import re
29
+ from collections import OrderedDict
30
+ from typing import Any , Callable , Dict , List , Optional , Set , Tuple , Type , TypeVar , Union
31
+
32
+ import torch
32
33
33
34
# Skip some modules sensitive to quantization
34
35
SKIP_QUANTIZATION_MODULES = [
35
36
"lm_head" ,
36
- "multi_modal_projector" , # Llama 3.2 Vision, Pixtral, Llava
37
- "merger" , # Qwen2 VL
38
- "modality_projection" , # Idefics, SmolVLM
39
- "router" , # MoE Router
40
- "gate" , # MoE Router
37
+ "multi_modal_projector" , # Llama 3.2 Vision, Pixtral, Llava
38
+ "merger" , # Qwen2 VL
39
+ "modality_projection" , # Idefics, SmolVLM
40
+ "router" , # MoE Router
41
+ "gate" , # MoE Router
41
42
]
42
43
44
+
43
45
def get_peft_regex (
44
46
model ,
45
- finetune_vision_layers : bool = True ,
46
- finetune_language_layers : bool = True ,
47
- finetune_attention_modules : bool = True ,
48
- finetune_mlp_modules : bool = True ,
49
- target_modules : list [str ] = None ,
50
- vision_tags : list [str ] = ["vision" , "image" , "visual" , "patch" ,],
51
- language_tags : list [str ] = ["language" , "text" ,],
52
- attention_tags : list [str ] = ["self_attn" , "attention" , "attn" ,],
53
- mlp_tags : list [str ] = ["mlp" , "feed_forward" , "ffn" , "dense" ,],
47
+ * ,
48
+ finetune_vision_layers : bool = True ,
49
+ finetune_language_layers : bool = True ,
50
+ finetune_attention_modules : bool = True ,
51
+ finetune_mlp_modules : bool = True ,
52
+ target_modules : List [str ] | None = None ,
53
+ vision_tags : List [str ] = ("vision" , "image" , "visual" , "patch" ),
54
+ language_tags : List [str ] = ("language" , "text" ),
55
+ attention_tags : List [str ] = ("self_attn" , "attention" , "attn" ),
56
+ mlp_tags : List [str ] = ("mlp" , "feed_forward" , "ffn" , "dense" ),
54
57
) -> str :
55
58
"""
56
- Create a regex pattern to apply LoRA to only select layers of a model.
59
+ Build a **safe** regular‑expression that matches ONLY the *leaf*
60
+ `torch.nn.Linear` layers we want to adapt with LoRA.
61
+
62
+ The previous implementation matched any module path that merely
63
+ *contained* one of the projection names; after fused‑projection
64
+ rewrites this included helpers such as
65
+ `model.layers.3.mlp.gate_up_proj → ModuleDict`, which PEFT cannot
66
+ patch. We now anchor the name to the **last dot‑separated field**
67
+ so only genuine linear layers match.
57
68
"""
58
- # All Unsloth Zoo code licensed under LGPLv3
59
- if not finetune_vision_layers and not finetune_language_layers :
69
+ # — sanity checks --------------------------------------------------
70
+ if not ( finetune_vision_layers or finetune_language_layers ) :
60
71
raise RuntimeError (
61
- "Unsloth: No layers to finetune - please select to finetune the vision and/or the language layers! "
72
+ "Select at least one of vision / language layers to finetune. "
62
73
)
63
- if not finetune_attention_modules and not finetune_mlp_modules :
74
+ if not ( finetune_attention_modules or finetune_mlp_modules ) :
64
75
raise RuntimeError (
65
- "Unsloth: No modules to finetune - please select to finetune the attention and/or the mlp modules! "
76
+ "Select at least one of attention / MLP modules to finetune. "
66
77
)
67
- pass
68
78
69
- from collections import Counter
70
- # Get only linear layers
71
- modules = model .named_modules ()
72
- linear_modules = [name for name , module in modules if isinstance (module , torch .nn .Linear )]
73
- all_linear_modules = Counter (x .rsplit ("." )[- 1 ] for x in linear_modules )
79
+ # — collect all leaf‑names of *linear* layers ----------------------
80
+ linear_modules = [
81
+ name for name , mod in model .named_modules () if isinstance (mod , torch .nn .Linear )
82
+ ]
83
+ leaf_names = [path .rsplit ("." , 1 )[- 1 ] for path in linear_modules ]
84
+ leaf_counts = Counter (leaf_names )
74
85
75
- # Isolate lm_head / projection matrices if count == 1
76
86
if target_modules is None :
77
- only_linear_modules = []
78
- projection_modules = {}
79
- for j , (proj , count ) in enumerate (all_linear_modules .items ()):
80
- if count != 1 :
81
- only_linear_modules .append (proj )
82
- else :
83
- projection_modules [proj ] = j
84
- pass
87
+ # keep names that appear in *more* than one place
88
+ # (single‑occurrence heads are usually lm_head / projectors)
89
+ candidate_leafs = [n for n , c in leaf_counts .items () if c > 1 ]
85
90
else :
86
- assert (type (target_modules ) is list )
87
- only_linear_modules = list (target_modules )
88
- pass
89
-
90
- # Create regex matcher
91
- regex_model_parts = []
92
- if finetune_vision_layers : regex_model_parts += vision_tags
93
- if finetune_language_layers : regex_model_parts += language_tags
94
- regex_components = []
95
- if finetune_attention_modules : regex_components += attention_tags
96
- if finetune_mlp_modules : regex_components += mlp_tags
97
-
98
- regex_model_parts = "|" .join (regex_model_parts )
99
- regex_components = "|" .join (regex_components )
100
-
101
- match_linear_modules = r"(?:" + "|" .join (re .escape (x ) for x in only_linear_modules ) + r")"
102
- regex_matcher = \
103
- r".*?(?:" + regex_model_parts + \
104
- r").*?(?:" + regex_components + \
105
- r").*?" + match_linear_modules + ".*?"
106
-
107
- # Also account for model.layers.0.self_attn/mlp type modules like Qwen
91
+ if not isinstance (target_modules , list ):
92
+ raise TypeError ("`target_modules` must be a list of strings." )
93
+ candidate_leafs = list (target_modules )
94
+
95
+ # — assemble regex parts ------------------------------------------
96
+ def _join (xs ):
97
+ return "|" .join (map (re .escape , xs )) or "$^" # empty → no match
98
+
99
+ # which *part* of the model path (vision/language)
100
+ model_part_pat = (
101
+ _join (vision_tags if finetune_vision_layers else [])
102
+ + "|"
103
+ + _join (language_tags if finetune_language_layers else "" )
104
+ )
105
+ # which *sub‑module* inside the block (attn/mlp)
106
+ component_pat = (
107
+ _join (attention_tags if finetune_attention_modules else [])
108
+ + "|"
109
+ + _join (mlp_tags if finetune_mlp_modules else "" )
110
+ )
111
+
112
+ # exact leaf names – anchor to “preceded by dot or start” AND “end of string”
113
+ leaf_pat = r"(?:(?<=\.)|^)(?:" + _join (candidate_leafs ) + r")$"
114
+
115
+ # full matcher
116
+ regex_matcher = (
117
+ r".*?(?:" + model_part_pat + r")" # vision / language part
118
+ r".*?(?:" + component_pat + r")" # attn / mlp component
119
+ r".*?" + leaf_pat # leaf linear layer
120
+ )
121
+
122
+ # also allow Qwen‑style `model.layers.0.self_attn.q_proj` paths
108
123
if finetune_language_layers :
109
- regex_matcher = r"(?:" + regex_matcher + \
110
- r")|(?:\bmodel\.layers\.[\d]{1,}\.(?:" + regex_components + \
111
- r")\.(?:" + match_linear_modules + r"))"
112
- pass
113
-
114
- # Check if regex is wrong since model does not have vision parts
115
- check = any (re .search (regex_matcher , name , flags = re .DOTALL ) for name in linear_modules )
116
- if not check :
117
- regex_matcher = \
118
- r".*?(?:" + regex_components + \
119
- r").*?" + match_linear_modules + ".*?"
120
- pass
121
-
122
- # Final check to confirm if matches exist
123
- check = any (re .search (regex_matcher , name , flags = re .DOTALL ) for name in linear_modules )
124
- if not check and target_modules is not None :
125
- raise RuntimeError (
126
- f"Unsloth: No layers to finetune? You most likely specified target_modules = { target_modules } incorrectly!"
124
+ regex_matcher = (
125
+ regex_matcher
126
+ + "|"
127
+ + r"(?:\bmodel\.layers\.\d+\.(?:"
128
+ + component_pat
129
+ + r")\."
130
+ + leaf_pat
131
+ + ")"
127
132
)
128
- elif not check :
133
+
134
+ # — verify we actually hit something ------------------------------
135
+ if not any (re .search (regex_matcher , n ) for n in linear_modules ):
129
136
raise RuntimeError (
130
- f"Unsloth: No layers to finetune for { model .config ._name_or_path } . Please file a bug report!"
137
+ f"Unsloth: the generated regex matched **no** linear layers "
138
+ f"in { model .__class__ .__name__ } . "
139
+ f"Check your *tags* / *target_modules* settings."
131
140
)
132
- pass
133
141
return regex_matcher
142
+
143
+
134
144
pass
135
145
136
146
137
147
def get_lora_layer_modules ():
138
148
# All Unsloth Zoo code licensed under LGPLv3
139
149
import peft .tuners .lora
150
+
140
151
path = os .path .split (peft .tuners .lora .__file__ )[0 ]
141
152
files = os .listdir (path )
142
153
143
154
Linear_LoRA_Layers = []
144
155
for file in files :
145
- if file == "__init__.py" or not file .endswith (".py" ): continue
156
+ if file == "__init__.py" or not file .endswith (".py" ):
157
+ continue
146
158
item = f"peft.tuners.lora.{ file [:- len ('.py' )]} "
147
159
exec (f"import { item } " , locals (), globals ())
148
160
modules = dir (eval (item ))
149
161
modules = [x for x in modules if x .startswith ("Linear" ) or x .endswith ("Linear" )]
150
- if len (modules ) == 0 : continue
162
+ if len (modules ) == 0 :
163
+ continue
151
164
exec (f"from { item } import ({ ', ' .join (modules )} )" , locals (), globals ())
152
- Linear_LoRA_Layers += [(eval (x ), item , x ,) for x in modules ]
165
+ Linear_LoRA_Layers += [
166
+ (
167
+ eval (x ),
168
+ item ,
169
+ x ,
170
+ )
171
+ for x in modules
172
+ ]
153
173
pass
154
174
return tuple (Linear_LoRA_Layers )
175
+
176
+
155
177
pass
156
178
157
179
@@ -164,16 +186,20 @@ def register_other_hooks(name1, name2, module, _hooks):
164
186
other_hooks = []
165
187
for value in old_hooks .values ():
166
188
qualname = getattr (value , "__qualname__" , "" )
167
- name = getattr (value , "__name__" , "" )
168
- if name1 in qualname or name2 in qualname : pass
169
- elif name2 in name or name2 in name : pass
170
- else : other_hooks .append (value )
189
+ name = getattr (value , "__name__" , "" )
190
+ if name1 in qualname or name2 in qualname :
191
+ pass
192
+ elif name2 in name or name2 in name :
193
+ pass
194
+ else :
195
+ other_hooks .append (value )
171
196
pass
172
197
# Keep none input requires grad hooks
173
198
exec (f"module.{ _hooks } = OrderedDict()" )
174
199
for hook in other_hooks :
175
200
exec (f"module.register{ _hooks [:- 1 ]} (hook)" )
176
201
pass
202
+
177
203
pass
178
204
179
205
# Remove all previous forward hooks for gradient checkpointing
@@ -199,6 +225,7 @@ def requires_grad_post_hook(module, input, output):
199
225
output .loss .requires_grad_ (True )
200
226
except Exception as _ :
201
227
raise RuntimeError ("Unsloth: Failed to make output require gradients!" )
228
+
202
229
pass
203
230
204
231
def requires_grad_pre_hook (module , input ):
@@ -208,19 +235,22 @@ def requires_grad_pre_hook(module, input):
208
235
elif type_input is tuple or type_input is list :
209
236
if len (input ) == 0 :
210
237
raise RuntimeError ("Unsloth: Failed to make input require gradients!" )
211
- # print(f" WARNING: Empty list input to {module.__class__.__name__}!") #
238
+ # print(f" WARNING: Empty list input to {module.__class__.__name__}!") #
212
239
# return
213
240
if torch .is_floating_point (input [0 ]):
214
241
input [0 ].requires_grad_ (True )
215
242
else :
216
243
raise RuntimeError ("Unsloth: Failed to make input require gradients!" )
244
+
217
245
pass
218
246
219
247
# Find 1st ever item which requires grad
220
248
param = None
221
249
for name , param in model .named_parameters ():
222
- if param .requires_grad : break
223
- if param is None : return
250
+ if param .requires_grad :
251
+ break
252
+ if param is None :
253
+ return
224
254
225
255
name = re .sub ("\.([\d]{1,})\." , r"[\1]." , name )
226
256
name_components = name .split ("." )
@@ -230,15 +260,18 @@ def requires_grad_pre_hook(module, input):
230
260
231
261
final_where = None
232
262
# Try getting previous parent module
233
- for j in range (len (name_components )- 1 , 0 , - 1 ):
263
+ for j in range (len (name_components ) - 1 , 0 , - 1 ):
234
264
name_curr = name_components [j ]
235
- name_pre = "model." + "." .join (name_components [:j ])
265
+ name_pre = "model." + "." .join (name_components [:j ])
236
266
# Disable [\d] since it fails in gradient checkpointing
237
- if re .search (r"\[[\d]{1,}\]" , name_pre ): continue
267
+ if re .search (r"\[[\d]{1,}\]" , name_pre ):
268
+ continue
238
269
module = eval (name_pre )
239
270
if hasattr (module , "forward" ):
240
- try : forward = inspect .getsource (module .forward )
241
- except : continue
271
+ try :
272
+ forward = inspect .getsource (module .forward )
273
+ except :
274
+ continue
242
275
243
276
# Normal self.language_model(...)
244
277
if f"self.{ name_curr } (" in forward :
@@ -250,7 +283,9 @@ def requires_grad_pre_hook(module, input):
250
283
if f"in self.{ module_list } :" in forward :
251
284
final_where = j
252
285
break
253
- elif re .search (r"for [^\s]{3,} in self\." + module_list , forward ) is not None :
286
+ elif (
287
+ re .search (r"for [^\s]{3,} in self\." + module_list , forward ) is not None
288
+ ):
254
289
# Might have failed finding self.layers: like self.layers[...]:
255
290
final_where = j
256
291
break
@@ -274,11 +309,18 @@ def requires_grad_pre_hook(module, input):
274
309
module_name = "model." + "." .join (name_components [:final_where ])
275
310
module = eval (module_name )
276
311
277
- if hasattr (module , "config" ) and (module .config .__class__ .__name__ in ("CLIPVisionConfig" , "SiglipVisionConfig" ,)):
312
+ if hasattr (module , "config" ) and (
313
+ module .config .__class__ .__name__
314
+ in (
315
+ "CLIPVisionConfig" ,
316
+ "SiglipVisionConfig" ,
317
+ )
318
+ ):
278
319
# CLIP - backtrack to get_input_embeddings since requires_grad fails!
279
320
old_module = model
280
321
for module_name , module in model .named_modules ():
281
- if not hasattr (module , "get_input_embeddings" ): break
322
+ if not hasattr (module , "get_input_embeddings" ):
323
+ break
282
324
old_module = module
283
325
module = old_module
284
326
pass
@@ -314,6 +356,8 @@ def requires_grad_pre_hook(module, input):
314
356
)
315
357
module .register_forward_pre_hook (requires_grad_pre_hook )
316
358
pass
359
+
360
+
317
361
pass
318
362
319
363
# Unsloth Zoo - Utilities for Unsloth
0 commit comments