Skip to content

Commit 9590eae

Browse files
ltiaofacebook-github-bot
authored andcommitted
Simplified and optimized logic for calculating per-metric subsampling rate for MapData (#3106)
Summary: This refines the logic for calculating per-metric subsampling rates in `MapData.subsample` and incorporates a (probably premature) performance optimization, achieved by utilizing binary search on a sorted list instead of linear search. Reviewed By: Balandat Differential Revision: D66366076
1 parent b335598 commit 9590eae

File tree

1 file changed

+52
-17
lines changed

1 file changed

+52
-17
lines changed

ax/core/map_data.py

+52-17
Original file line numberDiff line numberDiff line change
@@ -7,12 +7,14 @@
77

88
from __future__ import annotations
99

10+
from bisect import bisect_right
1011
from collections.abc import Iterable, Sequence
1112
from copy import deepcopy
1213
from logging import Logger
1314
from typing import Any, Generic, TypeVar
1415

1516
import numpy as np
17+
import numpy.typing as npt
1618
import pandas as pd
1719
from ax.core.data import Data
1820
from ax.core.types import TMapTrialEvaluation
@@ -412,6 +414,48 @@ def subsample(
412414
)
413415

414416

417+
def _ceil_divide(
418+
a: int | np.int_ | npt.NDArray[np.int_], b: int | np.int_ | npt.NDArray[np.int_]
419+
) -> np.int_ | npt.NDArray[np.int_]:
420+
return -np.floor_divide(-a, b)
421+
422+
423+
def _subsample_rate(
424+
map_df: pd.DataFrame,
425+
keep_every: int | None = None,
426+
limit_rows_per_group: int | None = None,
427+
limit_rows_per_metric: int | None = None,
428+
) -> int:
429+
if keep_every is not None:
430+
return keep_every
431+
432+
grouped_map_df = map_df.groupby(MapData.DEDUPLICATE_BY_COLUMNS)
433+
group_sizes = grouped_map_df.size()
434+
max_rows = group_sizes.max()
435+
436+
if limit_rows_per_group is not None:
437+
return _ceil_divide(max_rows, limit_rows_per_group).item()
438+
439+
if limit_rows_per_metric is not None:
440+
# search for the `keep_every` such that when you apply it to each group,
441+
# the total number of rows is smaller than `limit_rows_per_metric`.
442+
ks = np.arange(max_rows, 0, -1)
443+
# total sizes in ascending order
444+
total_sizes = np.sum(
445+
_ceil_divide(group_sizes.values, ks[..., np.newaxis]), axis=1
446+
)
447+
# binary search
448+
i = bisect_right(total_sizes, limit_rows_per_metric)
449+
# if no such `k` is found, then `derived_keep_every` stays as 1.
450+
if i > 0:
451+
return ks[i - 1].item()
452+
453+
raise ValueError(
454+
"at least one of `keep_every`, `limit_rows_per_group`, "
455+
"or `limit_rows_per_metric` must be specified."
456+
)
457+
458+
415459
def _subsample_one_metric(
416460
map_df: pd.DataFrame,
417461
map_key: str | None = None,
@@ -421,30 +465,21 @@ def _subsample_one_metric(
421465
include_first_last: bool = True,
422466
) -> pd.DataFrame:
423467
"""Helper function to subsample a dataframe that holds a single metric."""
424-
derived_keep_every = 1
425-
if keep_every is not None:
426-
derived_keep_every = keep_every
427-
elif limit_rows_per_group is not None:
428-
max_rows = map_df.groupby(MapData.DEDUPLICATE_BY_COLUMNS).size().max()
429-
derived_keep_every = np.ceil(max_rows / limit_rows_per_group)
430-
elif limit_rows_per_metric is not None:
431-
group_sizes = map_df.groupby(MapData.DEDUPLICATE_BY_COLUMNS).size().to_numpy()
432-
# search for the `keep_every` such that when you apply it to each group,
433-
# the total number of rows is smaller than `limit_rows_per_metric`.
434-
for k in range(1, group_sizes.max() + 1):
435-
if (np.ceil(group_sizes / k)).sum() <= limit_rows_per_metric:
436-
derived_keep_every = k
437-
break
438-
# if no such `k` is found, then `derived_keep_every` stays as 1.
468+
469+
grouped_map_df = map_df.groupby(MapData.DEDUPLICATE_BY_COLUMNS)
470+
471+
derived_keep_every = _subsample_rate(
472+
map_df, keep_every, limit_rows_per_group, limit_rows_per_metric
473+
)
439474

440475
if derived_keep_every <= 1:
441476
filtered_map_df = map_df
442477
else:
443478
filtered_dfs = []
444-
for _, df_g in map_df.groupby(MapData.DEDUPLICATE_BY_COLUMNS):
479+
for _, df_g in grouped_map_df:
445480
df_g = df_g.sort_values(map_key)
446481
if include_first_last:
447-
rows_per_group = int(np.ceil(len(df_g) / derived_keep_every))
482+
rows_per_group = _ceil_divide(len(df_g), derived_keep_every)
448483
linspace_idcs = np.linspace(0, len(df_g) - 1, rows_per_group)
449484
idcs = np.round(linspace_idcs).astype(int)
450485
filtered_df = df_g.iloc[idcs]

0 commit comments

Comments
 (0)