Skip to content

Commit 4d64042

Browse files
authored
[Feature] Add Trust-region algorithm searcher (#33)
* Tmp * Add test * Fix test code * Fix test code * Add trust region alg * Add test func * Fix minor bug * Fix cfg task name * Fix typo * Fix wrap immutable container * Fix base cfg * Fix disc crit * Fix cfg * Adjust hyperparam for disc * Fix test code * Fix typo * Fix alias * Add test code * Fix test code * Fix test code * Fix cont test function * Fix cont test function * Fix test function * Fix lint * Fix test code
1 parent c0246f7 commit 4d64042

16 files changed

Lines changed: 1391 additions & 87 deletions

File tree

.github/workflows/build_unit_test.yaml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -26,7 +26,7 @@ jobs:
2626
run: |
2727
pip install torch==1.9.1+cu111 torchvision==0.10.1+cu111 -f https://download.pytorch.org/whl/torch_stable.html
2828
pip install mmcv-full==1.4.7 -f https://download.openmmlab.com/mmcv/dist/cu111/torch1.9.0/index.html
29-
pip install nevergrad mmcls mmdet mmsegmentation protobuf==3.20
29+
pip install gpytorch nevergrad mmcls mmdet mmsegmentation protobuf==3.20
3030
pip install -e .
3131
- name: Run unittests and generate coverage report
3232
run: |
Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,4 @@
11
searcher = dict(
22
type='NevergradSearch',
3-
num_workers=16,
43
budget=256,
54
)
Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1 @@
1+
searcher = dict(type='TrustRegionSearcher', )

configs/mmtune/bbo_sphere_nevergrad_oneplusone.py

Lines changed: 0 additions & 29 deletions
This file was deleted.
Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,13 @@
1+
_base_ = ['./_base_/context/blackbox.py', './_base_/searcher/trust_region.py']
2+
3+
metric = 'result'
4+
mode = 'min'
5+
6+
space = {
7+
f'_variable{idx}': dict(type='Uniform', lower=-1.0, upper=1.0)
8+
for idx in range(8)
9+
}
10+
11+
task = dict(type='ContinuousTestFunction')
12+
13+
num_samples = 512
Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,14 @@
1+
_base_ = ['./_base_/context/blackbox.py', './_base_/searcher/trust_region.py']
2+
3+
metric = 'result'
4+
mode = 'min'
5+
6+
space = {
7+
f'_variable{idx}':
8+
dict(type='Choice', categories=[0, 1], alias=['OFF', 'ON'])
9+
for idx in range(8)
10+
}
11+
12+
task = dict(type='DiscreteTestFunction')
13+
14+
num_samples = 512

mmtune/mm/tasks/__init__.py

Lines changed: 13 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,14 +1,23 @@
11
from .base import BaseTask
22
from .blackbox import BlackBoxTask
33
from .builder import TASKS, build_task_processor
4+
from .cont_test_func import ContinuousTestFunction
5+
from .disc_test_func import DiscreteTestFunction
46
from .mmcls import MMClassification
57
from .mmdet import MMDetection
68
from .mmseg import MMSegmentation
79
from .mmtrainbase import MMTrainBasedTask
8-
from .sphere import Sphere
910

1011
__all__ = [
11-
'TASKS', 'build_task_processor', 'BaseTask', 'BlackBoxTask',
12-
'MMTrainBasedTask', 'MMClassification', 'MMDetection', 'MMSegmentation',
13-
'MMSegmentation', 'Sphere'
12+
'MMClassification',
13+
'DiscreteTestFunction',
14+
'ContinuousTestFunction',
15+
'TASKS',
16+
'build_task_processor',
17+
'BaseTask',
18+
'BlackBoxTask',
19+
'MMTrainBasedTask',
20+
'MMDetection',
21+
'MMSegmentation',
22+
'MMSegmentation',
1423
]

0 commit comments

Comments
 (0)