77
88from __future__ import annotations
99
10- import dataclasses
1110import warnings
1211from collections .abc import Mapping , Sequence
1312from logging import Logger
3938from ax .core .observation import ObservationData , ObservationFeatures
4039from ax .core .optimization_config import OptimizationConfig
4140from ax .core .parameter import FixedParameter , RangeParameter
42- from ax .core .search_space import SearchSpace , SearchSpaceDigest
41+ from ax .core .search_space import SearchSpace
4342from ax .exceptions .core import DataRequiredError , UnsupportedError , UserInputError
4443from ax .generation_strategy .best_model_selector import (
4544 ReductionCriterion ,
6463
6564
6665class TransferLearningAdapter (TorchAdapter ):
66+ _source_only_params : set [str ]
67+
6768 def __init__ (
6869 self ,
6970 * ,
@@ -150,22 +151,7 @@ def __init__(
150151 target_search_space = search_space ,
151152 )
152153
153- # Add source-only backfilled params as FixedParameter to the target
154- # search space so that the compatibility check passes and the model
155- # space includes these params (FixedToTunable will later convert them
156- # to RangeParameter using the joint space bounds).
157154 search_space = search_space .clone () # avoid mutating caller's object
158- target_param_names = set (search_space .parameters .keys ())
159- for name , param in self .joint_search_space .parameters .items ():
160- if name not in target_param_names and param .backfill_value is not None :
161- search_space .add_parameter (
162- FixedParameter (
163- name = name ,
164- parameter_type = param .parameter_type ,
165- value = param .backfill_value ,
166- )
167- )
168-
169155 # Include backfill param names in filled_params so Phase 1 of
170156 # check_search_space_compatibility passes for target-only params.
171157 filled_params .extend (self .joint_search_space .backfill_values ().keys ())
@@ -196,6 +182,28 @@ def __init__(
196182 default_model_gen_options = default_model_gen_options ,
197183 )
198184
185+ def _set_search_space (self , search_space : SearchSpace ) -> None :
186+ """Set search space and model space for transfer learning.
187+
188+ Overrides the base class to add source-only params (as RangeParameters)
189+ to ``_model_space`` while preserving target bounds for shared params.
190+ This ensures the SSD naturally covers the full joint feature space
191+ without needing post-hoc expansion, and Normalize is anchored to target
192+ bounds so target data maps to [0, 1].
193+ """
194+ self ._search_space = search_space .clone ()
195+ model_space = search_space .clone ()
196+ self ._source_only_params : set [str ] = set ()
197+ for name , param in self .joint_search_space .parameters .items ():
198+ if name not in model_space .parameters and isinstance (param , RangeParameter ):
199+ model_space .add_parameter (param .clone ())
200+ # Only mark as source-only if no backfill value exists.
201+ # Backfilled params have known values (via FillMissingParameters)
202+ # and should be included in data extraction.
203+ if param .backfill_value is None :
204+ self ._source_only_params .add (name )
205+ self ._model_space = model_space
206+
199207 def _transform_data (
200208 self ,
201209 experiment_data : ExperimentData ,
@@ -505,94 +513,18 @@ def _get_task_datasets(
505513 )
506514 return task_datasets
507515
508- def _expand_ssd_to_joint_space (
516+ def _get_target_data_parameters (
509517 self ,
510- search_space_digest : SearchSpaceDigest ,
511- ) -> SearchSpaceDigest :
512- """Expand SSD bounds and feature_names to cover the joint search space.
513-
514- The SSD produced by ``_get_fit_args`` reflects the target search space.
515- When source experiments have additional parameters, the model operates
516- in the full joint feature space. This method appends bounds and feature
517- names for source-only parameters so that input transforms receive
518- correct full-space bounds.
518+ all_params : list [str ],
519+ ) -> list [str ]:
520+ """Filter a joint parameter list to target-only params + task feature.
521+
522+ Source-only params (those added by ``_set_search_space`` from the joint
523+ space) are excluded because the target experiment data does not have
524+ those columns. Uses untransformed names, which are stable across
525+ transforms (Range params are never renamed by OneHot, IntToFloat, etc.).
519526 """
520- existing_names = set (search_space_digest .feature_names )
521- extra_names : list [str ] = []
522- extra_bounds : list [tuple [int | float , int | float ]] = []
523- # Only collect parameters absent from the target SSD. Shared
524- # parameters that appear in both target and source keep the target
525- # bounds -- source observations outside those bounds will normalize
526- # outside [0, 1]. This is intentional, as the GP hyperprior is calibrated
527- # for a __target__ task in [0, 1]^D.
528- for name , param in self .joint_search_space .parameters .items ():
529- if name not in existing_names and isinstance (param , RangeParameter ):
530- extra_names .append (name )
531- extra_bounds .append ((param .lower , param .upper ))
532- if not extra_names :
533- return search_space_digest
534- # Insert source-only params before the task feature
535- task_features = search_space_digest .task_features
536- if len (task_features ) == 1 :
537- tf_idx = task_features [0 ]
538- names = list (search_space_digest .feature_names )
539- bounds = list (search_space_digest .bounds )
540- # Raise if index-based fields (other than the task feature
541- # itself) reference indices at or above tf_idx, since we would
542- # need to shift them when inserting extra params.
543- for field_name in (
544- "ordinal_features" ,
545- "categorical_features" ,
546- "fidelity_features" ,
547- ):
548- indices = getattr (search_space_digest , field_name )
549- if any (i >= tf_idx for i in indices ):
550- raise UnsupportedError (
551- f"Cannot expand SSD: { field_name } contains index >= { tf_idx } ."
552- )
553- if any (
554- i >= tf_idx and i not in task_features
555- for i in search_space_digest .discrete_choices
556- ):
557- raise UnsupportedError (
558- f"Cannot expand SSD: discrete_choices contains index >= { tf_idx } ."
559- )
560- if search_space_digest .hierarchical_dependencies is not None and any (
561- i >= tf_idx for i in search_space_digest .hierarchical_dependencies
562- ):
563- raise UnsupportedError (
564- "Cannot expand SSD: hierarchical_dependencies contains "
565- f"index >= { tf_idx } ."
566- )
567- names [tf_idx :tf_idx ] = extra_names
568- bounds [tf_idx :tf_idx ] = extra_bounds
569- n_extra = len (extra_names )
570- new_task_features = [tf_idx + n_extra ]
571- new_target_values = dict (search_space_digest .target_values )
572- if tf_idx in new_target_values :
573- new_target_values [new_task_features [0 ]] = new_target_values .pop (tf_idx )
574- new_discrete = dict (search_space_digest .discrete_choices )
575- if tf_idx in new_discrete :
576- new_discrete [new_task_features [0 ]] = new_discrete .pop (tf_idx )
577- return dataclasses .replace (
578- search_space_digest ,
579- feature_names = names ,
580- bounds = bounds ,
581- task_features = new_task_features ,
582- target_values = new_target_values ,
583- discrete_choices = new_discrete ,
584- )
585- elif len (task_features ) == 0 :
586- # No task feature -- just append.
587- return dataclasses .replace (
588- search_space_digest ,
589- feature_names = search_space_digest .feature_names + extra_names ,
590- bounds = search_space_digest .bounds + extra_bounds ,
591- )
592- else :
593- raise UnsupportedError (
594- "Multiple task features are not supported in transfer learning."
595- )
527+ return [p for p in all_params if p not in self ._source_only_params ]
596528
597529 def _fit (
598530 self ,
@@ -610,15 +542,20 @@ def _fit(
610542 if experiment_data .arm_data .empty :
611543 # Temporarily unset self.outcomes to avoid an error in _get_fit_args.
612544 self .outcomes = []
545+ # Pre-compute the joint param ordering (mirrors _get_fit_args logic)
546+ # so we can derive the target-only subset for data extraction.
547+ all_params = list (search_space .parameters .keys ())
548+ task_name = Keys .TASK_FEATURE_NAME .value
549+ if task_name in all_params :
550+ all_params .remove (task_name )
551+ all_params .append (task_name )
552+ target_data_params = self ._get_target_data_parameters (all_params )
613553 datasets , candidate_metadata , search_space_digest = self ._get_fit_args (
614554 search_space = search_space ,
615555 experiment_data = experiment_data ,
616556 update_outcomes_and_parameters = True ,
557+ data_parameters = target_data_params ,
617558 )
618- # Expand SSD bounds to cover source-only params from the joint search
619- # space. This ensures Normalize (and other input transforms) get bounds
620- # for the full feature space, not just the target dims.
621- search_space_digest = self ._expand_ssd_to_joint_space (search_space_digest )
622559 if experiment_data .arm_data .empty :
623560 self .outcomes = outcomes
624561 # Temporarily set datasets to None. We will construct empty datasets
@@ -656,12 +593,13 @@ def _cross_validate(
656593 ) -> list [ObservationData ]:
657594 if self .parameters is None :
658595 raise ValueError (FIT_MODEL_ERROR .format (action = "_cross_validate" ))
596+ target_data_params = self ._get_target_data_parameters (self .parameters )
659597 datasets , _ , search_space_digest = self ._get_fit_args (
660598 search_space = search_space ,
661599 experiment_data = cv_training_data ,
662600 update_outcomes_and_parameters = False ,
601+ data_parameters = target_data_params ,
663602 )
664- search_space_digest = self ._expand_ssd_to_joint_space (search_space_digest )
665603 # Add the task feature to SSD, to ensure that a multi-task model is selected.
666604 if len (search_space_digest .task_features ) > 1 :
667605 raise UnsupportedError (
@@ -714,23 +652,22 @@ def gen(
714652 to ``RemoveFixed.untransform_observation_features``, which requires updating
715653 the signature of all transforms.
716654 """
717- # If a fixed parameter in the target search space, is
718- # a range parameter in the joint search space, then we
719- # should set it as a fixed feature here.
655+ # If a fixed parameter in the target search space is a range
656+ # parameter in the joint search space, pin it as a fixed feature.
720657 search_space = search_space or self ._search_space
721658 for name , target_p in search_space .parameters .items ():
722659 if (
723660 isinstance (target_p , FixedParameter )
724661 and (p := self .joint_search_space .parameters .get (name )) is not None
725662 and isinstance (p , RangeParameter )
726663 ):
727- # add to fixed features
728664 if fixed_features is None :
729665 fixed_features = ObservationFeatures (parameters = {})
730666 fixed_features .parameters .setdefault (name , target_p .value )
731- # Fix source-only params so the optimizer doesn't search over them.
732- # Center is a reasonable default; LearnedFeatureImputation overwrites
733- # these with learned values when configured.
667+ # Fix source-only params that ARE in the search space (e.g. injected
668+ # as FixedParam with a backfill value) so the optimizer doesn't search
669+ # over them. Params NOT in the search space are handled by the model
670+ # internally (HeterogeneousMTGP natively, LFI for MultiTaskGP).
734671 joint_center = self .joint_search_space .compute_naive_center ()
735672 for name , param in self .joint_search_space .parameters .items ():
736673 if name not in search_space .parameters and isinstance (
@@ -739,6 +676,20 @@ def gen(
739676 if fixed_features is None :
740677 fixed_features = ObservationFeatures (parameters = {})
741678 fixed_features .parameters .setdefault (name , joint_center [name ])
679+ # At gen time, optimize over the target search space only.
680+ # _source_only_params covers params without backfill (stable
681+ # untransformed names). Backfilled source-only params are identified
682+ # by checking joint SS backfill_values — their names are also stable
683+ # (RangeParameters are never renamed by transforms).
684+ saved_parameters = self .parameters
685+ backfilled_source_only = {
686+ name
687+ for name , param in self .joint_search_space .parameters .items ()
688+ if name not in self ._experiment .search_space .parameters
689+ and param .backfill_value is not None
690+ }
691+ exclude = self ._source_only_params | backfilled_source_only
692+ self .parameters = [p for p in self .parameters if p not in exclude ]
742693 generator_run = super ().gen (
743694 n = n ,
744695 search_space = search_space ,
@@ -747,6 +698,7 @@ def gen(
747698 fixed_features = fixed_features ,
748699 model_gen_options = model_gen_options ,
749700 )
701+ self .parameters = saved_parameters
750702 # Remove the parameters that are not in the target experiment's search
751703 # space, and update candidate_metadata_by_arm_signature to reflect the
752704 # new arm. We use the experiment's search space rather than
0 commit comments