1
1
import os
2
2
from pathlib import Path
3
3
from typing import Any , Dict , List , Optional , Tuple , Union
4
- import transformers
5
4
6
5
import torch
6
+ import transformers
7
7
from torch import nn
8
8
9
9
from xturing .engines .causal import CausalEngine , CausalLoraEngine
10
10
from xturing .engines .llama_utils import LlamaConfig , LlamaForCausalLM , LlamaTokenizer
11
11
from xturing .engines .lora_engine import prepare_model_for_int8_training
12
- from xturing .engines .quant_utils import make_quant , autotune_warmup
12
+ from xturing .engines .quant_utils import autotune_warmup , make_quant
13
13
from xturing .utils .hub import ModelHub
14
14
15
+
15
16
class LLamaEngine (CausalEngine ):
16
17
config_name : str = "llama_engine"
17
18
@@ -102,24 +103,28 @@ def __init__(self, weights_path: Optional[Union[str, Path]] = None):
102
103
target_modules = ["q_proj" , "v_proj" ],
103
104
)
104
105
105
- def find_layers (module , layers = [nn .Conv2d , nn .Linear ], name = '' ):
106
+
107
+ def find_layers (module , layers = [nn .Conv2d , nn .Linear ], name = "" ):
106
108
if type (module ) in layers :
107
109
return {name : module }
108
110
res = {}
109
111
for name1 , child in module .named_children ():
110
- res .update (find_layers (
111
- child , layers = layers , name = name + '.' + name1 if name != '' else name1
112
- ))
112
+ res .update (
113
+ find_layers (
114
+ child , layers = layers , name = name + "." + name1 if name != "" else name1
115
+ )
116
+ )
113
117
return res
114
118
119
+
115
120
class LlamaLoraInt4Engine (CausalLoraEngine ):
116
121
config_name : str = "llama_lora_int4_engine"
117
122
118
123
def __init__ (self , weights_path : Optional [Union [str , Path ]] = None ):
119
- model_name = "decapoda-research/llama-7b-hf"
124
+ model_name = "decapoda-research/llama-7b-hf"
120
125
121
126
if weights_path is None :
122
- weights_path = ModelHub ().load ("x/llama_lora_int4" )
127
+ weights_path = ModelHub ().load ("x/llama_lora_int4" )
123
128
124
129
config = LlamaConfig .from_pretrained (model_name )
125
130
@@ -129,10 +134,10 @@ def __init__(self, weights_path: Optional[Union[str, Path]] = None):
129
134
130
135
def noop (* args , ** kwargs ):
131
136
pass
132
-
133
- torch .nn .init .kaiming_uniform_ = noop
134
- torch .nn .init .uniform_ = noop
135
- torch .nn .init .normal_ = noop
137
+
138
+ torch .nn .init .kaiming_uniform_ = noop
139
+ torch .nn .init .uniform_ = noop
140
+ torch .nn .init .normal_ = noop
136
141
137
142
torch .set_default_dtype (torch .half )
138
143
transformers .modeling_utils ._init_weights = False
@@ -143,18 +148,23 @@ def noop(*args, **kwargs):
143
148
144
149
layers = find_layers (model )
145
150
146
- for name in [' lm_head' ]:
151
+ for name in [" lm_head" ]:
147
152
if name in layers :
148
153
del layers [name ]
149
-
154
+
150
155
wbits = 4
151
156
groupsize = 128
152
- warmup_autotune = True
153
-
157
+ warmup_autotune = True
158
+
154
159
make_quant (model , layers , wbits , groupsize )
155
-
156
160
157
- model .load_state_dict (torch .load (weights_path / Path ("pytorch_model.bin" )), strict = False )
161
+ state_dict = torch .load (
162
+ weights_path / Path ("pytorch_model.bin" ), map_location = "cpu"
163
+ )
164
+ new_state_dict = {}
165
+ for key , value in state_dict .items ():
166
+ new_state_dict [key [6 :]] = value
167
+ model .load_state_dict (new_state_dict , strict = False )
158
168
159
169
if warmup_autotune :
160
170
autotune_warmup (model )
@@ -171,12 +181,12 @@ def noop(*args, **kwargs):
171
181
tokenizer .pad_token_id = tokenizer .eos_token_id
172
182
173
183
super ().__init__ (
174
- model = model ,
184
+ model = model ,
175
185
tokenizer = tokenizer ,
176
186
target_modules = [
177
187
"q_proj" ,
178
188
"v_proj" ,
179
- ]
189
+ ],
180
190
)
181
191
182
192
torch .nn .init .kaiming_uniform_ = saved_kaiming_uniform_
0 commit comments