@@ -97,6 +97,16 @@ class MutateNormalizationType(enum.IntEnum):
97
97
UNNORMALIZED = 2
98
98
99
99
100
+ @enum .unique
101
+ class ContinuousFeaturePerturbationType (enum .IntEnum ):
102
+ """The type of perturbation to apply to continuous features."""
103
+
104
+ # Add the perturbation to the feature value.
105
+ ADDITIVE = 0
106
+ # Multiply the feature value by exp(perturbation).
107
+ MULTIPLICATIVE = 1
108
+
109
+
100
110
@struct .dataclass
101
111
class EagleStrategyConfig :
102
112
"""Eagle Strategy optimizer config.
@@ -106,6 +116,8 @@ class EagleStrategyConfig:
106
116
gravity: The maximum amount of attraction pull.
107
117
negative_gravity: The maximum amount of repulsion pull.
108
118
perturbation: The default amount of noise for perturbation.
119
+ continuous_feature_perturbation_type: The type of perturbation to apply to
120
+ continuous features.
109
121
categorical_perturbation_factor: A factor to apply on categorical params.
110
122
pure_categorical_perturbation_factor: A factor on purely categorical space.
111
123
prob_same_category_without_perturbation: Baseline probability of selecting
@@ -130,6 +142,9 @@ class EagleStrategyConfig:
130
142
negative_gravity : float = 0.008
131
143
# Perturbation
132
144
perturbation : float = 0.16
145
+ continuous_feature_perturbation_type : ContinuousFeaturePerturbationType = (
146
+ ContinuousFeaturePerturbationType .ADDITIVE
147
+ )
133
148
categorical_perturbation_factor : float = 1.0
134
149
pure_categorical_perturbation_factor : float = 30
135
150
prob_same_category_without_perturbation : float = 0.98
@@ -873,14 +888,28 @@ def _create_features(
873
888
features_changes_continuous = jnp .matmul (
874
889
scale , flat_features
875
890
) - flat_features_batch * jnp .sum (scale , axis = - 1 , keepdims = True )
876
-
877
- new_features_continuous = (
878
- features_batch .continuous
879
- + jnp .reshape (
880
- features_changes_continuous , features_batch .continuous .shape
881
- )
882
- + perturbations_batch .continuous
891
+ moved_features_continuous = features_batch .continuous + jnp .reshape (
892
+ features_changes_continuous , features_batch .continuous .shape
883
893
)
894
+ if (
895
+ self .config .continuous_feature_perturbation_type
896
+ == ContinuousFeaturePerturbationType .ADDITIVE
897
+ ):
898
+ new_features_continuous = (
899
+ moved_features_continuous + perturbations_batch .continuous
900
+ )
901
+ elif (
902
+ self .config .continuous_feature_perturbation_type
903
+ == ContinuousFeaturePerturbationType .MULTIPLICATIVE
904
+ ):
905
+ new_features_continuous = moved_features_continuous * jnp .exp (
906
+ perturbations_batch .continuous
907
+ )
908
+ else :
909
+ raise ValueError (
910
+ "Unsupported continuous feature perturbation type:"
911
+ f" { self .config .continuous_feature_perturbation_type } "
912
+ )
884
913
if self .max_categorical_size > 0 :
885
914
features_categorical_logits = (
886
915
self ._create_categorical_feature_logits (
0 commit comments