Skip to content

Commit 0648848

Browse files
committed
fix tf backened + nx.sqrt + doc
1 parent afa3866 commit 0648848

File tree

2 files changed

+11
-7
lines changed

2 files changed

+11
-7
lines changed

ot/backend.py

Lines changed: 7 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -3120,10 +3120,15 @@ def randn(self, *size, type_as=None):
31203120
def randperm(self, size, type_as=None):
31213121
if not isinstance(size, int):
31223122
raise ValueError("size must be an integer")
3123+
local_seed = self.rng_.make_seeds(2)[0]
31233124
if type_as is None:
3124-
return self.rng_.shuffle(tf.range(size))
3125+
return tf.random.experimental.stateless_shuffle(
3126+
tf.range(size), seed=local_seed
3127+
)
31253128
else:
3126-
return self.rng_.shuffle(tf.range(size, dtype=type_as.dtype))
3129+
return tf.random.experimental.stateless_shuffle(
3130+
tf.range(size, dtype=type_as.dtype), seed=local_seed
3131+
)
31273132

31283133
def _convert_to_index_for_coo(self, tensor):
31293134
if isinstance(tensor, self.__type__):

ot/bregman/_empirical.py

Lines changed: 4 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -833,18 +833,17 @@ def empirical_sinkhorn_nystroem(
833833
>>> X_s = np.reshape(np.arange(n_samples_a, dtype=np.float64), (n_samples_a, 1))
834834
>>> X_t = np.reshape(np.arange(0, n_samples_b, dtype=np.float64), (n_samples_b, 1))
835835
>>> empirical_sinkhorn_nystroem(X_s, X_t, reg, anchors, random_state=42)[:] # doctest: +ELLIPSIS
836-
array([[2.50000000e-01, 1.46537753e-01, 7.29587925e-10, 1.03462246e-01],
837-
[3.63816797e-10, 1.03462247e-01, 2.49999999e-01, 1.46537754e-01]])
836+
array([[0.125, 0.125, 0.125, 0.125],
837+
[0.125, 0.125, 0.125, 0.125]])
838838
839839
References
840840
----------
841841
842842
.. [80] Massively scalable Sinkhorn distances via the Nyström method, Jason Altschuler, Francis Bach, Alessandro Rudi, Jonathan Niles-Weed, NeurIPS 2019.
843843
844844
"""
845-
nx = get_backend(X_s, X_t)
846845
left_factor, right_factor = kernel_nystroem(
847-
X_s, X_t, anchors=anchors, sigma=nx.sqrt(reg / 2.0), random_state=random_state
846+
X_s, X_t, anchors=anchors, sigma=(reg / 2.0) ** (0.5), random_state=random_state
848847
)
849848
_, _, dict_log = sinkhorn_low_rank_kernel(
850849
K1=left_factor,
@@ -951,7 +950,7 @@ def empirical_sinkhorn_nystroem2(
951950
nx = get_backend(X_s, X_t)
952951
M1, M2 = compute_lr_sqeuclidean_matrix(X_s, X_t, False, nx=nx)
953952
left_factor, right_factor = kernel_nystroem(
954-
X_s, X_t, anchors=anchors, sigma=nx.sqrt(reg / 2.0), random_state=random_state
953+
X_s, X_t, anchors=anchors, sigma=(reg / 2.0) ** (0.5), random_state=random_state
955954
)
956955
if log:
957956
u, v, dict_log = sinkhorn_low_rank_kernel(

0 commit comments

Comments
 (0)