@@ -278,8 +278,10 @@ class UCBScoreFunction(eqx.Module):
278
278
279
279
The UCB acquisition value is the sum of the predicted mean based on completed
280
280
trials and the predicted standard deviation based on all trials, completed and
281
- pending (scaled by the UCB coefficient). This class follows the
282
- `acquisitions.ScoreFunction` protocol.
281
+ pending (scaled by the UCB coefficient). If `prior_acquisition` is not None,
282
+ the return value is the sum of the prior acquisition value and the UCB
283
+ acquisition value. This class follows the `acquisitions.ScoreFunction`
284
+ protocol.
283
285
284
286
Attributes:
285
287
predictive: Predictive model with cached Cholesky conditioned on completed
@@ -288,6 +290,7 @@ class UCBScoreFunction(eqx.Module):
288
290
on completed and pending trials.
289
291
ucb_coefficient: The UCB coefficient.
290
292
trust_region: Trust region.
293
+ prior_acquisition: An optional prior acquisition function.
291
294
scalarization_weights_rng: Random key for scalarization.
292
295
labels: Labels, shaped as [num_index_points, num_metrics].
293
296
num_scalarizations: Number of scalarizations.
@@ -297,6 +300,7 @@ class UCBScoreFunction(eqx.Module):
297
300
predictive_all_features : sp .UniformEnsemblePredictive
298
301
ucb_coefficient : jt .Float [jt .Array , '' ]
299
302
trust_region : Optional [acquisitions .TrustRegion ]
303
+ prior_acquisition : Callable [[types .ModelInput ], jax .Array ] | None
300
304
labels : types .PaddedArray
301
305
scalarizer : scalarization .Scalarization
302
306
@@ -306,6 +310,7 @@ def __init__(
306
310
predictive_all_features : sp .UniformEnsemblePredictive ,
307
311
ucb_coefficient : jt .Float [jt .Array , '' ],
308
312
trust_region : Optional [acquisitions .TrustRegion ],
313
+ prior_acquisition : Callable [[types .ModelInput ], jax .Array ] | None ,
309
314
scalarization_weights_rng : jax .Array ,
310
315
labels : types .PaddedArray ,
311
316
num_scalarizations : int = 1000 ,
@@ -314,6 +319,7 @@ def __init__(
314
319
self .predictive_all_features = predictive_all_features
315
320
self .ucb_coefficient = ucb_coefficient
316
321
self .trust_region = trust_region
322
+ self .prior_acquisition = prior_acquisition
317
323
self .labels = labels
318
324
self .scalarizer = acquisitions .create_hv_scalarization (
319
325
num_scalarizations , labels , scalarization_weights_rng
@@ -357,11 +363,16 @@ def score_with_aux(
357
363
scalarized_acq_values = _apply_trust_region (
358
364
self .trust_region , xs , scalarized_acq_values
359
365
)
360
- return scalarized_acq_values , {
366
+ aux = {
361
367
'mean' : mean ,
362
368
'stddev' : gprm .stddev (),
363
369
'stddev_from_all' : stddev_from_all ,
364
370
}
371
+ if self .prior_acquisition is not None :
372
+ prior_acq_values = self .prior_acquisition (xs )
373
+ scalarized_acq_values = prior_acq_values + scalarized_acq_values
374
+ aux ['prior_acq_values' ] = prior_acq_values
375
+ return scalarized_acq_values , aux
365
376
366
377
367
378
class PEScoreFunction (eqx .Module ):
@@ -370,8 +381,10 @@ class PEScoreFunction(eqx.Module):
370
381
The PE acquisition value is the predicted standard deviation (eq. (9)
371
382
in https://arxiv.org/pdf/1304.5350) based on all completed and active trials,
372
383
plus a penalty term that grows linearly in the amount of violation of the
373
- constraint `UCB(xs) >= threshold`. This class follows the
374
- `acquisitions.ScoreFunction` protocol.
384
+ constraint `UCB(xs) >= threshold`. If `prior_acquisition` is not None, the
385
+ returned value is the sum of the prior acquisition value and the PE
386
+ acquisition value. This class follows the `acquisitions.ScoreFunction`
387
+ protocol.
375
388
376
389
Attributes:
377
390
predictive: Predictive model with cached Cholesky conditioned on completed
@@ -383,6 +396,9 @@ class PEScoreFunction(eqx.Module):
383
396
values on `xs`.
384
397
penalty_coefficient: Multiplier on the constraint violation penalty.
385
398
trust_region:
399
+ prior_acquisition: An optional prior acquisition function.
400
+ multimetric_promising_region_penalty_type: The type of multimetric promising
401
+ region penalty.
386
402
387
403
Returns:
388
404
The Pure-Exploration acquisition value.
@@ -394,6 +410,7 @@ class PEScoreFunction(eqx.Module):
394
410
explore_ucb_coefficient : jt .Float [jt .Array , '' ]
395
411
penalty_coefficient : jt .Float [jt .Array , '' ]
396
412
trust_region : Optional [acquisitions .TrustRegion ]
413
+ prior_acquisition : Callable [[types .ModelInput ], jax .Array ] | None
397
414
multimetric_promising_region_penalty_type : (
398
415
MultimetricPromisingRegionPenaltyType
399
416
)
@@ -457,11 +474,16 @@ def score_with_aux(
457
474
acq_values = stddev_from_all + penalty
458
475
if self .trust_region is not None :
459
476
acq_values = _apply_trust_region (self .trust_region , xs , acq_values )
460
- return acq_values , {
477
+ aux = {
461
478
'mean' : mean ,
462
479
'stddev' : stddev ,
463
480
'stddev_from_all' : stddev_from_all ,
464
481
}
482
+ if self .prior_acquisition is not None :
483
+ prior_acq_values = self .prior_acquisition (xs )
484
+ acq_values = prior_acq_values + acq_values
485
+ aux ['prior_acq_values' ] = prior_acq_values
486
+ return acq_values , aux
465
487
466
488
467
489
def _logdet (matrix : jax .Array ):
@@ -486,8 +508,10 @@ class SetPEScoreFunction(eqx.Module):
486
508
predicted covariance matrix evaluated at the points (eq. (8) in
487
509
https://arxiv.org/pdf/1304.5350) based on all completed and active trials,
488
510
plus a penalty term that grows linearly in the amount of violation of the
489
- constraint `UCB(xs) >= threshold`. This class follows the
490
- `acquisitions.ScoreFunction` protocol.
511
+ constraint `UCB(xs) >= threshold`. If `prior_acquisition` is not None, the
512
+ returned value is the sum of the prior acquisition value and the PE
513
+ acquisition value. This class follows the `acquisitions.ScoreFunction`
514
+ protocol.
491
515
492
516
Attributes:
493
517
predictive: Predictive model with cached Cholesky conditioned on completed
@@ -499,6 +523,7 @@ class SetPEScoreFunction(eqx.Module):
499
523
values on `xs`.
500
524
penalty_coefficient: Multiplier on the constraint violation penalty.
501
525
trust_region:
526
+ prior_acquisition: An optional prior acquisition function.
502
527
503
528
Returns:
504
529
The Pure-Exploration acquisition value.
@@ -510,6 +535,7 @@ class SetPEScoreFunction(eqx.Module):
510
535
explore_ucb_coefficient : jt .Float [jt .Array , '' ]
511
536
penalty_coefficient : jt .Float [jt .Array , '' ]
512
537
trust_region : Optional [acquisitions .TrustRegion ]
538
+ prior_acquisition : Callable [[types .ModelInput ], jax .Array ] | None
513
539
514
540
def score (
515
541
self , xs : types .ModelInput , seed : Optional [jax .Array ] = None
@@ -549,11 +575,16 @@ def score_with_aux(
549
575
)
550
576
if self .trust_region is not None :
551
577
acq_values = _apply_trust_region_to_set (self .trust_region , xs , acq_values )
552
- return acq_values , {
578
+ aux = {
553
579
'mean' : mean ,
554
580
'stddev' : stddev ,
555
581
'stddev_from_all' : jnp .sqrt (jnp .diagonal (cov , axis1 = 1 , axis2 = 2 )),
556
582
}
583
+ if self .prior_acquisition is not None :
584
+ prior_acq_values = self .prior_acquisition (xs )
585
+ acq_values = prior_acq_values + acq_values
586
+ aux ['prior_acq_values' ] = prior_acq_values
587
+ return acq_values , aux
557
588
558
589
559
590
def default_ard_optimizer () -> optimizers .Optimizer [types .ParameterDict ]:
@@ -587,6 +618,14 @@ class method that takes `ModelInput` and returns a
587
618
observed.
588
619
rng: If not set, uses random numbers.
589
620
clear_jax_cache: If True, every `suggest` call clears the Jax cache.
621
+ padding_schedule: Configures what inputs (trials, features, labels) to pad
622
+ with what schedule. Useful for reducing JIT compilation passes. (Default
623
+ implies no padding.)
624
+ prior_acquisition: An optional prior acquisition function. If provided, the
625
+ suggestions will be generated by maximizing the sum of the prior
626
+ acquisition value and the GP-based acquisition value (UCB or PE). Useful
627
+ for biasing the suggestions towards a prior, e.g., being close to some
628
+ known parameter values.
590
629
"""
591
630
592
631
_problem : vz .ProblemStatement = attr .field (kw_only = False )
@@ -621,12 +660,13 @@ class method that takes `ModelInput` and returns a
621
660
factory = lambda : jax .random .PRNGKey (random .getrandbits (32 )), kw_only = True
622
661
)
623
662
_clear_jax_cache : bool = attr .field (default = False , kw_only = True )
624
- # Whether to pad all inputs, and what type of schedule to use. This is to
625
- # ensure fewer JIT compilation passes. (Default implies no padding.)
626
663
# TODO: Check padding does not affect designer behavior.
627
664
_padding_schedule : padding .PaddingSchedule = attr .field (
628
665
factory = padding .PaddingSchedule , kw_only = True
629
666
)
667
+ _prior_acquisition : Callable [[types .ModelInput ], jax .Array ] | None = (
668
+ attr .field (factory = lambda : None , kw_only = True )
669
+ )
630
670
631
671
default_eagle_config = es .EagleStrategyConfig (
632
672
visibility = 3.6782451729470043 ,
@@ -1003,6 +1043,7 @@ def _suggest_one(
1003
1043
predictive_all_features ,
1004
1044
ucb_coefficient = self ._config .ucb_coefficient ,
1005
1045
trust_region = tr if self ._use_trust_region else None ,
1046
+ prior_acquisition = self ._prior_acquisition ,
1006
1047
scalarization_weights_rng = scalarization_weights_rng ,
1007
1048
labels = data .labels ,
1008
1049
)
@@ -1014,6 +1055,7 @@ def _suggest_one(
1014
1055
ucb_coefficient = self ._config .ucb_coefficient ,
1015
1056
explore_ucb_coefficient = self ._config .explore_region_ucb_coefficient ,
1016
1057
trust_region = tr if self ._use_trust_region else None ,
1058
+ prior_acquisition = self ._prior_acquisition ,
1017
1059
multimetric_promising_region_penalty_type = (
1018
1060
self ._config .multimetric_promising_region_penalty_type
1019
1061
),
@@ -1083,6 +1125,10 @@ def _suggest_one(
1083
1125
'trust_radius' : f'{ tr .trust_radius } ' ,
1084
1126
'params' : f'{ model .params } ' ,
1085
1127
})
1128
+ if self ._prior_acquisition is not None :
1129
+ metadata .ns ('prior_acquisition' ).update (
1130
+ {'value' : f'{ aux ["prior_acq_values" ][0 ]} ' }
1131
+ )
1086
1132
metadata .ns ('timing' ).update (
1087
1133
{'time' : f'{ datetime .datetime .now () - start_time } ' }
1088
1134
)
@@ -1118,6 +1164,7 @@ def _suggest_batch_with_exploration(
1118
1164
ucb_coefficient = self ._config .ucb_coefficient ,
1119
1165
explore_ucb_coefficient = self ._config .explore_region_ucb_coefficient ,
1120
1166
trust_region = tr if self ._use_trust_region else None ,
1167
+ prior_acquisition = self ._prior_acquisition ,
1121
1168
)
1122
1169
1123
1170
acquisition_optimizer = self ._acquisition_optimizer_factory (self ._converter )
@@ -1180,6 +1227,10 @@ def _suggest_batch_with_exploration(
1180
1227
'trust_radius' : f'{ tr .trust_radius } ' ,
1181
1228
'params' : f'{ model .params } ' ,
1182
1229
})
1230
+ if self ._prior_acquisition is not None :
1231
+ metadata .ns ('prior_acquisition' ).update (
1232
+ {'value' : f'{ aux ["prior_acq_values" ][0 ]} ' }
1233
+ )
1183
1234
metadata .ns ('timing' ).update ({'time' : f'{ end_time - start_time } ' })
1184
1235
suggestions .append (
1185
1236
vz .TrialSuggestion (
0 commit comments