@@ -73,6 +73,7 @@ def __init__( # noqa: PLR0913
73
73
prior : str | Path | C | Mapping [str , Any ] | None = None ,
74
74
perturb_prior : float | None = None ,
75
75
value_metric : str | None = None ,
76
+ value_metric_test : str | None = None ,
76
77
cost_metric : str | None = None ,
77
78
):
78
79
"""Initialize the benchmark.
@@ -97,19 +98,30 @@ def __init__( # noqa: PLR0913
97
98
as the probability of swapping the value for a random one.
98
99
value_metric: The metric to use for this benchmark. Uses
99
100
the default metric from the Result if None.
101
+ value_metric_test: The metric to use as a test metric for this benchmark.
102
+ Uses the default test metric from the Result if left as None, and
103
+ if there is no default test metric, will return None.
100
104
cost_metric: The cost to use for this benchmark. Uses
101
105
the default cost from the Result if None.
102
106
"""
103
107
if value_metric is None :
104
108
value_metric = result_type .default_value_metric
109
+ if value_metric_test is None :
110
+ value_metric_test = result_type .default_value_metric_test
105
111
106
112
if cost_metric is None :
107
113
cost_metric = result_type .default_cost_metric
108
114
115
+ # Ensure that the result type actually has an atrribute called value_metric
116
+ if value_metric is None :
117
+ assert getattr (self .Result , "value_metric" , None ) is not None
118
+ value_metric = self .Result .value_metric
119
+
109
120
self .name = name
110
121
self .seed = seed
111
122
self .space = space
112
123
self .value_metric = value_metric
124
+ self .value_metric_test : str | None = value_metric_test
113
125
self .cost_metric = cost_metric
114
126
self .fidelity_range : tuple [F , F , F ] = fidelity_range
115
127
self .fidelity_name = fidelity_name
@@ -121,10 +133,6 @@ def __init__( # noqa: PLR0913
121
133
for metric_name , metric in self .Result .metric_defs .items ()
122
134
}
123
135
124
- if value_metric is None :
125
- assert getattr (self .Result , "value_metric" , None ) is not None
126
- value_metric = self .Result .value_metric
127
-
128
136
self ._prior_arg = prior
129
137
130
138
# NOTE: This is handled entirely by subclasses as it requires knowledge
@@ -250,6 +258,7 @@ def query(
250
258
* ,
251
259
at : F | None = None ,
252
260
value_metric : str | None = None ,
261
+ value_metric_test : str | None = None ,
253
262
cost_metric : str | None = None ,
254
263
) -> R :
255
264
"""Submit a query and get a result.
@@ -260,11 +269,17 @@ def query(
260
269
value_metric: The metric to use for this result. Uses
261
270
the value metric passed in to the constructor if not specified,
262
271
otherwise the default metric from the Result if None.
272
+ value_metric: The metric to use for this result. Uses
273
+ the value metric passed in to the constructor if not specified,
274
+ otherwise the default metric from the Result if None.
275
+ value_metric_test: The metric to use for this result. Uses
276
+ the value metric passed in to the constructor if not specified,
277
+ otherwise the default metric from the Result if None. If that
278
+ is still None, then the `value_metric_test` will be None as well.
263
279
cost_metric: The metric to use for this result. Uses
264
280
the cost metric passed in to the constructor if not specified,
265
281
otherwise the default metric from the Result if None.
266
282
267
-
268
283
Returns:
269
284
The result of the query
270
285
"""
@@ -282,13 +297,19 @@ def query(
282
297
__config = {k : __config .get (v , v ) for k , v in _reverse_renames .items ()}
283
298
284
299
value_metric = value_metric if value_metric is not None else self .value_metric
300
+ value_metric_test = (
301
+ value_metric_test
302
+ if value_metric_test is not None
303
+ else self .value_metric_test
304
+ )
285
305
cost_metric = cost_metric if cost_metric is not None else self .cost_metric
286
306
287
307
return self .Result .from_dict (
288
308
config = config ,
289
309
fidelity = at ,
290
310
result = self ._objective_function (__config , at = at ),
291
311
value_metric = str (value_metric ),
312
+ value_metric_test = value_metric_test ,
292
313
cost_metric = str (cost_metric ),
293
314
renames = self ._result_renames ,
294
315
)
@@ -301,6 +322,7 @@ def trajectory(
301
322
to : F | None = None ,
302
323
step : F | None = None ,
303
324
value_metric : str | None = None ,
325
+ value_metric_test : str | None = None ,
304
326
cost_metric : str | None = None ,
305
327
) -> list [R ]:
306
328
"""Get the full trajectory of a configuration.
@@ -313,6 +335,10 @@ def trajectory(
313
335
value_metric: The metric to use for this result. Uses
314
336
the value metric passed in to the constructor if not specified,
315
337
otherwise the default metric from the Result if None.
338
+ value_metric_test: The metric to use for this result. Uses
339
+ the value metric passed in to the constructor if not specified,
340
+ otherwise the default metric from the Result if None. If that
341
+ is still None, then the `value_metric_test` will be None as well.
316
342
cost_metric: The metric to use for this result. Uses
317
343
the cost metric passed in to the constructor if not specified,
318
344
otherwise the default metric from the Result if None.
@@ -330,6 +356,11 @@ def trajectory(
330
356
__config = {k : __config .get (v , v ) for k , v in _reverse_renames .items ()}
331
357
332
358
value_metric = value_metric if value_metric is not None else self .value_metric
359
+ value_metric_test = (
360
+ value_metric_test
361
+ if value_metric_test is not None
362
+ else self .value_metric_test
363
+ )
333
364
cost_metric = cost_metric if cost_metric is not None else self .cost_metric
334
365
335
366
return [
@@ -338,6 +369,7 @@ def trajectory(
338
369
fidelity = fidelity ,
339
370
result = result ,
340
371
value_metric = str (value_metric ),
372
+ value_metric_test = value_metric_test ,
341
373
cost_metric = str (cost_metric ),
342
374
renames = self ._result_renames ,
343
375
)
0 commit comments