@@ -82,6 +82,8 @@ def estimate_maximum_singular_value(
82
82
matrix_transpose = BCSR .from_bcoo (matrix .to_bcoo ().T )
83
83
elif isinstance (matrix , BCOO ):
84
84
matrix_transpose = BCSR .from_bcoo (matrix .T )
85
+ elif isinstance (matrix , jnp .ndarray ):
86
+ matrix_transpose = matrix .T
85
87
number_of_power_iterations = 0
86
88
87
89
def cond_fun (state ):
@@ -401,6 +403,7 @@ def initialize_solver_status(
401
403
scaled_problem : ScaledQpProblem ,
402
404
initial_primal_solution : jnp .array ,
403
405
initial_dual_solution : jnp .array ,
406
+ is_lp : bool = True ,
404
407
) -> PdhgSolverState :
405
408
"""Initialize the solver status for PDHG.
406
409
@@ -451,9 +454,11 @@ def initialize_solver_status(
451
454
self ._norm_A = estimate_maximum_singular_value (scaled_qp .constraint_matrix )[
452
455
0
453
456
]
454
- self ._norm_Q = estimate_maximum_singular_value (scaled_qp .objective_matrix )[
455
- 0
456
- ]
457
+ self ._norm_Q = jax .lax .cond (
458
+ is_lp ,
459
+ lambda : 0.0 ,
460
+ lambda : estimate_maximum_singular_value (scaled_qp .objective_matrix )[0 ],
461
+ )
457
462
step_size = 1.0 # Placeholder for step size.
458
463
459
464
if self .warm_start :
@@ -551,10 +556,14 @@ def take_step(
551
556
extrapolation_coefficient = solver_state .solutions_count / (
552
557
solver_state .solutions_count + 1.0
553
558
)
554
- step_size = self .calculate_constant_step_size (
555
- solver_state .primal_weight ,
556
- solver_state .solutions_count ,
557
- solver_state .step_size ,
559
+ step_size = jax .lax .cond (
560
+ problem .is_lp ,
561
+ lambda : solver_state .step_size ,
562
+ lambda : self .calculate_constant_step_size (
563
+ solver_state .primal_weight ,
564
+ solver_state .solutions_count ,
565
+ solver_state .step_size ,
566
+ ),
558
567
)
559
568
delta_primal , delta_primal_product , delta_dual = compute_next_solution (
560
569
problem , solver_state , step_size , extrapolation_coefficient
@@ -959,7 +968,10 @@ def optimize(
959
968
logger .info ("Preconditioning Time (seconds): %.2e" , precondition_time )
960
969
961
970
solver_state , last_restart_info = self .initialize_solver_status (
962
- scaled_problem , initial_primal_solution , initial_dual_solution
971
+ scaled_problem ,
972
+ initial_primal_solution ,
973
+ initial_dual_solution ,
974
+ original_problem .is_lp ,
963
975
)
964
976
965
977
# Iteration loop
0 commit comments