Skip to content

Commit e4df197

Browse files
Sunny Shenfacebook-github-bot
authored andcommitted
Deprecate deduplicate setting for SOBOL Generator; Set GenerationNode.should_deduplicate=True by default
Summary: Previously, we had dedup settings on both `Generator` and `GenerationNode` level, which is unnecessary and error-prone for accidentally using 2 different dedup settings at the 2 levels (see D92094386) Given that only `SobolGenerator` has special dedup logic, we've decided to deprecate `deduplicate` setting in `Generator` and only rely on `GenerationNode` level deduplication. Furthermore, since we generally expect dedup to be True for OSS use cases and AutoML use cases, we set `GenerationNode.should_deduplicate=True` by default while explicitly overwriting GNode dedup to False for online use cases. **One Potential Issue** Currently `RandomAdapter._gen` solely relies on SOBOL generator to dedup (by passing in `generated_points` to sobol generator and then `rejection_sample`- https://fburl.com/code/zmqfhn8g) and if we remove that, anyone that's using standalone `RandomAdapter` (without a GNode) may get duplicated trials if they call `RandomAdapter.gen` multiple times. This is partially mitigated by advancing `init_position` on sobol generator -- there shouldn't be duplicated points for continuous parameters, but if there are discrete parameters we can be producing duplicated points after rounding. The recommended path is through GenerationStrategy/GenerationNode, where dedup is properly handled. But there's no enforcement on never using standalone RandomAdapter. (ps. `TorchAdapter` passes pending_observations through to the BoTorch acquisition function as X_pending so it doesn't have the same problem) Differential Revision: D92884352
1 parent 541c458 commit e4df197

File tree

18 files changed

+52
-403
lines changed

18 files changed

+52
-403
lines changed

ax/adapter/factory.py

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -34,7 +34,6 @@
3434
def get_sobol(
3535
search_space: SearchSpace,
3636
seed: int | None = None,
37-
deduplicate: bool = False,
3837
init_position: int = 0,
3938
scramble: bool = True,
4039
fallback_to_sample_polytope: bool = False,
@@ -52,7 +51,6 @@ def get_sobol(
5251
Generators.SOBOL(
5352
experiment=Experiment(search_space=search_space),
5453
seed=seed,
55-
deduplicate=deduplicate,
5654
init_position=init_position,
5755
scramble=scramble,
5856
fallback_to_sample_polytope=fallback_to_sample_polytope,

ax/adapter/random.py

Lines changed: 0 additions & 31 deletions
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,6 @@
99

1010
from collections.abc import Mapping, Sequence
1111

12-
import numpy as np
1312
from ax.adapter.adapter_utils import (
1413
extract_parameter_constraints,
1514
extract_search_space_digest,
@@ -93,35 +92,6 @@ def _gen(
9392
linear_constraints = extract_parameter_constraints(
9493
search_space.parameter_constraints, self.parameters
9594
)
96-
# Extract generated points to deduplicate against.
97-
# Exclude out-of-design arms (which can only be manual arms
98-
# instead of adapter-generated arms).
99-
generated_points = None
100-
if self.generator.deduplicate:
101-
arms_to_deduplicate = self._experiment.arms_by_signature_for_deduplication
102-
generated_obs = [
103-
ObservationFeatures.from_arm(arm=arm)
104-
for arm in arms_to_deduplicate.values()
105-
if self._search_space.check_membership(parameterization=arm.parameters)
106-
]
107-
# Transform
108-
for t in self.transforms.values():
109-
generated_obs = t.transform_observation_features(generated_obs)
110-
# Add pending observations -- already transformed.
111-
generated_obs.extend(
112-
[obs for obs_list in pending_observations.values() for obs in obs_list]
113-
)
114-
if len(generated_obs) > 0:
115-
# Extract generated points array (n x d).
116-
generated_points = np.array(
117-
[
118-
[obs.parameters[p] for p in self.parameters]
119-
for obs in generated_obs
120-
]
121-
)
122-
# Take unique points only, since there may be duplicates coming
123-
# from pending observations for different metrics.
124-
generated_points = np.unique(generated_points, axis=0)
12595

12696
# Generate the candidates
12797
X, w = self.generator.gen(
@@ -131,7 +101,6 @@ def _gen(
131101
fixed_features=fixed_features_dict,
132102
model_gen_options=model_gen_options,
133103
rounding_func=transform_callback(self.parameters, self.transforms),
134-
generated_points=generated_points,
135104
)
136105
observation_features = parse_observation_features(X, self.parameters)
137106
return GenResults(

ax/adapter/registry.py

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -173,6 +173,9 @@ class GeneratorSetup(NamedTuple):
173173
default_generator_kwargs: Mapping[str, Any] | None = None
174174
standard_adapter_kwargs: Mapping[str, Any] | None = None
175175
not_saved_generator_kwargs: Sequence[str] | None = None
176+
# Kwargs that were removed from the generator but may still exist in
177+
# serialized data. These are silently filtered out before validation.
178+
deprecated_generator_kwargs: Sequence[str] | None = None
176179

177180

178181
"""A mapping of string keys that indicate a generator, to the corresponding
@@ -209,6 +212,8 @@ class GeneratorSetup(NamedTuple):
209212
adapter_class=RandomAdapter,
210213
generator_class=SobolGenerator,
211214
transforms=Cont_X_trans,
215+
# These kwargs were removed but may exist in old serialized data.
216+
deprecated_generator_kwargs=["deduplicate"],
212217
),
213218
"Uniform": GeneratorSetup(
214219
adapter_class=RandomAdapter,
@@ -291,6 +296,11 @@ def __call__(
291296
adapter_class = model_setup_info.adapter_class
292297
search_space = experiment.search_space
293298

299+
# Filter out deprecated kwargs that may exist in old serialized data.
300+
if model_setup_info.deprecated_generator_kwargs:
301+
for deprecated_key in model_setup_info.deprecated_generator_kwargs:
302+
kwargs.pop(deprecated_key, None)
303+
294304
if not silently_filter_kwargs:
295305
# Check correct kwargs are present
296306
callables = (generator_class, adapter_class)

ax/adapter/tests/test_random_adapter.py

Lines changed: 1 addition & 141 deletions
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,6 @@
66

77
# pyre-strict
88

9-
import dataclasses
109
from unittest import mock
1110

1211
import numpy as np
@@ -20,15 +19,10 @@
2019
from ax.core.parameter import ParameterType, RangeParameter
2120
from ax.core.parameter_constraint import ParameterConstraint
2221
from ax.core.search_space import SearchSpace
23-
from ax.exceptions.core import SearchSpaceExhausted
2422
from ax.generators.random.base import RandomGenerator
2523
from ax.generators.random.sobol import SobolGenerator
2624
from ax.utils.common.testutils import TestCase
27-
from ax.utils.testing.core_stubs import (
28-
get_data,
29-
get_search_space_for_range_values,
30-
get_small_discrete_search_space,
31-
)
25+
from ax.utils.testing.core_stubs import get_data
3226
from ax.utils.testing.modeling_stubs import get_experiment_for_value
3327

3428

@@ -147,22 +141,6 @@ def test_gen_simple(self) -> None:
147141
self.assertIsNone(gen_args["linear_constraints"])
148142
self.assertIsNone(gen_args["fixed_features"])
149143

150-
def test_deduplicate(self) -> None:
151-
exp = Experiment(search_space=get_small_discrete_search_space())
152-
sobol = RandomAdapter(
153-
experiment=exp,
154-
generator=SobolGenerator(deduplicate=True),
155-
transforms=Cont_X_trans,
156-
)
157-
for _ in range(4): # Search space is {[0, 1], {"red", "panda"}}
158-
# Generate & attach trials to the experiment so that the
159-
# generated points are used for deduplication.
160-
gr = sobol.gen(1)
161-
exp.new_trial(generator_run=gr).mark_running(no_runner_required=True)
162-
self.assertEqual(len(gr.arms), 1)
163-
with self.assertRaises(SearchSpaceExhausted):
164-
sobol.gen(1)
165-
166144
def test_search_space_not_expanded(self) -> None:
167145
data = get_data(num_non_sq_arms=0)
168146
sq_arm = Arm(name="status_quo", parameters={"x": 10.0, "y": 1.0, "z": 1.0})
@@ -186,124 +164,6 @@ def test_search_space_not_expanded(self) -> None:
186164
sobol.gen(1)
187165
self.assertEqual(sobol._model_space, sobol._search_space)
188166

189-
def test_generated_points(self) -> None:
190-
# Checks for generated points argument passed to Generator.gen.
191-
# Search space has two range parameters in [0, 5].
192-
exp = Experiment(
193-
search_space=get_search_space_for_range_values(min=0.0, max=5.0)
194-
)
195-
ssd = extract_search_space_digest(
196-
search_space=exp.search_space,
197-
param_names=list(exp.search_space.parameters.keys()),
198-
)
199-
ssd = dataclasses.replace(ssd, bounds=[(0.0, 1.0), (0.0, 1.0)])
200-
generator = SobolGenerator(deduplicate=True)
201-
gen_res = generator.gen(n=1, search_space_digest=ssd, rounding_func=lambda x: x)
202-
# Using Cont_X_trans, particularly UnitX here to test transform application.
203-
adapter = RandomAdapter(
204-
experiment=exp, generator=generator, transforms=Cont_X_trans
205-
)
206-
207-
# No pending points or previous trials on the experiment.
208-
with mock.patch.object(generator, "gen", return_value=gen_res) as mock_gen:
209-
adapter.gen(n=1)
210-
self.assertIsNone(mock_gen.call_args.kwargs["generated_points"])
211-
212-
# Attach two trials to the experiment.
213-
exp.new_trial().add_arm(Arm(parameters={"x": 0.0, "y": 0.0})).mark_running(
214-
no_runner_required=True
215-
)
216-
exp.new_trial().add_arm(Arm(parameters={"x": 2.0, "y": 2.0})).mark_running(
217-
no_runner_required=True
218-
)
219-
with mock.patch.object(generator, "gen", return_value=gen_res) as mock_gen:
220-
adapter.gen(n=1)
221-
self.assertEqual(
222-
mock_gen.call_args.kwargs["generated_points"].tolist(),
223-
[[0.0, 0.0], [0.4, 0.4]],
224-
)
225-
226-
# Add pending points -- only unique ones should be passed down.
227-
pending_observations = {
228-
m: [ObservationFeatures(parameters={"x": 3.0, "y": 3.0})]
229-
for m in ("m1", "m2")
230-
}
231-
with mock.patch.object(generator, "gen", return_value=gen_res) as mock_gen:
232-
adapter.gen(n=1, pending_observations=pending_observations)
233-
self.assertEqual(
234-
mock_gen.call_args.kwargs["generated_points"].tolist(),
235-
[[0.0, 0.0], [0.4, 0.4], [0.6, 0.6]],
236-
)
237-
238-
# Turn off deduplicate, nothing should be passed down.
239-
generator.deduplicate = False
240-
with mock.patch.object(generator, "gen", return_value=gen_res) as mock_gen:
241-
adapter.gen(n=1, pending_observations=pending_observations)
242-
self.assertIsNone(mock_gen.call_args.kwargs["generated_points"])
243-
244-
# Test filtering out-of-design arms during deduplication
245-
# Create experiment with in-design and out-of-design arms
246-
exp_with_ood_arms = Experiment(
247-
search_space=get_search_space_for_range_values(min=0.0, max=5.0)
248-
)
249-
in_design_arm = Arm(
250-
name="in_design", parameters={"x": 2.0, "y": 3.0}
251-
) # Within [0, 5]
252-
out_of_design_arm = Arm(
253-
name="out_of_design", parameters={"x": 6.0, "y": 7.0}
254-
) # Outside [0, 5]
255-
256-
exp_with_ood_arms.new_trial().add_arm(in_design_arm).mark_running(
257-
no_runner_required=True
258-
)
259-
exp_with_ood_arms.new_trial().add_arm(out_of_design_arm).mark_running(
260-
no_runner_required=True
261-
)
262-
263-
generator = SobolGenerator(deduplicate=True)
264-
adapter_mixed = RandomAdapter(
265-
experiment=exp_with_ood_arms,
266-
generator=generator,
267-
transforms=Cont_X_trans,
268-
)
269-
270-
# Only the in-design arm should be included in generated_points
271-
with mock.patch.object(generator, "gen", return_value=gen_res) as mock_gen:
272-
adapter_mixed.gen(n=1)
273-
274-
generated_points = mock_gen.call_args.kwargs["generated_points"]
275-
self.assertEqual(len(generated_points), 1)
276-
277-
# Test case where all arms are out-of-design
278-
exp_all_out_of_design = Experiment(
279-
search_space=get_search_space_for_range_values(min=0.0, max=5.0)
280-
)
281-
out_of_design_arm1 = Arm(name="out1", parameters={"x": 6.0, "y": 7.0})
282-
out_of_design_arm2 = Arm(name="out2", parameters={"x": -1.0, "y": 8.0})
283-
284-
exp_all_out_of_design.new_trial().add_arm(out_of_design_arm1).mark_running(
285-
no_runner_required=True
286-
)
287-
exp_all_out_of_design.new_trial().add_arm(out_of_design_arm2).mark_running(
288-
no_runner_required=True
289-
)
290-
291-
generator_all_out = SobolGenerator(deduplicate=True)
292-
adapter_all_out = RandomAdapter(
293-
experiment=exp_all_out_of_design,
294-
generator=generator_all_out,
295-
transforms=Cont_X_trans,
296-
)
297-
298-
# When all arms are out-of-design, generated_points should be empty
299-
with mock.patch.object(
300-
generator_all_out, "gen", return_value=gen_res
301-
) as mock_gen:
302-
adapter_all_out.gen(n=1)
303-
304-
generated_points_all_out = mock_gen.call_args.kwargs["generated_points"]
305-
self.assertIsNone(generated_points_all_out)
306-
307167
def test_generation_with_all_fixed(self) -> None:
308168
# Make sure candidate generation succeeds and returns correct parameters
309169
# when all parameters are fixed.

ax/adapter/tests/test_registry.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -174,7 +174,6 @@ def test_view_defaults(self) -> None:
174174
(
175175
{
176176
"seed": None,
177-
"deduplicate": True,
178177
"init_position": 0,
179178
"scramble": True,
180179
"generated_points": None,
@@ -194,7 +193,7 @@ def test_view_defaults(self) -> None:
194193
self.assertTrue(
195194
all(
196195
kw in Generators.SOBOL.view_kwargs()[0]
197-
for kw in ["seed", "deduplicate", "init_position", "scramble"]
196+
for kw in ["seed", "init_position", "scramble"]
198197
),
199198
all(
200199
kw in Generators.SOBOL.view_kwargs()[1]

ax/benchmark/tests/test_benchmark.py

Lines changed: 1 addition & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -934,11 +934,7 @@ def test_replication_with_generation_node(self) -> None:
934934
nodes=[
935935
GenerationNode(
936936
name="Sobol",
937-
generator_specs=[
938-
GeneratorSpec(
939-
Generators.SOBOL, generator_kwargs={"deduplicate": True}
940-
)
941-
],
937+
generator_specs=[GeneratorSpec(Generators.SOBOL)],
942938
)
943939
]
944940
),

ax/generation_strategy/dispatch_utils.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -51,7 +51,7 @@ def _make_sobol_step(
5151
enforce_num_trials: bool = True,
5252
max_parallelism: int | None = None,
5353
seed: int | None = None,
54-
should_deduplicate: bool = False,
54+
should_deduplicate: bool = True,
5555
) -> GenerationStep:
5656
"""Shortcut for creating a Sobol generation step."""
5757
return GenerationStep(
@@ -61,7 +61,7 @@ def _make_sobol_step(
6161
min_trials_observed=min_trials_observed or ceil(num_trials / 2),
6262
enforce_num_trials=enforce_num_trials,
6363
max_parallelism=max_parallelism,
64-
generator_kwargs={"deduplicate": True, "seed": seed},
64+
generator_kwargs={"seed": seed},
6565
should_deduplicate=should_deduplicate,
6666
use_all_trials_in_exp=True,
6767
)
@@ -76,7 +76,7 @@ def _make_botorch_step(
7676
generator_kwargs: dict[str, Any] | None = None,
7777
winsorization_config: None
7878
| (WinsorizationConfig | dict[str, WinsorizationConfig]) = None,
79-
should_deduplicate: bool = False,
79+
should_deduplicate: bool = True,
8080
disable_progbar: bool | None = None,
8181
jit_compile: bool | None = None,
8282
derelativize_with_raw_status_quo: bool = False,
@@ -299,7 +299,7 @@ def choose_generation_strategy_legacy(
299299
max_parallelism_cap: int | None = None,
300300
max_parallelism_override: int | None = None,
301301
optimization_config: OptimizationConfig | None = None,
302-
should_deduplicate: bool = False,
302+
should_deduplicate: bool = True,
303303
use_saasbo: bool = False,
304304
disable_progbar: bool | None = None,
305305
jit_compile: bool | None = None,

ax/generation_strategy/generation_node.py

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -122,7 +122,6 @@ class GenerationNode(SerializationMixin, SortableBase):
122122

123123
# Required options:
124124
generator_specs: list[GeneratorSpec]
125-
# TODO: Move `should_deduplicate` to `GeneratorSpec` if possible, and make optional
126125
should_deduplicate: bool
127126
_name: str
128127

@@ -150,7 +149,7 @@ def __init__(
150149
generator_specs: list[GeneratorSpec],
151150
transition_criteria: Sequence[TransitionCriterion] | None = None,
152151
best_model_selector: BestModelSelector | None = None,
153-
should_deduplicate: bool = False,
152+
should_deduplicate: bool = True,
154153
input_constructors: TInputConstructorsByPurpose | None = None,
155154
previous_node_name: str | None = None,
156155
trial_type: str | None = None,
@@ -1014,7 +1013,7 @@ def __new__(
10141013
min_trials_observed: int = 0,
10151014
max_parallelism: int | None = None,
10161015
enforce_num_trials: bool = True,
1017-
should_deduplicate: bool = False,
1016+
should_deduplicate: bool = True,
10181017
generator_name: str | None = None,
10191018
use_all_trials_in_exp: bool = False,
10201019
use_update: bool = False, # DEPRECATED.

ax/generation_strategy/tests/test_dispatch_utils.py

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -860,22 +860,24 @@ def test_max_parallelism_adjustments(self) -> None:
860860
self.assertEqual(self._get_max_parallelism(sobol_gpei._nodes[1]), 10)
861861

862862
def test_set_should_deduplicate(self) -> None:
863+
# Default is now should_deduplicate=True
863864
sobol_gpei = choose_generation_strategy_legacy(
864865
search_space=get_branin_search_space(),
865866
use_batch_trials=True,
866867
num_initialization_trials=3,
867868
)
868869
self.assertListEqual(
869-
[s.should_deduplicate for s in sobol_gpei._nodes], [False] * 2
870+
[s.should_deduplicate for s in sobol_gpei._nodes], [True] * 2
870871
)
872+
# Explicitly set should_deduplicate=False
871873
sobol_gpei = choose_generation_strategy_legacy(
872874
search_space=get_branin_search_space(),
873875
use_batch_trials=True,
874876
num_initialization_trials=3,
875-
should_deduplicate=True,
877+
should_deduplicate=False,
876878
)
877879
self.assertListEqual(
878-
[s.should_deduplicate for s in sobol_gpei._nodes], [True] * 2
880+
[s.should_deduplicate for s in sobol_gpei._nodes], [False] * 2
879881
)
880882

881883
def test_setting_experiment_attribute(self) -> None:

0 commit comments

Comments
 (0)