Skip to content

Commit 7bc3239

Browse files
[MRG] Optimize simplex setup (#796)
* Implementation of warmstart for network simplex can make use off precomputed potentials from sinkhorn or even related simplex * optimise initial setup for watmstart using heap * Update Releases and test file * working on optimizing the setup step so we loose less time on creating all the necessary variables * changed some interfaces and function names * small doc fix * making setup for dense faster * Add warmstart potentials to sparse and lazy solver also * Adding warmstart to missing functions * small wrapper issue in solve_sample * Removed timing and added prepocessing in case python includes 0 demand or supply * Updated release * delete useless file and remove timing code in openMP version * Fixed functions and added PR numbers in release file * Cleaned up and readded some code I deleted by accident * Changed names and replaced some code with existing functions --------- Co-authored-by: Rémi Flamary <remi.flamary@gmail.com>
1 parent 69d0c96 commit 7bc3239

File tree

6 files changed

+374
-176
lines changed

6 files changed

+374
-176
lines changed

RELEASES.md

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,8 @@ This new release adds support for sparse cost matrices and a new lazy EMD solver
77
#### New features
88
- Add lazy EMD solver with on-the-fly distance computation from coordinates (PR #788)
99
- Add Warmstart feature to the EMD solver for existing potentials (PR #793)
10-
- Add Warmstart potentials feature to the EMD solver for lazy and sparse solver
10+
- Add Warmstart potentials feature to the EMD solver for lazy and sparse solver (PR #795)
11+
- Faster init and result retrieval for EMD solver (PR #796)
1112
- Migrate backend from deprecated `scipy.sparse.coo_matrix` to modern `scipy.sparse.coo_array` (PR #782)
1213
- Geomloss function now handles both scalar and slice indices for i and j (PR #785)
1314
- Add support for sparse cost matrices in EMD solver (PR #778, Issue #397)

ot/lp/EMD_wrapper.cpp

Lines changed: 154 additions & 64 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,135 @@
1919
#include "EMD.h"
2020
#include <cstdint>
2121
#include <unordered_map>
22+
#include <vector>
23+
24+
namespace {
25+
26+
struct SetupPolicy {
27+
bool full_support;
28+
bool use_arc_mixing;
29+
bool use_dense_cost_pointer;
30+
};
31+
32+
inline SetupPolicy make_setup_policy(
33+
uint64_t n,
34+
uint64_t m,
35+
int n1,
36+
int n2,
37+
bool dense_cost_pointer_supported
38+
) {
39+
SetupPolicy policy;
40+
policy.full_support = (n == static_cast<uint64_t>(n1)) && (m == static_cast<uint64_t>(n2));
41+
policy.use_arc_mixing = !policy.full_support;
42+
policy.use_dense_cost_pointer = dense_cost_pointer_supported && policy.full_support;
43+
return policy;
44+
}
45+
46+
template <typename NetType, typename DigraphType>
47+
inline void setup_explicit_arc_costs(
48+
NetType& net,
49+
DigraphType& di,
50+
const double* D,
51+
int n2,
52+
const std::vector<uint64_t>& indI,
53+
const std::vector<uint64_t>& indJ,
54+
uint64_t n,
55+
uint64_t m
56+
) {
57+
int64_t idarc = 0;
58+
for (uint64_t i = 0; i < n; ++i) {
59+
for (uint64_t j = 0; j < m; ++j) {
60+
net.setCost(di.arcFromId(idarc), D[indI[i] * n2 + indJ[j]]);
61+
++idarc;
62+
}
63+
}
64+
}
65+
66+
template <typename NetType>
67+
inline void setup_warmstart_potentials(
68+
NetType& net,
69+
const double* alpha_init,
70+
const double* beta_init,
71+
const std::vector<uint64_t>& indI,
72+
const std::vector<uint64_t>& indJ,
73+
uint64_t n,
74+
uint64_t m
75+
) {
76+
if (alpha_init == nullptr || beta_init == nullptr) return;
77+
std::vector<double> alpha_compressed(n);
78+
std::vector<double> beta_compressed(m);
79+
for (uint64_t i = 0; i < n; ++i) alpha_compressed[i] = alpha_init[indI[i]];
80+
for (uint64_t j = 0; j < m; ++j) beta_compressed[j] = beta_init[indJ[j]];
81+
net.setWarmstartPotentials(&alpha_compressed[0], &beta_compressed[0], (int)n, (int)m);
82+
}
83+
84+
template <typename NetType>
85+
inline void extract_dense_full_support(
86+
const NetType& net,
87+
const double* D,
88+
double* G,
89+
double* alpha,
90+
double* beta,
91+
double* cost,
92+
uint64_t n,
93+
uint64_t m
94+
) {
95+
const int node_total = net.nodeNum();
96+
const int pi_base = node_total - 1;
97+
98+
for (uint64_t ii = 0; ii < n; ++ii) {
99+
alpha[ii] = -net._pi[pi_base - static_cast<int>(ii)];
100+
}
101+
for (uint64_t jj = 0; jj < m; ++jj) {
102+
beta[jj] = net._pi[pi_base - static_cast<int>(n + jj)];
103+
}
104+
105+
// Only write non-zero entries. G is already zero-initialized in Python.
106+
const int64_t arc_total = net.arcNum();
107+
for (int64_t a = 0; a < arc_total; ++a) {
108+
const double flow = net._flow[a];
109+
if (flow == 0.0) continue;
110+
const int64_t d_idx = arc_total - a - 1; // row-major index in D/G
111+
*cost += flow * D[d_idx];
112+
G[d_idx] = flow;
113+
}
114+
}
115+
116+
template <typename NetType, typename DigraphType, typename InvalidType>
117+
inline void extract_compressed_support(
118+
const NetType& net,
119+
DigraphType& di,
120+
InvalidType invalid,
121+
const double* D,
122+
double* G,
123+
double* alpha,
124+
double* beta,
125+
double* cost,
126+
const std::vector<uint64_t>& indI,
127+
const std::vector<uint64_t>& indJ,
128+
uint64_t n,
129+
int n2
130+
) {
131+
for (uint64_t ii = 0; ii < n; ++ii) {
132+
alpha[indI[ii]] = -net.potential(ii);
133+
}
134+
for (uint64_t jj = 0; jj < indJ.size(); ++jj) {
135+
beta[indJ[jj]] = net.potential(jj + n);
136+
}
137+
138+
uint64_t i, j;
139+
typename DigraphType::Arc a;
140+
di.first(a);
141+
for (; a != invalid; di.next(a)) {
142+
i = di.source(a);
143+
j = di.target(a);
144+
const double flow = net.flow(a);
145+
*cost += flow * D[indI[i] * n2 + indJ[j - n]];
146+
G[indI[i] * n2 + indJ[j - n]] = flow;
147+
}
148+
}
149+
150+
} // namespace
22151

23152

24153
int EMD_wrap(int n1, int n2, double *X, double *Y, double *D, double *G,
@@ -52,12 +181,14 @@ int EMD_wrap(int n1, int n2, double *X, double *Y, double *D, double *G,
52181
}
53182
}
54183

55-
// Define the graph
56-
184+
// Define graph and solver
57185
std::vector<uint64_t> indI(n), indJ(m);
58186
std::vector<double> weights1(n), weights2(m);
59187
Digraph di(n, m);
60-
NetworkSimplexSimple<Digraph,double,double, node_id_type> net(di, true, (int) (n + m), n * m, maxIter);
188+
const SetupPolicy policy = make_setup_policy(n, m, n1, n2, true);
189+
NetworkSimplexSimple<Digraph,double,double, node_id_type> net(
190+
di, policy.use_arc_mixing, (int) (n + m), n * m, maxIter
191+
);
61192

62193
// Set supply and demand, don't account for 0 values (faster)
63194

@@ -84,51 +215,26 @@ int EMD_wrap(int n1, int n2, double *X, double *Y, double *D, double *G,
84215

85216
net.supplyMap(&weights1[0], (int) n, &weights2[0], (int) m);
86217

87-
// Set the cost of each edge
88-
int64_t idarc = 0;
89-
for (uint64_t i=0; i<n; i++) {
90-
for (uint64_t j=0; j<m; j++) {
91-
double val=*(D+indI[i]*n2+indJ[j]);
92-
net.setCost(di.arcFromId(idarc), val);
93-
++idarc;
94-
}
95-
}
96-
97-
// Set warmstart potentials if provided
98-
if (alpha_init != nullptr && beta_init != nullptr) {
99-
// Compress warmstart potentials to only non-zero entries
100-
std::vector<double> alpha_compressed(n);
101-
std::vector<double> beta_compressed(m);
102-
for (uint64_t i = 0; i < n; i++) {
103-
alpha_compressed[i] = alpha_init[indI[i]];
104-
}
105-
for (uint64_t j = 0; j < m; j++) {
106-
beta_compressed[j] = beta_init[indJ[j]];
107-
}
108-
net.setWarmstartPotentials(&alpha_compressed[0], &beta_compressed[0], (int)n, (int)m);
218+
if (policy.use_dense_cost_pointer) {
219+
net.setDenseCostMatrix(D, n2);
220+
} else {
221+
setup_explicit_arc_costs(net, di, D, n2, indI, indJ, n, m);
109222
}
110-
223+
setup_warmstart_potentials(net, alpha_init, beta_init, indI, indJ, n, m);
111224
// Solve the problem with the network simplex algorithm
112225

113226
int ret=net.run();
114227

115-
uint64_t i, j;
116228
if (ret==(int)net.OPTIMAL || ret==(int)net.MAX_ITER_REACHED) {
117229
*cost = 0;
118-
Arc a; di.first(a);
119-
for (; a != INVALID; di.next(a)) {
120-
i = di.source(a);
121-
j = di.target(a);
122-
double flow = net.flow(a);
123-
*cost += flow * (*(D+indI[i]*n2+indJ[j-n]));
124-
*(G+indI[i]*n2+indJ[j-n]) = flow;
125-
*(alpha + indI[i]) = -net.potential(i);
126-
*(beta + indJ[j-n]) = net.potential(j);
230+
if (policy.full_support) {
231+
extract_dense_full_support(net, D, G, alpha, beta, cost, n, m);
232+
} else {
233+
extract_compressed_support(
234+
net, di, INVALID, D, G, alpha, beta, cost, indI, indJ, n, n2
235+
);
127236
}
128-
129237
}
130-
131-
132238
return ret;
133239
}
134240

@@ -173,7 +279,10 @@ int EMD_wrap_omp(int n1, int n2, double *X, double *Y, double *D, double *G,
173279
std::vector<uint64_t> indI(n), indJ(m);
174280
std::vector<double> weights1(n), weights2(m);
175281
Digraph di(n, m);
176-
NetworkSimplexSimple<Digraph,double,double, node_id_type> net(di, true, (int) (n + m), n * m, maxIter, numThreads);
282+
const SetupPolicy policy = make_setup_policy(n, m, n1, n2, false);
283+
NetworkSimplexSimple<Digraph,double,double, node_id_type> net(
284+
di, policy.use_arc_mixing, (int) (n + m), n * m, maxIter, numThreads
285+
);
177286

178287
// Set supply and demand, don't account for 0 values (faster)
179288

@@ -200,37 +309,19 @@ int EMD_wrap_omp(int n1, int n2, double *X, double *Y, double *D, double *G,
200309

201310
net.supplyMap(&weights1[0], (int) n, &weights2[0], (int) m);
202311

203-
// Set the cost of each edge
204-
int64_t idarc = 0;
205-
for (uint64_t i=0; i<n; i++) {
206-
for (uint64_t j=0; j<m; j++) {
207-
double val=*(D+indI[i]*n2+indJ[j]);
208-
net.setCost(di.arcFromId(idarc), val);
209-
++idarc;
210-
}
211-
}
212-
312+
setup_explicit_arc_costs(net, di, D, n2, indI, indJ, n, m);
213313

214314
// Solve the problem with the network simplex algorithm
215315

216316
int ret=net.run();
217-
uint64_t i, j;
218317
if (ret==(int)net.OPTIMAL || ret==(int)net.MAX_ITER_REACHED) {
219318
*cost = 0;
220-
Arc a; di.first(a);
221-
for (; a != INVALID; di.next(a)) {
222-
i = di.source(a);
223-
j = di.target(a);
224-
double flow = net.flow(a);
225-
*cost += flow * (*(D+indI[i]*n2+indJ[j-n]));
226-
*(G+indI[i]*n2+indJ[j-n]) = flow;
227-
*(alpha + indI[i]) = -net.potential(i);
228-
*(beta + indJ[j-n]) = net.potential(j);
229-
}
319+
extract_compressed_support(
320+
net, di, INVALID, D, G, alpha, beta, cost, indI, indJ, n, n2
321+
);
230322

231323
}
232324

233-
234325
return ret;
235326
}
236327

@@ -370,7 +461,6 @@ int EMD_wrap_sparse(
370461
}
371462

372463
int ret = net.run();
373-
374464
if (ret == (int)net.OPTIMAL || ret == (int)net.MAX_ITER_REACHED) {
375465
*cost = 0;
376466
*n_flows_out = 0;
@@ -521,6 +611,6 @@ int EMD_wrap_lazy(int n1, int n2, double *X, double *Y, double *coords_a, double
521611
}
522612
}
523613
}
524-
614+
525615
return ret;
526616
}

0 commit comments

Comments
 (0)