Skip to content

Commit

Permalink
Use input constructors to define status quo features for now (#2990)
Browse files Browse the repository at this point in the history
Summary:
Pull Request resolved: #2990

In the future we'd like to pursue an option outlined by sdaulton: "we would still need to set target trial in the transform_config for Derelativize in model_kwargs, unless we added some logic like use latest trial or something in derelativize. It seems like target trial is definitely the trial index that we should use"

But we need to be able to support relative constraints in next mondays launch so this seemed easier and faster for me to implement. Would like any thoughts on if this is an acceptable workaround for the next 2-4weeks?

Reviewed By: sdaulton

Differential Revision: D65161089

fbshipit-source-id: 6a4f8116dcd06cd248bdf96e519ea1d96096e127
  • Loading branch information
mgarrard authored and facebook-github-bot committed Oct 31, 2024
1 parent 95dceca commit 2c2241e
Show file tree
Hide file tree
Showing 4 changed files with 164 additions and 6 deletions.
43 changes: 43 additions & 0 deletions ax/modelbridge/generation_node_input_constructors.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@ class InputConstructorPurpose(Enum):

N = "n"
FIXED_FEATURES = "fixed_features"
STATUS_QUO_FEATURES = "status_quo_features"


class NodeInputConstructors(FuncEnum):
Expand All @@ -48,6 +49,7 @@ class NodeInputConstructors(FuncEnum):
REPEAT_N = "repeat_arm_n"
REMAINING_N = "remaining_n"
TARGET_TRIAL_FIXED_FEATURES = "set_target_trial"
STATUS_QUO_FEATURES = "get_status_quo"

# pyre-ignore[3]: Input constructors will be used to make different inputs,
# so we need to allow `Any` return type here.
Expand All @@ -73,6 +75,47 @@ def __call__(
# ------------------------- Purpose: `fixed_features` ------------------------- #


def get_status_quo(
previous_node: GenerationNode | None,
next_node: GenerationNode,
gs_gen_call_kwargs: dict[str, Any],
experiment: Experiment,
) -> ObservationFeatures | None:
"""Get the status quo features to pass to the fit of the next node, if applicable.
Args:
previous_node: The previous node in the ``GenerationStrategy``. This is the node
that is being transition away from, and is provided for easy access to
properties of this node.
next_node: The next node in the ``GenerationStrategy``. This is the node that
will leverage the inputs defined by this input constructor.
gs_gen_call_kwargs: The kwargs passed to the ``GenerationStrategy``'s
gen call.
experiment: The experiment associated with this ``GenerationStrategy``.
Returns:
An ``ObservationFeatures`` object that defines the status quo observation
features for fitting the model in the next node.
"""
target_trial_idx = get_target_trial_index(experiment=experiment)
if target_trial_idx is None:
raise AxGenerationException(
f"Attempting to construct status quo input into {next_node} but couldn't "
"identify the target trial. Often this could be due to no trials on the "
f"experiment that are in status {STATUSES_EXPECTING_DATA} on the "
f"experiment. The trials on this experiment are: {experiment.trials}."
)
if experiment.status_quo is None:
raise AxGenerationException(
f"Attempting to construct status quo input into {next_node} but the "
"experiment has no status quo. Please set a status quo before "
"generating."
)
return ObservationFeatures(
parameters=experiment.status_quo.parameters,
trial_index=target_trial_idx,
)


def set_target_trial(
previous_node: GenerationNode | None,
next_node: GenerationNode,
Expand Down
57 changes: 53 additions & 4 deletions ax/modelbridge/generation_strategy.py
Original file line number Diff line number Diff line change
Expand Up @@ -466,6 +466,9 @@ def gen_with_multiple_nodes(
gen_kwargs=gen_kwargs,
passed_fixed_features=fixed_features,
)
sq_ft_from_node = self._determine_sq_features_from_node(
node_to_gen_from=node_to_gen_from, gen_kwargs=gen_kwargs
)
# TODO: @mgarrard clean this up after gens merge. This is currently needed
# because the actual transition occurs in gs.gen(), but if a node is
# skipped, we need to transition here to actually initiate that transition
Expand All @@ -481,6 +484,7 @@ def gen_with_multiple_nodes(
n=arms_from_node,
pending_observations=pending_observations,
fixed_features=fixed_features_from_node,
status_quo_features=sq_ft_from_node,
)
)
# ensure that the points generated from each node are marked as pending
Expand Down Expand Up @@ -833,6 +837,7 @@ def _gen_multiple(
data: Data | None = None,
n: int = 1,
pending_observations: dict[str, list[ObservationFeatures]] | None = None,
status_quo_features: ObservationFeatures | None = None,
**model_gen_kwargs: Any,
) -> list[GeneratorRun]:
"""Produce multiple generator runs at once, to be made into multiple
Expand Down Expand Up @@ -866,11 +871,13 @@ def _gen_multiple(
model_gen_kwargs: Keyword arguments that are passed through to
``GenerationNode.gen``, which will pass them through to
``ModelSpec.gen``, which will pass them to ``ModelBridge.gen``.
status_quo_features: An ``ObservationFeature`` of the status quo arm,
needed by some models during fit to accomadate relative constraints.
Includes the status quo parameterization and target trial index.
"""
self.experiment = experiment
self._maybe_transition_to_next_node()
self._fit_current_model(data=data)

self._fit_current_model(data=data, status_quo_features=status_quo_features)
# Get GeneratorRun limit that respects the node's transition criterion that
# affect the number of generator runs that can be produced.
gr_limit = self._curr.generator_run_limit(raise_generation_errors=True)
Expand Down Expand Up @@ -978,6 +985,28 @@ def _determine_fixed_features_from_node(
)
return node_fixed_features

def _determine_sq_features_from_node(
self,
node_to_gen_from: GenerationNode,
gen_kwargs: dict[str, Any],
) -> ObservationFeatures | None:
"""todo"""
# TODO: @mgarrard to merge the input constructor logic into a single method
node_sq_features = None
if (
InputConstructorPurpose.STATUS_QUO_FEATURES
in node_to_gen_from.input_constructors
):
node_sq_features = node_to_gen_from.input_constructors[
InputConstructorPurpose.STATUS_QUO_FEATURES
](
previous_node=node_to_gen_from.previous_node,
next_node=node_to_gen_from,
gs_gen_call_kwargs=gen_kwargs,
experiment=self.experiment,
)
return node_sq_features

def _determine_arms_from_node(
self,
node_to_gen_from: GenerationNode,
Expand Down Expand Up @@ -1029,16 +1058,36 @@ def _determine_arms_from_node(

# ------------------------- Model selection logic helpers. -------------------------

def _fit_current_model(self, data: Data | None) -> None:
def _fit_current_model(
self,
data: Data | None,
status_quo_features: ObservationFeatures | None = None,
) -> None:
"""Fits or update the model on the current generation node (does not move
between generation nodes).
Args:
data: Optional ``Data`` to fit or update with; if not specified, generation
strategy will obtain the data via ``experiment.lookup_data``.
status_quo_features: An ``ObservationFeature`` of the status quo arm,
needed by some models during fit to accomadate relative constraints.
Includes the status quo parameterization and target trial index.
"""
data = self.experiment.lookup_data() if data is None else data
self._curr.fit(experiment=self.experiment, data=data)

# Only pass status_quo_features if not None to avoid errors
# with ``ExternalGenerationNode``.
if status_quo_features is not None:
self._curr.fit(
experiment=self.experiment,
data=data,
status_quo_features=status_quo_features,
)
else:
self._curr.fit(
experiment=self.experiment,
data=data,
)
self._model = self._curr._fitted_model

def _maybe_transition_to_next_node(
Expand Down
62 changes: 60 additions & 2 deletions ax/modelbridge/tests/test_generation_node_input_constructors.py
Original file line number Diff line number Diff line change
Expand Up @@ -77,9 +77,10 @@ def setUp(self) -> None:
# InputConstructorPurpose.N matches NodeInputConstructors.*_N. We may need
# to switch to constructing this mapping manually in the future.
self.purposes_to_input_constructors = {
p: [ip for ip in NodeInputConstructors if ip.name.endswith(f"_{p.name}")]
p: [ip for ip in NodeInputConstructors if ip.name.endswith(f"{p.name}")]
for p in InputConstructorPurpose
}

self.all_purposes_expected_signatures = {
InputConstructorPurpose.N: inspect.Signature(
parameters=EXPECTED_INPUT_CONSTRUCTOR_PARAMETER_ANNOTATIONS,
Expand All @@ -89,6 +90,10 @@ def setUp(self) -> None:
parameters=EXPECTED_INPUT_CONSTRUCTOR_PARAMETER_ANNOTATIONS,
return_annotation=ObservationFeatures | None,
),
InputConstructorPurpose.STATUS_QUO_FEATURES: inspect.Signature(
parameters=EXPECTED_INPUT_CONSTRUCTOR_PARAMETER_ANNOTATIONS,
return_annotation=ObservationFeatures | None,
),
}

def test_all_constructors_have_expected_signature_for_purpose(self) -> None:
Expand All @@ -108,6 +113,7 @@ def test_all_constructors_have_expected_signature_for_purpose(self) -> None:
untested_constructors.remove(constructor)

# There should be no untested constructors left.
print(untested_constructors)
self.assertEqual(len(untested_constructors), 0)

def test_consume_all_n_constructor(self) -> None:
Expand Down Expand Up @@ -301,6 +307,50 @@ def test_set_target_trial_long_run_wins(self) -> None:
),
)

def test_status_quo_features_no_sq(self) -> None:
self._add_sobol_trial(
experiment=self.experiment,
trial_type=Keys.SHORT_RUN,
complete=False,
num_arms=1,
)
with self.assertRaisesRegex(
AxGenerationException,
"experiment has no status quo",
):
NodeInputConstructors.STATUS_QUO_FEATURES(
previous_node=None,
next_node=self.sobol_generation_node,
gs_gen_call_kwargs={},
experiment=self.experiment,
)

def test_status_quo_features(self) -> None:
self._add_sobol_trial(
experiment=self.experiment,
trial_type=Keys.LONG_RUN,
complete=False,
num_arms=1,
with_status_quo=True,
)
self._add_sobol_trial(
experiment=self.experiment,
trial_type=Keys.LONG_RUN,
complete=False,
num_arms=3,
with_status_quo=True,
)
sq_ft = NodeInputConstructors.STATUS_QUO_FEATURES(
previous_node=None,
next_node=self.sobol_generation_node,
gs_gen_call_kwargs={},
experiment=self.experiment,
)
self.assertEqual(
sq_ft,
ObservationFeatures(parameters={"x1": 0, "x2": 0}, trial_index=1),
)

def test_set_target_trial_most_arms_long_run_wins(self) -> None:
self._add_sobol_trial(
experiment=self.experiment,
Expand Down Expand Up @@ -544,6 +594,7 @@ def _add_sobol_trial(
trial_type: str | None = None,
complete: bool = True,
num_arms: int = 1,
with_status_quo: bool = False,
) -> BatchTrial:
"""Helper function to add a trial to an experiment, takes a trial type and
whether or not the trial is complete, and number of arms"""
Expand All @@ -554,7 +605,14 @@ def _add_sobol_trial(
optimize_for_power=False,
trial_type=trial_type,
generator_runs=grs,
).run()
)
if with_status_quo:
experiment.status_quo = Arm(parameters={"x1": 0, "x2": 0})
trial.set_status_quo_with_weight(
status_quo=self.experiment.status_quo,
weight=1.0,
)
trial.run()
if complete:
trial.mark_completed()
return trial
Expand Down
8 changes: 8 additions & 0 deletions ax/utils/testing/core_stubs.py
Original file line number Diff line number Diff line change
Expand Up @@ -615,6 +615,7 @@ def get_branin_experiment_with_multi_objective(
num_objectives: int = 2,
with_trial: bool = False,
with_completed_trial: bool = False,
with_relative_constraint: bool = False,
) -> Experiment:
exp = Experiment(
name="branin_test_experiment",
Expand All @@ -625,6 +626,7 @@ def get_branin_experiment_with_multi_objective(
get_branin_multi_objective_optimization_config(
has_objective_thresholds=has_objective_thresholds,
num_objectives=num_objectives,
with_relative_constraint=with_relative_constraint,
)
if has_optimization_config
else None
Expand Down Expand Up @@ -1708,6 +1710,7 @@ def _validate_num_objectives(num_objectives: int) -> None:
def get_branin_multi_objective_optimization_config(
has_objective_thresholds: bool = False,
num_objectives: int = 2,
with_relative_constraint: bool = False,
) -> MultiObjectiveOptimizationConfig:
_validate_num_objectives(num_objectives=num_objectives)
# minimum Branin value is 0.397887
Expand Down Expand Up @@ -1740,6 +1743,11 @@ def get_branin_multi_objective_optimization_config(
return MultiObjectiveOptimizationConfig(
objective=get_branin_multi_objective(num_objectives=num_objectives),
objective_thresholds=objective_thresholds,
outcome_constraints=[
get_outcome_constraint(get_branin_metric(name="branin_d"), relative=True)
]
if with_relative_constraint
else None,
)


Expand Down

0 comments on commit 2c2241e

Please sign in to comment.