Skip to content

Commit e3516e4

Browse files
metascroyguangy10mikekgfbmergennachin
authored
Gguf cleanup (#230)
* clean up gguf loading. Move model loading to meta. * remove cpu * Fix CI and validation scripts (#154) * missing device (#232) * Use generator args to group all arguments to generator (#231) * prompt * chat_mode, num_samples * Move more generator args to use dataclass (#233) * prompt * chat_mode, num_samples * move more args * more gen args * update * args * undo some changes * typos * Minor lint fixes (#236) * remove redundancy & remove int4 linear test from ET tests (#237) * remove redundancy * no int4 linear on ET * small changes --------- Co-authored-by: Guang Yang <[email protected]> Co-authored-by: Michael Gschwind <[email protected]> Co-authored-by: Mergen Nachin <[email protected]>
1 parent 5a226bd commit e3516e4

File tree

3 files changed

+106
-175
lines changed

3 files changed

+106
-175
lines changed

build/builder.py

+14-14
Original file line numberDiff line numberDiff line change
@@ -154,21 +154,13 @@ def device_sync(device):
154154
sys.path.append(str(wd))
155155

156156

157-
def _load_model(builder_args):
158-
if builder_args.gguf_path:
159-
model = Transformer.from_gguf(builder_args.gguf_path)
160-
161-
# TODO: to take advantage of mmap, maybe we write converted gguf to file
162-
# and read back in?
163-
# TODO: should we add check that builder_args.precision is aligned with quant scheme, e.g., bfloat16
164-
# is needed for int4
165-
model = model.to(device=builder_args.device, dtype=builder_args.precision)
166-
return model.eval()
167-
else:
168-
return _load_model_not_gguf(builder_args)
157+
def _load_model_gguf(builder_args):
158+
assert builder_args.gguf_path
159+
model = Transformer.from_gguf(builder_args.gguf_path)
160+
return model
169161

170162

171-
def _load_model_not_gguf(builder_args):
163+
def _load_model_default(builder_args):
172164
assert not builder_args.gguf_path
173165

174166
with torch.device("meta"):
@@ -218,9 +210,17 @@ def _load_model_not_gguf(builder_args):
218210

219211
model.load_state_dict(checkpoint, assign=True, strict=False)
220212

213+
return model
214+
215+
216+
def _load_model(builder_args):
217+
if builder_args.gguf_path:
218+
model = _load_model_gguf(builder_args)
219+
else:
220+
model = _load_model_default(builder_args)
221+
221222
if builder_args.use_tp:
222223
from tp import apply_tp
223-
224224
print("Applying tensor parallel to model ...")
225225
apply_tp(model)
226226

build/gguf_loader.py

+89-158
Original file line numberDiff line numberDiff line change
@@ -17,69 +17,22 @@
1717
import torch
1818
import torch.nn as nn
1919

20+
wd = Path(__file__).parent.resolve()
21+
sys.path.append(str(wd))
22+
2023
from gguf import GGUFValueType, ReaderTensor
2124
from quantize import (
2225
group_dequantize_tensor_from_qparams,
2326
pack_scales_and_zeros,
2427
WeightOnlyInt4Linear,
2528
)
2629

27-
from build.gguf_util import F16, F32, Q4_0, Q6_K
28-
29-
wd = Path(__file__).parent.resolve()
30-
sys.path.append(str(wd))
31-
30+
from build.gguf_util import F16, F32, Q4_0, Q6_K, to_float
3231
from model import ModelArgs, Transformer
3332

3433
logger: logging.Logger = logging.getLogger(__name__)
3534

3635

37-
@dataclass
38-
class AttentionArgs:
39-
head_count: int
40-
head_count_kv: int
41-
layer_norm_rms_epsilon: float
42-
43-
44-
@dataclass
45-
class RopeArgs:
46-
dimension_count: int | None = None
47-
freq_base: float | None = None
48-
49-
50-
@dataclass
51-
class GGUFModelArgs:
52-
arch: str
53-
embedding_length: int
54-
block_count: int
55-
feed_forward_length: int
56-
vocab_size: int
57-
attention: AttentionArgs
58-
rope: RopeArgs
59-
60-
61-
@dataclass
62-
class GGUFWeights:
63-
tensors: list[ReaderTensor]
64-
65-
66-
def _create_pt_model(
67-
gguf_model_args: GGUFModelArgs,
68-
) -> nn.Module:
69-
llama_model_args = ModelArgs(
70-
dim=gguf_model_args.embedding_length,
71-
n_layers=gguf_model_args.block_count,
72-
n_heads=gguf_model_args.attention.head_count,
73-
n_local_heads=gguf_model_args.attention.head_count_kv,
74-
vocab_size=gguf_model_args.vocab_size,
75-
norm_eps=gguf_model_args.attention.layer_norm_rms_epsilon,
76-
hidden_dim=gguf_model_args.feed_forward_length,
77-
)
78-
pt_model = Transformer(llama_model_args)
79-
pt_model.eval()
80-
return pt_model
81-
82-
8336
_name_replacements = [
8437
("blk", "layers"),
8538
("token_embd", "tok_embeddings"),
@@ -102,29 +55,6 @@ def _convert_gguf_tensor_name_to_llama_nn(gguf_name: str) -> str:
10255
return result
10356

10457

105-
def _build_model_args(metadata: dict[str, Any]) -> GGUFModelArgs:
106-
arch = metadata["general.architecture"]
107-
assert (
108-
arch == "llama"
109-
), f"Only general.architecture=llama is supported, but got general.architecture={arch}"
110-
return GGUFModelArgs(
111-
arch=arch,
112-
embedding_length=metadata[f"{arch}.embedding_length"],
113-
block_count=metadata[f"{arch}.block_count"],
114-
feed_forward_length=metadata[f"{arch}.feed_forward_length"],
115-
vocab_size=len(metadata["tokenizer.ggml.tokens"]),
116-
attention=AttentionArgs(
117-
head_count=metadata[f"{arch}.attention.head_count"],
118-
head_count_kv=metadata[f"{arch}.attention.head_count_kv"],
119-
layer_norm_rms_epsilon=metadata[f"{arch}.attention.layer_norm_rms_epsilon"],
120-
),
121-
rope=RopeArgs(
122-
freq_base=metadata.get(f"{arch}.rope.freq_base", None),
123-
dimension_count=metadata.get(f"{arch}.rope.dimension_count", None),
124-
),
125-
)
126-
127-
12858
def _fqn_lookup(fqn: str, module: torch.nn.Module) -> Any:
12959
if fqn == "":
13060
return module
@@ -153,74 +83,6 @@ def _fqn_last(fqn: str) -> str:
15383
return atoms[-1]
15484

15585

156-
def load_weights(
157-
pt_model: torch.nn.Module, weight_map: Dict[str, ReaderTensor], inner_k_tiles=8
158-
) -> None:
159-
fqns = []
160-
for fqn in pt_model.state_dict():
161-
assert _fqn_last(fqn) == "weight"
162-
fqns.append(_fqn_up(fqn))
163-
164-
state_dict = {}
165-
for fqn in fqns:
166-
mod = _fqn_lookup(fqn, pt_model)
167-
168-
t = weight_map[f"{fqn}.weight"]
169-
170-
if (
171-
isinstance(mod, torch.nn.Linear)
172-
and t.tensor_type == gguf.GGMLQuantizationType.Q4_0
173-
):
174-
assert not mod.bias
175-
out_features = mod.out_features
176-
in_features = mod.in_features
177-
assert all(t.shape == (in_features, out_features))
178-
179-
q, s, z = Q4_0.unpack(t)
180-
scales_and_zeros = pack_scales_and_zeros(s, z)
181-
weight_int4pack = torch.ops.aten._convert_weight_to_int4pack(
182-
q, inner_k_tiles
183-
)
184-
185-
state_dict[f"{fqn}.weight"] = weight_int4pack.to("cpu")
186-
state_dict[f"{fqn}.scales_and_zeros"] = scales_and_zeros.to("cpu")
187-
188-
parent = _fqn_lookup(_fqn_up(fqn), pt_model)
189-
setattr(
190-
parent,
191-
_fqn_last(fqn),
192-
WeightOnlyInt4Linear(
193-
"cpu", # TODO: should --device work for gguf load? (yes?!)
194-
in_features,
195-
out_features,
196-
bias=False,
197-
groupsize=Q4_0.groupsize,
198-
inner_k_tiles=inner_k_tiles,
199-
),
200-
)
201-
else:
202-
# All other weights are dequantized to float
203-
if t.tensor_type == gguf.GGMLQuantizationType.Q4_0:
204-
as_float = group_dequantize_tensor_from_qparams(
205-
*Q4_0.unpack(t), Q4_0.n_bit, Q4_0.groupsize
206-
)
207-
elif t.tensor_type == gguf.GGMLQuantizationType.Q6_K:
208-
as_float = group_dequantize_tensor_from_qparams(
209-
*Q6_K.unpack(t), Q6_K.n_bit, Q6_K.groupsize
210-
)
211-
elif t.tensor_type == gguf.GGMLQuantizationType.F16:
212-
as_float = F16.unpack(t)
213-
elif t.tensor_type == gguf.GGMLQuantizationType.F32:
214-
as_float = F32.unpack(t)
215-
else:
216-
raise ValueError(f"Unsupported tensor type {t.tensor_type}")
217-
218-
state_dict[f"{fqn}.weight"] = as_float.to("cpu")
219-
220-
pt_model.load_state_dict(state_dict)
221-
return pt_model
222-
223-
22486
def _get_metadata(reader: gguf.GGUFReader) -> dict[str, Any]:
22587
metadata: dict[str, Any] = {}
22688

@@ -244,34 +106,103 @@ def _get_metadata(reader: gguf.GGUFReader) -> dict[str, Any]:
244106
return metadata
245107

246108

247-
def load_llama_from_gguf_file(gguf_file: str) -> torch.nn.Module:
109+
def load_model(gguf_file: str) -> torch.nn.Module:
248110
"""
249-
Load a LLaMa model from a GGUF file and return a PT nn.Module.
111+
Parses the GGUF file and returns an nn.Module on meta device.
250112
"""
251-
if not Path(gguf_file).is_file():
252-
raise ValueError(f"Could not find file {gguf_file}")
253113

254114
logger.info("Parsing GGUF metadata.")
255115
reader = gguf.GGUFReader(gguf_file, "r")
256116
metadata = _get_metadata(reader)
257-
model_args = _build_model_args(metadata)
117+
118+
arch = metadata["general.architecture"]
258119
assert (
259-
model_args.arch == "llama"
120+
arch == "llama"
260121
), "Only LLaMa models are supported by this converter."
261122

262-
logger.info("Creating initial PT model.")
263-
pt_model = _create_pt_model(model_args)
123+
model_args = ModelArgs(
124+
dim=metadata[f"{arch}.embedding_length"],
125+
n_layers=metadata[f"{arch}.block_count"],
126+
n_heads=metadata[f"{arch}.attention.head_count"],
127+
n_local_heads=metadata[f"{arch}.attention.head_count_kv"],
128+
vocab_size=len(metadata["tokenizer.ggml.tokens"]),
129+
norm_eps=metadata[f"{arch}.attention.layer_norm_rms_epsilon"],
130+
hidden_dim=metadata[f"{arch}.feed_forward_length"],
131+
)
264132

265-
logger.info("Reading GGUF weights.")
266-
gguf_weights = GGUFWeights(tensors=reader.tensors)
133+
# TODO: what to do with rope args like
134+
# metadata.get(f"{arch}.rope.freq_base", None)
135+
# metadata.get(f"{arch}.rope.dimension_count", None)
267136

268-
logger.info("Building GGUF weight map.")
269-
# map from fqn in pt_model to gguf tensor
137+
with torch.device("meta"):
138+
model = Transformer(model_args)
139+
return model
140+
141+
142+
def load_model_and_state_dict(gguf_file: str, load_as_quantized: bool, *, inner_k_tiles = 8) -> torch.nn.Module:
143+
"""
144+
Parses the GGUF file and returns an nn.Module on meta device along with a state_dict
145+
that can be loaded into it.
146+
147+
When load_as_quantized, the method tries to preserve the GGUF quantization when it
148+
is natively supported by PyTorch, otherwise it converts quantized tensors to FP32.
149+
"""
150+
151+
model = load_model(gguf_file)
152+
153+
reader = gguf.GGUFReader(gguf_file, "r")
270154
weight_map = {
271155
_convert_gguf_tensor_name_to_llama_nn(tensor.name): tensor
272-
for tensor in gguf_weights.tensors
156+
for tensor in reader.tensors
273157
}
274158

275-
logger.info("Loading weights into state_dict")
276-
pt_model = load_weights(pt_model, weight_map, inner_k_tiles=8)
277-
return pt_model
159+
state_dict = {}
160+
for fqn in weight_map:
161+
assert _fqn_last(fqn) == "weight"
162+
fqn = _fqn_up(fqn)
163+
164+
mod = _fqn_lookup(fqn, model)
165+
t = weight_map[f"{fqn}.weight"]
166+
167+
if (
168+
isinstance(mod, torch.nn.Linear)
169+
and t.tensor_type == gguf.GGMLQuantizationType.Q4_0
170+
and load_as_quantized
171+
):
172+
assert not mod.bias
173+
out_features = mod.out_features
174+
in_features = mod.in_features
175+
assert all(t.shape == (in_features, out_features))
176+
177+
q, s, z = Q4_0.unpack(t)
178+
scales_and_zeros = pack_scales_and_zeros(s, z)
179+
weight_int4pack = torch.ops.aten._convert_weight_to_int4pack(
180+
q, inner_k_tiles
181+
)
182+
183+
state_dict[f"{fqn}.weight"] = weight_int4pack
184+
state_dict[f"{fqn}.scales_and_zeros"] = scales_and_zeros
185+
186+
parent = _fqn_lookup(_fqn_up(fqn), model)
187+
setattr(
188+
parent,
189+
_fqn_last(fqn),
190+
WeightOnlyInt4Linear(
191+
"meta",
192+
in_features,
193+
out_features,
194+
bias=False,
195+
groupsize=Q4_0.groupsize,
196+
inner_k_tiles=inner_k_tiles,
197+
),
198+
)
199+
else:
200+
state_dict[f"{fqn}.weight"] = to_float(t)
201+
202+
return model, state_dict
203+
204+
205+
def load_llama_from_gguf_file(gguf_file: str) -> torch.nn.Module:
206+
model, state_dict = load_model_and_state_dict(gguf_file, load_as_quantized=True)
207+
model.load_state_dict(state_dict, assign=True)
208+
return model

build/model.py

+3-3
Original file line numberDiff line numberDiff line change
@@ -247,9 +247,9 @@ def from_params(cls, params_path: str):
247247

248248
@classmethod
249249
def from_gguf(cls, gguf_path: str):
250-
from build.gguf_loader import load_llama_from_gguf_file
251-
252-
model = load_llama_from_gguf_file(gguf_path)
250+
from build.gguf_loader import load_model_and_state_dict
251+
model, state_dict = load_model_and_state_dict(gguf_path, load_as_quantized=True, inner_k_tiles=8)
252+
model.load_state_dict(state_dict, assign=True)
253253
return model
254254

255255

0 commit comments

Comments
 (0)