Skip to content

Commit aaa41c9

Browse files
authored
Merge pull request #311 from kylebgorman/really
RNN encapsulation
2 parents 4abca3c + a0d8d37 commit aaa41c9

22 files changed

+1544
-1769
lines changed

pyproject.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,7 @@ exclude = ["examples*"]
1111

1212
[project]
1313
name = "yoyodyne"
14-
version = "0.2.20"
14+
version = "0.3.0"
1515
description = "Small-vocabulary neural sequence-to-sequence models"
1616
readme = "README.md"
1717
requires-python = ">= 3.9"

yoyodyne/data/batches.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -84,7 +84,7 @@ def lengths(self) -> torch.Tensor:
8484
Returns:
8585
torch.Tensor.
8686
"""
87-
return (self.mask == 0).sum(dim=1).cpu()
87+
return (~self.mask).sum(dim=1).cpu()
8888

8989

9090
class PaddedBatch(nn.Module):

yoyodyne/data/datamodules.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -180,6 +180,10 @@ def has_features(self) -> bool:
180180
def has_target(self) -> bool:
181181
return self.parser.has_target
182182

183+
@property
184+
def has_separate_features(self) -> bool:
185+
return self.collator.separate_features
186+
183187
# Required API.
184188

185189
def train_dataloader(self) -> data.DataLoader:

yoyodyne/models/base.py

Lines changed: 14 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -55,9 +55,9 @@ class BaseModel(abc.ABC, lightning.LightningModule):
5555
embedding_size: int
5656
encoder_layers: int
5757
decoder_layers: int
58-
features_encoder_cls: Optional[modules.base.BaseModule]
58+
features_encoder_cls: Optional[modules.BaseModule]
5959
hidden_size: int
60-
source_encoder_cls: modules.base.BaseModule
60+
source_encoder_cls: modules.BaseModule
6161
# Other stuff.
6262
eval_metrics: Set[evaluators.Evaluator]
6363
loss_func: Callable[[torch.Tensor, torch.Tensor], torch.Tensor]
@@ -249,6 +249,12 @@ def has_features_encoder(self):
249249
def num_parameters(self) -> int:
250250
return sum(part.numel() for part in self.parameters())
251251

252+
def start_symbol(self, batch_size: int) -> torch.Tensor:
253+
"""Generates a tensor of start symbols for the batch."""
254+
return torch.tensor([special.START_IDX], device=self.device).repeat(
255+
batch_size, 1
256+
)
257+
252258
def training_step(
253259
self,
254260
batch: data.PaddedBatch,
@@ -280,14 +286,16 @@ def training_step(
280286
)
281287
return loss
282288

283-
def validation_epoch_end(self, validation_step_outputs: Dict) -> Dict:
289+
def validation_epoch_end(
290+
self, validation_step_outputs: Dict
291+
) -> Dict[str, float]:
284292
"""Computes average loss and average accuracy.
285293
286294
Args:
287295
validation_step_outputs (Dict).
288296
289297
Returns:
290-
Dict: averaged metrics over all validation steps.
298+
Dict[str, float]: averaged metrics over all validation steps.
291299
"""
292300
avg_val_loss = torch.tensor(
293301
[v["val_loss"] for v in validation_step_outputs]
@@ -311,7 +319,7 @@ def validation_step(
311319
self,
312320
batch: data.PaddedBatch,
313321
batch_idx: int,
314-
) -> Dict:
322+
) -> Dict[str, float]:
315323
"""Runs one validation step.
316324
317325
This is called by the PL Trainer.
@@ -366,19 +374,7 @@ def predict_step(
366374
if self.beam_width > 1:
367375
return self(batch)
368376
else:
369-
return self._get_predicted(self(batch))
370-
371-
def _get_predicted(self, predictions: torch.Tensor) -> torch.Tensor:
372-
"""Picks the best index from the vocabulary.
373-
374-
Args:
375-
predictions (torch.Tensor): B x seq_len x target_vocab_size.
376-
377-
Returns:
378-
torch.Tensor: indices of the argmax at each timestep.
379-
"""
380-
assert len(predictions.size()) == 3
381-
return torch.argmax(predictions, dim=2)
377+
return torch.argmax(self(batch), dim=2)
382378

383379

384380
def add_argparse_args(parser: argparse.ArgumentParser) -> None:

yoyodyne/models/beam_search.py

Lines changed: 164 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,164 @@
1+
"""Beam search classes.
2+
3+
A Cell is a (possibly partial) hypothesis containing the decoder output,
4+
the symbol sequence, and the hypothesis's log-likelihood. Cells can
5+
generate their candidate extensions (in the form of new Cells) when
6+
provided with additional decoder output; they also know when they have reached
7+
a final state (i.e., when END has been generated).
8+
9+
A Beam holds a collection of Cells and an in-progress heap.
10+
11+
Current limitations:
12+
13+
* Beam search uses Python's heap implementation; this is reasonably performant
14+
in cPython (it uses a C extension module where available) but there may be a
15+
better pure PyTorch solution.
16+
* Beam search assumes a batch size of 1; it is not clear how to extend it to
17+
larger batches.
18+
* We hard-code the use of log-likelihoods; the addition of two log
19+
probabilities is equivalent to multiplying real numbers.
20+
* Beam search is designed to support RNN and attentive RNN models and interface
21+
issues might arise with other architectures.
22+
* Not much attention has been paid to keeping data on device.
23+
24+
See rnn.py for sample usage.
25+
"""
26+
27+
from __future__ import annotations
28+
29+
import dataclasses
30+
import heapq
31+
32+
from typing import Iterator, List
33+
34+
import torch
35+
from torch import nn
36+
37+
from . import modules
38+
from .. import special
39+
40+
41+
@dataclasses.dataclass(order=True)
42+
class Cell:
43+
"""Represents a (potentially partial) hypotheses in the beam search.
44+
45+
Only the log-likelihood field is used for comparison.
46+
47+
A cell is "final" once it has decoded the END symbol.
48+
49+
Args:
50+
state (modules.RNNState).
51+
symbols (List[int], optional).
52+
score (float, optional).
53+
"""
54+
55+
state: modules.RNNState = dataclasses.field(compare=False)
56+
symbols: List[int] = dataclasses.field(
57+
compare=False, default_factory=lambda: [special.START_IDX]
58+
)
59+
score: float = dataclasses.field(compare=True, default=0.0)
60+
61+
def extensions(
62+
self, state: modules.RNNState, scores: torch.Tensor
63+
) -> Iterator[Cell]:
64+
"""Generates extension cells.
65+
66+
Args:
67+
state (modules.RNNState).
68+
scores (torch.Tensor):
69+
70+
Yields:
71+
Cell: all single-symbol extensions of the current cell.
72+
"""
73+
for symbol, score in enumerate(scores):
74+
yield Cell(
75+
state, self.symbols + [symbol], self.score + score.item()
76+
)
77+
78+
@property
79+
def symbol(self) -> int:
80+
return self.symbols[-1]
81+
82+
@property
83+
def final(self) -> bool:
84+
return self.symbols[-1] == special.END_IDX
85+
86+
87+
class Beam:
88+
"""The beam.
89+
90+
This stores stores the current set of beam cells and an in-progress heap of
91+
the next set separately.
92+
93+
A beam is "final" once every cell has decoded the END symbol.
94+
95+
Args:
96+
beam_width (int).
97+
state (modules.RNNState).
98+
"""
99+
100+
beam_width: int
101+
# Current cells.
102+
cells: List[Cell]
103+
# Heap of the next set of cells.
104+
heap: List[Cell]
105+
106+
def __init__(self, beam_width, state: modules.RNNState):
107+
self.beam_width = beam_width
108+
self.cells = [Cell(state)]
109+
self.heap = []
110+
111+
def __len__(self) -> int:
112+
return len(self.cells)
113+
114+
def push(self, cell: Cell) -> None:
115+
"""Inserts the cell into the heap, maintaining the specified beam size.
116+
117+
Args:
118+
cell (Cell).
119+
"""
120+
if len(self.heap) < self.beam_width:
121+
heapq.heappush(self.heap, cell)
122+
else:
123+
heapq.heappushpop(self.heap, cell)
124+
125+
def update(self) -> None:
126+
"""Replaces the current cells and clears the heap."""
127+
self.cells = sorted(self.heap, reverse=True)
128+
self.heap.clear()
129+
130+
@property
131+
def final(self) -> bool:
132+
return all(cell.final for cell in self.cells)
133+
134+
def predictions(self, device: torch.device) -> torch.Tensor:
135+
"""Converts the best sequences into a padded tensor of predictions.
136+
137+
This implementation assumes batch size is 1.
138+
139+
Args:
140+
device (torch.device): the device to move the data to.
141+
142+
Returns:
143+
torch.Tensor: a B x beam_width x seq_length tensor of predictions.
144+
"""
145+
return nn.utils.rnn.pad_sequence(
146+
[torch.tensor(cell.symbols, device=device) for cell in self.cells],
147+
batch_first=True,
148+
padding_value=special.PAD_IDX,
149+
).unsqueeze(0)
150+
151+
def scores(self, device: torch.device) -> torch.Tensor:
152+
"""Converts the sequence scores into tensors.
153+
154+
This implementation assumes batch size is 1.
155+
156+
Args:
157+
device (torch.device): the device to move the data to.
158+
159+
Returns:
160+
torch.Tensor: a B x beam_width tensor of log-likelihoods.
161+
"""
162+
return torch.tensor(
163+
[cell.score for cell in self.cells], device=device
164+
).unsqueeze(0)

0 commit comments

Comments
 (0)