Skip to content

Commit 7b988f6

Browse files
authored
Merge pull request #316 from kylebgorman/slight
Slight simplification in hard attention loop
2 parents 33496e5 + 4c3c772 commit 7b988f6

File tree

2 files changed

+5
-10
lines changed

2 files changed

+5
-10
lines changed

yoyodyne/models/base.py

-1
Original file line numberDiff line numberDiff line change
@@ -370,7 +370,6 @@ def predict_step(
370370
using beam search, the predictions and scores as a tuple of
371371
tensors; if using greedy search, the predictions as a tensor.
372372
"""
373-
374373
if self.beam_width > 1:
375374
return self(batch)
376375
else:

yoyodyne/models/hard_attention.py

+5-9
Original file line numberDiff line numberDiff line change
@@ -216,9 +216,8 @@ def greedy_decode(
216216
emissions, transitions, state = self.decode_step(
217217
encoded, mask, symbol, state
218218
)
219-
symbol, likelihood = self._greedy_step(
220-
emissions, transitions[:, 0].unsqueeze(1)
221-
)
219+
likelihood = transitions[:, 0].unsqueeze(1)
220+
symbol = self._greedy_step(emissions, likelihood)
222221
predictions = [symbol]
223222
# Tracks when each sequence has decoded an END.
224223
final = torch.zeros(batch_size, device=self.device, dtype=bool)
@@ -233,7 +232,7 @@ def greedy_decode(
233232
likelihood = likelihood.logsumexp(dim=2, keepdim=True).transpose(
234233
1, 2
235234
)
236-
symbol, likelihood = self._greedy_step(emissions, likelihood)
235+
symbol = self._greedy_step(emissions, likelihood)
237236
predictions.append(symbol)
238237
final = torch.logical_or(final, symbol == special.END_IDX)
239238
if final.all():
@@ -257,15 +256,12 @@ def _greedy_step(
257256
symbol sequence.
258257
259258
Returns:
260-
Tuple[torch.Tensor, torch.Tensor]: greedily decoded symbol
261-
for current timestep and the current likelihood of the
262-
decoded symbol sequence.
259+
torch.Tensor: greedily decoded symbol for the current timestep.
263260
"""
264261
probabilities = likelihood + emissions.transpose(1, 2)
265262
probabilities = probabilities.logsumexp(dim=2)
266263
# -> B.
267-
symbol = torch.argmax(probabilities, dim=1)
268-
return symbol, likelihood
264+
return torch.argmax(probabilities, dim=1)
269265

270266
@staticmethod
271267
def _gather_at_idx(

0 commit comments

Comments
 (0)