Skip to content

Commit 6090284

Browse files
jduerholtfacebook-github-bot
authored andcommitted
remove input transform checks (#1568)
Summary: <!-- Thank you for sending the PR! We appreciate you spending the time to make BoTorch better. Help us understand your motivation by explaining why you decided to make this change. You can learn more about contributing to BoTorch here: https://github.com/pytorch/botorch/blob/main/CONTRIBUTING.md --> ## Motivation As I am currently refactoring our internal codebase, I had a look at sdaulton PR regarding probabilistic reparameterization. From my understanding one has to use it by representing the categoricals by a one hot encoding for the reparmeterized ACQF and then eventually transforming the input to a numerical represenation via `OneHotToNumeric` especially when one wants to use it togehter with `MixedSingleTaskGP`. Currently MixedSingleTaskGP is very strict on which input transforms are allowed. This PR lifts the restrictions to make it usable with OneHotToNumeric`. Note that the transform also has to be instantiated with `transform_on_train = False` and `train_X` has to be transformed before it is passed to the constructor of `MixedSingleTaskGP`, else the indices for the different kernels are mixed up. ### Have you read the [Contributing Guidelines on pull requests](https://github.com/pytorch/botorch/blob/main/CONTRIBUTING.md#pull-requests)? Yes. Pull Request resolved: #1568 Test Plan: Unit tests. ## Related PRs #1534 Reviewed By: esantorella Differential Revision: D42230252 Pulled By: Balandat fbshipit-source-id: b6a0a12d926fbab9890a75438eb60ef849441149
1 parent 00e84ff commit 6090284

File tree

2 files changed

+5
-51
lines changed

2 files changed

+5
-51
lines changed

botorch/models/gp_regression_mixed.py

+5-17
Original file line numberDiff line numberDiff line change
@@ -89,29 +89,17 @@ def __init__(
8989
`.posterior` on the model will be on the original scale).
9090
input_transform: An input transform that is applied in the model's
9191
forward pass. Only input transforms are allowed which do not
92-
transform the categorical dimensions. This can be achieved
93-
by using the `indices` argument when constructing the transform.
92+
transform the categorical dimensions. If you want to use it
93+
for example in combination with a `OneHotToNumeric` input transform
94+
one has to instantiate the transform with `transform_on_train` == False
95+
and pass in the already transformed input.
9496
"""
95-
if input_transform is not None:
96-
if not hasattr(input_transform, "indices"):
97-
raise ValueError(
98-
"Only continuous inputs can be transformed. "
99-
"Please use `indices` in the `input_transform`."
100-
)
101-
# check that no cat dim is in indices
102-
elif any(idx in input_transform.indices for idx in cat_dims):
103-
raise ValueError(
104-
"Only continuous inputs can be transformed. "
105-
"Categorical index found in `indices` of the `input_transform`."
106-
)
10797
if len(cat_dims) == 0:
10898
raise ValueError(
10999
"Must specify categorical dimensions for MixedSingleTaskGP"
110100
)
111101
self._ignore_X_dims_scaling_check = cat_dims
112-
input_batch_shape, aug_batch_shape = self.get_batch_dimensions(
113-
train_X=train_X, train_Y=train_Y
114-
)
102+
_, aug_batch_shape = self.get_batch_dimensions(train_X=train_X, train_Y=train_Y)
115103

116104
if cont_kernel_factory is None:
117105

test/models/test_gp_regression_mixed.py

-34
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,6 @@
55
# LICENSE file in the root directory of this source tree.
66

77
import itertools
8-
import random
98
import warnings
109

1110
import torch
@@ -44,15 +43,6 @@ def test_gp(self):
4443
)
4544
cat_dims = list(range(ncat))
4645
ord_dims = sorted(set(range(d)) - set(cat_dims))
47-
with self.assertRaises(ValueError):
48-
MixedSingleTaskGP(
49-
train_X,
50-
train_Y,
51-
cat_dims=cat_dims,
52-
input_transform=Normalize(
53-
d=d, bounds=bounds.to(**tkwargs), transform_on_train=True
54-
),
55-
)
5646
# test correct indices
5747
if (ncat < 3) and (ncat > 0):
5848
MixedSingleTaskGP(
@@ -66,30 +56,6 @@ def test_gp(self):
6656
indices=ord_dims,
6757
),
6858
)
69-
with self.assertRaises(ValueError):
70-
MixedSingleTaskGP(
71-
train_X,
72-
train_Y,
73-
cat_dims=cat_dims,
74-
input_transform=Normalize(
75-
d=d,
76-
bounds=bounds.to(**tkwargs),
77-
transform_on_train=True,
78-
indices=cat_dims,
79-
),
80-
)
81-
with self.assertRaises(ValueError):
82-
MixedSingleTaskGP(
83-
train_X,
84-
train_Y,
85-
cat_dims=cat_dims,
86-
input_transform=Normalize(
87-
d=d,
88-
bounds=bounds.to(**tkwargs),
89-
transform_on_train=True,
90-
indices=ord_dims + [random.choice(cat_dims)],
91-
),
92-
)
9359

9460
if len(cat_dims) == 0:
9561
with self.assertRaises(ValueError):

0 commit comments

Comments
 (0)