Skip to content

Commit 864307a

Browse files
committed
refactor(filter): make polars compatible
1 parent ac0c243 commit 864307a

1 file changed

Lines changed: 63 additions & 29 deletions

File tree

carps/analysis/utils.py

Lines changed: 63 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -6,19 +6,17 @@
66
import logging
77
from collections.abc import Sequence
88
from pathlib import Path
9-
from typing import TYPE_CHECKING, Any
9+
from typing import Any
1010

1111
import matplotlib.pyplot as plt
1212
import numpy as np
13+
import pandas as pd
14+
import polars as pl
1315
import seaborn as sns
1416
from matplotlib.lines import Line2D
1517

1618
from carps.utils.loggingutils import get_logger
1719

18-
if TYPE_CHECKING:
19-
import pandas as pd
20-
21-
2220
# colorblind_palette = ["#88CCEE", "#44AA99", "#117733", "#999933", "#DDCC77", "#CC6677", "#882255", "#AA4499", "#DDDDDD"] # noqa: E501
2321
colorblind_palette = ["#88CCEE", "#44AA99", "#117733", "#999933", "#DDCC77", "#CC6677", "#882255", "#AA4499", "#7A7A7A"]
2422
logger = get_logger("analysis utils")
@@ -103,40 +101,76 @@ def setup_seaborn(font_scale: float | None = None) -> None:
103101

104102

105103
def filter_only_final_performance(
106-
df: pd.DataFrame, x_column: str = "n_trials_norm", max_x: float = 1, key_performance: str = "trial_value__cost_inc"
107-
) -> pd.DataFrame:
108-
"""Filter final performance based on the maximum x value.
109-
110-
(1) Filter s.t. the x_column is less than or equal to max_x.
111-
(2) For each run (each group of optimizer_id, task_id, and seed), keep only the row with the
112-
best solution, which is defined as the row with the minimum cost_inc value.
104+
df: pd.DataFrame | pl.DataFrame,
105+
x_column: str = "n_trials_norm",
106+
max_x: float = 1,
107+
key_performance: str = "trial_value__cost_inc",
108+
) -> pd.DataFrame | pl.DataFrame:
109+
"""Extracts the best-found performance (incumbent) for each experimental run
110+
within a specified budget constraint.
111+
112+
This function simulates a snapshot of an optimization process. It first
113+
constrains the data to a maximum budget (x_column) and then identifies
114+
the single best configuration found up to that point for every unique
115+
combination of optimizer, task, and random seed.
116+
117+
Algorithm Logic:
118+
1. Filter: Retain only observations where the budget metric is <= `max_x`.
119+
2. Group: Partition data by ["optimizer_id", "task_id", "seed"].
120+
3. Identify Incumbent: Within each partition, locate the observation
121+
with the minimum value in `key_performance`.
122+
4. Tie-breaking: If multiple timestamps share the same minimum cost,
123+
the earliest occurrence is retained.
113124
114125
Parameters
115126
----------
116-
df : pd.DataFrame
117-
The DataFrame containing the performance data.
127+
df : Union[pd.DataFrame, pl.DataFrame]
128+
The dataset containing optimization traces. Supports both Pandas and
129+
Polars backends.
118130
x_column : str, optional
119-
The column to filter on, by default "n_trials_norm".
131+
The budget or time-step column (e.g., normalized trials, wall-clock
132+
time, or iterations), by default "n_trials_norm".
120133
max_x : float, optional
121-
The maximum value of the x_column to filter by, by default 1.
134+
The budget cutoff. Any data points beyond this value are ignored
135+
to simulate early stopping or specific budget analysis, by default 1.
122136
key_performance : str, optional
123-
The performance column, by default "trial_value__cost_inc".
137+
The metric to be minimized (e.g., cost, regret, or error).
138+
By default "trial_value__cost_inc".
124139
125140
Returns:
126141
-------
127-
pd.DataFrame
128-
A DataFrame containing only the final performance data for each optimizer, task, and seed.
142+
Union[pd.DataFrame, pl.DataFrame]
143+
A reduced DataFrame containing exactly one row per (optimizer, task, seed),
144+
representing the peak performance achieved within the given budget.
145+
The return type matches the input type.
146+
147+
Raises:
148+
------
149+
TypeError
150+
If the input 'df' is neither a Pandas nor a Polars DataFrame.
129151
"""
130-
131-
def keep(groupdf: pd.DataFrame) -> pd.DataFrame:
132-
groupdf = groupdf[groupdf[x_column] <= max_x]
133-
return groupdf[groupdf[key_performance] == groupdf[key_performance].min()].iloc[[-1]]
134-
135-
df_final = df.groupby(["optimizer_id", "task_id", "seed"]).apply(keep, include_groups=False)
136-
137-
if "level_3" in df_final.columns:
138-
df_final = df_final.drop(columns=["level_3"])
139-
return df_final.reset_index()
152+
group_cols = ["optimizer_id", "task_id", "seed"]
153+
154+
# --- Polars Backend (Vectorized Expressions) ---
155+
if isinstance(df, pl.DataFrame):
156+
return (
157+
df.filter(pl.col(x_column) <= max_x)
158+
.sort(key_performance, descending=False)
159+
.group_by(group_cols, maintain_order=True)
160+
.first()
161+
)
162+
163+
# --- Pandas Backend (Vectorized Sorting/Grouping) ---
164+
if isinstance(df, pd.DataFrame):
165+
# We avoid .apply() as it is slow; sorting + .first() is the idiomatic alternative
166+
return (
167+
df[df[x_column] <= max_x]
168+
.sort_values(key_performance, ascending=True)
169+
.groupby(group_cols, as_index=False)
170+
.first()
171+
)
172+
173+
raise TypeError(f"Unsupported dataframe type: {type(df)}. Expected Pandas or Polars.")
140174

141175

142176
def convert_mixed_types_to_str(logs: pd.DataFrame, logger: logging.Logger | None = None) -> pd.DataFrame:

0 commit comments

Comments
 (0)