1919#include " EMD.h"
2020#include < cstdint>
2121#include < unordered_map>
22+ #include < vector>
23+
24+ namespace {
25+
26+ struct SetupPolicy {
27+ bool full_support;
28+ bool use_arc_mixing;
29+ bool use_dense_cost_pointer;
30+ };
31+
32+ inline SetupPolicy make_setup_policy (
33+ uint64_t n,
34+ uint64_t m,
35+ int n1,
36+ int n2,
37+ bool dense_cost_pointer_supported
38+ ) {
39+ SetupPolicy policy;
40+ policy.full_support = (n == static_cast <uint64_t >(n1)) && (m == static_cast <uint64_t >(n2));
41+ policy.use_arc_mixing = !policy.full_support ;
42+ policy.use_dense_cost_pointer = dense_cost_pointer_supported && policy.full_support ;
43+ return policy;
44+ }
45+
46+ template <typename NetType, typename DigraphType>
47+ inline void setup_explicit_arc_costs (
48+ NetType& net,
49+ DigraphType& di,
50+ const double * D,
51+ int n2,
52+ const std::vector<uint64_t >& indI,
53+ const std::vector<uint64_t >& indJ,
54+ uint64_t n,
55+ uint64_t m
56+ ) {
57+ int64_t idarc = 0 ;
58+ for (uint64_t i = 0 ; i < n; ++i) {
59+ for (uint64_t j = 0 ; j < m; ++j) {
60+ net.setCost (di.arcFromId (idarc), D[indI[i] * n2 + indJ[j]]);
61+ ++idarc;
62+ }
63+ }
64+ }
65+
66+ template <typename NetType>
67+ inline void setup_warmstart_potentials (
68+ NetType& net,
69+ const double * alpha_init,
70+ const double * beta_init,
71+ const std::vector<uint64_t >& indI,
72+ const std::vector<uint64_t >& indJ,
73+ uint64_t n,
74+ uint64_t m
75+ ) {
76+ if (alpha_init == nullptr || beta_init == nullptr ) return ;
77+ std::vector<double > alpha_compressed (n);
78+ std::vector<double > beta_compressed (m);
79+ for (uint64_t i = 0 ; i < n; ++i) alpha_compressed[i] = alpha_init[indI[i]];
80+ for (uint64_t j = 0 ; j < m; ++j) beta_compressed[j] = beta_init[indJ[j]];
81+ net.setWarmstartPotentials (&alpha_compressed[0 ], &beta_compressed[0 ], (int )n, (int )m);
82+ }
83+
84+ template <typename NetType>
85+ inline void extract_dense_full_support (
86+ const NetType& net,
87+ const double * D,
88+ double * G,
89+ double * alpha,
90+ double * beta,
91+ double * cost,
92+ uint64_t n,
93+ uint64_t m
94+ ) {
95+ const int node_total = net.nodeNum ();
96+ const int pi_base = node_total - 1 ;
97+
98+ for (uint64_t ii = 0 ; ii < n; ++ii) {
99+ alpha[ii] = -net._pi [pi_base - static_cast <int >(ii)];
100+ }
101+ for (uint64_t jj = 0 ; jj < m; ++jj) {
102+ beta[jj] = net._pi [pi_base - static_cast <int >(n + jj)];
103+ }
104+
105+ // Only write non-zero entries. G is already zero-initialized in Python.
106+ const int64_t arc_total = net.arcNum ();
107+ for (int64_t a = 0 ; a < arc_total; ++a) {
108+ const double flow = net._flow [a];
109+ if (flow == 0.0 ) continue ;
110+ const int64_t d_idx = arc_total - a - 1 ; // row-major index in D/G
111+ *cost += flow * D[d_idx];
112+ G[d_idx] = flow;
113+ }
114+ }
115+
116+ template <typename NetType, typename DigraphType, typename InvalidType>
117+ inline void extract_compressed_support (
118+ const NetType& net,
119+ DigraphType& di,
120+ InvalidType invalid,
121+ const double * D,
122+ double * G,
123+ double * alpha,
124+ double * beta,
125+ double * cost,
126+ const std::vector<uint64_t >& indI,
127+ const std::vector<uint64_t >& indJ,
128+ uint64_t n,
129+ int n2
130+ ) {
131+ for (uint64_t ii = 0 ; ii < n; ++ii) {
132+ alpha[indI[ii]] = -net.potential (ii);
133+ }
134+ for (uint64_t jj = 0 ; jj < indJ.size (); ++jj) {
135+ beta[indJ[jj]] = net.potential (jj + n);
136+ }
137+
138+ uint64_t i, j;
139+ typename DigraphType::Arc a;
140+ di.first (a);
141+ for (; a != invalid; di.next (a)) {
142+ i = di.source (a);
143+ j = di.target (a);
144+ const double flow = net.flow (a);
145+ *cost += flow * D[indI[i] * n2 + indJ[j - n]];
146+ G[indI[i] * n2 + indJ[j - n]] = flow;
147+ }
148+ }
149+
150+ } // namespace
22151
23152
24153int EMD_wrap (int n1, int n2, double *X, double *Y, double *D, double *G,
@@ -52,12 +181,14 @@ int EMD_wrap(int n1, int n2, double *X, double *Y, double *D, double *G,
52181 }
53182 }
54183
55- // Define the graph
56-
184+ // Define graph and solver
57185 std::vector<uint64_t > indI (n), indJ (m);
58186 std::vector<double > weights1 (n), weights2 (m);
59187 Digraph di (n, m);
60- NetworkSimplexSimple<Digraph,double ,double , node_id_type> net (di, true , (int ) (n + m), n * m, maxIter);
188+ const SetupPolicy policy = make_setup_policy (n, m, n1, n2, true );
189+ NetworkSimplexSimple<Digraph,double ,double , node_id_type> net (
190+ di, policy.use_arc_mixing , (int ) (n + m), n * m, maxIter
191+ );
61192
62193 // Set supply and demand, don't account for 0 values (faster)
63194
@@ -84,51 +215,26 @@ int EMD_wrap(int n1, int n2, double *X, double *Y, double *D, double *G,
84215
85216 net.supplyMap (&weights1[0 ], (int ) n, &weights2[0 ], (int ) m);
86217
87- // Set the cost of each edge
88- int64_t idarc = 0 ;
89- for (uint64_t i=0 ; i<n; i++) {
90- for (uint64_t j=0 ; j<m; j++) {
91- double val=*(D+indI[i]*n2+indJ[j]);
92- net.setCost (di.arcFromId (idarc), val);
93- ++idarc;
94- }
95- }
96-
97- // Set warmstart potentials if provided
98- if (alpha_init != nullptr && beta_init != nullptr ) {
99- // Compress warmstart potentials to only non-zero entries
100- std::vector<double > alpha_compressed (n);
101- std::vector<double > beta_compressed (m);
102- for (uint64_t i = 0 ; i < n; i++) {
103- alpha_compressed[i] = alpha_init[indI[i]];
104- }
105- for (uint64_t j = 0 ; j < m; j++) {
106- beta_compressed[j] = beta_init[indJ[j]];
107- }
108- net.setWarmstartPotentials (&alpha_compressed[0 ], &beta_compressed[0 ], (int )n, (int )m);
218+ if (policy.use_dense_cost_pointer ) {
219+ net.setDenseCostMatrix (D, n2);
220+ } else {
221+ setup_explicit_arc_costs (net, di, D, n2, indI, indJ, n, m);
109222 }
110-
223+ setup_warmstart_potentials (net, alpha_init, beta_init, indI, indJ, n, m);
111224 // Solve the problem with the network simplex algorithm
112225
113226 int ret=net.run ();
114227
115- uint64_t i, j;
116228 if (ret==(int )net.OPTIMAL || ret==(int )net.MAX_ITER_REACHED ) {
117229 *cost = 0 ;
118- Arc a; di.first (a);
119- for (; a != INVALID; di.next (a)) {
120- i = di.source (a);
121- j = di.target (a);
122- double flow = net.flow (a);
123- *cost += flow * (*(D+indI[i]*n2+indJ[j-n]));
124- *(G+indI[i]*n2+indJ[j-n]) = flow;
125- *(alpha + indI[i]) = -net.potential (i);
126- *(beta + indJ[j-n]) = net.potential (j);
230+ if (policy.full_support ) {
231+ extract_dense_full_support (net, D, G, alpha, beta, cost, n, m);
232+ } else {
233+ extract_compressed_support (
234+ net, di, INVALID, D, G, alpha, beta, cost, indI, indJ, n, n2
235+ );
127236 }
128-
129237 }
130-
131-
132238 return ret;
133239}
134240
@@ -173,7 +279,10 @@ int EMD_wrap_omp(int n1, int n2, double *X, double *Y, double *D, double *G,
173279 std::vector<uint64_t > indI (n), indJ (m);
174280 std::vector<double > weights1 (n), weights2 (m);
175281 Digraph di (n, m);
176- NetworkSimplexSimple<Digraph,double ,double , node_id_type> net (di, true , (int ) (n + m), n * m, maxIter, numThreads);
282+ const SetupPolicy policy = make_setup_policy (n, m, n1, n2, false );
283+ NetworkSimplexSimple<Digraph,double ,double , node_id_type> net (
284+ di, policy.use_arc_mixing , (int ) (n + m), n * m, maxIter, numThreads
285+ );
177286
178287 // Set supply and demand, don't account for 0 values (faster)
179288
@@ -200,37 +309,19 @@ int EMD_wrap_omp(int n1, int n2, double *X, double *Y, double *D, double *G,
200309
201310 net.supplyMap (&weights1[0 ], (int ) n, &weights2[0 ], (int ) m);
202311
203- // Set the cost of each edge
204- int64_t idarc = 0 ;
205- for (uint64_t i=0 ; i<n; i++) {
206- for (uint64_t j=0 ; j<m; j++) {
207- double val=*(D+indI[i]*n2+indJ[j]);
208- net.setCost (di.arcFromId (idarc), val);
209- ++idarc;
210- }
211- }
212-
312+ setup_explicit_arc_costs (net, di, D, n2, indI, indJ, n, m);
213313
214314 // Solve the problem with the network simplex algorithm
215315
216316 int ret=net.run ();
217- uint64_t i, j;
218317 if (ret==(int )net.OPTIMAL || ret==(int )net.MAX_ITER_REACHED ) {
219318 *cost = 0 ;
220- Arc a; di.first (a);
221- for (; a != INVALID; di.next (a)) {
222- i = di.source (a);
223- j = di.target (a);
224- double flow = net.flow (a);
225- *cost += flow * (*(D+indI[i]*n2+indJ[j-n]));
226- *(G+indI[i]*n2+indJ[j-n]) = flow;
227- *(alpha + indI[i]) = -net.potential (i);
228- *(beta + indJ[j-n]) = net.potential (j);
229- }
319+ extract_compressed_support (
320+ net, di, INVALID, D, G, alpha, beta, cost, indI, indJ, n, n2
321+ );
230322
231323 }
232324
233-
234325 return ret;
235326}
236327
@@ -370,7 +461,6 @@ int EMD_wrap_sparse(
370461 }
371462
372463 int ret = net.run ();
373-
374464 if (ret == (int )net.OPTIMAL || ret == (int )net.MAX_ITER_REACHED ) {
375465 *cost = 0 ;
376466 *n_flows_out = 0 ;
@@ -521,6 +611,6 @@ int EMD_wrap_lazy(int n1, int n2, double *X, double *Y, double *coords_a, double
521611 }
522612 }
523613 }
524-
614+
525615 return ret;
526616}
0 commit comments