@@ -97,6 +97,14 @@ class MutateNormalizationType(enum.IntEnum):
97
97
UNNORMALIZED = 2
98
98
99
99
100
+ @enum .unique
101
+ class ContinuousFeaturePerturbationType (enum .IntEnum ):
102
+ """The type of perturbation to apply to continuous features."""
103
+
104
+ ADDITIVE = 0
105
+ MULTIPLICATIVE = 1
106
+
107
+
100
108
@struct .dataclass
101
109
class EagleStrategyConfig :
102
110
"""Eagle Strategy optimizer config.
@@ -106,6 +114,8 @@ class EagleStrategyConfig:
106
114
gravity: The maximum amount of attraction pull.
107
115
negative_gravity: The maximum amount of repulsion pull.
108
116
perturbation: The default amount of noise for perturbation.
117
+ continuous_feature_perturbation_type: The type of perturbation to apply to
118
+ continuous features.
109
119
categorical_perturbation_factor: A factor to apply on categorical params.
110
120
pure_categorical_perturbation_factor: A factor on purely categorical space.
111
121
prob_same_category_without_perturbation: Baseline probability of selecting
@@ -130,6 +140,9 @@ class EagleStrategyConfig:
130
140
negative_gravity : float = 0.008
131
141
# Perturbation
132
142
perturbation : float = 0.16
143
+ continuous_feature_perturbation_type : ContinuousFeaturePerturbationType = (
144
+ ContinuousFeaturePerturbationType .ADDITIVE
145
+ )
133
146
categorical_perturbation_factor : float = 1.0
134
147
pure_categorical_perturbation_factor : float = 30
135
148
prob_same_category_without_perturbation : float = 0.98
@@ -873,14 +886,28 @@ def _create_features(
873
886
features_changes_continuous = jnp .matmul (
874
887
scale , flat_features
875
888
) - flat_features_batch * jnp .sum (scale , axis = - 1 , keepdims = True )
876
-
877
- new_features_continuous = (
878
- features_batch .continuous
879
- + jnp .reshape (
880
- features_changes_continuous , features_batch .continuous .shape
881
- )
882
- + perturbations_batch .continuous
889
+ moved_features_continuous = features_batch .continuous + jnp .reshape (
890
+ features_changes_continuous , features_batch .continuous .shape
883
891
)
892
+ if (
893
+ self .config .continuous_feature_perturbation_type
894
+ == ContinuousFeaturePerturbationType .ADDITIVE
895
+ ):
896
+ new_features_continuous = (
897
+ moved_features_continuous + perturbations_batch .continuous
898
+ )
899
+ elif (
900
+ self .config .continuous_feature_perturbation_type
901
+ == ContinuousFeaturePerturbationType .MULTIPLICATIVE
902
+ ):
903
+ new_features_continuous = moved_features_continuous * jnp .exp (
904
+ perturbations_batch .continuous
905
+ )
906
+ else :
907
+ raise ValueError (
908
+ "Unsupported continuous feature perturbation type:"
909
+ f" { self .config .continuous_feature_perturbation_type } "
910
+ )
884
911
if self .max_categorical_size > 0 :
885
912
features_categorical_logits = (
886
913
self ._create_categorical_feature_logits (
0 commit comments