Skip to content

Commit e407627

Browse files
Carl Hvarfnerfacebook-github-bot
authored andcommitted
Use StratifiedStandardize for per-task Y standardization in TL (#5194)
Summary: Adds per-task outcome standardization to the transfer learning adapter, ensuring each task's observations are standardized independently rather than jointly. Updates the default transform pipeline to use TL-specific outcome transforms. This removes ambiguity on whether the right transforms have been applied (e.g. QuickBO/warm-starting), where standardization is not performed across, but within experiments. Differential Revision: D102197139
1 parent 6e0bd89 commit e407627

2 files changed

Lines changed: 7 additions & 4 deletions

File tree

ax/adapter/registry.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -139,6 +139,7 @@
139139
]
140140

141141
Y_trans: list[type[Transform]] = [Derelativize, Winsorize, BilogY, StandardizeY]
142+
TL_Y_trans: list[type[Transform]] = [Derelativize, Winsorize, BilogY]
142143

143144
# Expected `List[Type[Transform]]` for 2nd anonymous parameter to
144145
# call `list.__add__` but got `List[Type[SearchSpaceToChoice]]`.

ax/adapter/transfer_learning/adapter.py

Lines changed: 6 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,7 @@
2222
Generators,
2323
GeneratorSetup,
2424
MBM_X_trans,
25-
Y_trans,
25+
TL_Y_trans,
2626
)
2727
from ax.adapter.torch import FIT_MODEL_ERROR, TorchAdapter
2828
from ax.adapter.transfer_learning.utils import get_joint_search_space
@@ -53,6 +53,7 @@
5353
from ax.utils.common.logger import get_logger
5454
from botorch.models.multitask import MultiTaskGP
5555
from botorch.models.transforms.input import InputTransform, Normalize
56+
from botorch.models.transforms.outcome import StratifiedStandardize
5657
from botorch.utils.datasets import MultiTaskDataset, SupervisedDataset
5758
from gpytorch.kernels.kernel import Kernel
5859
from pyre_extensions import assert_is_instance
@@ -745,7 +746,7 @@ def transfer_learning_generator_specs_constructor(
745746
Args:
746747
model_class: The MultiTask BoTorch Model to use in the BOTL.
747748
transform: Optional list of transforms to use in the Adapter.
748-
Defaults to MBM_X_trans + [MetadataToTask] + Y_trans.
749+
Defaults to MBM_X_trans + [MetadataToTask] + TL_Y_trans.
749750
jit_compile: Whether to use jit compilation in Pyro when the fully Bayesian
750751
model is used.
751752
torch_device: What torch device to use (defaults to None, i.e. falls back to
@@ -780,7 +781,7 @@ def transfer_learning_generator_specs_constructor(
780781
input_transform_options: dict[str, dict[str, Any]] = {
781782
"Normalize": {},
782783
}
783-
transforms = transforms or MBM_X_trans + [MetadataToTask] + Y_trans
784+
transforms = transforms or MBM_X_trans + [MetadataToTask] + TL_Y_trans
784785
transform_configs = get_derelativize_config(
785786
derelativize_with_raw_status_quo=derelativize_with_raw_status_quo
786787
)
@@ -798,6 +799,7 @@ def transfer_learning_generator_specs_constructor(
798799
botorch_model_class=model_class,
799800
model_options=botorch_model_kwargs or {},
800801
input_transform_classes=input_transform_classes,
802+
outcome_transform_classes=[StratifiedStandardize],
801803
input_transform_options=input_transform_options,
802804
mll_options=mll_kwargs,
803805
covar_module_class=covar_module_class,
@@ -839,5 +841,5 @@ def transfer_learning_generator_specs_constructor(
839841
GENERATOR_KEY_TO_GENERATOR_SETUP["BOTL"] = GeneratorSetup(
840842
adapter_class=TransferLearningAdapter,
841843
generator_class=BoTorchGenerator,
842-
transforms=MBM_X_trans + [MetadataToTask] + Y_trans,
844+
transforms=MBM_X_trans + [MetadataToTask] + TL_Y_trans,
843845
)

0 commit comments

Comments
 (0)