Skip to content

Commit debdf00

Browse files
mgarrardmeta-codesync[bot]
authored andcommitted
Split TransitionCriterion into TransitionCriterion and GenerationCriterion (facebook#4854)
Summary: Pull Request resolved: facebook#4854 **TLDR:** This diff splits TransitionCriterion into (1) TransitionCriterion and (2) GenerationBlockingCriterion. I think this makes sense to do because it *greatly* increases the conceptual clarity of the transition criterion. Some ways it does this include: 1. Removal of confusing dual purpose flags — block_transition_if_unmet and block_generation_if_met flags. Now transition criteria are inferred to block transition if unmet and generation criteria are inferred to raise informative errors if the criteria is met. 2. Each criterion contains less flags, and the flags are more directly intuitive. 3. With upcoming removal of special logic for online, we will need to add more generation blocking criteria (ie do we have an opt config), it is better to make this change before adding more criteria that will need to be migrated 4. It will allows the logic for transition and generation to be smoother — this diff keeps things ~= to exisiting logic as possible to minimize diff review overhead, but in subsequent diffs we can save fit time if we know we can’t generate from this node + can’t transition. It will also allow for some further clarification on generation/transition blocking logic that i think is contributing to the confusion of the file 5. i like that creating a new generation blocking criteria with a specific error to raise is easy and painless **Cons of this change:** - it’s a large change, sorry about that. - There is some duplication between TrialBased transition criterion and generation criterion. I explored using a Mixin here, but i find mixins tend to add unnecessary inheritance structures to reason about. **Most important files for review, in order of importance** 1. transition_criterion.py 2. generation_node.py 3. decoder.py 4. encoders.py 5. registery.py 6. generation_strategy_dispatch.py 7. generation_nodes.py 8. generation_strategy.py The remaining files are mainly trivial updates to tests **Note about backwards compatibility:** * This diff will directly decode legacy MaxGenerationParallelism as a generation blocking criterion called MaxGenrationParallelism * Historically, there are some instances of mintrials that have block_gen_if_met=True, this usually comes from enforce_num_trials=True. Now we call this MaxTrialsAwaitingData, and MinTrials is decoded as that. I am open to other, better names for this new criterion. **Other notes/potential improvements:** - we could split transition criterion, generation criterion, and utils into their own files. i kinda like them together, and if we do want to do this split i’d like to do it in a follow up to try to minimize an already v large blast radius Reviewed By: lena-kashtelyan Differential Revision: D92201085 fbshipit-source-id: 27a8b6c6afe81d35d1c41b7e45c384a2934c3257
1 parent c2f9aec commit debdf00

20 files changed

Lines changed: 947 additions & 518 deletions

ax/api/utils/generation_strategy_dispatch.py

Lines changed: 13 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,7 @@
2222
GenerationStrategy,
2323
)
2424
from ax.generation_strategy.generator_spec import GeneratorSpec
25-
from ax.generation_strategy.transition_criterion import MinTrials
25+
from ax.generation_strategy.transition_criterion import MaxTrialsAwaitingData, MinTrials
2626
from ax.generators.torch.botorch_modular.surrogate import ModelConfig, SurrogateSpec
2727
from botorch.acquisition.acquisition import AcquisitionFunction
2828
from botorch.models.fully_bayesian import SaasFullyBayesianSingleTaskGP
@@ -71,21 +71,28 @@ def _get_sobol_node(
7171
MinTrials( # This represents the initialization budget.
7272
threshold=initialization_budget,
7373
transition_to="MBM",
74-
block_gen_if_met=(not allow_exceeding_initialization_budget),
75-
block_transition_if_unmet=True,
7674
use_all_trials_in_exp=use_existing_trials_for_initialization,
7775
not_in_statuses=[TrialStatus.FAILED, TrialStatus.ABANDONED],
7876
),
7977
MinTrials( # This represents minimum observed trials requirement.
8078
threshold=min_observed_initialization_trials,
8179
transition_to="MBM",
82-
block_gen_if_met=False,
83-
block_transition_if_unmet=True,
8480
use_all_trials_in_exp=True,
8581
only_in_statuses=[TrialStatus.COMPLETED],
8682
count_only_trials_with_data=True,
8783
),
8884
]
85+
# If we want to enforce the initialization budget, add a pausing
86+
# criterion that prevents exceeding the budget.
87+
pausing_criteria = None
88+
if not allow_exceeding_initialization_budget:
89+
pausing_criteria = [
90+
MaxTrialsAwaitingData(
91+
threshold=initialization_budget,
92+
not_in_statuses=[TrialStatus.FAILED, TrialStatus.ABANDONED],
93+
use_all_trials_in_exp=use_existing_trials_for_initialization,
94+
)
95+
]
8996
return GenerationNode(
9097
name="Sobol",
9198
generator_specs=[
@@ -95,6 +102,7 @@ def _get_sobol_node(
95102
)
96103
],
97104
transition_criteria=transition_criteria,
105+
pausing_criteria=pausing_criteria,
98106
should_deduplicate=True,
99107
suggested_experiment_status=ExperimentStatus.INITIALIZATION,
100108
)

ax/api/utils/tests/test_generation_strategy_dispatch.py

Lines changed: 13 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,7 @@
1919
from ax.exceptions.core import UserInputError
2020
from ax.generation_strategy.center_generation_node import CenterGenerationNode
2121
from ax.generation_strategy.dispatch_utils import get_derelativize_config
22-
from ax.generation_strategy.transition_criterion import MinTrials
22+
from ax.generation_strategy.transition_criterion import MaxTrialsAwaitingData, MinTrials
2323
from ax.generators.torch.botorch_modular.surrogate import ModelConfig, SurrogateSpec
2424
from ax.utils.common.testutils import TestCase
2525
from ax.utils.testing.core_stubs import (
@@ -99,16 +99,12 @@ def test_choose_gs_fast_with_options(self) -> None:
9999
MinTrials(
100100
threshold=2,
101101
transition_to="MBM",
102-
block_gen_if_met=False,
103-
block_transition_if_unmet=True,
104102
use_all_trials_in_exp=False,
105103
not_in_statuses=[TrialStatus.FAILED, TrialStatus.ABANDONED],
106104
),
107105
MinTrials(
108106
threshold=4,
109107
transition_to="MBM",
110-
block_gen_if_met=False,
111-
block_transition_if_unmet=True,
112108
use_all_trials_in_exp=True,
113109
only_in_statuses=[TrialStatus.COMPLETED],
114110
count_only_trials_with_data=True,
@@ -390,7 +386,18 @@ def test_abandoned_and_failed_trials_excluded_from_initialization_budget(
390386
first_tc.not_in_statuses, [TrialStatus.FAILED, TrialStatus.ABANDONED]
391387
)
392388
self.assertEqual(first_tc.threshold, 5)
393-
self.assertTrue(first_tc.block_gen_if_met)
389+
# Verify MaxTrialsAwaitingData is in pausing_criteria
390+
pausing_criteria = [
391+
pc
392+
for pc in sobol_node._pausing_criteria
393+
if isinstance(pc, MaxTrialsAwaitingData)
394+
]
395+
self.assertEqual(len(pausing_criteria), 1)
396+
self.assertEqual(pausing_criteria[0].threshold, 5)
397+
self.assertEqual(
398+
pausing_criteria[0].not_in_statuses,
399+
[TrialStatus.FAILED, TrialStatus.ABANDONED],
400+
)
394401

395402
# Test the actual behavior: Generate 5 trials, mark 3 as ABANDONED,
396403
# verify that Sobol can still generate more trials

0 commit comments

Comments
 (0)