34
34
in addition to the "attraction" forces, meaning fireflies move towards the
35
35
bright spots as well as away from the dark spots. We also support non-decimal
36
36
parameter types (categorical, discrete, integer), and treat them uniquely when
37
- computing distance, adding pertrubation, and mutating fireflies.
37
+ computing distance, applying pertrubation, and mutating fireflies.
38
38
39
39
For more details, see the linked paper.
40
40
68
68
# Run the optimization.
69
69
trials = optimizer.optimize(problem_statement, objective_function)
70
70
"""
71
+
71
72
import dataclasses
72
73
# pylint: disable=g-long-lambda
73
74
@@ -97,6 +98,16 @@ class MutateNormalizationType(enum.IntEnum):
97
98
UNNORMALIZED = 2
98
99
99
100
101
+ @enum .unique
102
+ class ContinuousFeaturePerturbationType (enum .IntEnum ):
103
+ """The type of perturbation to apply to continuous features."""
104
+
105
+ # Add the perturbation to the feature value.
106
+ ADDITIVE = 0
107
+ # Multiply the feature value by exp(perturbation).
108
+ MULTIPLICATIVE = 1
109
+
110
+
100
111
@struct .dataclass
101
112
class EagleStrategyConfig :
102
113
"""Eagle Strategy optimizer config.
@@ -106,6 +117,8 @@ class EagleStrategyConfig:
106
117
gravity: The maximum amount of attraction pull.
107
118
negative_gravity: The maximum amount of repulsion pull.
108
119
perturbation: The default amount of noise for perturbation.
120
+ continuous_feature_perturbation_type: The type of perturbation to apply to
121
+ continuous features.
109
122
categorical_perturbation_factor: A factor to apply on categorical params.
110
123
pure_categorical_perturbation_factor: A factor on purely categorical space.
111
124
prob_same_category_without_perturbation: Baseline probability of selecting
@@ -130,6 +143,9 @@ class EagleStrategyConfig:
130
143
negative_gravity : float = 0.008
131
144
# Perturbation
132
145
perturbation : float = 0.16
146
+ continuous_feature_perturbation_type : ContinuousFeaturePerturbationType = (
147
+ ContinuousFeaturePerturbationType .ADDITIVE
148
+ )
133
149
categorical_perturbation_factor : float = 1.0
134
150
pure_categorical_perturbation_factor : float = 30
135
151
prob_same_category_without_perturbation : float = 0.98
@@ -873,14 +889,28 @@ def _create_features(
873
889
features_changes_continuous = jnp .matmul (
874
890
scale , flat_features
875
891
) - flat_features_batch * jnp .sum (scale , axis = - 1 , keepdims = True )
876
-
877
- new_features_continuous = (
878
- features_batch .continuous
879
- + jnp .reshape (
880
- features_changes_continuous , features_batch .continuous .shape
881
- )
882
- + perturbations_batch .continuous
892
+ moved_features_continuous = features_batch .continuous + jnp .reshape (
893
+ features_changes_continuous , features_batch .continuous .shape
883
894
)
895
+ if (
896
+ self .config .continuous_feature_perturbation_type
897
+ == ContinuousFeaturePerturbationType .ADDITIVE
898
+ ):
899
+ new_features_continuous = (
900
+ moved_features_continuous + perturbations_batch .continuous
901
+ )
902
+ elif (
903
+ self .config .continuous_feature_perturbation_type
904
+ == ContinuousFeaturePerturbationType .MULTIPLICATIVE
905
+ ):
906
+ new_features_continuous = moved_features_continuous * jnp .exp (
907
+ perturbations_batch .continuous
908
+ )
909
+ else :
910
+ raise ValueError (
911
+ "Unsupported continuous feature perturbation type:"
912
+ f" { self .config .continuous_feature_perturbation_type } "
913
+ )
884
914
if self .max_categorical_size > 0 :
885
915
features_categorical_logits = (
886
916
self ._create_categorical_feature_logits (
0 commit comments