Skip to content

Commit 0e07b31

Browse files
committed
small copilot edits
1 parent 2b624f5 commit 0e07b31

File tree

1 file changed

+62
-74
lines changed

1 file changed

+62
-74
lines changed

src/lp.cpp

Lines changed: 62 additions & 74 deletions
Original file line numberDiff line numberDiff line change
@@ -7,9 +7,9 @@ using namespace arma;
77

88
/**
99
* Two-Phase Simplex Method Implementation
10-
*
10+
*
1111
* Solves: minimize c'x subject to Ax = b, x >= 0
12-
*
12+
*
1313
* Phase 1: Find initial feasible basis using artificial variables
1414
* Phase 2: Optimize objective from feasible basis
1515
*/
@@ -32,12 +32,12 @@ static int find_entering_variable(const rowvec& reduced_costs, double tol) {
3232
* Find leaving variable using minimum ratio test with Bland's rule for ties
3333
* Returns -1 if unbounded
3434
*/
35-
static int find_leaving_variable(const mat& tableau, int entering_col,
35+
static int find_leaving_variable(const mat& tableau, int entering_col,
3636
const uvec& basis, double tol) {
3737
int m = tableau.n_rows - 1; // Exclude objective row
3838
int leaving_row = -1;
3939
double min_ratio = std::numeric_limits<double>::infinity();
40-
40+
4141
for (int i = 0; i < m; ++i) {
4242
double pivot_elem = tableau(i, entering_col);
4343
if (pivot_elem > tol) {
@@ -53,7 +53,7 @@ static int find_leaving_variable(const mat& tableau, int entering_col,
5353
}
5454
}
5555
}
56-
56+
5757
return leaving_row;
5858
}
5959

@@ -63,7 +63,7 @@ static int find_leaving_variable(const mat& tableau, int entering_col,
6363
static void pivot(mat& tableau, int pivot_row, int pivot_col) {
6464
double pivot_elem = tableau(pivot_row, pivot_col);
6565
tableau.row(pivot_row) /= pivot_elem;
66-
66+
6767
int m = tableau.n_rows;
6868
for (int i = 0; i < m; ++i) {
6969
if (i != pivot_row) {
@@ -75,78 +75,78 @@ static void pivot(mat& tableau, int pivot_row, int pivot_col) {
7575

7676
/**
7777
* Phase 1: Find initial feasible basis using artificial variables
78-
*
78+
*
7979
* Solves: minimize sum(artificial variables)
8080
* Returns: basis indices and whether feasible solution exists
8181
*/
82-
static bool phase1_simplex(const mat& A, const vec& b, uvec& basis,
82+
static bool phase1_simplex(const mat& A, const vec& b, uvec& basis,
8383
mat& tableau, double tol, int max_iter = 10000) {
8484
uword m = A.n_rows; // Number of constraints
8585
uword n = A.n_cols; // Number of original variables
86-
86+
8787
// Build Phase 1 tableau: [A | I | b]
8888
// Variables: [original vars | artificial vars | RHS]
8989
// Objective: minimize sum of artificial variables
9090
tableau = mat(m + 1, n + m + 1, fill::zeros);
91-
91+
9292
// Constraint rows
9393
tableau(span(0, m-1), span(0, n-1)) = A;
9494
tableau(span(0, m-1), span(n, n+m-1)) = eye<mat>(m, m);
9595
tableau(span(0, m-1), n + m) = b;
96-
96+
9797
// Objective row: minimize sum of artificial variables
9898
// After adding constraints, reduced costs for artificial vars
9999
tableau(m, span(n, n+m-1)).fill(1.0);
100-
100+
101101
// Initial basis: artificial variables
102102
basis = uvec(m);
103103
for (uword i = 0; i < m; ++i) {
104104
basis(i) = n + i;
105105
}
106-
106+
107107
// Update objective row by subtracting constraint rows
108108
for (uword i = 0; i < m; ++i) {
109109
tableau.row(m) -= tableau.row(i);
110110
}
111-
111+
112112
// Run simplex on Phase 1 problem
113113
int iter = 0;
114114
while (iter < max_iter) {
115115
// Get reduced costs (objective row, excluding RHS)
116116
rowvec reduced_costs = tableau(m, span(0, n + m - 1));
117-
117+
118118
// Find entering variable
119119
int entering = find_entering_variable(reduced_costs, tol);
120120
if (entering == -1) {
121121
// Optimal solution found for Phase 1
122122
break;
123123
}
124-
124+
125125
// Find leaving variable
126126
int leaving = find_leaving_variable(tableau, entering, basis, tol);
127127
if (leaving == -1) {
128128
// Unbounded (shouldn't happen in Phase 1 with artificial vars)
129129
return false;
130130
}
131-
131+
132132
// Pivot
133133
pivot(tableau, leaving, entering);
134134
basis(leaving) = entering;
135-
135+
136136
++iter;
137137
}
138-
138+
139139
if (iter >= max_iter) {
140140
return false;
141141
}
142-
142+
143143
// Check if feasible: objective value should be ~0
144144
double phase1_obj = -tableau(m, n + m);
145145
if (phase1_obj > tol) {
146146
// Infeasible
147147
return false;
148148
}
149-
149+
150150
// Remove artificial variables from basis if present
151151
for (uword i = 0; i < m; ++i) {
152152
if (basis(i) >= n) {
@@ -160,52 +160,52 @@ static bool phase1_simplex(const mat& A, const vec& b, uvec& basis,
160160
}
161161
}
162162
}
163-
163+
164164
return true;
165165
}
166166

167167
/**
168168
* Phase 2: Optimize objective function from feasible basis
169169
*/
170-
static LPResult phase2_simplex(const vec& c, const mat& tableau_p1, const uvec& basis_p1,
170+
static LPResult phase2_simplex(const vec& c, const mat& tableau_p1, const uvec& basis_p1,
171171
uword n_orig, double tol, int max_iter = 10000) {
172172
LPResult result;
173173
uword m = basis_p1.n_elem;
174-
174+
175175
// Remove artificial variable columns from tableau
176176
// Keep only: original variables (0..n_orig-1) and RHS (last column)
177177
uword rhs_col = tableau_p1.n_cols - 1;
178178
mat tableau(m + 1, n_orig + 1);
179179
tableau(span(0, m-1), span(0, n_orig-1)) = tableau_p1(span(0, m-1), span(0, n_orig-1));
180180
tableau(span(0, m-1), n_orig) = tableau_p1(span(0, m-1), rhs_col);
181-
181+
182182
// Copy basis (but artificial vars should have been pivoted out already)
183183
uvec basis = basis_p1;
184-
184+
185185
// Build Phase 2 objective row
186186
tableau.row(m).zeros();
187187
tableau(m, span(0, n_orig - 1)) = c.t();
188-
188+
189189
// Update reduced costs for current basis
190190
for (uword i = 0; i < m; ++i) {
191191
if (basis(i) < n_orig) {
192192
tableau.row(m) -= c(basis(i)) * tableau.row(i);
193193
}
194194
}
195-
195+
196196
// Run simplex for Phase 2
197197
int iter = 0;
198198
while (iter < max_iter) {
199199
// Get reduced costs
200200
rowvec reduced_costs = tableau(m, span(0, n_orig - 1));
201-
201+
202202
// Find entering variable
203203
int entering = find_entering_variable(reduced_costs, tol);
204204
if (entering == -1) {
205205
// Optimal solution found
206206
result.status = LPStatus::OPTIMAL;
207207
result.objval = -tableau(m, n_orig);
208-
208+
209209
// Extract solution
210210
result.solution = vec(n_orig, fill::zeros);
211211
for (uword i = 0; i < m; ++i) {
@@ -215,22 +215,22 @@ static LPResult phase2_simplex(const vec& c, const mat& tableau_p1, const uvec&
215215
}
216216
return result;
217217
}
218-
218+
219219
// Find leaving variable
220220
int leaving = find_leaving_variable(tableau, entering, basis, tol);
221221
if (leaving == -1) {
222222
// Unbounded
223223
result.status = LPStatus::UNBOUNDED;
224224
return result;
225225
}
226-
226+
227227
// Pivot
228228
pivot(tableau, leaving, entering);
229229
basis(leaving) = entering;
230-
230+
231231
++iter;
232232
}
233-
233+
234234
// Max iterations reached
235235
result.status = LPStatus::ERROR;
236236
return result;
@@ -241,16 +241,16 @@ static LPResult phase2_simplex(const vec& c, const mat& tableau_p1, const uvec&
241241
*/
242242
LPResult solve_lp_simplex(const vec& c, const mat& A, const vec& b, double tol) {
243243
LPResult result;
244-
244+
245245
// Validate inputs
246246
if (A.n_rows != b.n_elem || A.n_cols != c.n_elem) {
247247
result.status = LPStatus::ERROR;
248248
return result;
249249
}
250-
250+
251251
uword m = A.n_rows;
252252
uword n = A.n_cols;
253-
253+
254254
// Check for negative RHS and flip constraints if needed
255255
mat A_work = A;
256256
vec b_work = b;
@@ -260,24 +260,24 @@ LPResult solve_lp_simplex(const vec& c, const mat& A, const vec& b, double tol)
260260
b_work(i) = -b_work(i);
261261
}
262262
}
263-
263+
264264
uvec basis;
265265
mat tableau;
266-
266+
267267
// Phase 1: Find feasible basis
268268
if (!phase1_simplex(A_work, b_work, basis, tableau, tol)) {
269269
result.status = LPStatus::INFEASIBLE;
270270
return result;
271271
}
272-
272+
273273
// Phase 2: Optimize
274274
result = phase2_simplex(c, tableau, basis, n, tol);
275-
275+
276276
return result;
277277
}
278278

279279
/**
280-
* Compute bounds on contrasts using custom LP solver
280+
* Compute bounds on contrasts using custom LP solver (optimized version)
281281
*/
282282
std::tuple<mat, mat> bounds_lp_contrast_cpp(
283283
const mat& x,
@@ -294,89 +294,77 @@ std::tuple<mat, mat> bounds_lp_contrast_cpp(
294294
uword n_y = y.n_cols; // Number of outcomes
295295
uword n_c = contr_m.n_cols; // Number of contrasts
296296
uword n_vars = n_y * n_x; // Number of variables (entries of B)
297-
297+
298298
mat res_min(n, n_c);
299299
mat res_max(n, n_c);
300-
301-
// Build constraint matrix structure (same for all observations and contrasts)
302-
// Constraints:
303-
// 1. B %*% x[i,] = y[i,] (n_y equality constraints)
304-
// 2. If sum_one: sum(B[,k]) = scale*(1-shift) for k=1..n_x (n_x constraints)
305-
// 3. If has_ub: B[j] <= ub (n_vars inequality -> converted to equality with slack)
306-
300+
301+
// Build constraint matrix structure (reused across observations)
307302
uword n_eq = n_y + (sum_one ? n_x : 0);
308303
uword n_ineq = has_ub ? n_vars : 0;
309304
uword n_constraints = n_eq + n_ineq;
310305
uword n_total_vars = n_vars + n_ineq; // Original + slack variables
311-
306+
312307
mat A(n_constraints, n_total_vars, fill::zeros);
313308
vec b(n_constraints);
314-
315-
// Build static parts of constraint matrix
316-
317-
// Part 2: sum_one constraints (if applicable)
309+
310+
// Part 2: sum_one constraints (static across observations)
318311
if (sum_one) {
319-
// Each column of B sums to scale*(1-shift)
320-
// Variables are in row-major order: B[0,0], B[0,1], ..., B[0,n_x-1], B[1,0], ...
321312
for (uword k = 0; k < n_x; ++k) {
322313
for (uword j = 0; j < n_y; ++j) {
323314
A(n_y + k, j * n_x + k) = 1.0;
324315
}
325316
b(n_y + k) = scale * (1.0 - shift);
326317
}
327318
}
328-
329-
// Part 3: Upper bound constraints (if applicable)
319+
320+
// Part 3: Upper bound constraints (static across observations)
330321
if (has_ub) {
331-
// B[idx] + slack[idx] = ub for idx=0..n_vars-1
332322
for (uword idx = 0; idx < n_vars; ++idx) {
333323
A(n_eq + idx, idx) = 1.0; // Original variable
334324
A(n_eq + idx, n_vars + idx) = 1.0; // Slack variable
335325
b(n_eq + idx) = ub;
336326
}
337327
}
338-
328+
339329
// Solve LP for each observation and contrast
340330
for (uword j = 0; j < n_c; ++j) {
341331
vec contr = contr_m.col(j);
342332
double obj_offset = sum(contr) * shift;
343-
344-
// Build objective for original variables
333+
334+
// Build objective for original variables (reused across observations)
345335
vec c(n_total_vars, fill::zeros);
346336
c(span(0, n_vars - 1)) = contr;
347-
337+
348338
for (uword i = 0; i < n; ++i) {
349-
// Part 1: Observation-specific constraints B %*% x[i,] = y[i,]
350-
// B is n_y x n_x in row-major order
351-
// B %*% x means: sum_k B[j,k] * x[k] = y[j] for each row j
339+
// Part 1: Update observation-specific constraints B %*% x[i,] = y[i,]
352340
for (uword row = 0; row < n_y; ++row) {
353341
for (uword col = 0; col < n_x; ++col) {
354342
A(row, row * n_x + col) = x(i, col);
355343
}
356344
b(row) = y(i, row);
357345
}
358-
359-
// Solve for minimum
346+
347+
// Solve for minimum and maximum using two-phase simplex
360348
LPResult sol_min = solve_lp_simplex(c, A, b);
349+
LPResult sol_max = solve_lp_simplex(-c, A, b);
350+
361351
if (sol_min.status == LPStatus::OPTIMAL) {
362352
res_min(i, j) = sol_min.objval + obj_offset;
363353
} else {
364354
res_min(i, j) = datum::nan;
365355
}
366-
367-
// Solve for maximum (negate objective)
368-
LPResult sol_max = solve_lp_simplex(-c, A, b);
356+
369357
if (sol_max.status == LPStatus::OPTIMAL) {
370358
res_max(i, j) = -sol_max.objval + obj_offset;
371359
} else {
372360
res_max(i, j) = datum::nan;
373361
}
374362
}
375363
}
376-
364+
377365
// Apply inverse transformation
378366
res_min = res_min * scale + shift;
379367
res_max = res_max * scale + shift;
380-
368+
381369
return std::make_tuple(res_min, res_max);
382370
}

0 commit comments

Comments
 (0)