[WIP] Spectral-Grassmann OT#792
Conversation
Codecov Report❌ Patch coverage is Additional details and impacted files@@ Coverage Diff @@
## master #792 +/- ##
==========================================
- Coverage 96.77% 95.77% -1.00%
==========================================
Files 107 108 +1
Lines 22342 22621 +279
==========================================
+ Hits 21622 21666 +44
- Misses 720 955 +235 🚀 New features to boost your workflow:
|
rflamary
left a comment
There was a problem hiding this comment.
Hello @osheasienna and @thibaut-germain this is a nice first step.
Here are below a few comments that we can discuss together
ot/sgot.py
Outdated
| return C | ||
|
|
||
|
|
||
| def metric( |
There was a problem hiding this comment.
| def metric( | |
| def sgot_metric( |
ot/sgot.py
Outdated
| return prod ** (q / 2) | ||
|
|
||
|
|
||
| def ot_plan(C, Ws=None, Wt=None, nx=None): |
There was a problem hiding this comment.
this function is not needed, this is two lines and the ormalization wrt ws and wt are not oK because it rcan retrun very weird things
ot/sgot.py
Outdated
| ### SPECTRAL-GRASSMANNIAN WASSERSTEIN METRIC ### | ||
| ##################################################################################################################################### | ||
| ##################################################################################################################################### | ||
| def cost( |
There was a problem hiding this comment.
| def cost( | |
| def sgot_cost_matrix( |
ot/sgot.py
Outdated
| imag_scale=1.0, | ||
| nx=None, | ||
| ): | ||
| """Compute the SGOT cost matrix between two spectral decompositions. |
There was a problem hiding this comment.
recall here the equation with eta and define with math teh different acceptable metrics
ot/sgot.py
Outdated
| raise ValueError(f"cost() expects Dt to be 1D (n,), got shape {Dt.shape}") | ||
| lam2 = Dt | ||
|
|
||
| lam1 = nx.astype(lam1, "complex128") |
There was a problem hiding this comment.
is that necessary? seems overkill to add a function to the backend for that . When and why does it fails?
test/test_sgot.py
Outdated
| logits_s = rng.randn(r) | ||
| logits_t = rng.randn(r) | ||
|
|
||
| Ws = np.exp(logits_s) |
There was a problem hiding this comment.
simpler and return only positive values
| Ws = np.exp(logits_s) | |
| Ws = rng.rand(r) |
test/test_sgot.py
Outdated
| """Create test_cost for each trial: sweep over HPs and run cost().""" | ||
| grassmann_types = ["geodesic", "chordal", "procrustes", "martin"] | ||
| n_trials = 10 | ||
| for _ in range(n_trials): |
test/test_sgot.py
Outdated
| def test_hyperparameter_sweep(): | ||
| grassmann_types = ["geodesic", "chordal", "procrustes", "martin"] | ||
|
|
||
| for _ in range(10): |
| This new release adds support for sparse cost matrices and a new lazy EMD solver that computes distances on-the-fly from coordinates, reducing memory usage from O(n×m) to O(n+m). Both implementations are backend-agnostic and preserve gradient computation for automatic differentiation. | ||
|
|
||
| #### New features | ||
| - Add lazy EMD solver with on-the-fly distance computation from coordinates (PR #788) |
RELEASES.md
Outdated
| ## Upcomming 0.9.7.post1 | ||
|
|
||
| #### New features | ||
| The next release will add cost functions between linear operators following [A Spectral-Grassmann Wasserstein metric for operator representations of dynamical systems](https://arxiv.org/pdf/2509.24920). |
There was a problem hiding this comment.
move this text to the new feature of 0.9.7.dev0 this is what we are working on. Also add a line in the Itemize with the PR number
rflamary
left a comment
There was a problem hiding this comment.
A few comments from talking together
ot/sgot.py
Outdated
| if grassman_metric == "procrustes": | ||
| return 2.0 * (1.0 - delta) | ||
| if grassman_metric == "martin": | ||
| return -nx.log(nx.clip(delta**2, eps, 1e300)) |
ot/sgot.py
Outdated
| C_grass = _grassmann_distance_squared(delta, grassman_metric=grassman_metric, nx=nx) | ||
|
|
||
| C2 = eta * C_lambda + (1.0 - eta) * C_grass | ||
| C = C2 ** (p / 2.0) |
There was a problem hiding this comment.
| C = C2 ** (p / 2.0) | |
| C = nx.real(C2) ** (p / 2.0) |
ot/sgot.py
Outdated
| q=1, | ||
| r=2, | ||
| grassman_metric="chordal", | ||
| real_scale=1.0, |
There was a problem hiding this comment.
lets call this eigen_scaling and set it to None by default
| nx=None, | ||
| ): | ||
| """Compute the SGOT metric between two spectral decompositions. | ||
|
|
There was a problem hiding this comment.
add equation that illustrate p q and r
test/test_sgot.py
Outdated
| import numpy as np | ||
| import pytest | ||
|
|
||
| from ot.backend import get_backend |
There was a problem hiding this comment.
| from ot.backend import get_backend | |
| from ot.backend import get_backend, torch, jax |
test/test_sgot.py
Outdated
| rng = np.random.RandomState(0) | ||
|
|
||
|
|
||
| def rand_complex(shape): |
There was a problem hiding this comment.
| def rand_complex(shape): | |
| def rand_complex(shape,rng): |
test/test_sgot.py
Outdated
| return real + 1j * imag | ||
|
|
||
|
|
||
| def random_atoms(d=8, r=4): |
There was a problem hiding this comment.
| def random_atoms(d=8, r=4): | |
| def random_atoms(d=8, r=4,seed=42): |
test/test_sgot.py
Outdated
|
|
||
|
|
||
| @pytest.mark.parametrize("backend_name", ["numpy", "torch", "jax"]) | ||
| def test_cost_backend_consistency(backend_name): |
There was a problem hiding this comment.
| def test_cost_backend_consistency(backend_name): | |
| def test_cost_backend_consistency(nx): |
test/test_sgot.py
Outdated
| # --------------------------------------------------------------------- | ||
|
|
||
|
|
||
| def test_hyperparameter_sweep_cost(nx): |
There was a problem hiding this comment.
| def test_hyperparameter_sweep_cost(nx): | |
| def test_hyperparameter_sweep_cost(nx,grassmann_types,p,q,r,eta): |
ot/sgot.py
Outdated
| Ws = Ws / nx.sum(Ws) | ||
| Wt = Wt / nx.sum(Wt) | ||
|
|
||
| P = ot.emd2(Ws, Wt, nx.real(C)) |
There was a problem hiding this comment.
emd2 retruns directly obj no need to compute it again below
Types of changes
Adding sgot file in the ot folder.
Motivation and context / Related issue
Keep track of SGOT implementation in POT.
How has this been tested (if it applies)
Not tested yet.
PR checklist