Skip to content

Commit 216e70d

Browse files
vizier-teamcopybara-github
vizier-team
authored andcommitted
Update the Eagle Strategy to support multiplicative continuous feature perturbation.
PiperOrigin-RevId: 723098266
1 parent 7ce0ae8 commit 216e70d

File tree

1 file changed

+34
-7
lines changed

1 file changed

+34
-7
lines changed

vizier/_src/algorithms/optimizers/eagle_strategy.py

+34-7
Original file line numberDiff line numberDiff line change
@@ -97,6 +97,14 @@ class MutateNormalizationType(enum.IntEnum):
9797
UNNORMALIZED = 2
9898

9999

100+
@enum.unique
101+
class ContinuousFeaturePerturbationType(enum.IntEnum):
102+
"""The type of perturbation to apply to continuous features."""
103+
104+
ADDITIVE = 0
105+
MULTIPLICATIVE = 1
106+
107+
100108
@struct.dataclass
101109
class EagleStrategyConfig:
102110
"""Eagle Strategy optimizer config.
@@ -106,6 +114,8 @@ class EagleStrategyConfig:
106114
gravity: The maximum amount of attraction pull.
107115
negative_gravity: The maximum amount of repulsion pull.
108116
perturbation: The default amount of noise for perturbation.
117+
continuous_feature_perturbation_type: The type of perturbation to apply to
118+
continuous features.
109119
categorical_perturbation_factor: A factor to apply on categorical params.
110120
pure_categorical_perturbation_factor: A factor on purely categorical space.
111121
prob_same_category_without_perturbation: Baseline probability of selecting
@@ -130,6 +140,9 @@ class EagleStrategyConfig:
130140
negative_gravity: float = 0.008
131141
# Perturbation
132142
perturbation: float = 0.16
143+
continuous_feature_perturbation_type: ContinuousFeaturePerturbationType = (
144+
ContinuousFeaturePerturbationType.ADDITIVE
145+
)
133146
categorical_perturbation_factor: float = 1.0
134147
pure_categorical_perturbation_factor: float = 30
135148
prob_same_category_without_perturbation: float = 0.98
@@ -873,14 +886,28 @@ def _create_features(
873886
features_changes_continuous = jnp.matmul(
874887
scale, flat_features
875888
) - flat_features_batch * jnp.sum(scale, axis=-1, keepdims=True)
876-
877-
new_features_continuous = (
878-
features_batch.continuous
879-
+ jnp.reshape(
880-
features_changes_continuous, features_batch.continuous.shape
881-
)
882-
+ perturbations_batch.continuous
889+
moved_features_continuous = features_batch.continuous + jnp.reshape(
890+
features_changes_continuous, features_batch.continuous.shape
883891
)
892+
if (
893+
self.config.continuous_feature_perturbation_type
894+
== ContinuousFeaturePerturbationType.ADDITIVE
895+
):
896+
new_features_continuous = (
897+
moved_features_continuous + perturbations_batch.continuous
898+
)
899+
elif (
900+
self.config.continuous_feature_perturbation_type
901+
== ContinuousFeaturePerturbationType.MULTIPLICATIVE
902+
):
903+
new_features_continuous = moved_features_continuous * jnp.exp(
904+
perturbations_batch.continuous
905+
)
906+
else:
907+
raise ValueError(
908+
"Unsupported continuous feature perturbation type:"
909+
f" {self.config.continuous_feature_perturbation_type}"
910+
)
884911
if self.max_categorical_size > 0:
885912
features_categorical_logits = (
886913
self._create_categorical_feature_logits(

0 commit comments

Comments
 (0)