diff --git a/ax/core/observation.py b/ax/core/observation.py index 6732d00cbcb..5b91f0b2774 100644 --- a/ax/core/observation.py +++ b/ax/core/observation.py @@ -468,6 +468,7 @@ def observations_from_data( latest_rows_per_group: int | None = None, limit_rows_per_metric: int | None = None, limit_rows_per_group: int | None = None, + load_only_completed_map_metrics: bool = True, ) -> list[Observation]: """Convert Data (or MapData) to observations. @@ -501,6 +502,8 @@ def observations_from_data( uses MapData.subsample() with `limit_rows_per_group` on the first map_key (map_data.map_keys[0]) to subsample the MapData. Ignored if `latest_rows_per_group` is specified. + load_only_completed_map_metrics: If True, only loads the last observation + for each completed MapMetric. Returns: List of Observation objects. @@ -511,7 +514,8 @@ def observations_from_data( statuses_to_include_map_metric = NON_ABANDONED_STATUSES is_map_data = isinstance(data, MapData) map_keys = [] - if is_map_data: + take_map_branch = is_map_data and not load_only_completed_map_metrics + if take_map_branch: data = assert_is_instance(data, MapData) map_keys.extend(data.map_keys) if latest_rows_per_group is not None: @@ -526,7 +530,7 @@ def observations_from_data( df = data.map_df else: df = data.df - feature_cols = get_feature_cols(data, is_map_data=is_map_data) + feature_cols = get_feature_cols(data, is_map_data=take_map_branch) return _observations_from_dataframe( experiment=experiment, df=df, diff --git a/ax/core/tests/test_observation.py b/ax/core/tests/test_observation.py index 10ec6564a16..8a281156a44 100644 --- a/ax/core/tests/test_observation.py +++ b/ax/core/tests/test_observation.py @@ -474,7 +474,9 @@ def test_ObservationsFromMapData(self) -> None: MapKeyInfo(key="timestamp", default_value=0.0), ], ) - observations = observations_from_data(experiment, data) + observations = observations_from_data( + experiment, data, load_only_completed_map_metrics=False + ) self.assertEqual(len(observations), 3) diff --git a/ax/modelbridge/base.py b/ax/modelbridge/base.py index f02cee14aa0..44c6b52411e 100644 --- a/ax/modelbridge/base.py +++ b/ax/modelbridge/base.py @@ -306,6 +306,7 @@ def _prepare_observations( statuses_to_include=self.statuses_to_fit, statuses_to_include_map_metric=self.statuses_to_fit_map_metric, map_keys_as_parameters=map_keys_as_parameters, + load_only_completed_map_metrics=self._fit_only_completed_map_metrics, ) def _transform_data( diff --git a/ax/modelbridge/map_torch.py b/ax/modelbridge/map_torch.py index 959b9d78e95..70f77c9e962 100644 --- a/ax/modelbridge/map_torch.py +++ b/ax/modelbridge/map_torch.py @@ -201,8 +201,10 @@ def _fit( `self.parameters_with_map_keys` instead of `self.parameters`. """ self.parameters = list(search_space.parameters.keys()) + print("MAPTORCH parameters", parameters) if parameters is None: parameters = self.parameters_with_map_keys + print("MAPTORCH parameters", parameters) super()._fit( model=model, search_space=search_space, @@ -265,6 +267,7 @@ def _prepare_observations( statuses_to_include=self.statuses_to_fit, statuses_to_include_map_metric=self.statuses_to_fit_map_metric, map_keys_as_parameters=True, + load_only_completed_map_metrics=False, ) def _compute_in_design(