Skip to content

Commit e164e78

Browse files
authored
[MRG] Warmstart for exact sparse and lazy solvers (#795)
* Implementation of warmstart for network simplex can make use off precomputed potentials from sinkhorn or even related simplex * optimise initial setup for watmstart using heap * Update Releases and test file * changed some interfaces and function names * small doc fix * Add warmstart potentials to sparse and lazy solver also * Adding warmstart to missing functions * small wrapper issue in solve_sample
1 parent 83d05f9 commit e164e78

File tree

8 files changed

+244
-17
lines changed

8 files changed

+244
-17
lines changed

RELEASES.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@ This new release adds support for sparse cost matrices and a new lazy EMD solver
77
#### New features
88
- Add lazy EMD solver with on-the-fly distance computation from coordinates (PR #788)
99
- Add Warmstart feature to the EMD solver for existing potentials (PR #793)
10+
- Add Warmstart potentials feature to the EMD solver for lazy and sparse solver
1011
- Migrate backend from deprecated `scipy.sparse.coo_matrix` to modern `scipy.sparse.coo_array` (PR #782)
1112
- Geomloss function now handles both scalar and slice indices for i and j (PR #785)
1213
- Add support for sparse cost matrices in EMD solver (PR #778, Issue #397)

ot/lp/EMD.h

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -48,7 +48,9 @@ int EMD_wrap_sparse(
4848
double *alpha, // Output: dual variables for sources (n1)
4949
double *beta, // Output: dual variables for targets (n2)
5050
double *cost, // Output: total transportation cost
51-
uint64_t maxIter // Maximum iterations for solver
51+
uint64_t maxIter, // Maximum iterations for solver
52+
double *alpha_init, // Initial dual variables for sources (warmstart)
53+
double *beta_init // Initial dual variables for targets (warmstart)
5254
);
5355

5456
int EMD_wrap_lazy(
@@ -64,7 +66,9 @@ int EMD_wrap_lazy(
6466
double *alpha, // Output: dual variables for sources (n1)
6567
double *beta, // Output: dual variables for targets (n2)
6668
double *cost, // Output: total transportation cost
67-
uint64_t maxIter // Maximum iterations for solver
69+
uint64_t maxIter, // Maximum iterations for solver
70+
double *alpha_init, // Initial dual variables for sources (warmstart)
71+
double *beta_init // Initial dual variables for targets (warmstart)
6872
);
6973

7074

ot/lp/EMD_wrapper.cpp

Lines changed: 36 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -253,7 +253,9 @@ int EMD_wrap_sparse(
253253
double *alpha,
254254
double *beta,
255255
double *cost,
256-
uint64_t maxIter
256+
uint64_t maxIter,
257+
double *alpha_init,
258+
double *beta_init
257259
) {
258260
using namespace lemon;
259261

@@ -351,6 +353,22 @@ int EMD_wrap_sparse(
351353
}
352354
}
353355

356+
// Initialize warmstart if provided
357+
if (alpha_init != nullptr && beta_init != nullptr) {
358+
// Map original indices to graph indices for warmstart
359+
std::vector<double> alpha_filtered(n);
360+
std::vector<double> beta_filtered(m);
361+
for (uint64_t i = 0; i < n; i++) {
362+
uint64_t orig_i = indI[i];
363+
alpha_filtered[i] = alpha_init[orig_i];
364+
}
365+
for (uint64_t j = 0; j < m; j++) {
366+
uint64_t orig_j = indJ[j];
367+
beta_filtered[j] = beta_init[orig_j];
368+
}
369+
net.setWarmstartPotentials(&alpha_filtered[0], &beta_filtered[0], n, m);
370+
}
371+
354372
int ret = net.run();
355373

356374
if (ret == (int)net.OPTIMAL || ret == (int)net.MAX_ITER_REACHED) {
@@ -389,7 +407,7 @@ int EMD_wrap_sparse(
389407

390408
int EMD_wrap_lazy(int n1, int n2, double *X, double *Y, double *coords_a, double *coords_b,
391409
int dim, int metric, double *G, double *alpha, double *beta,
392-
double *cost, uint64_t maxIter) {
410+
double *cost, uint64_t maxIter, double *alpha_init, double *beta_init) {
393411
using namespace lemon;
394412
typedef FullBipartiteDigraph Digraph;
395413
DIGRAPH_TYPEDEFS(Digraph);
@@ -454,6 +472,22 @@ int EMD_wrap_lazy(int n1, int n2, double *X, double *Y, double *coords_a, double
454472
// Enable lazy cost computation - costs will be computed on-the-fly
455473
net.setLazyCost(&coords_a_filtered[0], &coords_b_filtered[0], dim, metric, n, m);
456474

475+
// Initialize warmstart if provided
476+
if (alpha_init != nullptr && beta_init != nullptr) {
477+
// Map original indices to graph indices for warmstart
478+
std::vector<double> alpha_filtered(n);
479+
std::vector<double> beta_filtered(m);
480+
for (int i = 0; i < n; i++) {
481+
int orig_i = idx_a[i];
482+
alpha_filtered[i] = alpha_init[orig_i];
483+
}
484+
for (int j = 0; j < m; j++) {
485+
int orig_j = idx_b[j];
486+
beta_filtered[j] = beta_init[orig_j];
487+
}
488+
net.setWarmstartPotentials(&alpha_filtered[0], &beta_filtered[0], n, m);
489+
}
490+
457491
// Run solver
458492
int ret = net.run();
459493

ot/lp/_network_simplex.py

Lines changed: 15 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -820,6 +820,7 @@ def emd2_lazy(
820820
return_matrix=True,
821821
center_dual=True,
822822
check_marginals=True,
823+
potentials_init=None,
823824
):
824825
r"""Solves the Earth Movers distance problem with lazy cost computation and returns the loss
825826
@@ -873,6 +874,9 @@ def emd2_lazy(
873874
If True, centers the dual potential using :py:func:`ot.lp.center_ot_dual`
874875
check_marginals: bool, optional (default=True)
875876
If True, checks that the marginals mass are equal
877+
potentials_init : tuple of (ns,) and (nt,) arrays, optional
878+
Initial dual potentials (u, v) to warmstart the solver. If provided,
879+
the solver starts from these potentials instead of a cold start.
876880
877881
Returns
878882
-------
@@ -942,8 +946,18 @@ def emd2_lazy(
942946
)
943947
b_np = b_np * a_np.sum() / b_np.sum()
944948

949+
# Handle warmstart potentials
950+
alpha_init_np = None
951+
beta_init_np = None
952+
if potentials_init is not None:
953+
alpha_init, beta_init = potentials_init
954+
alpha_init_np = nx.to_numpy(alpha_init)
955+
beta_init_np = nx.to_numpy(beta_init)
956+
alpha_init_np = np.asarray(alpha_init_np, dtype=np.float64, order="C")
957+
beta_init_np = np.asarray(beta_init_np, dtype=np.float64, order="C")
958+
945959
G, cost, u, v, result_code = emd_c_lazy(
946-
a_np, b_np, X_a_np, X_b_np, metric, numItermax
960+
a_np, b_np, X_a_np, X_b_np, metric, numItermax, alpha_init_np, beta_init_np
947961
)
948962

949963
if center_dual:

ot/lp/emd_wrap.pyx

Lines changed: 26 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -22,8 +22,8 @@ import warnings
2222
cdef extern from "EMD.h":
2323
int EMD_wrap(int n1,int n2, double *X, double *Y,double *D, double *G, double* alpha, double* beta, double *cost, uint64_t maxIter, double* alpha_init, double* beta_init) nogil
2424
int EMD_wrap_omp(int n1,int n2, double *X, double *Y,double *D, double *G, double* alpha, double* beta, double *cost, uint64_t maxIter, int numThreads) nogil
25-
int EMD_wrap_sparse(int n1, int n2, double *X, double *Y, uint64_t n_edges, uint64_t *edge_sources, uint64_t *edge_targets, double *edge_costs, uint64_t *flow_sources_out, uint64_t *flow_targets_out, double *flow_values_out, uint64_t *n_flows_out, double *alpha, double *beta, double *cost, uint64_t maxIter) nogil
26-
int EMD_wrap_lazy(int n1, int n2, double *X, double *Y, double *coords_a, double *coords_b, int dim, int metric, double *G, double* alpha, double* beta, double *cost, uint64_t maxIter) nogil
25+
int EMD_wrap_sparse(int n1, int n2, double *X, double *Y, uint64_t n_edges, uint64_t *edge_sources, uint64_t *edge_targets, double *edge_costs, uint64_t *flow_sources_out, uint64_t *flow_targets_out, double *flow_values_out, uint64_t *n_flows_out, double *alpha, double *beta, double *cost, uint64_t maxIter, double* alpha_init, double* beta_init) nogil
26+
int EMD_wrap_lazy(int n1, int n2, double *X, double *Y, double *coords_a, double *coords_b, int dim, int metric, double *G, double* alpha, double* beta, double *cost, uint64_t maxIter, double* alpha_init, double* beta_init) nogil
2727
cdef enum ProblemType: INFEASIBLE, OPTIMAL, UNBOUNDED, MAX_ITER_REACHED
2828

2929

@@ -233,7 +233,9 @@ def emd_c_sparse(np.ndarray[double, ndim=1, mode="c"] a,
233233
np.ndarray[uint64_t, ndim=1, mode="c"] edge_sources,
234234
np.ndarray[uint64_t, ndim=1, mode="c"] edge_targets,
235235
np.ndarray[double, ndim=1, mode="c"] edge_costs,
236-
uint64_t max_iter):
236+
uint64_t max_iter,
237+
np.ndarray[double, ndim=1, mode="c"] alpha_init=None,
238+
np.ndarray[double, ndim=1, mode="c"] beta_init=None):
237239
"""
238240
Sparse EMD solver using cost matrix in COO (Coordinate) sparse format.
239241
@@ -255,6 +257,10 @@ def emd_c_sparse(np.ndarray[double, ndim=1, mode="c"] a,
255257
Cost for each edge (non-zero values in COO format)
256258
max_iter : uint64_t
257259
Maximum number of iterations
260+
alpha_init : (n1,) array, float64, optional
261+
Initial dual variables for sources (warmstart)
262+
beta_init : (n2,) array, float64, optional
263+
Initial dual variables for targets (warmstart)
258264
259265
Returns
260266
-------
@@ -287,6 +293,12 @@ def emd_c_sparse(np.ndarray[double, ndim=1, mode="c"] a,
287293
cdef np.ndarray[double, ndim=1, mode="c"] alpha = np.zeros(n1)
288294
cdef np.ndarray[double, ndim=1, mode="c"] beta = np.zeros(n2)
289295

296+
cdef double* alpha_init_ptr = NULL
297+
cdef double* beta_init_ptr = NULL
298+
if alpha_init is not None and beta_init is not None:
299+
alpha_init_ptr = <double*> alpha_init.data
300+
beta_init_ptr = <double*> beta_init.data
301+
290302
with nogil:
291303
result_code = EMD_wrap_sparse(
292304
n1, n2,
@@ -295,7 +307,8 @@ def emd_c_sparse(np.ndarray[double, ndim=1, mode="c"] a,
295307
<uint64_t*> edge_sources.data, <uint64_t*> edge_targets.data, <double*> edge_costs.data,
296308
<uint64_t*> flow_sources.data, <uint64_t*> flow_targets.data, <double*> flow_values.data,
297309
&n_flows_out,
298-
<double*> alpha.data, <double*> beta.data, &cost, max_iter
310+
<double*> alpha.data, <double*> beta.data, &cost, max_iter,
311+
alpha_init_ptr, beta_init_ptr
299312
)
300313

301314
# Trim to actual number of flows
@@ -308,7 +321,7 @@ def emd_c_sparse(np.ndarray[double, ndim=1, mode="c"] a,
308321

309322
@cython.boundscheck(False)
310323
@cython.wraparound(False)
311-
def emd_c_lazy(np.ndarray[double, ndim=1, mode="c"] a, np.ndarray[double, ndim=1, mode="c"] b, np.ndarray[double, ndim=2, mode="c"] coords_a, np.ndarray[double, ndim=2, mode="c"] coords_b, str metric='sqeuclidean', uint64_t max_iter=100000):
324+
def emd_c_lazy(np.ndarray[double, ndim=1, mode="c"] a, np.ndarray[double, ndim=1, mode="c"] b, np.ndarray[double, ndim=2, mode="c"] coords_a, np.ndarray[double, ndim=2, mode="c"] coords_b, str metric='sqeuclidean', uint64_t max_iter=100000, np.ndarray[double, ndim=1, mode="c"] alpha_init=None, np.ndarray[double, ndim=1, mode="c"] beta_init=None):
312325
"""Solves the Earth Movers distance problem with lazy cost computation from coordinates."""
313326
cdef int n1 = coords_a.shape[0]
314327
cdef int n2 = coords_b.shape[0]
@@ -339,6 +352,13 @@ def emd_c_lazy(np.ndarray[double, ndim=1, mode="c"] a, np.ndarray[double, ndim=1
339352
a = np.ones((n1,)) / n1
340353
if not len(b):
341354
b = np.ones((n2,)) / n2
355+
356+
cdef double* alpha_init_ptr = NULL
357+
cdef double* beta_init_ptr = NULL
358+
if alpha_init is not None and beta_init is not None:
359+
alpha_init_ptr = <double*> alpha_init.data
360+
beta_init_ptr = <double*> beta_init.data
361+
342362
with nogil:
343-
result_code = EMD_wrap_lazy(n1, n2, <double*> a.data, <double*> b.data, <double*> coords_a.data, <double*> coords_b.data, dim, metric_code, <double*> G.data, <double*> alpha.data, <double*> beta.data, <double*> &cost, max_iter)
363+
result_code = EMD_wrap_lazy(n1, n2, <double*> a.data, <double*> b.data, <double*> coords_a.data, <double*> coords_b.data, dim, metric_code, <double*> G.data, <double*> alpha.data, <double*> beta.data, <double*> &cost, max_iter, alpha_init_ptr, beta_init_ptr)
344364
return G, cost, alpha, beta, result_code

ot/lp/network_simplex_simple.h

Lines changed: 13 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -679,10 +679,10 @@ namespace lemon {
679679
if (!_lazy_cost) {
680680
return _cost[arc_id];
681681
} else {
682-
// For artificial arcs (>= _arc_num), return 0
683-
// These are not real transport arcs
682+
// For artificial arcs (>= _arc_num), return stored cost
683+
// (0 for positive supply, ART_COST for negative supply)
684684
if (arc_id >= _arc_num) {
685-
return 0;
685+
return _cost[arc_id];
686686
}
687687
// Compute lazily from coordinates
688688
// _source and _target use reversed node numbering: _node_id(n) = _node_num - n - 1
@@ -1138,7 +1138,13 @@ namespace lemon {
11381138

11391139
for (ArcsType e = 0; e < _arc_num; ++e) {
11401140
_state[e] = STATE_LOWER;
1141-
Cost c = _cost[e];
1141+
Cost c;
1142+
if (_lazy_cost) {
1143+
// Compute cost on-the-fly for lazy mode
1144+
c = getCostForArc(e);
1145+
} else {
1146+
c = _cost[e];
1147+
}
11421148
if (c > ART_COST) ART_COST = c;
11431149
Cost rc = fabs(c + _pi[_source[e]] - _pi[_target[e]]);
11441150
if ((ArcsType)maxheap.size() < K) {
@@ -1436,10 +1442,11 @@ namespace lemon {
14361442
while (u != _root) {
14371443
ArcsType e = _pred[u];
14381444
int v = _parent[u];
1445+
Cost c = getCostForArc(e);
14391446
if (_forward[u]) {
1440-
_pi[u] = _pi[v] - _cost[e];
1447+
_pi[u] = _pi[v] - c;
14411448
} else {
1442-
_pi[u] = _pi[v] + _cost[e];
1449+
_pi[u] = _pi[v] + c;
14431450
}
14441451
u = _thread[u];
14451452
}

ot/solvers.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1772,6 +1772,7 @@ def solve_sample(
17721772
numItermax=max_iter if max_iter is not None else 100000,
17731773
log=True,
17741774
return_matrix=True,
1775+
potentials_init=potentials_init,
17751776
)
17761777

17771778
res = OTResult(

0 commit comments

Comments
 (0)