Skip to content

Commit 7c2788f

Browse files
SebastianAmentfacebook-github-bot
authored andcommitted
Only load completed map metrics by default (#3409)
Summary: Pull Request resolved: #3409 This commit ensures that the `Adapter`only loads completed map metrics by default. This makes sure that any method (not-necessarily map-data aware) can be applied by default. Reviewed By: dme65, saitcakmak Differential Revision: D70012386 fbshipit-source-id: 519df369fefc11aa3f8abd234827092db7d511c4
1 parent 63a1eaf commit 7c2788f

File tree

4 files changed

+13
-3
lines changed

4 files changed

+13
-3
lines changed

ax/core/observation.py

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -468,6 +468,7 @@ def observations_from_data(
468468
latest_rows_per_group: int | None = None,
469469
limit_rows_per_metric: int | None = None,
470470
limit_rows_per_group: int | None = None,
471+
load_only_completed_map_metrics: bool = True,
471472
) -> list[Observation]:
472473
"""Convert Data (or MapData) to observations.
473474
@@ -501,6 +502,8 @@ def observations_from_data(
501502
uses MapData.subsample() with `limit_rows_per_group` on the first
502503
map_key (map_data.map_keys[0]) to subsample the MapData. Ignored if
503504
`latest_rows_per_group` is specified.
505+
load_only_completed_map_metrics: If True, only loads the last observation
506+
for each completed MapMetric.
504507
505508
Returns:
506509
List of Observation objects.
@@ -511,7 +514,8 @@ def observations_from_data(
511514
statuses_to_include_map_metric = NON_ABANDONED_STATUSES
512515
is_map_data = isinstance(data, MapData)
513516
map_keys = []
514-
if is_map_data:
517+
take_map_branch = is_map_data and not load_only_completed_map_metrics
518+
if take_map_branch:
515519
data = assert_is_instance(data, MapData)
516520
map_keys.extend(data.map_keys)
517521
if latest_rows_per_group is not None:
@@ -526,7 +530,7 @@ def observations_from_data(
526530
df = data.map_df
527531
else:
528532
df = data.df
529-
feature_cols = get_feature_cols(data, is_map_data=is_map_data)
533+
feature_cols = get_feature_cols(data, is_map_data=take_map_branch)
530534
return _observations_from_dataframe(
531535
experiment=experiment,
532536
df=df,

ax/core/tests/test_observation.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -474,7 +474,9 @@ def test_ObservationsFromMapData(self) -> None:
474474
MapKeyInfo(key="timestamp", default_value=0.0),
475475
],
476476
)
477-
observations = observations_from_data(experiment, data)
477+
observations = observations_from_data(
478+
experiment, data, load_only_completed_map_metrics=False
479+
)
478480

479481
self.assertEqual(len(observations), 3)
480482

ax/modelbridge/base.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -306,6 +306,7 @@ def _prepare_observations(
306306
statuses_to_include=self.statuses_to_fit,
307307
statuses_to_include_map_metric=self.statuses_to_fit_map_metric,
308308
map_keys_as_parameters=map_keys_as_parameters,
309+
load_only_completed_map_metrics=self._fit_only_completed_map_metrics,
309310
)
310311

311312
def _transform_data(

ax/modelbridge/map_torch.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -201,8 +201,10 @@ def _fit(
201201
`self.parameters_with_map_keys` instead of `self.parameters`.
202202
"""
203203
self.parameters = list(search_space.parameters.keys())
204+
print("MAPTORCH parameters", parameters)
204205
if parameters is None:
205206
parameters = self.parameters_with_map_keys
207+
print("MAPTORCH parameters", parameters)
206208
super()._fit(
207209
model=model,
208210
search_space=search_space,
@@ -265,6 +267,7 @@ def _prepare_observations(
265267
statuses_to_include=self.statuses_to_fit,
266268
statuses_to_include_map_metric=self.statuses_to_fit_map_metric,
267269
map_keys_as_parameters=True,
270+
load_only_completed_map_metrics=False,
268271
)
269272

270273
def _compute_in_design(

0 commit comments

Comments
 (0)