Skip to content

Commit 25d79ec

Browse files
mgarrardfacebook-github-bot
authored andcommitted
Clean up center generation node implementation (#4734)
Summary: This diff cleans up center generation node logic to hopefully make it a bit more eloquent. It does the following: * Removes AutoTransitionAfterGenOrExhaustion as it is duplicative of AutoTransitionAfterGen, it just has a different defalt * Updates custom gen override to skip if either (1) the center point already exists or (2) we were unable to find a suitable center point * makes update_generator_state to be a pass through since we'll set the property during gen override *only compute center params once by adding a property on the center generation node Notably, it would be only in a misconfigured state that we would use sobol fallback now, this is consistent with David's initial changes, see D87797277 for explaination of that chage -- we could even remove sobol fallback from the node if we want Reviewed By: saitcakmak Differential Revision: D90040731 Privacy Context Container: L1307644
1 parent c62f3fd commit 25d79ec

5 files changed

Lines changed: 13 additions & 159 deletions

File tree

ax/generation_strategy/center_generation_node.py

Lines changed: 11 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -21,9 +21,7 @@
2121
from ax.exceptions.generation_strategy import AxGenerationException
2222
from ax.generation_strategy.external_generation_node import ExternalGenerationNode
2323
from ax.generation_strategy.generator_spec import GeneratorSpec
24-
from ax.generation_strategy.transition_criterion import (
25-
AutoTransitionAfterGenOrExhaustion,
26-
)
24+
from ax.generation_strategy.transition_criterion import AutoTransitionAfterGen
2725
from pyre_extensions import none_throws
2826

2927

@@ -43,7 +41,7 @@ def __init__(self, next_node_name: str) -> None:
4341
super().__init__(
4442
name="CenterOfSearchSpace",
4543
transition_criteria=[
46-
AutoTransitionAfterGenOrExhaustion(
44+
AutoTransitionAfterGen(
4745
transition_to=next_node_name,
4846
continue_trial_generation=False,
4947
)
@@ -58,9 +56,12 @@ def __init__(self, next_node_name: str) -> None:
5856
),
5957
**self.fallback_specs, # This includes the default fallbacks.
6058
}
59+
# custom property to enable single center point computation
60+
self._center_params: TParameterization | None = None
6161

6262
def update_generator_state(self, experiment: Experiment, data: Data) -> None:
63-
self.search_space = experiment.search_space
63+
# State is already set in gen() and will persist during generation
64+
pass
6465

6566
def gen(
6667
self,
@@ -79,27 +80,26 @@ def gen(
7980
before attempting generation. If so, it sets _should_skip to True and
8081
returns None, allowing the generation strategy to transition to the next node.
8182
"""
82-
# Check if center already exists or is infeasible
8383
self.search_space = experiment.search_space
84-
center_params = self.compute_center_params()
84+
self._center_params = self.compute_center_params()
8585

8686
# Check if unable to find a suitable center
87-
if center_params is None:
87+
if self._center_params is None:
8888
self._should_skip = True
8989
return None
9090

9191
# Check if center already exists in experiment
92-
center_arm = Arm(parameters=center_params)
92+
center_arm = Arm(parameters=self._center_params)
9393
if center_arm.signature in experiment.arms_by_signature:
9494
self._should_skip = True
9595
return None
9696

97-
# Otherwise, proceed with normal generation
9897
return super().gen(
9998
experiment=experiment,
10099
pending_observations=pending_observations,
101100
skip_fit=skip_fit,
102101
data=data,
102+
n=n,
103103
arms_per_node=arms_per_node,
104104
**gs_gen_kwargs,
105105
)
@@ -164,12 +164,4 @@ def get_next_candidate(
164164
parameter bounds and parameter constraints w.r.t non-log range parameters.
165165
This finds the center of the largest inscribed ball in the feasible region.
166166
"""
167-
center_params = self.compute_center_params()
168-
if center_params is None:
169-
# raising an exception here will cause fallback to sobol, currently
170-
# it should be very unlikely to hit this case
171-
raise AxGenerationException(
172-
"Center of the search space does not satisfy parameter "
173-
"constraints. The generation strategy will fallback to Sobol. "
174-
)
175-
return center_params
167+
return none_throws(self._center_params)

ax/generation_strategy/external_generation_node.py

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -23,8 +23,6 @@
2323
from ax.generation_strategy.transition_criterion import TransitionCriterion
2424

2525

26-
# TODO[drfreund]: Introduce a `GenerationNodeInterface` to
27-
# make inheritance/overriding of `GenNode` methods cleaner.
2826
class ExternalGenerationNode(GenerationNode, ABC):
2927
"""A generation node intended to be used with non-Ax methods for
3028
candidate generation.

ax/generation_strategy/tests/test_center_generation_node.py

Lines changed: 2 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,6 @@
1010
from unittest.mock import patch
1111

1212
from ax.adapter.registry import Generators
13-
1413
from ax.core.arm import Arm
1514
from ax.core.experiment import Experiment
1615
from ax.core.parameter import (
@@ -30,9 +29,7 @@
3029
from ax.generation_strategy.generation_node import GenerationNode
3130
from ax.generation_strategy.generation_strategy import GenerationStrategy
3231
from ax.generation_strategy.generator_spec import GeneratorSpec
33-
from ax.generation_strategy.transition_criterion import (
34-
AutoTransitionAfterGenOrExhaustion,
35-
)
32+
from ax.generation_strategy.transition_criterion import AutoTransitionAfterGen
3633
from ax.utils.common.testutils import TestCase
3734
from ax.utils.testing.core_stubs import get_branin_experiment
3835
from pyre_extensions import none_throws
@@ -123,7 +120,7 @@ def test_center_generation(self) -> None:
123120
self.assertEqual(
124121
self.node.transition_criteria,
125122
[
126-
AutoTransitionAfterGenOrExhaustion(
123+
AutoTransitionAfterGen(
127124
transition_to="test", continue_trial_generation=False
128125
)
129126
],

ax/generation_strategy/tests/test_transition_criterion.py

Lines changed: 0 additions & 66 deletions
Original file line numberDiff line numberDiff line change
@@ -11,12 +11,10 @@
1111

1212
import pandas as pd
1313
from ax.adapter.registry import Generators
14-
from ax.core.arm import Arm
1514
from ax.core.auxiliary import AuxiliaryExperiment, AuxiliaryExperimentPurpose
1615
from ax.core.data import Data
1716
from ax.core.trial_status import TrialStatus
1817
from ax.exceptions.core import UserInputError
19-
from ax.generation_strategy.center_generation_node import CenterGenerationNode
2018
from ax.generation_strategy.generation_strategy import (
2119
GenerationNode,
2220
GenerationStep,
@@ -25,7 +23,6 @@
2523
from ax.generation_strategy.generator_spec import GeneratorSpec
2624
from ax.generation_strategy.transition_criterion import (
2725
AutoTransitionAfterGen,
28-
AutoTransitionAfterGenOrExhaustion,
2926
AuxiliaryExperimentCheck,
3027
IsSingleObjective,
3128
MaxGenerationParallelism,
@@ -376,69 +373,6 @@ def test_auto_with_should_skip_node(self) -> None:
376373
.is_met(experiment=experiment, curr_node=gs._nodes[0])
377374
)
378375

379-
def test_auto_transition_after_gen_or_exhaustion(self) -> None:
380-
"""Test AutoTransitionAfterGenOrExhaustion transitions after generation
381-
or when search space is exhausted.
382-
"""
383-
experiment = self.branin_experiment
384-
385-
# Test 1: Transition after successful generation (like AutoTransitionAfterGen)
386-
gs = GenerationStrategy(
387-
name="test",
388-
nodes=[
389-
GenerationNode(
390-
name="sobol_1",
391-
generator_specs=[self.sobol_generator_spec],
392-
transition_criteria=[
393-
AutoTransitionAfterGenOrExhaustion(transition_to="sobol_2")
394-
],
395-
),
396-
GenerationNode(
397-
name="sobol_2", generator_specs=[self.sobol_generator_spec]
398-
),
399-
],
400-
)
401-
gs.experiment = experiment
402-
403-
# Generate from first node
404-
gs.gen(experiment=experiment)
405-
self.assertEqual(gs.current_node_name, "sobol_1")
406-
407-
# Should transition to next node on next gen after generating
408-
gs.gen(experiment=experiment)
409-
self.assertEqual(gs.current_node_name, "sobol_2")
410-
411-
# Test 2: Transition immediately when search space is exhausted
412-
# Use CenterGenerationNode which can only generate one unique candidate
413-
experiment2 = get_branin_experiment()
414-
# Add the center point so it's already in the experiment
415-
center_arm = Arm(parameters={"x1": 2.5, "x2": 7.5})
416-
experiment2.new_trial().add_arm(arm=center_arm)
417-
418-
gs2 = GenerationStrategy(
419-
name="test_exhaustion",
420-
nodes=[
421-
CenterGenerationNode(next_node_name="sobol"),
422-
GenerationNode(
423-
name="sobol", generator_specs=[self.sobol_generator_spec]
424-
),
425-
],
426-
)
427-
gs2.experiment = experiment2
428-
429-
# Since center already exists, should skip CenterGenerationNode
430-
# and transition directly to sobol
431-
gr = gs2.gen(experiment=experiment2, n=1)[0]
432-
self.assertEqual(gr[0]._generation_node_name, "sobol")
433-
self.assertEqual(gs2.current_node_name, "sobol")
434-
435-
# Test 3: Call block_continued_generation_error
436-
criterion = AutoTransitionAfterGenOrExhaustion(transition_to="sobol_2")
437-
# This method has a pass statement, so calling it should not raise an error
438-
criterion.block_continued_generation_error(
439-
node_name="sobol_1", experiment=experiment, trials_from_node=set()
440-
)
441-
442376
def test_is_single_objective_does_not_transition(self) -> None:
443377
exp = self.branin_experiment
444378
exp.optimization_config = get_branin_multi_objective_optimization_config()

ax/generation_strategy/transition_criterion.py

Lines changed: 0 additions & 67 deletions
Original file line numberDiff line numberDiff line change
@@ -603,73 +603,6 @@ def block_continued_generation_error(
603603
pass
604604

605605

606-
class AutoTransitionAfterGenOrExhaustion(TransitionCriterion):
607-
"""A class to designate automatic transition from one GenerationNode to another
608-
after generating a candidate or when the search space is exhausted.
609-
610-
This criterion is met when either:
611-
1. A GeneratorRun is generated by this GenerationNode, OR
612-
2. The search space is exhausted (no more unique candidates can be generated)
613-
614-
This is particularly useful for nodes that have a limited search space, such as
615-
CenterGenerationNode, which can only generate one unique candidate (the center).
616-
If that candidate already exists or is infeasible, the search space is exhausted and
617-
the node should transition to the next node.
618-
619-
Args:
620-
transition_to: The name of the GenerationNode the GenerationStrategy should
621-
transition to next.
622-
block_transition_if_unmet: A flag to prevent the node from completing and
623-
being able to transition to another node. This criterion defaults to
624-
setting this to True to ensure we validate that either a GeneratorRun was
625-
generated or the search space is exhausted.
626-
continue_trial_generation: A flag to indicate that all generation for a given
627-
trial is not completed, and thus even after transition, the next node will
628-
continue to generate arms for the same trial. Example usage: in
629-
``BatchTrial``s we may enable generation of arms within a batch from
630-
different ``GenerationNodes`` by setting this flag to True.
631-
"""
632-
633-
def __init__(
634-
self,
635-
transition_to: str,
636-
block_transition_if_unmet: bool | None = True,
637-
continue_trial_generation: bool | None = False,
638-
) -> None:
639-
super().__init__(
640-
transition_to=transition_to,
641-
block_transition_if_unmet=block_transition_if_unmet,
642-
continue_trial_generation=continue_trial_generation,
643-
)
644-
645-
def is_met(
646-
self,
647-
experiment: Experiment,
648-
curr_node: GenerationNode,
649-
) -> bool:
650-
"""Return True if any GeneratorRun is generated by this GenerationNode
651-
or if the node should be skipped due to search space exhaustion.
652-
"""
653-
if curr_node._should_skip:
654-
return True
655-
last_gr_from_gs = curr_node.generation_strategy.last_generator_run
656-
if (
657-
last_gr_from_gs is not None
658-
and last_gr_from_gs._generation_node_name == curr_node.name
659-
):
660-
return True
661-
return False
662-
663-
def block_continued_generation_error(
664-
self,
665-
node_name: str,
666-
experiment: Experiment,
667-
trials_from_node: set[int],
668-
) -> None:
669-
"""Error to be raised if the `block_gen_if_met` flag is set to True."""
670-
pass
671-
672-
673606
class AuxiliaryExperimentCheck(TransitionCriterion):
674607
"""A class to transition from one GenerationNode to another by checking if certain
675608
types of Auxiliary Experiment purposes exists.

0 commit comments

Comments
 (0)