Skip to content

Commit 1e01aa5

Browse files
Carl Hvarfnerfacebook-github-bot
authored andcommitted
Use search space bounds for Normalize in transfer learning adapter (facebook#5184)
Summary: The transfer learning adapter explicitly passed `bounds=None` to Normalize, forcing `learn_bounds=True`. This caused Normalize bounds to be learned from data instead of fixed to the search space, resulting in bounds that drift during training and differ between benchmark configs despite identical search spaces. Remove the `bounds=None` override so that `_set_default_bounds` provides the correct search space bounds from the SearchSpaceDigest. Reviewed By: sdaulton Differential Revision: D100669010
1 parent 9037c4b commit 1e01aa5

3 files changed

Lines changed: 226 additions & 13 deletions

File tree

ax/adapter/transfer_learning/adapter.py

Lines changed: 101 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@
77

88
from __future__ import annotations
99

10+
import dataclasses
1011
import warnings
1112
from collections.abc import Mapping, Sequence
1213
from logging import Logger
@@ -38,7 +39,7 @@
3839
from ax.core.observation import ObservationData, ObservationFeatures
3940
from ax.core.optimization_config import OptimizationConfig
4041
from ax.core.parameter import FixedParameter, RangeParameter
41-
from ax.core.search_space import SearchSpace
42+
from ax.core.search_space import SearchSpace, SearchSpaceDigest
4243
from ax.exceptions.core import DataRequiredError, UnsupportedError, UserInputError
4344
from ax.generation_strategy.best_model_selector import (
4445
ReductionCriterion,
@@ -504,6 +505,86 @@ def _get_task_datasets(
504505
)
505506
return task_datasets
506507

508+
def _expand_ssd_to_joint_space(
509+
self,
510+
search_space_digest: SearchSpaceDigest,
511+
) -> SearchSpaceDigest:
512+
"""Expand SSD bounds and feature_names to cover the joint search space.
513+
514+
The SSD produced by ``_get_fit_args`` reflects the target search space.
515+
When source experiments have additional parameters, the model operates
516+
in the full joint feature space. This method appends bounds and feature
517+
names for source-only parameters so that input transforms receive
518+
correct full-space bounds.
519+
"""
520+
existing_names = set(search_space_digest.feature_names)
521+
extra_names: list[str] = []
522+
extra_bounds: list[tuple[int | float, int | float]] = []
523+
# Only collect parameters absent from the target SSD. Shared
524+
# parameters that appear in both target and source keep the target
525+
# bounds -- source observations outside those bounds will normalize
526+
# outside [0, 1]. This is intentional, as the GP hyperprior is calibrated
527+
# for a __target__ task in [0, 1]^D.
528+
for name, param in self.joint_search_space.parameters.items():
529+
if name not in existing_names and isinstance(param, RangeParameter):
530+
extra_names.append(name)
531+
extra_bounds.append((param.lower, param.upper))
532+
if not extra_names:
533+
return search_space_digest
534+
# Insert source-only params before the task feature
535+
task_features = search_space_digest.task_features
536+
if len(task_features) == 1:
537+
tf_idx = task_features[0]
538+
names = list(search_space_digest.feature_names)
539+
bounds = list(search_space_digest.bounds)
540+
# Assert no index-based fields need shifting.
541+
for field_name in (
542+
"ordinal_features",
543+
"categorical_features",
544+
"fidelity_features",
545+
):
546+
indices = getattr(search_space_digest, field_name)
547+
if any(i >= tf_idx for i in indices):
548+
raise UnsupportedError(
549+
f"Cannot expand SSD: {field_name} contains index >= {tf_idx}."
550+
)
551+
if any(i >= tf_idx for i in search_space_digest.discrete_choices):
552+
raise UnsupportedError(
553+
f"Cannot expand SSD: discrete_choices contains index >= {tf_idx}."
554+
)
555+
if search_space_digest.hierarchical_dependencies is not None and any(
556+
i >= tf_idx for i in search_space_digest.hierarchical_dependencies
557+
):
558+
raise UnsupportedError(
559+
"Cannot expand SSD: hierarchical_dependencies contains "
560+
f"index >= {tf_idx}."
561+
)
562+
names[tf_idx:tf_idx] = extra_names
563+
bounds[tf_idx:tf_idx] = extra_bounds
564+
# Task feature index shifts by the number of inserted params.
565+
new_task_features = [tf_idx + len(extra_names)]
566+
new_target_values = dict(search_space_digest.target_values)
567+
if tf_idx in new_target_values:
568+
new_target_values[new_task_features[0]] = new_target_values.pop(tf_idx)
569+
return dataclasses.replace(
570+
search_space_digest,
571+
feature_names=names,
572+
bounds=bounds,
573+
task_features=new_task_features,
574+
target_values=new_target_values,
575+
)
576+
elif len(task_features) == 0:
577+
# No task feature -- just append.
578+
return dataclasses.replace(
579+
search_space_digest,
580+
feature_names=search_space_digest.feature_names + extra_names,
581+
bounds=search_space_digest.bounds + extra_bounds,
582+
)
583+
else:
584+
raise UnsupportedError(
585+
"Multiple task features are not supported in transfer learning."
586+
)
587+
507588
def _fit(
508589
self,
509590
search_space: SearchSpace,
@@ -525,6 +606,10 @@ def _fit(
525606
experiment_data=experiment_data,
526607
update_outcomes_and_parameters=True,
527608
)
609+
# Expand SSD bounds to cover source-only params from the joint search
610+
# space. This ensures Normalize (and other input transforms) get bounds
611+
# for the full feature space, not just the target dims.
612+
search_space_digest = self._expand_ssd_to_joint_space(search_space_digest)
528613
if experiment_data.arm_data.empty:
529614
self.outcomes = outcomes
530615
# Temporarily set datasets to None. We will construct empty datasets
@@ -567,6 +652,7 @@ def _cross_validate(
567652
experiment_data=cv_training_data,
568653
update_outcomes_and_parameters=False,
569654
)
655+
search_space_digest = self._expand_ssd_to_joint_space(search_space_digest)
570656
# Add the task feature to SSD, to ensure that a multi-task model is selected.
571657
if len(search_space_digest.task_features) > 1:
572658
raise UnsupportedError(
@@ -612,7 +698,7 @@ def gen(
612698
613699
Once the ``GeneratorRun`` is produced, it checks for any fixed parameters
614700
that are not in the target search space and removes them. This is a hack
615-
around limitations of the ``RemoveFixed`` transform. Since we construct the
701+
around limitations of the Ax ``RemoveFixed`` transform. Since we construct the
616702
transforms with the joint space, we end up adding back all fixed parameters
617703
from the joint space rather than adding only the parameters from the
618704
target search space. A proper fix would require passing in the search space
@@ -633,6 +719,17 @@ def gen(
633719
if fixed_features is None:
634720
fixed_features = ObservationFeatures(parameters={})
635721
fixed_features.parameters.setdefault(name, target_p.value)
722+
# Fix source-only params so the optimizer doesn't search over them.
723+
# Center is a reasonable default; LearnedFeatureImputation overwrites
724+
# these with learned values when configured.
725+
joint_center = self.joint_search_space.compute_naive_center()
726+
for name, param in self.joint_search_space.parameters.items():
727+
if name not in search_space.parameters and isinstance(
728+
param, RangeParameter
729+
):
730+
if fixed_features is None:
731+
fixed_features = ObservationFeatures(parameters={})
732+
fixed_features.parameters.setdefault(name, joint_center[name])
636733
generator_run = super().gen(
637734
n=n,
638735
search_space=search_space,
@@ -719,12 +816,8 @@ def transfer_learning_generator_specs_constructor(
719816
selector in case there is model selection enabled.
720817
"""
721818
input_transform_classes: list[type[InputTransform]] = [Normalize]
722-
input_transform_options = {
723-
"Normalize": {
724-
# None for bounds here ensures we do not use bounds from
725-
# the search space digest.
726-
"bounds": None,
727-
}
819+
input_transform_options: dict[str, dict[str, Any]] = {
820+
"Normalize": {},
728821
}
729822
transforms = transforms or MBM_X_trans + [MetadataToTask] + Y_trans
730823
transform_configs = get_derelativize_config(
Lines changed: 116 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,116 @@
1+
# Copyright (c) Meta Platforms, Inc. and affiliates.
2+
#
3+
# This source code is licensed under the MIT license found in the
4+
# LICENSE file in the root directory of this source tree.
5+
6+
# pyre-strict
7+
8+
from unittest.mock import MagicMock, PropertyMock
9+
10+
from ax.adapter.transfer_learning.adapter import TransferLearningAdapter
11+
from ax.core.parameter import ParameterType, RangeParameter
12+
from ax.core.search_space import SearchSpace, SearchSpaceDigest
13+
from ax.exceptions.core import UnsupportedError
14+
from ax.utils.common.testutils import TestCase
15+
16+
17+
class ExpandSsdToJointSpaceTest(TestCase):
18+
def setUp(self) -> None:
19+
super().setUp()
20+
self.adapter = MagicMock(spec=TransferLearningAdapter)
21+
22+
def _make_joint_ss(self, params: dict[str, tuple[float, float]]) -> SearchSpace:
23+
return SearchSpace(
24+
parameters=[
25+
RangeParameter(
26+
name=n,
27+
lower=lo,
28+
upper=hi,
29+
parameter_type=ParameterType.FLOAT,
30+
)
31+
for n, (lo, hi) in params.items()
32+
]
33+
)
34+
35+
def test_no_extra_params_returns_unchanged(self) -> None:
36+
type(self.adapter).joint_search_space = PropertyMock(
37+
return_value=self._make_joint_ss({"x1": (0, 1), "x2": (0, 1)})
38+
)
39+
ssd = SearchSpaceDigest(
40+
feature_names=["x1", "x2", "task"],
41+
bounds=[(0, 1), (0, 1), (0, 2)],
42+
task_features=[2],
43+
target_values={2: 0},
44+
)
45+
result = TransferLearningAdapter._expand_ssd_to_joint_space(self.adapter, ssd)
46+
self.assertIs(result, ssd)
47+
48+
def test_single_task_feature_inserts_before_task(self) -> None:
49+
type(self.adapter).joint_search_space = PropertyMock(
50+
return_value=self._make_joint_ss(
51+
{"x1": (0, 1), "x2": (0, 1), "x3": (-2, 5)}
52+
)
53+
)
54+
ssd = SearchSpaceDigest(
55+
feature_names=["x1", "x2", "task"],
56+
bounds=[(0, 1), (0, 1), (0, 2)],
57+
task_features=[2],
58+
target_values={2: 0},
59+
)
60+
result = TransferLearningAdapter._expand_ssd_to_joint_space(self.adapter, ssd)
61+
self.assertEqual(result.feature_names, ["x1", "x2", "x3", "task"])
62+
self.assertEqual(result.bounds, [(0, 1), (0, 1), (-2, 5), (0, 2)])
63+
self.assertEqual(result.task_features, [3])
64+
self.assertEqual(result.target_values, {3: 0})
65+
66+
def test_zero_task_features_appends(self) -> None:
67+
type(self.adapter).joint_search_space = PropertyMock(
68+
return_value=self._make_joint_ss({"x1": (0, 1), "x2": (-1, 3)})
69+
)
70+
ssd = SearchSpaceDigest(
71+
feature_names=["x1"],
72+
bounds=[(0, 1)],
73+
)
74+
result = TransferLearningAdapter._expand_ssd_to_joint_space(self.adapter, ssd)
75+
self.assertEqual(result.feature_names, ["x1", "x2"])
76+
self.assertEqual(result.bounds, [(0, 1), (-1, 3)])
77+
78+
def test_discrete_choices_at_task_idx_raises(self) -> None:
79+
type(self.adapter).joint_search_space = PropertyMock(
80+
return_value=self._make_joint_ss({"x1": (0, 1), "x2": (0, 1), "x3": (0, 1)})
81+
)
82+
ssd = SearchSpaceDigest(
83+
feature_names=["x1", "x2", "task"],
84+
bounds=[(0, 1), (0, 1), (0, 2)],
85+
task_features=[2],
86+
target_values={2: 0},
87+
discrete_choices={2: [0, 1, 2]},
88+
)
89+
with self.assertRaisesRegex(UnsupportedError, "discrete_choices"):
90+
TransferLearningAdapter._expand_ssd_to_joint_space(self.adapter, ssd)
91+
92+
def test_hierarchical_dependencies_at_task_idx_raises(self) -> None:
93+
type(self.adapter).joint_search_space = PropertyMock(
94+
return_value=self._make_joint_ss({"x1": (0, 1), "x2": (0, 1), "x3": (0, 1)})
95+
)
96+
ssd = SearchSpaceDigest(
97+
feature_names=["x1", "x2", "task"],
98+
bounds=[(0, 1), (0, 1), (0, 2)],
99+
task_features=[2],
100+
target_values={2: 0},
101+
hierarchical_dependencies={2: {0: [1]}},
102+
)
103+
with self.assertRaisesRegex(UnsupportedError, "hierarchical_dependencies"):
104+
TransferLearningAdapter._expand_ssd_to_joint_space(self.adapter, ssd)
105+
106+
def test_multiple_task_features_raises(self) -> None:
107+
type(self.adapter).joint_search_space = PropertyMock(
108+
return_value=self._make_joint_ss({"x1": (0, 1), "x2": (0, 1), "x3": (0, 1)})
109+
)
110+
ssd = SearchSpaceDigest(
111+
feature_names=["x1", "task1", "task2"],
112+
bounds=[(0, 1), (0, 1), (0, 1)],
113+
task_features=[1, 2],
114+
)
115+
with self.assertRaisesRegex(UnsupportedError, "Multiple task features"):
116+
TransferLearningAdapter._expand_ssd_to_joint_space(self.adapter, ssd)

ax/generators/torch/botorch_modular/surrogate.py

Lines changed: 9 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -738,11 +738,15 @@ def fit(
738738
candidate_metadata=candidate_metadata,
739739
)
740740

741-
# Only update the outcome names and models if the dataset input matches
742-
# the feature names from the search space digest. Otherwise we only
743-
# keep the model within self._submodels as it may be models fitted on
744-
# auxiliary data such as the preference model for BOPE
745-
if set(dataset.feature_names) == feature_names_set:
741+
# Only update the outcome names and models if the dataset input
742+
# matches the feature names from the SSD. In heterogeneous TL,
743+
# _expand_ssd_to_joint_space adds source-only features to the SSD,
744+
# so the target MultiTaskDataset's feature_names will be a strict
745+
# subset -- the missing names are source-only params.
746+
if set(dataset.feature_names) == feature_names_set or (
747+
isinstance(dataset, MultiTaskDataset)
748+
and set(dataset.feature_names).issubset(feature_names_set)
749+
):
746750
models.append(model)
747751
outcome_names.extend(dataset.outcome_names)
748752

0 commit comments

Comments
 (0)