From 7e250776497e5088c61ffccfb8caa62e61260644 Mon Sep 17 00:00:00 2001 From: Sebastian Ament Date: Fri, 21 Feb 2025 14:31:46 -0800 Subject: [PATCH] Only load completed map metrics by default (#3409) Summary: This commit ensures that the `Adapter`only loads completed map metrics by default. This makes sure that any method (not-necessarily map-data aware) can be applied by default. Reviewed By: dme65 Differential Revision: D70012386 --- ax/core/observation.py | 8 ++++++-- ax/modelbridge/base.py | 1 + 2 files changed, 7 insertions(+), 2 deletions(-) diff --git a/ax/core/observation.py b/ax/core/observation.py index 6732d00cbcb..5b91f0b2774 100644 --- a/ax/core/observation.py +++ b/ax/core/observation.py @@ -468,6 +468,7 @@ def observations_from_data( latest_rows_per_group: int | None = None, limit_rows_per_metric: int | None = None, limit_rows_per_group: int | None = None, + load_only_completed_map_metrics: bool = True, ) -> list[Observation]: """Convert Data (or MapData) to observations. @@ -501,6 +502,8 @@ def observations_from_data( uses MapData.subsample() with `limit_rows_per_group` on the first map_key (map_data.map_keys[0]) to subsample the MapData. Ignored if `latest_rows_per_group` is specified. + load_only_completed_map_metrics: If True, only loads the last observation + for each completed MapMetric. Returns: List of Observation objects. @@ -511,7 +514,8 @@ def observations_from_data( statuses_to_include_map_metric = NON_ABANDONED_STATUSES is_map_data = isinstance(data, MapData) map_keys = [] - if is_map_data: + take_map_branch = is_map_data and not load_only_completed_map_metrics + if take_map_branch: data = assert_is_instance(data, MapData) map_keys.extend(data.map_keys) if latest_rows_per_group is not None: @@ -526,7 +530,7 @@ def observations_from_data( df = data.map_df else: df = data.df - feature_cols = get_feature_cols(data, is_map_data=is_map_data) + feature_cols = get_feature_cols(data, is_map_data=take_map_branch) return _observations_from_dataframe( experiment=experiment, df=df, diff --git a/ax/modelbridge/base.py b/ax/modelbridge/base.py index f02cee14aa0..44c6b52411e 100644 --- a/ax/modelbridge/base.py +++ b/ax/modelbridge/base.py @@ -306,6 +306,7 @@ def _prepare_observations( statuses_to_include=self.statuses_to_fit, statuses_to_include_map_metric=self.statuses_to_fit_map_metric, map_keys_as_parameters=map_keys_as_parameters, + load_only_completed_map_metrics=self._fit_only_completed_map_metrics, ) def _transform_data(