diff --git a/ax/core/observation.py b/ax/core/observation.py index 6732d00cbcb..5b91f0b2774 100644 --- a/ax/core/observation.py +++ b/ax/core/observation.py @@ -468,6 +468,7 @@ def observations_from_data( latest_rows_per_group: int | None = None, limit_rows_per_metric: int | None = None, limit_rows_per_group: int | None = None, + load_only_completed_map_metrics: bool = True, ) -> list[Observation]: """Convert Data (or MapData) to observations. @@ -501,6 +502,8 @@ def observations_from_data( uses MapData.subsample() with `limit_rows_per_group` on the first map_key (map_data.map_keys[0]) to subsample the MapData. Ignored if `latest_rows_per_group` is specified. + load_only_completed_map_metrics: If True, only loads the last observation + for each completed MapMetric. Returns: List of Observation objects. @@ -511,7 +514,8 @@ def observations_from_data( statuses_to_include_map_metric = NON_ABANDONED_STATUSES is_map_data = isinstance(data, MapData) map_keys = [] - if is_map_data: + take_map_branch = is_map_data and not load_only_completed_map_metrics + if take_map_branch: data = assert_is_instance(data, MapData) map_keys.extend(data.map_keys) if latest_rows_per_group is not None: @@ -526,7 +530,7 @@ def observations_from_data( df = data.map_df else: df = data.df - feature_cols = get_feature_cols(data, is_map_data=is_map_data) + feature_cols = get_feature_cols(data, is_map_data=take_map_branch) return _observations_from_dataframe( experiment=experiment, df=df, diff --git a/ax/modelbridge/base.py b/ax/modelbridge/base.py index f02cee14aa0..44c6b52411e 100644 --- a/ax/modelbridge/base.py +++ b/ax/modelbridge/base.py @@ -306,6 +306,7 @@ def _prepare_observations( statuses_to_include=self.statuses_to_fit, statuses_to_include_map_metric=self.statuses_to_fit_map_metric, map_keys_as_parameters=map_keys_as_parameters, + load_only_completed_map_metrics=self._fit_only_completed_map_metrics, ) def _transform_data(