Skip to content

Commit f79d608

Browse files
esantorellafacebook-github-bot
authored andcommitted
Clarify behavior on standard deviations with <1 degree of freedom and silence some unit test warnings (#2357)
Summary: ## Motivation Unit tests were producing a lot of warnings about taking standard deviations across fewer than 2 observations, and it was not clear to me if these warnings were legitimate in context. * For checking the standardization of input data, no longer check the standard deviation if there is just one observation. * For the standardize input transform, explicitly set standard deviations to 1 when there is only one observation. This actually matches the legacy behavior, but previously it wasn't clear because the standard deviation would become NaN before being corrected to 1. * Error on attempting to standardize 0 observations. This never worked so now it is more clear. Pull Request resolved: #2357 Test Plan: Added units ## Related PRs Reviewed By: Balandat Differential Revision: D57931412 Pulled By: esantorella fbshipit-source-id: 36a9c81a950a0b92749673fdd22aec62b45aaae9
1 parent 1e73b30 commit f79d608

File tree

5 files changed

+77
-31
lines changed

5 files changed

+77
-31
lines changed

botorch/models/transforms/outcome.py

+9-1
Original file line numberDiff line numberDiff line change
@@ -286,7 +286,15 @@ def forward(
286286
f"Wrong output dimension. Y.size(-1) is {Y.size(-1)}; expected "
287287
f"{self._m}."
288288
)
289-
stdvs = Y.std(dim=-2, keepdim=True)
289+
if Y.shape[-2] < 1:
290+
raise ValueError(f"Can't standardize with no observations. {Y.shape=}.")
291+
292+
elif Y.shape[-2] == 1:
293+
stdvs = torch.ones(
294+
(*Y.shape[:-2], 1, Y.shape[-1]), dtype=Y.dtype, device=Y.device
295+
)
296+
else:
297+
stdvs = Y.std(dim=-2, keepdim=True)
290298
stdvs = stdvs.where(stdvs >= self._min_stdv, torch.full_like(stdvs, 1.0))
291299
means = Y.mean(dim=-2, keepdim=True)
292300
if self._outputs is not None:

botorch/models/utils/assorted.py

+23-10
Original file line numberDiff line numberDiff line change
@@ -171,7 +171,7 @@ def check_min_max_scaling(
171171
)
172172
if raise_on_fail:
173173
raise InputDataError(msg)
174-
warnings.warn(msg, InputDataWarning)
174+
warnings.warn(msg, InputDataWarning, stacklevel=2)
175175

176176

177177
def check_standardization(
@@ -191,15 +191,28 @@ def check_standardization(
191191
raise_on_fail: If True, raise an exception instead of a warning.
192192
"""
193193
with torch.no_grad():
194-
Ymean, Ystd = torch.mean(Y, dim=-2), torch.std(Y, dim=-2)
195-
if torch.abs(Ymean).max() > atol_mean or torch.abs(Ystd - 1).max() > atol_std:
196-
msg = (
197-
f"Input data is not standardized (mean = {Ymean}, std = {Ystd}). "
198-
"Please consider scaling the input to zero mean and unit variance."
199-
)
200-
if raise_on_fail:
201-
raise InputDataError(msg)
202-
warnings.warn(msg, InputDataWarning)
194+
Ymean = torch.mean(Y, dim=-2)
195+
mean_not_zero = torch.abs(Ymean).max() > atol_mean
196+
if Y.shape[-2] <= 1:
197+
if mean_not_zero:
198+
msg = (
199+
f"Data is not standardized (mean = {Ymean}). "
200+
"Please consider scaling the input to zero mean and unit variance."
201+
)
202+
if raise_on_fail:
203+
raise InputDataError(msg)
204+
warnings.warn(msg, InputDataWarning, stacklevel=2)
205+
else:
206+
Ystd = torch.std(Y, dim=-2)
207+
std_not_one = torch.abs(Ystd - 1).max() > atol_std
208+
if mean_not_zero or std_not_one:
209+
msg = (
210+
f"Data is not standardized (std = {Ystd}, mean = {Ymean}). "
211+
"Please consider scaling the input to zero mean and unit variance."
212+
)
213+
if raise_on_fail:
214+
raise InputDataError(msg)
215+
warnings.warn(msg, InputDataWarning, stacklevel=2)
203216

204217

205218
def validate_input_scaling(

botorch/utils/testing.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -60,7 +60,7 @@ def setUp(self, suppress_input_warnings: bool = True) -> None:
6060
)
6161
warnings.filterwarnings(
6262
"ignore",
63-
message="Input data is not standardized.",
63+
message="Data is not standardized.",
6464
category=InputDataWarning,
6565
)
6666
warnings.filterwarnings(

test/models/transforms/test_outcome.py

+21-11
Original file line numberDiff line numberDiff line change
@@ -115,7 +115,14 @@ def test_is_linear(self) -> None:
115115
)
116116
self.assertEqual(posterior_is_gpt, transform._is_linear)
117117

118-
def test_standardize(self):
118+
def test_standardize_raises_when_no_observations(self) -> None:
119+
tf = Standardize(m=1)
120+
with self.assertRaisesRegex(
121+
ValueError, "Can't standardize with no observations."
122+
):
123+
tf(torch.zeros(0, 1, device=self.device), None)
124+
125+
def test_standardize(self) -> None:
119126
# test error on incompatible dim
120127
tf = Standardize(m=1)
121128
with self.assertRaisesRegex(
@@ -134,9 +141,10 @@ def test_standardize(self):
134141
ms = (1, 2)
135142
batch_shapes = (torch.Size(), torch.Size([2]))
136143
dtypes = (torch.float, torch.double)
144+
ns = [1, 3]
137145

138146
# test transform, untransform, untransform_posterior
139-
for m, batch_shape, dtype in itertools.product(ms, batch_shapes, dtypes):
147+
for m, batch_shape, dtype, n in itertools.product(ms, batch_shapes, dtypes, ns):
140148
# test init
141149
tf = Standardize(m=m, batch_shape=batch_shape)
142150
self.assertTrue(tf.training)
@@ -148,7 +156,7 @@ def test_standardize(self):
148156
# no observation noise
149157
with torch.random.fork_rng():
150158
torch.manual_seed(0)
151-
Y = torch.rand(*batch_shape, 3, m, device=self.device, dtype=dtype)
159+
Y = torch.rand(*batch_shape, n, m, device=self.device, dtype=dtype)
152160
Y_tf, Yvar_tf = tf(Y, None)
153161
self.assertTrue(tf.training)
154162
self.assertTrue(torch.all(Y_tf.mean(dim=-2).abs() < 1e-4))
@@ -171,14 +179,16 @@ def test_standardize(self):
171179
tf = Standardize(m=m, batch_shape=batch_shape)
172180
with torch.random.fork_rng():
173181
torch.manual_seed(0)
174-
Y = torch.rand(*batch_shape, 3, m, device=self.device, dtype=dtype)
182+
Y = torch.rand(*batch_shape, n, m, device=self.device, dtype=dtype)
175183
Yvar = 1e-8 + torch.rand(
176-
*batch_shape, 3, m, device=self.device, dtype=dtype
184+
*batch_shape, n, m, device=self.device, dtype=dtype
177185
)
178186
Y_tf, Yvar_tf = tf(Y, Yvar)
179187
self.assertTrue(tf.training)
180188
self.assertTrue(torch.all(Y_tf.mean(dim=-2).abs() < 1e-4))
181-
Yvar_tf_expected = Yvar / Y.std(dim=-2, keepdim=True) ** 2
189+
Yvar_tf_expected = (
190+
Yvar if n == 1 else Yvar / Y.std(dim=-2, keepdim=True) ** 2
191+
)
182192
self.assertAllClose(Yvar_tf, Yvar_tf_expected)
183193
tf.eval()
184194
self.assertFalse(tf.training)
@@ -190,7 +200,7 @@ def test_standardize(self):
190200
for interleaved, lazy in itertools.product((True, False), (True, False)):
191201
if m == 1 and interleaved: # interleave has no meaning for m=1
192202
continue
193-
shape = batch_shape + torch.Size([3, m])
203+
shape = batch_shape + torch.Size([n, m])
194204
posterior = _get_test_posterior(
195205
shape,
196206
device=self.device,
@@ -216,12 +226,12 @@ def test_standardize(self):
216226
# Untransform BlockDiagLinearOperator.
217227
if m > 1:
218228
base_lcv = DiagLinearOperator(
219-
torch.rand(*batch_shape, m, 3, device=self.device, dtype=dtype)
229+
torch.rand(*batch_shape, m, n, device=self.device, dtype=dtype)
220230
)
221231
lcv = BlockDiagLinearOperator(base_lcv)
222232
mvn = MultitaskMultivariateNormal(
223233
mean=torch.rand(
224-
*batch_shape, 3, m, device=self.device, dtype=dtype
234+
*batch_shape, n, m, device=self.device, dtype=dtype
225235
),
226236
covariance_matrix=lcv,
227237
interleaved=False,
@@ -240,7 +250,7 @@ def test_standardize(self):
240250
samples2 = p_utf.rsample(sample_shape=torch.Size([4, 2]))
241251
self.assertEqual(
242252
samples2.shape,
243-
torch.Size([4, 2]) + batch_shape + torch.Size([3, m]),
253+
torch.Size([4, 2]) + batch_shape + torch.Size([n, m]),
244254
)
245255

246256
# untransform_posterior for non-GPyTorch posterior
@@ -252,7 +262,7 @@ def test_standardize(self):
252262
)
253263
p_utf2 = tf.untransform_posterior(posterior2)
254264
self.assertEqual(p_utf2.device.type, self.device.type)
255-
self.assertTrue(p_utf2.dtype == dtype)
265+
self.assertEqual(p_utf2.dtype, dtype)
256266
mean_expected = tf.means + tf.stdvs * posterior.mean
257267
variance_expected = tf.stdvs**2 * posterior.variance
258268
self.assertAllClose(p_utf2.mean, mean_expected)

test/models/utils/test_assorted.py

+23-8
Original file line numberDiff line numberDiff line change
@@ -141,26 +141,41 @@ def test_check_min_max_scaling(self):
141141
def test_check_standardization(self):
142142
# Ensure that it is not filtered out.
143143
warnings.filterwarnings("always", category=InputDataWarning)
144+
torch.manual_seed(0)
144145
Y = torch.randn(3, 4, 2)
145146
# check standardized input
146147
Yst = (Y - Y.mean(dim=-2, keepdim=True)) / Y.std(dim=-2, keepdim=True)
147148
with warnings.catch_warnings(record=True) as ws:
148149
check_standardization(Y=Yst)
149150
self.assertFalse(any(issubclass(w.category, InputDataWarning) for w in ws))
150151
check_standardization(Y=Yst, raise_on_fail=True)
151-
# check nonzero mean
152+
153+
# check standardized input with one observation
154+
y = torch.zeros((3, 1, 2))
152155
with warnings.catch_warnings(record=True) as ws:
156+
check_standardization(Y=y)
157+
self.assertFalse(any(issubclass(w.category, InputDataWarning) for w in ws))
158+
check_standardization(Y=y, raise_on_fail=True)
159+
160+
# check nonzero mean for case where >= 2 observations per batch
161+
msg_more_than_1_obs = r"Data is not standardized \(std ="
162+
with self.assertWarnsRegex(InputDataWarning, msg_more_than_1_obs):
153163
check_standardization(Y=Yst + 1)
154-
self.assertTrue(any(issubclass(w.category, InputDataWarning) for w in ws))
155-
self.assertTrue(any("not standardized" in str(w.message) for w in ws))
156-
with self.assertRaises(InputDataError):
164+
with self.assertRaisesRegex(InputDataError, msg_more_than_1_obs):
157165
check_standardization(Y=Yst + 1, raise_on_fail=True)
166+
167+
# check nonzero mean for case where < 2 observations per batch
168+
msg_one_obs = r"Data is not standardized \(mean ="
169+
y = torch.ones((3, 1, 2), dtype=torch.float32)
170+
with self.assertWarnsRegex(InputDataWarning, msg_one_obs):
171+
check_standardization(Y=y)
172+
with self.assertRaisesRegex(InputDataError, msg_one_obs):
173+
check_standardization(Y=y, raise_on_fail=True)
174+
158175
# check non-unit variance
159-
with warnings.catch_warnings(record=True) as ws:
176+
with self.assertWarnsRegex(InputDataWarning, msg_more_than_1_obs):
160177
check_standardization(Y=Yst * 2)
161-
self.assertTrue(any(issubclass(w.category, InputDataWarning) for w in ws))
162-
self.assertTrue(any("not standardized" in str(w.message) for w in ws))
163-
with self.assertRaises(InputDataError):
178+
with self.assertRaisesRegex(InputDataError, msg_more_than_1_obs):
164179
check_standardization(Y=Yst * 2, raise_on_fail=True)
165180

166181
def test_validate_input_scaling(self):

0 commit comments

Comments
 (0)