Skip to content

Commit 6f88d2d

Browse files
sdaultonmeta-codesync[bot]
authored andcommitted
Open-source compute_task_selection_cv (#5096)
Summary: Pull Request resolved: #5096 Reviewed By: saitcakmak Differential Revision: D96362189 fbshipit-source-id: 10961d1b9bda94c7d21e554ec1cf318f1e7ba1e4
1 parent 168516b commit 6f88d2d

3 files changed

Lines changed: 872 additions & 0 deletions

File tree

Lines changed: 280 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,280 @@
1+
# Copyright (c) Meta Platforms, Inc. and affiliates.
2+
#
3+
# This source code is licensed under the MIT license found in the
4+
# LICENSE file in the root directory of this source tree.
5+
6+
# pyre-strict
7+
8+
from __future__ import annotations
9+
10+
from collections.abc import Callable
11+
from logging import Logger
12+
13+
from ax.adapter.base import Adapter
14+
from ax.adapter.cross_validation import compute_diagnostics, cross_validate
15+
from ax.adapter.transforms.winsorize import Winsorize
16+
from ax.core.auxiliary import AuxiliaryExperimentPurpose
17+
from ax.core.auxiliary_source import AuxiliarySource
18+
from ax.core.data import Data
19+
from ax.core.experiment import Experiment
20+
from ax.core.observation import Observation
21+
from ax.exceptions.core import AxError
22+
from ax.generation_strategy.generation_strategy import GenerationStrategy
23+
from ax.utils.common.logger import get_logger
24+
from ax.utils.stats.model_fit_stats import (
25+
DIAGNOSTIC_FN_DIRECTIONS,
26+
ModelFitMetricDirection,
27+
)
28+
from botorch.exceptions.errors import ModelFittingError
29+
from pyre_extensions import assert_is_instance
30+
31+
logger: Logger = get_logger(__name__)
32+
33+
34+
def _mean_diagnostic(
35+
diagnostics: dict[str, dict[str, float]],
36+
eval_criterion: str,
37+
metric_names: list[str],
38+
) -> float:
39+
"""Compute the mean of ``eval_criterion`` across ``metric_names``."""
40+
criterion_values = diagnostics[eval_criterion]
41+
return sum(criterion_values[m] for m in metric_names) / len(metric_names)
42+
43+
44+
def _get_winsorization_test_selector(
45+
adapter: Adapter,
46+
) -> Callable[[Observation], bool] | None:
47+
"""Return a test selector that excludes observations outside Winsorize cutoffs.
48+
49+
When a model uses the Winsorize transform, observations whose raw values
50+
fall outside the learned clipping bounds are not meaningful test points
51+
because their observed values would be clipped during transformation.
52+
This selector keeps only observations where all metrics' means
53+
lie strictly within their cutoff ranges, so that cross-validation scores are
54+
computed on un-clipped data.
55+
56+
Returns None if the adapter has no Winsorize transform or if all cutoffs
57+
are effectively unbounded (negative infinity to positive infinity).
58+
"""
59+
if "Winsorize" not in adapter.transforms:
60+
return None
61+
62+
winsorize_transform: Winsorize = assert_is_instance(
63+
adapter.transforms["Winsorize"], Winsorize
64+
)
65+
66+
# Check if all cutoffs are effectively unbounded.
67+
all_unbounded = all(
68+
lo == float("-inf") and hi == float("inf")
69+
for lo, hi in winsorize_transform.cutoffs.values()
70+
)
71+
if all_unbounded:
72+
return None
73+
74+
def test_selector(obs: Observation) -> bool:
75+
od = obs.data
76+
for i, metric_signature in enumerate(od.metric_signatures):
77+
cutoffs = winsorize_transform.cutoffs.get(metric_signature)
78+
if cutoffs is None:
79+
continue
80+
mean = od.means[i]
81+
if mean <= cutoffs[0] or mean >= cutoffs[1]:
82+
return False
83+
return True
84+
85+
return test_selector
86+
87+
88+
def _fit_and_cv(
89+
generation_strategy: GenerationStrategy,
90+
experiment: Experiment,
91+
data: Data,
92+
eval_criterion: str,
93+
metric_names: list[str],
94+
) -> float:
95+
"""Clone a GenerationStrategy, fit the appropriate node, and compute
96+
mean CV score.
97+
98+
Uses ``GenerationStrategy.fit`` to let the GS select the correct
99+
node (e.g. TL vs non-TL) based on the current experiment state, then
100+
runs cross-validation on the best fitted adapter.
101+
"""
102+
gs = generation_strategy.clone_reset()
103+
adapter = gs.fit(experiment=experiment, data=data)
104+
if adapter is None:
105+
raise AxError("No fitted adapter after fitting the generation node.")
106+
test_selector = _get_winsorization_test_selector(adapter)
107+
cv_results = cross_validate(adapter, untransform=False, test_selector=test_selector)
108+
return _mean_diagnostic(
109+
compute_diagnostics(cv_results), eval_criterion, metric_names
110+
)
111+
112+
113+
def compute_task_selection_cv(
114+
source_experiments: list[Experiment],
115+
target_experiment: Experiment,
116+
generation_strategy: GenerationStrategy,
117+
target_data: Data | None = None,
118+
eval_criterion: str = "MSE",
119+
max_tasks: int = 2,
120+
) -> list[str]:
121+
"""Greedy forward task selection via cross-validation (RP_CV).
122+
123+
Starting from a target-only model, greedily adds source tasks one at a
124+
time, keeping each addition only if it improves the leave-one-out
125+
cross-validation score on the target data.
126+
127+
The metric names are extracted from the target experiment's
128+
``optimization_config``. When the objective has multiple metrics
129+
(e.g. ``MultiObjective``), the mean ``eval_criterion`` across all
130+
objective metrics is used for selection.
131+
132+
At each step the generation strategy is cloned and
133+
``GenerationStrategy.fit`` is called so that the GS picks the
134+
appropriate node (TL or non-TL) based on whether auxiliary sources are
135+
attached to the experiment. The node is then fitted and
136+
cross-validated.
137+
138+
The direction (minimize vs maximize) for ``eval_criterion`` is looked up
139+
automatically from ``DIAGNOSTIC_FN_DIRECTIONS``.
140+
141+
Args:
142+
source_experiments: Candidate source experiments.
143+
target_experiment: Target experiment with attached data and an
144+
``optimization_config`` whose objective defines the metrics.
145+
generation_strategy: A ``GenerationStrategy`` that will be cloned
146+
via ``clone_reset()`` before each fit. The GS should contain
147+
nodes that handle both the single-task (no auxiliary sources)
148+
and multi-task (with auxiliary sources) cases via transition
149+
criteria.
150+
target_data: Data to use for fitting and CV. If ``None``, uses
151+
``target_experiment.lookup_data()``.
152+
eval_criterion: Diagnostic key from ``compute_diagnostics``.
153+
Must be a key in ``DIAGNOSTIC_FN_DIRECTIONS``.
154+
Defaults to ``"MSE"``.
155+
max_tasks: Maximum number of sources to select. Defaults to 2.
156+
157+
Returns:
158+
Ordered list of selected source experiment names, in the order
159+
they were greedily added. Empty if no source improves CV.
160+
161+
Raises:
162+
AxError: If source experiments have duplicate names.
163+
ValueError: If the target experiment has no data or if
164+
``eval_criterion`` is not in ``DIAGNOSTIC_FN_DIRECTIONS``.
165+
"""
166+
# Validate unique source names.
167+
source_names: list[str] = []
168+
for i, exp in enumerate(source_experiments):
169+
if not exp.has_name:
170+
exp.name = f"source_{i}"
171+
if exp.name in source_names:
172+
raise AxError("Source experiments must have unique names.")
173+
source_names.append(exp.name)
174+
175+
if target_data is None:
176+
target_data = target_experiment.lookup_data()
177+
if target_data.df.empty:
178+
raise ValueError(
179+
"Target experiment has no data. Cannot perform CV task selection."
180+
)
181+
182+
if eval_criterion not in DIAGNOSTIC_FN_DIRECTIONS:
183+
raise ValueError(
184+
f"Unknown eval_criterion '{eval_criterion}'. "
185+
f"Must be one of {list(DIAGNOSTIC_FN_DIRECTIONS.keys())}."
186+
)
187+
minimize = (
188+
DIAGNOSTIC_FN_DIRECTIONS[eval_criterion] == ModelFitMetricDirection.MINIMIZE
189+
)
190+
191+
opt_config = target_experiment.optimization_config
192+
if opt_config is None:
193+
metric_names = list(
194+
set(target_experiment.metrics.keys()).intersection(
195+
target_data.df.metric_names.unique()
196+
)
197+
)
198+
else:
199+
metric_names = list(opt_config.metric_names)
200+
logger.info(f"Evaluating CV on metrics: {metric_names}")
201+
202+
aux_srcs: list[AuxiliarySource] = [
203+
AuxiliarySource(experiment=exp) for exp in source_experiments
204+
]
205+
206+
# Fit base adapter (target only) and compute baseline CV score.
207+
logger.info("Fitting base adapter (target only) for CV baseline.")
208+
best_score = _fit_and_cv(
209+
generation_strategy=generation_strategy,
210+
experiment=target_experiment,
211+
data=target_data,
212+
eval_criterion=eval_criterion,
213+
metric_names=metric_names,
214+
)
215+
logger.info(f"Baseline mean CV {eval_criterion}: {best_score:.6f}")
216+
217+
# Save original auxiliary experiments to restore later.
218+
original_aux = target_experiment.auxiliary_experiments_by_purpose.get(
219+
AuxiliaryExperimentPurpose.TRANSFERABLE_EXPERIMENT
220+
)
221+
222+
selected_names: list[str] = []
223+
selected_aux_srcs: list[AuxiliarySource] = []
224+
remaining_idcs: set[int] = set(range(len(aux_srcs)))
225+
226+
try:
227+
for step in range(max_tasks):
228+
best_idx: int | None = None
229+
for i in remaining_idcs:
230+
candidate_aux = selected_aux_srcs + [aux_srcs[i]]
231+
target_experiment.auxiliary_experiments_by_purpose[
232+
AuxiliaryExperimentPurpose.TRANSFERABLE_EXPERIMENT
233+
] = candidate_aux # pyre-ignore[6]
234+
235+
try:
236+
score = _fit_and_cv(
237+
generation_strategy=generation_strategy,
238+
experiment=target_experiment,
239+
data=target_data,
240+
eval_criterion=eval_criterion,
241+
metric_names=metric_names,
242+
)
243+
except (AxError, ModelFittingError, RuntimeError) as e:
244+
logger.warning(
245+
f"CV failed for candidate '{source_names[i]}': {e}. Skipping.",
246+
exc_info=True,
247+
)
248+
continue
249+
250+
is_better = score < best_score if minimize else score > best_score
251+
if is_better:
252+
best_score = score
253+
best_idx = i
254+
255+
if best_idx is None:
256+
logger.info(
257+
f"No improvement at step {step + 1}. "
258+
f"Stopping with {len(selected_names)} selected sources."
259+
)
260+
break
261+
262+
selected_aux_srcs.append(aux_srcs[best_idx])
263+
remaining_idcs.remove(best_idx)
264+
selected_names.append(source_names[best_idx])
265+
logger.info(
266+
f"Step {step + 1}: selected '{source_names[best_idx]}' "
267+
f"(mean CV {eval_criterion}={best_score:.6f})"
268+
)
269+
finally:
270+
# Restore original auxiliary experiments.
271+
if original_aux is not None:
272+
target_experiment.auxiliary_experiments_by_purpose[
273+
AuxiliaryExperimentPurpose.TRANSFERABLE_EXPERIMENT
274+
] = original_aux
275+
else:
276+
target_experiment.auxiliary_experiments_by_purpose.pop(
277+
AuxiliaryExperimentPurpose.TRANSFERABLE_EXPERIMENT, None
278+
)
279+
280+
return selected_names
Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,6 @@
1+
# Copyright (c) Meta Platforms, Inc. and affiliates.
2+
#
3+
# This source code is licensed under the MIT license found in the
4+
# LICENSE file in the root directory of this source tree.
5+
6+
# pyre-strict

0 commit comments

Comments
 (0)