Skip to content

Commit a1ae913

Browse files
authored
added functionality + test to add a custom callback and interrupt the… (#843)
* added functionality + test to add a custom callback and interrupt the solving * fix types
1 parent 140f9e6 commit a1ae913

File tree

2 files changed

+50
-16
lines changed

2 files changed

+50
-16
lines changed

pulp/apis/highs_api.py

Lines changed: 20 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33
import os
44
import subprocess
55
from math import inf
6-
from typing import List
6+
from typing import List, Optional
77

88
from .. import constants
99
from .core import LpSolver, LpSolver_CMD, PulpSolverError
@@ -271,13 +271,17 @@ def readsol(self, filename):
271271
return values
272272

273273

274+
highspy = None
275+
276+
274277
class HiGHS(LpSolver):
275278
name = "HiGHS"
276279

277280
try:
278281
global highspy
279282
import highspy # type: ignore[import-not-found, import-untyped, unused-ignore]
280283
except:
284+
hscb = None
281285

282286
def available(self):
283287
"""True if the solver is available"""
@@ -288,13 +292,7 @@ def actualSolve(self, lp, callback=None):
288292
raise PulpSolverError("HiGHS: Not Available")
289293

290294
else:
291-
# Note(maciej): It was surprising to me that higshpy wasn't logging out of the box,
292-
# even with the different logging options set. This callback seems to work, but there
293-
# are probably better ways of doing this ¯\_(ツ)_/¯
294-
DEFAULT_CALLBACK = lambda logType, logMsg, callbackValue: print(
295-
f"[{logType.name}] {logMsg}"
296-
)
297-
DEFAULT_CALLBACK_VALUE = ""
295+
hscb = highspy.cb # type: ignore[attr-defined, unused-ignore]
298296

299297
def __init__(
300298
self,
@@ -305,21 +303,23 @@ def __init__(
305303
gapRel=None,
306304
threads=None,
307305
timeLimit=None,
306+
callbacksToActivate: Optional[List[highspy.cb.HighsCallbackType]] = None,
308307
**solverParams,
309308
):
310309
"""
311310
:param bool mip: if False, assume LP even if integer variables
312311
:param bool msg: if False, no log is shown
313-
:param tuple callbackTuple: Tuple of log callback function (see DEFAULT_CALLBACK above for definition)
314-
and callbackValue (tag embedded in every callback)
312+
:param tuple callbackTuple: Tuple of callback function and callbackValue (see tests for an example)
315313
:param float gapRel: relative gap tolerance for the solver to stop (in fraction)
316314
:param float gapAbs: absolute gap tolerance for the solver to stop
317315
:param int threads: sets the maximum number of threads
318316
:param float timeLimit: maximum time for solver (in seconds)
319317
:param dict solverParams: list of named options to pass directly to the HiGHS solver
318+
:param callbacksToActivate: list of callback types to start
320319
"""
321320
super().__init__(mip=mip, msg=msg, timeLimit=timeLimit, **solverParams)
322321
self.callbackTuple = callbackTuple
322+
self.callbacksToActivate = callbacksToActivate
323323
self.gapAbs = gapAbs
324324
self.gapRel = gapRel
325325
self.threads = threads
@@ -333,12 +333,12 @@ def callSolver(self, lp):
333333
def createAndConfigureSolver(self, lp):
334334
lp.solverModel = highspy.Highs()
335335

336-
if self.msg and self.callbackTuple:
337-
callbackTuple = self.callbackTuple or (
338-
HiGHS.DEFAULT_CALLBACK,
339-
HiGHS.DEFAULT_CALLBACK_VALUE,
340-
)
341-
lp.solverModel.setLogCallback(*callbackTuple)
336+
if self.callbackTuple:
337+
lp.solverModel.setCallback(*self.callbackTuple)
338+
339+
if self.callbacksToActivate:
340+
for cb_type in self.callbacksToActivate:
341+
lp.solverModel.startCallback(cb_type)
342342

343343
if not self.msg:
344344
lp.solverModel.setOptionValue("output_flag", False)
@@ -465,6 +465,10 @@ def findSolutionValues(self, lp):
465465
constants.LpStatusOptimal,
466466
constants.LpSolutionIntegerFeasible,
467467
),
468+
HighsModelStatus.kInterrupt: (
469+
constants.LpStatusOptimal,
470+
constants.LpSolutionIntegerFeasible,
471+
),
468472
HighsModelStatus.kTimeLimit: (
469473
constants.LpStatusOptimal,
470474
constants.LpSolutionIntegerFeasible,

pulp/tests/test_pulp.py

Lines changed: 30 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2194,6 +2194,36 @@ class SCIP_PYTest(BaseSolverTest.PuLPTest):
21942194
class HiGHS_PYTest(BaseSolverTest.PuLPTest):
21952195
solveInst = HiGHS
21962196

2197+
def test_callback(self):
2198+
prob = create_bin_packing_problem(bins=40, seed=99)
2199+
2200+
# we pass a list as data to the tuple, so we can edit it.
2201+
# then we count the number of calls and stop the solving
2202+
# for more information on the callback, see: github.com/ERGO-Code/HiGHS @ examples/call_highs_from_python
2203+
def user_callback(
2204+
callback_type, message, data_out, data_in, user_callback_data
2205+
):
2206+
#
2207+
if callback_type == HiGHS.hscb.HighsCallbackType.kCallbackMipInterrupt:
2208+
print(
2209+
f"userInterruptCallback(type {callback_type}); "
2210+
f"data {user_callback_data};"
2211+
f"message: {message};"
2212+
f"objective {data_out.objective_function_value:.4g};"
2213+
)
2214+
print(f"Dual bound = {data_out.mip_dual_bound:.4g}")
2215+
print(f"Primal bound = {data_out.mip_primal_bound:.4g}")
2216+
print(f"Gap = {data_out.mip_gap:.4g}")
2217+
if isinstance(user_callback_data, list):
2218+
user_callback_data.append(1)
2219+
data_in.user_interrupt = len(user_callback_data) > 5
2220+
2221+
solver = HiGHS(
2222+
callbackTuple=(user_callback, []),
2223+
callbacksToActivate=[HiGHS.hscb.HighsCallbackType.kCallbackMipInterrupt],
2224+
)
2225+
status = prob.solve(solver)
2226+
21972227

21982228
class HiGHS_CMDTest(BaseSolverTest.PuLPTest):
21992229
solveInst = HiGHS_CMD

0 commit comments

Comments
 (0)