Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Only load last observation of map data by default #3403

Open
wants to merge 1 commit into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion ax/core/observation.py
Original file line number Diff line number Diff line change
Expand Up @@ -511,7 +511,7 @@ def observations_from_data(
statuses_to_include_map_metric = NON_ABANDONED_STATUSES
is_map_data = isinstance(data, MapData)
map_keys = []
if is_map_data:
if is_map_data and len(assert_is_instance(data, MapData).map_df) > 0:
data = assert_is_instance(data, MapData)
map_keys.extend(data.map_keys)
if latest_rows_per_group is not None:
Expand Down
9 changes: 8 additions & 1 deletion ax/modelbridge/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -112,6 +112,7 @@ def __init__(
fit_tracking_metrics: bool = True,
fit_on_init: bool = True,
fit_only_completed_map_metrics: bool = True,
latest_rows_per_group: int | None = 1,
) -> None:
"""
Applies transforms and fits model.
Expand Down Expand Up @@ -159,6 +160,11 @@ def __init__(
fit_only_completed_map_metrics: Whether to fit a model to map metrics only
when the trial is completed. This is useful for applications like
modeling partially completed learning curves in AutoML.
latest_rows_per_group: If specified and data is an instance of MapData,
uses MapData.latest() with `rows_per_group=latest_rows_per_group` to
retrieve the most recent rows for each group. Useful in cases where
learning curves are frequently updated, preventing an excessive
number of Observation objects.
"""
t_fit_start = time.monotonic()
transforms = transforms or []
Expand Down Expand Up @@ -188,6 +194,7 @@ def __init__(
self._fit_abandoned = fit_abandoned
self._fit_tracking_metrics = fit_tracking_metrics
self._fit_only_completed_map_metrics = fit_only_completed_map_metrics
self._latest_rows_per_group = latest_rows_per_group
self.outcomes: list[str] = []
self._experiment_has_immutable_search_space_and_opt_config: bool = (
experiment is not None and experiment.immutable_search_space_and_opt_config
Expand Down Expand Up @@ -302,7 +309,7 @@ def _prepare_observations(
return observations_from_data(
experiment=experiment,
data=data,
latest_rows_per_group=None,
latest_rows_per_group=self._latest_rows_per_group,
statuses_to_include=self.statuses_to_fit,
statuses_to_include_map_metric=self.statuses_to_fit_map_metric,
map_keys_as_parameters=map_keys_as_parameters,
Expand Down
4 changes: 2 additions & 2 deletions ax/modelbridge/tests/test_base_modelbridge.py
Original file line number Diff line number Diff line change
Expand Up @@ -1062,8 +1062,8 @@ def test_fit_only_completed_map_metrics(
)
_, kwargs = mock_observations_from_data.call_args
self.assertTrue(kwargs["map_keys_as_parameters"])
# assert `latest_rows_per_group` is not specified or is None
self.assertIsNone(kwargs.get("latest_rows_per_group"))
# assert `latest_rows_per_group` is not specified or is 1
self.assertEqual(kwargs.get("latest_rows_per_group"), 1)
mock_observations_from_data.reset_mock()

# calling without map data calls observations_from_data with
Expand Down
2 changes: 2 additions & 0 deletions ax/modelbridge/torch.py
Original file line number Diff line number Diff line change
Expand Up @@ -120,6 +120,7 @@ def __init__(
fit_on_init: bool = True,
default_model_gen_options: TConfig | None = None,
fit_only_completed_map_metrics: bool = True,
latest_rows_per_group: int | None = 1,
) -> None:
# This warning is being added while we are on 0.4.3, so it will be
# released in 0.4.4 or 0.5.0. The `torch_dtype` argument can be removed
Expand Down Expand Up @@ -163,6 +164,7 @@ def __init__(
fit_tracking_metrics=fit_tracking_metrics,
fit_on_init=fit_on_init,
fit_only_completed_map_metrics=fit_only_completed_map_metrics,
latest_rows_per_group=latest_rows_per_group,
)

def feature_importances(self, metric_name: str) -> dict[str, float]:
Expand Down
Loading