@@ -825,22 +825,18 @@ def test_wasserstein_bary_2d(nx, method):
825825
826826 # wasserstein
827827 reg = 1e-2
828- if nx .__name__ in ("jax" , "tf" ) and method == "sinkhorn_log" :
829- with pytest .raises (NotImplementedError ):
830- ot .bregman .convolutional_barycenter2d (A_nx , reg , method = method )
831- else :
832- bary_wass_np , log_np = ot .bregman .convolutional_barycenter2d (
833- A , reg , method = method , verbose = True , log = True
834- )
835- bary_wass = nx .to_numpy (
836- ot .bregman .convolutional_barycenter2d (A_nx , reg , method = method )
837- )
828+ bary_wass_np , log_np = ot .bregman .convolutional_barycenter2d (
829+ A , reg , method = method , verbose = True , log = True
830+ )
831+ bary_wass = nx .to_numpy (
832+ ot .bregman .convolutional_barycenter2d (A_nx , reg , method = method )
833+ )
838834
839- np .testing .assert_allclose (1 , np .sum (bary_wass ), rtol = 1e-3 )
840- np .testing .assert_allclose (bary_wass , bary_wass_np , atol = 1e-3 )
835+ np .testing .assert_allclose (1 , np .sum (bary_wass ), rtol = 1e-3 )
836+ np .testing .assert_allclose (bary_wass , bary_wass_np , atol = 1e-3 )
841837
842- # help in checking if log and verbose do not bug the function
843- ot .bregman .convolutional_barycenter2d (A , reg , log = True , verbose = True )
838+ # help in checking if log and verbose do not bug the function
839+ ot .bregman .convolutional_barycenter2d (A , reg , log = True , verbose = True )
844840
845841
846842@pytest .skip_backend ("tf" )
@@ -856,27 +852,23 @@ def test_wasserstein_bary_2d_dtype_device(nx, method):
856852
857853 # wasserstein
858854 reg = 1e-2
859- if nx .__name__ in ("jax" , "tf" ) and method == "sinkhorn_log" :
860- with pytest .raises (NotImplementedError ):
861- ot .bregman .convolutional_barycenter2d (Ab , reg , method = method )
862- else :
863- # Compute the barycenter with numpy
864- bary_wass_np , log_np = ot .bregman .convolutional_barycenter2d (
865- A , reg , method = method , verbose = True , log = True
866- )
867- # Compute the barycenter with the backend
868- bary_wass_b = ot .bregman .convolutional_barycenter2d (Ab , reg , method = method )
869- # Convert the backend result to numpy, to compare with the numpy result
870- bary_wass = nx .to_numpy (bary_wass_b )
855+ # Compute the barycenter with numpy
856+ bary_wass_np , log_np = ot .bregman .convolutional_barycenter2d (
857+ A , reg , method = method , verbose = True , log = True
858+ )
859+ # Compute the barycenter with the backend
860+ bary_wass_b = ot .bregman .convolutional_barycenter2d (Ab , reg , method = method )
861+ # Convert the backend result to numpy, to compare with the numpy result
862+ bary_wass = nx .to_numpy (bary_wass_b )
871863
872- np .testing .assert_allclose (1 , np .sum (bary_wass ), rtol = 1e-3 )
873- np .testing .assert_allclose (bary_wass , bary_wass_np , atol = 1e-3 )
864+ np .testing .assert_allclose (1 , np .sum (bary_wass ), rtol = 1e-3 )
865+ np .testing .assert_allclose (bary_wass , bary_wass_np , atol = 1e-3 )
874866
875- # help in checking if log and verbose do not bug the function
876- ot .bregman .convolutional_barycenter2d (A , reg , log = True , verbose = True )
867+ # help in checking if log and verbose do not bug the function
868+ ot .bregman .convolutional_barycenter2d (A , reg , log = True , verbose = True )
877869
878- # Test that the dtype and device are the same after the computation
879- nx .assert_same_dtype_device (Ab , bary_wass_b )
870+ # Test that the dtype and device are the same after the computation
871+ nx .assert_same_dtype_device (Ab , bary_wass_b )
880872
881873
882874@pytest .mark .skipif (not tf , reason = "tf not installed" )
@@ -894,37 +886,6 @@ def test_wasserstein_bary_2d_device_tf(method):
894886
895887 # wasserstein
896888 reg = 1e-2
897- if method == "sinkhorn_log" :
898- with pytest .raises (NotImplementedError ):
899- ot .bregman .convolutional_barycenter2d (Ab , reg , method = method )
900- else :
901- # Compute the barycenter with numpy
902- bary_wass_np , log_np = ot .bregman .convolutional_barycenter2d (
903- A , reg , method = method , verbose = True , log = True
904- )
905- # Compute the barycenter with the backend
906- bary_wass_b = ot .bregman .convolutional_barycenter2d (Ab , reg , method = method )
907- # Convert the backend result to numpy, to compare with the numpy result
908- bary_wass = nx .to_numpy (bary_wass_b )
909-
910- np .testing .assert_allclose (1 , np .sum (bary_wass ), rtol = 1e-3 )
911- np .testing .assert_allclose (bary_wass , bary_wass_np , atol = 1e-3 )
912-
913- # help in checking if log and verbose do not bug the function
914- ot .bregman .convolutional_barycenter2d (A , reg , log = True , verbose = True )
915-
916- # Test that the dtype and device are the same after the computation
917- nx .assert_same_dtype_device (Ab , bary_wass_b )
918-
919- # Check that everything happens on the GPU
920- Ab = nx .from_numpy (A )
921-
922- # wasserstein
923- reg = 1e-2
924- if method == "sinkhorn_log" :
925- with pytest .raises (NotImplementedError ):
926- ot .bregman .convolutional_barycenter2d (Ab , reg , method = method )
927- else :
928889 # Compute the barycenter with numpy
929890 bary_wass_np , log_np = ot .bregman .convolutional_barycenter2d (
930891 A , reg , method = method , verbose = True , log = True
@@ -943,9 +904,32 @@ def test_wasserstein_bary_2d_device_tf(method):
943904 # Test that the dtype and device are the same after the computation
944905 nx .assert_same_dtype_device (Ab , bary_wass_b )
945906
946- # Check this only if GPU is available
947- if len (tf .config .list_physical_devices ("GPU" )) > 0 :
948- assert nx .dtype_device (bary_wass_b )[1 ].startswith ("GPU" )
907+ # Check that everything happens on the GPU
908+ Ab = nx .from_numpy (A )
909+
910+ # wasserstein
911+ reg = 1e-2
912+ # Compute the barycenter with numpy
913+ bary_wass_np , log_np = ot .bregman .convolutional_barycenter2d (
914+ A , reg , method = method , verbose = True , log = True
915+ )
916+ # Compute the barycenter with the backend
917+ bary_wass_b = ot .bregman .convolutional_barycenter2d (Ab , reg , method = method )
918+ # Convert the backend result to numpy, to compare with the numpy result
919+ bary_wass = nx .to_numpy (bary_wass_b )
920+
921+ np .testing .assert_allclose (1 , np .sum (bary_wass ), rtol = 1e-3 )
922+ np .testing .assert_allclose (bary_wass , bary_wass_np , atol = 1e-3 )
923+
924+ # help in checking if log and verbose do not bug the function
925+ ot .bregman .convolutional_barycenter2d (A , reg , log = True , verbose = True )
926+
927+ # Test that the dtype and device are the same after the computation
928+ nx .assert_same_dtype_device (Ab , bary_wass_b )
929+
930+ # Check this only if GPU is available
931+ if len (tf .config .list_physical_devices ("GPU" )) > 0 :
932+ assert nx .dtype_device (bary_wass_b )[1 ].startswith ("GPU" )
949933
950934
951935@pytest .mark .parametrize ("method" , ["sinkhorn" , "sinkhorn_log" ])
@@ -957,22 +941,18 @@ def test_wasserstein_bary_2d_debiased(nx, method):
957941
958942 # wasserstein
959943 reg = 1e-2
960- if nx .__name__ in ("jax" , "tf" ) and method == "sinkhorn_log" :
961- with pytest .raises (NotImplementedError ):
962- ot .bregman .convolutional_barycenter2d_debiased (A_nx , reg , method = method )
963- else :
964- bary_wass_np , log_np = ot .bregman .convolutional_barycenter2d_debiased (
965- A , reg , method = method , verbose = True , log = True
966- )
967- bary_wass = nx .to_numpy (
968- ot .bregman .convolutional_barycenter2d_debiased (A_nx , reg , method = method )
969- )
944+ bary_wass_np , log_np = ot .bregman .convolutional_barycenter2d_debiased (
945+ A , reg , method = method , verbose = True , log = True
946+ )
947+ bary_wass = nx .to_numpy (
948+ ot .bregman .convolutional_barycenter2d_debiased (A_nx , reg , method = method )
949+ )
970950
971- np .testing .assert_allclose (1 , np .sum (bary_wass ), rtol = 1e-3 )
972- np .testing .assert_allclose (bary_wass , bary_wass_np , atol = 1e-3 )
951+ np .testing .assert_allclose (1 , np .sum (bary_wass ), rtol = 1e-3 )
952+ np .testing .assert_allclose (bary_wass , bary_wass_np , atol = 1e-3 )
973953
974- # help in checking if log and verbose do not bug the function
975- ot .bregman .convolutional_barycenter2d_debiased (A , reg , log = True , verbose = True )
954+ # help in checking if log and verbose do not bug the function
955+ ot .bregman .convolutional_barycenter2d_debiased (A , reg , log = True , verbose = True )
976956
977957
978958@pytest .skip_backend ("tf" )
@@ -988,31 +968,25 @@ def test_wasserstein_bary_2d_debiased_dtype_device(nx, method):
988968
989969 # wasserstein
990970 reg = 1e-2
991- if nx .__name__ in ("jax" , "tf" ) and method == "sinkhorn_log" :
992- with pytest .raises (NotImplementedError ):
993- ot .bregman .convolutional_barycenter2d_debiased (Ab , reg , method = method )
994- else :
995- # Compute the barycenter with numpy
996- bary_wass_np , log_np = ot .bregman .convolutional_barycenter2d_debiased (
997- A , reg , method = method , verbose = True , log = True
998- )
999- # Compute the barycenter with the backend
1000- bary_wass_b = ot .bregman .convolutional_barycenter2d_debiased (
1001- Ab , reg , method = method
1002- )
1003- # Convert the backend result to numpy, to compare with the numpy result
1004- bary_wass = nx .to_numpy (bary_wass_b )
971+ # Compute the barycenter with numpy
972+ bary_wass_np , log_np = ot .bregman .convolutional_barycenter2d_debiased (
973+ A , reg , method = method , verbose = True , log = True
974+ )
975+ # Compute the barycenter with the backend
976+ bary_wass_b = ot .bregman .convolutional_barycenter2d_debiased (
977+ Ab , reg , method = method
978+ )
979+ # Convert the backend result to numpy, to compare with the numpy result
980+ bary_wass = nx .to_numpy (bary_wass_b )
1005981
1006- np .testing .assert_allclose (1 , np .sum (bary_wass ), rtol = 1e-3 )
1007- np .testing .assert_allclose (bary_wass , bary_wass_np , atol = 1e-3 )
982+ np .testing .assert_allclose (1 , np .sum (bary_wass ), rtol = 1e-3 )
983+ np .testing .assert_allclose (bary_wass , bary_wass_np , atol = 1e-3 )
1008984
1009- # help in checking if log and verbose do not bug the function
1010- ot .bregman .convolutional_barycenter2d_debiased (
1011- A , reg , log = True , verbose = True
1012- )
985+ # help in checking if log and verbose do not bug the function
986+ ot .bregman .convolutional_barycenter2d_debiased (A , reg , log = True , verbose = True )
1013987
1014- # Test that the dtype and device are the same after the computation
1015- nx .assert_same_dtype_device (Ab , bary_wass_b )
988+ # Test that the dtype and device are the same after the computation
989+ nx .assert_same_dtype_device (Ab , bary_wass_b )
1016990
1017991
1018992@pytest .mark .skipif (not tf , reason = "tf not installed" )
@@ -1030,41 +1004,6 @@ def test_wasserstein_bary_2d_debiased_device_tf(method):
10301004
10311005 # wasserstein
10321006 reg = 1e-2
1033- if method == "sinkhorn_log" :
1034- with pytest .raises (NotImplementedError ):
1035- ot .bregman .convolutional_barycenter2d_debiased (Ab , reg , method = method )
1036- else :
1037- # Compute the barycenter with numpy
1038- bary_wass_np , log_np = ot .bregman .convolutional_barycenter2d_debiased (
1039- A , reg , method = method , verbose = True , log = True
1040- )
1041- # Compute the barycenter with the backend
1042- bary_wass_b = ot .bregman .convolutional_barycenter2d_debiased (
1043- Ab , reg , method = method
1044- )
1045- # Convert the backend result to numpy, to compare with the numpy result
1046- bary_wass = nx .to_numpy (bary_wass_b )
1047-
1048- np .testing .assert_allclose (1 , np .sum (bary_wass ), rtol = 1e-3 )
1049- np .testing .assert_allclose (bary_wass , bary_wass_np , atol = 1e-3 )
1050-
1051- # help in checking if log and verbose do not bug the function
1052- ot .bregman .convolutional_barycenter2d_debiased (
1053- A , reg , log = True , verbose = True
1054- )
1055-
1056- # Test that the dtype and device are the same after the computation
1057- nx .assert_same_dtype_device (Ab , bary_wass_b )
1058-
1059- # Check that everything happens on the GPU
1060- Ab = nx .from_numpy (A )
1061-
1062- # wasserstein
1063- reg = 1e-2
1064- if method == "sinkhorn_log" :
1065- with pytest .raises (NotImplementedError ):
1066- ot .bregman .convolutional_barycenter2d_debiased (Ab , reg , method = method )
1067- else :
10681007 # Compute the barycenter with numpy
10691008 bary_wass_np , log_np = ot .bregman .convolutional_barycenter2d_debiased (
10701009 A , reg , method = method , verbose = True , log = True
@@ -1077,6 +1016,29 @@ def test_wasserstein_bary_2d_debiased_device_tf(method):
10771016 bary_wass = nx .to_numpy (bary_wass_b )
10781017
10791018 np .testing .assert_allclose (1 , np .sum (bary_wass ), rtol = 1e-3 )
1019+ np .testing .assert_allclose (bary_wass , bary_wass_np , atol = 1e-3 )
1020+
1021+ # help in checking if log and verbose do not bug the function
1022+ ot .bregman .convolutional_barycenter2d_debiased (A , reg , log = True , verbose = True )
1023+
1024+ # Test that the dtype and device are the same after the computation
1025+ nx .assert_same_dtype_device (Ab , bary_wass_b )
1026+
1027+ # Check that everything happens on the GPU
1028+ Ab = nx .from_numpy (A )
1029+
1030+ # wasserstein
1031+ reg = 1e-2
1032+ # Compute the barycenter with numpy
1033+ bary_wass_np , log_np = ot .bregman .convolutional_barycenter2d_debiased (
1034+ A , reg , method = method , verbose = True , log = True
1035+ )
1036+ # Compute the barycenter with the backend
1037+ bary_wass_b = ot .bregman .convolutional_barycenter2d_debiased (Ab , reg , method = method )
1038+ # Convert the backend result to numpy, to compare with the numpy result
1039+ bary_wass = nx .to_numpy (bary_wass_b )
1040+
1041+ np .testing .assert_allclose (1 , np .sum (bary_wass ), rtol = 1e-3 )
10801042
10811043
10821044def test_unmix (nx ):
0 commit comments