-
Notifications
You must be signed in to change notification settings - Fork 4.5k
/
Copy pathmodel.py
136 lines (111 loc) · 4.89 KB
/
model.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
import re
from typing import Dict, Set
import torch
import torch.nn as nn
from peft import PeftModel, PeftType
def extract_lora_layers(model: PeftModel, names: Set[str], adapter_name: str = "default"):
config = model.peft_config[adapter_name]
if config.peft_type != PeftType.LORA:
raise ValueError(f"Adapter {adapter_name} is not a LORA adapter.")
# to_return = lora_state_dict(model, bias=model.peft_config.bias)
# adapted from `https://github.com/microsoft/LoRA/blob/main/loralib/utils.py`
# to be used directly with the state dict which is necessary when using DeepSpeed or FSDP
bias = config.bias
if bias == "none":
to_return = {k for k in names if "lora_" in k}
elif bias == "all":
to_return = {k for k in names if "lora_" in k or "bias" in k}
elif bias == "lora_only":
to_return = set()
for k in names:
if "lora_" in k:
to_return.add(k)
bias_name = k.split("lora_")[0] + "bias"
if bias_name in names:
to_return.add(bias_name)
else:
raise NotImplementedError
to_return = {k for k in to_return if (("lora_" in k and adapter_name in k) or ("bias" in k))}
if config.use_dora:
# Here we take care of a refactor of DoRA which changed lora_magnitude_vector from a ParameterDict to a
# ModuleDict with a DoraLayer instance. The old parameter is now the "weight" attribute of that layer. Since
# we want the state_dict format not to change, we remove the "weight" part.
new_dora_suffix = f"lora_magnitude_vector.{adapter_name}.weight"
def renamed_dora_weights(k):
if k.endswith(new_dora_suffix):
k = k[:-7] # remove ".weight"
return k
to_return = {renamed_dora_weights(k) for k in to_return}
to_return = {re.sub(f"lora_\S\.{adapter_name}\.(weight|bias)", "base_layer", k) for k in to_return}
return to_return
class PeftUnwrapMixin:
def __init__(self, peft_model: PeftModel):
self.base_model = peft_model.get_base_model()
# peft does not affect buffers
self.lora_layers = extract_lora_layers(peft_model, set(n for n, p in self.base_model.named_parameters()))
potential_lora_weights = set()
for n in self.lora_layers:
potential_lora_weights.add(f"{n}.weight")
potential_lora_weights.add(f"{n}.bias")
self.lora_param_to_origin_param = {n: n.replace("base_layer.", "") for n in potential_lora_weights}
self.origin_param_to_lora_param = {v: k for k, v in self.lora_param_to_origin_param.items()}
def named_parameters(self):
for n, p in self.base_model.named_parameters():
if n in self.lora_param_to_origin_param:
n = self.lora_param_to_origin_param[n]
yield n, p
def named_buffers(self):
return self.base_model.named_buffers()
@property
def _modules(self):
return self.base_model._modules
@property
def _non_persistent_buffers_set(self):
return self.base_model._non_persistent_buffers_set
def patch_state_dict(self, state_dict: Dict[str, torch.Tensor]):
new_state_dict = {}
for k, v in state_dict.items():
if k in self.origin_param_to_lora_param:
k = self.origin_param_to_lora_param[k]
new_state_dict[k] = v
return new_state_dict
def state_dict(self):
state_dict = {}
for k, v in self.base_model.state_dict().items():
if k in self.lora_param_to_origin_param:
k = self.lora_param_to_origin_param[k]
state_dict[k] = v
return state_dict
def load_state_dict(self, state_dict, strict: bool = True, assign: bool = False):
state_dict = self.patch_state_dict(state_dict)
self.base_model.load_state_dict(state_dict, strict=strict, assign=assign)
def __hash__(self):
return hash(self.base_model)
class ModelWrapper(nn.Module):
"""
A wrapper class to define the common interface used by booster.
Args:
module (nn.Module): The model to be wrapped.
"""
def __init__(self, module: nn.Module) -> None:
super().__init__()
self.module = module
def unwrap(self, unwrap_peft: bool = True):
"""
Unwrap the model to return the original model for checkpoint saving/loading.
"""
if isinstance(self.module, ModelWrapper):
model = self.module.unwrap()
else:
model = self.module
if unwrap_peft and isinstance(model, PeftModel):
model = PeftUnwrapMixin(model)
return model
def forward(self, *args, **kwargs):
return self.module(*args, **kwargs)
class AMPModelMixin:
"""This mixin class defines the interface for AMP training."""
def update_master_params(self):
"""
Update the master parameters for AMP training.
"""