Skip to content

Commit 927965b

Browse files
Carl Hvarfnerfacebook-github-bot
authored andcommitted
Use StratifiedStandardize for per-task Y standardization in TL (#5194)
Summary: Adds per-task outcome standardization to the transfer learning adapter, ensuring each task's observations are standardized independently rather than jointly. Updates the default transform pipeline to use TL-specific outcome transforms. This removes ambiguity on whether the right transforms have been applied (e.g. QuickBO/warm-starting), where standardization is not performed across, but within experiments. Differential Revision: D102197139
1 parent e7b3e85 commit 927965b

2 files changed

Lines changed: 7 additions & 4 deletions

File tree

ax/adapter/registry.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -139,6 +139,7 @@
139139
]
140140

141141
Y_trans: list[type[Transform]] = [Derelativize, Winsorize, BilogY, StandardizeY]
142+
TL_Y_trans: list[type[Transform]] = [Derelativize, Winsorize, BilogY]
142143

143144
# Expected `List[Type[Transform]]` for 2nd anonymous parameter to
144145
# call `list.__add__` but got `List[Type[SearchSpaceToChoice]]`.

ax/adapter/transfer_learning/adapter.py

Lines changed: 6 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,7 @@
2222
Generators,
2323
GeneratorSetup,
2424
MBM_X_trans,
25-
Y_trans,
25+
TL_Y_trans,
2626
)
2727
from ax.adapter.torch import FIT_MODEL_ERROR, TorchAdapter
2828
from ax.adapter.transfer_learning.utils import get_joint_search_space
@@ -53,6 +53,7 @@
5353
from ax.utils.common.logger import get_logger
5454
from botorch.models.multitask import MultiTaskGP
5555
from botorch.models.transforms.input import InputTransform, Normalize
56+
from botorch.models.transforms.outcome import StratifiedStandardize
5657
from botorch.utils.datasets import MultiTaskDataset, SupervisedDataset
5758
from gpytorch.kernels.kernel import Kernel
5859
from pyre_extensions import assert_is_instance
@@ -744,7 +745,7 @@ def transfer_learning_generator_specs_constructor(
744745
Args:
745746
model_class: The MultiTask BoTorch Model to use in the BOTL.
746747
transform: Optional list of transforms to use in the Adapter.
747-
Defaults to MBM_X_trans + [MetadataToTask] + Y_trans.
748+
Defaults to MBM_X_trans + [MetadataToTask] + TL_Y_trans.
748749
jit_compile: Whether to use jit compilation in Pyro when the fully Bayesian
749750
model is used.
750751
torch_device: What torch device to use (defaults to None, i.e. falls back to
@@ -779,7 +780,7 @@ def transfer_learning_generator_specs_constructor(
779780
input_transform_options: dict[str, dict[str, Any]] = {
780781
"Normalize": {},
781782
}
782-
transforms = transforms or MBM_X_trans + [MetadataToTask] + Y_trans
783+
transforms = transforms or MBM_X_trans + [MetadataToTask] + TL_Y_trans
783784
transform_configs = get_derelativize_config(
784785
derelativize_with_raw_status_quo=derelativize_with_raw_status_quo
785786
)
@@ -797,6 +798,7 @@ def transfer_learning_generator_specs_constructor(
797798
botorch_model_class=model_class,
798799
model_options=botorch_model_kwargs or {},
799800
input_transform_classes=input_transform_classes,
801+
outcome_transform_classes=[StratifiedStandardize],
800802
input_transform_options=input_transform_options,
801803
mll_options=mll_kwargs,
802804
covar_module_class=covar_module_class,
@@ -838,5 +840,5 @@ def transfer_learning_generator_specs_constructor(
838840
GENERATOR_KEY_TO_GENERATOR_SETUP["BOTL"] = GeneratorSetup(
839841
adapter_class=TransferLearningAdapter,
840842
generator_class=BoTorchGenerator,
841-
transforms=MBM_X_trans + [MetadataToTask] + Y_trans,
843+
transforms=MBM_X_trans + [MetadataToTask] + TL_Y_trans,
842844
)

0 commit comments

Comments
 (0)