Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Only load completed map metrics by default #3409

Closed
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 6 additions & 2 deletions ax/core/observation.py
Original file line number Diff line number Diff line change
Expand Up @@ -468,6 +468,7 @@ def observations_from_data(
latest_rows_per_group: int | None = None,
limit_rows_per_metric: int | None = None,
limit_rows_per_group: int | None = None,
load_only_completed_map_metrics: bool = True,
) -> list[Observation]:
"""Convert Data (or MapData) to observations.

Expand Down Expand Up @@ -501,6 +502,8 @@ def observations_from_data(
uses MapData.subsample() with `limit_rows_per_group` on the first
map_key (map_data.map_keys[0]) to subsample the MapData. Ignored if
`latest_rows_per_group` is specified.
load_only_completed_map_metrics: If True, only loads the last observation
for each completed MapMetric.

Returns:
List of Observation objects.
Expand All @@ -511,7 +514,8 @@ def observations_from_data(
statuses_to_include_map_metric = NON_ABANDONED_STATUSES
is_map_data = isinstance(data, MapData)
map_keys = []
if is_map_data:
take_map_branch = is_map_data and not load_only_completed_map_metrics
if take_map_branch:
data = assert_is_instance(data, MapData)
map_keys.extend(data.map_keys)
if latest_rows_per_group is not None:
Expand All @@ -526,7 +530,7 @@ def observations_from_data(
df = data.map_df
else:
df = data.df
feature_cols = get_feature_cols(data, is_map_data=is_map_data)
feature_cols = get_feature_cols(data, is_map_data=take_map_branch)
return _observations_from_dataframe(
experiment=experiment,
df=df,
Expand Down
4 changes: 3 additions & 1 deletion ax/core/tests/test_observation.py
Original file line number Diff line number Diff line change
Expand Up @@ -474,7 +474,9 @@ def test_ObservationsFromMapData(self) -> None:
MapKeyInfo(key="timestamp", default_value=0.0),
],
)
observations = observations_from_data(experiment, data)
observations = observations_from_data(
experiment, data, load_only_completed_map_metrics=False
)

self.assertEqual(len(observations), 3)

Expand Down
1 change: 1 addition & 0 deletions ax/modelbridge/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -306,6 +306,7 @@ def _prepare_observations(
statuses_to_include=self.statuses_to_fit,
statuses_to_include_map_metric=self.statuses_to_fit_map_metric,
map_keys_as_parameters=map_keys_as_parameters,
load_only_completed_map_metrics=self._fit_only_completed_map_metrics,
)

def _transform_data(
Expand Down
3 changes: 3 additions & 0 deletions ax/modelbridge/map_torch.py
Original file line number Diff line number Diff line change
Expand Up @@ -201,8 +201,10 @@ def _fit(
`self.parameters_with_map_keys` instead of `self.parameters`.
"""
self.parameters = list(search_space.parameters.keys())
print("MAPTORCH parameters", parameters)
if parameters is None:
parameters = self.parameters_with_map_keys
print("MAPTORCH parameters", parameters)
super()._fit(
model=model,
search_space=search_space,
Expand Down Expand Up @@ -265,6 +267,7 @@ def _prepare_observations(
statuses_to_include=self.statuses_to_fit,
statuses_to_include_map_metric=self.statuses_to_fit_map_metric,
map_keys_as_parameters=True,
load_only_completed_map_metrics=False,
)

def _compute_in_design(
Expand Down