Skip to content

Commit

Permalink
Only load completed map metrics by default (#3409)
Browse files Browse the repository at this point in the history
Summary:
Pull Request resolved: #3409

This commit ensures that the `Adapter`only loads completed map metrics by default. This makes sure that any method (not-necessarily map-data aware) can be applied by default.

Reviewed By: dme65, saitcakmak

Differential Revision: D70012386

fbshipit-source-id: 519df369fefc11aa3f8abd234827092db7d511c4
  • Loading branch information
SebastianAment authored and facebook-github-bot committed Feb 22, 2025
1 parent 63a1eaf commit 7c2788f
Show file tree
Hide file tree
Showing 4 changed files with 13 additions and 3 deletions.
8 changes: 6 additions & 2 deletions ax/core/observation.py
Original file line number Diff line number Diff line change
Expand Up @@ -468,6 +468,7 @@ def observations_from_data(
latest_rows_per_group: int | None = None,
limit_rows_per_metric: int | None = None,
limit_rows_per_group: int | None = None,
load_only_completed_map_metrics: bool = True,
) -> list[Observation]:
"""Convert Data (or MapData) to observations.
Expand Down Expand Up @@ -501,6 +502,8 @@ def observations_from_data(
uses MapData.subsample() with `limit_rows_per_group` on the first
map_key (map_data.map_keys[0]) to subsample the MapData. Ignored if
`latest_rows_per_group` is specified.
load_only_completed_map_metrics: If True, only loads the last observation
for each completed MapMetric.
Returns:
List of Observation objects.
Expand All @@ -511,7 +514,8 @@ def observations_from_data(
statuses_to_include_map_metric = NON_ABANDONED_STATUSES
is_map_data = isinstance(data, MapData)
map_keys = []
if is_map_data:
take_map_branch = is_map_data and not load_only_completed_map_metrics
if take_map_branch:
data = assert_is_instance(data, MapData)
map_keys.extend(data.map_keys)
if latest_rows_per_group is not None:
Expand All @@ -526,7 +530,7 @@ def observations_from_data(
df = data.map_df
else:
df = data.df
feature_cols = get_feature_cols(data, is_map_data=is_map_data)
feature_cols = get_feature_cols(data, is_map_data=take_map_branch)
return _observations_from_dataframe(
experiment=experiment,
df=df,
Expand Down
4 changes: 3 additions & 1 deletion ax/core/tests/test_observation.py
Original file line number Diff line number Diff line change
Expand Up @@ -474,7 +474,9 @@ def test_ObservationsFromMapData(self) -> None:
MapKeyInfo(key="timestamp", default_value=0.0),
],
)
observations = observations_from_data(experiment, data)
observations = observations_from_data(
experiment, data, load_only_completed_map_metrics=False
)

self.assertEqual(len(observations), 3)

Expand Down
1 change: 1 addition & 0 deletions ax/modelbridge/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -306,6 +306,7 @@ def _prepare_observations(
statuses_to_include=self.statuses_to_fit,
statuses_to_include_map_metric=self.statuses_to_fit_map_metric,
map_keys_as_parameters=map_keys_as_parameters,
load_only_completed_map_metrics=self._fit_only_completed_map_metrics,
)

def _transform_data(
Expand Down
3 changes: 3 additions & 0 deletions ax/modelbridge/map_torch.py
Original file line number Diff line number Diff line change
Expand Up @@ -201,8 +201,10 @@ def _fit(
`self.parameters_with_map_keys` instead of `self.parameters`.
"""
self.parameters = list(search_space.parameters.keys())
print("MAPTORCH parameters", parameters)
if parameters is None:
parameters = self.parameters_with_map_keys
print("MAPTORCH parameters", parameters)
super()._fit(
model=model,
search_space=search_space,
Expand Down Expand Up @@ -265,6 +267,7 @@ def _prepare_observations(
statuses_to_include=self.statuses_to_fit,
statuses_to_include_map_metric=self.statuses_to_fit_map_metric,
map_keys_as_parameters=True,
load_only_completed_map_metrics=False,
)

def _compute_in_design(
Expand Down

0 comments on commit 7c2788f

Please sign in to comment.