Skip to content

Commit ef40472

Browse files
committed
move param_combinations to cpmpy.solvers.utils
1 parent 0539a39 commit ef40472

File tree

4 files changed

+52
-38
lines changed

4 files changed

+52
-38
lines changed

cpmpy/solvers/__init__.py

+9-1
Original file line numberDiff line numberDiff line change
@@ -23,9 +23,17 @@
2323
2424
CPM_ortools
2525
CPM_pysat
26+
27+
=================
28+
List of functions
29+
=================
30+
.. autosummary::
31+
:nosignatures:
32+
33+
param_combinations
2634
"""
2735

28-
from .utils import builtin_solvers, get_supported_solvers
36+
from .utils import builtin_solvers, get_supported_solvers, param_combinations
2937
from .ortools import CPM_ortools
3038
from .pysat import CPM_pysat
3139
# from minizinc import CPMpyMiniZinc # closed for maintenance

cpmpy/solvers/utils.py

+33
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@
1717
:nosignatures:
1818
1919
get_supported_solvers
20+
param_combinations
2021
"""
2122

2223
#from .minizinc import CPMpyMiniZinc # closed for maintenance
@@ -31,5 +32,37 @@ def get_supported_solvers():
3132
"""
3233
return [sv for sv in builtin_solvers if sv.supported()]
3334

35+
def param_combinations(all_params, remaining_keys=None, cur_params=None):
36+
"""
37+
Recursively yield all combinations of param values
38+
39+
For example usage, see `examples/advanced/hyperparameter_search.py`
40+
41+
- all_params is a dict of {key: list} items, e.g.:
42+
{'val': [1,2], 'opt': [True,False]}
43+
44+
- output is an generator over all {key:value} combinations
45+
of the keys and values. For the example above:
46+
generator([{'val':1,'opt':True},{'val':1,'opt':False},{'val':2,'opt':True},{'val':2,'opt':False}])
47+
"""
48+
if remaining_keys is None or cur_params is None:
49+
# init
50+
remaining_keys = list(all_params.keys())
51+
cur_params = dict()
52+
53+
cur_key = remaining_keys[0]
54+
myresults = [] # (runtime, cur_params)
55+
for cur_value in all_params[cur_key]:
56+
cur_params[cur_key] = cur_value
57+
if len(remaining_keys) == 1:
58+
# terminal, return copy
59+
yield dict(cur_params)
60+
else:
61+
# recursive call
62+
yield from param_combinations(all_params,
63+
remaining_keys=remaining_keys[1:],
64+
cur_params=cur_params)
65+
66+
3467
# Order matters! first is default, then tries second, etc...
3568
builtin_solvers=[CPM_ortools,CPM_pysat]

docs/solver_parameters.md

+9-6
Original file line numberDiff line numberDiff line change
@@ -29,16 +29,19 @@ s.solve(num_search_workers=8, log_search_progress=True)
2929
## Hyperparameter search across different parameters
3030
Because CPMpy offers programmatic access to the solver API, hyperparameter search can be straightforwardly done with little overhead between the calls.
3131

32-
The full example is in [examples/advanced/hyperparameter_search.py](examples/advanced/hyperparameter_search.py), here is a relevant excrept:
32+
The cpmpy.solvers module has a helper function `param_combinations` that generates all parameter combinations of an input, which can then be looped over.
33+
34+
The example is in [examples/advanced/hyperparameter_search.py](examples/advanced/hyperparameter_search.py), the key part is:
3335

3436
```python
37+
from cpmpy.solvers import CPM_ortools, param_combinations
38+
3539
params = {'cp_model_probing_level': [0,1,2,3],
3640
'linearization_level': [0,1,2],
3741
'symmetry_level': [0,1,2]}
3842

39-
configs = gridsearch(model, CPM_ortools, params)
40-
41-
best = configs[0]
42-
print("Best config:", best[1])
43-
print(" with runtime:", round(best[0],2))
43+
for params in param_combinations(all_params):
44+
s = CPM_ortools(model)
45+
s.solve(**params)
46+
print(s.status().runtime, "seconds for config", params)
4447
```

examples/advanced/hyperparameter_search.py

+1-31
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@
55
"""
66
import sys
77
from cpmpy import *
8-
from cpmpy.solvers import CPM_ortools
8+
from cpmpy.solvers import CPM_ortools, param_combinations
99
from cpmpy.transformations.flatten_model import flatten_model
1010

1111
def main():
@@ -45,36 +45,6 @@ def main():
4545
# With default parameters: 0.13
4646

4747

48-
def param_combinations(all_params, remaining_keys=None, cur_params=None):
49-
"""
50-
Recursively yield all combinations of param values
51-
52-
- all_params is a dict of {key: list} items, e.g.:
53-
{'val': [1,2], 'opt': [True,False]}
54-
55-
- output is an generator over all {key:value} combinations
56-
of the keys and values. For the example above:
57-
generator([{'val':1,'opt':True},{'val':1,'opt':False},{'val':2,'opt':True},{'val':2,'opt':False}])
58-
"""
59-
if remaining_keys is None or cur_params is None:
60-
# init
61-
remaining_keys = list(all_params.keys())
62-
cur_params = dict()
63-
64-
cur_key = remaining_keys[0]
65-
myresults = [] # (runtime, cur_params)
66-
for cur_value in all_params[cur_key]:
67-
cur_params[cur_key] = cur_value
68-
if len(remaining_keys) == 1:
69-
# terminal, return copy
70-
yield dict(cur_params)
71-
else:
72-
# recursive call
73-
yield from param_combinations(all_params,
74-
remaining_keys=remaining_keys[1:],
75-
cur_params=cur_params)
76-
77-
7848

7949
def nqueens(n=8):
8050
""" N-queens problem

0 commit comments

Comments
 (0)