From a78451e749d779a4880f9636e069d056f5833ce6 Mon Sep 17 00:00:00 2001 From: Sebastian Ament Date: Fri, 21 Feb 2025 12:01:49 -0800 Subject: [PATCH] Only load last observation of map data by default Summary: This commit ensures that the `Adapter` only loads a single observation by default, even for map metrics. This makes sure that any method (not-necessarily map-data aware) can be applied by default. Differential Revision: D69992533 --- ax/core/observation.py | 2 +- ax/modelbridge/base.py | 9 ++++++++- ax/modelbridge/tests/test_base_modelbridge.py | 4 ++-- ax/modelbridge/torch.py | 2 ++ 4 files changed, 13 insertions(+), 4 deletions(-) diff --git a/ax/core/observation.py b/ax/core/observation.py index 6732d00cbcb..13f732df3a4 100644 --- a/ax/core/observation.py +++ b/ax/core/observation.py @@ -511,7 +511,7 @@ def observations_from_data( statuses_to_include_map_metric = NON_ABANDONED_STATUSES is_map_data = isinstance(data, MapData) map_keys = [] - if is_map_data: + if is_map_data and len(assert_is_instance(data, MapData).map_df) > 0: data = assert_is_instance(data, MapData) map_keys.extend(data.map_keys) if latest_rows_per_group is not None: diff --git a/ax/modelbridge/base.py b/ax/modelbridge/base.py index f02cee14aa0..e7f65f6c59c 100644 --- a/ax/modelbridge/base.py +++ b/ax/modelbridge/base.py @@ -112,6 +112,7 @@ def __init__( fit_tracking_metrics: bool = True, fit_on_init: bool = True, fit_only_completed_map_metrics: bool = True, + latest_rows_per_group: int | None = 1, ) -> None: """ Applies transforms and fits model. @@ -159,6 +160,11 @@ def __init__( fit_only_completed_map_metrics: Whether to fit a model to map metrics only when the trial is completed. This is useful for applications like modeling partially completed learning curves in AutoML. + latest_rows_per_group: If specified and data is an instance of MapData, + uses MapData.latest() with `rows_per_group=latest_rows_per_group` to + retrieve the most recent rows for each group. Useful in cases where + learning curves are frequently updated, preventing an excessive + number of Observation objects. """ t_fit_start = time.monotonic() transforms = transforms or [] @@ -188,6 +194,7 @@ def __init__( self._fit_abandoned = fit_abandoned self._fit_tracking_metrics = fit_tracking_metrics self._fit_only_completed_map_metrics = fit_only_completed_map_metrics + self._latest_rows_per_group = latest_rows_per_group self.outcomes: list[str] = [] self._experiment_has_immutable_search_space_and_opt_config: bool = ( experiment is not None and experiment.immutable_search_space_and_opt_config @@ -302,7 +309,7 @@ def _prepare_observations( return observations_from_data( experiment=experiment, data=data, - latest_rows_per_group=None, + latest_rows_per_group=self._latest_rows_per_group, statuses_to_include=self.statuses_to_fit, statuses_to_include_map_metric=self.statuses_to_fit_map_metric, map_keys_as_parameters=map_keys_as_parameters, diff --git a/ax/modelbridge/tests/test_base_modelbridge.py b/ax/modelbridge/tests/test_base_modelbridge.py index bc0822ba9bb..70ce8365186 100644 --- a/ax/modelbridge/tests/test_base_modelbridge.py +++ b/ax/modelbridge/tests/test_base_modelbridge.py @@ -1062,8 +1062,8 @@ def test_fit_only_completed_map_metrics( ) _, kwargs = mock_observations_from_data.call_args self.assertTrue(kwargs["map_keys_as_parameters"]) - # assert `latest_rows_per_group` is not specified or is None - self.assertIsNone(kwargs.get("latest_rows_per_group")) + # assert `latest_rows_per_group` is not specified or is 1 + self.assertEqual(kwargs.get("latest_rows_per_group"), 1) mock_observations_from_data.reset_mock() # calling without map data calls observations_from_data with diff --git a/ax/modelbridge/torch.py b/ax/modelbridge/torch.py index 54ee8d13811..729549ba223 100644 --- a/ax/modelbridge/torch.py +++ b/ax/modelbridge/torch.py @@ -120,6 +120,7 @@ def __init__( fit_on_init: bool = True, default_model_gen_options: TConfig | None = None, fit_only_completed_map_metrics: bool = True, + latest_rows_per_group: int | None = 1, ) -> None: # This warning is being added while we are on 0.4.3, so it will be # released in 0.4.4 or 0.5.0. The `torch_dtype` argument can be removed @@ -163,6 +164,7 @@ def __init__( fit_tracking_metrics=fit_tracking_metrics, fit_on_init=fit_on_init, fit_only_completed_map_metrics=fit_only_completed_map_metrics, + latest_rows_per_group=latest_rows_per_group, ) def feature_importances(self, metric_name: str) -> dict[str, float]: