@@ -423,6 +423,79 @@ def fun(x, y):
423423 grad = mx .grad (fun )(mx .array (1.0 ), mx .array (1.0 ))
424424 self .assertEqual (grad .item (), 1.0 )
425425
426+ def test_cumprod_grad (self ):
427+ def fun (y ):
428+ return mx .cumprod (y ).sum ()
429+
430+ y = mx .array ([2.0 , 1.0 , 2.0 , 2.0 , 3.0 ])
431+ out = mx .grad (fun )(y )
432+ expected = mx .array ([20.0 , 38.0 , 18.0 , 16.0 , 8.0 ])
433+ self .assertTrue (mx .allclose (out , expected ))
434+
435+ y = mx .array ([2.0 , 0.0 , 2.0 , 2.0 , 3.0 ])
436+ out = mx .grad (fun )(y )
437+ expected = mx .array ([1.0 , 38.0 , 0.0 , 0.0 , 0.0 ])
438+ self .assertTrue (mx .allclose (out , expected ))
439+
440+ y = mx .array ([2.0 , 0.0 , 2.0 , 0.0 , 3.0 ])
441+ out = mx .grad (fun )(y )
442+ expected = mx .array ([1.0 , 6.0 , 0.0 , 0.0 , 0.0 ])
443+ self .assertTrue (mx .allclose (out , expected ))
444+
445+ def fun (y ):
446+ return mx .cumprod (y , inclusive = False ).sum ()
447+
448+ y = mx .array ([2.0 , 1.0 , 2.0 , 2.0 , 3.0 ])
449+ out = mx .grad (fun )(y )
450+ expected = mx .array ([8.0 , 14.0 , 6.0 , 4.0 , 0.0 ])
451+ self .assertTrue (mx .allclose (out , expected ))
452+
453+ y = mx .array ([2.0 , 0.0 , 2.0 , 2.0 , 3.0 ])
454+ out = mx .grad (fun )(y )
455+ expected = mx .array ([1.0 , 14.0 , 0.0 , 0.0 , 0.0 ])
456+ self .assertTrue (mx .allclose (out , expected ))
457+
458+ y = mx .array ([2.0 , 0.0 , 2.0 , 0.0 , 3.0 ])
459+ out = mx .grad (fun )(y )
460+ expected = mx .array ([1.0 , 6.0 , 0.0 , 0.0 , 0.0 ])
461+ self .assertTrue (mx .allclose (out , expected ))
462+
463+ def fun (y ):
464+ return mx .cumprod (y , inclusive = False , reverse = True ).sum ()
465+
466+ y = mx .array ([2.0 , 1.0 , 2.0 , 2.0 , 3.0 ])
467+ out = mx .grad (fun )(y )
468+ expected = mx .array ([0.0 , 12.0 , 12.0 , 15.0 , 11.0 ])
469+ self .assertTrue (mx .allclose (out , expected ))
470+
471+ y = mx .array ([2.0 , 0.0 , 2.0 , 2.0 , 3.0 ])
472+ out = mx .grad (fun )(y )
473+ expected = mx .array ([0.0 , 12.0 , 6.0 , 9.0 , 7.0 ])
474+ self .assertTrue (mx .allclose (out , expected ))
475+
476+ y = mx .array ([2.0 , 0.0 , 2.0 , 0.0 , 3.0 ])
477+ out = mx .grad (fun )(y )
478+ expected = mx .array ([0.0 , 0.0 , 0.0 , 9.0 , 1.0 ])
479+ self .assertTrue (mx .allclose (out , expected ))
480+
481+ def fun (y ):
482+ return mx .cumprod (y , reverse = True ).sum ()
483+
484+ y = mx .array ([2.0 , 1.0 , 2.0 , 2.0 , 3.0 ])
485+ out = mx .grad (fun )(y )
486+ expected = mx .array ([12.0 , 36.0 , 24.0 , 27.0 , 19.0 ])
487+ self .assertTrue (mx .allclose (out , expected ))
488+
489+ y = mx .array ([2.0 , 0.0 , 2.0 , 2.0 , 3.0 ])
490+ out = mx .grad (fun )(y )
491+ expected = mx .array ([0.0 , 36.0 , 6.0 , 9.0 , 7.0 ])
492+ self .assertTrue (mx .allclose (out , expected ))
493+
494+ y = mx .array ([2.0 , 0.0 , 2.0 , 0.0 , 3.0 ])
495+ out = mx .grad (fun )(y )
496+ expected = mx .array ([0.0 , 0.0 , 0.0 , 9.0 , 1.0 ])
497+ self .assertTrue (mx .allclose (out , expected ))
498+
426499
427500if __name__ == "__main__" :
428501 unittest .main ()
0 commit comments