1- # Copyright 2018 The Texar Authors. All Rights Reserved.
1+ # Copyright 2019 The Texar Authors. All Rights Reserved.
22#
33# Licensed under the Apache License, Version 2.0 (the "License");
44# you may not use this file except in compliance with the License.
@@ -54,7 +54,10 @@ def _assert_same_size(outputs, output_size):
5454 flat_output = nest .flatten (outputs )
5555
5656 for (output , size ) in zip (flat_output , flat_output_size ):
57- if output [0 ].shape != tf .TensorShape (size ):
57+ if isinstance (size , tf .TensorShape ):
58+ if output .shape == size :
59+ pass
60+ elif output [0 ].shape != tf .TensorShape (size ):
5861 raise ValueError (
5962 "The output size does not match the the required output_size" )
6063
@@ -518,7 +521,8 @@ class instance.
518521 - output: A Tensor or a (nested) tuple of Tensors with the same \
519522 structure and size of :attr:`output_size`. The batch dimension \
520523 equals :attr:`num_samples` if specified, or is determined by the \
521- distribution dimensionality.
524+ distribution dimensionality. If :attr:`transform` is `False`, \
525+ :attr:`output` will be equal to :attr:`sample`.
522526 - sample: The sample from the distribution, prior to transformation.
523527
524528 Raises:
@@ -549,9 +553,10 @@ class instance.
549553 fn_modules = ['tensorflow' , 'tensorflow.nn' , 'texar.custom' ]
550554 activation_fn = get_function (self .hparams .activation_fn , fn_modules )
551555 output = _mlp_transform (sample , self ._output_size , activation_fn )
556+ else :
557+ output = sample
552558
553559 _assert_same_size (output , self ._output_size )
554-
555560 if not self ._built :
556561 self ._add_internal_trainable_variables ()
557562 self ._built = True
@@ -616,7 +621,7 @@ def default_hparams():
616621 def _build (self ,
617622 distribution = 'MultivariateNormalDiag' ,
618623 distribution_kwargs = None ,
619- transform = False ,
624+ transform = True ,
620625 num_samples = None ):
621626 """Samples from a distribution and optionally performs transformation
622627 with an MLP layer.
@@ -649,7 +654,8 @@ class instance.
649654 - output: A Tensor or a (nested) tuple of Tensors with the same \
650655 structure and size of :attr:`output_size`. The batch dimension \
651656 equals :attr:`num_samples` if specified, or is determined by the \
652- distribution dimensionality.
657+ distribution dimensionality. If :attr:`transform` is `False`, \
658+ :attr:`output` will be equal to :attr:`sample`.
653659 - sample: The sample from the distribution, prior to transformation.
654660
655661 Raises:
@@ -661,31 +667,32 @@ class instance.
661667 "tensorflow.contrib.distributions" , "texar.custom" ])
662668
663669 if num_samples :
664- output = dstr .sample (num_samples )
670+ sample = dstr .sample (num_samples )
665671 else :
666- output = dstr .sample ()
672+ sample = dstr .sample ()
667673
668674 if dstr .event_shape == []:
669- output = tf .reshape (output ,
670- output .shape .concatenate (tf .TensorShape (1 )))
675+ sample = tf .reshape (sample ,
676+ sample .shape .concatenate (tf .TensorShape (1 )))
671677
672678 # Disable gradients through samples
673- output = tf .stop_gradient (output )
679+ sample = tf .stop_gradient (sample )
674680
675- output = tf .cast (output , tf .float32 )
681+ sample = tf .cast (sample , tf .float32 )
676682
677683 if transform :
678684 fn_modules = ['tensorflow' , 'tensorflow.nn' , 'texar.custom' ]
679685 activation_fn = get_function (self .hparams .activation_fn , fn_modules )
680- output = _mlp_transform (output , self ._output_size , activation_fn )
686+ output = _mlp_transform (sample , self ._output_size , activation_fn )
687+ else :
688+ output = sample
681689
682690 _assert_same_size (output , self ._output_size )
683-
684691 if not self ._built :
685692 self ._add_internal_trainable_variables ()
686693 self ._built = True
687694
688- return output
695+ return output , sample
689696
690697
691698#class ConcatConnector(ConnectorBase):
0 commit comments