@@ -289,6 +289,30 @@ def test_projections_stiefel():
289289 np .matmul (P_T , P ), np .array ([np .eye (2 ) for k in range (n_projs )])
290290 )
291291
292+ rng = np .random .RandomState (0 )
293+
294+ projections = ot .sliced .get_projections_sphere (3 , n_projs , seed = rng )
295+ projections_T = np .transpose (projections , [0 , 2 , 1 ])
296+
297+ np .testing .assert_almost_equal (
298+ np .matmul (projections_T , projections ),
299+ np .array ([np .eye (2 ) for k in range (n_projs )]),
300+ )
301+
302+ # np.testing.assert_almost_equal(projections, P)
303+
304+
305+ def test_projections_sphere_to_circle ():
306+ rng = np .random .RandomState (0 )
307+
308+ n_projs = 500
309+ x = rng .randn (100 , 3 )
310+ x = x / np .sqrt (np .sum (x ** 2 , - 1 , keepdims = True ))
311+
312+ x_projs , _ = ot .sliced .projection_sphere_to_circle (x , n_projs )
313+ assert x_projs .shape == (n_projs , 100 )
314+ assert np .all (x_projs >= 0 ) and np .all (x_projs < 1 )
315+
292316
293317def test_sliced_sphere_same_dist ():
294318 n = 100
@@ -506,3 +530,153 @@ def test_sliced_sphere_unif_backend_type_devices(nx):
506530 valb = ot .sliced_wasserstein_sphere_unif (xb )
507531
508532 nx .assert_same_dtype_device (xb , valb )
533+
534+
535+ def test_linear_sliced_sphere_same_dist ():
536+ n = 100
537+ rng = np .random .RandomState (0 )
538+
539+ x = rng .randn (n , 3 )
540+ x = x / np .sqrt (np .sum (x ** 2 , - 1 , keepdims = True ))
541+ u = ot .utils .unif (n )
542+
543+ res = ot .linear_sliced_wasserstein_sphere (x , x , u , u , 10 , seed = rng )
544+ np .testing .assert_almost_equal (res , 0.0 )
545+
546+
547+ def test_linear_sliced_sphere_same_proj ():
548+ n_projections = 10
549+ n = 100
550+ rng = np .random .RandomState (0 )
551+
552+ x = rng .randn (n , 3 )
553+ x = x / np .sqrt (np .sum (x ** 2 , - 1 , keepdims = True ))
554+
555+ y = rng .randn (n , 3 )
556+ y = y / np .sqrt (np .sum (y ** 2 , - 1 , keepdims = True ))
557+
558+ seed = 42
559+
560+ cost1 , log1 = ot .linear_sliced_wasserstein_sphere (
561+ x , y , seed = seed , n_projections = n_projections , log = True
562+ )
563+ cost2 , log2 = ot .linear_sliced_wasserstein_sphere (
564+ x , y , seed = seed , n_projections = n_projections , log = True
565+ )
566+
567+ assert np .allclose (log1 ["projections" ], log2 ["projections" ])
568+ assert np .isclose (cost1 , cost2 )
569+
570+
571+ def test_linear_sliced_sphere_bad_shapes ():
572+ n = 100
573+ rng = np .random .RandomState (0 )
574+
575+ x = rng .randn (n , 3 )
576+ x = x / np .sqrt (np .sum (x ** 2 , - 1 , keepdims = True ))
577+
578+ y = rng .randn (n , 4 )
579+ y = y / np .sqrt (np .sum (x ** 2 , - 1 , keepdims = True ))
580+
581+ u = ot .utils .unif (n )
582+
583+ with pytest .raises (ValueError ):
584+ _ = ot .linear_sliced_wasserstein_sphere (x , y , u , u , 10 , seed = rng )
585+
586+
587+ def test_linear_sliced_sphere_values_on_the_sphere ():
588+ n = 100
589+ rng = np .random .RandomState (0 )
590+
591+ x = rng .randn (n , 3 )
592+ x = x / np .sqrt (np .sum (x ** 2 , - 1 , keepdims = True ))
593+
594+ y = rng .randn (n , 4 )
595+
596+ u = ot .utils .unif (n )
597+
598+ with pytest .raises (ValueError ):
599+ _ = ot .linear_sliced_wasserstein_sphere (x , y , u , u , 10 , seed = rng )
600+
601+
602+ def test_linear_sliced_sphere_log ():
603+ n = 100
604+ rng = np .random .RandomState (0 )
605+
606+ x = rng .randn (n , 4 )
607+ x = x / np .sqrt (np .sum (x ** 2 , - 1 , keepdims = True ))
608+ y = rng .randn (n , 4 )
609+ y = y / np .sqrt (np .sum (y ** 2 , - 1 , keepdims = True ))
610+ u = ot .utils .unif (n )
611+
612+ res , log = ot .linear_sliced_wasserstein_sphere (x , y , u , u , 10 , seed = rng , log = True )
613+ assert len (log ) == 2
614+ projections = log ["projections" ]
615+ projected_emds = log ["projected_emds" ]
616+
617+ assert projections .shape [0 ] == len (projected_emds ) == 10
618+ for emd in projected_emds :
619+ assert emd > 0
620+
621+
622+ def test_linear_sliced_sphere_different_dists ():
623+ n = 100
624+ rng = np .random .RandomState (0 )
625+
626+ x = rng .randn (n , 3 )
627+ x = x / np .sqrt (np .sum (x ** 2 , - 1 , keepdims = True ))
628+
629+ u = ot .utils .unif (n )
630+ y = rng .randn (n , 3 )
631+ y = y / np .sqrt (np .sum (y ** 2 , - 1 , keepdims = True ))
632+
633+ res = ot .linear_sliced_wasserstein_sphere (x , y , u , u , 10 , seed = rng )
634+ assert res > 0.0
635+
636+
637+ def test_1d_linear_sliced_sphere_equals_emd ():
638+ n = 100
639+ m = 120
640+ rng = np .random .RandomState (0 )
641+
642+ x = rng .randn (n , 2 )
643+ x = x / np .sqrt (np .sum (x ** 2 , - 1 , keepdims = True ))
644+ x_coords = (np .arctan2 (- x [:, 1 ], - x [:, 0 ]) + np .pi ) / (2 * np .pi )
645+ a = rng .uniform (0 , 1 , n )
646+ a /= a .sum ()
647+
648+ y = rng .randn (m , 2 )
649+ y = y / np .sqrt (np .sum (y ** 2 , - 1 , keepdims = True ))
650+ y_coords = (np .arctan2 (- y [:, 1 ], - y [:, 0 ]) + np .pi ) / (2 * np .pi )
651+ u = ot .utils .unif (m )
652+
653+ res = ot .linear_sliced_wasserstein_sphere (x , y , a , u , 100 , seed = 42 )
654+ expected = ot .linear_circular_ot (x_coords .T , y_coords .T , a , u )
655+
656+ np .testing .assert_almost_equal (res ** 2 , expected , decimal = 5 )
657+
658+
659+ def test_linear_sliced_sphere_backend_type_devices (nx ):
660+ n = 100
661+ rng = np .random .RandomState (0 )
662+
663+ x = rng .randn (n , 3 )
664+ x = x / np .sqrt (np .sum (x ** 2 , - 1 , keepdims = True ))
665+
666+ y = rng .randn (2 * n , 3 )
667+ y = y / np .sqrt (np .sum (y ** 2 , - 1 , keepdims = True ))
668+
669+ sw_np , log = ot .linear_sliced_wasserstein_sphere (x , y , log = True )
670+ P = log ["projections" ]
671+
672+ for tp in nx .__type_list__ :
673+ print (nx .dtype_device (tp ))
674+
675+ xb , yb = nx .from_numpy (x , y , type_as = tp )
676+
677+ valb = ot .linear_sliced_wasserstein_sphere (
678+ xb , yb , projections = nx .from_numpy (P , type_as = tp )
679+ )
680+
681+ nx .assert_same_dtype_device (xb , valb )
682+ np .testing .assert_almost_equal (sw_np , nx .to_numpy (valb ))
0 commit comments