Skip to content

Commit 2d28dc0

Browse files
Louis Tiaofacebook-github-bot
Louis Tiao
authored andcommitted
Latest observations from MapData (facebook#3112)
Summary: Pull Request resolved: facebook#3112 Differential Revision: D66434621
1 parent e470b2e commit 2d28dc0

File tree

3 files changed

+149
-26
lines changed

3 files changed

+149
-26
lines changed

ax/core/map_data.py

+48-8
Original file line numberDiff line numberDiff line change
@@ -278,15 +278,15 @@ def from_multiple_data(
278278
def df(self) -> pd.DataFrame:
279279
"""Returns a Data shaped DataFrame"""
280280

281-
# If map_keys is empty just return the df
282281
if self._memo_df is not None:
283282
return self._memo_df
284283

284+
# If map_keys is empty just return the df
285285
if len(self.map_keys) == 0:
286286
return self.map_df
287287

288-
self._memo_df = self.map_df.sort_values(self.map_keys).drop_duplicates(
289-
MapData.DEDUPLICATE_BY_COLUMNS, keep="last"
288+
self._memo_df = _tail(
289+
map_df=self.map_df, map_keys=self.map_keys, n=1, sort=True
290290
)
291291

292292
return self._memo_df
@@ -340,6 +340,32 @@ def clone(self) -> MapData:
340340
description=self.description,
341341
)
342342

343+
def latest(
344+
self,
345+
map_keys: list[str] | None = None,
346+
rows_per_group: int = 1,
347+
) -> MapData:
348+
"""Return a new MapData with the most recently observed `rows_per_group`
349+
rows for each (arm, metric) group, determined by the `map_key` values,
350+
where higher implies more recent.
351+
352+
This function considers only the relative ordering of the `map_key` values,
353+
making it most suitable when these values are equally spaced.
354+
355+
If `rows_per_group` is greater than the number of rows in a given
356+
(arm, metric) group, then all rows are returned.
357+
"""
358+
if map_keys is None:
359+
map_keys = self.map_keys
360+
361+
return MapData(
362+
df=_tail(
363+
map_df=self.map_df, map_keys=map_keys, n=rows_per_group, sort=True
364+
),
365+
map_key_infos=self.map_key_infos,
366+
description=self.description,
367+
)
368+
343369
def subsample(
344370
self,
345371
map_key: str | None = None,
@@ -348,11 +374,13 @@ def subsample(
348374
limit_rows_per_metric: int | None = None,
349375
include_first_last: bool = True,
350376
) -> MapData:
351-
"""Subsample the `map_key` column in an equally-spaced manner (if there is
352-
a `self.map_keys` is length one, then `map_key` can be set to None). The
353-
values of the `map_key` column are not taken into account, so this function
354-
is most reasonable when those values are equally-spaced. There are three
355-
ways that this can be done:
377+
"""Return a new MapData that subsamples the `map_key` column in an
378+
equally-spaced manner. If `self.map_keys` has a length of one, `map_key`
379+
can be set to None. This function considers only the relative ordering
380+
of the `map_key` values, making it most suitable when these values are
381+
equally spaced.
382+
383+
There are three ways that this can be done:
356384
1. If `keep_every = k` is set, then every kth row of the DataFrame in the
357385
`map_key` column is kept after grouping by `DEDUPLICATE_BY_COLUMNS`.
358386
In other words, every kth step of each (arm, metric) will be kept.
@@ -456,6 +484,18 @@ def _subsample_rate(
456484
)
457485

458486

487+
def _tail(
488+
map_df: pd.DataFrame,
489+
map_keys: list[str],
490+
n: int = 1,
491+
sort: bool = True,
492+
) -> pd.DataFrame:
493+
df = map_df.sort_values(map_keys).groupby(MapData.DEDUPLICATE_BY_COLUMNS).tail(n)
494+
if sort:
495+
df.sort_values(MapData.DEDUPLICATE_BY_COLUMNS, inplace=True)
496+
return df
497+
498+
459499
def _subsample_one_metric(
460500
map_df: pd.DataFrame,
461501
map_key: str | None = None,

ax/core/observation.py

+27-17
Original file line numberDiff line numberDiff line change
@@ -452,6 +452,7 @@ def observations_from_data(
452452
statuses_to_include: set[TrialStatus] | None = None,
453453
statuses_to_include_map_metric: set[TrialStatus] | None = None,
454454
map_keys_as_parameters: bool = False,
455+
latest_rows_per_group: int | None = None,
455456
limit_rows_per_metric: int | None = None,
456457
limit_rows_per_group: int | None = None,
457458
) -> list[Observation]:
@@ -472,17 +473,21 @@ def observations_from_data(
472473
trials with statuses in this set. Defaults to all statuses except abandoned.
473474
map_keys_as_parameters: Whether map_keys should be returned as part of
474475
the parameters of the Observation objects.
475-
limit_rows_per_metric: If specified, and if data is an instance of MapData,
476-
uses MapData.subsample() with
477-
`limit_rows_per_metric` equal to the specified value on the first
478-
map_key (map_data.map_keys[0]) to subsample the MapData. This is
479-
useful in, e.g., cases where learning curves are frequently
480-
updated, leading to an intractable number of Observation objects
481-
created.
482-
limit_rows_per_group: If specified, and if data is an instance of MapData,
483-
uses MapData.subsample() with
484-
`limit_rows_per_group` equal to the specified value on the first
485-
map_key (map_data.map_keys[0]) to subsample the MapData.
476+
latest_rows_per_group: If specified and data is an instance of MapData,
477+
uses MapData.latest() with `rows_per_group=latest_rows_per_group` to
478+
retrieve the most recent rows for each group. Useful in cases where
479+
learning curves are frequently updated, preventing an excessive
480+
number of Observation objects. Overrides `limit_rows_per_metric`
481+
and `limit_rows_per_group`.
482+
limit_rows_per_metric: If specified and data is an instance of MapData,
483+
uses MapData.subsample() with `limit_rows_per_metric` on the first
484+
map_key (map_data.map_keys[0]) to subsample the MapData. Useful for
485+
managing the number of Observation objects when learning curves are
486+
frequently updated. Ignored if `latest_rows_per_group` is specified.
487+
limit_rows_per_group: If specified and data is an instance of MapData,
488+
uses MapData.subsample() with `limit_rows_per_group` on the first
489+
map_key (map_data.map_keys[0]) to subsample the MapData. Ignored if
490+
`latest_rows_per_group` is specified.
486491
487492
Returns:
488493
List of Observation objects.
@@ -499,13 +504,18 @@ def observations_from_data(
499504
if is_map_data:
500505
data = checked_cast(MapData, data)
501506

502-
if limit_rows_per_metric is not None or limit_rows_per_group is not None:
503-
data = data.subsample(
504-
map_key=data.map_keys[0],
505-
limit_rows_per_metric=limit_rows_per_metric,
506-
limit_rows_per_group=limit_rows_per_group,
507-
include_first_last=True,
507+
if latest_rows_per_group is not None:
508+
data = data.latest(
509+
map_keys=data.map_keys, rows_per_group=latest_rows_per_group
508510
)
511+
else:
512+
if limit_rows_per_metric is not None or limit_rows_per_group is not None:
513+
data = data.subsample(
514+
map_key=data.map_keys[0],
515+
limit_rows_per_metric=limit_rows_per_metric,
516+
limit_rows_per_group=limit_rows_per_group,
517+
include_first_last=True,
518+
)
509519

510520
map_keys.extend(data.map_keys)
511521
obs_cols = obs_cols.union(data.map_keys)

ax/core/tests/test_map_data.py

+74-1
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@
66
# pyre-strict
77

88

9+
import numpy as np
910
import pandas as pd
1011
from ax.core.data import Data
1112
from ax.core.map_data import MapData, MapKeyInfo
@@ -236,7 +237,17 @@ def test_upcast(self) -> None:
236237

237238
self.assertIsNotNone(fresh._memo_df) # Assert df is cached after first call
238239

239-
def test_subsample(self) -> None:
240+
self.assertTrue(
241+
fresh.df.equals(
242+
fresh.map_df.sort_values(fresh.map_keys).drop_duplicates(
243+
MapData.DEDUPLICATE_BY_COLUMNS, keep="last"
244+
)
245+
)
246+
)
247+
248+
def test_latest(self) -> None:
249+
seed = 8888
250+
240251
arm_names = ["0_0", "1_0", "2_0", "3_0"]
241252
max_epochs = [25, 50, 75, 100]
242253
metric_names = ["a", "b"]
@@ -259,6 +270,68 @@ def test_subsample(self) -> None:
259270
)
260271
large_map_data = MapData(df=large_map_df, map_key_infos=self.map_key_infos)
261272

273+
shuffled_large_map_df = large_map_data.map_df.groupby(
274+
MapData.DEDUPLICATE_BY_COLUMNS
275+
).sample(frac=1, random_state=seed)
276+
shuffled_large_map_data = MapData(
277+
df=shuffled_large_map_df, map_key_infos=self.map_key_infos
278+
)
279+
280+
for rows_per_group in [1, 40]:
281+
large_map_data_latest = large_map_data.latest(rows_per_group=rows_per_group)
282+
283+
if rows_per_group == 1:
284+
self.assertTrue(
285+
large_map_data_latest.map_df.groupby("metric_name")
286+
.epoch.transform(lambda col: set(col) == set(max_epochs))
287+
.all()
288+
)
289+
290+
# when rows_per_group is larger than the number of rows
291+
# actually observed in a group
292+
actual_rows_per_group = large_map_data_latest.map_df.groupby(
293+
MapData.DEDUPLICATE_BY_COLUMNS
294+
).size()
295+
expected_rows_per_group = np.minimum(
296+
large_map_data_latest.map_df.groupby(
297+
MapData.DEDUPLICATE_BY_COLUMNS
298+
).epoch.max(),
299+
rows_per_group,
300+
)
301+
self.assertTrue(actual_rows_per_group.equals(expected_rows_per_group))
302+
303+
# behavior should be consistent even if map_keys are not in ascending order
304+
shuffled_large_map_data_latest = shuffled_large_map_data.latest(
305+
rows_per_group=rows_per_group
306+
)
307+
self.assertTrue(
308+
shuffled_large_map_data_latest.map_df.equals(
309+
large_map_data_latest.map_df
310+
)
311+
)
312+
313+
def test_subsample(self) -> None:
314+
arm_names = ["0_0", "1_0", "2_0", "3_0"]
315+
max_epochs = [25, 50, 75, 100]
316+
metric_names = ["a", "b"]
317+
large_map_df = pd.DataFrame(
318+
[
319+
{
320+
"arm_name": arm_name,
321+
"epoch": epoch + 1,
322+
"mean": epoch * 0.1,
323+
"sem": 0.1,
324+
"trial_index": trial_index,
325+
"metric_name": metric_name,
326+
}
327+
for metric_name in metric_names
328+
for trial_index, (arm_name, max_epoch) in enumerate(
329+
zip(arm_names, max_epochs)
330+
)
331+
for epoch in range(max_epoch)
332+
]
333+
)
334+
large_map_data = MapData(df=large_map_df, map_key_infos=self.map_key_infos)
262335
large_map_df_sparse_metric = pd.DataFrame(
263336
[
264337
{

0 commit comments

Comments
 (0)