@@ -597,17 +597,23 @@ def test_log(self):
597597 self .assertTrue (torch .equal (Y_tf [..., [0 ]], Y_tf_subset ))
598598 self .assertIsNone (Yvar_tf_subset )
599599
600- # test error if observation noise present
600+ # test with observation noise (delta method)
601601 tf = Log ()
602- Y = torch .rand (* batch_shape , 3 , m , device = self .device , dtype = dtype )
602+ Y = 1e-2 + torch .rand (* batch_shape , 3 , m , device = self .device , dtype = dtype )
603603 Yvar = 1e-8 + torch .rand (
604604 * batch_shape , 3 , m , device = self .device , dtype = dtype
605605 )
606- with self .assertRaises (NotImplementedError ):
607- tf (Y , Yvar )
606+ Y_tf , Yvar_tf = tf (Y , Yvar )
607+ self .assertTrue (tf .training )
608+ self .assertAllClose (Y_tf , torch .log (Y ))
609+ # Delta method: Var[log(Y)] ≈ Var[Y] / Y^2
610+ self .assertAllClose (Yvar_tf , Yvar / Y .pow (2 ))
608611 tf .eval ()
609- with self .assertRaises (NotImplementedError ):
610- tf .untransform (Y , Yvar )
612+ self .assertFalse (tf .training )
613+ Y_utf , Yvar_utf = tf .untransform (Y_tf , Yvar_tf )
614+ self .assertAllClose (Y_utf , Y )
615+ # Reverse: Var[Y] = Var[log(Y)] * exp(2 * log(Y)) = Var[log(Y)] * Y^2
616+ self .assertAllClose (Yvar_utf , Yvar )
611617
612618 # untransform_posterior
613619 tf = Log ()
@@ -661,14 +667,22 @@ def test_log(self):
661667 with self .assertRaises (NotImplementedError ):
662668 tf_subset = tf .subset_output (idcs = [0 ])
663669
664- # with observation noise
670+ # with observation noise (subset of outputs)
665671 tf = Log (outputs = outputs )
666- Y = torch .rand (* batch_shape , 3 , m , device = self .device , dtype = dtype )
672+ Y = 1e-2 + torch .rand (* batch_shape , 3 , m , device = self .device , dtype = dtype )
667673 Yvar = 1e-8 + torch .rand (
668674 * batch_shape , 3 , m , device = self .device , dtype = dtype
669675 )
670- with self .assertRaises (NotImplementedError ):
671- tf (Y , Yvar )
676+ Y_tf , Yvar_tf = tf (Y , Yvar )
677+ # output 0 should be untransformed, output 1 should be transformed
678+ self .assertAllClose (Y_tf [..., 0 ], Y [..., 0 ])
679+ self .assertAllClose (Y_tf [..., 1 ], torch .log (Y [..., 1 ]))
680+ self .assertAllClose (Yvar_tf [..., 0 ], Yvar [..., 0 ])
681+ self .assertAllClose (Yvar_tf [..., 1 ], Yvar [..., 1 ] / Y [..., 1 ].pow (2 ))
682+ tf .eval ()
683+ Y_utf , Yvar_utf = tf .untransform (Y_tf , Yvar_tf )
684+ self .assertAllClose (Y_utf , Y )
685+ self .assertAllClose (Yvar_utf , Yvar )
672686
673687 # error on untransform_posterior
674688 with self .assertRaises (NotImplementedError ):
@@ -722,13 +736,17 @@ def test_chained_outcome_transform(self):
722736 with self .assertRaises (RuntimeError ):
723737 tf .subset_output (idcs = [0 , 1 , 2 ])
724738
725- # test error if observation noise present
726- Y = torch .rand (* batch_shape , 3 , m , device = self .device , dtype = dtype )
739+ # test observation noise is propagated through chained transform
740+ Y = 1e-2 + torch .rand (* batch_shape , 3 , m , device = self .device , dtype = dtype )
727741 Yvar = 1e-8 + torch .rand (
728742 * batch_shape , 3 , m , device = self .device , dtype = dtype
729743 )
730- with self .assertRaises (NotImplementedError ):
731- tf (Y , Yvar )
744+ tf1 = Log ()
745+ tf2 = Standardize (m = m , batch_shape = batch_shape )
746+ tf = ChainedOutcomeTransform (log = tf1 , standardize = tf2 )
747+ Y_tf , Yvar_tf = tf (Y , Yvar )
748+ self .assertEqual (Y_tf .shape , Y .shape )
749+ self .assertEqual (Yvar_tf .shape , Yvar .shape )
732750
733751 # untransform_posterior
734752 tf1 = Log ()
@@ -781,15 +799,19 @@ def test_chained_outcome_transform(self):
781799 torch .allclose (Y_utf , Y )
782800 self .assertIsNone (Yvar_utf )
783801
784- # with observation noise
785- Y = torch .rand (* batch_shape , 3 , m , device = self .device , dtype = dtype )
802+ # with observation noise (subset outputs)
803+ Y = 1e-2 + torch .rand (* batch_shape , 3 , m , device = self .device , dtype = dtype )
786804 Yvar = 1e-8 + torch .rand (
787805 * batch_shape , 3 , m , device = self .device , dtype = dtype
788806 )
789- with self .assertRaises (NotImplementedError ):
790- tf (Y , Yvar )
807+ tf1 = Log (outputs = outputs )
808+ tf2 = Standardize (m = m , outputs = outputs , batch_shape = batch_shape )
809+ tf = ChainedOutcomeTransform (log = tf1 , standardize = tf2 )
810+ Y_tf , Yvar_tf = tf (Y , Yvar )
811+ self .assertEqual (Y_tf .shape , Y .shape )
812+ self .assertEqual (Yvar_tf .shape , Yvar .shape )
791813
792- # error on untransform_posterior
814+ # error on untransform_posterior (subset outputs not supported)
793815 with self .assertRaises (NotImplementedError ):
794816 tf .untransform_posterior (None )
795817
0 commit comments