From 7c2788fc55b9cb167ee7678badde004f8d049370 Mon Sep 17 00:00:00 2001 From: Sebastian Ament Date: Fri, 21 Feb 2025 17:50:34 -0800 Subject: [PATCH] Only load completed map metrics by default (#3409) Summary: Pull Request resolved: https://github.com/facebook/Ax/pull/3409 This commit ensures that the `Adapter`only loads completed map metrics by default. This makes sure that any method (not-necessarily map-data aware) can be applied by default. Reviewed By: dme65, saitcakmak Differential Revision: D70012386 fbshipit-source-id: 519df369fefc11aa3f8abd234827092db7d511c4 --- ax/core/observation.py | 8 ++++++-- ax/core/tests/test_observation.py | 4 +++- ax/modelbridge/base.py | 1 + ax/modelbridge/map_torch.py | 3 +++ 4 files changed, 13 insertions(+), 3 deletions(-) diff --git a/ax/core/observation.py b/ax/core/observation.py index 6732d00cbcb..5b91f0b2774 100644 --- a/ax/core/observation.py +++ b/ax/core/observation.py @@ -468,6 +468,7 @@ def observations_from_data( latest_rows_per_group: int | None = None, limit_rows_per_metric: int | None = None, limit_rows_per_group: int | None = None, + load_only_completed_map_metrics: bool = True, ) -> list[Observation]: """Convert Data (or MapData) to observations. @@ -501,6 +502,8 @@ def observations_from_data( uses MapData.subsample() with `limit_rows_per_group` on the first map_key (map_data.map_keys[0]) to subsample the MapData. Ignored if `latest_rows_per_group` is specified. + load_only_completed_map_metrics: If True, only loads the last observation + for each completed MapMetric. Returns: List of Observation objects. @@ -511,7 +514,8 @@ def observations_from_data( statuses_to_include_map_metric = NON_ABANDONED_STATUSES is_map_data = isinstance(data, MapData) map_keys = [] - if is_map_data: + take_map_branch = is_map_data and not load_only_completed_map_metrics + if take_map_branch: data = assert_is_instance(data, MapData) map_keys.extend(data.map_keys) if latest_rows_per_group is not None: @@ -526,7 +530,7 @@ def observations_from_data( df = data.map_df else: df = data.df - feature_cols = get_feature_cols(data, is_map_data=is_map_data) + feature_cols = get_feature_cols(data, is_map_data=take_map_branch) return _observations_from_dataframe( experiment=experiment, df=df, diff --git a/ax/core/tests/test_observation.py b/ax/core/tests/test_observation.py index 10ec6564a16..8a281156a44 100644 --- a/ax/core/tests/test_observation.py +++ b/ax/core/tests/test_observation.py @@ -474,7 +474,9 @@ def test_ObservationsFromMapData(self) -> None: MapKeyInfo(key="timestamp", default_value=0.0), ], ) - observations = observations_from_data(experiment, data) + observations = observations_from_data( + experiment, data, load_only_completed_map_metrics=False + ) self.assertEqual(len(observations), 3) diff --git a/ax/modelbridge/base.py b/ax/modelbridge/base.py index f02cee14aa0..44c6b52411e 100644 --- a/ax/modelbridge/base.py +++ b/ax/modelbridge/base.py @@ -306,6 +306,7 @@ def _prepare_observations( statuses_to_include=self.statuses_to_fit, statuses_to_include_map_metric=self.statuses_to_fit_map_metric, map_keys_as_parameters=map_keys_as_parameters, + load_only_completed_map_metrics=self._fit_only_completed_map_metrics, ) def _transform_data( diff --git a/ax/modelbridge/map_torch.py b/ax/modelbridge/map_torch.py index 959b9d78e95..70f77c9e962 100644 --- a/ax/modelbridge/map_torch.py +++ b/ax/modelbridge/map_torch.py @@ -201,8 +201,10 @@ def _fit( `self.parameters_with_map_keys` instead of `self.parameters`. """ self.parameters = list(search_space.parameters.keys()) + print("MAPTORCH parameters", parameters) if parameters is None: parameters = self.parameters_with_map_keys + print("MAPTORCH parameters", parameters) super()._fit( model=model, search_space=search_space, @@ -265,6 +267,7 @@ def _prepare_observations( statuses_to_include=self.statuses_to_fit, statuses_to_include_map_metric=self.statuses_to_fit_map_metric, map_keys_as_parameters=True, + load_only_completed_map_metrics=False, ) def _compute_in_design(