Skip to content

Commit 0666bd6

Browse files
sdaultonfacebook-github-bot
authored andcommitted
call maybe transition to next node in Analyses and enable skipping center GN when trials are provided (#4922)
Summary: see title. This important for using a new GS on an experiment with data in model-based analyses. This adds support for * skipping the center GN when trials with data are provided * Transitioning the GS when possible when extracting the adapter in analyses (important for fast-forwarding the GS for an experiment with data). Differential Revision: D93804600
1 parent 3a77c2b commit 0666bd6

File tree

7 files changed

+67
-23
lines changed

7 files changed

+67
-23
lines changed

ax/analysis/utils.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -76,7 +76,7 @@ def extract_relevant_adapter(
7676
"Provided GenerationStrategy has no adapter, but no Experiment was "
7777
"provided to source data to fit the adapter."
7878
)
79-
79+
generation_strategy.maybe_transition_to_next_node(raise_data_required_error=False)
8080
generation_strategy.current_node._fit(experiment=experiment)
8181
adapter = generation_strategy.adapter
8282

ax/api/utils/generation_strategy_dispatch.py

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -270,7 +270,12 @@ def choose_generation_strategy(
270270
if struct.initialize_with_center and (
271271
struct.initialization_budget is None or struct.initialization_budget > 0
272272
):
273-
center_node = CenterGenerationNode(next_node_name=nodes[0].name)
273+
center_node = CenterGenerationNode(
274+
next_node_name=nodes[0].name,
275+
use_existing_trials_for_initialization=(
276+
struct.use_existing_trials_for_initialization
277+
),
278+
)
274279
nodes.insert(0, center_node)
275280
gs_name = f"Center+{gs_name}"
276281
return GenerationStrategy(name=gs_name, nodes=nodes)

ax/generation_strategy/center_generation_node.py

Lines changed: 25 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -29,12 +29,14 @@
2929
@dataclass(init=False)
3030
class CenterGenerationNode(ExternalGenerationNode):
3131
next_node_name: str
32+
use_existing_trials_for_initialization: bool
3233

3334
def __init__(
3435
self,
3536
next_node_name: str,
3637
suggested_experiment_status: ExperimentStatus
3738
| None = ExperimentStatus.INITIALIZATION,
39+
use_existing_trials_for_initialization: bool = False,
3840
) -> None:
3941
"""A generation node that samples the center of the search space.
4042
This generation node is only used to generate the first point of the experiment.
@@ -49,6 +51,9 @@ def __init__(
4951
the center point.
5052
suggested_experiment_status: Optional suggested experiment status for this
5153
node.
54+
use_existing_trials_for_initialization: If True and the experiment already
55+
has trials, this node will be skipped during transition checks
56+
outside the gen flow (e.g., when fitting a model for analysis).
5257
"""
5358
super().__init__(
5459
name="CenterOfSearchSpace",
@@ -63,6 +68,9 @@ def __init__(
6368
)
6469
self.search_space: SearchSpace | None = None
6570
self.next_node_name = next_node_name
71+
self.use_existing_trials_for_initialization = (
72+
use_existing_trials_for_initialization
73+
)
6674
self.fallback_specs: dict[type[Exception], GeneratorSpec] = {
6775
AxGenerationException: GeneratorSpec(
6876
generator_enum=Generators.SOBOL, generator_key_override="Fallback_Sobol"
@@ -72,6 +80,23 @@ def __init__(
7280
# custom property to enable single center point computation
7381
self._center_params: TParameterization | None = None
7482

83+
def should_transition_to_next_node(
84+
self, raise_data_required_error: bool = True
85+
) -> tuple[bool, str]:
86+
# Lazily evaluate skip condition for cases where gen() hasn't been
87+
# called (e.g., maybe_transition_to_next_node from analysis code).
88+
# When use_existing_trials_for_initialization is True, existing trials
89+
# count toward initialization, so the center node can be skipped.
90+
if (
91+
not self._should_skip
92+
and self.use_existing_trials_for_initialization
93+
and len(self.experiment.trials) > 0
94+
):
95+
self._should_skip = True
96+
return super().should_transition_to_next_node(
97+
raise_data_required_error=raise_data_required_error
98+
)
99+
75100
def update_generator_state(self, experiment: Experiment, data: Data) -> None:
76101
# State is already set in gen() and will persist during generation
77102
pass

ax/generation_strategy/generation_strategy.py

Lines changed: 31 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -328,7 +328,7 @@ def current_generator_run_limit(
328328
generate any more generator runs at all.
329329
"""
330330
try:
331-
self._maybe_transition_to_next_node(raise_data_required_error=False)
331+
self.maybe_transition_to_next_node(raise_data_required_error=False)
332332
except GenerationStrategyCompleted:
333333
return 0, True
334334

@@ -543,7 +543,7 @@ def _gen_with_multiple_nodes(
543543
# reset should skip as conditions may have changed, do not reset
544544
# until now so node properties can be as up to date as possible
545545
node_to_gen_from._should_skip = False
546-
transitioned = self._maybe_transition_to_next_node()
546+
transitioned = self._transition_to_next_node()
547547
try:
548548
gr = self._curr.gen(
549549
experiment=experiment,
@@ -604,24 +604,11 @@ def _should_continue_gen_for_trial(self) -> bool:
604604

605605
# ------------------------- Node selection logic helpers. -------------------------
606606

607-
def _maybe_transition_to_next_node(
608-
self,
609-
raise_data_required_error: bool = True,
610-
) -> bool:
611-
"""Moves this generation strategy to next node if the current node's
612-
transition criteria are met. This method is safe to use both when generating
613-
candidates or simply checking how many generator runs (to be made into trials)
614-
can currently be produced.
615-
616-
NOTE: this method raises ``GenerationStrategyCompleted`` error if the
617-
optimization is complete
618-
619-
Args:
620-
raise_data_required_error: Whether to raise ``DataRequiredError`` in the
621-
maybe_step_completed method in GenerationNode class.
607+
def _transition_to_next_node(self, raise_data_required_error: bool = True) -> bool:
608+
"""Attempts a single transition to the next node if criteria are met.
622609
623610
Returns:
624-
Whether generation strategy moved to the next node.
611+
Whether the generation strategy moved to the next node.
625612
"""
626613
move_to_next_node, next_node = self._curr.should_transition_to_next_node(
627614
raise_data_required_error=raise_data_required_error
@@ -634,3 +621,29 @@ def _maybe_transition_to_next_node(
634621
)
635622
self._curr = self.nodes_by_name[next_node]
636623
return move_to_next_node
624+
625+
def maybe_transition_to_next_node(
626+
self, raise_data_required_error: bool = True
627+
) -> bool:
628+
"""Moves this generation strategy to next node if the current node's
629+
transition criteria are met, advancing through multiple nodes if
630+
possible. This method is safe to use both when generating candidates or
631+
simply checking how many generator runs (to be made into trials) can
632+
currently be produced.
633+
634+
NOTE: this method raises ``GenerationStrategyCompleted`` error if the
635+
optimization is complete
636+
637+
Args:
638+
raise_data_required_error: Whether to raise ``DataRequiredError`` in the
639+
maybe_step_completed method in GenerationNode class.
640+
641+
Returns:
642+
Whether generation strategy moved to the next node.
643+
"""
644+
moved = False
645+
while self._transition_to_next_node(
646+
raise_data_required_error=raise_data_required_error
647+
):
648+
moved = True
649+
return moved

ax/generation_strategy/tests/test_center_generation_node.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -219,7 +219,8 @@ def test_deduplication(self) -> None:
219219
def test_repr(self) -> None:
220220
self.assertEqual(
221221
repr(self.node),
222-
"CenterGenerationNode(next_node_name='test')",
222+
"CenterGenerationNode(next_node_name='test',"
223+
" use_existing_trials_for_initialization=False)",
223224
)
224225

225226
def test_equality(self) -> None:

ax/service/ax_client.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1226,7 +1226,7 @@ def fit_model(self) -> None:
12261226
"At least one trial must be completed with data to fit a model."
12271227
)
12281228
# Check if we should transition before generating the next candidate.
1229-
self.generation_strategy._maybe_transition_to_next_node()
1229+
self.generation_strategy.maybe_transition_to_next_node()
12301230
self.generation_strategy._curr._fit(experiment=self.experiment)
12311231

12321232
def verify_trial_parameterization(

ax/service/tests/test_ax_client.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2597,7 +2597,7 @@ def test_get_pareto_optimal_points_objective_threshold_inference(
25972597
ax_client, _ = get_branin_currin_optimization_with_N_sobol_trials(
25982598
num_trials=20, include_objective_thresholds=False
25992599
)
2600-
ax_client.generation_strategy._maybe_transition_to_next_node()
2600+
ax_client.generation_strategy.maybe_transition_to_next_node()
26012601
ax_client.generation_strategy._curr._fit(experiment=ax_client.experiment)
26022602
with with_rng_seed(seed=RANDOM_SEED):
26032603
predicted_pareto = ax_client.get_pareto_optimal_parameters()

0 commit comments

Comments
 (0)