4
4
5
5
import logging
6
6
import os
7
- from typing import Dict , List , Any , Optional
7
+ from typing import Dict , List , Any , Optional , Callable
8
8
9
9
import pandas as pd
10
10
11
11
from miplearn .components .component import Component
12
12
from miplearn .instance .base import Instance
13
- from miplearn .solvers .learning import LearningSolver
13
+ from miplearn .solvers .learning import LearningSolver , FileInstanceWrapper
14
14
from miplearn .solvers .pyomo .gurobi import GurobiPyomoSolver
15
15
from sklearn .utils ._testing import ignore_warnings
16
16
from sklearn .exceptions import ConvergenceWarning
@@ -43,37 +43,24 @@ def __init__(self, solvers: Dict[str, LearningSolver]) -> None:
43
43
44
44
def parallel_solve (
45
45
self ,
46
- instances : List [Instance ],
46
+ filenames : List [str ],
47
+ build_model : Callable ,
47
48
n_jobs : int = 1 ,
48
49
n_trials : int = 1 ,
49
50
progress : bool = False ,
50
51
) -> None :
51
- """
52
- Solves the given instances in parallel and collect benchmark statistics.
53
-
54
- Parameters
55
- ----------
56
- instances: List[Instance]
57
- List of instances to solve. This can either be a list of instances
58
- already loaded in memory, or a list of filenames pointing to pickled (and
59
- optionally gzipped) files.
60
- n_jobs: int
61
- List of instances to solve in parallel at a time.
62
- n_trials: int
63
- How many times each instance should be solved.
64
- """
65
52
self ._silence_miplearn_logger ()
66
- trials = instances * n_trials
53
+ trials = filenames * n_trials
67
54
for (solver_name , solver ) in self .solvers .items ():
68
55
results = solver .parallel_solve (
69
56
trials ,
57
+ build_model ,
70
58
n_jobs = n_jobs ,
71
- label = "solve (%s)" % solver_name ,
72
- discard_outputs = True ,
59
+ label = "benchmark (%s)" % solver_name ,
73
60
progress = progress ,
74
61
)
75
62
for i in range (len (trials )):
76
- idx = i % len (instances )
63
+ idx = i % len (filenames )
77
64
results [i ]["Solver" ] = solver_name
78
65
results [i ]["Instance" ] = idx
79
66
self .results = self .results .append (pd .DataFrame ([results [i ]]))
@@ -93,21 +80,15 @@ def write_csv(self, filename: str) -> None:
93
80
94
81
def fit (
95
82
self ,
96
- instances : List [Instance ],
83
+ filenames : List [str ],
84
+ build_model : Callable ,
85
+ progress : bool = False ,
97
86
n_jobs : int = 1 ,
98
- progress : bool = True ,
99
87
) -> None :
100
- """
101
- Trains all solvers with the provided training instances.
102
-
103
- Parameters
104
- ----------
105
- instances: List[Instance]
106
- List of training instances.
107
- n_jobs: int
108
- Number of parallel processes to use.
109
- """
110
- components : List [Component ] = []
88
+ components = []
89
+ instances : List [Instance ] = [
90
+ FileInstanceWrapper (f , build_model , mode = "r" ) for f in filenames
91
+ ]
111
92
for (solver_name , solver ) in self .solvers .items ():
112
93
if solver_name == "baseline" :
113
94
continue
@@ -128,6 +109,114 @@ def _restore_miplearn_logger(self) -> None:
128
109
miplearn_logger = logging .getLogger ("miplearn" )
129
110
miplearn_logger .setLevel (self .prev_log_level )
130
111
112
+ def write_svg (
113
+ self ,
114
+ output : Optional [str ] = None ,
115
+ ) -> None :
116
+ import matplotlib .pyplot as plt
117
+ import pandas as pd
118
+ import seaborn as sns
119
+
120
+ sns .set_style ("whitegrid" )
121
+ sns .set_palette ("Blues_r" )
122
+ groups = self .results .groupby ("Instance" )
123
+ best_lower_bound = groups ["mip_lower_bound" ].transform ("max" )
124
+ best_upper_bound = groups ["mip_upper_bound" ].transform ("min" )
125
+ self .results ["Relative lower bound" ] = self .results ["mip_lower_bound" ] / best_lower_bound
126
+ self .results ["Relative upper bound" ] = self .results ["mip_upper_bound" ] / best_upper_bound
127
+
128
+ if (self .results ["mip_sense" ] == "min" ).any ():
129
+ primal_column = "Relative upper bound"
130
+ obj_column = "mip_upper_bound"
131
+ predicted_obj_column = "Objective: Predicted upper bound"
132
+ else :
133
+ primal_column = "Relative lower bound"
134
+ obj_column = "mip_lower_bound"
135
+ predicted_obj_column = "Objective: Predicted lower bound"
136
+
137
+ palette = {
138
+ "baseline" : "#9b59b6" ,
139
+ "ml-exact" : "#3498db" ,
140
+ "ml-heuristic" : "#95a5a6" ,
141
+ }
142
+ fig , ((ax1 , ax2 ), (ax3 , ax4 )) = plt .subplots (
143
+ nrows = 2 ,
144
+ ncols = 2 ,
145
+ figsize = (8 , 8 ),
146
+ )
147
+
148
+ # Wallclock time
149
+ sns .stripplot (
150
+ x = "Solver" ,
151
+ y = "mip_wallclock_time" ,
152
+ data = self .results ,
153
+ ax = ax1 ,
154
+ jitter = 0.25 ,
155
+ palette = palette ,
156
+ size = 2.0 ,
157
+ )
158
+ sns .barplot (
159
+ x = "Solver" ,
160
+ y = "mip_wallclock_time" ,
161
+ data = self .results ,
162
+ ax = ax1 ,
163
+ errwidth = 0.0 ,
164
+ alpha = 0.4 ,
165
+ palette = palette ,
166
+ )
167
+ ax1 .set (ylabel = "Wallclock time (s)" )
168
+
169
+ # Gap
170
+ sns .stripplot (
171
+ x = "Solver" ,
172
+ y = "Gap" ,
173
+ jitter = 0.25 ,
174
+ data = self .results [self .results ["Solver" ] != "ml-heuristic" ],
175
+ ax = ax2 ,
176
+ palette = palette ,
177
+ size = 2.0 ,
178
+ )
179
+ ax2 .set (ylabel = "Relative MIP gap" )
180
+
181
+ # Relative primal bound
182
+ sns .stripplot (
183
+ x = "Solver" ,
184
+ y = primal_column ,
185
+ jitter = 0.25 ,
186
+ data = self .results [self .results ["Solver" ] == "ml-heuristic" ],
187
+ ax = ax3 ,
188
+ palette = palette ,
189
+ size = 2.0 ,
190
+ )
191
+ sns .scatterplot (
192
+ x = obj_column ,
193
+ y = predicted_obj_column ,
194
+ hue = "Solver" ,
195
+ data = self .results [self .results ["Solver" ] == "ml-exact" ],
196
+ ax = ax4 ,
197
+ palette = palette ,
198
+ size = 2.0 ,
199
+ )
200
+
201
+ # Predicted vs actual primal bound
202
+ xlim , ylim = ax4 .get_xlim (), ax4 .get_ylim ()
203
+ ax4 .plot (
204
+ [- 1e10 , 1e10 ],
205
+ [- 1e10 , 1e10 ],
206
+ ls = "-" ,
207
+ color = "#cccccc" ,
208
+ )
209
+ ax4 .set_xlim (xlim )
210
+ ax4 .set_ylim (xlim )
211
+ ax4 .get_legend ().remove ()
212
+ ax4 .set (
213
+ ylabel = "Predicted optimal value" ,
214
+ xlabel = "Actual optimal value" ,
215
+ )
216
+
217
+ fig .tight_layout ()
218
+ plt .savefig (output )
219
+
131
220
132
221
@ignore_warnings (category = ConvergenceWarning )
133
222
def run_benchmarks (
@@ -173,111 +262,3 @@ def run_benchmarks(
173
262
plot (benchmark .results )
174
263
175
264
176
- def plot (
177
- results : pd .DataFrame ,
178
- output : Optional [str ] = None ,
179
- ) -> None :
180
- import matplotlib .pyplot as plt
181
- import pandas as pd
182
- import seaborn as sns
183
-
184
- sns .set_style ("whitegrid" )
185
- sns .set_palette ("Blues_r" )
186
- groups = results .groupby ("Instance" )
187
- best_lower_bound = groups ["mip_lower_bound" ].transform ("max" )
188
- best_upper_bound = groups ["mip_upper_bound" ].transform ("min" )
189
- results ["Relative lower bound" ] = results ["mip_lower_bound" ] / best_lower_bound
190
- results ["Relative upper bound" ] = results ["mip_upper_bound" ] / best_upper_bound
191
-
192
- if (results ["mip_sense" ] == "min" ).any ():
193
- primal_column = "Relative upper bound"
194
- obj_column = "mip_upper_bound"
195
- predicted_obj_column = "Objective: Predicted upper bound"
196
- else :
197
- primal_column = "Relative lower bound"
198
- obj_column = "mip_lower_bound"
199
- predicted_obj_column = "Objective: Predicted lower bound"
200
-
201
- palette = {
202
- "baseline" : "#9b59b6" ,
203
- "ml-exact" : "#3498db" ,
204
- "ml-heuristic" : "#95a5a6" ,
205
- }
206
- fig , ((ax1 , ax2 ), (ax3 , ax4 )) = plt .subplots (
207
- nrows = 2 ,
208
- ncols = 2 ,
209
- figsize = (8 , 8 ),
210
- )
211
-
212
- # Wallclock time
213
- sns .stripplot (
214
- x = "Solver" ,
215
- y = "mip_wallclock_time" ,
216
- data = results ,
217
- ax = ax1 ,
218
- jitter = 0.25 ,
219
- palette = palette ,
220
- size = 2.0 ,
221
- )
222
- sns .barplot (
223
- x = "Solver" ,
224
- y = "mip_wallclock_time" ,
225
- data = results ,
226
- ax = ax1 ,
227
- errwidth = 0.0 ,
228
- alpha = 0.4 ,
229
- palette = palette ,
230
- )
231
- ax1 .set (ylabel = "Wallclock time (s)" )
232
-
233
- # Gap
234
- sns .stripplot (
235
- x = "Solver" ,
236
- y = "Gap" ,
237
- jitter = 0.25 ,
238
- data = results [results ["Solver" ] != "ml-heuristic" ],
239
- ax = ax2 ,
240
- palette = palette ,
241
- size = 2.0 ,
242
- )
243
- ax2 .set (ylabel = "Relative MIP gap" )
244
-
245
- # Relative primal bound
246
- sns .stripplot (
247
- x = "Solver" ,
248
- y = primal_column ,
249
- jitter = 0.25 ,
250
- data = results [results ["Solver" ] == "ml-heuristic" ],
251
- ax = ax3 ,
252
- palette = palette ,
253
- size = 2.0 ,
254
- )
255
- sns .scatterplot (
256
- x = obj_column ,
257
- y = predicted_obj_column ,
258
- hue = "Solver" ,
259
- data = results [results ["Solver" ] == "ml-exact" ],
260
- ax = ax4 ,
261
- palette = palette ,
262
- size = 2.0 ,
263
- )
264
-
265
- # Predicted vs actual primal bound
266
- xlim , ylim = ax4 .get_xlim (), ax4 .get_ylim ()
267
- ax4 .plot (
268
- [- 1e10 , 1e10 ],
269
- [- 1e10 , 1e10 ],
270
- ls = "-" ,
271
- color = "#cccccc" ,
272
- )
273
- ax4 .set_xlim (xlim )
274
- ax4 .set_ylim (ylim )
275
- ax4 .get_legend ().remove ()
276
- ax4 .set (
277
- ylabel = "Predicted value" ,
278
- xlabel = "Actual value" ,
279
- )
280
-
281
- fig .tight_layout ()
282
- if output is not None :
283
- plt .savefig (output )
0 commit comments