@@ -3275,7 +3275,9 @@ std::tuple<array, array, array> quantize(
32753275 }
32763276
32773277 // Compute some constants used for the quantization
3278- int n_bins = (1 << bits) - 1 ; // 2**bits - 1
3278+ array n_bins ((1 << bits) - 1 , w.dtype ()); // 2**bits - 1
3279+ array eps (1e-7 , w.dtype ());
3280+ array zero (0 , w.dtype ());
32793281 int el_per_int = 32 / bits;
32803282 array shifts = power (array (2 , uint32), arange (0 , 32 , bits, uint32, s), s);
32813283 shifts = reshape (shifts, {1 , 1 , -1 }, s);
@@ -3299,16 +3301,22 @@ std::tuple<array, array, array> quantize(
32993301 reshape (w, {w.shape (0 ), w.shape (1 ) / group_size, group_size}, s);
33003302 array w_max = max (packed_w, /* axis= */ -1 , /* keepdims= */ true , s);
33013303 array w_min = min (packed_w, /* axis= */ -1 , /* keepdims= */ true , s);
3302- array scales = maximum (
3303- divide (subtract (w_max, w_min, s), array (n_bins, w.dtype ()), s),
3304- array (1e-7 , w.dtype ()),
3305- s);
3306- // making sure that 0 is represented exactly in the resulting quantization
3307- array biases = multiply (round (divide (w_min, scales, s), s), scales, s);
3304+
3305+ array mask = greater (abs (w_min, s), abs (w_max, s), s);
3306+ array scales = maximum (divide (subtract (w_max, w_min, s), n_bins, s), eps, s);
3307+ scales = where (mask, scales, negative (scales), s);
3308+ array edge = where (mask, w_min, w_max, s);
3309+ array q0 = round (divide (edge, scales, s), s);
3310+ scales = where (not_equal (q0, zero, s), divide (edge, q0, s), scales);
3311+ array biases = where (equal (q0, zero, s), zero, edge);
33083312
33093313 // Quantize and pack w
33103314 packed_w = astype (
3311- round (divide (subtract (packed_w, biases, s), scales, s), s), uint32);
3315+ clip (
3316+ round (divide (subtract (packed_w, biases, s), scales, s), s),
3317+ zero,
3318+ n_bins),
3319+ uint32);
33123320 packed_w = reshape (packed_w, {w.shape (0 ), -1 , el_per_int}, s);
33133321 packed_w = sum (
33143322 multiply (packed_w, shifts, s), /* axis= */ 2 , /* keepdims= */ false , s);
0 commit comments