Skip to content

Commit 4fe2fa2

Browse files
jbochiawni
andauthored
GGUF: Avoid dequantization when format is compatible (#426)
* GGUF: Don't dequantize q4_1 * Fix weight order. First in low bits * Add unpacking for q4_0 * Don't dequantize q8_0 * rebase quants and split file * don't quantize every weight * reapply patch * error handling --------- Co-authored-by: Awni Hannun <[email protected]>
1 parent 37fc9db commit 4fe2fa2

File tree

5 files changed

+210
-20
lines changed

5 files changed

+210
-20
lines changed

mlx/io/CMakeLists.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@ target_sources(
44
${CMAKE_CURRENT_SOURCE_DIR}/load.cpp
55
${CMAKE_CURRENT_SOURCE_DIR}/safetensor.cpp
66
${CMAKE_CURRENT_SOURCE_DIR}/gguf.cpp
7+
${CMAKE_CURRENT_SOURCE_DIR}/gguf_quants.cpp
78
)
89

910
MESSAGE(STATUS "Downloading json")

mlx/io/gguf.cpp

Lines changed: 31 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -1,17 +1,10 @@
1-
// Copyright © 2023 Apple Inc.
1+
// Copyright © 2023-2024 Apple Inc.
22

33
#include <cstdint>
44
#include <cstring>
55
#include <numeric>
66

7-
#include "mlx/io.h"
8-
#include "mlx/primitives.h"
9-
#include "mlx/transforms.h"
10-
#include "mlx/utils.h"
11-
12-
extern "C" {
13-
#include <gguflib.h>
14-
}
7+
#include <mlx/io/gguf.h>
158

169
namespace mlx::core {
1710

@@ -52,7 +45,16 @@ std::optional<Dtype> gguf_type_to_dtype(const uint32_t& gguf_type) {
5245
}
5346
}
5447

55-
std::pair<allocator::Buffer, Dtype> extract_tensor_data(gguf_tensor* tensor) {
48+
std::vector<int> get_shape(const gguf_tensor& tensor) {
49+
std::vector<int> shape;
50+
// The dimension order in GGML is the reverse of the order used in MLX.
51+
for (int i = tensor.ndim - 1; i >= 0; i--) {
52+
shape.push_back(tensor.dim[i]);
53+
}
54+
return shape;
55+
}
56+
57+
std::tuple<allocator::Buffer, Dtype> extract_tensor_data(gguf_tensor* tensor) {
5658
std::optional<Dtype> equivalent_dtype = gguf_type_to_dtype(tensor->type);
5759
// If there's an equivalent type, we can simply copy.
5860
if (equivalent_dtype.has_value()) {
@@ -203,16 +205,27 @@ std::unordered_map<std::string, MetaData> load_metadata(gguf_ctx* ctx) {
203205
std::unordered_map<std::string, array> load_arrays(gguf_ctx* ctx) {
204206
std::unordered_map<std::string, array> array_map;
205207
gguf_tensor tensor;
208+
209+
auto check_insert = [](auto inserted) {
210+
if (!inserted.second) {
211+
std::ostringstream msg;
212+
msg << "[load_gguf] Duplicate parameter name " << inserted.first->second
213+
<< " this can happend when loading quantized tensors.";
214+
throw std::runtime_error(msg.str());
215+
}
216+
};
217+
206218
while (gguf_get_tensor(ctx, &tensor)) {
207-
std::vector<int> shape;
208-
// The dimension order in GGML is the reverse of the order used in MLX.
209-
for (int i = tensor.ndim - 1; i >= 0; i--) {
210-
shape.push_back(tensor.dim[i]);
219+
if (tensor.type == GGUF_TYPE_Q4_0 || tensor.type == GGUF_TYPE_Q4_1 ||
220+
tensor.type == GGUF_TYPE_Q8_0) {
221+
gguf_load_quantized(array_map, tensor);
222+
} else {
223+
std::string name = std::string(tensor.name, tensor.namelen);
224+
225+
const auto& [data, dtype] = extract_tensor_data(&tensor);
226+
array loaded_array = array(data, get_shape(tensor), dtype);
227+
array_map.insert({name, loaded_array});
211228
}
212-
const auto& [data, dtype] = extract_tensor_data(&tensor);
213-
array loaded_array = array(data, shape, dtype);
214-
std::string name = std::string(tensor.name, tensor.namelen);
215-
array_map.insert({name, loaded_array});
216229
}
217230
return array_map;
218231
}

mlx/io/gguf.h

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,20 @@
1+
// Copyright © 2023-2024 Apple Inc.
2+
#pragma once
3+
4+
#include "mlx/io.h"
5+
#include "mlx/primitives.h"
6+
#include "mlx/transforms.h"
7+
#include "mlx/utils.h"
8+
9+
extern "C" {
10+
#include <gguflib.h>
11+
}
12+
13+
namespace mlx::core {
14+
15+
std::vector<int> get_shape(const gguf_tensor& tensor);
16+
void gguf_load_quantized(
17+
std::unordered_map<std::string, array>& a,
18+
const gguf_tensor& tensor);
19+
20+
} // namespace mlx::core

mlx/io/gguf_quants.cpp

Lines changed: 158 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,158 @@
1+
// Copyright © 2023-2024 Apple Inc.
2+
3+
#include <cstdint>
4+
#include <cstring>
5+
6+
#include <mlx/io/gguf.h>
7+
8+
namespace mlx::core {
9+
10+
void unpack_32_4(uint8_t* data, int8_t* dst) {
11+
for (int64_t j = 0; j < 16; ++j) {
12+
uint8_t x = (data[j + 2] & 0x0F); // j+2 to skip scale bytes.
13+
if (j % 2 != 0) {
14+
x <<= 4;
15+
}
16+
dst[j / 2] += x;
17+
}
18+
// Last 16 weights are in the higher bits
19+
for (int64_t j = 0; j < 16; ++j) {
20+
uint8_t x = (data[j + 2] >> 4);
21+
if (j % 2 != 0) {
22+
x <<= 4;
23+
}
24+
dst[8 + j / 2] += x;
25+
}
26+
}
27+
28+
// Extracts (weight, scales, biases) from Q4_0 tensors.
29+
// Data layout is: |16 bit scale|32 x 4bit weights|.
30+
void extract_q4_0_data(
31+
const gguf_tensor& tensor,
32+
array& weights_arr,
33+
array& scales_arr,
34+
array& biases_arr) {
35+
const uint64_t bytes_per_block = 18; // 2 bytes scale, 32x0.5 byte weights
36+
auto data = static_cast<uint8_t*>(tensor.weights_data);
37+
auto weights = weights_arr.data<int8_t>();
38+
auto scales = scales_arr.data<float16_t>();
39+
auto biases = biases_arr.data<float16_t>();
40+
for (int64_t i = 0; i < scales_arr.size(); i++) {
41+
scales[i] = *((float16_t*)data);
42+
biases[i] = -8 * scales[i];
43+
unpack_32_4(data, weights);
44+
weights += 16;
45+
data += bytes_per_block;
46+
}
47+
}
48+
49+
// Extracts (weight, scales, biases) from Q4_1 tensors.
50+
// Data layout is: |16 bit scale|16 bit bias|32 x 4bit weights|.
51+
void extract_q4_1_data(
52+
const gguf_tensor& tensor,
53+
array& weights_arr,
54+
array& scales_arr,
55+
array& biases_arr) {
56+
const uint64_t bytes_per_block =
57+
20; // 2 bytes scale, 2 bytes bias, 32x0.5 byte weights
58+
auto data = static_cast<uint8_t*>(tensor.weights_data);
59+
auto weights = weights_arr.data<int8_t>();
60+
auto scales = scales_arr.data<float16_t>();
61+
auto biases = biases_arr.data<float16_t>();
62+
for (int64_t i = 0; i < scales_arr.size(); i++) {
63+
scales[i] = *((float16_t*)data);
64+
biases[i] = *((float16_t*)(data) + 1);
65+
unpack_32_4(data, weights);
66+
weights += 16;
67+
data += bytes_per_block;
68+
}
69+
}
70+
71+
// Extracts (weight, scales, biases) from Q8_0 tensors.
72+
// Data layout is: |16 bit scale|32 x 8bit weights|.
73+
void extract_q8_0_data(
74+
const gguf_tensor& tensor,
75+
array& weights_arr,
76+
array& scales_arr,
77+
array& biases_arr) {
78+
const uint64_t weights_per_block = 32;
79+
const uint64_t bytes_per_block = 34; // 2 bytes scale, 32x1 byte weights
80+
auto data = static_cast<uint8_t*>(tensor.weights_data);
81+
auto weights = weights_arr.data<int8_t>();
82+
auto scales = scales_arr.data<float16_t>();
83+
auto biases = biases_arr.data<float16_t>();
84+
for (int64_t i = 0; i < scales_arr.size(); i++) {
85+
uint8_t* block_data = data + i * bytes_per_block;
86+
scales[i] = *((float16_t*)block_data);
87+
biases[i] = -128 * scales[i];
88+
for (int64_t j = 0; j < weights_per_block; ++j) {
89+
uint8_t x = block_data[j + 2]; // j+2 to skip the scale bytes.
90+
// Original data is in int8_t, so we add a bias of -128 and invert the
91+
// first bit.
92+
x ^= 1 << 7;
93+
weights[i * weights_per_block + j] = x;
94+
}
95+
}
96+
}
97+
98+
void gguf_load_quantized(
99+
std::unordered_map<std::string, array>& a,
100+
const gguf_tensor& tensor) {
101+
uint64_t weights_per_byte;
102+
if (tensor.type == GGUF_TYPE_Q4_0 || tensor.type == GGUF_TYPE_Q4_1) {
103+
weights_per_byte = 2;
104+
} else { // tensor.type == GGUF_TYPE_Q8_0
105+
weights_per_byte = 1;
106+
}
107+
108+
std::string name = std::string(tensor.name, tensor.namelen);
109+
std::vector<int> shape = get_shape(tensor);
110+
const uint64_t weights_per_block = 32;
111+
if (shape[shape.size() - 1] % weights_per_block != 0) {
112+
std::ostringstream msg;
113+
msg << "[load_gguf] tensor " << name
114+
<< "has incompatible last dim shape: " << shape[shape.size() - 1];
115+
throw std::runtime_error(msg.str());
116+
}
117+
const uint64_t num_blocks = tensor.num_weights / weights_per_block;
118+
119+
std::vector<int> weights_shape = shape;
120+
weights_shape.back() /= (weights_per_byte * 4);
121+
122+
array weights(std::move(weights_shape), uint32, nullptr, {});
123+
weights.set_data(allocator::malloc(weights.nbytes()));
124+
125+
// For scales and bias
126+
shape[shape.size() - 1] = shape[shape.size() - 1] / weights_per_block;
127+
array scales(shape, float16, nullptr, {});
128+
array biases(std::move(shape), float16, nullptr, {});
129+
scales.set_data(allocator::malloc(scales.nbytes()));
130+
biases.set_data(allocator::malloc(biases.nbytes()));
131+
132+
if (tensor.type == GGUF_TYPE_Q4_0) {
133+
extract_q4_0_data(tensor, weights, scales, biases);
134+
} else if (tensor.type == GGUF_TYPE_Q4_1) {
135+
extract_q4_1_data(tensor, weights, scales, biases);
136+
} else if (tensor.type == GGUF_TYPE_Q8_0) {
137+
extract_q8_0_data(tensor, weights, scales, biases);
138+
}
139+
140+
a.insert({name, weights});
141+
142+
auto check_insert = [](auto inserted) {
143+
if (!inserted.second) {
144+
std::ostringstream msg;
145+
msg << "[load_gguf] Duplicate parameter name " << inserted.first->second
146+
<< " this can happend when loading quantized tensors.";
147+
throw std::runtime_error(msg.str());
148+
}
149+
};
150+
151+
const std::string weight_suffix = ".weight";
152+
const std::string name_prefix =
153+
name.substr(0, name.length() - weight_suffix.length());
154+
check_insert(a.insert({name_prefix + ".scales", scales}));
155+
check_insert(a.insert({name_prefix + ".biases", biases}));
156+
}
157+
158+
} // namespace mlx::core

tests/metal_tests.cpp

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -500,15 +500,13 @@ TEST_CASE("test metal enable/disable cache") {
500500
auto buf = a.malloc(size, false);
501501
auto buf_ptr = static_cast<MTL::Buffer*>(buf.ptr());
502502
unsigned char first_byte = *reinterpret_cast<unsigned char*>(buf_ptr);
503-
printf("first byte: %d\n", first_byte);
504503

505504
// Release a
506505
a.free(buf);
507506

508507
// If release successfully, the first byte should be different from the
509508
// first byte before release
510509
unsigned char new_first_byte = *reinterpret_cast<unsigned char*>(buf_ptr);
511-
printf("new first byte: %d\n", new_first_byte);
512510

513511
CHECK_NE(new_first_byte, first_byte);
514512
}

0 commit comments

Comments
 (0)