Skip to content

Commit c0b1b76

Browse files
Merge pull request #177 from stochasticai/toan/fix_int4
fix: int4 loading model
2 parents d8bbb07 + 4df89b6 commit c0b1b76

File tree

1 file changed

+30
-20
lines changed

1 file changed

+30
-20
lines changed

src/xturing/engines/llama_engine.py

Lines changed: 30 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -1,17 +1,18 @@
11
import os
22
from pathlib import Path
33
from typing import Any, Dict, List, Optional, Tuple, Union
4-
import transformers
54

65
import torch
6+
import transformers
77
from torch import nn
88

99
from xturing.engines.causal import CausalEngine, CausalLoraEngine
1010
from xturing.engines.llama_utils import LlamaConfig, LlamaForCausalLM, LlamaTokenizer
1111
from xturing.engines.lora_engine import prepare_model_for_int8_training
12-
from xturing.engines.quant_utils import make_quant, autotune_warmup
12+
from xturing.engines.quant_utils import autotune_warmup, make_quant
1313
from xturing.utils.hub import ModelHub
1414

15+
1516
class LLamaEngine(CausalEngine):
1617
config_name: str = "llama_engine"
1718

@@ -102,24 +103,28 @@ def __init__(self, weights_path: Optional[Union[str, Path]] = None):
102103
target_modules=["q_proj", "v_proj"],
103104
)
104105

105-
def find_layers(module, layers=[nn.Conv2d, nn.Linear], name=''):
106+
107+
def find_layers(module, layers=[nn.Conv2d, nn.Linear], name=""):
106108
if type(module) in layers:
107109
return {name: module}
108110
res = {}
109111
for name1, child in module.named_children():
110-
res.update(find_layers(
111-
child, layers=layers, name=name + '.' + name1 if name != '' else name1
112-
))
112+
res.update(
113+
find_layers(
114+
child, layers=layers, name=name + "." + name1 if name != "" else name1
115+
)
116+
)
113117
return res
114118

119+
115120
class LlamaLoraInt4Engine(CausalLoraEngine):
116121
config_name: str = "llama_lora_int4_engine"
117122

118123
def __init__(self, weights_path: Optional[Union[str, Path]] = None):
119-
model_name = "decapoda-research/llama-7b-hf"
124+
model_name = "decapoda-research/llama-7b-hf"
120125

121126
if weights_path is None:
122-
weights_path = ModelHub().load("x/llama_lora_int4")
127+
weights_path = ModelHub().load("x/llama_lora_int4")
123128

124129
config = LlamaConfig.from_pretrained(model_name)
125130

@@ -129,10 +134,10 @@ def __init__(self, weights_path: Optional[Union[str, Path]] = None):
129134

130135
def noop(*args, **kwargs):
131136
pass
132-
133-
torch.nn.init.kaiming_uniform_ = noop
134-
torch.nn.init.uniform_ = noop
135-
torch.nn.init.normal_ = noop
137+
138+
torch.nn.init.kaiming_uniform_ = noop
139+
torch.nn.init.uniform_ = noop
140+
torch.nn.init.normal_ = noop
136141

137142
torch.set_default_dtype(torch.half)
138143
transformers.modeling_utils._init_weights = False
@@ -143,18 +148,23 @@ def noop(*args, **kwargs):
143148

144149
layers = find_layers(model)
145150

146-
for name in ['lm_head']:
151+
for name in ["lm_head"]:
147152
if name in layers:
148153
del layers[name]
149-
154+
150155
wbits = 4
151156
groupsize = 128
152-
warmup_autotune=True
153-
157+
warmup_autotune = True
158+
154159
make_quant(model, layers, wbits, groupsize)
155-
156160

157-
model.load_state_dict(torch.load(weights_path / Path("pytorch_model.bin")), strict=False)
161+
state_dict = torch.load(
162+
weights_path / Path("pytorch_model.bin"), map_location="cpu"
163+
)
164+
new_state_dict = {}
165+
for key, value in state_dict.items():
166+
new_state_dict[key[6:]] = value
167+
model.load_state_dict(new_state_dict, strict=False)
158168

159169
if warmup_autotune:
160170
autotune_warmup(model)
@@ -171,12 +181,12 @@ def noop(*args, **kwargs):
171181
tokenizer.pad_token_id = tokenizer.eos_token_id
172182

173183
super().__init__(
174-
model=model,
184+
model=model,
175185
tokenizer=tokenizer,
176186
target_modules=[
177187
"q_proj",
178188
"v_proj",
179-
]
189+
],
180190
)
181191

182192
torch.nn.init.kaiming_uniform_ = saved_kaiming_uniform_

0 commit comments

Comments
 (0)