High-Performance and minimal JAX Port available here (recommended):
https://github.com/peterhalmos/HiRef/tree/main
This is the repository for "Hierarchical Refinement: Optimal Transport to Infinity and Beyond," (ICML 2025) which scales optimal transport linearly in space and log-linearly in time by using a hierarchical strategy that constructs multiscale partitions from low-rank optimal transport.
In the section below, we detail the usage of Hierarchical Refinement which complements the simple example notebooks:
- [refinement_demo_nb.ipynb](notebooks/refinement_demo_nb.ipynb)
- [refinement_demo_nb_fast.ipynb](notebooks/refinement_demo_nb_fast.ipynb)
We additionally detail how to run Hierarchical Refinement with acceleration, which we recommend using and intend to set as a default in the future.
Hierarchical Refinement algorithm: low-rank optimal transport is used to progressively refine partitions at the previous scale, with the coarsest scale partitions denoted
Examples of HiRef JAX Bijections on Varied Datasets.
Hierarchical Refinement (HiRef) only requires two n×d dimensional point clouds X
and Y
(torch tensors) as input.
Before running HiRef, call the rank-annealing scheduler to find a sequence of ranks that minimizes the number of calls to the low-rank optimal transport subroutine while remaining under a machine-specific maximal rank.
n
: The size of the datasethierarchy_depth (κ)
: The depth of the hierarchy of levels used in the refinement strategymax_Q
: The maximal terminal rank at the base casemax_rank
: The maximal rank of the intermediate sub-problems
Import the rank annealing module and compute the rank schedule:
import rank_annealing
rank_schedule = rank_annealing.optimal_rank_schedule(
n=n, hierarchy_depth=hierarchy_depth, max_Q=max_Q, max_rank=max_rank
)
Import HR_OT and initialize the class using only the point clouds (you can additionally input the cost C
if desired) along with any relevant parameters (e.g., sq_Euclidean) for your problem.
import HR_OT
hrot = HR_OT.HierarchicalRefinementOT.init_from_point_clouds(
X, Y, rank_schedule, base_rank=1, device=device
)
Run and return paired tuples from X
and Y
(the bijective Monge map between the datasets):
Gamma_hrot = hrot.run(return_as_coupling=False)
To print the Optimal Transport (OT) cost, simply call:
cost_hrot = hrot.compute_OT_cost()
print(f"Refinement Cost: {cost_hrot.item()}")
There are a number of ways to accelerate Hierarchical Refinement. First, one may edit the solver parameters to place a ceiling on the number of iterations run, e.g. if speed is a more important factor than solution optimality. This can be done by lowering the max_iter
and max_inneriters
parameters from their conservative default values. In addition, one may use a more lightweight low-rank OT solver (src/LR_mini.LROT_LR_opt
for low-rank cost matrix and src/LR_mini.LROT_opt
for full cost matrix) by simply passing the updated solvers to the initialization. All of these changes can be implemented with the following modifications to the code above:
import HR_OT
import LR_mini
solver_params = {
'max_iter' : 45,
'min_iter' : 30,
'max_inneriters_balanced' : 60
}
hrot_lr = HR_OT.HierarchicalRefinementOT.init_from_point_clouds(
X, Y,
rank_schedule, base_rank=1,
device=device,
solver_params=solver_params,
solver=LR_mini.LROT_LR_opt,
solver_full=LR_mini.LROT_opt
)
An implementation demonstrating this acceleration for squared Euclidean cost can be found in refinement_demo_nb_fast.ipynb
for reference, with an example alignment of 2 x 200k points.
For questions, discussions, or collaboration inquiries, feel free to reach out at [email protected] or [email protected].
All experiments are available in the folder HiRef/experiments/
. To generate the datasets used for the mouse-embryo experiment, the raw Stereo-Seq mouse embryo slices (E9.5-16.5) may be accessed from the MOSTA database and converted to a pre-processed form using the script HiRef/data/preprocess_embryo.py
. While the default hyperparameter settings for HiRef have changed, the exact experiments and hyperparameter settings used are available on OpenReview. At the time of benchmarking, the default epsilon of Sinkhorn in ott-jax was 0.05, which has since been modified. Note that low-rank solvers such as LOT and FRLC are non-convex and use randomized initial conditions, so the solution of HiRef may exhibit slight variation between runs.
If you found Hierarchical Refinement to be useful in your work, feel free to cite the paper:
@inproceedings{
halmos2025hierarchical,
title={Hierarchical Refinement: Optimal Transport to Infinity and Beyond},
author={Peter Halmos and Julian Gold and Xinhao Liu and Benjamin Raphael},
booktitle={Forty-second International Conference on Machine Learning},
year={2025},
url={https://openreview.net/forum?id=EBNgREMoVD}
}