Skip to content

Commit 1d7bbdb

Browse files
saitcakmakmeta-codesync[bot]
authored andcommitted
Support observation noise in Log outcome transform (#3245)
Summary: Removes NotImplementedError when Yvar is provided to the Log transform. Uses delta method approximation: Yvar_tf = Yvar / Y^2 in forward, Yvar = Yvar_tf * exp(2 * Y_tf) in untransform. Documents that this assumes Gaussian noise in log-space (log-normal in original space). Closes #2623 Pull Request resolved: #3245 Reviewed By: Balandat Differential Revision: D97353503 Pulled By: saitcakmak fbshipit-source-id: 971317aa197d2c2872c553ed89be30e478f3b26e
1 parent 37de228 commit 1d7bbdb

2 files changed

Lines changed: 78 additions & 26 deletions

File tree

botorch/models/transforms/outcome.py

Lines changed: 37 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,7 @@
2222

2323
from __future__ import annotations
2424

25+
import logging
2526
import warnings
2627
from abc import ABC, abstractmethod
2728
from collections import OrderedDict
@@ -39,6 +40,8 @@
3940
from torch import Tensor
4041
from torch.nn import Module, ModuleDict
4142

43+
logger: logging.Logger = logging.getLogger(__name__)
44+
4245

4346
class OutcomeTransform(Module, ABC):
4447
"""Abstract base class for outcome transforms."""
@@ -726,6 +729,11 @@ class Log(OutcomeTransform):
726729
Useful if the targets are modeled using a (multivariate) log-Normal
727730
distribution. This means that we can use a standard GP model on the
728731
log-transformed outcomes and un-transform the model posterior of that GP.
732+
733+
When observation noise is provided, the variance is transformed using the
734+
delta method approximation: Var[log(Y)] ≈ Var[Y] / Y^2. This assumes that
735+
the observation noise is Gaussian in the log-transformed space, which
736+
corresponds to log-normal observation noise in the original space.
729737
"""
730738

731739
def __init__(self, outputs: list[int] | None = None) -> None:
@@ -789,10 +797,18 @@ def forward(
789797
dim=-1,
790798
)
791799
if Yvar is not None:
792-
# TODO: Delta method, possibly issue warning
793-
raise NotImplementedError(
794-
"Log does not yet support transforming observation noise"
795-
)
800+
# Delta method: Var[log(Y)] ≈ Var[Y] / Y^2
801+
Yvar_tf = Yvar / Y.clamp(min=1e-8).pow(2)
802+
if outputs is not None:
803+
Yvar = torch.stack(
804+
[
805+
Yvar_tf[..., i] if i in outputs else Yvar[..., i]
806+
for i in range(Y.size(-1))
807+
],
808+
dim=-1,
809+
)
810+
else:
811+
Yvar = Yvar_tf
796812
return Y_tf, Yvar
797813

798814
def untransform(
@@ -825,10 +841,24 @@ def untransform(
825841
dim=-1,
826842
)
827843
if Yvar is not None:
828-
# TODO: Delta method, possibly issue warning
829-
raise NotImplementedError(
830-
"Log does not yet support transforming observation noise"
844+
# Reverse of delta method: Var[Y] = Var[log(Y)] * Y^2
845+
# Since Y = exp(Y_log), this is Var[log(Y)] * exp(2 * Y_log)
846+
logger.debug(
847+
"Log.untransform: Reverse delta method for observation noise "
848+
"is a lossy operation. The untransformed variance is an "
849+
"approximation that may not exactly match the original variance."
831850
)
851+
Yvar_utf = Yvar * torch.exp(2.0 * Y)
852+
if outputs is not None:
853+
Yvar = torch.stack(
854+
[
855+
Yvar_utf[..., i] if i in outputs else Yvar[..., i]
856+
for i in range(Y.size(-1))
857+
],
858+
dim=-1,
859+
)
860+
else:
861+
Yvar = Yvar_utf
832862
return Y_utf, Yvar
833863

834864
def untransform_posterior(

test/models/transforms/test_outcome.py

Lines changed: 41 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -597,17 +597,23 @@ def test_log(self):
597597
self.assertTrue(torch.equal(Y_tf[..., [0]], Y_tf_subset))
598598
self.assertIsNone(Yvar_tf_subset)
599599

600-
# test error if observation noise present
600+
# test with observation noise (delta method)
601601
tf = Log()
602-
Y = torch.rand(*batch_shape, 3, m, device=self.device, dtype=dtype)
602+
Y = 1e-2 + torch.rand(*batch_shape, 3, m, device=self.device, dtype=dtype)
603603
Yvar = 1e-8 + torch.rand(
604604
*batch_shape, 3, m, device=self.device, dtype=dtype
605605
)
606-
with self.assertRaises(NotImplementedError):
607-
tf(Y, Yvar)
606+
Y_tf, Yvar_tf = tf(Y, Yvar)
607+
self.assertTrue(tf.training)
608+
self.assertAllClose(Y_tf, torch.log(Y))
609+
# Delta method: Var[log(Y)] ≈ Var[Y] / Y^2
610+
self.assertAllClose(Yvar_tf, Yvar / Y.pow(2))
608611
tf.eval()
609-
with self.assertRaises(NotImplementedError):
610-
tf.untransform(Y, Yvar)
612+
self.assertFalse(tf.training)
613+
Y_utf, Yvar_utf = tf.untransform(Y_tf, Yvar_tf)
614+
self.assertAllClose(Y_utf, Y)
615+
# Reverse: Var[Y] = Var[log(Y)] * exp(2 * log(Y)) = Var[log(Y)] * Y^2
616+
self.assertAllClose(Yvar_utf, Yvar)
611617

612618
# untransform_posterior
613619
tf = Log()
@@ -661,14 +667,22 @@ def test_log(self):
661667
with self.assertRaises(NotImplementedError):
662668
tf_subset = tf.subset_output(idcs=[0])
663669

664-
# with observation noise
670+
# with observation noise (subset of outputs)
665671
tf = Log(outputs=outputs)
666-
Y = torch.rand(*batch_shape, 3, m, device=self.device, dtype=dtype)
672+
Y = 1e-2 + torch.rand(*batch_shape, 3, m, device=self.device, dtype=dtype)
667673
Yvar = 1e-8 + torch.rand(
668674
*batch_shape, 3, m, device=self.device, dtype=dtype
669675
)
670-
with self.assertRaises(NotImplementedError):
671-
tf(Y, Yvar)
676+
Y_tf, Yvar_tf = tf(Y, Yvar)
677+
# output 0 should be untransformed, output 1 should be transformed
678+
self.assertAllClose(Y_tf[..., 0], Y[..., 0])
679+
self.assertAllClose(Y_tf[..., 1], torch.log(Y[..., 1]))
680+
self.assertAllClose(Yvar_tf[..., 0], Yvar[..., 0])
681+
self.assertAllClose(Yvar_tf[..., 1], Yvar[..., 1] / Y[..., 1].pow(2))
682+
tf.eval()
683+
Y_utf, Yvar_utf = tf.untransform(Y_tf, Yvar_tf)
684+
self.assertAllClose(Y_utf, Y)
685+
self.assertAllClose(Yvar_utf, Yvar)
672686

673687
# error on untransform_posterior
674688
with self.assertRaises(NotImplementedError):
@@ -722,13 +736,17 @@ def test_chained_outcome_transform(self):
722736
with self.assertRaises(RuntimeError):
723737
tf.subset_output(idcs=[0, 1, 2])
724738

725-
# test error if observation noise present
726-
Y = torch.rand(*batch_shape, 3, m, device=self.device, dtype=dtype)
739+
# test observation noise is propagated through chained transform
740+
Y = 1e-2 + torch.rand(*batch_shape, 3, m, device=self.device, dtype=dtype)
727741
Yvar = 1e-8 + torch.rand(
728742
*batch_shape, 3, m, device=self.device, dtype=dtype
729743
)
730-
with self.assertRaises(NotImplementedError):
731-
tf(Y, Yvar)
744+
tf1 = Log()
745+
tf2 = Standardize(m=m, batch_shape=batch_shape)
746+
tf = ChainedOutcomeTransform(log=tf1, standardize=tf2)
747+
Y_tf, Yvar_tf = tf(Y, Yvar)
748+
self.assertEqual(Y_tf.shape, Y.shape)
749+
self.assertEqual(Yvar_tf.shape, Yvar.shape)
732750

733751
# untransform_posterior
734752
tf1 = Log()
@@ -781,15 +799,19 @@ def test_chained_outcome_transform(self):
781799
torch.allclose(Y_utf, Y)
782800
self.assertIsNone(Yvar_utf)
783801

784-
# with observation noise
785-
Y = torch.rand(*batch_shape, 3, m, device=self.device, dtype=dtype)
802+
# with observation noise (subset outputs)
803+
Y = 1e-2 + torch.rand(*batch_shape, 3, m, device=self.device, dtype=dtype)
786804
Yvar = 1e-8 + torch.rand(
787805
*batch_shape, 3, m, device=self.device, dtype=dtype
788806
)
789-
with self.assertRaises(NotImplementedError):
790-
tf(Y, Yvar)
807+
tf1 = Log(outputs=outputs)
808+
tf2 = Standardize(m=m, outputs=outputs, batch_shape=batch_shape)
809+
tf = ChainedOutcomeTransform(log=tf1, standardize=tf2)
810+
Y_tf, Yvar_tf = tf(Y, Yvar)
811+
self.assertEqual(Y_tf.shape, Y.shape)
812+
self.assertEqual(Yvar_tf.shape, Yvar.shape)
791813

792-
# error on untransform_posterior
814+
# error on untransform_posterior (subset outputs not supported)
793815
with self.assertRaises(NotImplementedError):
794816
tf.untransform_posterior(None)
795817

0 commit comments

Comments
 (0)