diff --git a/ax/core/observation.py b/ax/core/observation.py index c73d0075add..5c76c3a0604 100644 --- a/ax/core/observation.py +++ b/ax/core/observation.py @@ -439,7 +439,7 @@ def get_feature_cols(data: Data, is_map_data: bool = False) -> list[str]: feature_cols = OBS_COLS.intersection(data.df.columns) # note we use this check, rather than isinstance, since # only some Adapters (e.g. MapTorchAdapter) - # use observations_from_map_data, which is required + # use observations_from_data, which is required # to properly handle MapData features (e.g. fidelity). if is_map_data: data = assert_is_instance(data, MapData) @@ -464,74 +464,36 @@ def observations_from_data( data: Data, statuses_to_include: set[TrialStatus] | None = None, statuses_to_include_map_metric: set[TrialStatus] | None = None, -) -> list[Observation]: - """Convert Data to observations. - - Converts a Data object to a list of Observation objects. Pulls arm parameters from - from experiment. Overrides fidelity parameters in the arm with those found in the - Data object. - - Uses a diagonal covariance matrix across metric_names. - - Args: - experiment: Experiment with arm parameters. - data: Data of observations. - statuses_to_include: data from non-MapMetrics will only be included for trials - with statuses in this set. Defaults to all statuses except abandoned. - statuses_to_include_map_metric: data from MapMetrics will only be included for - trials with statuses in this set. Defaults to completed status only. - - Returns: - List of Observation objects. - """ - if statuses_to_include is None: - statuses_to_include = NON_ABANDONED_STATUSES - if statuses_to_include_map_metric is None: - statuses_to_include_map_metric = {TrialStatus.COMPLETED} - feature_cols = get_feature_cols(data) - return _observations_from_dataframe( - experiment=experiment, - df=data.df, - cols=feature_cols, - statuses_to_include=statuses_to_include, - statuses_to_include_map_metric=statuses_to_include_map_metric, - map_keys=[], - ) - - -def observations_from_map_data( - experiment: experiment.Experiment, - map_data: MapData, - statuses_to_include: set[TrialStatus] | None = None, - statuses_to_include_map_metric: set[TrialStatus] | None = None, map_keys_as_parameters: bool = False, limit_rows_per_metric: int | None = None, limit_rows_per_group: int | None = None, ) -> list[Observation]: - """Convert MapData to observations. + """Convert Data (or MapData) to observations. - Converts a MapData object to a list of Observation objects. Pulls arm parameters - from experiment. Overrides fidelity parameters in the arm with those found in the - Data object. + Converts a Data (or MapData) object to a list of Observation objects. + Pulls arm parameters from from experiment. Overrides fidelity parameters + in the arm with those found in the Data object. Uses a diagonal covariance matrix across metric_names. Args: experiment: Experiment with arm parameters. - map_data: MapData of observations. + data: Data (or MapData) of observations. statuses_to_include: data from non-MapMetrics will only be included for trials with statuses in this set. Defaults to all statuses except abandoned. statuses_to_include_map_metric: data from MapMetrics will only be included for trials with statuses in this set. Defaults to all statuses except abandoned. map_keys_as_parameters: Whether map_keys should be returned as part of the parameters of the Observation objects. - limit_rows_per_metric: If specified, uses MapData.subsample() with + limit_rows_per_metric: If specified, and if data is an instance of MapData, + uses MapData.subsample() with `limit_rows_per_metric` equal to the specified value on the first map_key (map_data.map_keys[0]) to subsample the MapData. This is useful in, e.g., cases where learning curves are frequently updated, leading to an intractable number of Observation objects created. - limit_rows_per_group: If specified, uses MapData.subsample() with + limit_rows_per_group: If specified, and if data is an instance of MapData, + uses MapData.subsample() with `limit_rows_per_group` equal to the specified value on the first map_key (map_data.map_keys[0]) to subsample the MapData. @@ -542,19 +504,27 @@ def observations_from_map_data( statuses_to_include = NON_ABANDONED_STATUSES if statuses_to_include_map_metric is None: statuses_to_include_map_metric = NON_ABANDONED_STATUSES - if limit_rows_per_metric is not None or limit_rows_per_group is not None: - map_data = map_data.subsample( - map_key=map_data.map_keys[0], - limit_rows_per_metric=limit_rows_per_metric, - limit_rows_per_group=limit_rows_per_group, - include_first_last=True, - ) - feature_cols = get_feature_cols(map_data, is_map_data=True) + is_map_data = isinstance(data, MapData) + map_keys = [] + if is_map_data: + data = assert_is_instance(data, MapData) + map_keys.extend(data.map_keys) + if limit_rows_per_metric is not None or limit_rows_per_group is not None: + data = data.subsample( + map_key=map_keys[0], + limit_rows_per_metric=limit_rows_per_metric, + limit_rows_per_group=limit_rows_per_group, + include_first_last=True, + ) + df = data.map_df + else: + df = data.df + feature_cols = get_feature_cols(data, is_map_data=is_map_data) return _observations_from_dataframe( experiment=experiment, - df=map_data.map_df, + df=df, cols=feature_cols, - map_keys=map_data.map_keys, + map_keys=map_keys, statuses_to_include=statuses_to_include, statuses_to_include_map_metric=statuses_to_include_map_metric, map_keys_as_parameters=map_keys_as_parameters, diff --git a/ax/core/tests/test_observation.py b/ax/core/tests/test_observation.py index c0b7b8f7a2e..10ec6564a16 100644 --- a/ax/core/tests/test_observation.py +++ b/ax/core/tests/test_observation.py @@ -23,7 +23,6 @@ ObservationData, ObservationFeatures, observations_from_data, - observations_from_map_data, recombine_observations, separate_observations, ) @@ -475,7 +474,7 @@ def test_ObservationsFromMapData(self) -> None: MapKeyInfo(key="timestamp", default_value=0.0), ], ) - observations = observations_from_map_data(experiment, data) + observations = observations_from_data(experiment, data) self.assertEqual(len(observations), 3) diff --git a/ax/modelbridge/base.py b/ax/modelbridge/base.py index e2c6825e0a5..a70411ccb2e 100644 --- a/ax/modelbridge/base.py +++ b/ax/modelbridge/base.py @@ -26,7 +26,6 @@ ObservationData, ObservationFeatures, observations_from_data, - observations_from_map_data, recombine_observations, separate_observations, ) @@ -297,19 +296,15 @@ def _prepare_observations( ) -> list[Observation]: if experiment is None or data is None: return [] - if not self._fit_only_completed_map_metrics and isinstance(data, MapData): - return observations_from_map_data( - experiment=experiment, - map_data=data, - map_keys_as_parameters=True, - statuses_to_include=self.statuses_to_fit, - statuses_to_include_map_metric=self.statuses_to_fit_map_metric, - ) + map_keys_as_parameters = ( + not self._fit_only_completed_map_metrics and isinstance(data, MapData) + ) return observations_from_data( experiment=experiment, data=data, statuses_to_include=self.statuses_to_fit, statuses_to_include_map_metric=self.statuses_to_fit_map_metric, + map_keys_as_parameters=map_keys_as_parameters, ) def _transform_data( diff --git a/ax/modelbridge/map_torch.py b/ax/modelbridge/map_torch.py index 597407556e6..264150dc678 100644 --- a/ax/modelbridge/map_torch.py +++ b/ax/modelbridge/map_torch.py @@ -19,7 +19,7 @@ Observation, ObservationData, ObservationFeatures, - observations_from_map_data, + observations_from_data, separate_observations, ) from ax.core.optimization_config import OptimizationConfig @@ -256,14 +256,14 @@ def _prepare_observations( """ if experiment is None or data is None: return [] - return observations_from_map_data( + return observations_from_data( experiment=experiment, - map_data=data, # pyre-ignore[6]: Checked in __init__. - map_keys_as_parameters=True, + data=data, limit_rows_per_metric=self._map_data_limit_rows_per_metric, limit_rows_per_group=self._map_data_limit_rows_per_group, statuses_to_include=self.statuses_to_fit, statuses_to_include_map_metric=self.statuses_to_fit_map_metric, + map_keys_as_parameters=True, ) def _compute_in_design( diff --git a/ax/modelbridge/tests/test_base_modelbridge.py b/ax/modelbridge/tests/test_base_modelbridge.py index e882bc7efa6..b60d3dfca46 100644 --- a/ax/modelbridge/tests/test_base_modelbridge.py +++ b/ax/modelbridge/tests/test_base_modelbridge.py @@ -1040,23 +1040,18 @@ def test_SetModelSpace(self) -> None: self.assertEqual(sum(m.training_in_design), 7) self.assertEqual(m.model_space.parameters["x2"].upper, 20) - @mock.patch( - "ax.modelbridge.base.observations_from_map_data", - autospec=True, - return_value=([get_observation1()]), - ) @mock.patch( "ax.modelbridge.base.observations_from_data", autospec=True, - return_value=([get_observation1(), get_observation2()]), + return_value=([get_observation1()]), ) def test_fit_only_completed_map_metrics( - self, mock_observations_from_data: Mock, mock_observations_from_map_data: Mock + self, mock_observations_from_data: Mock ) -> None: # NOTE: If empty data object is not passed, observations are not # extracted, even with mock. # _prepare_observations is called in the constructor and itself calls - # observations_from_map_data. + # observations_from_data with map_keys_as_parameters=True Adapter( search_space=get_search_space_for_value(), model=0, @@ -1065,13 +1060,16 @@ def test_fit_only_completed_map_metrics( status_quo_name="1_1", fit_only_completed_map_metrics=False, ) - self.assertTrue(mock_observations_from_map_data.called) - self.assertFalse(mock_observations_from_data.called) - - # calling without map data calls regular observations_from_data even - # if fit_only_completed_map_metrics is False + mock_observations_from_data.assert_called_once_with( + experiment=mock.ANY, + data=mock.ANY, + statuses_to_include=mock.ANY, + statuses_to_include_map_metric=mock.ANY, + map_keys_as_parameters=True, + ) mock_observations_from_data.reset_mock() - mock_observations_from_map_data.reset_mock() + # calling without map data calls observations_from_data with + # map_keys_as_parameters=False even if fit_only_completed_map_metrics is False Adapter( search_space=get_search_space_for_value(), model=0, @@ -1080,5 +1078,10 @@ def test_fit_only_completed_map_metrics( status_quo_name="1_1", fit_only_completed_map_metrics=False, ) - self.assertFalse(mock_observations_from_map_data.called) - self.assertTrue(mock_observations_from_data.called) + mock_observations_from_data.assert_called_once_with( + experiment=mock.ANY, + data=mock.ANY, + statuses_to_include=mock.ANY, + statuses_to_include_map_metric=mock.ANY, + map_keys_as_parameters=False, + )