77
88from __future__ import annotations
99
10+ import dataclasses
1011import warnings
1112from collections .abc import Mapping , Sequence
1213from logging import Logger
3839from ax .core .observation import ObservationData , ObservationFeatures
3940from ax .core .optimization_config import OptimizationConfig
4041from ax .core .parameter import FixedParameter , RangeParameter
41- from ax .core .search_space import SearchSpace
42+ from ax .core .search_space import SearchSpace , SearchSpaceDigest
4243from ax .exceptions .core import DataRequiredError , UnsupportedError , UserInputError
4344from ax .generation_strategy .best_model_selector import (
4445 ReductionCriterion ,
@@ -504,6 +505,86 @@ def _get_task_datasets(
504505 )
505506 return task_datasets
506507
508+ def _expand_ssd_to_joint_space (
509+ self ,
510+ search_space_digest : SearchSpaceDigest ,
511+ ) -> SearchSpaceDigest :
512+ """Expand SSD bounds and feature_names to cover the joint search space.
513+
514+ The SSD produced by ``_get_fit_args`` reflects the target search space.
515+ When source experiments have additional parameters, the model operates
516+ in the full joint feature space. This method appends bounds and feature
517+ names for source-only parameters so that input transforms receive
518+ correct full-space bounds.
519+ """
520+ existing_names = set (search_space_digest .feature_names )
521+ extra_names : list [str ] = []
522+ extra_bounds : list [tuple [int | float , int | float ]] = []
523+ # Only collect parameters absent from the target SSD. Shared
524+ # parameters that appear in both target and source keep the target
525+ # bounds -- source observations outside those bounds will normalize
526+ # outside [0, 1]. This is intentional, as the GP hyperprior is calibrated
527+ # for a __target__ task in [0, 1]^D.
528+ for name , param in self .joint_search_space .parameters .items ():
529+ if name not in existing_names and isinstance (param , RangeParameter ):
530+ extra_names .append (name )
531+ extra_bounds .append ((param .lower , param .upper ))
532+ if not extra_names :
533+ return search_space_digest
534+ # Insert source-only params before the task feature
535+ task_features = search_space_digest .task_features
536+ if len (task_features ) == 1 :
537+ tf_idx = task_features [0 ]
538+ names = list (search_space_digest .feature_names )
539+ bounds = list (search_space_digest .bounds )
540+ # Assert no index-based fields need shifting.
541+ for field_name in (
542+ "ordinal_features" ,
543+ "categorical_features" ,
544+ "fidelity_features" ,
545+ ):
546+ indices = getattr (search_space_digest , field_name )
547+ if any (i >= tf_idx for i in indices ):
548+ raise UnsupportedError (
549+ f"Cannot expand SSD: { field_name } contains index >= { tf_idx } ."
550+ )
551+ if any (i >= tf_idx for i in search_space_digest .discrete_choices ):
552+ raise UnsupportedError (
553+ f"Cannot expand SSD: discrete_choices contains index >= { tf_idx } ."
554+ )
555+ if search_space_digest .hierarchical_dependencies is not None and any (
556+ i >= tf_idx for i in search_space_digest .hierarchical_dependencies
557+ ):
558+ raise UnsupportedError (
559+ "Cannot expand SSD: hierarchical_dependencies contains "
560+ f"index >= { tf_idx } ."
561+ )
562+ names [tf_idx :tf_idx ] = extra_names
563+ bounds [tf_idx :tf_idx ] = extra_bounds
564+ # Task feature index shifts by the number of inserted params.
565+ new_task_features = [tf_idx + len (extra_names )]
566+ new_target_values = dict (search_space_digest .target_values )
567+ if tf_idx in new_target_values :
568+ new_target_values [new_task_features [0 ]] = new_target_values .pop (tf_idx )
569+ return dataclasses .replace (
570+ search_space_digest ,
571+ feature_names = names ,
572+ bounds = bounds ,
573+ task_features = new_task_features ,
574+ target_values = new_target_values ,
575+ )
576+ elif len (task_features ) == 0 :
577+ # No task feature -- just append.
578+ return dataclasses .replace (
579+ search_space_digest ,
580+ feature_names = search_space_digest .feature_names + extra_names ,
581+ bounds = search_space_digest .bounds + extra_bounds ,
582+ )
583+ else :
584+ raise UnsupportedError (
585+ "Multiple task features are not supported in transfer learning."
586+ )
587+
507588 def _fit (
508589 self ,
509590 search_space : SearchSpace ,
@@ -525,6 +606,10 @@ def _fit(
525606 experiment_data = experiment_data ,
526607 update_outcomes_and_parameters = True ,
527608 )
609+ # Expand SSD bounds to cover source-only params from the joint search
610+ # space. This ensures Normalize (and other input transforms) get bounds
611+ # for the full feature space, not just the target dims.
612+ search_space_digest = self ._expand_ssd_to_joint_space (search_space_digest )
528613 if experiment_data .arm_data .empty :
529614 self .outcomes = outcomes
530615 # Temporarily set datasets to None. We will construct empty datasets
@@ -567,6 +652,7 @@ def _cross_validate(
567652 experiment_data = cv_training_data ,
568653 update_outcomes_and_parameters = False ,
569654 )
655+ search_space_digest = self ._expand_ssd_to_joint_space (search_space_digest )
570656 # Add the task feature to SSD, to ensure that a multi-task model is selected.
571657 if len (search_space_digest .task_features ) > 1 :
572658 raise UnsupportedError (
@@ -612,7 +698,7 @@ def gen(
612698
613699 Once the ``GeneratorRun`` is produced, it checks for any fixed parameters
614700 that are not in the target search space and removes them. This is a hack
615- around limitations of the ``RemoveFixed`` transform. Since we construct the
701+ around limitations of the Ax ``RemoveFixed`` transform. Since we construct the
616702 transforms with the joint space, we end up adding back all fixed parameters
617703 from the joint space rather than adding only the parameters from the
618704 target search space. A proper fix would require passing in the search space
@@ -633,6 +719,17 @@ def gen(
633719 if fixed_features is None :
634720 fixed_features = ObservationFeatures (parameters = {})
635721 fixed_features .parameters .setdefault (name , target_p .value )
722+ # Fix source-only params so the optimizer doesn't search over them.
723+ # Center is a reasonable default; LearnedFeatureImputation overwrites
724+ # these with learned values when configured.
725+ joint_center = self .joint_search_space .compute_naive_center ()
726+ for name , param in self .joint_search_space .parameters .items ():
727+ if name not in search_space .parameters and isinstance (
728+ param , RangeParameter
729+ ):
730+ if fixed_features is None :
731+ fixed_features = ObservationFeatures (parameters = {})
732+ fixed_features .parameters .setdefault (name , joint_center [name ])
636733 generator_run = super ().gen (
637734 n = n ,
638735 search_space = search_space ,
@@ -719,12 +816,8 @@ def transfer_learning_generator_specs_constructor(
719816 selector in case there is model selection enabled.
720817 """
721818 input_transform_classes : list [type [InputTransform ]] = [Normalize ]
722- input_transform_options = {
723- "Normalize" : {
724- # None for bounds here ensures we do not use bounds from
725- # the search space digest.
726- "bounds" : None ,
727- }
819+ input_transform_options : dict [str , dict [str , Any ]] = {
820+ "Normalize" : {},
728821 }
729822 transforms = transforms or MBM_X_trans + [MetadataToTask ] + Y_trans
730823 transform_configs = get_derelativize_config (
0 commit comments