Skip to content

Commit 80d9542

Browse files
author
Francisco Muñoz
committed
test: refactor tests to delete the error for unavailable backends
1 parent a70cf8d commit 80d9542

File tree

1 file changed

+99
-137
lines changed

1 file changed

+99
-137
lines changed

test/test_bregman.py

Lines changed: 99 additions & 137 deletions
Original file line numberDiff line numberDiff line change
@@ -825,22 +825,18 @@ def test_wasserstein_bary_2d(nx, method):
825825

826826
# wasserstein
827827
reg = 1e-2
828-
if nx.__name__ in ("jax", "tf") and method == "sinkhorn_log":
829-
with pytest.raises(NotImplementedError):
830-
ot.bregman.convolutional_barycenter2d(A_nx, reg, method=method)
831-
else:
832-
bary_wass_np, log_np = ot.bregman.convolutional_barycenter2d(
833-
A, reg, method=method, verbose=True, log=True
834-
)
835-
bary_wass = nx.to_numpy(
836-
ot.bregman.convolutional_barycenter2d(A_nx, reg, method=method)
837-
)
828+
bary_wass_np, log_np = ot.bregman.convolutional_barycenter2d(
829+
A, reg, method=method, verbose=True, log=True
830+
)
831+
bary_wass = nx.to_numpy(
832+
ot.bregman.convolutional_barycenter2d(A_nx, reg, method=method)
833+
)
838834

839-
np.testing.assert_allclose(1, np.sum(bary_wass), rtol=1e-3)
840-
np.testing.assert_allclose(bary_wass, bary_wass_np, atol=1e-3)
835+
np.testing.assert_allclose(1, np.sum(bary_wass), rtol=1e-3)
836+
np.testing.assert_allclose(bary_wass, bary_wass_np, atol=1e-3)
841837

842-
# help in checking if log and verbose do not bug the function
843-
ot.bregman.convolutional_barycenter2d(A, reg, log=True, verbose=True)
838+
# help in checking if log and verbose do not bug the function
839+
ot.bregman.convolutional_barycenter2d(A, reg, log=True, verbose=True)
844840

845841

846842
@pytest.skip_backend("tf")
@@ -856,27 +852,23 @@ def test_wasserstein_bary_2d_dtype_device(nx, method):
856852

857853
# wasserstein
858854
reg = 1e-2
859-
if nx.__name__ in ("jax", "tf") and method == "sinkhorn_log":
860-
with pytest.raises(NotImplementedError):
861-
ot.bregman.convolutional_barycenter2d(Ab, reg, method=method)
862-
else:
863-
# Compute the barycenter with numpy
864-
bary_wass_np, log_np = ot.bregman.convolutional_barycenter2d(
865-
A, reg, method=method, verbose=True, log=True
866-
)
867-
# Compute the barycenter with the backend
868-
bary_wass_b = ot.bregman.convolutional_barycenter2d(Ab, reg, method=method)
869-
# Convert the backend result to numpy, to compare with the numpy result
870-
bary_wass = nx.to_numpy(bary_wass_b)
855+
# Compute the barycenter with numpy
856+
bary_wass_np, log_np = ot.bregman.convolutional_barycenter2d(
857+
A, reg, method=method, verbose=True, log=True
858+
)
859+
# Compute the barycenter with the backend
860+
bary_wass_b = ot.bregman.convolutional_barycenter2d(Ab, reg, method=method)
861+
# Convert the backend result to numpy, to compare with the numpy result
862+
bary_wass = nx.to_numpy(bary_wass_b)
871863

872-
np.testing.assert_allclose(1, np.sum(bary_wass), rtol=1e-3)
873-
np.testing.assert_allclose(bary_wass, bary_wass_np, atol=1e-3)
864+
np.testing.assert_allclose(1, np.sum(bary_wass), rtol=1e-3)
865+
np.testing.assert_allclose(bary_wass, bary_wass_np, atol=1e-3)
874866

875-
# help in checking if log and verbose do not bug the function
876-
ot.bregman.convolutional_barycenter2d(A, reg, log=True, verbose=True)
867+
# help in checking if log and verbose do not bug the function
868+
ot.bregman.convolutional_barycenter2d(A, reg, log=True, verbose=True)
877869

878-
# Test that the dtype and device are the same after the computation
879-
nx.assert_same_dtype_device(Ab, bary_wass_b)
870+
# Test that the dtype and device are the same after the computation
871+
nx.assert_same_dtype_device(Ab, bary_wass_b)
880872

881873

882874
@pytest.mark.skipif(not tf, reason="tf not installed")
@@ -894,37 +886,6 @@ def test_wasserstein_bary_2d_device_tf(method):
894886

895887
# wasserstein
896888
reg = 1e-2
897-
if method == "sinkhorn_log":
898-
with pytest.raises(NotImplementedError):
899-
ot.bregman.convolutional_barycenter2d(Ab, reg, method=method)
900-
else:
901-
# Compute the barycenter with numpy
902-
bary_wass_np, log_np = ot.bregman.convolutional_barycenter2d(
903-
A, reg, method=method, verbose=True, log=True
904-
)
905-
# Compute the barycenter with the backend
906-
bary_wass_b = ot.bregman.convolutional_barycenter2d(Ab, reg, method=method)
907-
# Convert the backend result to numpy, to compare with the numpy result
908-
bary_wass = nx.to_numpy(bary_wass_b)
909-
910-
np.testing.assert_allclose(1, np.sum(bary_wass), rtol=1e-3)
911-
np.testing.assert_allclose(bary_wass, bary_wass_np, atol=1e-3)
912-
913-
# help in checking if log and verbose do not bug the function
914-
ot.bregman.convolutional_barycenter2d(A, reg, log=True, verbose=True)
915-
916-
# Test that the dtype and device are the same after the computation
917-
nx.assert_same_dtype_device(Ab, bary_wass_b)
918-
919-
# Check that everything happens on the GPU
920-
Ab = nx.from_numpy(A)
921-
922-
# wasserstein
923-
reg = 1e-2
924-
if method == "sinkhorn_log":
925-
with pytest.raises(NotImplementedError):
926-
ot.bregman.convolutional_barycenter2d(Ab, reg, method=method)
927-
else:
928889
# Compute the barycenter with numpy
929890
bary_wass_np, log_np = ot.bregman.convolutional_barycenter2d(
930891
A, reg, method=method, verbose=True, log=True
@@ -943,9 +904,32 @@ def test_wasserstein_bary_2d_device_tf(method):
943904
# Test that the dtype and device are the same after the computation
944905
nx.assert_same_dtype_device(Ab, bary_wass_b)
945906

946-
# Check this only if GPU is available
947-
if len(tf.config.list_physical_devices("GPU")) > 0:
948-
assert nx.dtype_device(bary_wass_b)[1].startswith("GPU")
907+
# Check that everything happens on the GPU
908+
Ab = nx.from_numpy(A)
909+
910+
# wasserstein
911+
reg = 1e-2
912+
# Compute the barycenter with numpy
913+
bary_wass_np, log_np = ot.bregman.convolutional_barycenter2d(
914+
A, reg, method=method, verbose=True, log=True
915+
)
916+
# Compute the barycenter with the backend
917+
bary_wass_b = ot.bregman.convolutional_barycenter2d(Ab, reg, method=method)
918+
# Convert the backend result to numpy, to compare with the numpy result
919+
bary_wass = nx.to_numpy(bary_wass_b)
920+
921+
np.testing.assert_allclose(1, np.sum(bary_wass), rtol=1e-3)
922+
np.testing.assert_allclose(bary_wass, bary_wass_np, atol=1e-3)
923+
924+
# help in checking if log and verbose do not bug the function
925+
ot.bregman.convolutional_barycenter2d(A, reg, log=True, verbose=True)
926+
927+
# Test that the dtype and device are the same after the computation
928+
nx.assert_same_dtype_device(Ab, bary_wass_b)
929+
930+
# Check this only if GPU is available
931+
if len(tf.config.list_physical_devices("GPU")) > 0:
932+
assert nx.dtype_device(bary_wass_b)[1].startswith("GPU")
949933

950934

951935
@pytest.mark.parametrize("method", ["sinkhorn", "sinkhorn_log"])
@@ -957,22 +941,18 @@ def test_wasserstein_bary_2d_debiased(nx, method):
957941

958942
# wasserstein
959943
reg = 1e-2
960-
if nx.__name__ in ("jax", "tf") and method == "sinkhorn_log":
961-
with pytest.raises(NotImplementedError):
962-
ot.bregman.convolutional_barycenter2d_debiased(A_nx, reg, method=method)
963-
else:
964-
bary_wass_np, log_np = ot.bregman.convolutional_barycenter2d_debiased(
965-
A, reg, method=method, verbose=True, log=True
966-
)
967-
bary_wass = nx.to_numpy(
968-
ot.bregman.convolutional_barycenter2d_debiased(A_nx, reg, method=method)
969-
)
944+
bary_wass_np, log_np = ot.bregman.convolutional_barycenter2d_debiased(
945+
A, reg, method=method, verbose=True, log=True
946+
)
947+
bary_wass = nx.to_numpy(
948+
ot.bregman.convolutional_barycenter2d_debiased(A_nx, reg, method=method)
949+
)
970950

971-
np.testing.assert_allclose(1, np.sum(bary_wass), rtol=1e-3)
972-
np.testing.assert_allclose(bary_wass, bary_wass_np, atol=1e-3)
951+
np.testing.assert_allclose(1, np.sum(bary_wass), rtol=1e-3)
952+
np.testing.assert_allclose(bary_wass, bary_wass_np, atol=1e-3)
973953

974-
# help in checking if log and verbose do not bug the function
975-
ot.bregman.convolutional_barycenter2d_debiased(A, reg, log=True, verbose=True)
954+
# help in checking if log and verbose do not bug the function
955+
ot.bregman.convolutional_barycenter2d_debiased(A, reg, log=True, verbose=True)
976956

977957

978958
@pytest.skip_backend("tf")
@@ -988,31 +968,25 @@ def test_wasserstein_bary_2d_debiased_dtype_device(nx, method):
988968

989969
# wasserstein
990970
reg = 1e-2
991-
if nx.__name__ in ("jax", "tf") and method == "sinkhorn_log":
992-
with pytest.raises(NotImplementedError):
993-
ot.bregman.convolutional_barycenter2d_debiased(Ab, reg, method=method)
994-
else:
995-
# Compute the barycenter with numpy
996-
bary_wass_np, log_np = ot.bregman.convolutional_barycenter2d_debiased(
997-
A, reg, method=method, verbose=True, log=True
998-
)
999-
# Compute the barycenter with the backend
1000-
bary_wass_b = ot.bregman.convolutional_barycenter2d_debiased(
1001-
Ab, reg, method=method
1002-
)
1003-
# Convert the backend result to numpy, to compare with the numpy result
1004-
bary_wass = nx.to_numpy(bary_wass_b)
971+
# Compute the barycenter with numpy
972+
bary_wass_np, log_np = ot.bregman.convolutional_barycenter2d_debiased(
973+
A, reg, method=method, verbose=True, log=True
974+
)
975+
# Compute the barycenter with the backend
976+
bary_wass_b = ot.bregman.convolutional_barycenter2d_debiased(
977+
Ab, reg, method=method
978+
)
979+
# Convert the backend result to numpy, to compare with the numpy result
980+
bary_wass = nx.to_numpy(bary_wass_b)
1005981

1006-
np.testing.assert_allclose(1, np.sum(bary_wass), rtol=1e-3)
1007-
np.testing.assert_allclose(bary_wass, bary_wass_np, atol=1e-3)
982+
np.testing.assert_allclose(1, np.sum(bary_wass), rtol=1e-3)
983+
np.testing.assert_allclose(bary_wass, bary_wass_np, atol=1e-3)
1008984

1009-
# help in checking if log and verbose do not bug the function
1010-
ot.bregman.convolutional_barycenter2d_debiased(
1011-
A, reg, log=True, verbose=True
1012-
)
985+
# help in checking if log and verbose do not bug the function
986+
ot.bregman.convolutional_barycenter2d_debiased(A, reg, log=True, verbose=True)
1013987

1014-
# Test that the dtype and device are the same after the computation
1015-
nx.assert_same_dtype_device(Ab, bary_wass_b)
988+
# Test that the dtype and device are the same after the computation
989+
nx.assert_same_dtype_device(Ab, bary_wass_b)
1016990

1017991

1018992
@pytest.mark.skipif(not tf, reason="tf not installed")
@@ -1030,41 +1004,6 @@ def test_wasserstein_bary_2d_debiased_device_tf(method):
10301004

10311005
# wasserstein
10321006
reg = 1e-2
1033-
if method == "sinkhorn_log":
1034-
with pytest.raises(NotImplementedError):
1035-
ot.bregman.convolutional_barycenter2d_debiased(Ab, reg, method=method)
1036-
else:
1037-
# Compute the barycenter with numpy
1038-
bary_wass_np, log_np = ot.bregman.convolutional_barycenter2d_debiased(
1039-
A, reg, method=method, verbose=True, log=True
1040-
)
1041-
# Compute the barycenter with the backend
1042-
bary_wass_b = ot.bregman.convolutional_barycenter2d_debiased(
1043-
Ab, reg, method=method
1044-
)
1045-
# Convert the backend result to numpy, to compare with the numpy result
1046-
bary_wass = nx.to_numpy(bary_wass_b)
1047-
1048-
np.testing.assert_allclose(1, np.sum(bary_wass), rtol=1e-3)
1049-
np.testing.assert_allclose(bary_wass, bary_wass_np, atol=1e-3)
1050-
1051-
# help in checking if log and verbose do not bug the function
1052-
ot.bregman.convolutional_barycenter2d_debiased(
1053-
A, reg, log=True, verbose=True
1054-
)
1055-
1056-
# Test that the dtype and device are the same after the computation
1057-
nx.assert_same_dtype_device(Ab, bary_wass_b)
1058-
1059-
# Check that everything happens on the GPU
1060-
Ab = nx.from_numpy(A)
1061-
1062-
# wasserstein
1063-
reg = 1e-2
1064-
if method == "sinkhorn_log":
1065-
with pytest.raises(NotImplementedError):
1066-
ot.bregman.convolutional_barycenter2d_debiased(Ab, reg, method=method)
1067-
else:
10681007
# Compute the barycenter with numpy
10691008
bary_wass_np, log_np = ot.bregman.convolutional_barycenter2d_debiased(
10701009
A, reg, method=method, verbose=True, log=True
@@ -1077,6 +1016,29 @@ def test_wasserstein_bary_2d_debiased_device_tf(method):
10771016
bary_wass = nx.to_numpy(bary_wass_b)
10781017

10791018
np.testing.assert_allclose(1, np.sum(bary_wass), rtol=1e-3)
1019+
np.testing.assert_allclose(bary_wass, bary_wass_np, atol=1e-3)
1020+
1021+
# help in checking if log and verbose do not bug the function
1022+
ot.bregman.convolutional_barycenter2d_debiased(A, reg, log=True, verbose=True)
1023+
1024+
# Test that the dtype and device are the same after the computation
1025+
nx.assert_same_dtype_device(Ab, bary_wass_b)
1026+
1027+
# Check that everything happens on the GPU
1028+
Ab = nx.from_numpy(A)
1029+
1030+
# wasserstein
1031+
reg = 1e-2
1032+
# Compute the barycenter with numpy
1033+
bary_wass_np, log_np = ot.bregman.convolutional_barycenter2d_debiased(
1034+
A, reg, method=method, verbose=True, log=True
1035+
)
1036+
# Compute the barycenter with the backend
1037+
bary_wass_b = ot.bregman.convolutional_barycenter2d_debiased(Ab, reg, method=method)
1038+
# Convert the backend result to numpy, to compare with the numpy result
1039+
bary_wass = nx.to_numpy(bary_wass_b)
1040+
1041+
np.testing.assert_allclose(1, np.sum(bary_wass), rtol=1e-3)
10801042

10811043

10821044
def test_unmix(nx):

0 commit comments

Comments
 (0)