-
Notifications
You must be signed in to change notification settings - Fork 372
Expand file tree
/
Copy pathmap_replay.py
More file actions
332 lines (284 loc) · 13.4 KB
/
map_replay.py
File metadata and controls
332 lines (284 loc) · 13.4 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
#!/usr/bin/env python3
# Copyright (c) Meta Platforms, Inc. and affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
# pyre-strict
from collections import defaultdict
from logging import Logger
from typing import Any
import pandas as pd
from ax.core.base_trial import BaseTrial
from ax.core.data import Data, MAP_KEY
from ax.core.map_metric import MapMetric
from ax.core.metric import MetricFetchE, MetricFetchResult
from ax.core.trial import Trial
from ax.utils.common.logger import get_logger
from ax.utils.common.result import Err, Ok
from pyre_extensions import none_throws
logger: Logger = get_logger(__name__)
class MapDataReplayState:
"""Shared state coordinator for replaying historical map data.
Manages normalized cursor-based progression across multiple metrics
and trials. The cursor model uses a global min/max MAP_KEY across
all metrics to preserve cross-metric timing alignment.
This class serves original MAP_KEY values (not normalized). Downstream
early stopping strategies apply normalization independently via
``_maybe_normalize_map_key`` in ``ax.adapter.data_utils``.
"""
def __init__(
self,
map_data: Data,
metric_signatures: list[str],
step_size: float = 0.01,
) -> None:
"""Initialize replay state from historical data.
Args:
map_data: Historical data containing progression data.
metric_signatures: List of metric signatures to replay.
step_size: Cursor increment per advancement step. Determines
the granularity of replay (e.g. 0.01 = 100 steps).
"""
self.step_size: float = step_size
# Pre-index data by (trial_index, metric_signature) for O(1) lookups
self._data: dict[tuple[int, str], pd.DataFrame] = {}
all_trial_indices: set[int] = set()
all_prog_values: list[float] = []
per_trial_max_prog: dict[int, float] = {}
for metric_signature in metric_signatures:
df = map_data.full_df
df = df[df["metric_signature"] == metric_signature]
replay_df = df.sort_values(
by=["trial_index", MAP_KEY], ascending=True, ignore_index=True
)
for trial_index, group in replay_df.groupby("trial_index"):
trial_index = int(trial_index)
self._data[(trial_index, metric_signature)] = group.reset_index(
drop=True
)
all_trial_indices.add(trial_index)
prog_values = group[MAP_KEY].values
all_prog_values.extend(prog_values.tolist())
trial_max = float(prog_values.max())
if trial_index in per_trial_max_prog:
per_trial_max_prog[trial_index] = max(
per_trial_max_prog[trial_index], trial_max
)
else:
per_trial_max_prog[trial_index] = trial_max
if all_prog_values:
self.min_prog: float = float(min(all_prog_values))
self.max_prog: float = float(max(all_prog_values))
else:
self.min_prog = 0.0
self.max_prog = 0.0
self._per_trial_max_prog: dict[int, float] = per_trial_max_prog
self._trial_cursors: defaultdict[int, float] = defaultdict(float)
self._trial_indices: set[int] = all_trial_indices
def advance_trial(self, trial_index: int) -> None:
"""Advance the cursor for a trial by one step."""
self._trial_cursors[trial_index] = min(
self._trial_cursors[trial_index] + self.step_size, 1.0
)
def has_trial_data(self, trial_index: int) -> bool:
"""Check if any replay data exists for a given trial."""
return trial_index in self._trial_indices
def is_trial_complete(self, trial_index: int) -> bool:
"""Check if a trial's cursor has reached its maximum progression."""
if self.min_prog == self.max_prog:
return True
curr_prog = self.min_prog + self._trial_cursors[trial_index] * (
self.max_prog - self.min_prog
)
return curr_prog >= self._per_trial_max_prog.get(trial_index, 0.0)
def get_data(self, trial_index: int, metric_signature: str) -> pd.DataFrame:
"""Get replay data for a trial up to the current cursor position.
Returns a DataFrame filtered to rows where MAP_KEY <= current
progression value, with original (non-normalized) MAP_KEY values.
"""
df = self._data.get((trial_index, metric_signature))
if df is None:
return pd.DataFrame()
if self.min_prog == self.max_prog:
return df
curr_prog = self.min_prog + self._trial_cursors[trial_index] * (
self.max_prog - self.min_prog
)
return df[df[MAP_KEY] <= curr_prog]
class MapDataReplayMetric(MapMetric):
"""A metric for replaying historical map data."""
def __init__(
self,
name: str,
map_data: Data,
metric_name: str,
max_steps_validation: int | None = 200,
lower_is_better: bool | None = None,
) -> None:
"""Inits MapDataReplayMetric.
Args:
name: The name of the metric.
map_data: Historical data to use for replaying. It is assumed that
there is a single curve (arm) per trial (i.e., no batch trials).
metric_name: The metric to replay from `map_data`.
max_steps_validation: If not None, we check to see that the inferred
scaling factor and offset does not lead to a number of replay steps
that is larger than `max_steps_validation` for any trial.
lower_is_better: If True, lower metric values are considered
desirable.
"""
self.map_data: Data = map_data
self.max_steps_validation = max_steps_validation
self.metric_name: str = metric_name
# Store pre-processed DataFrame sorted by trial_index and step
self._replay_df: pd.DataFrame = _prepare_replay_dataframe(
map_data=map_data, metric_name=self.metric_name
)
# Pre-group by trial_index for O(1) trial lookups instead of O(n) filtering
self._trial_groups: dict[int, pd.DataFrame] = {
int(trial_idx): group
for trial_idx, group in self._replay_df.groupby("trial_index")
}
# Pre-compute trial statistics using vectorized groupby, then extract
# offset and scaling_factor once, and store only last_step as a dict
trial_stats = _compute_trial_stats(self._replay_df)
self.offset: float = trial_stats["first_step"].min()
self.scaling_factor: float = _compute_scaling_factor(
trial_stats=trial_stats, offset=self.offset
)
# Store only last_step as dict for O(1) lookups in hot paths
# Explicitly convert keys to int for consistency with _trial_groups
self._trial_last_step: dict[int, float] = {
int(k): float(v) for k, v in trial_stats["last_step"].items()
}
self._trial_index_to_step: dict[int, int] = defaultdict(int)
super().__init__(name=name, lower_is_better=lower_is_better)
self._validate_replay_feasibility(trial_stats=trial_stats)
@classmethod
def is_available_while_running(cls) -> bool:
return True
def _validate_replay_feasibility(self, trial_stats: pd.DataFrame) -> None:
"""Check that the offset and scaling factor results in a reasonable number
of steps for all trials (i.e., we don't want an intractable number of trials
if (trial_max_step - offset) / scaling_factor is too large).
Args:
trial_stats: DataFrame with trial statistics (first_step, last_step,
num_points). Passed in to avoid recomputing or storing it.
"""
if self.max_steps_validation is None:
return
# Vectorized computation of max steps per trial
max_steps_per_trial = (
trial_stats["last_step"] - self.offset
) / self.scaling_factor
max_steps = max_steps_per_trial.max()
# Find violating trials
violating = max_steps_per_trial[max_steps_per_trial > self.max_steps_validation]
if not violating.empty:
trial_idx = violating.index[0]
max_steps_trial = violating.iloc[0]
raise ValueError(
f"For trial {trial_idx}, the computed offset {self.offset} and "
f"scaling factor {self.scaling_factor} lead to "
f"{max_steps_trial} steps, which is larger than "
f"{self.max_steps_validation} steps to replay."
)
logger.debug(
f"Validated MapReplayMetric {self.name} with "
f"{len(trial_stats)} trials, scaling factor = "
f"{self.scaling_factor:.2f}, and offset = {self.offset:.2f}, "
f"resulting in maximum steps = {max_steps}."
)
def has_trial_data(self, trial_idx: int) -> bool:
"""Check if any replay data exists for a given trial."""
# Use pre-grouped dict for O(1) lookup instead of checking DataFrame index
return trial_idx in self._trial_groups
def more_replay_available(self, trial_idx: int) -> bool:
"""Check if more replay data is available for a given trial."""
trial_max_step = self._trial_last_step.get(trial_idx)
if trial_max_step is None:
return False
current_step = (
self.offset + self._trial_index_to_step[trial_idx] * self.scaling_factor
)
return current_step < trial_max_step
def fetch_trial_data(self, trial: BaseTrial, **kwargs: Any) -> MetricFetchResult:
try:
if not isinstance(trial, Trial):
raise RuntimeError(
"Only (non-batch) Trials are supported by "
f"{self.__class__.__name__}."
)
trial_idx = trial.index
# Increment the step counter if we can.
if trial.status.is_running and self.more_replay_available(
trial_idx=trial_idx
):
self._trial_index_to_step[trial_idx] += 1
trial_scaled_step = (
self.offset + self._trial_index_to_step[trial_idx] * self.scaling_factor
)
logger.info(f"Trial {trial_idx} is at step {trial_scaled_step}.")
# Use pre-grouped data for O(1) lookup instead of filtering full DataFrame
trial_group = self._trial_groups.get(trial_idx)
if trial_group is None:
return Ok(value=Data.from_multiple_data(data=[]))
# Filter only the trial's subset (much smaller than full DataFrame)
trial_data = trial_group[trial_group[MAP_KEY] <= trial_scaled_step]
if trial_data.empty:
return Ok(value=Data())
# Create the result DataFrame in one operation
result_df = pd.DataFrame(
{
"arm_name": none_throws(trial.arm).name,
"metric_name": self.name,
"mean": trial_data["mean"].values,
"sem": trial_data["sem"].values,
"trial_index": trial.index,
"metric_signature": self.signature,
MAP_KEY: trial_data[MAP_KEY].values,
}
)
return Ok(value=Data(df=result_df))
except Exception as e:
return Err(
MetricFetchE(message=f"Failed to fetch {self.name}", exception=e)
)
def _prepare_replay_dataframe(map_data: Data, metric_name: str) -> pd.DataFrame:
"""Prepare a pre-sorted DataFrame for efficient replay lookups.
Filters the data to the specified metric and sorts by trial_index and step.
This allows efficient vectorized filtering during fetch_trial_data.
"""
df = map_data.full_df
df = df[df["metric_name"] == metric_name]
# Sort once upfront for efficient lookups
return df.sort_values(
by=["trial_index", MAP_KEY], ascending=True, ignore_index=True
)
def _compute_trial_stats(replay_df: pd.DataFrame) -> pd.DataFrame:
"""Compute per-trial statistics using vectorized groupby operations.
Returns a DataFrame indexed by trial_index with columns:
- first_step: the first (minimum) step value for each trial
- last_step: the last (maximum) step value for each trial
- num_points: the number of data points per trial
"""
stats = replay_df.groupby("trial_index")[MAP_KEY].agg(
first_step="first", # Data is pre-sorted, so first/last are min/max
last_step="last",
num_points="count",
)
return stats
def _compute_scaling_factor(trial_stats: pd.DataFrame, offset: float) -> float:
"""Compute the scaling factor for replay data using vectorized operations.
The scaling factor is:
`mean_{trial in trials} (max_steps_trial - offset) / num_points_trial`.
"""
# Vectorized computation of per-trial scaling factors
valid_mask = (trial_stats["num_points"] > 0) & (trial_stats["last_step"] > offset)
if not valid_mask.any():
return 1.0
scaling_factors = (
trial_stats.loc[valid_mask, "last_step"] - offset
) / trial_stats.loc[valid_mask, "num_points"]
scaling_factor = float(scaling_factors.mean())
return scaling_factor if scaling_factor > 0.0 else 1.0