Skip to content

Batch OT losses (Sinkhorn + Gromov)#755

Merged
rflamary merged 37 commits intoPythonOT:masterfrom
KrzakalaPaul:ot-batch
Sep 16, 2025
Merged

Batch OT losses (Sinkhorn + Gromov)#755
rflamary merged 37 commits intoPythonOT:masterfrom
KrzakalaPaul:ot-batch

Conversation

@KrzakalaPaul
Copy link
Contributor

Types of changes

Add the ot.batch module for solving N optimal transport problems at the same time in parallel. The two main features are:

  • ot.batch.solve_batch, the API is the same as for ot.solve (with less features for now).
  • ot.batch.solve_gromov_batch, the API is the same as for ot.solve_gromov (with less features for now).

Some examples have been added to examples/batch

Motivation and context / Related issue

It is often the case that one needs to solve N optimal transport problem at the same time. At the moment, the only way to do this in POT was using a for loop. This is very inefficient for data stored on a GPU for instance.
Instead, it is now possible to use ot.batch for solving N problems at the same time with fully batch parallel operations. As demonstrated in examples/batch/demo_efficiency.py this can lead to a speed of 1 or 2 orders of magnitude.

How has this been tested (if it applies)

New tests have been added to test/batch/ to ensure that using ot.solve_batch yields the same results that using a for loops of ot.solve. Idem for ot.solve_batch_gromov.

PR checklist

  • I have read the CONTRIBUTING document.
  • The documentation is up-to-date with the changes I made (check build artifacts).
  • All tests passed, and additional code has been covered with new tests.
  • I have added the PR and Issue fix to the RELEASES.md file.

@codecov
Copy link

codecov bot commented Aug 25, 2025

Codecov Report

❌ Patch coverage is 97.60956% with 12 lines in your changes missing coverage. Please review.
✅ Project coverage is 97.16%. Comparing base (803d2ab) to head (e70ecd1).
⚠️ Report is 1 commits behind head on master.

Additional details and impacted files
@@           Coverage Diff            @@
##           master     #755    +/-   ##
========================================
  Coverage   97.15%   97.16%            
========================================
  Files         101      107     +6     
  Lines       21386    21879   +493     
========================================
+ Hits        20778    21259   +481     
- Misses        608      620    +12     
🚀 New features to boost your workflow:
  • ❄️ Test Analytics: Detect flaky tests, report on failures, and find test suite problems.

@rflamary rflamary changed the title Ot batch Batch OT losses (Sinkhonr + Gromov) Aug 27, 2025
@rflamary rflamary changed the title Batch OT losses (Sinkhonr + Gromov) Batch OT losses (Sinkhorn + Gromov) Aug 27, 2025
Copy link
Collaborator

@rflamary rflamary left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks @KrzakalaPaul here are a few comments to take into account please.

log_dual=True,
grad="detach",
):
r"""Solves the linear optimal transport problem using Bregman projections.
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Add entropic here this is not exact OT

symmetric=None,
M=None,
alpha=None,
epsilon=1e-2,
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

reg + mandatory

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

call it reg_inner

tol=1e-5,
max_iter_inner=50,
tol_inner=1e-5,
grad="detach",
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

we need a gradient by default

@rflamary rflamary merged commit 3a53dff into PythonOT:master Sep 16, 2025
18 checks passed
MatDag pushed a commit to MatDag/POT that referenced this pull request Sep 25, 2025
* linear ot implemented

* improve stopping criterion and assymetric case

* Add recompute_const and simplify the pipeline for the symmetric = False

* add tests

* update the examples and rename to follow the "ot.solve" naming conventions

* update realeases.md

* idem

* move set_grad_enabled to backend

* set_grad_enabled for quadratric solver

* update doc

* remove useless importation in doc

* Update references

* update example

* Remove classes in quadratic, move examples to backend, add potentials, remove context managers for grads. To do: improve doc and tests

* updat tests

* Massive improvement of the documentation for ot.batch

* cover (almost) all ot.batch with tests

* bug in the tests

* update docstring

* highlight that ot.batch is solving the entropic version

* removing yet another error in the docstring

* Add missing parameter recompute_const

* Remove png, add all backends and gradient mode to tests

* add the missing pytest

* change .sum() into nx.sum

* add missing backend

* yet another missing nx

* remove useless squeeze and add test for non-log bregman

* remove last_step from quadratic tests

* add missing tests and improve documentation

* proper unsqueeze test

* add unsqueeze to tensorflow

* solve double backprop issue in test_gradients_torch

---------

Co-authored-by: PaulKrzakala <paul.krzakala@gmail.com>
Co-authored-by: Rémi Flamary <remi.flamary@gmail.com>
Co-authored-by: Cédric Vincent-Cuaz <cedvincentcuaz@gmail.com>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants