@@ -3246,7 +3246,8 @@ std::vector<array> QuantizedMatmul::vjp(
32463246 cotangents[0 ],
32473247 primals[1 ],
32483248 primals[2 ],
3249- primals[3 ],
3249+ mode_ == QuantizationMode::Affine ? std::optional<array>(primals[3 ])
3250+ : std::nullopt ,
32503251 !transpose_,
32513252 group_size_,
32523253 bits_,
@@ -3260,7 +3261,7 @@ std::vector<array> QuantizedMatmul::vjp(
32603261 " [QuantizedMatmul::vjp] no gradient wrt the quantized weights." );
32613262 } else {
32623263 if (mode_ == QuantizationMode::Mxfp4) {
3263- throw std::runtime_error (
3264+ throw std::invalid_argument (
32643265 " [QuantizedMatmul::vjp] no gradient wrt scales with mxfp4 quantization." );
32653266 }
32663267 if (!dsb) {
@@ -3305,7 +3306,8 @@ std::vector<array> QuantizedMatmul::jvp(
33053306 tangents[0 ],
33063307 primals[1 ],
33073308 primals[2 ],
3308- primals[3 ],
3309+ mode_ == QuantizationMode::Affine ? std::optional<array>(primals[3 ])
3310+ : std::nullopt ,
33093311 transpose_,
33103312 group_size_,
33113313 bits_,
@@ -3346,9 +3348,11 @@ std::vector<array> GatherQMM::vjp(
33463348 auto & x = primals[0 ];
33473349 auto & w = primals[1 ];
33483350 auto & scales = primals[2 ];
3349- auto & biases = primals[3 ];
3350- auto & lhs_indices = primals[4 ];
3351- auto & rhs_indices = primals[5 ];
3351+ auto & lhs_indices = primals[primals.size () - 2 ];
3352+ auto & rhs_indices = primals[primals.size () - 1 ];
3353+ auto biases = (mode_ == QuantizationMode::Affine)
3354+ ? std::optional<array>(primals[3 ])
3355+ : std::nullopt ;
33523356
33533357 int M = cotan.shape (-2 );
33543358 int N = cotan.shape (-1 );
@@ -3401,7 +3405,7 @@ std::vector<array> GatherQMM::vjp(
34013405 " [GatherQMM::vjp] no gradient wrt the quantized weights." );
34023406 } else {
34033407 if (mode_ == QuantizationMode::Mxfp4) {
3404- throw std::runtime_error (
3408+ throw std::invalid_argument (
34053409 " [GatherQMM::vjp] no gradient wrt scales with mxfp4 quantization." );
34063410 }
34073411
@@ -3432,7 +3436,7 @@ std::vector<array> GatherQMM::vjp(
34323436 dequantize (
34333437 w,
34343438 ones_like (scales, stream ()),
3435- zeros_like (biases, stream ()),
3439+ zeros_like (* biases, stream ()),
34363440 group_size_,
34373441 bits_,
34383442 quantization_mode_to_string (mode_),
0 commit comments