Skip to content

Commit f3550a1

Browse files
authored
Merge pull request #266 from CUNY-CL/mapper
Moves to mapper interface
2 parents 0a91f56 + 8bc0127 commit f3550a1

12 files changed

+276
-207
lines changed

pyproject.toml

+1-1
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,7 @@ exclude = ["examples*"]
1111

1212
[project]
1313
name = "yoyodyne"
14-
version = "0.2.16"
14+
version = "0.2.17"
1515
description = "Small-vocabulary neural sequence-to-sequence models"
1616
readme = "README.md"
1717
requires-python = ">= 3.9"

yoyodyne/data/__init__.py

+2
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,8 @@
77
from .datamodules import DataModule # noqa: F401
88
from .datasets import Dataset # noqa: F401
99
from .indexes import Index # noqa: F401
10+
from .mappers import Mapper # noqa: F401
11+
from .tsv import TsvParser # noqa: F401
1012

1113

1214
def add_argparse_args(parser: argparse.ArgumentParser) -> None:

yoyodyne/data/collators.py

+1
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22

33
import argparse
44
import dataclasses
5+
56
from typing import List
67

78
import torch

yoyodyne/data/datamodules.py

+61-24
Original file line numberDiff line numberDiff line change
@@ -6,13 +6,42 @@
66
from torch.utils import data
77

88
from .. import defaults, util
9-
from . import collators, datasets, indexes, tsv
9+
from . import collators, datasets, indexes, mappers, tsv
1010

1111

1212
class DataModule(lightning.LightningDataModule):
13-
"""Parses, indexes, collates and loads data.
14-
15-
The batch size tuner is permitted to mutate the `batch_size` argument.
13+
"""Data module.
14+
15+
This is responsible for indexing the data, collating/padding, and
16+
generating datasets.
17+
18+
Args:
19+
model_dir: Path for checkpoints, indexes, and logs.
20+
train: Path for training data TSV.
21+
val: Path for validation data TSV.
22+
predict: Path for prediction data TSV.
23+
test: Path for test data TSV.
24+
source_col: 1-indexed column in TSV containing source strings.
25+
features_col: 1-indexed column in TSV containing features strings.
26+
target_col: 1-indexed column in TSV containing target strings.
27+
source_sep: String used to split source string into symbols; an empty
28+
string indicates that each Unicode codepoint is its own symbol.
29+
features_sep: String used to split features string into symbols; an
30+
empty string indicates that each Unicode codepoint is its own
31+
symbol.
32+
target_sep: String used to split target string into symbols; an empty
33+
string indicates that each Unicode codepoint is its own symbol.
34+
separate_features: Whether or not a separate encoder should be used
35+
for features.
36+
tie_embeddings: Whether or not source and target embeddings are tied.
37+
If not, then source symbols are wrapped in {...}.
38+
batch_size: Desired batch size.
39+
max_source_length: The maximum length of a source string; this includes
40+
concatenated feature strings if not using separate features. An
41+
error will be raised if any source exceeds this limit.
42+
max_target_length: The maximum length of a target string. A warning
43+
will be raised and the target strings will be truncated if any
44+
target exceeds this limit.
1645
"""
1746

1847
train: Optional[str]
@@ -37,18 +66,16 @@ def __init__(
3766
source_col: int = defaults.SOURCE_COL,
3867
features_col: int = defaults.FEATURES_COL,
3968
target_col: int = defaults.TARGET_COL,
40-
# String parsing arguments.
4169
source_sep: str = defaults.SOURCE_SEP,
4270
features_sep: str = defaults.FEATURES_SEP,
4371
target_sep: str = defaults.TARGET_SEP,
44-
# Collator options.
45-
batch_size: int = defaults.BATCH_SIZE,
72+
# Modeling options.
4673
separate_features: bool = False,
74+
tie_embeddings: bool = defaults.TIE_EMBEDDINGS,
75+
# Other.
76+
batch_size: int = defaults.BATCH_SIZE,
4777
max_source_length: int = defaults.MAX_SOURCE_LENGTH,
4878
max_target_length: int = defaults.MAX_TARGET_LENGTH,
49-
tie_embeddings: bool = defaults.TIE_EMBEDDINGS,
50-
# Indexing.
51-
index: Optional[indexes.Index] = None,
5279
):
5380
super().__init__()
5481
self.train = train
@@ -83,7 +110,7 @@ def __init__(
83110
def _make_index(
84111
self, model_dir: str, tie_embeddings: bool
85112
) -> indexes.Index:
86-
# Computes index.
113+
"""Creates the index from a training set."""
87114
source_vocabulary: Set[str] = set()
88115
features_vocabulary: Set[str] = set()
89116
target_vocabulary: Set[str] = set()
@@ -107,21 +134,22 @@ def _make_index(
107134
for source in self.parser.samples(self.train):
108135
source_vocabulary.update(source)
109136
index = indexes.Index(
110-
source_vocabulary=sorted(source_vocabulary),
137+
source_vocabulary=source_vocabulary,
111138
features_vocabulary=(
112-
sorted(features_vocabulary) if features_vocabulary else None
113-
),
114-
target_vocabulary=(
115-
sorted(target_vocabulary) if target_vocabulary else None
139+
features_vocabulary if features_vocabulary else None
116140
),
141+
target_vocabulary=target_vocabulary if target_vocabulary else None,
117142
tie_embeddings=tie_embeddings,
118143
)
144+
# Writes it to the model directory.
119145
index.write(model_dir)
120146
return index
121147

148+
# Logging.
149+
122150
@staticmethod
123151
def pprint(vocabulary: Iterable) -> str:
124-
"""Prints the vocabulary for debugging adn logging purposes."""
152+
"""Prints the vocabulary for debugging dnd logging purposes."""
125153
return ", ".join(f"{symbol!r}" for symbol in vocabulary)
126154

127155
def log_vocabularies(self) -> None:
@@ -140,6 +168,8 @@ def log_vocabularies(self) -> None:
140168
f"{self.pprint(self.index.target_vocabulary)}"
141169
)
142170

171+
# Properties.
172+
143173
@property
144174
def has_features(self) -> bool:
145175
return self.parser.has_features
@@ -148,13 +178,6 @@ def has_features(self) -> bool:
148178
def has_target(self) -> bool:
149179
return self.parser.has_target
150180

151-
def _dataset(self, path: str) -> datasets.Dataset:
152-
return datasets.Dataset(
153-
list(self.parser.samples(path)),
154-
self.index,
155-
self.parser,
156-
)
157-
158181
# Required API.
159182

160183
def train_dataloader(self) -> data.DataLoader:
@@ -165,6 +188,7 @@ def train_dataloader(self) -> data.DataLoader:
165188
batch_size=self.batch_size,
166189
shuffle=True,
167190
num_workers=1,
191+
persistent_workers=True,
168192
)
169193

170194
def val_dataloader(self) -> data.DataLoader:
@@ -173,7 +197,9 @@ def val_dataloader(self) -> data.DataLoader:
173197
self._dataset(self.val),
174198
collate_fn=self.collator,
175199
batch_size=self.batch_size,
200+
shuffle=False,
176201
num_workers=1,
202+
persistent_workers=True,
177203
)
178204

179205
def predict_dataloader(self) -> data.DataLoader:
@@ -182,7 +208,9 @@ def predict_dataloader(self) -> data.DataLoader:
182208
self._dataset(self.predict),
183209
collate_fn=self.collator,
184210
batch_size=self.batch_size,
211+
shuffle=False,
185212
num_workers=1,
213+
persistent_workers=True,
186214
)
187215

188216
def test_dataloader(self) -> data.DataLoader:
@@ -191,5 +219,14 @@ def test_dataloader(self) -> data.DataLoader:
191219
self._dataset(self.test),
192220
collate_fn=self.collator,
193221
batch_size=self.batch_size,
222+
shuffle=False,
194223
num_workers=1,
224+
persistent_workers=True,
225+
)
226+
227+
def _dataset(self, path: str) -> datasets.Dataset:
228+
return datasets.Dataset(
229+
list(self.parser.samples(path)),
230+
mappers.Mapper(self.index),
231+
self.parser,
195232
)

0 commit comments

Comments
 (0)