7
7
8
8
from __future__ import annotations
9
9
10
+ from bisect import bisect_right
10
11
from collections .abc import Iterable , Sequence
11
12
from copy import deepcopy
12
13
from logging import Logger
13
14
from typing import Any , Generic , TypeVar
14
15
15
16
import numpy as np
17
+ import numpy .typing as npt
16
18
import pandas as pd
17
19
from ax .core .data import Data
18
20
from ax .core .types import TMapTrialEvaluation
@@ -412,6 +414,48 @@ def subsample(
412
414
)
413
415
414
416
417
+ def _ceil_divide (
418
+ a : int | np .int_ | npt .NDArray [np .int_ ], b : int | np .int_ | npt .NDArray [np .int_ ]
419
+ ) -> np .int_ | npt .NDArray [np .int_ ]:
420
+ return - np .floor_divide (- a , b )
421
+
422
+
423
+ def _subsample_rate (
424
+ map_df : pd .DataFrame ,
425
+ keep_every : int | None = None ,
426
+ limit_rows_per_group : int | None = None ,
427
+ limit_rows_per_metric : int | None = None ,
428
+ ) -> int | np .int_ :
429
+ if keep_every is not None :
430
+ return keep_every
431
+
432
+ grouped_map_df = map_df .groupby (MapData .DEDUPLICATE_BY_COLUMNS )
433
+ group_sizes = grouped_map_df .size ()
434
+ max_rows = group_sizes .max ()
435
+
436
+ if limit_rows_per_group is not None :
437
+ return _ceil_divide (max_rows , limit_rows_per_group )
438
+
439
+ if limit_rows_per_metric is not None :
440
+ # search for the `keep_every` such that when you apply it to each group,
441
+ # the total number of rows is smaller than `limit_rows_per_metric`.
442
+ ks = np .arange (max_rows , 0 , - 1 )
443
+ # total sizes in ascending order
444
+ total_sizes = np .sum (
445
+ _ceil_divide (group_sizes .values , ks [..., np .newaxis ]), axis = 1
446
+ )
447
+ # binary search
448
+ i = bisect_right (total_sizes , limit_rows_per_metric )
449
+ # if no such `k` is found, then `derived_keep_every` stays as 1.
450
+ if i > 0 :
451
+ return ks [i - 1 ]
452
+
453
+ raise ValueError (
454
+ "at least one of `keep_every`, `limit_rows_per_group`, "
455
+ "or `limit_rows_per_metric` must be specified."
456
+ )
457
+
458
+
415
459
def _subsample_one_metric (
416
460
map_df : pd .DataFrame ,
417
461
map_key : str | None = None ,
@@ -421,30 +465,21 @@ def _subsample_one_metric(
421
465
include_first_last : bool = True ,
422
466
) -> pd .DataFrame :
423
467
"""Helper function to subsample a dataframe that holds a single metric."""
424
- derived_keep_every = 1
425
- if keep_every is not None :
426
- derived_keep_every = keep_every
427
- elif limit_rows_per_group is not None :
428
- max_rows = map_df .groupby (MapData .DEDUPLICATE_BY_COLUMNS ).size ().max ()
429
- derived_keep_every = np .ceil (max_rows / limit_rows_per_group )
430
- elif limit_rows_per_metric is not None :
431
- group_sizes = map_df .groupby (MapData .DEDUPLICATE_BY_COLUMNS ).size ().to_numpy ()
432
- # search for the `keep_every` such that when you apply it to each group,
433
- # the total number of rows is smaller than `limit_rows_per_metric`.
434
- for k in range (1 , group_sizes .max () + 1 ):
435
- if (np .ceil (group_sizes / k )).sum () <= limit_rows_per_metric :
436
- derived_keep_every = k
437
- break
438
- # if no such `k` is found, then `derived_keep_every` stays as 1.
468
+
469
+ grouped_map_df = map_df .groupby (MapData .DEDUPLICATE_BY_COLUMNS )
470
+
471
+ derived_keep_every = _subsample_rate (
472
+ map_df , keep_every , limit_rows_per_group , limit_rows_per_metric
473
+ )
439
474
440
475
if derived_keep_every <= 1 :
441
476
filtered_map_df = map_df
442
477
else :
443
478
filtered_dfs = []
444
- for _ , df_g in map_df . groupby ( MapData . DEDUPLICATE_BY_COLUMNS ) :
479
+ for _ , df_g in grouped_map_df :
445
480
df_g = df_g .sort_values (map_key )
446
481
if include_first_last :
447
- rows_per_group = int ( np . ceil ( len (df_g ) / derived_keep_every ) )
482
+ rows_per_group = _ceil_divide ( len (df_g ), derived_keep_every )
448
483
linspace_idcs = np .linspace (0 , len (df_g ) - 1 , rows_per_group )
449
484
idcs = np .round (linspace_idcs ).astype (int )
450
485
filtered_df = df_g .iloc [idcs ]
0 commit comments