Skip to content

Commit da3318e

Browse files
Add paste diff comparison links to TransferLearningAnalysis (facebook#4980)
Summary: Pull Request resolved: facebook#4980 Follow-up to D92926519. Instead of listing overlapping parameter names inline in the table (which can get long and hard to read), add support for generating a paste diff that shows the current experiment's parameters side-by-side with the source experiment's overlapping parameters. This follows the same callable-injection pattern used by MetricFetchingErrorsAnalysis: the core analysis accepts an optional `create_diff_paste_callable` that takes (before_content, after_content, title) and returns a diffing URL. When provided, a "Comparison" column is added to the table containing the diff link. The existing "Parameters" column is preserved. The diff content uses a YAML-like format with `experiment_name` and`parameter_names` keys, making the side-by-side comparison easy to read. The source (old) experiment appears on the left and the current (new) experiment on the right. Reviewed By: eonofrey Differential Revision: D95218298 fbshipit-source-id: 372b980372ef5b1f4621011a6e4ad7b96799c015
1 parent 3e8e4e7 commit da3318e

3 files changed

Lines changed: 137 additions & 11 deletions

File tree

ax/analysis/healthcheck/tests/test_transfer_learning_analysis.py

Lines changed: 85 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -39,6 +39,13 @@ def _make_experiment(
3939
)
4040

4141

42+
def _dummy_create_diff_paste(
43+
before_content: str, after_content: str, title: str
44+
) -> str:
45+
"""Dummy callable that returns a fake diffing URL."""
46+
return "https://www.internalfb.com/intern/diffing/?paste_number=12345"
47+
48+
4249
_MOCK_TARGET = "ax.storage.sqa_store.load.identify_transferable_experiments"
4350

4451

@@ -162,3 +169,81 @@ def test_experiment_name_passed_to_identify(self, mock_identify: object) -> None
162169
mock_identify.assert_called_once() # pyre-ignore[16]
163170
call_kwargs = mock_identify.call_args.kwargs # pyre-ignore[16]
164171
self.assertEqual(call_kwargs["experiment_name"], "test_experiment")
172+
173+
@patch(_MOCK_TARGET)
174+
def test_diff_paste_callable_adds_comparison_column(
175+
self, mock_identify: object
176+
) -> None:
177+
"""When create_diff_paste_callable is provided, a 'Comparison' column
178+
should be added alongside the existing 'Parameters' column."""
179+
experiment = _make_experiment(
180+
["x1", "x2", "x3", "x4"], experiment_type="my_type"
181+
)
182+
mock_identify.return_value = { # pyre-ignore[16]
183+
"source_exp": TransferLearningMetadata(
184+
overlap_parameters=["x1", "x2", "x3"],
185+
),
186+
}
187+
analysis = TransferLearningAnalysis(
188+
create_diff_paste_callable=_dummy_create_diff_paste,
189+
)
190+
card = analysis.compute(experiment=experiment)
191+
self.assertIn("Comparison", card.df.columns)
192+
self.assertIn("Parameters", card.df.columns)
193+
self.assertIn("diffing", card.df.iloc[0]["Comparison"])
194+
self.assertEqual(card.df.iloc[0]["Parameters"], "x1, x2, x3")
195+
196+
@patch(_MOCK_TARGET)
197+
def test_diff_paste_callable_receives_correct_content(
198+
self, mock_identify: object
199+
) -> None:
200+
"""Verify before/after content includes experiment name headers and
201+
sorted parameters."""
202+
experiment = _make_experiment(
203+
["alpha", "beta", "gamma"], experiment_type="my_type"
204+
)
205+
mock_identify.return_value = { # pyre-ignore[16]
206+
"source_exp": TransferLearningMetadata(
207+
overlap_parameters=["gamma", "alpha"],
208+
),
209+
}
210+
captured_args: list[tuple[str, str, str]] = []
211+
212+
def _capture_callable(before: str, after: str, title: str) -> str:
213+
captured_args.append((before, after, title))
214+
return "https://example.com/diff"
215+
216+
analysis = TransferLearningAnalysis(
217+
create_diff_paste_callable=_capture_callable,
218+
)
219+
analysis.compute(experiment=experiment)
220+
self.assertEqual(len(captured_args), 1)
221+
before, after, title = captured_args[0]
222+
# Before should have source experiment in YAML format
223+
self.assertEqual(
224+
before,
225+
"experiment_name: source_exp (old)\nparameter_names:\n - alpha\n - gamma",
226+
)
227+
# After should have current experiment in YAML format
228+
self.assertEqual(
229+
after,
230+
"experiment_name: test_experiment (new)\n"
231+
"parameter_names:\n - alpha\n - beta\n - gamma",
232+
)
233+
self.assertIn("source_exp", title)
234+
self.assertIn("test_experiment", title)
235+
236+
@patch(_MOCK_TARGET)
237+
def test_no_callable_has_no_comparison_column(self, mock_identify: object) -> None:
238+
"""Without callable, the 'Parameters' column should be present
239+
but no 'Comparison' column."""
240+
experiment = _make_experiment(["x1", "x2", "x3"], experiment_type="my_type")
241+
mock_identify.return_value = { # pyre-ignore[16]
242+
"source_exp": TransferLearningMetadata(
243+
overlap_parameters=["x1", "x2"],
244+
),
245+
}
246+
analysis = TransferLearningAnalysis()
247+
card = analysis.compute(experiment=experiment)
248+
self.assertIn("Parameters", card.df.columns)
249+
self.assertNotIn("Comparison", card.df.columns)

ax/analysis/healthcheck/transfer_learning_analysis.py

Lines changed: 45 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,8 @@
88
from __future__ import annotations
99

1010
import json
11-
from typing import final, TYPE_CHECKING
11+
from collections.abc import Callable
12+
from typing import Any, final, TYPE_CHECKING
1213

1314
import markdown as md
1415
import pandas as pd
@@ -46,11 +47,26 @@ def __init__(
4647
overlap_threshold: float = 0.50,
4748
max_num_exps: int = 10,
4849
config: SQAConfig | None = None,
50+
create_diff_paste_callable: Callable[[str, str, str], str] | None = None,
4951
) -> None:
52+
"""
53+
Args:
54+
experiment_types: List of experiment types to search for.
55+
overlap_threshold: Minimum proportion of overlapping parameters.
56+
max_num_exps: Max number of transferable experiments to return.
57+
config: SQAConfig to use for the query.
58+
create_diff_paste_callable: A function that takes
59+
(before_content, after_content, title) and returns a URL to a
60+
paste diff comparing the current experiment's parameters with
61+
a source experiment's overlapping parameters. If provided, a
62+
"Comparison" column is added to the table containing the diff
63+
link.
64+
"""
5065
self.experiment_types = experiment_types
5166
self.overlap_threshold = overlap_threshold
5267
self.max_num_exps = max_num_exps
5368
self.config = config
69+
self.create_diff_paste_callable = create_diff_paste_callable
5470

5571
@override
5672
def compute(
@@ -104,23 +120,42 @@ def compute(
104120
)
105121

106122
total_parameters = len(experiment.search_space.parameters)
123+
current_params_sorted = sorted(experiment.search_space.parameters.keys())
107124

108-
rows = []
125+
rows: list[dict[str, Any]] = []
109126
for exp_name, metadata in transferable_experiments.items():
110127
overlap_count = len(metadata.overlap_parameters)
111128
overlap_pct = (
112129
(overlap_count / total_parameters * 100)
113130
if total_parameters > 0
114131
else 0.0
115132
)
116-
rows.append(
117-
{
118-
"Experiment": exp_name,
119-
"Overlapping Parameters": overlap_count,
120-
"Overlap (%)": round(overlap_pct, 1),
121-
"Parameters": ", ".join(sorted(metadata.overlap_parameters)),
122-
}
123-
)
133+
overlap_sorted = sorted(metadata.overlap_parameters)
134+
row: dict[str, Any] = {
135+
"Experiment": exp_name,
136+
"Overlapping Parameters": overlap_count,
137+
"Overlap (%)": round(overlap_pct, 1),
138+
"Parameters": ", ".join(overlap_sorted),
139+
}
140+
141+
if self.create_diff_paste_callable is not None:
142+
create_diff_paste = self.create_diff_paste_callable
143+
before_params = "\n".join(f" - {p}" for p in overlap_sorted)
144+
before_content = (
145+
f"experiment_name: {exp_name} (old)\n"
146+
f"parameter_names:\n{before_params}"
147+
)
148+
after_params = "\n".join(f" - {p}" for p in current_params_sorted)
149+
after_content = (
150+
f"experiment_name: {experiment.name} (new)\n"
151+
f"parameter_names:\n{after_params}"
152+
)
153+
title = f"Parameter comparison: {experiment.name} vs {exp_name}"
154+
row["Comparison"] = create_diff_paste(
155+
before_content, after_content, title
156+
)
157+
158+
rows.append(row)
124159

125160
df = pd.DataFrame(rows)
126161

ax/analysis/overview.py

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@
55

66
# pyre-strict
77

8+
from collections.abc import Callable
89
from typing import Any, final
910

1011
from ax.adapter.base import Adapter
@@ -116,6 +117,7 @@ def __init__(
116117
tier_metadata: dict[str, Any] | None = None,
117118
model_fit_threshold: float | None = None,
118119
sqa_config: Any = None,
120+
create_diff_paste_callable: Callable[[str, str, str], str] | None = None,
119121
) -> None:
120122
super().__init__()
121123
self.can_generate = can_generate
@@ -127,6 +129,7 @@ def __init__(
127129
self.tier_metadata = tier_metadata
128130
self.model_fit_threshold = model_fit_threshold
129131
self.sqa_config = sqa_config
132+
self.create_diff_paste_callable = create_diff_paste_callable
130133

131134
@override
132135
def validate_applicable_state(
@@ -232,7 +235,10 @@ def compute(
232235
if not has_batch_trials
233236
else None,
234237
BaselineImprovementAnalysis() if not has_batch_trials else None,
235-
TransferLearningAnalysis(config=self.sqa_config),
238+
TransferLearningAnalysis(
239+
config=self.sqa_config,
240+
create_diff_paste_callable=self.create_diff_paste_callable,
241+
),
236242
*[
237243
SearchSpaceAnalysis(trial_index=trial.index)
238244
for trial in candidate_trials

0 commit comments

Comments
 (0)