Skip to content

Commit 6cc253a

Browse files
committed
Update
1 parent 3fd2526 commit 6cc253a

File tree

15 files changed

+591
-275
lines changed

15 files changed

+591
-275
lines changed

Makefile

+4-4
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@ PYTHON := python3
22
PYTEST := pytest
33
PIP := $(PYTHON) -m pip
44
MYPY := $(PYTHON) -m mypy
5-
PYTEST_ARGS := -W ignore::DeprecationWarning -vv --log-level=DEBUG
5+
PYTEST_ARGS := -W ignore::DeprecationWarning -vv --log-level=DEBUG tests
66
VERSION := 0.2
77

88
all: docs test
@@ -41,9 +41,9 @@ reformat:
4141
$(PYTHON) -m black .
4242

4343
test:
44-
rm -rf .mypy_cache
45-
$(MYPY) -p miplearn
46-
$(MYPY) -p tests
44+
# rm -rf .mypy_cache
45+
# $(MYPY) -p miplearn
46+
# $(MYPY) -p tests
4747
$(PYTEST) $(PYTEST_ARGS)
4848

4949
.PHONY: test test-watch docs install dist

miplearn/benchmark.py

+123-142
Original file line numberDiff line numberDiff line change
@@ -4,13 +4,13 @@
44

55
import logging
66
import os
7-
from typing import Dict, List, Any, Optional
7+
from typing import Dict, List, Any, Optional, Callable
88

99
import pandas as pd
1010

1111
from miplearn.components.component import Component
1212
from miplearn.instance.base import Instance
13-
from miplearn.solvers.learning import LearningSolver
13+
from miplearn.solvers.learning import LearningSolver, FileInstanceWrapper
1414
from miplearn.solvers.pyomo.gurobi import GurobiPyomoSolver
1515
from sklearn.utils._testing import ignore_warnings
1616
from sklearn.exceptions import ConvergenceWarning
@@ -43,37 +43,24 @@ def __init__(self, solvers: Dict[str, LearningSolver]) -> None:
4343

4444
def parallel_solve(
4545
self,
46-
instances: List[Instance],
46+
filenames: List[str],
47+
build_model: Callable,
4748
n_jobs: int = 1,
4849
n_trials: int = 1,
4950
progress: bool = False,
5051
) -> None:
51-
"""
52-
Solves the given instances in parallel and collect benchmark statistics.
53-
54-
Parameters
55-
----------
56-
instances: List[Instance]
57-
List of instances to solve. This can either be a list of instances
58-
already loaded in memory, or a list of filenames pointing to pickled (and
59-
optionally gzipped) files.
60-
n_jobs: int
61-
List of instances to solve in parallel at a time.
62-
n_trials: int
63-
How many times each instance should be solved.
64-
"""
6552
self._silence_miplearn_logger()
66-
trials = instances * n_trials
53+
trials = filenames * n_trials
6754
for (solver_name, solver) in self.solvers.items():
6855
results = solver.parallel_solve(
6956
trials,
57+
build_model,
7058
n_jobs=n_jobs,
71-
label="solve (%s)" % solver_name,
72-
discard_outputs=True,
59+
label="benchmark (%s)" % solver_name,
7360
progress=progress,
7461
)
7562
for i in range(len(trials)):
76-
idx = i % len(instances)
63+
idx = i % len(filenames)
7764
results[i]["Solver"] = solver_name
7865
results[i]["Instance"] = idx
7966
self.results = self.results.append(pd.DataFrame([results[i]]))
@@ -93,21 +80,15 @@ def write_csv(self, filename: str) -> None:
9380

9481
def fit(
9582
self,
96-
instances: List[Instance],
83+
filenames: List[str],
84+
build_model: Callable,
85+
progress: bool = False,
9786
n_jobs: int = 1,
98-
progress: bool = True,
9987
) -> None:
100-
"""
101-
Trains all solvers with the provided training instances.
102-
103-
Parameters
104-
----------
105-
instances: List[Instance]
106-
List of training instances.
107-
n_jobs: int
108-
Number of parallel processes to use.
109-
"""
110-
components: List[Component] = []
88+
components = []
89+
instances: List[Instance] = [
90+
FileInstanceWrapper(f, build_model, mode="r") for f in filenames
91+
]
11192
for (solver_name, solver) in self.solvers.items():
11293
if solver_name == "baseline":
11394
continue
@@ -128,6 +109,114 @@ def _restore_miplearn_logger(self) -> None:
128109
miplearn_logger = logging.getLogger("miplearn")
129110
miplearn_logger.setLevel(self.prev_log_level)
130111

112+
def write_svg(
113+
self,
114+
output: Optional[str] = None,
115+
) -> None:
116+
import matplotlib.pyplot as plt
117+
import pandas as pd
118+
import seaborn as sns
119+
120+
sns.set_style("whitegrid")
121+
sns.set_palette("Blues_r")
122+
groups = self.results.groupby("Instance")
123+
best_lower_bound = groups["mip_lower_bound"].transform("max")
124+
best_upper_bound = groups["mip_upper_bound"].transform("min")
125+
self.results["Relative lower bound"] = self.results["mip_lower_bound"] / best_lower_bound
126+
self.results["Relative upper bound"] = self.results["mip_upper_bound"] / best_upper_bound
127+
128+
if (self.results["mip_sense"] == "min").any():
129+
primal_column = "Relative upper bound"
130+
obj_column = "mip_upper_bound"
131+
predicted_obj_column = "Objective: Predicted upper bound"
132+
else:
133+
primal_column = "Relative lower bound"
134+
obj_column = "mip_lower_bound"
135+
predicted_obj_column = "Objective: Predicted lower bound"
136+
137+
palette = {
138+
"baseline": "#9b59b6",
139+
"ml-exact": "#3498db",
140+
"ml-heuristic": "#95a5a6",
141+
}
142+
fig, ((ax1, ax2), (ax3, ax4)) = plt.subplots(
143+
nrows=2,
144+
ncols=2,
145+
figsize=(8, 8),
146+
)
147+
148+
# Wallclock time
149+
sns.stripplot(
150+
x="Solver",
151+
y="mip_wallclock_time",
152+
data=self.results,
153+
ax=ax1,
154+
jitter=0.25,
155+
palette=palette,
156+
size=2.0,
157+
)
158+
sns.barplot(
159+
x="Solver",
160+
y="mip_wallclock_time",
161+
data=self.results,
162+
ax=ax1,
163+
errwidth=0.0,
164+
alpha=0.4,
165+
palette=palette,
166+
)
167+
ax1.set(ylabel="Wallclock time (s)")
168+
169+
# Gap
170+
sns.stripplot(
171+
x="Solver",
172+
y="Gap",
173+
jitter=0.25,
174+
data=self.results[self.results["Solver"] != "ml-heuristic"],
175+
ax=ax2,
176+
palette=palette,
177+
size=2.0,
178+
)
179+
ax2.set(ylabel="Relative MIP gap")
180+
181+
# Relative primal bound
182+
sns.stripplot(
183+
x="Solver",
184+
y=primal_column,
185+
jitter=0.25,
186+
data=self.results[self.results["Solver"] == "ml-heuristic"],
187+
ax=ax3,
188+
palette=palette,
189+
size=2.0,
190+
)
191+
sns.scatterplot(
192+
x=obj_column,
193+
y=predicted_obj_column,
194+
hue="Solver",
195+
data=self.results[self.results["Solver"] == "ml-exact"],
196+
ax=ax4,
197+
palette=palette,
198+
size=2.0,
199+
)
200+
201+
# Predicted vs actual primal bound
202+
xlim, ylim = ax4.get_xlim(), ax4.get_ylim()
203+
ax4.plot(
204+
[-1e10, 1e10],
205+
[-1e10, 1e10],
206+
ls="-",
207+
color="#cccccc",
208+
)
209+
ax4.set_xlim(xlim)
210+
ax4.set_ylim(xlim)
211+
ax4.get_legend().remove()
212+
ax4.set(
213+
ylabel="Predicted optimal value",
214+
xlabel="Actual optimal value",
215+
)
216+
217+
fig.tight_layout()
218+
plt.savefig(output)
219+
131220

132221
@ignore_warnings(category=ConvergenceWarning)
133222
def run_benchmarks(
@@ -173,111 +262,3 @@ def run_benchmarks(
173262
plot(benchmark.results)
174263

175264

176-
def plot(
177-
results: pd.DataFrame,
178-
output: Optional[str] = None,
179-
) -> None:
180-
import matplotlib.pyplot as plt
181-
import pandas as pd
182-
import seaborn as sns
183-
184-
sns.set_style("whitegrid")
185-
sns.set_palette("Blues_r")
186-
groups = results.groupby("Instance")
187-
best_lower_bound = groups["mip_lower_bound"].transform("max")
188-
best_upper_bound = groups["mip_upper_bound"].transform("min")
189-
results["Relative lower bound"] = results["mip_lower_bound"] / best_lower_bound
190-
results["Relative upper bound"] = results["mip_upper_bound"] / best_upper_bound
191-
192-
if (results["mip_sense"] == "min").any():
193-
primal_column = "Relative upper bound"
194-
obj_column = "mip_upper_bound"
195-
predicted_obj_column = "Objective: Predicted upper bound"
196-
else:
197-
primal_column = "Relative lower bound"
198-
obj_column = "mip_lower_bound"
199-
predicted_obj_column = "Objective: Predicted lower bound"
200-
201-
palette = {
202-
"baseline": "#9b59b6",
203-
"ml-exact": "#3498db",
204-
"ml-heuristic": "#95a5a6",
205-
}
206-
fig, ((ax1, ax2), (ax3, ax4)) = plt.subplots(
207-
nrows=2,
208-
ncols=2,
209-
figsize=(8, 8),
210-
)
211-
212-
# Wallclock time
213-
sns.stripplot(
214-
x="Solver",
215-
y="mip_wallclock_time",
216-
data=results,
217-
ax=ax1,
218-
jitter=0.25,
219-
palette=palette,
220-
size=2.0,
221-
)
222-
sns.barplot(
223-
x="Solver",
224-
y="mip_wallclock_time",
225-
data=results,
226-
ax=ax1,
227-
errwidth=0.0,
228-
alpha=0.4,
229-
palette=palette,
230-
)
231-
ax1.set(ylabel="Wallclock time (s)")
232-
233-
# Gap
234-
sns.stripplot(
235-
x="Solver",
236-
y="Gap",
237-
jitter=0.25,
238-
data=results[results["Solver"] != "ml-heuristic"],
239-
ax=ax2,
240-
palette=palette,
241-
size=2.0,
242-
)
243-
ax2.set(ylabel="Relative MIP gap")
244-
245-
# Relative primal bound
246-
sns.stripplot(
247-
x="Solver",
248-
y=primal_column,
249-
jitter=0.25,
250-
data=results[results["Solver"] == "ml-heuristic"],
251-
ax=ax3,
252-
palette=palette,
253-
size=2.0,
254-
)
255-
sns.scatterplot(
256-
x=obj_column,
257-
y=predicted_obj_column,
258-
hue="Solver",
259-
data=results[results["Solver"] == "ml-exact"],
260-
ax=ax4,
261-
palette=palette,
262-
size=2.0,
263-
)
264-
265-
# Predicted vs actual primal bound
266-
xlim, ylim = ax4.get_xlim(), ax4.get_ylim()
267-
ax4.plot(
268-
[-1e10, 1e10],
269-
[-1e10, 1e10],
270-
ls="-",
271-
color="#cccccc",
272-
)
273-
ax4.set_xlim(xlim)
274-
ax4.set_ylim(ylim)
275-
ax4.get_legend().remove()
276-
ax4.set(
277-
ylabel="Predicted value",
278-
xlabel="Actual value",
279-
)
280-
281-
fig.tight_layout()
282-
if output is not None:
283-
plt.savefig(output)

miplearn/classifiers/__init__.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -38,7 +38,7 @@ def fit(self, x_train: np.ndarray, y_train: np.ndarray) -> None:
3838
np.float16,
3939
np.float32,
4040
np.float64,
41-
], f"x_train.dtype shoule be float. Found {x_train.dtype} instead."
41+
], f"x_train.dtype should be float. Found {x_train.dtype} instead."
4242
assert y_train.dtype == np.bool8
4343
assert len(x_train.shape) == 2
4444
assert len(y_train.shape) == 2

0 commit comments

Comments
 (0)