Streaming Entropic Optimal Transport in PyTorch + Triton
FlashSinkhorn computes Sinkhorn OT using FlashAttention-style streaming—never materializing the n×m cost matrix—enabling O(nd) memory instead of O(n²).
- FlashSinkhorn kernels — shifted-potential formulation inspired by FlashAttention. On A100 GPUs, achieves up to 32× forward-pass and 161× end-to-end speedups over state-of-the-art online baselines on point-cloud OT
- Fused Triton kernels for forward, gradient, and HVP
- GeomLoss-compatible API (
SamplesLoss) - Analytic gradients (no backprop through Sinkhorn iterations)
- Hessian-vector products via streaming CG solver
- Half-cost support (
half_cost=True) for exact GeomLoss parity - Unbalanced/semi-unbalanced OT via
reachparameter - Large-D support (d > 1024) with tiled gradient kernel
- Early stopping with convergence threshold
pip install flash-sinkhorn
# From source (development)
pip install -e ".[dev]"Requirements: PyTorch ≥2.5, Triton ≥3.1, CUDA 12.x
import torch
from flash_sinkhorn import SamplesLoss
x = torch.randn(4096, 64, device="cuda")
y = torch.randn(4096, 64, device="cuda")
loss = SamplesLoss(loss="sinkhorn", blur=0.1, debias=True)
cost = loss(x, y)x = torch.randn(4096, 64, device="cuda", requires_grad=True)
y = torch.randn(4096, 64, device="cuda")
loss = SamplesLoss(loss="sinkhorn", blur=0.1, debias=True)
cost = loss(x, y)
grad_x = torch.autograd.grad(cost, x)[0] # Analytic gradientUse half_cost=True to match GeomLoss's cost convention:
# FlashSinkhorn with half_cost matches GeomLoss exactly
flash_loss = SamplesLoss(loss="sinkhorn", blur=0.1, half_cost=True, debias=True)
# Equivalent GeomLoss call
# geomloss_loss = geomloss.SamplesLoss(loss="sinkhorn", p=2, blur=0.1, debias="positive")For distributions with different total mass or outliers:
loss = SamplesLoss(
loss="sinkhorn",
blur=0.1,
debias=True,
reach=1.0, # Unbalanced OT with KL penalty
)Different constraints for source vs target:
loss = SamplesLoss(
loss="sinkhorn",
blur=0.1,
reach_x=1.0, # Relax source marginal
reach_y=None, # Keep target marginal strict (balanced)
)loss = SamplesLoss(
loss="sinkhorn",
blur=0.1,
n_iters=100,
threshold=1e-3, # Stop when potential change < threshold
inner_iterations=10, # Check every N iterations
)x = torch.randn(4096, 64, device="cuda", requires_grad=True)
y = torch.randn(4096, 64, device="cuda")
v = torch.randn_like(x)
loss = SamplesLoss(loss="sinkhorn", blur=0.1)
cost = loss(x, y)
# First-order gradient
grad_x = torch.autograd.grad(cost, x, create_graph=True)[0]
# HVP via double backward (uses streaming CG solver)
hvp = torch.autograd.grad((grad_x * v).sum(), x)[0]FlashSinkhorn is a reformulated Sinkhorn kernel that uses shifted potentials inspired by FlashAttention. It reduces bias vector loads by 67% and elementwise operations by 78% per tile, and improves scalability on OT-based downstream tasks.
Standard Sinkhorn loads 3 bias vectors per tile (g, log_b, y²). FlashSinkhorn precomputes a single fused bias u = (g_shifted + eps*log(b)) / eps and uses raw coordinates with an inline scale factor, matching FlashAttention's score interface exactly.
Symmetric solver (vs v0.2.0 GeomLoss-style kernel):
| n | v0.2.0 | v0.3.0 | Speedup |
|---|---|---|---|
| 50,000 | 1730 ms | 1450 ms | 1.19x |
| 10,000 | 88 ms | 61 ms | 1.43x |
| 5,000 | 25 ms | 24 ms | 1.04x |
Alternating solver (vs v0.2.0 OTT-style kernel, 10 iterations):
| n | v0.2.0 | v0.3.0 | Speedup |
|---|---|---|---|
| 50,000 | 137.9 ms | 102.6 ms | 1.34x |
| 20,000 | 25.7 ms | 21.7 ms | 1.19x |
| 10,000 | 8.9 ms | 8.3 ms | 1.07x |
Low-level FlashSinkhorn API:
from flash_sinkhorn.kernels import (
sinkhorn_flashstyle_symmetric, # Full symmetric solver
sinkhorn_flashstyle_alternating, # Full alternating solver
flashsinkhorn_symmetric_step, # Single fused iteration
apply_plan_vec_flashstyle, # Transport plan @ vector (shifted potentials)
apply_plan_mat_flashstyle, # Transport plan @ matrix (shifted potentials)
)SamplesLoss(
loss="sinkhorn",
p=2, # Only p=2 supported (squared Euclidean)
blur=0.05, # Regularization: eps = blur^2
debias=True, # Debiased Sinkhorn divergence
half_cost=False, # Use ||x-y||²/2 to match GeomLoss
reach=None, # Unbalanced OT (None = balanced)
reach_x=None, # Semi-unbalanced: source marginal
reach_y=None, # Semi-unbalanced: target marginal
scaling=0.5, # Epsilon annealing factor
n_iters=None, # Max iterations (None = use scaling)
threshold=None, # Early stopping threshold
inner_iterations=10, # Check convergence every N iters
)# FlashSinkhorn (recommended)
from flash_sinkhorn.kernels import (
sinkhorn_flashstyle_symmetric,
sinkhorn_flashstyle_alternating,
apply_plan_vec_flashstyle,
apply_plan_mat_flashstyle,
)
# Legacy kernels (still available)
from flash_sinkhorn.kernels.sinkhorn_triton_geomloss_sqeuclid import (
sinkhorn_geomloss_online_potentials_sqeuclid,
)
from flash_sinkhorn.kernels.sinkhorn_triton_grad_sqeuclid import (
sinkhorn_geomloss_online_grad_sqeuclid,
)
from flash_sinkhorn.hvp import hvp_x_sqeuclid_from_potentials- FlashSinkhorn default:
C(x,y) = ||x-y||² - GeomLoss p=2 default:
C(x,y) = ||x-y||²/2 - Use
half_cost=Trueto match GeomLoss
FlashSinkhorn streams tiles of (x,y) and computes costs on-the-fly:
- Forward: O(nd) memory (no n×m cost matrix)
- Gradient: O(nd) memory (streaming accumulation)
- HVP: O(nd) memory (CG solver with streaming matvec)
- Uses
exp2/log2for stable LSE computation - Safe log/division guards against underflow
- TF32 enabled by default for ~2x speedup on A100/H100 (set
allow_tf32=Falsefor strict FP32) - HVP (double backward) uses strict FP32 internally for numerical stability
Compare FlashSinkhorn against GeomLoss (KeOps) and OTT-JAX.
Install benchmark dependencies:
pip install geomloss pykeops ott-jax jax[cuda12]Run benchmarks:
# Forward pass benchmark
python -m flash_sinkhorn.bench.bench_forward --sizes 5000,10000,20000 --dims 64 --verify
# Backward pass benchmark
python -m flash_sinkhorn.bench.bench_backward --sizes 5000,10000,20000 --dims 64 --verify
# Quick test (small size)
python -m flash_sinkhorn.bench.bench_forward --sizes 5000 --dims 4 --verify
# Run only FlashSinkhorn (skip GeomLoss/OTT-JAX)
python -m flash_sinkhorn.bench.bench_forward --sizes 10000 --dims 64 --no-geomloss --no-ottResults are saved to output/paper_benchmarks/forward/ and output/paper_benchmarks/backward/.
If you find FlashSinkhorn useful in your research, please cite our paper:
@inproceedings{ye2026flashsinkhorn,
title={FlashSinkhorn: IO-Aware Entropic Optimal Transport on GPU},
author={Ye, Felix X.-F. and Li, Xingjie and Yu, An and Chang, Ming-Ching and Chu, Linsong and Wertheimer, Davis},
booktitle={Proceedings of the 43rd International Conference on Machine Learning (ICML)},
note={Spotlight},
year={2026},
url={https://arxiv.org/abs/2602.03067}
}MIT
