Skip to content

Commit 4462eab

Browse files
vizier-teamcopybara-github
vizier-team
authored andcommitted
Update the Eagle Strategy to support multiplicative continuous feature perturbation.
PiperOrigin-RevId: 723098266
1 parent 7ce0ae8 commit 4462eab

File tree

2 files changed

+95
-8
lines changed

2 files changed

+95
-8
lines changed

vizier/_src/algorithms/optimizers/eagle_strategy.py

+38-8
Original file line numberDiff line numberDiff line change
@@ -34,7 +34,7 @@
3434
in addition to the "attraction" forces, meaning fireflies move towards the
3535
bright spots as well as away from the dark spots. We also support non-decimal
3636
parameter types (categorical, discrete, integer), and treat them uniquely when
37-
computing distance, adding pertrubation, and mutating fireflies.
37+
computing distance, applying pertrubation, and mutating fireflies.
3838
3939
For more details, see the linked paper.
4040
@@ -68,6 +68,7 @@
6868
# Run the optimization.
6969
trials = optimizer.optimize(problem_statement, objective_function)
7070
"""
71+
7172
import dataclasses
7273
# pylint: disable=g-long-lambda
7374

@@ -97,6 +98,16 @@ class MutateNormalizationType(enum.IntEnum):
9798
UNNORMALIZED = 2
9899

99100

101+
@enum.unique
102+
class ContinuousFeaturePerturbationType(enum.IntEnum):
103+
"""The type of perturbation to apply to continuous features."""
104+
105+
# Add the perturbation to the feature value.
106+
ADDITIVE = 0
107+
# Multiply the feature value by exp(perturbation).
108+
MULTIPLICATIVE = 1
109+
110+
100111
@struct.dataclass
101112
class EagleStrategyConfig:
102113
"""Eagle Strategy optimizer config.
@@ -106,6 +117,8 @@ class EagleStrategyConfig:
106117
gravity: The maximum amount of attraction pull.
107118
negative_gravity: The maximum amount of repulsion pull.
108119
perturbation: The default amount of noise for perturbation.
120+
continuous_feature_perturbation_type: The type of perturbation to apply to
121+
continuous features.
109122
categorical_perturbation_factor: A factor to apply on categorical params.
110123
pure_categorical_perturbation_factor: A factor on purely categorical space.
111124
prob_same_category_without_perturbation: Baseline probability of selecting
@@ -130,6 +143,9 @@ class EagleStrategyConfig:
130143
negative_gravity: float = 0.008
131144
# Perturbation
132145
perturbation: float = 0.16
146+
continuous_feature_perturbation_type: ContinuousFeaturePerturbationType = (
147+
ContinuousFeaturePerturbationType.ADDITIVE
148+
)
133149
categorical_perturbation_factor: float = 1.0
134150
pure_categorical_perturbation_factor: float = 30
135151
prob_same_category_without_perturbation: float = 0.98
@@ -873,14 +889,28 @@ def _create_features(
873889
features_changes_continuous = jnp.matmul(
874890
scale, flat_features
875891
) - flat_features_batch * jnp.sum(scale, axis=-1, keepdims=True)
876-
877-
new_features_continuous = (
878-
features_batch.continuous
879-
+ jnp.reshape(
880-
features_changes_continuous, features_batch.continuous.shape
881-
)
882-
+ perturbations_batch.continuous
892+
moved_features_continuous = features_batch.continuous + jnp.reshape(
893+
features_changes_continuous, features_batch.continuous.shape
883894
)
895+
if (
896+
self.config.continuous_feature_perturbation_type
897+
== ContinuousFeaturePerturbationType.ADDITIVE
898+
):
899+
new_features_continuous = (
900+
moved_features_continuous + perturbations_batch.continuous
901+
)
902+
elif (
903+
self.config.continuous_feature_perturbation_type
904+
== ContinuousFeaturePerturbationType.MULTIPLICATIVE
905+
):
906+
new_features_continuous = moved_features_continuous * jnp.exp(
907+
perturbations_batch.continuous
908+
)
909+
else:
910+
raise ValueError(
911+
"Unsupported continuous feature perturbation type:"
912+
f" {self.config.continuous_feature_perturbation_type}"
913+
)
884914
if self.max_categorical_size > 0:
885915
features_categorical_logits = (
886916
self._create_categorical_feature_logits(

vizier/_src/algorithms/optimizers/eagle_strategy_test.py

+57
Original file line numberDiff line numberDiff line change
@@ -31,6 +31,9 @@
3131
from absl.testing import absltest
3232

3333
tfd = tfp.distributions
34+
ContinuousFeaturePerturbationType = (
35+
eagle_strategy.ContinuousFeaturePerturbationType
36+
)
3437

3538

3639
def _create_logits_vector_simple(
@@ -437,6 +440,60 @@ def test_optimize_with_eagle(self):
437440
score_fn=lambda x, _: -jnp.sum(x.continuous.padded_array, 1), count=1
438441
)
439442

443+
def test_continuous_feature_perturbation_type(self):
444+
optimizer_additive_perturbation = vb.VectorizedOptimizerFactory(
445+
strategy_factory=eagle_strategy.VectorizedEagleStrategyFactory(
446+
eagle_config=eagle_strategy.EagleStrategyConfig(
447+
continuous_feature_perturbation_type=(
448+
ContinuousFeaturePerturbationType.ADDITIVE
449+
)
450+
)
451+
),
452+
max_evaluations=50,
453+
)(self.converter)
454+
expected_count = 4
455+
new_features_additive_perturbation = optimizer_additive_perturbation(
456+
score_fn=lambda x, _: -jnp.sum(x.continuous.padded_array, 1),
457+
count=expected_count,
458+
).features
459+
self.assertSequenceEqual(
460+
new_features_additive_perturbation.continuous.shape,
461+
(expected_count, 1, 2),
462+
)
463+
self.assertSequenceEqual(
464+
new_features_additive_perturbation.categorical.shape,
465+
(expected_count, 1, 2),
466+
)
467+
optimizer_mult_perturbation = vb.VectorizedOptimizerFactory(
468+
strategy_factory=eagle_strategy.VectorizedEagleStrategyFactory(
469+
eagle_config=eagle_strategy.EagleStrategyConfig(
470+
continuous_feature_perturbation_type=(
471+
ContinuousFeaturePerturbationType.MULTIPLICATIVE
472+
)
473+
)
474+
),
475+
max_evaluations=50,
476+
)(self.converter)
477+
new_features_mult_perturbation = optimizer_mult_perturbation(
478+
score_fn=lambda x, _: -jnp.sum(x.continuous.padded_array, 1),
479+
count=expected_count,
480+
).features
481+
self.assertSequenceEqual(
482+
new_features_mult_perturbation.continuous.shape,
483+
(expected_count, 1, 2),
484+
)
485+
self.assertSequenceEqual(
486+
new_features_mult_perturbation.categorical.shape,
487+
(expected_count, 1, 2),
488+
)
489+
self.assertGreater(
490+
np.abs(
491+
new_features_additive_perturbation.continuous
492+
- new_features_mult_perturbation.continuous
493+
).max(),
494+
1e-1,
495+
)
496+
440497
def test_optimize_with_eagle_continuous_only(self):
441498
problem = vz.ProblemStatement()
442499
root = problem.search_space.select_root()

0 commit comments

Comments
 (0)