@@ -216,9 +216,8 @@ def greedy_decode(
216
216
emissions , transitions , state = self .decode_step (
217
217
encoded , mask , symbol , state
218
218
)
219
- symbol , likelihood = self ._greedy_step (
220
- emissions , transitions [:, 0 ].unsqueeze (1 )
221
- )
219
+ likelihood = transitions [:, 0 ].unsqueeze (1 )
220
+ symbol = self ._greedy_step (emissions , likelihood )
222
221
predictions = [symbol ]
223
222
# Tracks when each sequence has decoded an END.
224
223
final = torch .zeros (batch_size , device = self .device , dtype = bool )
@@ -233,7 +232,7 @@ def greedy_decode(
233
232
likelihood = likelihood .logsumexp (dim = 2 , keepdim = True ).transpose (
234
233
1 , 2
235
234
)
236
- symbol , likelihood = self ._greedy_step (emissions , likelihood )
235
+ symbol = self ._greedy_step (emissions , likelihood )
237
236
predictions .append (symbol )
238
237
final = torch .logical_or (final , symbol == special .END_IDX )
239
238
if final .all ():
@@ -257,15 +256,12 @@ def _greedy_step(
257
256
symbol sequence.
258
257
259
258
Returns:
260
- Tuple[torch.Tensor, torch.Tensor]: greedily decoded symbol
261
- for current timestep and the current likelihood of the
262
- decoded symbol sequence.
259
+ torch.Tensor: greedily decoded symbol for the current timestep.
263
260
"""
264
261
probabilities = likelihood + emissions .transpose (1 , 2 )
265
262
probabilities = probabilities .logsumexp (dim = 2 )
266
263
# -> B.
267
- symbol = torch .argmax (probabilities , dim = 1 )
268
- return symbol , likelihood
264
+ return torch .argmax (probabilities , dim = 1 )
269
265
270
266
@staticmethod
271
267
def _gather_at_idx (
0 commit comments