Skip to content

Commit 8ce49cd

Browse files
authored
fix quantized vjp for mxfp4 (#2555)
1 parent 9c68b50 commit 8ce49cd

File tree

2 files changed

+43
-8
lines changed

2 files changed

+43
-8
lines changed

mlx/primitives.cpp

Lines changed: 12 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -3246,7 +3246,8 @@ std::vector<array> QuantizedMatmul::vjp(
32463246
cotangents[0],
32473247
primals[1],
32483248
primals[2],
3249-
primals[3],
3249+
mode_ == QuantizationMode::Affine ? std::optional<array>(primals[3])
3250+
: std::nullopt,
32503251
!transpose_,
32513252
group_size_,
32523253
bits_,
@@ -3260,7 +3261,7 @@ std::vector<array> QuantizedMatmul::vjp(
32603261
"[QuantizedMatmul::vjp] no gradient wrt the quantized weights.");
32613262
} else {
32623263
if (mode_ == QuantizationMode::Mxfp4) {
3263-
throw std::runtime_error(
3264+
throw std::invalid_argument(
32643265
"[QuantizedMatmul::vjp] no gradient wrt scales with mxfp4 quantization.");
32653266
}
32663267
if (!dsb) {
@@ -3305,7 +3306,8 @@ std::vector<array> QuantizedMatmul::jvp(
33053306
tangents[0],
33063307
primals[1],
33073308
primals[2],
3308-
primals[3],
3309+
mode_ == QuantizationMode::Affine ? std::optional<array>(primals[3])
3310+
: std::nullopt,
33093311
transpose_,
33103312
group_size_,
33113313
bits_,
@@ -3346,9 +3348,11 @@ std::vector<array> GatherQMM::vjp(
33463348
auto& x = primals[0];
33473349
auto& w = primals[1];
33483350
auto& scales = primals[2];
3349-
auto& biases = primals[3];
3350-
auto& lhs_indices = primals[4];
3351-
auto& rhs_indices = primals[5];
3351+
auto& lhs_indices = primals[primals.size() - 2];
3352+
auto& rhs_indices = primals[primals.size() - 1];
3353+
auto biases = (mode_ == QuantizationMode::Affine)
3354+
? std::optional<array>(primals[3])
3355+
: std::nullopt;
33523356

33533357
int M = cotan.shape(-2);
33543358
int N = cotan.shape(-1);
@@ -3401,7 +3405,7 @@ std::vector<array> GatherQMM::vjp(
34013405
"[GatherQMM::vjp] no gradient wrt the quantized weights.");
34023406
} else {
34033407
if (mode_ == QuantizationMode::Mxfp4) {
3404-
throw std::runtime_error(
3408+
throw std::invalid_argument(
34053409
"[GatherQMM::vjp] no gradient wrt scales with mxfp4 quantization.");
34063410
}
34073411

@@ -3432,7 +3436,7 @@ std::vector<array> GatherQMM::vjp(
34323436
dequantize(
34333437
w,
34343438
ones_like(scales, stream()),
3435-
zeros_like(biases, stream()),
3439+
zeros_like(*biases, stream()),
34363440
group_size_,
34373441
bits_,
34383442
quantization_mode_to_string(mode_),

python/tests/test_quantized.py

Lines changed: 31 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -842,6 +842,37 @@ def mm(sb, x, wq):
842842
num_ds = (out_up - out_down) / (2 * eps)
843843
self.assertAlmostEqual(dparams[p][idx], num_ds, delta=2e-2)
844844

845+
def test_mxfp4_vjp_scales_throws(self):
846+
mx.random.seed(0)
847+
x = mx.random.normal(shape=(2, 512))
848+
w = mx.random.normal(shape=(512, 512))
849+
wq, s = mx.quantize(w, bits=4, group_size=32, mode="mxfp4")
850+
851+
def mm(s, x, wq):
852+
return mx.quantized_matmul(
853+
x, wq, s, bits=4, group_size=32, mode="mxfp4"
854+
).sum()
855+
856+
# Should raise
857+
with self.assertRaises(ValueError):
858+
ds = mx.grad(mm)(s, x, wq)
859+
860+
rhs_indices = mx.array(0)
861+
with self.assertRaises(ValueError):
862+
863+
def gmm(s, x, wq):
864+
return mx.gather_qmm(
865+
x,
866+
wq,
867+
s,
868+
rhs_indices=rhs_indices,
869+
bits=4,
870+
group_size=32,
871+
mode="mxfp4",
872+
).sum()
873+
874+
ds = mx.grad(gmm)(s, x, wq)
875+
845876

846877
if __name__ == "__main__":
847878
mlx_tests.MLXTestRunner()

0 commit comments

Comments
 (0)