6
6
from torch .utils import data
7
7
8
8
from .. import defaults , util
9
- from . import collators , datasets , indexes , tsv
9
+ from . import collators , datasets , indexes , mappers , tsv
10
10
11
11
12
12
class DataModule (lightning .LightningDataModule ):
13
- """Parses, indexes, collates and loads data.
14
-
15
- The batch size tuner is permitted to mutate the `batch_size` argument.
13
+ """Data module.
14
+
15
+ This is responsible for indexing the data, collating/padding, and
16
+ generating datasets.
17
+
18
+ Args:
19
+ model_dir: Path for checkpoints, indexes, and logs.
20
+ train: Path for training data TSV.
21
+ val: Path for validation data TSV.
22
+ predict: Path for prediction data TSV.
23
+ test: Path for test data TSV.
24
+ source_col: 1-indexed column in TSV containing source strings.
25
+ features_col: 1-indexed column in TSV containing features strings.
26
+ target_col: 1-indexed column in TSV containing target strings.
27
+ source_sep: String used to split source string into symbols; an empty
28
+ string indicates that each Unicode codepoint is its own symbol.
29
+ features_sep: String used to split features string into symbols; an
30
+ empty string indicates that each Unicode codepoint is its own
31
+ symbol.
32
+ target_sep: String used to split target string into symbols; an empty
33
+ string indicates that each Unicode codepoint is its own symbol.
34
+ separate_features: Whether or not a separate encoder should be used
35
+ for features.
36
+ tie_embeddings: Whether or not source and target embeddings are tied.
37
+ If not, then source symbols are wrapped in {...}.
38
+ batch_size: Desired batch size.
39
+ max_source_length: The maximum length of a source string; this includes
40
+ concatenated feature strings if not using separate features. An
41
+ error will be raised if any source exceeds this limit.
42
+ max_target_length: The maximum length of a target string. A warning
43
+ will be raised and the target strings will be truncated if any
44
+ target exceeds this limit.
16
45
"""
17
46
18
47
train : Optional [str ]
@@ -37,18 +66,16 @@ def __init__(
37
66
source_col : int = defaults .SOURCE_COL ,
38
67
features_col : int = defaults .FEATURES_COL ,
39
68
target_col : int = defaults .TARGET_COL ,
40
- # String parsing arguments.
41
69
source_sep : str = defaults .SOURCE_SEP ,
42
70
features_sep : str = defaults .FEATURES_SEP ,
43
71
target_sep : str = defaults .TARGET_SEP ,
44
- # Collator options.
45
- batch_size : int = defaults .BATCH_SIZE ,
72
+ # Modeling options.
46
73
separate_features : bool = False ,
74
+ tie_embeddings : bool = defaults .TIE_EMBEDDINGS ,
75
+ # Other.
76
+ batch_size : int = defaults .BATCH_SIZE ,
47
77
max_source_length : int = defaults .MAX_SOURCE_LENGTH ,
48
78
max_target_length : int = defaults .MAX_TARGET_LENGTH ,
49
- tie_embeddings : bool = defaults .TIE_EMBEDDINGS ,
50
- # Indexing.
51
- index : Optional [indexes .Index ] = None ,
52
79
):
53
80
super ().__init__ ()
54
81
self .train = train
@@ -83,7 +110,7 @@ def __init__(
83
110
def _make_index (
84
111
self , model_dir : str , tie_embeddings : bool
85
112
) -> indexes .Index :
86
- # Computes index.
113
+ """Creates the index from a training set."""
87
114
source_vocabulary : Set [str ] = set ()
88
115
features_vocabulary : Set [str ] = set ()
89
116
target_vocabulary : Set [str ] = set ()
@@ -107,21 +134,22 @@ def _make_index(
107
134
for source in self .parser .samples (self .train ):
108
135
source_vocabulary .update (source )
109
136
index = indexes .Index (
110
- source_vocabulary = sorted ( source_vocabulary ) ,
137
+ source_vocabulary = source_vocabulary ,
111
138
features_vocabulary = (
112
- sorted (features_vocabulary ) if features_vocabulary else None
113
- ),
114
- target_vocabulary = (
115
- sorted (target_vocabulary ) if target_vocabulary else None
139
+ features_vocabulary if features_vocabulary else None
116
140
),
141
+ target_vocabulary = target_vocabulary if target_vocabulary else None ,
117
142
tie_embeddings = tie_embeddings ,
118
143
)
144
+ # Writes it to the model directory.
119
145
index .write (model_dir )
120
146
return index
121
147
148
+ # Logging.
149
+
122
150
@staticmethod
123
151
def pprint (vocabulary : Iterable ) -> str :
124
- """Prints the vocabulary for debugging adn logging purposes."""
152
+ """Prints the vocabulary for debugging dnd logging purposes."""
125
153
return ", " .join (f"{ symbol !r} " for symbol in vocabulary )
126
154
127
155
def log_vocabularies (self ) -> None :
@@ -140,6 +168,8 @@ def log_vocabularies(self) -> None:
140
168
f"{ self .pprint (self .index .target_vocabulary )} "
141
169
)
142
170
171
+ # Properties.
172
+
143
173
@property
144
174
def has_features (self ) -> bool :
145
175
return self .parser .has_features
@@ -148,13 +178,6 @@ def has_features(self) -> bool:
148
178
def has_target (self ) -> bool :
149
179
return self .parser .has_target
150
180
151
- def _dataset (self , path : str ) -> datasets .Dataset :
152
- return datasets .Dataset (
153
- list (self .parser .samples (path )),
154
- self .index ,
155
- self .parser ,
156
- )
157
-
158
181
# Required API.
159
182
160
183
def train_dataloader (self ) -> data .DataLoader :
@@ -165,6 +188,7 @@ def train_dataloader(self) -> data.DataLoader:
165
188
batch_size = self .batch_size ,
166
189
shuffle = True ,
167
190
num_workers = 1 ,
191
+ persistent_workers = True ,
168
192
)
169
193
170
194
def val_dataloader (self ) -> data .DataLoader :
@@ -173,7 +197,9 @@ def val_dataloader(self) -> data.DataLoader:
173
197
self ._dataset (self .val ),
174
198
collate_fn = self .collator ,
175
199
batch_size = self .batch_size ,
200
+ shuffle = False ,
176
201
num_workers = 1 ,
202
+ persistent_workers = True ,
177
203
)
178
204
179
205
def predict_dataloader (self ) -> data .DataLoader :
@@ -182,7 +208,9 @@ def predict_dataloader(self) -> data.DataLoader:
182
208
self ._dataset (self .predict ),
183
209
collate_fn = self .collator ,
184
210
batch_size = self .batch_size ,
211
+ shuffle = False ,
185
212
num_workers = 1 ,
213
+ persistent_workers = True ,
186
214
)
187
215
188
216
def test_dataloader (self ) -> data .DataLoader :
@@ -191,5 +219,14 @@ def test_dataloader(self) -> data.DataLoader:
191
219
self ._dataset (self .test ),
192
220
collate_fn = self .collator ,
193
221
batch_size = self .batch_size ,
222
+ shuffle = False ,
194
223
num_workers = 1 ,
224
+ persistent_workers = True ,
225
+ )
226
+
227
+ def _dataset (self , path : str ) -> datasets .Dataset :
228
+ return datasets .Dataset (
229
+ list (self .parser .samples (path )),
230
+ mappers .Mapper (self .index ),
231
+ self .parser ,
195
232
)
0 commit comments