Skip to content

Commit ec8578d

Browse files
authored
Fix quantization of all 0s (#1028)
1 parent d0dbfe0 commit ec8578d

File tree

2 files changed

+12
-1
lines changed

2 files changed

+12
-1
lines changed

mlx/ops.cpp

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3274,7 +3274,10 @@ std::tuple<array, array, array> quantize(
32743274
reshape(w, {w.shape(0), w.shape(1) / group_size, group_size}, s);
32753275
array w_max = max(packed_w, /* axis= */ -1, /* keepdims= */ true, s);
32763276
array w_min = min(packed_w, /* axis= */ -1, /* keepdims= */ true, s);
3277-
array delta = divide(subtract(w_max, w_min, s), array(n_bins, w.dtype()), s);
3277+
array delta = maximum(
3278+
divide(subtract(w_max, w_min, s), array(n_bins, w.dtype()), s),
3279+
array(1e-7, w.dtype()),
3280+
s);
32783281
array scales = squeeze(delta, -1, s);
32793282
array biases = squeeze(w_min, -1, s);
32803283

python/tests/test_quantized.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,14 @@ def test_quantize_dequantize(self):
1818
eps = 1e-6
1919
self.assertTrue((errors <= (scales[..., None] + eps)).all())
2020

21+
# test quantize/dequantize 0s
22+
a = mx.zeros((256, 512))
23+
for gs in [32, 64, 128]:
24+
for b in [2, 4, 8]:
25+
w_q, scales, biases = mx.quantize(a, gs, b)
26+
a_hat = mx.dequantize(w_q, scales, biases, gs, b)
27+
self.assertTrue(mx.all(a_hat == 0))
28+
2129
def test_qmm(self):
2230
key = mx.random.key(0)
2331
k1, k2 = mx.random.split(key)

0 commit comments

Comments
 (0)