diff --git a/ax/generation_strategy/generation_node.py b/ax/generation_strategy/generation_node.py index 61d7319ab1f..ae2e608df0c 100644 --- a/ax/generation_strategy/generation_node.py +++ b/ax/generation_strategy/generation_node.py @@ -11,9 +11,10 @@ from collections import defaultdict from collections.abc import Callable, Sequence from logging import Logger -from typing import Any, TYPE_CHECKING +from typing import Any, TYPE_CHECKING, Union + +import ax.generation_strategy as gs_module # @manual -from ax.core.arm import Arm from ax.core.data import Data from ax.core.experiment import Experiment from ax.core.generator_run import GeneratorRun @@ -49,6 +50,7 @@ ) from ax.utils.common.base import SortableBase from ax.utils.common.constants import Keys +from ax.utils.common.kwargs import consolidate_kwargs from ax.utils.common.logger import get_logger from ax.utils.common.serialization import SerializationMixin from pyre_extensions import none_throws @@ -190,7 +192,6 @@ def __init__( self._previous_node_name = previous_node_name self._trial_type = trial_type self._should_skip = should_skip - # pyre-fixme[8]: Incompatible attribute type self.fallback_specs: dict[type[Exception], GeneratorSpec] = ( fallback_specs if fallback_specs is not None else DEFAULT_FALLBACK ) @@ -296,6 +297,25 @@ def _fitted_model(self) -> Adapter | None: """ return self.model_spec_to_gen_from._fitted_model + def __repr__(self) -> str: + """String representation of this ``GenerationNode`` (note that it + will abridge some aspects of ``TransitionCriterion`` and + ``GeneratorSpec`` attributes). + """ + str_rep = f"{self.__class__.__name__}" + str_rep += f"(node_name='{self.node_name}'" + str_rep += ", model_specs=" + generator_spec_str = ( + ", ".join([spec._brief_repr() for spec in self.model_specs]) + .replace("\n", " ") + .replace("\t", "") + ) + str_rep += f"[{generator_spec_str}]" + str_rep += ( + f", transition_criteria={str(self._brief_transition_criteria_repr())}" + ) + return f"{str_rep})" + def fit( self, experiment: Experiment, @@ -342,55 +362,12 @@ def fit( }, ) - def _get_model_state_from_last_generator_run( - self, model_spec: GeneratorSpec - ) -> dict[str, Any]: - """Get the fit args from the last generator run for the model being fit. - - NOTE: This only works for the base GeneratorSpec class. Factory functions - are not supported and will return an empty dict. - - Args: - model_spec: The model spec to get the fit args for. - - Returns: - A dictionary of fit args extracted from the last generator run - that was generated by the model being fit. - """ - if ( - isinstance(model_spec, FactoryFunctionGeneratorSpec) - or self._generation_strategy is None - ): - # We cannot extract the args for factory functions (which are to be - # deprecated). If there is no GS, we cannot access the previous GRs. - return {} - curr_model = model_spec.model_enum - # Find the last GR that was generated by the model being fit. - grs = self.generation_strategy._generator_runs - for gr in reversed(grs): - if ( - gr._generation_node_name == self.node_name - and gr._model_key == model_spec.model_key - ): - break - else: - # No previous GR from this model. - return {} - # Extract the fit args from the GR. - return _extract_model_state_after_gen( - # pyre-ignore [61]: Local variable `gr` is undefined, or not always defined. - # Pyre is wrong here. If we reach this line, `gr` must be defined. - generator_run=gr, - model_class=curr_model.model_class, - ) - - # TODO [drfreund]: Move this up to `GenerationNodeInterface` once implemented. def gen( self, + *, + experiment: Experiment, n: int | None = None, pending_observations: dict[str, list[ObservationFeatures]] | None = None, - max_gen_attempts_for_deduplication: int = MAX_GEN_ATTEMPTS, - arms_by_signature_for_deduplication: dict[str, Arm] | None = None, **model_gen_kwargs: Any, ) -> GeneratorRun: """This method generates candidates using `self._gen` and handles deduplication @@ -408,12 +385,6 @@ def gen( pending_observations: A map from metric name to pending observations for that metric, used by some models to avoid resuggesting points that are currently being evaluated. - max_gen_attempts_for_deduplication: Maximum number of attempts for - generating new candidates without duplicates. If non-duplicate - candidates are not generated with these attempts, a - ``GenerationStrategyRepeatedPoints`` exception will be raised. - arms_by_signature_for_deduplication: A dictionary mapping arm signatures to - the arms, to be used for deduplicating newly generated arms. model_gen_kwargs: Keyword arguments, passed through to ``ModelSpec.gen``; these override any pre-specified in ``ModelSpec.model_gen_kwargs``. @@ -421,85 +392,33 @@ def gen( Returns: A ``GeneratorRun`` containing the newly generated candidates. """ - generator_run = None - n_gen_draws = 0 + model_gen_kwargs = model_gen_kwargs or {} try: - # Keep generating until each of `generator_run.arms` is not a duplicate - # of a previous arm, if `should_deduplicate is True` - while n_gen_draws < max_gen_attempts_for_deduplication: - n_gen_draws += 1 - generator_run = self._gen( - n=n, - pending_observations=pending_observations, - **model_gen_kwargs, - ) - if not ( - self.should_deduplicate - and arms_by_signature_for_deduplication - and any( - arm.signature in arms_by_signature_for_deduplication - for arm in generator_run.arms - ) - ): # Not deduplicating or generated a non-duplicate arm. - break - - logger.info( - "The generator run produced duplicate arms. Re-running the " - "generation step in an attempt to deduplicate. Candidates " - f"produced in the last generator run: {generator_run.arms}." - ) - - if n_gen_draws >= max_gen_attempts_for_deduplication: - raise GenerationStrategyRepeatedPoints( - MAX_GEN_ATTEMPTS_EXCEEDED_MESSAGE - ) - except Exception as e: - error_type = type(e) - if error_type not in self.fallback_specs: - raise e - - # identify fallback model to use - fallback_model = self.fallback_specs[error_type] - logger.warning( - f"gen failed with error {e}, " - "switching to fallback model with model_enum " - f"{fallback_model.model_enum}" - ) - - # fit fallback model using information from `self.experiment` - # as ground truth - fallback_model.fit( - experiment=self.experiment, - data=self.experiment.lookup_data(), - search_space=self.experiment.search_space, - optimization_config=self.experiment.optimization_config, - **self._get_model_state_from_last_generator_run( - model_spec=fallback_model - ), + # Generate from the main generator on this node. If deduplicating, + # keep generating until each of `generator_run.arms` is not a + # duplicate of a previous active arm (e.g. not from a failed trial) + # on the experiment. + gr = self._gen_maybe_deduplicate( + experiment=experiment, + n=n, + pending_observations=pending_observations, + **model_gen_kwargs, ) - # Switch _model_spec_to_gen_from to a fallback spec - self._model_spec_to_gen_from = fallback_model - generator_run = self._gen( + except Exception as e: + gr = self._try_gen_with_fallback( + exception=e, n=n, pending_observations=pending_observations, **model_gen_kwargs, ) - assert generator_run is not None, ( - "The GeneratorRun is None which is an unexpected state of this" - " GenerationStrategy. This occurred on GenerationNode: {self.node_name}." - ) - generator_run._generation_node_name = self.node_name + gr._generation_node_name = self.node_name # TODO: @mgarrard determine a more refined way to indicate trial type if self._trial_type is not None: - gen_metadata = ( - generator_run.gen_metadata - if generator_run.gen_metadata is not None - else {} - ) + gen_metadata = gr.gen_metadata if gr.gen_metadata is not None else {} gen_metadata["trial_type"] = self._trial_type - generator_run._gen_metadata = gen_metadata - return generator_run + gr._gen_metadata = gen_metadata + return gr def _gen( self, @@ -541,6 +460,114 @@ def _gen( **model_gen_kwargs, ) + def _try_gen_with_fallback( + self, + exception: Exception, + n: int | None, + pending_observations: dict[str, list[ObservationFeatures]] | None, + **model_gen_kwargs: Any, + ) -> GeneratorRun: + error_type = type(exception) + if error_type not in self.fallback_specs: + raise exception + + # identify fallback model to use + fallback_model = self.fallback_specs[error_type] + logger.warning( + f"gen failed with error {exception}, " + "switching to fallback model with model_enum " + f"{fallback_model.model_enum}" + ) + + # fit fallback model using information from `self.experiment` + # as ground truth + fallback_model.fit( + experiment=self.experiment, + data=self.experiment.lookup_data(), + search_space=self.experiment.search_space, + optimization_config=self.experiment.optimization_config, + **self._get_model_state_from_last_generator_run(model_spec=fallback_model), + ) + # Switch _model_spec_to_gen_from to a fallback spec + self._model_spec_to_gen_from = fallback_model + gr = self._gen( + n=n, + pending_observations=pending_observations, + **model_gen_kwargs, + ) + return gr + + def _gen_maybe_deduplicate( + self, + experiment: Experiment, + n: int | None, + pending_observations: dict[str, list[ObservationFeatures]] | None, + **model_gen_kwargs: Any, + ) -> GeneratorRun: + n_gen_draws = 0 + gr = None + dedup_against_arms = experiment.arms_by_signature_for_deduplication + # Keep generating until each of `generator_run.arms` is not a duplicate + # of a previous arm, if `should_deduplicate is True` + while n_gen_draws < MAX_GEN_ATTEMPTS: + n_gen_draws += 1 + gr = self._gen( + n=n, + pending_observations=pending_observations, + **model_gen_kwargs, + ) + if not self.should_deduplicate or not dedup_against_arms: + return gr # Not deduplicationg. + if all( + arm.signature not in dedup_against_arms for arm in gr.arms + ): # Not deduplicating or generated a non-duplicate arm. + return gr # Generated a set of non-duplicate arms. + logger.info( + "The generator run produced duplicate arms. Re-running the " + "generation step in an attempt to deduplicate. Candidates " + f"produced in the last generator run: {gr.arms}." + ) + + raise GenerationStrategyRepeatedPoints(MAX_GEN_ATTEMPTS_EXCEEDED_MESSAGE) + + def _get_model_state_from_last_generator_run( + self, model_spec: GeneratorSpec + ) -> dict[str, Any]: + """Get the fit args from the last generator run for the model being fit. + + NOTE: This only works for the base GeneratorSpec class. Factory functions + are not supported and will return an empty dict. + + Args: + model_spec: The model spec to get the fit args for. + + Returns: + A dictionary of fit args extracted from the last generator run + that was generated by the model being fit. + """ + if ( + isinstance(model_spec, FactoryFunctionGeneratorSpec) + or self._generation_strategy is None + ): + # We cannot extract the args for factory functions (which are to be + # deprecated). If there is no GS, we cannot access the previous GRs. + return {} + curr_model = model_spec.model_enum + # Find the last GR that was generated by the model being fit. + grs = self.generation_strategy._generator_runs + for gr in reversed(grs): + if ( + gr._generation_node_name == self.node_name + and gr._model_key == model_spec.model_key + ): + # Extract the fit args from the GR. + return _extract_model_state_after_gen( + generator_run=gr, + model_class=curr_model.model_class, + ) + # No previous GR from this model. + return {} + # ------------------------- Model selection logic helpers. ------------------------- def _pick_fitted_model_to_gen_from(self) -> GeneratorSpec: @@ -640,7 +667,7 @@ def should_transition_to_next_node( if len(self.transition_criteria) == 0: return False, self.node_name - # for each edge in node DAG, check if the transition criterion are met, if so + # for each edge in node DAG, check if the transition criteria are met, if so # transition to the next node defined by that edge. for next_node, all_tc in self.transition_edges.items(): transition_blocking = [tc for tc in all_tc if tc.block_transition_if_unmet] @@ -741,25 +768,135 @@ def _brief_transition_criteria_repr(self) -> str: ) return f"[{tc_list}]" - def __repr__(self) -> str: - "String representation of this GenerationNode" - # add model specs - str_rep = f"{self.__class__.__name__}" - str_rep += f"(node_name='{self.node_name}'" + def apply_input_constructors( + self, + gen_kwargs: dict[str, Any], + ) -> dict[str, Union[int, ObservationFeatures | None]]: + # NOTE: In the future we might have to add new types ot the `Union` above + # or allow `Any` for the value type, but until we have more different types + # of input constructors, this provides a bit of additional typechecking. + return { + "n": self._determine_arms_from_node( + gen_kwargs=gen_kwargs, + ), + "fixed_features": self._determine_fixed_features_from_node( + gen_kwargs=gen_kwargs, + ), + "status_quo_features": self._determine_sq_features_from_node( + gen_kwargs=gen_kwargs, + ), + } + + def _determine_arms_from_node( + self, + gen_kwargs: dict[str, Any], + ) -> int: + """Calculates the number of arms to generate from the node that will be used + during generation. - str_rep += ", model_specs=" - model_spec_str = ( - ", ".join([spec._brief_repr() for spec in self.model_specs]) - .replace("\n", " ") - .replace("\t", "") + Args: + gen_kwargs: The kwargs passed to the ``GenerationStrategy``'s + gen call, including arms_per_node: an optional map from node name to + the number of arms to generate from that node. If not provided, will + default to the number of arms specified in the node's + ``InputConstructors`` or n if no``InputConstructors`` are defined on + the node. + + Returns: + The number of arms to generate from the node that will be used during this + generation via ``_gen_multiple``. + """ + arms_per_node = gen_kwargs.get("arms_per_node") + purpose_N = ( + gs_module.generation_node_input_constructors.InputConstructorPurpose.N ) - str_rep += f"[{model_spec_str}]" + if arms_per_node is not None: + # arms_per_node provides a way to manually override input + # constructors. This should be used with caution, and only + # if you really know what you're doing. :) + arms_from_node = arms_per_node[self.node_name] + elif purpose_N not in self.input_constructors: + # if the node does not have an input constructor for N, then we + # assume a default of generating n arms from this node. + n = gen_kwargs.get("n") + arms_from_node = n if n is not None else self.generation_strategy.DEFAULT_N + else: + arms_from_node = self.input_constructors[purpose_N]( + previous_node=self.previous_node, + next_node=self, + gs_gen_call_kwargs=gen_kwargs, + experiment=self.experiment, + ) - str_rep += ( - f", transition_criteria={str(self._brief_transition_criteria_repr())}" + return arms_from_node + + def _determine_fixed_features_from_node( + self, + gen_kwargs: dict[str, Any], + ) -> ObservationFeatures | None: + """Uses the ``InputConstructors`` on the node to determine the fixed features + to pass into the model. If fixed_features are provided, the will take + precedence over the fixed_features from the node. + + Args: + node_to_gen_from: The node from which to generate from + gen_kwargs: The kwargs passed to the ``GenerationStrategy``'s + gen call, including the fixed features passed to the ``gen`` method if + any. + + Returns: + An object of ObservationFeatures that represents the fixed features to + pass into the model. + """ + # passed_fixed_features represents the fixed features that were passed by the + # user to the gen method as overrides. + passed_fixed_features = gen_kwargs.get("fixed_features") + if passed_fixed_features is not None: + return passed_fixed_features + + node_fixed_features = None + input_constructors_module = gs_module.generation_node_input_constructors + purpose_fixed_features = ( + input_constructors_module.InputConstructorPurpose.FIXED_FEATURES ) + if purpose_fixed_features in self.input_constructors: + node_fixed_features = self.input_constructors[purpose_fixed_features]( + previous_node=self.previous_node, + next_node=self, + gs_gen_call_kwargs=gen_kwargs, + experiment=self.experiment, + ) + return node_fixed_features - return f"{str_rep})" + def _determine_sq_features_from_node( + self, + gen_kwargs: dict[str, Any], + ) -> ObservationFeatures | None: + """Uses the ``InputConstructors`` on the node to determine the status quo + features to pass into the model. + + Args: + node_to_gen_from: The node from which to generate from + gen_kwargs: The kwargs passed to the ``GenerationStrategy``'s + gen call. + + Returns: + An object of ObservationFeatures that represents the status quo features + to pass into the model. + """ + node_sq_features = None + input_constructors_module = gs_module.generation_node_input_constructors + purpose_sq_features = ( + input_constructors_module.InputConstructorPurpose.STATUS_QUO_FEATURES + ) + if purpose_sq_features in self.input_constructors: + node_sq_features = self.input_constructors[purpose_sq_features]( + previous_node=self.previous_node, + next_node=self, + gs_gen_call_kwargs=gen_kwargs, + experiment=self.experiment, + ) + return node_sq_features class GenerationStep(GenerationNode, SortableBase): @@ -975,17 +1112,16 @@ def _unique_id(self) -> str: def gen( self, + *, + experiment: Experiment, n: int | None = None, pending_observations: dict[str, list[ObservationFeatures]] | None = None, - max_gen_attempts_for_deduplication: int = MAX_GEN_ATTEMPTS, - arms_by_signature_for_deduplication: dict[str, Arm] | None = None, **model_gen_kwargs: Any, ) -> GeneratorRun: gr = super().gen( + experiment=experiment, n=n, pending_observations=pending_observations, - max_gen_attempts_for_deduplication=max_gen_attempts_for_deduplication, - arms_by_signature_for_deduplication=arms_by_signature_for_deduplication, **model_gen_kwargs, ) gr._generation_step_index = self.index diff --git a/ax/generation_strategy/generation_strategy.py b/ax/generation_strategy/generation_strategy.py index d7a6e5d0d47..c3e528503b5 100644 --- a/ax/generation_strategy/generation_strategy.py +++ b/ax/generation_strategy/generation_strategy.py @@ -37,7 +37,7 @@ from ax.utils.common.base import Base from ax.utils.common.logger import get_logger from ax.utils.common.typeutils import assert_is_instance_list -from pyre_extensions import none_throws +from pyre_extensions import assert_is_instance, none_throws logger: Logger = get_logger(__name__) @@ -758,7 +758,6 @@ def _gen_with_multiple_nodes( ) while continue_gen_for_trial: - pack_gs_gen_kwargs["grs_this_gen"] = grs should_transition, node_to_gen_from_name = ( self._curr.should_transition_to_next_node( raise_data_required_error=False @@ -770,52 +769,47 @@ def _gen_with_multiple_nodes( # reset should skip as conditions may have changed, do not reset # until now so node properties can be as up to date as possible node_to_gen_from._should_skip = False - arms_from_node = self._determine_arms_from_node( - node_to_gen_from=node_to_gen_from, - n=n, - gen_kwargs=pack_gs_gen_kwargs, - ) - fixed_features_from_node = self._determine_fixed_features_from_node( - node_to_gen_from=node_to_gen_from, - gen_kwargs=pack_gs_gen_kwargs, - ) - sq_ft_from_node = self._determine_sq_features_from_node( - node_to_gen_from=node_to_gen_from, gen_kwargs=pack_gs_gen_kwargs - ) self._maybe_transition_to_next_node() - if node_to_gen_from._should_skip: + input_constructor_values = self._curr.apply_input_constructors( + gen_kwargs=pack_gs_gen_kwargs + ) + if ( + node_to_gen_from._should_skip + ): # Determined during input constructor computation continue - self._fit_current_model(data=data, status_quo_features=sq_ft_from_node) + + # TODO[@drfreund,mgarrard]: We won't need this here if we figure + # out another way to pass SQ features. + sq_f = input_constructor_values.pop("status_quo_features") + if sq_f is not None: + sq_f = assert_is_instance(sq_f, ObservationFeatures) + self._fit_current_model(data=data, status_quo_features=sq_f) self._curr.generator_run_limit(raise_generation_errors=True) - if arms_from_node != 0: - try: - curr_node_gr = self._curr.gen( - n=arms_from_node, - pending_observations=pending_observations, - arms_by_signature_for_deduplication=( - experiment.arms_by_signature_for_deduplication - ), - fixed_features=fixed_features_from_node, - ) - except DataRequiredError as err: - # Model needs more data, so we log the error and return - # as many generator runs as we were able to produce, unless - # no trials were produced at all (in which case its safe to raise). - if len(grs) == 0: - raise - logger.debug(f"Model required more data: {err}.") - break - self._generator_runs.append(curr_node_gr) - grs.append(curr_node_gr) - # ensure that the points generated from each node are marked as pending - # points for future calls to gen - pending_observations = extend_pending_observations( - experiment=experiment, - pending_observations=pending_observations, - # only pass in the most recent generator run to avoid unnecessary - # deduplication in extend_pending_observations - generator_runs=[grs[-1]], + model_gen_kwargs = {**pack_gs_gen_kwargs} + model_gen_kwargs.update(input_constructor_values) + try: + curr_node_gr = self._curr.gen( + **model_gen_kwargs, ) + except DataRequiredError as err: + # Model needs more data, so we log the error and return + # as many generator runs as we were able to produce, unless + # no trials were produced at all (in which case its safe to raise). + if len(grs) == 0: + raise + logger.debug(f"Model required more data: {err}.") + break + self._generator_runs.append(curr_node_gr) + grs.append(curr_node_gr) + # ensure that the points generated from each node are marked as pending + # points for future calls to gen + pending_observations = extend_pending_observations( + experiment=experiment, + pending_observations=pending_observations, + # only pass in the most recent generator run to avoid unnecessary + # deduplication in extend_pending_observations + generator_runs=[grs[-1]], + ) continue_gen_for_trial = self._should_continue_gen_for_trial() return grs @@ -878,127 +872,6 @@ def _initialize_gen_kwargs( "pending_observations": pending_observations, } - def _determine_fixed_features_from_node( - self, - node_to_gen_from: GenerationNode, - gen_kwargs: dict[str, Any], - ) -> ObservationFeatures | None: - """Uses the ``InputConstructors`` on the node to determine the fixed features - to pass into the model. If fixed_features are provided, the will take - precedence over the fixed_features from the node. - - Args: - node_to_gen_from: The node from which to generate from - gen_kwargs: The kwargs passed to the ``GenerationStrategy``'s - gen call, including the fixed features passed to the ``gen`` method if - any. - - Returns: - An object of ObservationFeatures that represents the fixed features to - pass into the model. - """ - # passed_fixed_features represents the fixed features that were passed by the - # user to the gen method as overrides. - passed_fixed_features = gen_kwargs.get("fixed_features") - if passed_fixed_features is not None: - return passed_fixed_features - - node_fixed_features = None - if ( - InputConstructorPurpose.FIXED_FEATURES - in node_to_gen_from.input_constructors - ): - node_fixed_features = node_to_gen_from.input_constructors[ - InputConstructorPurpose.FIXED_FEATURES - ]( - previous_node=node_to_gen_from.previous_node, - next_node=node_to_gen_from, - gs_gen_call_kwargs=gen_kwargs, - experiment=self.experiment, - ) - return node_fixed_features - - def _determine_sq_features_from_node( - self, - node_to_gen_from: GenerationNode, - gen_kwargs: dict[str, Any], - ) -> ObservationFeatures | None: - """Uses the ``InputConstructors`` on the node to determine the status quo - features to pass into the model. - - Args: - node_to_gen_from: The node from which to generate from - gen_kwargs: The kwargs passed to the ``GenerationStrategy``'s - gen call. - - Returns: - An object of ObservationFeatures that represents the status quo features - to pass into the model. - """ - node_sq_features = None - if ( - InputConstructorPurpose.STATUS_QUO_FEATURES - in node_to_gen_from.input_constructors - ): - node_sq_features = node_to_gen_from.input_constructors[ - InputConstructorPurpose.STATUS_QUO_FEATURES - ]( - previous_node=node_to_gen_from.previous_node, - next_node=node_to_gen_from, - gs_gen_call_kwargs=gen_kwargs, - experiment=self.experiment, - ) - return node_sq_features - - def _determine_arms_from_node( - self, - node_to_gen_from: GenerationNode, - gen_kwargs: dict[str, Any], - n: int | None = None, - ) -> int: - """Calculates the number of arms to generate from the node that will be used - during generation. - - Args: - n: Integer representing how many arms should be in the generator run - produced by this method. NOTE: Some underlying models may ignore - the `n` and produce a model-determined number of arms. In that - case this method will also output a generator run with number of - arms that can differ from `n`. - node_to_gen_from: The node from which to generate from - gen_kwargs: The kwargs passed to the ``GenerationStrategy``'s - gen call, including arms_per_node: an optional map from node name to - the number of arms to generate from that node. If not provided, will - default to the number of arms specified in the node's - ``InputConstructors`` or n if no``InputConstructors`` are defined on - the node. - - Returns: - The number of arms to generate from the node that will be used during this - generation via ``_gen_multiple``. - """ - arms_per_node = gen_kwargs.get("arms_per_node") - if arms_per_node is not None: - # arms_per_node provides a way to manually override input - # constructors. This should be used with caution, and only - # if you really know what you're doing. :) - arms_from_node = arms_per_node[node_to_gen_from.node_name] - elif InputConstructorPurpose.N not in node_to_gen_from.input_constructors: - # if the node does not have an input constructor for N, then we - # assume a default of generating n arms from this node. - arms_from_node = n if n is not None else self.DEFAULT_N - else: - arms_from_node = node_to_gen_from.input_constructors[ - InputConstructorPurpose.N - ]( - previous_node=node_to_gen_from.previous_node, - next_node=node_to_gen_from, - gs_gen_call_kwargs=gen_kwargs, - experiment=self.experiment, - ) - - return arms_from_node - # ------------------------- Model selection logic helpers. ------------------------- def _fit_current_model( diff --git a/ax/generation_strategy/tests/test_generation_node.py b/ax/generation_strategy/tests/test_generation_node.py index edd4203045b..64bb6c53647 100644 --- a/ax/generation_strategy/tests/test_generation_node.py +++ b/ax/generation_strategy/tests/test_generation_node.py @@ -164,7 +164,9 @@ def test_gen(self) -> None: self.sobol_model_spec, "gen", wraps=self.sobol_model_spec.gen ) as mock_model_spec_gen: gr = self.sobol_generation_node.gen( - n=1, pending_observations={"branin": []} + experiment=self.branin_experiment, + n=1, + pending_observations={"branin": []}, ) mock_model_spec_gen.assert_called_with(n=1, pending_observations={"branin": []}) self.assertEqual(gr._model_key, self.sobol_model_spec.model_key) @@ -194,7 +196,7 @@ def test_gen_with_trial_type(self) -> None: experiment=self.branin_experiment, data=self.branin_data, ) - gr = mbm_short.gen(n=2) + gr = mbm_short.gen(experiment=self.branin_experiment, n=2) gen_metadata = gr.gen_metadata self.assertIsNotNone(gen_metadata) self.assertEqual(gen_metadata["trial_type"], Keys.SHORT_RUN) @@ -206,10 +208,11 @@ def test_gen_with_no_trial_type(self) -> None: experiment=self.branin_experiment, data=self.branin_data, ) - gr = self.sobol_generation_node.gen(n=2) + gr = self.sobol_generation_node.gen(experiment=self.branin_experiment, n=2) self.assertIsNotNone(gr.gen_metadata) self.assertFalse("trial_type" in gr.gen_metadata) + @mock_botorch_optimize def test_model_gen_kwargs_deepcopy(self) -> None: sampler = SobolQMCNormalSampler(torch.Size([1])) node = GenerationNode( @@ -234,7 +237,9 @@ def test_model_gen_kwargs_deepcopy(self) -> None: experiment=self.branin_experiment, data=dat, ) - node.gen(n=1, pending_observations={"branin": []}) + node.gen( + experiment=self.branin_experiment, n=1, pending_observations={"branin": []} + ) # verify that sampler is not modified in-place by checking base samples self.assertIs( node.model_spec_to_gen_from.model_gen_kwargs["model_gen_options"][ @@ -435,7 +440,9 @@ def test_gen(self) -> None: ) # Check that with `ModelSelectionNode` generation from a node with # multiple model specs does not fail. - gr = self.model_selection_node.gen(n=1, pending_observations={"branin": []}) + gr = self.model_selection_node.gen( + experiment=self.branin_experiment, n=1, pending_observations={"branin": []} + ) # Check that the metric aggregation function is called twice, once for each # model spec. self.assertEqual(self.mock_aggregation.call_count, 2) diff --git a/ax/modelbridge/cross_validation.py b/ax/modelbridge/cross_validation.py index 4ffaffcadbf..946e0673133 100644 --- a/ax/modelbridge/cross_validation.py +++ b/ax/modelbridge/cross_validation.py @@ -21,6 +21,7 @@ from ax.core.observation import Observation, ObservationData, recombine_observations from ax.core.optimization_config import OptimizationConfig from ax.modelbridge.base import Adapter, unwrap_observation_data +from ax.modelbridge.random import RandomAdapter from ax.utils.common.logger import get_logger from ax.utils.stats.model_fit_stats import ( coefficient_of_determination, @@ -436,7 +437,8 @@ def get_fit_and_std_quality_and_generalization_dict( } except Exception as e: - warn("Encountered exception in computing model fit quality: " + str(e)) + if not isinstance(Adapter, RandomAdapter): + warn("Encountered exception in computing model fit quality: " + str(e)) return { "model_fit_quality": None, "model_std_quality": None, diff --git a/ax/utils/common/kwargs.py b/ax/utils/common/kwargs.py index 7c4a23c1c02..501ce138c1c 100644 --- a/ax/utils/common/kwargs.py +++ b/ax/utils/common/kwargs.py @@ -20,7 +20,8 @@ def consolidate_kwargs( - kwargs_iterable: Iterable[dict[str, Any] | None], keywords: Iterable[str] + kwargs_iterable: Iterable[dict[str, Any] | None], + keywords: Iterable[str] | None = None, ) -> dict[str, Any]: """Combine an iterable of kwargs into a single dict of kwargs, where kwargs by duplicate keys that appear later in the iterable get priority over the @@ -37,7 +38,13 @@ def consolidate_kwargs( all_kwargs = {} for kwargs in kwargs_iterable: if kwargs is not None: - all_kwargs.update({kw: p for kw, p in kwargs.items() if kw in keywords}) + all_kwargs.update( + { + kw: p + for kw, p in kwargs.items() + if keywords is None or kw in keywords + } + ) return all_kwargs