diff --git a/ax/core/observation.py b/ax/core/observation.py index 6732d00cbcb..a4a42f5dce9 100644 --- a/ax/core/observation.py +++ b/ax/core/observation.py @@ -468,6 +468,7 @@ def observations_from_data( latest_rows_per_group: int | None = None, limit_rows_per_metric: int | None = None, limit_rows_per_group: int | None = None, + load_only_completed_map_metrics: bool = True, ) -> list[Observation]: """Convert Data (or MapData) to observations. @@ -501,6 +502,8 @@ def observations_from_data( uses MapData.subsample() with `limit_rows_per_group` on the first map_key (map_data.map_keys[0]) to subsample the MapData. Ignored if `latest_rows_per_group` is specified. + load_only_completed_map_metrics: If True, only loads the last observation + for each completed MapMetric. Returns: List of Observation objects. @@ -511,7 +514,7 @@ def observations_from_data( statuses_to_include_map_metric = NON_ABANDONED_STATUSES is_map_data = isinstance(data, MapData) map_keys = [] - if is_map_data: + if is_map_data and not load_only_completed_map_metrics: data = assert_is_instance(data, MapData) map_keys.extend(data.map_keys) if latest_rows_per_group is not None: diff --git a/ax/modelbridge/base.py b/ax/modelbridge/base.py index f02cee14aa0..44c6b52411e 100644 --- a/ax/modelbridge/base.py +++ b/ax/modelbridge/base.py @@ -306,6 +306,7 @@ def _prepare_observations( statuses_to_include=self.statuses_to_fit, statuses_to_include_map_metric=self.statuses_to_fit_map_metric, map_keys_as_parameters=map_keys_as_parameters, + load_only_completed_map_metrics=self._fit_only_completed_map_metrics, ) def _transform_data(