Skip to content

Commit 0c7498b

Browse files
committed
add pd1-tabular
2 parents 911c1cb + f64118f commit 0c7498b

26 files changed

+2521
-57
lines changed

src/mfpbench/__main__.py

+10
Original file line numberDiff line numberDiff line change
@@ -129,6 +129,7 @@ def do(cls, args: argparse.Namespace) -> None:
129129
download=True,
130130
install=False,
131131
force=args.force,
132+
workers=args.workers,
132133
)
133134

134135
@override
@@ -149,6 +150,15 @@ def fill_parser(cls, parser: argparse.ArgumentParser) -> argparse.ArgumentParser
149150
action="store_true",
150151
help="Print out the available benchmarks data sources",
151152
)
153+
parser.add_argument(
154+
"--workers",
155+
type=int,
156+
default=1,
157+
help=(
158+
"The number of workers to use for downloading"
159+
" if the downlaoder supports it"
160+
),
161+
)
152162
parser.add_argument(
153163
"--benchmark",
154164
choices=[

src/mfpbench/benchmark.py

+37-5
Original file line numberDiff line numberDiff line change
@@ -73,6 +73,7 @@ def __init__( # noqa: PLR0913
7373
prior: str | Path | C | Mapping[str, Any] | None = None,
7474
perturb_prior: float | None = None,
7575
value_metric: str | None = None,
76+
value_metric_test: str | None = None,
7677
cost_metric: str | None = None,
7778
):
7879
"""Initialize the benchmark.
@@ -97,19 +98,30 @@ def __init__( # noqa: PLR0913
9798
as the probability of swapping the value for a random one.
9899
value_metric: The metric to use for this benchmark. Uses
99100
the default metric from the Result if None.
101+
value_metric_test: The metric to use as a test metric for this benchmark.
102+
Uses the default test metric from the Result if left as None, and
103+
if there is no default test metric, will return None.
100104
cost_metric: The cost to use for this benchmark. Uses
101105
the default cost from the Result if None.
102106
"""
103107
if value_metric is None:
104108
value_metric = result_type.default_value_metric
109+
if value_metric_test is None:
110+
value_metric_test = result_type.default_value_metric_test
105111

106112
if cost_metric is None:
107113
cost_metric = result_type.default_cost_metric
108114

115+
# Ensure that the result type actually has an atrribute called value_metric
116+
if value_metric is None:
117+
assert getattr(self.Result, "value_metric", None) is not None
118+
value_metric = self.Result.value_metric
119+
109120
self.name = name
110121
self.seed = seed
111122
self.space = space
112123
self.value_metric = value_metric
124+
self.value_metric_test: str | None = value_metric_test
113125
self.cost_metric = cost_metric
114126
self.fidelity_range: tuple[F, F, F] = fidelity_range
115127
self.fidelity_name = fidelity_name
@@ -121,10 +133,6 @@ def __init__( # noqa: PLR0913
121133
for metric_name, metric in self.Result.metric_defs.items()
122134
}
123135

124-
if value_metric is None:
125-
assert getattr(self.Result, "value_metric", None) is not None
126-
value_metric = self.Result.value_metric
127-
128136
self._prior_arg = prior
129137

130138
# NOTE: This is handled entirely by subclasses as it requires knowledge
@@ -250,6 +258,7 @@ def query(
250258
*,
251259
at: F | None = None,
252260
value_metric: str | None = None,
261+
value_metric_test: str | None = None,
253262
cost_metric: str | None = None,
254263
) -> R:
255264
"""Submit a query and get a result.
@@ -260,11 +269,17 @@ def query(
260269
value_metric: The metric to use for this result. Uses
261270
the value metric passed in to the constructor if not specified,
262271
otherwise the default metric from the Result if None.
272+
value_metric: The metric to use for this result. Uses
273+
the value metric passed in to the constructor if not specified,
274+
otherwise the default metric from the Result if None.
275+
value_metric_test: The metric to use for this result. Uses
276+
the value metric passed in to the constructor if not specified,
277+
otherwise the default metric from the Result if None. If that
278+
is still None, then the `value_metric_test` will be None as well.
263279
cost_metric: The metric to use for this result. Uses
264280
the cost metric passed in to the constructor if not specified,
265281
otherwise the default metric from the Result if None.
266282
267-
268283
Returns:
269284
The result of the query
270285
"""
@@ -282,13 +297,19 @@ def query(
282297
__config = {k: __config.get(v, v) for k, v in _reverse_renames.items()}
283298

284299
value_metric = value_metric if value_metric is not None else self.value_metric
300+
value_metric_test = (
301+
value_metric_test
302+
if value_metric_test is not None
303+
else self.value_metric_test
304+
)
285305
cost_metric = cost_metric if cost_metric is not None else self.cost_metric
286306

287307
return self.Result.from_dict(
288308
config=config,
289309
fidelity=at,
290310
result=self._objective_function(__config, at=at),
291311
value_metric=str(value_metric),
312+
value_metric_test=value_metric_test,
292313
cost_metric=str(cost_metric),
293314
renames=self._result_renames,
294315
)
@@ -301,6 +322,7 @@ def trajectory(
301322
to: F | None = None,
302323
step: F | None = None,
303324
value_metric: str | None = None,
325+
value_metric_test: str | None = None,
304326
cost_metric: str | None = None,
305327
) -> list[R]:
306328
"""Get the full trajectory of a configuration.
@@ -313,6 +335,10 @@ def trajectory(
313335
value_metric: The metric to use for this result. Uses
314336
the value metric passed in to the constructor if not specified,
315337
otherwise the default metric from the Result if None.
338+
value_metric_test: The metric to use for this result. Uses
339+
the value metric passed in to the constructor if not specified,
340+
otherwise the default metric from the Result if None. If that
341+
is still None, then the `value_metric_test` will be None as well.
316342
cost_metric: The metric to use for this result. Uses
317343
the cost metric passed in to the constructor if not specified,
318344
otherwise the default metric from the Result if None.
@@ -330,6 +356,11 @@ def trajectory(
330356
__config = {k: __config.get(v, v) for k, v in _reverse_renames.items()}
331357

332358
value_metric = value_metric if value_metric is not None else self.value_metric
359+
value_metric_test = (
360+
value_metric_test
361+
if value_metric_test is not None
362+
else self.value_metric_test
363+
)
333364
cost_metric = cost_metric if cost_metric is not None else self.cost_metric
334365

335366
return [
@@ -338,6 +369,7 @@ def trajectory(
338369
fidelity=fidelity,
339370
result=result,
340371
value_metric=str(value_metric),
372+
value_metric_test=value_metric_test,
341373
cost_metric=str(cost_metric),
342374
renames=self._result_renames,
343375
)

src/mfpbench/get.py

+10-3
Original file line numberDiff line numberDiff line change
@@ -4,16 +4,16 @@
44
from typing import TYPE_CHECKING, Any
55

66
from mfpbench.jahs import JAHSBenchmark
7-
from mfpbench.lcbench_tabular import (
8-
LCBenchTabularBenchmark,
9-
)
7+
from mfpbench.lcbench_tabular import LCBenchTabularBenchmark
8+
from mfpbench.nb201_tabular.benchmark import NB201TabularBenchmark
109
from mfpbench.pd1 import (
1110
PD1cifar100_wideresnet_2048,
1211
PD1imagenet_resnet_512,
1312
PD1lm1b_transformer_2048,
1413
PD1translatewmt_xformer_64,
1514
PD1uniref50_transformer_128,
1615
)
16+
from mfpbench.pd1_tabular import PD1TabularBenchmark
1717
from mfpbench.synthetic.hartmann import (
1818
MFHartmann3Benchmark,
1919
MFHartmann3BenchmarkBad,
@@ -26,6 +26,7 @@
2626
MFHartmann6BenchmarkModerate,
2727
MFHartmann6BenchmarkTerrible,
2828
)
29+
from mfpbench.taskset_tabular import TaskSetTabularBenchmark
2930
from mfpbench.yahpo import (
3031
IAMLglmnetBenchmark,
3132
IAMLrangerBenchmark,
@@ -84,6 +85,12 @@
8485
"imagenet_resnet_512": PD1imagenet_resnet_512,
8586
# LCBenchTabular
8687
"lcbench_tabular": LCBenchTabularBenchmark,
88+
# PD1Tabular
89+
"pd1_tabular": PD1TabularBenchmark,
90+
# TaskSetTabular
91+
"taskset_tabular": TaskSetTabularBenchmark,
92+
# nb201 tabular
93+
"nb201_tabular": NB201TabularBenchmark,
8794
}
8895

8996

src/mfpbench/jahs/benchmark.py

+1
Original file line numberDiff line numberDiff line change
@@ -56,6 +56,7 @@ class JAHSConfig(Config):
5656
class JAHSResult(Result[JAHSConfig, int]):
5757
default_value_metric: ClassVar[str] = "valid_acc"
5858
default_cost_metric: ClassVar[str] = "runtime"
59+
default_value_metric_test: ClassVar[str] = "test_acc"
5960

6061
metric_defs: ClassVar[Mapping[str, Metric]] = {
6162
"runtime": Metric(minimize=True, bounds=(0, np.inf)),

src/mfpbench/lcbench_tabular/benchmark.py

+5-2
Original file line numberDiff line numberDiff line change
@@ -140,14 +140,15 @@ class LCBenchTabularConfig(TabularConfig):
140140
class LCBenchTabularResult(Result[LCBenchTabularConfig, int]):
141141
metric_defs: ClassVar[Mapping[str, Metric]] = {
142142
"val_accuracy": Metric(minimize=False, bounds=(0, 100)),
143-
"val_balanced_accuracy": Metric(minimize=False, bounds=(0, 100)),
143+
"val_balanced_accuracy": Metric(minimize=False, bounds=(0, 1)),
144144
"val_cross_entropy": Metric(minimize=True, bounds=(0, np.inf)),
145145
"test_accuracy": Metric(minimize=False, bounds=(0, 100)),
146-
"test_balanced_accuracy": Metric(minimize=False, bounds=(0, 100)),
146+
"test_balanced_accuracy": Metric(minimize=False, bounds=(0, 1)),
147147
"test_cross_entropy": Metric(minimize=True, bounds=(0, np.inf)),
148148
"time": Metric(minimize=True, bounds=(0, np.inf)),
149149
}
150150
default_value_metric: ClassVar[str] = "val_balanced_accuracy"
151+
default_value_metric_test: ClassVar[str] = "test_balanced_accuracy"
151152
default_cost_metric: ClassVar[str] = "time"
152153

153154
time: Metric.Value
@@ -214,6 +215,7 @@ def __init__(
214215
prior: str | Path | LCBenchTabularConfig | Mapping[str, Any] | None = None,
215216
perturb_prior: float | None = None,
216217
value_metric: str | None = None,
218+
value_metric_test: str | None = None,
217219
cost_metric: str | None = None,
218220
) -> None:
219221
"""Initialize the benchmark.
@@ -282,6 +284,7 @@ def __init__(
282284
result_type=LCBenchTabularResult,
283285
config_type=LCBenchTabularConfig,
284286
value_metric=value_metric,
287+
value_metric_test=value_metric_test,
285288
cost_metric=cost_metric,
286289
space=space,
287290
seed=seed,

src/mfpbench/metric.py

+3
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33
from dataclasses import dataclass, field
44

55
import numpy as np
6+
import pandas as pd
67

78

89
class OutOfBoundsError(ValueError):
@@ -38,6 +39,8 @@ def as_value(self, value: float) -> Metric.Value:
3839
Returns:
3940
The metric value.
4041
"""
42+
if pd.isna(value):
43+
value = np.inf
4144
return Metric.Value(value=value, definition=self)
4245

4346
@property

src/mfpbench/nb201_tabular/__init__.py

Whitespace-only changes.

0 commit comments

Comments
 (0)