@@ -314,6 +314,64 @@ def test_vmap_matmul(self):
314314 expected = mx .addmm (mx .moveaxis (c , 2 , 0 ), a , mx .moveaxis (b , 1 , 0 ))
315315 self .assertTrue (mx .allclose (out , expected ))
316316
317+ def test_vmap_svd (self ):
318+ a = mx .random .uniform (shape = (3 , 4 , 2 ))
319+
320+ cpu_svd = lambda x : mx .linalg .svd (x , stream = mx .cpu )
321+
322+ # Vmap over the first axis (this is already supported natively by the primitive).
323+ Us , Ss , Vts = mx .vmap (cpu_svd , in_axes = (0 ,))(a )
324+ self .assertEqual (Us .shape , (a .shape [0 ], a .shape [1 ], a .shape [1 ]))
325+ self .assertEqual (Ss .shape , (a .shape [0 ], a .shape [2 ]))
326+ self .assertEqual (Vts .shape , (a .shape [0 ], a .shape [2 ], a .shape [2 ]))
327+
328+ for i in range (a .shape [0 ]):
329+ M = a [i ]
330+ U , S , Vt = Us [i ], Ss [i ], Vts [i ]
331+ self .assertTrue (
332+ mx .allclose (U [:, : len (S )] @ mx .diag (S ) @ Vt , M , rtol = 1e-5 , atol = 1e-7 )
333+ )
334+
335+ # Vmap over the second axis.
336+ Us , Ss , Vts = mx .vmap (cpu_svd , in_axes = (1 ,))(a )
337+ self .assertEqual (Us .shape , (a .shape [1 ], a .shape [0 ], a .shape [0 ]))
338+ self .assertEqual (Ss .shape , (a .shape [1 ], a .shape [2 ]))
339+ self .assertEqual (Vts .shape , (a .shape [1 ], a .shape [2 ], a .shape [2 ]))
340+
341+ for i in range (a .shape [1 ]):
342+ M = a [:, i , :]
343+ U , S , Vt = Us [i ], Ss [i ], Vts [i ]
344+ self .assertTrue (
345+ mx .allclose (U [:, : len (S )] @ mx .diag (S ) @ Vt , M , rtol = 1e-5 , atol = 1e-7 )
346+ )
347+
348+ def test_vmap_inverse (self ):
349+ a = mx .random .uniform (shape = (3 , 4 , 4 ))
350+
351+ cpu_inv = lambda x : mx .linalg .inv (x , stream = mx .cpu )
352+
353+ # Vmap over the first axis (this is already supported natively by the primitive).
354+ invs = mx .vmap (cpu_inv , in_axes = (0 ,))(a )
355+
356+ for i in range (a .shape [0 ]):
357+ self .assertTrue (
358+ mx .allclose (a [i ] @ invs [i ], mx .eye (a .shape [1 ]), rtol = 0 , atol = 1e-5 )
359+ )
360+
361+ a = mx .random .uniform (shape = (4 , 3 , 4 ))
362+
363+ # Without vmapping, each input matrix is not square.
364+ with self .assertRaises (ValueError ):
365+ mx .eval (cpu_inv (a ))
366+
367+ # Vmap over the second axis.
368+ invs = mx .vmap (cpu_inv , in_axes = (1 ,))(a )
369+
370+ for i in range (a .shape [1 ]):
371+ self .assertTrue (
372+ mx .allclose (a [:, i , :] @ invs [i ], mx .eye (a .shape [0 ]), rtol = 0 , atol = 1e-5 )
373+ )
374+
317375
318376if __name__ == "__main__" :
319377 unittest .main ()
0 commit comments