diff --git a/ax/modelbridge/generation_node_input_constructors.py b/ax/modelbridge/generation_node_input_constructors.py index abc5fcda749..d088ae9dd35 100644 --- a/ax/modelbridge/generation_node_input_constructors.py +++ b/ax/modelbridge/generation_node_input_constructors.py @@ -32,6 +32,7 @@ class InputConstructorPurpose(Enum): N = "n" FIXED_FEATURES = "fixed_features" + STATUS_QUO_FEATURES = "status_quo_features" class NodeInputConstructors(FuncEnum): @@ -48,6 +49,7 @@ class NodeInputConstructors(FuncEnum): REPEAT_N = "repeat_arm_n" REMAINING_N = "remaining_n" TARGET_TRIAL_FIXED_FEATURES = "set_target_trial" + STATUS_QUO_FEATURES = "get_status_quo" # pyre-ignore[3]: Input constructors will be used to make different inputs, # so we need to allow `Any` return type here. @@ -73,6 +75,47 @@ def __call__( # ------------------------- Purpose: `fixed_features` ------------------------- # +def get_status_quo( + previous_node: GenerationNode | None, + next_node: GenerationNode, + gs_gen_call_kwargs: dict[str, Any], + experiment: Experiment, +) -> ObservationFeatures | None: + """Get the status quo features to pass to the fit of the next node, if applicable. + + Args: + previous_node: The previous node in the ``GenerationStrategy``. This is the node + that is being transition away from, and is provided for easy access to + properties of this node. + next_node: The next node in the ``GenerationStrategy``. This is the node that + will leverage the inputs defined by this input constructor. + gs_gen_call_kwargs: The kwargs passed to the ``GenerationStrategy``'s + gen call. + experiment: The experiment associated with this ``GenerationStrategy``. + Returns: + An ``ObservationFeatures`` object that defines the status quo observation + features for fitting the model in the next node. + """ + target_trial_idx = get_target_trial_index(experiment=experiment) + if target_trial_idx is None: + raise AxGenerationException( + f"Attempting to construct status quo input into {next_node} but couldn't " + "identify the target trial. Often this could be due to no trials on the " + f"experiment that are in status {STATUSES_EXPECTING_DATA} on the " + f"experiment. The trials on this experiment are: {experiment.trials}." + ) + if experiment.status_quo is None: + raise AxGenerationException( + f"Attempting to construct status quo input into {next_node} but the " + "experiment has no status quo. Please set a status quo before " + "generating." + ) + return ObservationFeatures( + parameters=experiment.status_quo.parameters, + trial_index=target_trial_idx, + ) + + def set_target_trial( previous_node: GenerationNode | None, next_node: GenerationNode, diff --git a/ax/modelbridge/generation_strategy.py b/ax/modelbridge/generation_strategy.py index d69e3d7e6b6..21143fa7b69 100644 --- a/ax/modelbridge/generation_strategy.py +++ b/ax/modelbridge/generation_strategy.py @@ -466,6 +466,9 @@ def gen_with_multiple_nodes( gen_kwargs=gen_kwargs, passed_fixed_features=fixed_features, ) + sq_ft_from_node = self._determine_sq_features_from_node( + node_to_gen_from=node_to_gen_from, gen_kwargs=gen_kwargs + ) # TODO: @mgarrard clean this up after gens merge. This is currently needed # because the actual transition occurs in gs.gen(), but if a node is # skipped, we need to transition here to actually initiate that transition @@ -481,6 +484,7 @@ def gen_with_multiple_nodes( n=arms_from_node, pending_observations=pending_observations, fixed_features=fixed_features_from_node, + status_quo_features=sq_ft_from_node, ) ) # ensure that the points generated from each node are marked as pending @@ -833,6 +837,7 @@ def _gen_multiple( data: Data | None = None, n: int = 1, pending_observations: dict[str, list[ObservationFeatures]] | None = None, + status_quo_features: ObservationFeatures | None = None, **model_gen_kwargs: Any, ) -> list[GeneratorRun]: """Produce multiple generator runs at once, to be made into multiple @@ -866,11 +871,13 @@ def _gen_multiple( model_gen_kwargs: Keyword arguments that are passed through to ``GenerationNode.gen``, which will pass them through to ``ModelSpec.gen``, which will pass them to ``ModelBridge.gen``. + status_quo_features: An ``ObservationFeature`` of the status quo arm, + needed by some models during fit to accomadate relative constraints. + Includes the status quo parameterization and target trial index. """ self.experiment = experiment self._maybe_transition_to_next_node() - self._fit_current_model(data=data) - + self._fit_current_model(data=data, status_quo_features=status_quo_features) # Get GeneratorRun limit that respects the node's transition criterion that # affect the number of generator runs that can be produced. gr_limit = self._curr.generator_run_limit(raise_generation_errors=True) @@ -978,6 +985,28 @@ def _determine_fixed_features_from_node( ) return node_fixed_features + def _determine_sq_features_from_node( + self, + node_to_gen_from: GenerationNode, + gen_kwargs: dict[str, Any], + ) -> ObservationFeatures | None: + """todo""" + # TODO: @mgarrard to merge the input constructor logic into a single method + node_sq_features = None + if ( + InputConstructorPurpose.STATUS_QUO_FEATURES + in node_to_gen_from.input_constructors + ): + node_sq_features = node_to_gen_from.input_constructors[ + InputConstructorPurpose.STATUS_QUO_FEATURES + ]( + previous_node=node_to_gen_from.previous_node, + next_node=node_to_gen_from, + gs_gen_call_kwargs=gen_kwargs, + experiment=self.experiment, + ) + return node_sq_features + def _determine_arms_from_node( self, node_to_gen_from: GenerationNode, @@ -1029,16 +1058,36 @@ def _determine_arms_from_node( # ------------------------- Model selection logic helpers. ------------------------- - def _fit_current_model(self, data: Data | None) -> None: + def _fit_current_model( + self, + data: Data | None, + status_quo_features: ObservationFeatures | None = None, + ) -> None: """Fits or update the model on the current generation node (does not move between generation nodes). Args: data: Optional ``Data`` to fit or update with; if not specified, generation strategy will obtain the data via ``experiment.lookup_data``. + status_quo_features: An ``ObservationFeature`` of the status quo arm, + needed by some models during fit to accomadate relative constraints. + Includes the status quo parameterization and target trial index. """ data = self.experiment.lookup_data() if data is None else data - self._curr.fit(experiment=self.experiment, data=data) + + # Only pass status_quo_features if not None to avoid errors + # with ``ExternalGenerationNode``. + if status_quo_features is not None: + self._curr.fit( + experiment=self.experiment, + data=data, + status_quo_features=status_quo_features, + ) + else: + self._curr.fit( + experiment=self.experiment, + data=data, + ) self._model = self._curr._fitted_model def _maybe_transition_to_next_node( diff --git a/ax/modelbridge/tests/test_generation_node_input_constructors.py b/ax/modelbridge/tests/test_generation_node_input_constructors.py index 0183ee0f021..732de6a1338 100644 --- a/ax/modelbridge/tests/test_generation_node_input_constructors.py +++ b/ax/modelbridge/tests/test_generation_node_input_constructors.py @@ -77,9 +77,10 @@ def setUp(self) -> None: # InputConstructorPurpose.N matches NodeInputConstructors.*_N. We may need # to switch to constructing this mapping manually in the future. self.purposes_to_input_constructors = { - p: [ip for ip in NodeInputConstructors if ip.name.endswith(f"_{p.name}")] + p: [ip for ip in NodeInputConstructors if ip.name.endswith(f"{p.name}")] for p in InputConstructorPurpose } + self.all_purposes_expected_signatures = { InputConstructorPurpose.N: inspect.Signature( parameters=EXPECTED_INPUT_CONSTRUCTOR_PARAMETER_ANNOTATIONS, @@ -89,6 +90,10 @@ def setUp(self) -> None: parameters=EXPECTED_INPUT_CONSTRUCTOR_PARAMETER_ANNOTATIONS, return_annotation=ObservationFeatures | None, ), + InputConstructorPurpose.STATUS_QUO_FEATURES: inspect.Signature( + parameters=EXPECTED_INPUT_CONSTRUCTOR_PARAMETER_ANNOTATIONS, + return_annotation=ObservationFeatures | None, + ), } def test_all_constructors_have_expected_signature_for_purpose(self) -> None: @@ -108,6 +113,7 @@ def test_all_constructors_have_expected_signature_for_purpose(self) -> None: untested_constructors.remove(constructor) # There should be no untested constructors left. + print(untested_constructors) self.assertEqual(len(untested_constructors), 0) def test_consume_all_n_constructor(self) -> None: @@ -301,6 +307,50 @@ def test_set_target_trial_long_run_wins(self) -> None: ), ) + def test_status_quo_features_no_sq(self) -> None: + self._add_sobol_trial( + experiment=self.experiment, + trial_type=Keys.SHORT_RUN, + complete=False, + num_arms=1, + ) + with self.assertRaisesRegex( + AxGenerationException, + "experiment has no status quo", + ): + NodeInputConstructors.STATUS_QUO_FEATURES( + previous_node=None, + next_node=self.sobol_generation_node, + gs_gen_call_kwargs={}, + experiment=self.experiment, + ) + + def test_status_quo_features(self) -> None: + self._add_sobol_trial( + experiment=self.experiment, + trial_type=Keys.LONG_RUN, + complete=False, + num_arms=1, + with_status_quo=True, + ) + self._add_sobol_trial( + experiment=self.experiment, + trial_type=Keys.LONG_RUN, + complete=False, + num_arms=3, + with_status_quo=True, + ) + sq_ft = NodeInputConstructors.STATUS_QUO_FEATURES( + previous_node=None, + next_node=self.sobol_generation_node, + gs_gen_call_kwargs={}, + experiment=self.experiment, + ) + self.assertEqual( + sq_ft, + ObservationFeatures(parameters={"x1": 0, "x2": 0}, trial_index=1), + ) + def test_set_target_trial_most_arms_long_run_wins(self) -> None: self._add_sobol_trial( experiment=self.experiment, @@ -544,6 +594,7 @@ def _add_sobol_trial( trial_type: str | None = None, complete: bool = True, num_arms: int = 1, + with_status_quo: bool = False, ) -> BatchTrial: """Helper function to add a trial to an experiment, takes a trial type and whether or not the trial is complete, and number of arms""" @@ -554,7 +605,14 @@ def _add_sobol_trial( optimize_for_power=False, trial_type=trial_type, generator_runs=grs, - ).run() + ) + if with_status_quo: + experiment.status_quo = Arm(parameters={"x1": 0, "x2": 0}) + trial.set_status_quo_with_weight( + status_quo=self.experiment.status_quo, + weight=1.0, + ) + trial.run() if complete: trial.mark_completed() return trial diff --git a/ax/utils/testing/core_stubs.py b/ax/utils/testing/core_stubs.py index 694410c6678..4e7e120234a 100644 --- a/ax/utils/testing/core_stubs.py +++ b/ax/utils/testing/core_stubs.py @@ -615,6 +615,7 @@ def get_branin_experiment_with_multi_objective( num_objectives: int = 2, with_trial: bool = False, with_completed_trial: bool = False, + with_relative_constraint: bool = False, ) -> Experiment: exp = Experiment( name="branin_test_experiment", @@ -625,6 +626,7 @@ def get_branin_experiment_with_multi_objective( get_branin_multi_objective_optimization_config( has_objective_thresholds=has_objective_thresholds, num_objectives=num_objectives, + with_relative_constraint=with_relative_constraint, ) if has_optimization_config else None @@ -1708,6 +1710,7 @@ def _validate_num_objectives(num_objectives: int) -> None: def get_branin_multi_objective_optimization_config( has_objective_thresholds: bool = False, num_objectives: int = 2, + with_relative_constraint: bool = False, ) -> MultiObjectiveOptimizationConfig: _validate_num_objectives(num_objectives=num_objectives) # minimum Branin value is 0.397887 @@ -1740,6 +1743,11 @@ def get_branin_multi_objective_optimization_config( return MultiObjectiveOptimizationConfig( objective=get_branin_multi_objective(num_objectives=num_objectives), objective_thresholds=objective_thresholds, + outcome_constraints=[ + get_outcome_constraint(get_branin_metric(name="branin_d"), relative=True) + ] + if with_relative_constraint + else None, )