Skip to content

Commit e347c26

Browse files
committed
fix regex
1 parent 29e78c9 commit e347c26

File tree

1 file changed

+143
-99
lines changed

1 file changed

+143
-99
lines changed

unsloth_zoo/peft_utils.py

Lines changed: 143 additions & 99 deletions
Original file line numberDiff line numberDiff line change
@@ -24,134 +24,156 @@
2424
]
2525

2626
import inspect
27-
import torch
2827
import os
29-
from typing import Any, Callable, Dict, List, Optional, Set, Tuple, Type, TypeVar, Union
30-
from collections import OrderedDict
3128
import re
29+
from collections import OrderedDict
30+
from typing import Any, Callable, Dict, List, Optional, Set, Tuple, Type, TypeVar, Union
31+
32+
import torch
3233

3334
# Skip some modules sensitive to quantization
3435
SKIP_QUANTIZATION_MODULES = [
3536
"lm_head",
36-
"multi_modal_projector", # Llama 3.2 Vision, Pixtral, Llava
37-
"merger", # Qwen2 VL
38-
"modality_projection", # Idefics, SmolVLM
39-
"router", # MoE Router
40-
"gate", # MoE Router
37+
"multi_modal_projector", # Llama 3.2 Vision, Pixtral, Llava
38+
"merger", # Qwen2 VL
39+
"modality_projection", # Idefics, SmolVLM
40+
"router", # MoE Router
41+
"gate", # MoE Router
4142
]
4243

44+
4345
def get_peft_regex(
4446
model,
45-
finetune_vision_layers : bool = True,
46-
finetune_language_layers : bool = True,
47-
finetune_attention_modules : bool = True,
48-
finetune_mlp_modules : bool = True,
49-
target_modules : list[str] = None,
50-
vision_tags : list[str] = ["vision", "image", "visual", "patch",],
51-
language_tags : list[str] = ["language", "text",],
52-
attention_tags : list[str] = ["self_attn", "attention", "attn",],
53-
mlp_tags : list[str] = ["mlp", "feed_forward", "ffn", "dense",],
47+
*,
48+
finetune_vision_layers: bool = True,
49+
finetune_language_layers: bool = True,
50+
finetune_attention_modules: bool = True,
51+
finetune_mlp_modules: bool = True,
52+
target_modules: List[str] | None = None,
53+
vision_tags: List[str] = ("vision", "image", "visual", "patch"),
54+
language_tags: List[str] = ("language", "text"),
55+
attention_tags: List[str] = ("self_attn", "attention", "attn"),
56+
mlp_tags: List[str] = ("mlp", "feed_forward", "ffn", "dense"),
5457
) -> str:
5558
"""
56-
Create a regex pattern to apply LoRA to only select layers of a model.
59+
Build a **safe** regular‑expression that matches ONLY the *leaf*
60+
`torch.nn.Linear` layers we want to adapt with LoRA.
61+
62+
The previous implementation matched any module path that merely
63+
*contained* one of the projection names; after fused‑projection
64+
rewrites this included helpers such as
65+
`model.layers.3.mlp.gate_up_proj → ModuleDict`, which PEFT cannot
66+
patch. We now anchor the name to the **last dot‑separated field**
67+
so only genuine linear layers match.
5768
"""
58-
# All Unsloth Zoo code licensed under LGPLv3
59-
if not finetune_vision_layers and not finetune_language_layers:
69+
# — sanity checks --------------------------------------------------
70+
if not (finetune_vision_layers or finetune_language_layers):
6071
raise RuntimeError(
61-
"Unsloth: No layers to finetune - please select to finetune the vision and/or the language layers!"
72+
"Select at least one of vision / language layers to finetune."
6273
)
63-
if not finetune_attention_modules and not finetune_mlp_modules:
74+
if not (finetune_attention_modules or finetune_mlp_modules):
6475
raise RuntimeError(
65-
"Unsloth: No modules to finetune - please select to finetune the attention and/or the mlp modules!"
76+
"Select at least one of attention / MLP modules to finetune."
6677
)
67-
pass
6878

69-
from collections import Counter
70-
# Get only linear layers
71-
modules = model.named_modules()
72-
linear_modules = [name for name, module in modules if isinstance(module, torch.nn.Linear)]
73-
all_linear_modules = Counter(x.rsplit(".")[-1] for x in linear_modules)
79+
# — collect all leaf‑names of *linear* layers ----------------------
80+
linear_modules = [
81+
name for name, mod in model.named_modules() if isinstance(mod, torch.nn.Linear)
82+
]
83+
leaf_names = [path.rsplit(".", 1)[-1] for path in linear_modules]
84+
leaf_counts = Counter(leaf_names)
7485

75-
# Isolate lm_head / projection matrices if count == 1
7686
if target_modules is None:
77-
only_linear_modules = []
78-
projection_modules = {}
79-
for j, (proj, count) in enumerate(all_linear_modules.items()):
80-
if count != 1:
81-
only_linear_modules.append(proj)
82-
else:
83-
projection_modules[proj] = j
84-
pass
87+
# keep names that appear in *more* than one place
88+
# (single‑occurrence heads are usually lm_head / projectors)
89+
candidate_leafs = [n for n, c in leaf_counts.items() if c > 1]
8590
else:
86-
assert(type(target_modules) is list)
87-
only_linear_modules = list(target_modules)
88-
pass
89-
90-
# Create regex matcher
91-
regex_model_parts = []
92-
if finetune_vision_layers: regex_model_parts += vision_tags
93-
if finetune_language_layers: regex_model_parts += language_tags
94-
regex_components = []
95-
if finetune_attention_modules: regex_components += attention_tags
96-
if finetune_mlp_modules: regex_components += mlp_tags
97-
98-
regex_model_parts = "|".join(regex_model_parts)
99-
regex_components = "|".join(regex_components)
100-
101-
match_linear_modules = r"(?:" + "|".join(re.escape(x) for x in only_linear_modules) + r")"
102-
regex_matcher = \
103-
r".*?(?:" + regex_model_parts + \
104-
r").*?(?:" + regex_components + \
105-
r").*?" + match_linear_modules + ".*?"
106-
107-
# Also account for model.layers.0.self_attn/mlp type modules like Qwen
91+
if not isinstance(target_modules, list):
92+
raise TypeError("`target_modules` must be a list of strings.")
93+
candidate_leafs = list(target_modules)
94+
95+
# — assemble regex parts ------------------------------------------
96+
def _join(xs):
97+
return "|".join(map(re.escape, xs)) or "$^" # empty → no match
98+
99+
# which *part* of the model path (vision/language)
100+
model_part_pat = (
101+
_join(vision_tags if finetune_vision_layers else [])
102+
+ "|"
103+
+ _join(language_tags if finetune_language_layers else "")
104+
)
105+
# which *sub‑module* inside the block (attn/mlp)
106+
component_pat = (
107+
_join(attention_tags if finetune_attention_modules else [])
108+
+ "|"
109+
+ _join(mlp_tags if finetune_mlp_modules else "")
110+
)
111+
112+
# exact leaf names – anchor to “preceded by dot or start” AND “end of string”
113+
leaf_pat = r"(?:(?<=\.)|^)(?:" + _join(candidate_leafs) + r")$"
114+
115+
# full matcher
116+
regex_matcher = (
117+
r".*?(?:" + model_part_pat + r")" # vision / language part
118+
r".*?(?:" + component_pat + r")" # attn / mlp component
119+
r".*?" + leaf_pat # leaf linear layer
120+
)
121+
122+
# also allow Qwen‑style `model.layers.0.self_attn.q_proj` paths
108123
if finetune_language_layers:
109-
regex_matcher = r"(?:" + regex_matcher + \
110-
r")|(?:\bmodel\.layers\.[\d]{1,}\.(?:" + regex_components + \
111-
r")\.(?:" + match_linear_modules + r"))"
112-
pass
113-
114-
# Check if regex is wrong since model does not have vision parts
115-
check = any(re.search(regex_matcher, name, flags = re.DOTALL) for name in linear_modules)
116-
if not check:
117-
regex_matcher = \
118-
r".*?(?:" + regex_components + \
119-
r").*?" + match_linear_modules + ".*?"
120-
pass
121-
122-
# Final check to confirm if matches exist
123-
check = any(re.search(regex_matcher, name, flags = re.DOTALL) for name in linear_modules)
124-
if not check and target_modules is not None:
125-
raise RuntimeError(
126-
f"Unsloth: No layers to finetune? You most likely specified target_modules = {target_modules} incorrectly!"
124+
regex_matcher = (
125+
regex_matcher
126+
+ "|"
127+
+ r"(?:\bmodel\.layers\.\d+\.(?:"
128+
+ component_pat
129+
+ r")\."
130+
+ leaf_pat
131+
+ ")"
127132
)
128-
elif not check:
133+
134+
# — verify we actually hit something ------------------------------
135+
if not any(re.search(regex_matcher, n) for n in linear_modules):
129136
raise RuntimeError(
130-
f"Unsloth: No layers to finetune for {model.config._name_or_path}. Please file a bug report!"
137+
f"Unsloth: the generated regex matched **no** linear layers "
138+
f"in {model.__class__.__name__}. "
139+
f"Check your *tags* / *target_modules* settings."
131140
)
132-
pass
133141
return regex_matcher
142+
143+
134144
pass
135145

136146

137147
def get_lora_layer_modules():
138148
# All Unsloth Zoo code licensed under LGPLv3
139149
import peft.tuners.lora
150+
140151
path = os.path.split(peft.tuners.lora.__file__)[0]
141152
files = os.listdir(path)
142153

143154
Linear_LoRA_Layers = []
144155
for file in files:
145-
if file == "__init__.py" or not file.endswith(".py"): continue
156+
if file == "__init__.py" or not file.endswith(".py"):
157+
continue
146158
item = f"peft.tuners.lora.{file[:-len('.py')]}"
147159
exec(f"import {item}", locals(), globals())
148160
modules = dir(eval(item))
149161
modules = [x for x in modules if x.startswith("Linear") or x.endswith("Linear")]
150-
if len(modules) == 0: continue
162+
if len(modules) == 0:
163+
continue
151164
exec(f"from {item} import ({', '.join(modules)})", locals(), globals())
152-
Linear_LoRA_Layers += [(eval(x), item, x,) for x in modules]
165+
Linear_LoRA_Layers += [
166+
(
167+
eval(x),
168+
item,
169+
x,
170+
)
171+
for x in modules
172+
]
153173
pass
154174
return tuple(Linear_LoRA_Layers)
175+
176+
155177
pass
156178

157179

@@ -164,16 +186,20 @@ def register_other_hooks(name1, name2, module, _hooks):
164186
other_hooks = []
165187
for value in old_hooks.values():
166188
qualname = getattr(value, "__qualname__", "")
167-
name = getattr(value, "__name__", "")
168-
if name1 in qualname or name2 in qualname: pass
169-
elif name2 in name or name2 in name: pass
170-
else: other_hooks.append(value)
189+
name = getattr(value, "__name__", "")
190+
if name1 in qualname or name2 in qualname:
191+
pass
192+
elif name2 in name or name2 in name:
193+
pass
194+
else:
195+
other_hooks.append(value)
171196
pass
172197
# Keep none input requires grad hooks
173198
exec(f"module.{_hooks} = OrderedDict()")
174199
for hook in other_hooks:
175200
exec(f"module.register{_hooks[:-1]}(hook)")
176201
pass
202+
177203
pass
178204

179205
# Remove all previous forward hooks for gradient checkpointing
@@ -199,6 +225,7 @@ def requires_grad_post_hook(module, input, output):
199225
output.loss.requires_grad_(True)
200226
except Exception as _:
201227
raise RuntimeError("Unsloth: Failed to make output require gradients!")
228+
202229
pass
203230

204231
def requires_grad_pre_hook(module, input):
@@ -208,19 +235,22 @@ def requires_grad_pre_hook(module, input):
208235
elif type_input is tuple or type_input is list:
209236
if len(input) == 0:
210237
raise RuntimeError("Unsloth: Failed to make input require gradients!")
211-
# print(f" WARNING: Empty list input to {module.__class__.__name__}!") #
238+
# print(f" WARNING: Empty list input to {module.__class__.__name__}!") #
212239
# return
213240
if torch.is_floating_point(input[0]):
214241
input[0].requires_grad_(True)
215242
else:
216243
raise RuntimeError("Unsloth: Failed to make input require gradients!")
244+
217245
pass
218246

219247
# Find 1st ever item which requires grad
220248
param = None
221249
for name, param in model.named_parameters():
222-
if param.requires_grad: break
223-
if param is None: return
250+
if param.requires_grad:
251+
break
252+
if param is None:
253+
return
224254

225255
name = re.sub("\.([\d]{1,})\.", r"[\1].", name)
226256
name_components = name.split(".")
@@ -230,15 +260,18 @@ def requires_grad_pre_hook(module, input):
230260

231261
final_where = None
232262
# Try getting previous parent module
233-
for j in range(len(name_components)-1, 0, -1):
263+
for j in range(len(name_components) - 1, 0, -1):
234264
name_curr = name_components[j]
235-
name_pre = "model." + ".".join(name_components[:j])
265+
name_pre = "model." + ".".join(name_components[:j])
236266
# Disable [\d] since it fails in gradient checkpointing
237-
if re.search(r"\[[\d]{1,}\]", name_pre): continue
267+
if re.search(r"\[[\d]{1,}\]", name_pre):
268+
continue
238269
module = eval(name_pre)
239270
if hasattr(module, "forward"):
240-
try: forward = inspect.getsource(module.forward)
241-
except: continue
271+
try:
272+
forward = inspect.getsource(module.forward)
273+
except:
274+
continue
242275

243276
# Normal self.language_model(...)
244277
if f"self.{name_curr}(" in forward:
@@ -250,7 +283,9 @@ def requires_grad_pre_hook(module, input):
250283
if f"in self.{module_list}:" in forward:
251284
final_where = j
252285
break
253-
elif re.search(r"for [^\s]{3,} in self\." + module_list, forward) is not None:
286+
elif (
287+
re.search(r"for [^\s]{3,} in self\." + module_list, forward) is not None
288+
):
254289
# Might have failed finding self.layers: like self.layers[...]:
255290
final_where = j
256291
break
@@ -274,11 +309,18 @@ def requires_grad_pre_hook(module, input):
274309
module_name = "model." + ".".join(name_components[:final_where])
275310
module = eval(module_name)
276311

277-
if hasattr(module, "config") and (module.config.__class__.__name__ in ("CLIPVisionConfig", "SiglipVisionConfig",)):
312+
if hasattr(module, "config") and (
313+
module.config.__class__.__name__
314+
in (
315+
"CLIPVisionConfig",
316+
"SiglipVisionConfig",
317+
)
318+
):
278319
# CLIP - backtrack to get_input_embeddings since requires_grad fails!
279320
old_module = model
280321
for module_name, module in model.named_modules():
281-
if not hasattr(module, "get_input_embeddings"): break
322+
if not hasattr(module, "get_input_embeddings"):
323+
break
282324
old_module = module
283325
module = old_module
284326
pass
@@ -314,6 +356,8 @@ def requires_grad_pre_hook(module, input):
314356
)
315357
module.register_forward_pre_hook(requires_grad_pre_hook)
316358
pass
359+
360+
317361
pass
318362

319363
# Unsloth Zoo - Utilities for Unsloth

0 commit comments

Comments
 (0)