Skip to content

Commit 6e0bd89

Browse files
Carl Hvarfnerfacebook-github-bot
authored andcommitted
Decouple parameter discovery from data extraction in _get_fit_args (facebook#5200)
Summary: Adds a `data_parameters` argument to `TorchAdapter._get_fit_args` that decouples SSD construction (model params) from data column extraction (target params). This lets the TL adapter set `_model_space` to include source-only RangeParameters directly, so the SSD naturally covers the full joint feature space -- eliminating the need for the `_expand_ssd_to_joint_space` post-hoc expansion. Overrides `_set_search_space` to add source-only RangeParameters from the joint search space to `_model_space` while preserving target bounds for shared params (Normalize stays anchored to target bounds). At gen time, `self.parameters` is temporarily swapped to target-only so `extract_search_space_digest` sees only params present in the gen-time search space. Deletes `_expand_ssd_to_joint_space` (~90 lines). Differential Revision: D104702983
1 parent d7f94ac commit 6e0bd89

3 files changed

Lines changed: 263 additions & 213 deletions

File tree

ax/adapter/torch.py

Lines changed: 19 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -774,6 +774,7 @@ def _get_fit_args(
774774
search_space: SearchSpace,
775775
experiment_data: ExperimentData,
776776
update_outcomes_and_parameters: bool,
777+
data_parameters: list[str] | None = None,
777778
) -> tuple[
778779
list[SupervisedDataset],
779780
list[list[TCandidateMetadata]] | None,
@@ -791,6 +792,11 @@ def _get_fit_args(
791792
update_outcomes_and_parameters: Whether to update `self.outcomes` with
792793
all outcomes found in the observations and `self.parameters` with
793794
all parameters in the search space. Typically only used in `_fit`.
795+
data_parameters: When provided, columns to extract from
796+
``experiment_data``. Defaults to ``self.parameters``. Useful when
797+
the model space is larger than the data (e.g. transfer learning
798+
with heterogeneous search spaces where the data only contains
799+
target columns but the model operates in a joint feature space).
794800
795801
Returns:
796802
The datasets & metadata, extracted from the ``experiment_data``, and the
@@ -818,12 +824,23 @@ def _get_fit_args(
818824
search_space_digest = extract_search_space_digest(
819825
search_space=search_space, param_names=self.parameters
820826
)
827+
extract_params = (
828+
data_parameters if data_parameters is not None else self.parameters
829+
)
830+
# When data_parameters differs from self.parameters, the SSD's
831+
# task_feature indices don't match the data columns. Pass None
832+
# to skip MultiTaskDataset wrapping (the caller handles it).
833+
extraction_ssd = (
834+
None
835+
if data_parameters is not None and data_parameters != self.parameters
836+
else search_space_digest
837+
)
821838
# Convert observations to datasets
822839
datasets, ordered_outcomes, candidate_metadata = self._convert_experiment_data(
823840
experiment_data=experiment_data,
824841
outcomes=self.outcomes,
825-
parameters=self.parameters,
826-
search_space_digest=search_space_digest,
842+
parameters=extract_params,
843+
search_space_digest=extraction_ssd,
827844
)
828845
datasets = self._update_w_aux_exp_datasets(datasets=datasets)
829846

ax/adapter/transfer_learning/adapter.py

Lines changed: 67 additions & 115 deletions
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,6 @@
77

88
from __future__ import annotations
99

10-
import dataclasses
1110
import warnings
1211
from collections.abc import Mapping, Sequence
1312
from logging import Logger
@@ -39,7 +38,7 @@
3938
from ax.core.observation import ObservationData, ObservationFeatures
4039
from ax.core.optimization_config import OptimizationConfig
4140
from ax.core.parameter import FixedParameter, RangeParameter
42-
from ax.core.search_space import SearchSpace, SearchSpaceDigest
41+
from ax.core.search_space import SearchSpace
4342
from ax.exceptions.core import DataRequiredError, UnsupportedError, UserInputError
4443
from ax.generation_strategy.best_model_selector import (
4544
ReductionCriterion,
@@ -64,6 +63,8 @@
6463

6564

6665
class TransferLearningAdapter(TorchAdapter):
66+
_source_only_params: set[str]
67+
6768
def __init__(
6869
self,
6970
*,
@@ -150,22 +151,7 @@ def __init__(
150151
target_search_space=search_space,
151152
)
152153

153-
# Add source-only backfilled params as FixedParameter to the target
154-
# search space so that the compatibility check passes and the model
155-
# space includes these params (FixedToTunable will later convert them
156-
# to RangeParameter using the joint space bounds).
157154
search_space = search_space.clone() # avoid mutating caller's object
158-
target_param_names = set(search_space.parameters.keys())
159-
for name, param in self.joint_search_space.parameters.items():
160-
if name not in target_param_names and param.backfill_value is not None:
161-
search_space.add_parameter(
162-
FixedParameter(
163-
name=name,
164-
parameter_type=param.parameter_type,
165-
value=param.backfill_value,
166-
)
167-
)
168-
169155
# Include backfill param names in filled_params so Phase 1 of
170156
# check_search_space_compatibility passes for target-only params.
171157
filled_params.extend(self.joint_search_space.backfill_values().keys())
@@ -196,6 +182,28 @@ def __init__(
196182
default_model_gen_options=default_model_gen_options,
197183
)
198184

185+
def _set_search_space(self, search_space: SearchSpace) -> None:
186+
"""Set search space and model space for transfer learning.
187+
188+
Overrides the base class to add source-only params (as RangeParameters)
189+
to ``_model_space`` while preserving target bounds for shared params.
190+
This ensures the SSD naturally covers the full joint feature space
191+
without needing post-hoc expansion, and Normalize is anchored to target
192+
bounds so target data maps to [0, 1].
193+
"""
194+
self._search_space = search_space.clone()
195+
model_space = search_space.clone()
196+
self._source_only_params: set[str] = set()
197+
for name, param in self.joint_search_space.parameters.items():
198+
if name not in model_space.parameters and isinstance(param, RangeParameter):
199+
model_space.add_parameter(param.clone())
200+
# Only mark as source-only if no backfill value exists.
201+
# Backfilled params have known values (via FillMissingParameters)
202+
# and should be included in data extraction.
203+
if param.backfill_value is None:
204+
self._source_only_params.add(name)
205+
self._model_space = model_space
206+
199207
def _transform_data(
200208
self,
201209
experiment_data: ExperimentData,
@@ -505,94 +513,18 @@ def _get_task_datasets(
505513
)
506514
return task_datasets
507515

508-
def _expand_ssd_to_joint_space(
516+
def _get_target_data_parameters(
509517
self,
510-
search_space_digest: SearchSpaceDigest,
511-
) -> SearchSpaceDigest:
512-
"""Expand SSD bounds and feature_names to cover the joint search space.
513-
514-
The SSD produced by ``_get_fit_args`` reflects the target search space.
515-
When source experiments have additional parameters, the model operates
516-
in the full joint feature space. This method appends bounds and feature
517-
names for source-only parameters so that input transforms receive
518-
correct full-space bounds.
518+
all_params: list[str],
519+
) -> list[str]:
520+
"""Filter a joint parameter list to target-only params + task feature.
521+
522+
Source-only params (those added by ``_set_search_space`` from the joint
523+
space) are excluded because the target experiment data does not have
524+
those columns. Uses untransformed names, which are stable across
525+
transforms (Range params are never renamed by OneHot, IntToFloat, etc.).
519526
"""
520-
existing_names = set(search_space_digest.feature_names)
521-
extra_names: list[str] = []
522-
extra_bounds: list[tuple[int | float, int | float]] = []
523-
# Only collect parameters absent from the target SSD. Shared
524-
# parameters that appear in both target and source keep the target
525-
# bounds -- source observations outside those bounds will normalize
526-
# outside [0, 1]. This is intentional, as the GP hyperprior is calibrated
527-
# for a __target__ task in [0, 1]^D.
528-
for name, param in self.joint_search_space.parameters.items():
529-
if name not in existing_names and isinstance(param, RangeParameter):
530-
extra_names.append(name)
531-
extra_bounds.append((param.lower, param.upper))
532-
if not extra_names:
533-
return search_space_digest
534-
# Insert source-only params before the task feature
535-
task_features = search_space_digest.task_features
536-
if len(task_features) == 1:
537-
tf_idx = task_features[0]
538-
names = list(search_space_digest.feature_names)
539-
bounds = list(search_space_digest.bounds)
540-
# Raise if index-based fields (other than the task feature
541-
# itself) reference indices at or above tf_idx, since we would
542-
# need to shift them when inserting extra params.
543-
for field_name in (
544-
"ordinal_features",
545-
"categorical_features",
546-
"fidelity_features",
547-
):
548-
indices = getattr(search_space_digest, field_name)
549-
if any(i >= tf_idx for i in indices):
550-
raise UnsupportedError(
551-
f"Cannot expand SSD: {field_name} contains index >= {tf_idx}."
552-
)
553-
if any(
554-
i >= tf_idx and i not in task_features
555-
for i in search_space_digest.discrete_choices
556-
):
557-
raise UnsupportedError(
558-
f"Cannot expand SSD: discrete_choices contains index >= {tf_idx}."
559-
)
560-
if search_space_digest.hierarchical_dependencies is not None and any(
561-
i >= tf_idx for i in search_space_digest.hierarchical_dependencies
562-
):
563-
raise UnsupportedError(
564-
"Cannot expand SSD: hierarchical_dependencies contains "
565-
f"index >= {tf_idx}."
566-
)
567-
names[tf_idx:tf_idx] = extra_names
568-
bounds[tf_idx:tf_idx] = extra_bounds
569-
n_extra = len(extra_names)
570-
new_task_features = [tf_idx + n_extra]
571-
new_target_values = dict(search_space_digest.target_values)
572-
if tf_idx in new_target_values:
573-
new_target_values[new_task_features[0]] = new_target_values.pop(tf_idx)
574-
new_discrete = dict(search_space_digest.discrete_choices)
575-
if tf_idx in new_discrete:
576-
new_discrete[new_task_features[0]] = new_discrete.pop(tf_idx)
577-
return dataclasses.replace(
578-
search_space_digest,
579-
feature_names=names,
580-
bounds=bounds,
581-
task_features=new_task_features,
582-
target_values=new_target_values,
583-
discrete_choices=new_discrete,
584-
)
585-
elif len(task_features) == 0:
586-
# No task feature -- just append.
587-
return dataclasses.replace(
588-
search_space_digest,
589-
feature_names=search_space_digest.feature_names + extra_names,
590-
bounds=search_space_digest.bounds + extra_bounds,
591-
)
592-
else:
593-
raise UnsupportedError(
594-
"Multiple task features are not supported in transfer learning."
595-
)
527+
return [p for p in all_params if p not in self._source_only_params]
596528

597529
def _fit(
598530
self,
@@ -610,15 +542,20 @@ def _fit(
610542
if experiment_data.arm_data.empty:
611543
# Temporarily unset self.outcomes to avoid an error in _get_fit_args.
612544
self.outcomes = []
545+
# Pre-compute the joint param ordering (mirrors _get_fit_args logic)
546+
# so we can derive the target-only subset for data extraction.
547+
all_params = list(search_space.parameters.keys())
548+
task_name = Keys.TASK_FEATURE_NAME.value
549+
if task_name in all_params:
550+
all_params.remove(task_name)
551+
all_params.append(task_name)
552+
target_data_params = self._get_target_data_parameters(all_params)
613553
datasets, candidate_metadata, search_space_digest = self._get_fit_args(
614554
search_space=search_space,
615555
experiment_data=experiment_data,
616556
update_outcomes_and_parameters=True,
557+
data_parameters=target_data_params,
617558
)
618-
# Expand SSD bounds to cover source-only params from the joint search
619-
# space. This ensures Normalize (and other input transforms) get bounds
620-
# for the full feature space, not just the target dims.
621-
search_space_digest = self._expand_ssd_to_joint_space(search_space_digest)
622559
if experiment_data.arm_data.empty:
623560
self.outcomes = outcomes
624561
# Temporarily set datasets to None. We will construct empty datasets
@@ -656,12 +593,13 @@ def _cross_validate(
656593
) -> list[ObservationData]:
657594
if self.parameters is None:
658595
raise ValueError(FIT_MODEL_ERROR.format(action="_cross_validate"))
596+
target_data_params = self._get_target_data_parameters(self.parameters)
659597
datasets, _, search_space_digest = self._get_fit_args(
660598
search_space=search_space,
661599
experiment_data=cv_training_data,
662600
update_outcomes_and_parameters=False,
601+
data_parameters=target_data_params,
663602
)
664-
search_space_digest = self._expand_ssd_to_joint_space(search_space_digest)
665603
# Add the task feature to SSD, to ensure that a multi-task model is selected.
666604
if len(search_space_digest.task_features) > 1:
667605
raise UnsupportedError(
@@ -714,23 +652,22 @@ def gen(
714652
to ``RemoveFixed.untransform_observation_features``, which requires updating
715653
the signature of all transforms.
716654
"""
717-
# If a fixed parameter in the target search space, is
718-
# a range parameter in the joint search space, then we
719-
# should set it as a fixed feature here.
655+
# If a fixed parameter in the target search space is a range
656+
# parameter in the joint search space, pin it as a fixed feature.
720657
search_space = search_space or self._search_space
721658
for name, target_p in search_space.parameters.items():
722659
if (
723660
isinstance(target_p, FixedParameter)
724661
and (p := self.joint_search_space.parameters.get(name)) is not None
725662
and isinstance(p, RangeParameter)
726663
):
727-
# add to fixed features
728664
if fixed_features is None:
729665
fixed_features = ObservationFeatures(parameters={})
730666
fixed_features.parameters.setdefault(name, target_p.value)
731-
# Fix source-only params so the optimizer doesn't search over them.
732-
# Center is a reasonable default; LearnedFeatureImputation overwrites
733-
# these with learned values when configured.
667+
# Fix source-only params that ARE in the search space (e.g. injected
668+
# as FixedParam with a backfill value) so the optimizer doesn't search
669+
# over them. Params NOT in the search space are handled by the model
670+
# internally (HeterogeneousMTGP natively, LFI for MultiTaskGP).
734671
joint_center = self.joint_search_space.compute_naive_center()
735672
for name, param in self.joint_search_space.parameters.items():
736673
if name not in search_space.parameters and isinstance(
@@ -739,6 +676,20 @@ def gen(
739676
if fixed_features is None:
740677
fixed_features = ObservationFeatures(parameters={})
741678
fixed_features.parameters.setdefault(name, joint_center[name])
679+
# At gen time, optimize over the target search space only.
680+
# _source_only_params covers params without backfill (stable
681+
# untransformed names). Backfilled source-only params are identified
682+
# by checking joint SS backfill_values — their names are also stable
683+
# (RangeParameters are never renamed by transforms).
684+
saved_parameters = self.parameters
685+
backfilled_source_only = {
686+
name
687+
for name, param in self.joint_search_space.parameters.items()
688+
if name not in self._experiment.search_space.parameters
689+
and param.backfill_value is not None
690+
}
691+
exclude = self._source_only_params | backfilled_source_only
692+
self.parameters = [p for p in self.parameters if p not in exclude]
742693
generator_run = super().gen(
743694
n=n,
744695
search_space=search_space,
@@ -747,6 +698,7 @@ def gen(
747698
fixed_features=fixed_features,
748699
model_gen_options=model_gen_options,
749700
)
701+
self.parameters = saved_parameters
750702
# Remove the parameters that are not in the target experiment's search
751703
# space, and update candidate_metadata_by_arm_signature to reflect the
752704
# new arm. We use the experiment's search space rather than

0 commit comments

Comments
 (0)