Skip to content

Commit fd1c081

Browse files
authored
stable cumprod grad at 0 (#1167)
1 parent 76b6cec commit fd1c081

File tree

2 files changed

+119
-6
lines changed

2 files changed

+119
-6
lines changed

mlx/primitives.cpp

Lines changed: 46 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -2748,12 +2748,52 @@ std::vector<array> Scan::vjp(
27482748
if (reduce_type_ == Scan::Sum) {
27492749
return {cumsum(cotangents[0], axis_, !reverse_, inclusive_, stream())};
27502750
} else if (reduce_type_ == Scan::Prod) {
2751-
// TODO: Make it numerically stable when we introduce where()
2752-
auto prod = outputs[0];
2753-
auto partial_grads = multiply(prod, cotangents[0], stream());
2754-
auto accum_grads =
2755-
cumsum(partial_grads, axis_, !reverse_, inclusive_, stream());
2756-
return {divide(accum_grads, primals[0], stream())};
2751+
auto in = primals[0];
2752+
// Find the location of the first 0 and set it to 1:
2753+
// - A: Exclusive cumprod
2754+
// - B: Inclusive cumprod
2755+
// - Find the location that is 0 in A and not zero B
2756+
// Compute the gradient by:
2757+
// - Compute the regular gradient for everything before the first zero
2758+
// - Set the first zero to 1 and redo the computation, use this for the
2759+
// gradient of the first zero
2760+
// - Everything after the first zero has a gradient of 0
2761+
2762+
// Get inclusive and exclusive cum prods
2763+
auto cprod_exclusive = cumprod(in, axis_, reverse_, !inclusive_, stream());
2764+
auto cprod_inclusive = outputs[0];
2765+
if (!inclusive_) {
2766+
std::swap(cprod_exclusive, cprod_inclusive);
2767+
}
2768+
2769+
// Make the mask for the first zero
2770+
auto z = array(0, in.dtype());
2771+
auto eq_zero = equal(cprod_inclusive, z, stream());
2772+
auto first_zero =
2773+
logical_and(eq_zero, not_equal(cprod_exclusive, z, stream()), stream());
2774+
2775+
auto to_partial_grad = [this, &cotangents](const array& arr) {
2776+
return cumsum(
2777+
multiply(arr, cotangents[0], stream()),
2778+
axis_,
2779+
!reverse_,
2780+
inclusive_,
2781+
stream());
2782+
};
2783+
2784+
auto cprod_with_one = cumprod(
2785+
where(first_zero, array(1, in.dtype()), in, stream()),
2786+
axis_,
2787+
reverse_,
2788+
inclusive_,
2789+
stream());
2790+
auto grad_with_one = to_partial_grad(cprod_with_one);
2791+
auto grad = divide(to_partial_grad(outputs[0]), in, stream());
2792+
return {where(
2793+
first_zero,
2794+
grad_with_one,
2795+
where(eq_zero, z, grad, stream()),
2796+
stream())};
27572797
} else {
27582798
// Can probably be implemented by equals and then cummax to make the mask
27592799
throw std::runtime_error("VJP is not implemented for cumulative min/max");

python/tests/test_autograd.py

Lines changed: 73 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -423,6 +423,79 @@ def fun(x, y):
423423
grad = mx.grad(fun)(mx.array(1.0), mx.array(1.0))
424424
self.assertEqual(grad.item(), 1.0)
425425

426+
def test_cumprod_grad(self):
427+
def fun(y):
428+
return mx.cumprod(y).sum()
429+
430+
y = mx.array([2.0, 1.0, 2.0, 2.0, 3.0])
431+
out = mx.grad(fun)(y)
432+
expected = mx.array([20.0, 38.0, 18.0, 16.0, 8.0])
433+
self.assertTrue(mx.allclose(out, expected))
434+
435+
y = mx.array([2.0, 0.0, 2.0, 2.0, 3.0])
436+
out = mx.grad(fun)(y)
437+
expected = mx.array([1.0, 38.0, 0.0, 0.0, 0.0])
438+
self.assertTrue(mx.allclose(out, expected))
439+
440+
y = mx.array([2.0, 0.0, 2.0, 0.0, 3.0])
441+
out = mx.grad(fun)(y)
442+
expected = mx.array([1.0, 6.0, 0.0, 0.0, 0.0])
443+
self.assertTrue(mx.allclose(out, expected))
444+
445+
def fun(y):
446+
return mx.cumprod(y, inclusive=False).sum()
447+
448+
y = mx.array([2.0, 1.0, 2.0, 2.0, 3.0])
449+
out = mx.grad(fun)(y)
450+
expected = mx.array([8.0, 14.0, 6.0, 4.0, 0.0])
451+
self.assertTrue(mx.allclose(out, expected))
452+
453+
y = mx.array([2.0, 0.0, 2.0, 2.0, 3.0])
454+
out = mx.grad(fun)(y)
455+
expected = mx.array([1.0, 14.0, 0.0, 0.0, 0.0])
456+
self.assertTrue(mx.allclose(out, expected))
457+
458+
y = mx.array([2.0, 0.0, 2.0, 0.0, 3.0])
459+
out = mx.grad(fun)(y)
460+
expected = mx.array([1.0, 6.0, 0.0, 0.0, 0.0])
461+
self.assertTrue(mx.allclose(out, expected))
462+
463+
def fun(y):
464+
return mx.cumprod(y, inclusive=False, reverse=True).sum()
465+
466+
y = mx.array([2.0, 1.0, 2.0, 2.0, 3.0])
467+
out = mx.grad(fun)(y)
468+
expected = mx.array([0.0, 12.0, 12.0, 15.0, 11.0])
469+
self.assertTrue(mx.allclose(out, expected))
470+
471+
y = mx.array([2.0, 0.0, 2.0, 2.0, 3.0])
472+
out = mx.grad(fun)(y)
473+
expected = mx.array([0.0, 12.0, 6.0, 9.0, 7.0])
474+
self.assertTrue(mx.allclose(out, expected))
475+
476+
y = mx.array([2.0, 0.0, 2.0, 0.0, 3.0])
477+
out = mx.grad(fun)(y)
478+
expected = mx.array([0.0, 0.0, 0.0, 9.0, 1.0])
479+
self.assertTrue(mx.allclose(out, expected))
480+
481+
def fun(y):
482+
return mx.cumprod(y, reverse=True).sum()
483+
484+
y = mx.array([2.0, 1.0, 2.0, 2.0, 3.0])
485+
out = mx.grad(fun)(y)
486+
expected = mx.array([12.0, 36.0, 24.0, 27.0, 19.0])
487+
self.assertTrue(mx.allclose(out, expected))
488+
489+
y = mx.array([2.0, 0.0, 2.0, 2.0, 3.0])
490+
out = mx.grad(fun)(y)
491+
expected = mx.array([0.0, 36.0, 6.0, 9.0, 7.0])
492+
self.assertTrue(mx.allclose(out, expected))
493+
494+
y = mx.array([2.0, 0.0, 2.0, 0.0, 3.0])
495+
out = mx.grad(fun)(y)
496+
expected = mx.array([0.0, 0.0, 0.0, 9.0, 1.0])
497+
self.assertTrue(mx.allclose(out, expected))
498+
426499

427500
if __name__ == "__main__":
428501
unittest.main()

0 commit comments

Comments
 (0)