Skip to content

Commit fe47a10

Browse files
vizier-teamcopybara-github
vizier-team
authored andcommitted
Supports a prior acquisition function in GP_UCB_PE
PiperOrigin-RevId: 723223169
1 parent 7ce0ae8 commit fe47a10

File tree

4 files changed

+271
-20
lines changed

4 files changed

+271
-20
lines changed

vizier/_src/algorithms/designers/gp_ucb_pe.py

+62-11
Original file line numberDiff line numberDiff line change
@@ -278,8 +278,10 @@ class UCBScoreFunction(eqx.Module):
278278
279279
The UCB acquisition value is the sum of the predicted mean based on completed
280280
trials and the predicted standard deviation based on all trials, completed and
281-
pending (scaled by the UCB coefficient). This class follows the
282-
`acquisitions.ScoreFunction` protocol.
281+
pending (scaled by the UCB coefficient). If `prior_acquisition` is not None,
282+
the return value is the sum of the prior acquisition value and the UCB
283+
acquisition value. This class follows the `acquisitions.ScoreFunction`
284+
protocol.
283285
284286
Attributes:
285287
predictive: Predictive model with cached Cholesky conditioned on completed
@@ -288,6 +290,7 @@ class UCBScoreFunction(eqx.Module):
288290
on completed and pending trials.
289291
ucb_coefficient: The UCB coefficient.
290292
trust_region: Trust region.
293+
prior_acquisition: An optional prior acquisition function.
291294
scalarization_weights_rng: Random key for scalarization.
292295
labels: Labels, shaped as [num_index_points, num_metrics].
293296
num_scalarizations: Number of scalarizations.
@@ -297,6 +300,7 @@ class UCBScoreFunction(eqx.Module):
297300
predictive_all_features: sp.UniformEnsemblePredictive
298301
ucb_coefficient: jt.Float[jt.Array, '']
299302
trust_region: Optional[acquisitions.TrustRegion]
303+
prior_acquisition: Callable[[types.ModelInput], jax.Array] | None
300304
labels: types.PaddedArray
301305
scalarizer: scalarization.Scalarization
302306

@@ -306,6 +310,7 @@ def __init__(
306310
predictive_all_features: sp.UniformEnsemblePredictive,
307311
ucb_coefficient: jt.Float[jt.Array, ''],
308312
trust_region: Optional[acquisitions.TrustRegion],
313+
prior_acquisition: Callable[[types.ModelInput], jax.Array] | None,
309314
scalarization_weights_rng: jax.Array,
310315
labels: types.PaddedArray,
311316
num_scalarizations: int = 1000,
@@ -314,6 +319,7 @@ def __init__(
314319
self.predictive_all_features = predictive_all_features
315320
self.ucb_coefficient = ucb_coefficient
316321
self.trust_region = trust_region
322+
self.prior_acquisition = prior_acquisition
317323
self.labels = labels
318324
self.scalarizer = acquisitions.create_hv_scalarization(
319325
num_scalarizations, labels, scalarization_weights_rng
@@ -357,11 +363,16 @@ def score_with_aux(
357363
scalarized_acq_values = _apply_trust_region(
358364
self.trust_region, xs, scalarized_acq_values
359365
)
360-
return scalarized_acq_values, {
366+
aux = {
361367
'mean': mean,
362368
'stddev': gprm.stddev(),
363369
'stddev_from_all': stddev_from_all,
364370
}
371+
if self.prior_acquisition is not None:
372+
prior_acq_values = self.prior_acquisition(xs)
373+
scalarized_acq_values = prior_acq_values + scalarized_acq_values
374+
aux['prior_acq_values'] = prior_acq_values
375+
return scalarized_acq_values, aux
365376

366377

367378
class PEScoreFunction(eqx.Module):
@@ -370,8 +381,10 @@ class PEScoreFunction(eqx.Module):
370381
The PE acquisition value is the predicted standard deviation (eq. (9)
371382
in https://arxiv.org/pdf/1304.5350) based on all completed and active trials,
372383
plus a penalty term that grows linearly in the amount of violation of the
373-
constraint `UCB(xs) >= threshold`. This class follows the
374-
`acquisitions.ScoreFunction` protocol.
384+
constraint `UCB(xs) >= threshold`. If `prior_acquisition` is not None, the
385+
returned value is the sum of the prior acquisition value and the PE
386+
acquisition value. This class follows the `acquisitions.ScoreFunction`
387+
protocol.
375388
376389
Attributes:
377390
predictive: Predictive model with cached Cholesky conditioned on completed
@@ -383,6 +396,9 @@ class PEScoreFunction(eqx.Module):
383396
values on `xs`.
384397
penalty_coefficient: Multiplier on the constraint violation penalty.
385398
trust_region:
399+
prior_acquisition: An optional prior acquisition function.
400+
multimetric_promising_region_penalty_type: The type of multimetric promising
401+
region penalty.
386402
387403
Returns:
388404
The Pure-Exploration acquisition value.
@@ -394,6 +410,7 @@ class PEScoreFunction(eqx.Module):
394410
explore_ucb_coefficient: jt.Float[jt.Array, '']
395411
penalty_coefficient: jt.Float[jt.Array, '']
396412
trust_region: Optional[acquisitions.TrustRegion]
413+
prior_acquisition: Callable[[types.ModelInput], jax.Array] | None
397414
multimetric_promising_region_penalty_type: (
398415
MultimetricPromisingRegionPenaltyType
399416
)
@@ -457,11 +474,16 @@ def score_with_aux(
457474
acq_values = stddev_from_all + penalty
458475
if self.trust_region is not None:
459476
acq_values = _apply_trust_region(self.trust_region, xs, acq_values)
460-
return acq_values, {
477+
aux = {
461478
'mean': mean,
462479
'stddev': stddev,
463480
'stddev_from_all': stddev_from_all,
464481
}
482+
if self.prior_acquisition is not None:
483+
prior_acq_values = self.prior_acquisition(xs)
484+
acq_values = prior_acq_values + acq_values
485+
aux['prior_acq_values'] = prior_acq_values
486+
return acq_values, aux
465487

466488

467489
def _logdet(matrix: jax.Array):
@@ -486,8 +508,10 @@ class SetPEScoreFunction(eqx.Module):
486508
predicted covariance matrix evaluated at the points (eq. (8) in
487509
https://arxiv.org/pdf/1304.5350) based on all completed and active trials,
488510
plus a penalty term that grows linearly in the amount of violation of the
489-
constraint `UCB(xs) >= threshold`. This class follows the
490-
`acquisitions.ScoreFunction` protocol.
511+
constraint `UCB(xs) >= threshold`. If `prior_acquisition` is not None, the
512+
returned value is the sum of the prior acquisition value and the PE
513+
acquisition value. This class follows the `acquisitions.ScoreFunction`
514+
protocol.
491515
492516
Attributes:
493517
predictive: Predictive model with cached Cholesky conditioned on completed
@@ -499,6 +523,7 @@ class SetPEScoreFunction(eqx.Module):
499523
values on `xs`.
500524
penalty_coefficient: Multiplier on the constraint violation penalty.
501525
trust_region:
526+
prior_acquisition: An optional prior acquisition function.
502527
503528
Returns:
504529
The Pure-Exploration acquisition value.
@@ -510,6 +535,7 @@ class SetPEScoreFunction(eqx.Module):
510535
explore_ucb_coefficient: jt.Float[jt.Array, '']
511536
penalty_coefficient: jt.Float[jt.Array, '']
512537
trust_region: Optional[acquisitions.TrustRegion]
538+
prior_acquisition: Callable[[types.ModelInput], jax.Array] | None
513539

514540
def score(
515541
self, xs: types.ModelInput, seed: Optional[jax.Array] = None
@@ -549,11 +575,16 @@ def score_with_aux(
549575
)
550576
if self.trust_region is not None:
551577
acq_values = _apply_trust_region_to_set(self.trust_region, xs, acq_values)
552-
return acq_values, {
578+
aux = {
553579
'mean': mean,
554580
'stddev': stddev,
555581
'stddev_from_all': jnp.sqrt(jnp.diagonal(cov, axis1=1, axis2=2)),
556582
}
583+
if self.prior_acquisition is not None:
584+
prior_acq_values = self.prior_acquisition(xs)
585+
acq_values = prior_acq_values + acq_values
586+
aux['prior_acq_values'] = prior_acq_values
587+
return acq_values, aux
557588

558589

559590
def default_ard_optimizer() -> optimizers.Optimizer[types.ParameterDict]:
@@ -587,6 +618,14 @@ class method that takes `ModelInput` and returns a
587618
observed.
588619
rng: If not set, uses random numbers.
589620
clear_jax_cache: If True, every `suggest` call clears the Jax cache.
621+
padding_schedule: Configures what inputs (trials, features, labels) to pad
622+
with what schedule. Useful for reducing JIT compilation passes. (Default
623+
implies no padding.)
624+
prior_acquisition: An optional prior acquisition function. If provided, the
625+
suggestions will be generated by maximizing the sum of the prior
626+
acquisition value and the GP-based acquisition value (UCB or PE). Useful
627+
for biasing the suggestions towards a prior, e.g., being close to some
628+
known parameter values.
590629
"""
591630

592631
_problem: vz.ProblemStatement = attr.field(kw_only=False)
@@ -621,12 +660,13 @@ class method that takes `ModelInput` and returns a
621660
factory=lambda: jax.random.PRNGKey(random.getrandbits(32)), kw_only=True
622661
)
623662
_clear_jax_cache: bool = attr.field(default=False, kw_only=True)
624-
# Whether to pad all inputs, and what type of schedule to use. This is to
625-
# ensure fewer JIT compilation passes. (Default implies no padding.)
626663
# TODO: Check padding does not affect designer behavior.
627664
_padding_schedule: padding.PaddingSchedule = attr.field(
628665
factory=padding.PaddingSchedule, kw_only=True
629666
)
667+
_prior_acquisition: Callable[[types.ModelInput], jax.Array] | None = (
668+
attr.field(factory=lambda: None, kw_only=True)
669+
)
630670

631671
default_eagle_config = es.EagleStrategyConfig(
632672
visibility=3.6782451729470043,
@@ -1003,6 +1043,7 @@ def _suggest_one(
10031043
predictive_all_features,
10041044
ucb_coefficient=self._config.ucb_coefficient,
10051045
trust_region=tr if self._use_trust_region else None,
1046+
prior_acquisition=self._prior_acquisition,
10061047
scalarization_weights_rng=scalarization_weights_rng,
10071048
labels=data.labels,
10081049
)
@@ -1014,6 +1055,7 @@ def _suggest_one(
10141055
ucb_coefficient=self._config.ucb_coefficient,
10151056
explore_ucb_coefficient=self._config.explore_region_ucb_coefficient,
10161057
trust_region=tr if self._use_trust_region else None,
1058+
prior_acquisition=self._prior_acquisition,
10171059
multimetric_promising_region_penalty_type=(
10181060
self._config.multimetric_promising_region_penalty_type
10191061
),
@@ -1083,6 +1125,10 @@ def _suggest_one(
10831125
'trust_radius': f'{tr.trust_radius}',
10841126
'params': f'{model.params}',
10851127
})
1128+
if self._prior_acquisition is not None:
1129+
metadata.ns('prior_acquisition').update(
1130+
{'value': f'{aux["prior_acq_values"][0]}'}
1131+
)
10861132
metadata.ns('timing').update(
10871133
{'time': f'{datetime.datetime.now() - start_time}'}
10881134
)
@@ -1118,6 +1164,7 @@ def _suggest_batch_with_exploration(
11181164
ucb_coefficient=self._config.ucb_coefficient,
11191165
explore_ucb_coefficient=self._config.explore_region_ucb_coefficient,
11201166
trust_region=tr if self._use_trust_region else None,
1167+
prior_acquisition=self._prior_acquisition,
11211168
)
11221169

11231170
acquisition_optimizer = self._acquisition_optimizer_factory(self._converter)
@@ -1180,6 +1227,10 @@ def _suggest_batch_with_exploration(
11801227
'trust_radius': f'{tr.trust_radius}',
11811228
'params': f'{model.params}',
11821229
})
1230+
if self._prior_acquisition is not None:
1231+
metadata.ns('prior_acquisition').update(
1232+
{'value': f'{aux["prior_acq_values"][0]}'}
1233+
)
11831234
metadata.ns('timing').update({'time': f'{end_time - start_time}'})
11841235
suggestions.append(
11851236
vz.TrialSuggestion(

vizier/_src/algorithms/designers/gp_ucb_pe_test.py

+114-1
Original file line numberDiff line numberDiff line change
@@ -28,6 +28,7 @@
2828
from vizier._src.algorithms.designers import quasi_random
2929
from vizier._src.algorithms.optimizers import eagle_strategy as es
3030
from vizier._src.algorithms.optimizers import vectorized_base as vb
31+
from vizier._src.jax import types
3132
from vizier._src.jax.models import multitask_tuned_gp_models
3233
from vizier.jax import optimizers
3334
from vizier.pyvizier.converters import padding
@@ -328,7 +329,7 @@ def test_on_flat_space(
328329
# single-metric case because the acquisition value in the multi-metric
329330
# case is randomly scalarized.
330331
if num_metrics == 1:
331-
self.assertAlmostEqual(mean + 10.0 * stddev_from_all, acq)
332+
self.assertAlmostEqual(mean + 10.0 * stddev_from_all, acq, places=5)
332333
self.assertTrue(use_ucb)
333334
continue
334335

@@ -449,6 +450,118 @@ def test_ucb_overwrite(self):
449450
)
450451
self.assertTrue(use_ucb)
451452

453+
@parameterized.parameters(
454+
dict(optimize_set_acquisition_for_exploration=False),
455+
dict(optimize_set_acquisition_for_exploration=True),
456+
)
457+
def test_prior_acquisition(
458+
self, optimize_set_acquisition_for_exploration: bool
459+
):
460+
problem = vz.ProblemStatement(
461+
test_studies.flat_continuous_space_with_scaling()
462+
)
463+
problem.metric_information.append(
464+
vz.MetricInformation(
465+
name='metric', goal=vz.ObjectiveMetricGoal.MAXIMIZE
466+
)
467+
)
468+
vectorized_optimizer_factory = vb.VectorizedOptimizerFactory(
469+
strategy_factory=es.VectorizedEagleStrategyFactory(),
470+
max_evaluations=100,
471+
)
472+
473+
def dummy_prior_acquisition(xs: types.ModelInput):
474+
return np.ones(xs.continuous.shape[0]) * 12345.0
475+
476+
designer = gp_ucb_pe.VizierGPUCBPEBandit(
477+
problem,
478+
acquisition_optimizer_factory=vectorized_optimizer_factory,
479+
metadata_ns='gp_ucb_pe_bandit_test',
480+
num_seed_trials=1,
481+
config=gp_ucb_pe.UCBPEConfig(
482+
ucb_coefficient=10.0,
483+
explore_region_ucb_coefficient=0.5,
484+
cb_violation_penalty_coefficient=10.0,
485+
ucb_overwrite_probability=0.0,
486+
pe_overwrite_probability=0.0,
487+
signal_to_noise_threshold=0.0,
488+
optimize_set_acquisition_for_exploration=(
489+
optimize_set_acquisition_for_exploration
490+
),
491+
),
492+
padding_schedule=padding.PaddingSchedule(
493+
num_trials=padding.PaddingType.MULTIPLES_OF_10
494+
),
495+
prior_acquisition=dummy_prior_acquisition,
496+
rng=jax.random.PRNGKey(1),
497+
)
498+
499+
trial_id = 1
500+
batch_size = 3
501+
iters = 2
502+
rng = jax.random.PRNGKey(1)
503+
all_trials = []
504+
# Simulates a batch suggestion loop that completes a full batch of
505+
# suggestions before asking for the next batch.
506+
for _ in range(iters):
507+
suggestions = designer.suggest(count=batch_size)
508+
self.assertLen(suggestions, batch_size)
509+
completed_trials = []
510+
for suggestion in suggestions:
511+
problem.search_space.assert_contains(suggestion.parameters)
512+
trial_id += 1
513+
measurement = vz.Measurement()
514+
for mi in problem.metric_information:
515+
measurement.metrics[mi.name] = float(
516+
jax.random.uniform(
517+
rng,
518+
minval=mi.min_value_or(lambda: -10.0),
519+
maxval=mi.max_value_or(lambda: 10.0),
520+
)
521+
)
522+
rng, _ = jax.random.split(rng)
523+
completed_trials.append(
524+
suggestion.to_trial(trial_id).complete(measurement)
525+
)
526+
all_trials.extend(completed_trials)
527+
designer.update(
528+
completed=abstractions.CompletedTrials(completed_trials),
529+
all_active=abstractions.ActiveTrials(),
530+
)
531+
532+
self.assertLen(all_trials, iters * batch_size)
533+
534+
for idx, trial in enumerate(all_trials):
535+
if idx < batch_size:
536+
# Skips the first batch of suggestions, which are generated by the
537+
# seeding designer, not acquisition function optimization.
538+
continue
539+
mean, stddev, stddev_from_all, acq, use_ucb = _extract_predictions(
540+
trial.metadata.ns('gp_ucb_pe_bandit_test')
541+
)
542+
prior_acq_value = float(
543+
trial.metadata.ns('gp_ucb_pe_bandit_test')
544+
.ns('prior_acquisition')
545+
.get('value')
546+
)
547+
self.assertEqual(prior_acq_value, 12345.0)
548+
if idx % batch_size == 0:
549+
# The first suggestion in a batch is expected to be generated by UCB,
550+
# and the acquisition value is expected to be the sum of UCB and the
551+
# prior acquisition value.
552+
self.assertTrue(use_ucb)
553+
self.assertAlmostEqual(
554+
mean + 10.0 * stddev + prior_acq_value,
555+
acq,
556+
)
557+
else:
558+
# Later suggestions in a batch are expected to be generated by PE,
559+
# and the acquisition value is expected to be the sum of PE and the
560+
# prior acquisition value.
561+
self.assertFalse(use_ucb)
562+
if not optimize_set_acquisition_for_exploration:
563+
self.assertAlmostEqual(stddev_from_all + prior_acq_value, acq)
564+
452565

453566
if __name__ == '__main__':
454567
jax.config.update('jax_enable_x64', True)

0 commit comments

Comments
 (0)