Skip to content

Commit

Permalink
Keep observations from repeated arms in MapData.map_df (#3399)
Browse files Browse the repository at this point in the history
Summary:
Pull Request resolved: #3399

**Context:** MapData’s df has only the most recent observation for each (arm, metric) pair, so any arm that appears in multiple trials won’t be included. Data.df does allow data from the same arm in different trials. MapData's behavior is surprising, and downstream functions such as `BestPointMixin._get_trace` assume that every trial with data will be present in `map_df`, so something isn't right.

**This diff:** Now keeps the most recent observation for each (trial_index, arm, metric).

Reviewed By: mpolson64

Differential Revision: D69955018

fbshipit-source-id: 608fb2b1332f45fb7b6bd4d085ef6150318adce2
  • Loading branch information
esantorella authored and facebook-github-bot committed Feb 21, 2025
1 parent a96cc6f commit 63a1eaf
Show file tree
Hide file tree
Showing 2 changed files with 14 additions and 1 deletion.
2 changes: 1 addition & 1 deletion ax/core/map_data.py
Original file line number Diff line number Diff line change
Expand Up @@ -95,7 +95,7 @@ class MapData(Data):
`experiment.attach_data()` (this requires a description to be set.)
"""

DEDUPLICATE_BY_COLUMNS = ["arm_name", "metric_name"]
DEDUPLICATE_BY_COLUMNS = ["trial_index", "arm_name", "metric_name"]

_map_df: pd.DataFrame
_memo_df: pd.DataFrame | None
Expand Down
13 changes: 13 additions & 0 deletions ax/core/tests/test_map_data.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,15 @@ def setUp(self) -> None:
super().setUp()
self.df = pd.DataFrame(
[
{
"arm_name": "0_0",
"epoch": 0,
"mean": 3.0,
"sem": 0.3,
"trial_index": 0,
"metric_name": "a",
},
# repeated arm 0_0
{
"arm_name": "0_0",
"epoch": 0,
Expand Down Expand Up @@ -78,6 +87,10 @@ def setUp(self) -> None:

self.mmd = MapData(df=self.df, map_key_infos=self.map_key_infos)

def test_df(self) -> None:
df = self.mmd.df
self.assertEqual(set(df["trial_index"].drop_duplicates()), {0, 1})

def test_map_key_info(self) -> None:
self.assertEqual(self.map_key_infos, self.mmd.map_key_infos)

Expand Down

0 comments on commit 63a1eaf

Please sign in to comment.