17
17
import torch
18
18
import torch .nn as nn
19
19
20
+ wd = Path (__file__ ).parent .resolve ()
21
+ sys .path .append (str (wd ))
22
+
20
23
from gguf import GGUFValueType , ReaderTensor
21
24
from quantize import (
22
25
group_dequantize_tensor_from_qparams ,
23
26
pack_scales_and_zeros ,
24
27
WeightOnlyInt4Linear ,
25
28
)
26
29
27
- from build .gguf_util import F16 , F32 , Q4_0 , Q6_K
28
-
29
- wd = Path (__file__ ).parent .resolve ()
30
- sys .path .append (str (wd ))
31
-
30
+ from build .gguf_util import F16 , F32 , Q4_0 , Q6_K , to_float
32
31
from model import ModelArgs , Transformer
33
32
34
33
logger : logging .Logger = logging .getLogger (__name__ )
35
34
36
35
37
- @dataclass
38
- class AttentionArgs :
39
- head_count : int
40
- head_count_kv : int
41
- layer_norm_rms_epsilon : float
42
-
43
-
44
- @dataclass
45
- class RopeArgs :
46
- dimension_count : int | None = None
47
- freq_base : float | None = None
48
-
49
-
50
- @dataclass
51
- class GGUFModelArgs :
52
- arch : str
53
- embedding_length : int
54
- block_count : int
55
- feed_forward_length : int
56
- vocab_size : int
57
- attention : AttentionArgs
58
- rope : RopeArgs
59
-
60
-
61
- @dataclass
62
- class GGUFWeights :
63
- tensors : list [ReaderTensor ]
64
-
65
-
66
- def _create_pt_model (
67
- gguf_model_args : GGUFModelArgs ,
68
- ) -> nn .Module :
69
- llama_model_args = ModelArgs (
70
- dim = gguf_model_args .embedding_length ,
71
- n_layers = gguf_model_args .block_count ,
72
- n_heads = gguf_model_args .attention .head_count ,
73
- n_local_heads = gguf_model_args .attention .head_count_kv ,
74
- vocab_size = gguf_model_args .vocab_size ,
75
- norm_eps = gguf_model_args .attention .layer_norm_rms_epsilon ,
76
- hidden_dim = gguf_model_args .feed_forward_length ,
77
- )
78
- pt_model = Transformer (llama_model_args )
79
- pt_model .eval ()
80
- return pt_model
81
-
82
-
83
36
_name_replacements = [
84
37
("blk" , "layers" ),
85
38
("token_embd" , "tok_embeddings" ),
@@ -102,29 +55,6 @@ def _convert_gguf_tensor_name_to_llama_nn(gguf_name: str) -> str:
102
55
return result
103
56
104
57
105
- def _build_model_args (metadata : dict [str , Any ]) -> GGUFModelArgs :
106
- arch = metadata ["general.architecture" ]
107
- assert (
108
- arch == "llama"
109
- ), f"Only general.architecture=llama is supported, but got general.architecture={ arch } "
110
- return GGUFModelArgs (
111
- arch = arch ,
112
- embedding_length = metadata [f"{ arch } .embedding_length" ],
113
- block_count = metadata [f"{ arch } .block_count" ],
114
- feed_forward_length = metadata [f"{ arch } .feed_forward_length" ],
115
- vocab_size = len (metadata ["tokenizer.ggml.tokens" ]),
116
- attention = AttentionArgs (
117
- head_count = metadata [f"{ arch } .attention.head_count" ],
118
- head_count_kv = metadata [f"{ arch } .attention.head_count_kv" ],
119
- layer_norm_rms_epsilon = metadata [f"{ arch } .attention.layer_norm_rms_epsilon" ],
120
- ),
121
- rope = RopeArgs (
122
- freq_base = metadata .get (f"{ arch } .rope.freq_base" , None ),
123
- dimension_count = metadata .get (f"{ arch } .rope.dimension_count" , None ),
124
- ),
125
- )
126
-
127
-
128
58
def _fqn_lookup (fqn : str , module : torch .nn .Module ) -> Any :
129
59
if fqn == "" :
130
60
return module
@@ -153,74 +83,6 @@ def _fqn_last(fqn: str) -> str:
153
83
return atoms [- 1 ]
154
84
155
85
156
- def load_weights (
157
- pt_model : torch .nn .Module , weight_map : Dict [str , ReaderTensor ], inner_k_tiles = 8
158
- ) -> None :
159
- fqns = []
160
- for fqn in pt_model .state_dict ():
161
- assert _fqn_last (fqn ) == "weight"
162
- fqns .append (_fqn_up (fqn ))
163
-
164
- state_dict = {}
165
- for fqn in fqns :
166
- mod = _fqn_lookup (fqn , pt_model )
167
-
168
- t = weight_map [f"{ fqn } .weight" ]
169
-
170
- if (
171
- isinstance (mod , torch .nn .Linear )
172
- and t .tensor_type == gguf .GGMLQuantizationType .Q4_0
173
- ):
174
- assert not mod .bias
175
- out_features = mod .out_features
176
- in_features = mod .in_features
177
- assert all (t .shape == (in_features , out_features ))
178
-
179
- q , s , z = Q4_0 .unpack (t )
180
- scales_and_zeros = pack_scales_and_zeros (s , z )
181
- weight_int4pack = torch .ops .aten ._convert_weight_to_int4pack (
182
- q , inner_k_tiles
183
- )
184
-
185
- state_dict [f"{ fqn } .weight" ] = weight_int4pack .to ("cpu" )
186
- state_dict [f"{ fqn } .scales_and_zeros" ] = scales_and_zeros .to ("cpu" )
187
-
188
- parent = _fqn_lookup (_fqn_up (fqn ), pt_model )
189
- setattr (
190
- parent ,
191
- _fqn_last (fqn ),
192
- WeightOnlyInt4Linear (
193
- "cpu" , # TODO: should --device work for gguf load? (yes?!)
194
- in_features ,
195
- out_features ,
196
- bias = False ,
197
- groupsize = Q4_0 .groupsize ,
198
- inner_k_tiles = inner_k_tiles ,
199
- ),
200
- )
201
- else :
202
- # All other weights are dequantized to float
203
- if t .tensor_type == gguf .GGMLQuantizationType .Q4_0 :
204
- as_float = group_dequantize_tensor_from_qparams (
205
- * Q4_0 .unpack (t ), Q4_0 .n_bit , Q4_0 .groupsize
206
- )
207
- elif t .tensor_type == gguf .GGMLQuantizationType .Q6_K :
208
- as_float = group_dequantize_tensor_from_qparams (
209
- * Q6_K .unpack (t ), Q6_K .n_bit , Q6_K .groupsize
210
- )
211
- elif t .tensor_type == gguf .GGMLQuantizationType .F16 :
212
- as_float = F16 .unpack (t )
213
- elif t .tensor_type == gguf .GGMLQuantizationType .F32 :
214
- as_float = F32 .unpack (t )
215
- else :
216
- raise ValueError (f"Unsupported tensor type { t .tensor_type } " )
217
-
218
- state_dict [f"{ fqn } .weight" ] = as_float .to ("cpu" )
219
-
220
- pt_model .load_state_dict (state_dict )
221
- return pt_model
222
-
223
-
224
86
def _get_metadata (reader : gguf .GGUFReader ) -> dict [str , Any ]:
225
87
metadata : dict [str , Any ] = {}
226
88
@@ -244,34 +106,103 @@ def _get_metadata(reader: gguf.GGUFReader) -> dict[str, Any]:
244
106
return metadata
245
107
246
108
247
- def load_llama_from_gguf_file (gguf_file : str ) -> torch .nn .Module :
109
+ def load_model (gguf_file : str ) -> torch .nn .Module :
248
110
"""
249
- Load a LLaMa model from a GGUF file and return a PT nn.Module.
111
+ Parses the GGUF file and returns an nn.Module on meta device .
250
112
"""
251
- if not Path (gguf_file ).is_file ():
252
- raise ValueError (f"Could not find file { gguf_file } " )
253
113
254
114
logger .info ("Parsing GGUF metadata." )
255
115
reader = gguf .GGUFReader (gguf_file , "r" )
256
116
metadata = _get_metadata (reader )
257
- model_args = _build_model_args (metadata )
117
+
118
+ arch = metadata ["general.architecture" ]
258
119
assert (
259
- model_args . arch == "llama"
120
+ arch == "llama"
260
121
), "Only LLaMa models are supported by this converter."
261
122
262
- logger .info ("Creating initial PT model." )
263
- pt_model = _create_pt_model (model_args )
123
+ model_args = ModelArgs (
124
+ dim = metadata [f"{ arch } .embedding_length" ],
125
+ n_layers = metadata [f"{ arch } .block_count" ],
126
+ n_heads = metadata [f"{ arch } .attention.head_count" ],
127
+ n_local_heads = metadata [f"{ arch } .attention.head_count_kv" ],
128
+ vocab_size = len (metadata ["tokenizer.ggml.tokens" ]),
129
+ norm_eps = metadata [f"{ arch } .attention.layer_norm_rms_epsilon" ],
130
+ hidden_dim = metadata [f"{ arch } .feed_forward_length" ],
131
+ )
264
132
265
- logger .info ("Reading GGUF weights." )
266
- gguf_weights = GGUFWeights (tensors = reader .tensors )
133
+ # TODO: what to do with rope args like
134
+ # metadata.get(f"{arch}.rope.freq_base", None)
135
+ # metadata.get(f"{arch}.rope.dimension_count", None)
267
136
268
- logger .info ("Building GGUF weight map." )
269
- # map from fqn in pt_model to gguf tensor
137
+ with torch .device ("meta" ):
138
+ model = Transformer (model_args )
139
+ return model
140
+
141
+
142
+ def load_model_and_state_dict (gguf_file : str , load_as_quantized : bool , * , inner_k_tiles = 8 ) -> torch .nn .Module :
143
+ """
144
+ Parses the GGUF file and returns an nn.Module on meta device along with a state_dict
145
+ that can be loaded into it.
146
+
147
+ When load_as_quantized, the method tries to preserve the GGUF quantization when it
148
+ is natively supported by PyTorch, otherwise it converts quantized tensors to FP32.
149
+ """
150
+
151
+ model = load_model (gguf_file )
152
+
153
+ reader = gguf .GGUFReader (gguf_file , "r" )
270
154
weight_map = {
271
155
_convert_gguf_tensor_name_to_llama_nn (tensor .name ): tensor
272
- for tensor in gguf_weights .tensors
156
+ for tensor in reader .tensors
273
157
}
274
158
275
- logger .info ("Loading weights into state_dict" )
276
- pt_model = load_weights (pt_model , weight_map , inner_k_tiles = 8 )
277
- return pt_model
159
+ state_dict = {}
160
+ for fqn in weight_map :
161
+ assert _fqn_last (fqn ) == "weight"
162
+ fqn = _fqn_up (fqn )
163
+
164
+ mod = _fqn_lookup (fqn , model )
165
+ t = weight_map [f"{ fqn } .weight" ]
166
+
167
+ if (
168
+ isinstance (mod , torch .nn .Linear )
169
+ and t .tensor_type == gguf .GGMLQuantizationType .Q4_0
170
+ and load_as_quantized
171
+ ):
172
+ assert not mod .bias
173
+ out_features = mod .out_features
174
+ in_features = mod .in_features
175
+ assert all (t .shape == (in_features , out_features ))
176
+
177
+ q , s , z = Q4_0 .unpack (t )
178
+ scales_and_zeros = pack_scales_and_zeros (s , z )
179
+ weight_int4pack = torch .ops .aten ._convert_weight_to_int4pack (
180
+ q , inner_k_tiles
181
+ )
182
+
183
+ state_dict [f"{ fqn } .weight" ] = weight_int4pack
184
+ state_dict [f"{ fqn } .scales_and_zeros" ] = scales_and_zeros
185
+
186
+ parent = _fqn_lookup (_fqn_up (fqn ), model )
187
+ setattr (
188
+ parent ,
189
+ _fqn_last (fqn ),
190
+ WeightOnlyInt4Linear (
191
+ "meta" ,
192
+ in_features ,
193
+ out_features ,
194
+ bias = False ,
195
+ groupsize = Q4_0 .groupsize ,
196
+ inner_k_tiles = inner_k_tiles ,
197
+ ),
198
+ )
199
+ else :
200
+ state_dict [f"{ fqn } .weight" ] = to_float (t )
201
+
202
+ return model , state_dict
203
+
204
+
205
+ def load_llama_from_gguf_file (gguf_file : str ) -> torch .nn .Module :
206
+ model , state_dict = load_model_and_state_dict (gguf_file , load_as_quantized = True )
207
+ model .load_state_dict (state_dict , assign = True )
208
+ return model
0 commit comments