Skip to content

Commit 138ffbb

Browse files
authored
Merge pull request #15 from ZedongPeng/fix-bug
fix: rapdhg constant stepsize nan error
2 parents bf2ac62 + d18b0c9 commit 138ffbb

File tree

2 files changed

+31
-8
lines changed

2 files changed

+31
-8
lines changed

mpax/rapdhg.py

+20-8
Original file line numberDiff line numberDiff line change
@@ -82,6 +82,8 @@ def estimate_maximum_singular_value(
8282
matrix_transpose = BCSR.from_bcoo(matrix.to_bcoo().T)
8383
elif isinstance(matrix, BCOO):
8484
matrix_transpose = BCSR.from_bcoo(matrix.T)
85+
elif isinstance(matrix, jnp.ndarray):
86+
matrix_transpose = matrix.T
8587
number_of_power_iterations = 0
8688

8789
def cond_fun(state):
@@ -401,6 +403,7 @@ def initialize_solver_status(
401403
scaled_problem: ScaledQpProblem,
402404
initial_primal_solution: jnp.array,
403405
initial_dual_solution: jnp.array,
406+
is_lp: bool = True,
404407
) -> PdhgSolverState:
405408
"""Initialize the solver status for PDHG.
406409
@@ -451,9 +454,11 @@ def initialize_solver_status(
451454
self._norm_A = estimate_maximum_singular_value(scaled_qp.constraint_matrix)[
452455
0
453456
]
454-
self._norm_Q = estimate_maximum_singular_value(scaled_qp.objective_matrix)[
455-
0
456-
]
457+
self._norm_Q = jax.lax.cond(
458+
is_lp,
459+
lambda: 0.0,
460+
lambda: estimate_maximum_singular_value(scaled_qp.objective_matrix)[0],
461+
)
457462
step_size = 1.0 # Placeholder for step size.
458463

459464
if self.warm_start:
@@ -551,10 +556,14 @@ def take_step(
551556
extrapolation_coefficient = solver_state.solutions_count / (
552557
solver_state.solutions_count + 1.0
553558
)
554-
step_size = self.calculate_constant_step_size(
555-
solver_state.primal_weight,
556-
solver_state.solutions_count,
557-
solver_state.step_size,
559+
step_size = jax.lax.cond(
560+
problem.is_lp,
561+
lambda: solver_state.step_size,
562+
lambda: self.calculate_constant_step_size(
563+
solver_state.primal_weight,
564+
solver_state.solutions_count,
565+
solver_state.step_size,
566+
),
558567
)
559568
delta_primal, delta_primal_product, delta_dual = compute_next_solution(
560569
problem, solver_state, step_size, extrapolation_coefficient
@@ -959,7 +968,10 @@ def optimize(
959968
logger.info("Preconditioning Time (seconds): %.2e", precondition_time)
960969

961970
solver_state, last_restart_info = self.initialize_solver_status(
962-
scaled_problem, initial_primal_solution, initial_dual_solution
971+
scaled_problem,
972+
initial_primal_solution,
973+
initial_dual_solution,
974+
original_problem.is_lp,
963975
)
964976

965977
# Iteration loop

tests/rapdhg_test.py

+11
Original file line numberDiff line numberDiff line change
@@ -35,6 +35,17 @@ def test_rapdhg_lp():
3535
assert pytest.approx(result.primal_objective, rel=1e-2) == expected_obj
3636

3737

38+
def test_rapdhg_lp_constant_stepsize():
39+
"""Test the raPDHG solver on a sample LP problem."""
40+
for model_filename, expected_obj in lp_model_objs.items():
41+
gurobi_model = gp.read(pytest_cache_dir + "/" + model_filename)
42+
qp = create_qp_from_gurobi(gurobi_model)
43+
solver = raPDHG(adaptive_step_size=False, eps_abs=1e-6, eps_rel=1e-6)
44+
result = solver.optimize(qp)
45+
46+
assert pytest.approx(result.primal_objective, rel=1e-2) == expected_obj
47+
48+
3849
def test_rapdhg_qp():
3950
"""Test the raPDHG solver on a sample LP problem."""
4051
for model_filename, expected_obj in qp_model_objs.items():

0 commit comments

Comments
 (0)