@@ -65,7 +65,7 @@ class WrappedRNNEncoder:
6565 """Wraps RNN encoder modules to work with packing.
6666
6767 The derived modules do not pass an initial hidden state (or cell state, in
68- the case of GRUs , so it is effectively zero.
68+ the case of LSTMs) , so it is effectively zero.
6969 """
7070
7171 @staticmethod
@@ -91,7 +91,7 @@ def _pack(
9191 @abc .abstractmethod
9292 def forward (
9393 self , sequence : torch .Tensor , lengths : torch .Tensor
94- ) -> Tuple [ torch .Tensor , RNNState ] : ...
94+ ) -> torch .Tensor : ...
9595
9696
9797class WrappedGRUEncoder (nn .GRU , WrappedRNNEncoder ):
@@ -101,19 +101,19 @@ def forward(
101101 self ,
102102 sequence : torch .Tensor ,
103103 lengths : torch .Tensor ,
104- ) -> Tuple [ torch .Tensor , RNNState ] :
105- packed , hidden = super ().forward (self ._pack (sequence , lengths ))
106- return self ._pad (packed ), RNNState ( hidden )
104+ ) -> torch .Tensor :
105+ packed , _ = super ().forward (self ._pack (sequence , lengths ))
106+ return self ._pad (packed )
107107
108108
109109class WrappedLSTMEncoder (nn .LSTM , WrappedRNNEncoder ):
110110 """Wraps LSTM API to work with packing."""
111111
112112 def forward (
113113 self , sequence : torch .Tensor , lengths : torch .Tensor
114- ) -> Tuple [ torch .Tensor , RNNState ] :
115- packed , ( hidden , cell ) = super ().forward (self ._pack (sequence , lengths ))
116- return self ._pad (packed ), RNNState ( hidden , cell )
114+ ) -> torch .Tensor :
115+ packed , _ = super ().forward (self ._pack (sequence , lengths ))
116+ return self ._pad (packed )
117117
118118
119119class RNNEncoder (RNNModule ):
@@ -132,8 +132,7 @@ def forward(self, source: data.PaddedTensor) -> torch.Tensor:
132132 Returns:
133133 torch.Tensor.
134134 """
135- encoded , _ = self .module (self .embed (source .padded ), source .lengths ())
136- return encoded
135+ return self .module (self .embed (source .padded ), source .lengths ())
137136
138137 @property
139138 def output_size (self ) -> int :
0 commit comments