Skip to content

Commit 2c2241e

Browse files
mgarrardfacebook-github-bot
authored andcommitted
Use input constructors to define status quo features for now (#2990)
Summary: Pull Request resolved: #2990 In the future we'd like to pursue an option outlined by sdaulton: "we would still need to set target trial in the transform_config for Derelativize in model_kwargs, unless we added some logic like use latest trial or something in derelativize. It seems like target trial is definitely the trial index that we should use" But we need to be able to support relative constraints in next mondays launch so this seemed easier and faster for me to implement. Would like any thoughts on if this is an acceptable workaround for the next 2-4weeks? Reviewed By: sdaulton Differential Revision: D65161089 fbshipit-source-id: 6a4f8116dcd06cd248bdf96e519ea1d96096e127
1 parent 95dceca commit 2c2241e

File tree

4 files changed

+164
-6
lines changed

4 files changed

+164
-6
lines changed

ax/modelbridge/generation_node_input_constructors.py

Lines changed: 43 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -32,6 +32,7 @@ class InputConstructorPurpose(Enum):
3232

3333
N = "n"
3434
FIXED_FEATURES = "fixed_features"
35+
STATUS_QUO_FEATURES = "status_quo_features"
3536

3637

3738
class NodeInputConstructors(FuncEnum):
@@ -48,6 +49,7 @@ class NodeInputConstructors(FuncEnum):
4849
REPEAT_N = "repeat_arm_n"
4950
REMAINING_N = "remaining_n"
5051
TARGET_TRIAL_FIXED_FEATURES = "set_target_trial"
52+
STATUS_QUO_FEATURES = "get_status_quo"
5153

5254
# pyre-ignore[3]: Input constructors will be used to make different inputs,
5355
# so we need to allow `Any` return type here.
@@ -73,6 +75,47 @@ def __call__(
7375
# ------------------------- Purpose: `fixed_features` ------------------------- #
7476

7577

78+
def get_status_quo(
79+
previous_node: GenerationNode | None,
80+
next_node: GenerationNode,
81+
gs_gen_call_kwargs: dict[str, Any],
82+
experiment: Experiment,
83+
) -> ObservationFeatures | None:
84+
"""Get the status quo features to pass to the fit of the next node, if applicable.
85+
86+
Args:
87+
previous_node: The previous node in the ``GenerationStrategy``. This is the node
88+
that is being transition away from, and is provided for easy access to
89+
properties of this node.
90+
next_node: The next node in the ``GenerationStrategy``. This is the node that
91+
will leverage the inputs defined by this input constructor.
92+
gs_gen_call_kwargs: The kwargs passed to the ``GenerationStrategy``'s
93+
gen call.
94+
experiment: The experiment associated with this ``GenerationStrategy``.
95+
Returns:
96+
An ``ObservationFeatures`` object that defines the status quo observation
97+
features for fitting the model in the next node.
98+
"""
99+
target_trial_idx = get_target_trial_index(experiment=experiment)
100+
if target_trial_idx is None:
101+
raise AxGenerationException(
102+
f"Attempting to construct status quo input into {next_node} but couldn't "
103+
"identify the target trial. Often this could be due to no trials on the "
104+
f"experiment that are in status {STATUSES_EXPECTING_DATA} on the "
105+
f"experiment. The trials on this experiment are: {experiment.trials}."
106+
)
107+
if experiment.status_quo is None:
108+
raise AxGenerationException(
109+
f"Attempting to construct status quo input into {next_node} but the "
110+
"experiment has no status quo. Please set a status quo before "
111+
"generating."
112+
)
113+
return ObservationFeatures(
114+
parameters=experiment.status_quo.parameters,
115+
trial_index=target_trial_idx,
116+
)
117+
118+
76119
def set_target_trial(
77120
previous_node: GenerationNode | None,
78121
next_node: GenerationNode,

ax/modelbridge/generation_strategy.py

Lines changed: 53 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -466,6 +466,9 @@ def gen_with_multiple_nodes(
466466
gen_kwargs=gen_kwargs,
467467
passed_fixed_features=fixed_features,
468468
)
469+
sq_ft_from_node = self._determine_sq_features_from_node(
470+
node_to_gen_from=node_to_gen_from, gen_kwargs=gen_kwargs
471+
)
469472
# TODO: @mgarrard clean this up after gens merge. This is currently needed
470473
# because the actual transition occurs in gs.gen(), but if a node is
471474
# skipped, we need to transition here to actually initiate that transition
@@ -481,6 +484,7 @@ def gen_with_multiple_nodes(
481484
n=arms_from_node,
482485
pending_observations=pending_observations,
483486
fixed_features=fixed_features_from_node,
487+
status_quo_features=sq_ft_from_node,
484488
)
485489
)
486490
# ensure that the points generated from each node are marked as pending
@@ -833,6 +837,7 @@ def _gen_multiple(
833837
data: Data | None = None,
834838
n: int = 1,
835839
pending_observations: dict[str, list[ObservationFeatures]] | None = None,
840+
status_quo_features: ObservationFeatures | None = None,
836841
**model_gen_kwargs: Any,
837842
) -> list[GeneratorRun]:
838843
"""Produce multiple generator runs at once, to be made into multiple
@@ -866,11 +871,13 @@ def _gen_multiple(
866871
model_gen_kwargs: Keyword arguments that are passed through to
867872
``GenerationNode.gen``, which will pass them through to
868873
``ModelSpec.gen``, which will pass them to ``ModelBridge.gen``.
874+
status_quo_features: An ``ObservationFeature`` of the status quo arm,
875+
needed by some models during fit to accomadate relative constraints.
876+
Includes the status quo parameterization and target trial index.
869877
"""
870878
self.experiment = experiment
871879
self._maybe_transition_to_next_node()
872-
self._fit_current_model(data=data)
873-
880+
self._fit_current_model(data=data, status_quo_features=status_quo_features)
874881
# Get GeneratorRun limit that respects the node's transition criterion that
875882
# affect the number of generator runs that can be produced.
876883
gr_limit = self._curr.generator_run_limit(raise_generation_errors=True)
@@ -978,6 +985,28 @@ def _determine_fixed_features_from_node(
978985
)
979986
return node_fixed_features
980987

988+
def _determine_sq_features_from_node(
989+
self,
990+
node_to_gen_from: GenerationNode,
991+
gen_kwargs: dict[str, Any],
992+
) -> ObservationFeatures | None:
993+
"""todo"""
994+
# TODO: @mgarrard to merge the input constructor logic into a single method
995+
node_sq_features = None
996+
if (
997+
InputConstructorPurpose.STATUS_QUO_FEATURES
998+
in node_to_gen_from.input_constructors
999+
):
1000+
node_sq_features = node_to_gen_from.input_constructors[
1001+
InputConstructorPurpose.STATUS_QUO_FEATURES
1002+
](
1003+
previous_node=node_to_gen_from.previous_node,
1004+
next_node=node_to_gen_from,
1005+
gs_gen_call_kwargs=gen_kwargs,
1006+
experiment=self.experiment,
1007+
)
1008+
return node_sq_features
1009+
9811010
def _determine_arms_from_node(
9821011
self,
9831012
node_to_gen_from: GenerationNode,
@@ -1029,16 +1058,36 @@ def _determine_arms_from_node(
10291058

10301059
# ------------------------- Model selection logic helpers. -------------------------
10311060

1032-
def _fit_current_model(self, data: Data | None) -> None:
1061+
def _fit_current_model(
1062+
self,
1063+
data: Data | None,
1064+
status_quo_features: ObservationFeatures | None = None,
1065+
) -> None:
10331066
"""Fits or update the model on the current generation node (does not move
10341067
between generation nodes).
10351068
10361069
Args:
10371070
data: Optional ``Data`` to fit or update with; if not specified, generation
10381071
strategy will obtain the data via ``experiment.lookup_data``.
1072+
status_quo_features: An ``ObservationFeature`` of the status quo arm,
1073+
needed by some models during fit to accomadate relative constraints.
1074+
Includes the status quo parameterization and target trial index.
10391075
"""
10401076
data = self.experiment.lookup_data() if data is None else data
1041-
self._curr.fit(experiment=self.experiment, data=data)
1077+
1078+
# Only pass status_quo_features if not None to avoid errors
1079+
# with ``ExternalGenerationNode``.
1080+
if status_quo_features is not None:
1081+
self._curr.fit(
1082+
experiment=self.experiment,
1083+
data=data,
1084+
status_quo_features=status_quo_features,
1085+
)
1086+
else:
1087+
self._curr.fit(
1088+
experiment=self.experiment,
1089+
data=data,
1090+
)
10421091
self._model = self._curr._fitted_model
10431092

10441093
def _maybe_transition_to_next_node(

ax/modelbridge/tests/test_generation_node_input_constructors.py

Lines changed: 60 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -77,9 +77,10 @@ def setUp(self) -> None:
7777
# InputConstructorPurpose.N matches NodeInputConstructors.*_N. We may need
7878
# to switch to constructing this mapping manually in the future.
7979
self.purposes_to_input_constructors = {
80-
p: [ip for ip in NodeInputConstructors if ip.name.endswith(f"_{p.name}")]
80+
p: [ip for ip in NodeInputConstructors if ip.name.endswith(f"{p.name}")]
8181
for p in InputConstructorPurpose
8282
}
83+
8384
self.all_purposes_expected_signatures = {
8485
InputConstructorPurpose.N: inspect.Signature(
8586
parameters=EXPECTED_INPUT_CONSTRUCTOR_PARAMETER_ANNOTATIONS,
@@ -89,6 +90,10 @@ def setUp(self) -> None:
8990
parameters=EXPECTED_INPUT_CONSTRUCTOR_PARAMETER_ANNOTATIONS,
9091
return_annotation=ObservationFeatures | None,
9192
),
93+
InputConstructorPurpose.STATUS_QUO_FEATURES: inspect.Signature(
94+
parameters=EXPECTED_INPUT_CONSTRUCTOR_PARAMETER_ANNOTATIONS,
95+
return_annotation=ObservationFeatures | None,
96+
),
9297
}
9398

9499
def test_all_constructors_have_expected_signature_for_purpose(self) -> None:
@@ -108,6 +113,7 @@ def test_all_constructors_have_expected_signature_for_purpose(self) -> None:
108113
untested_constructors.remove(constructor)
109114

110115
# There should be no untested constructors left.
116+
print(untested_constructors)
111117
self.assertEqual(len(untested_constructors), 0)
112118

113119
def test_consume_all_n_constructor(self) -> None:
@@ -301,6 +307,50 @@ def test_set_target_trial_long_run_wins(self) -> None:
301307
),
302308
)
303309

310+
def test_status_quo_features_no_sq(self) -> None:
311+
self._add_sobol_trial(
312+
experiment=self.experiment,
313+
trial_type=Keys.SHORT_RUN,
314+
complete=False,
315+
num_arms=1,
316+
)
317+
with self.assertRaisesRegex(
318+
AxGenerationException,
319+
"experiment has no status quo",
320+
):
321+
NodeInputConstructors.STATUS_QUO_FEATURES(
322+
previous_node=None,
323+
next_node=self.sobol_generation_node,
324+
gs_gen_call_kwargs={},
325+
experiment=self.experiment,
326+
)
327+
328+
def test_status_quo_features(self) -> None:
329+
self._add_sobol_trial(
330+
experiment=self.experiment,
331+
trial_type=Keys.LONG_RUN,
332+
complete=False,
333+
num_arms=1,
334+
with_status_quo=True,
335+
)
336+
self._add_sobol_trial(
337+
experiment=self.experiment,
338+
trial_type=Keys.LONG_RUN,
339+
complete=False,
340+
num_arms=3,
341+
with_status_quo=True,
342+
)
343+
sq_ft = NodeInputConstructors.STATUS_QUO_FEATURES(
344+
previous_node=None,
345+
next_node=self.sobol_generation_node,
346+
gs_gen_call_kwargs={},
347+
experiment=self.experiment,
348+
)
349+
self.assertEqual(
350+
sq_ft,
351+
ObservationFeatures(parameters={"x1": 0, "x2": 0}, trial_index=1),
352+
)
353+
304354
def test_set_target_trial_most_arms_long_run_wins(self) -> None:
305355
self._add_sobol_trial(
306356
experiment=self.experiment,
@@ -544,6 +594,7 @@ def _add_sobol_trial(
544594
trial_type: str | None = None,
545595
complete: bool = True,
546596
num_arms: int = 1,
597+
with_status_quo: bool = False,
547598
) -> BatchTrial:
548599
"""Helper function to add a trial to an experiment, takes a trial type and
549600
whether or not the trial is complete, and number of arms"""
@@ -554,7 +605,14 @@ def _add_sobol_trial(
554605
optimize_for_power=False,
555606
trial_type=trial_type,
556607
generator_runs=grs,
557-
).run()
608+
)
609+
if with_status_quo:
610+
experiment.status_quo = Arm(parameters={"x1": 0, "x2": 0})
611+
trial.set_status_quo_with_weight(
612+
status_quo=self.experiment.status_quo,
613+
weight=1.0,
614+
)
615+
trial.run()
558616
if complete:
559617
trial.mark_completed()
560618
return trial

ax/utils/testing/core_stubs.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -615,6 +615,7 @@ def get_branin_experiment_with_multi_objective(
615615
num_objectives: int = 2,
616616
with_trial: bool = False,
617617
with_completed_trial: bool = False,
618+
with_relative_constraint: bool = False,
618619
) -> Experiment:
619620
exp = Experiment(
620621
name="branin_test_experiment",
@@ -625,6 +626,7 @@ def get_branin_experiment_with_multi_objective(
625626
get_branin_multi_objective_optimization_config(
626627
has_objective_thresholds=has_objective_thresholds,
627628
num_objectives=num_objectives,
629+
with_relative_constraint=with_relative_constraint,
628630
)
629631
if has_optimization_config
630632
else None
@@ -1708,6 +1710,7 @@ def _validate_num_objectives(num_objectives: int) -> None:
17081710
def get_branin_multi_objective_optimization_config(
17091711
has_objective_thresholds: bool = False,
17101712
num_objectives: int = 2,
1713+
with_relative_constraint: bool = False,
17111714
) -> MultiObjectiveOptimizationConfig:
17121715
_validate_num_objectives(num_objectives=num_objectives)
17131716
# minimum Branin value is 0.397887
@@ -1740,6 +1743,11 @@ def get_branin_multi_objective_optimization_config(
17401743
return MultiObjectiveOptimizationConfig(
17411744
objective=get_branin_multi_objective(num_objectives=num_objectives),
17421745
objective_thresholds=objective_thresholds,
1746+
outcome_constraints=[
1747+
get_outcome_constraint(get_branin_metric(name="branin_d"), relative=True)
1748+
]
1749+
if with_relative_constraint
1750+
else None,
17431751
)
17441752

17451753

0 commit comments

Comments
 (0)