Skip to content

Commit a0d8d37

Browse files
committed
Cleans up RNN wrappers
1 parent 1efb865 commit a0d8d37

File tree

1 file changed

+9
-10
lines changed

1 file changed

+9
-10
lines changed

yoyodyne/models/modules/rnn.py

Lines changed: 9 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -65,7 +65,7 @@ class WrappedRNNEncoder:
6565
"""Wraps RNN encoder modules to work with packing.
6666
6767
The derived modules do not pass an initial hidden state (or cell state, in
68-
the case of GRUs, so it is effectively zero.
68+
the case of LSTMs), so it is effectively zero.
6969
"""
7070

7171
@staticmethod
@@ -91,7 +91,7 @@ def _pack(
9191
@abc.abstractmethod
9292
def forward(
9393
self, sequence: torch.Tensor, lengths: torch.Tensor
94-
) -> Tuple[torch.Tensor, RNNState]: ...
94+
) -> torch.Tensor: ...
9595

9696

9797
class WrappedGRUEncoder(nn.GRU, WrappedRNNEncoder):
@@ -101,19 +101,19 @@ def forward(
101101
self,
102102
sequence: torch.Tensor,
103103
lengths: torch.Tensor,
104-
) -> Tuple[torch.Tensor, RNNState]:
105-
packed, hidden = super().forward(self._pack(sequence, lengths))
106-
return self._pad(packed), RNNState(hidden)
104+
) -> torch.Tensor:
105+
packed, _ = super().forward(self._pack(sequence, lengths))
106+
return self._pad(packed)
107107

108108

109109
class WrappedLSTMEncoder(nn.LSTM, WrappedRNNEncoder):
110110
"""Wraps LSTM API to work with packing."""
111111

112112
def forward(
113113
self, sequence: torch.Tensor, lengths: torch.Tensor
114-
) -> Tuple[torch.Tensor, RNNState]:
115-
packed, (hidden, cell) = super().forward(self._pack(sequence, lengths))
116-
return self._pad(packed), RNNState(hidden, cell)
114+
) -> torch.Tensor:
115+
packed, _ = super().forward(self._pack(sequence, lengths))
116+
return self._pad(packed)
117117

118118

119119
class RNNEncoder(RNNModule):
@@ -132,8 +132,7 @@ def forward(self, source: data.PaddedTensor) -> torch.Tensor:
132132
Returns:
133133
torch.Tensor.
134134
"""
135-
encoded, _ = self.module(self.embed(source.padded), source.lengths())
136-
return encoded
135+
return self.module(self.embed(source.padded), source.lengths())
137136

138137
@property
139138
def output_size(self) -> int:

0 commit comments

Comments
 (0)