Skip to content

Commit 63a1eaf

Browse files
esantorellafacebook-github-bot
authored andcommitted
Keep observations from repeated arms in MapData.map_df (#3399)
Summary: Pull Request resolved: #3399 **Context:** MapData’s df has only the most recent observation for each (arm, metric) pair, so any arm that appears in multiple trials won’t be included. Data.df does allow data from the same arm in different trials. MapData's behavior is surprising, and downstream functions such as `BestPointMixin._get_trace` assume that every trial with data will be present in `map_df`, so something isn't right. **This diff:** Now keeps the most recent observation for each (trial_index, arm, metric). Reviewed By: mpolson64 Differential Revision: D69955018 fbshipit-source-id: 608fb2b1332f45fb7b6bd4d085ef6150318adce2
1 parent a96cc6f commit 63a1eaf

File tree

2 files changed

+14
-1
lines changed

2 files changed

+14
-1
lines changed

ax/core/map_data.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -95,7 +95,7 @@ class MapData(Data):
9595
`experiment.attach_data()` (this requires a description to be set.)
9696
"""
9797

98-
DEDUPLICATE_BY_COLUMNS = ["arm_name", "metric_name"]
98+
DEDUPLICATE_BY_COLUMNS = ["trial_index", "arm_name", "metric_name"]
9999

100100
_map_df: pd.DataFrame
101101
_memo_df: pd.DataFrame | None

ax/core/tests/test_map_data.py

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,15 @@ def setUp(self) -> None:
1818
super().setUp()
1919
self.df = pd.DataFrame(
2020
[
21+
{
22+
"arm_name": "0_0",
23+
"epoch": 0,
24+
"mean": 3.0,
25+
"sem": 0.3,
26+
"trial_index": 0,
27+
"metric_name": "a",
28+
},
29+
# repeated arm 0_0
2130
{
2231
"arm_name": "0_0",
2332
"epoch": 0,
@@ -78,6 +87,10 @@ def setUp(self) -> None:
7887

7988
self.mmd = MapData(df=self.df, map_key_infos=self.map_key_infos)
8089

90+
def test_df(self) -> None:
91+
df = self.mmd.df
92+
self.assertEqual(set(df["trial_index"].drop_duplicates()), {0, 1})
93+
8194
def test_map_key_info(self) -> None:
8295
self.assertEqual(self.map_key_infos, self.mmd.map_key_infos)
8396

0 commit comments

Comments
 (0)