Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 6 additions & 0 deletions RELEASES.md
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
# Releases


## 0.9.7.dev0

This new release adds support for sparse cost matrices and a new lazy EMD solver that computes distances on-the-fly from coordinates, reducing memory usage from O(n×m) to O(n+m). Both implementations are backend-agnostic and preserve gradient computation for automatic differentiation.
Expand All @@ -12,8 +13,13 @@ This new release adds support for sparse cost matrices and a new lazy EMD solver
- Migrate backend from deprecated `scipy.sparse.coo_matrix` to modern `scipy.sparse.coo_array` (PR #782)
- Geomloss function now handles both scalar and slice indices for i and j (PR #785)
- Add support for sparse cost matrices in EMD solver (PR #778, Issue #397)
<<<<<<< HEAD
- Added UOT1D with Frank-Wolfe in `ot.unbalanced.uot_1d` (PR #765)
- Add Sliced UOT and Unbalanced Sliced OT in `ot/unbalanced/_sliced.py` (PR #765)
=======
- Add cost functions between linear operators following
[A Spectral-Grassmann Wasserstein metric for operator representations of dynamical systems](https://arxiv.org/pdf/2509.24920) (PR #792)
>>>>>>> 8d13c55 (edits as per PR #792)

#### Closed issues

Expand Down
110 changes: 106 additions & 4 deletions ot/backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -622,6 +622,46 @@ def clip(self, a, a_min=None, a_max=None):
"""
raise NotImplementedError()

def real(self, a):
"""
Return the real part of the tensor element-wise.

This function follows the api from :any:`numpy.real`

See: https://numpy.org/doc/stable/reference/generated/numpy.real.html
"""
raise NotImplementedError()

def imag(self, a):
"""
Return the imaginary part of the tensor element-wise.

This function follows the api from :any:`numpy.imag`

See: https://numpy.org/doc/stable/reference/generated/numpy.imag.html
"""
raise NotImplementedError()

def conj(self, a):
"""
Return the complex conjugate, element-wise.

This function follows the api from :any:`numpy.conj`

See: https://numpy.org/doc/stable/reference/generated/numpy.conj.html
"""
raise NotImplementedError()

def arccos(self, a):
"""
Trigonometric inverse cosine, element-wise.

This function follows the api from :any:`numpy.arccos`

See: https://numpy.org/doc/stable/reference/generated/numpy.arccos.html
"""
raise NotImplementedError()

def repeat(self, a, repeats, axis=None):
r"""
Repeats elements of a tensor.
Expand Down Expand Up @@ -1193,7 +1233,7 @@ def _from_numpy(self, a, type_as=None):
elif isinstance(a, float):
return a
else:
return a.astype(type_as.dtype)
return np.asarray(a, dtype=type_as.dtype)

def set_gradients(self, val, inputs, grads):
# No gradients for numpy
Expand Down Expand Up @@ -1313,6 +1353,18 @@ def outer(self, a, b):
def clip(self, a, a_min=None, a_max=None):
return np.clip(a, a_min, a_max)

def real(self, a):
return np.real(a)

def imag(self, a):
return np.imag(a)

def conj(self, a):
return np.conj(a)

def arccos(self, a):
return np.arccos(a)

def repeat(self, a, repeats, axis=None):
return np.repeat(a, repeats, axis)

Expand Down Expand Up @@ -1604,7 +1656,7 @@ def _from_numpy(self, a, type_as=None):
if type_as is None:
return jnp.array(a)
else:
return self._change_device(jnp.array(a).astype(type_as.dtype), type_as)
return self._change_device(jnp.asarray(a, dtype=type_as.dtype), type_as)

def set_gradients(self, val, inputs, grads):
from jax.flatten_util import ravel_pytree
Expand Down Expand Up @@ -1730,6 +1782,18 @@ def outer(self, a, b):
def clip(self, a, a_min=None, a_max=None):
return jnp.clip(a, a_min, a_max)

def real(self, a):
return jnp.real(a)

def imag(self, a):
return jnp.imag(a)

def conj(self, a):
return jnp.conj(a)

def arccos(self, a):
return jnp.arccos(a)

def repeat(self, a, repeats, axis=None):
return jnp.repeat(a, repeats, axis)

Expand Down Expand Up @@ -1803,7 +1867,9 @@ def randperm(self, size, type_as=None):
if not isinstance(size, int):
raise ValueError("size must be an integer")
if type_as is not None:
return jax.random.permutation(subkey, size).astype(type_as.dtype)
return jnp.asarray(
jax.random.permutation(subkey, size), dtype=type_as.dtype
)
else:
return jax.random.permutation(subkey, size)

Expand Down Expand Up @@ -2227,6 +2293,18 @@ def outer(self, a, b):
def clip(self, a, a_min=None, a_max=None):
return torch.clamp(a, a_min, a_max)

def real(self, a):
return torch.real(a)

def imag(self, a):
return torch.imag(a)

def conj(self, a):
return torch.conj(a)

def arccos(self, a):
return torch.acos(a)

def repeat(self, a, repeats, axis=None):
return torch.repeat_interleave(a, repeats, dim=axis)

Expand Down Expand Up @@ -2728,6 +2806,18 @@ def outer(self, a, b):
def clip(self, a, a_min=None, a_max=None):
return cp.clip(a, a_min, a_max)

def real(self, a):
return cp.real(a)

def imag(self, a):
return cp.imag(a)

def conj(self, a):
return cp.conj(a)

def arccos(self, a):
return cp.arccos(a)

def repeat(self, a, repeats, axis=None):
return cp.repeat(a, repeats, axis)

Expand Down Expand Up @@ -2819,7 +2909,7 @@ def randperm(self, size, type_as=None):
return self.rng_.permutation(size)
else:
with cp.cuda.Device(type_as.device):
return self.rng_.permutation(size).astype(type_as.dtype)
return cp.asarray(self.rng_.permutation(size), dtype=type_as.dtype)

def coo_matrix(self, data, rows, cols, shape=None, type_as=None):
data = self.from_numpy(data)
Expand Down Expand Up @@ -3162,6 +3252,18 @@ def outer(self, a, b):
def clip(self, a, a_min=None, a_max=None):
return tnp.clip(a, a_min, a_max)

def real(self, a):
return tnp.real(a)

def imag(self, a):
return tnp.imag(a)

def conj(self, a):
return tnp.conj(a)

def arccos(self, a):
return tnp.arccos(a)

def repeat(self, a, repeats, axis=None):
return tnp.repeat(a, repeats, axis)

Expand Down
Loading