Skip to content

Commit 06d84cd

Browse files
Carl Hvarfnerfacebook-github-bot
authored andcommitted
Use StratifiedStandardize for per-task Y standardization in TL (#5194)
Summary: Adds per-task outcome standardization to the transfer learning adapter, ensuring each task's observations are standardized independently rather than jointly. Updates the default transform pipeline to use TL-specific outcome transforms. This removes ambiguity on whether the right transforms have been applied (e.g. QuickBO/warm-starting), where standardization is not performed across, but within experiments. Differential Revision: D102197139
1 parent 93d1add commit 06d84cd

2 files changed

Lines changed: 7 additions & 4 deletions

File tree

ax/adapter/registry.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -139,6 +139,7 @@
139139
]
140140

141141
Y_trans: list[type[Transform]] = [Derelativize, Winsorize, BilogY, StandardizeY]
142+
TL_Y_trans: list[type[Transform]] = [Derelativize, Winsorize, BilogY]
142143

143144
# Expected `List[Type[Transform]]` for 2nd anonymous parameter to
144145
# call `list.__add__` but got `List[Type[SearchSpaceToChoice]]`.

ax/adapter/transfer_learning/adapter.py

Lines changed: 6 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -23,7 +23,7 @@
2323
Generators,
2424
GeneratorSetup,
2525
MBM_X_trans,
26-
Y_trans,
26+
TL_Y_trans,
2727
)
2828
from ax.adapter.torch import FIT_MODEL_ERROR, TorchAdapter
2929
from ax.adapter.transfer_learning.utils import get_joint_search_space
@@ -54,6 +54,7 @@
5454
from ax.utils.common.logger import get_logger
5555
from botorch.models.multitask import MultiTaskGP
5656
from botorch.models.transforms.input import InputTransform, Normalize
57+
from botorch.models.transforms.outcome import StratifiedStandardize
5758
from botorch.utils.datasets import MultiTaskDataset, SupervisedDataset
5859
from gpytorch.kernels.kernel import Kernel
5960
from pyre_extensions import assert_is_instance
@@ -793,7 +794,7 @@ def transfer_learning_generator_specs_constructor(
793794
Args:
794795
model_class: The MultiTask BoTorch Model to use in the BOTL.
795796
transform: Optional list of transforms to use in the Adapter.
796-
Defaults to MBM_X_trans + [MetadataToTask] + Y_trans.
797+
Defaults to MBM_X_trans + [MetadataToTask] + TL_Y_trans.
797798
jit_compile: Whether to use jit compilation in Pyro when the fully Bayesian
798799
model is used.
799800
torch_device: What torch device to use (defaults to None, i.e. falls back to
@@ -828,7 +829,7 @@ def transfer_learning_generator_specs_constructor(
828829
input_transform_options: dict[str, dict[str, Any]] = {
829830
"Normalize": {},
830831
}
831-
transforms = transforms or MBM_X_trans + [MetadataToTask] + Y_trans
832+
transforms = transforms or MBM_X_trans + [MetadataToTask] + TL_Y_trans
832833
transform_configs = get_derelativize_config(
833834
derelativize_with_raw_status_quo=derelativize_with_raw_status_quo
834835
)
@@ -846,6 +847,7 @@ def transfer_learning_generator_specs_constructor(
846847
botorch_model_class=model_class,
847848
model_options=botorch_model_kwargs or {},
848849
input_transform_classes=input_transform_classes,
850+
outcome_transform_classes=[StratifiedStandardize],
849851
input_transform_options=input_transform_options,
850852
mll_options=mll_kwargs,
851853
covar_module_class=covar_module_class,
@@ -887,5 +889,5 @@ def transfer_learning_generator_specs_constructor(
887889
GENERATOR_KEY_TO_GENERATOR_SETUP["BOTL"] = GeneratorSetup(
888890
adapter_class=TransferLearningAdapter,
889891
generator_class=BoTorchGenerator,
890-
transforms=MBM_X_trans + [MetadataToTask] + Y_trans,
892+
transforms=MBM_X_trans + [MetadataToTask] + TL_Y_trans,
891893
)

0 commit comments

Comments
 (0)