2
2
import io
3
3
from torchtext .utils import download_from_url , extract_archive
4
4
from torchtext .experimental .datasets .raw .common import RawTextIterableDataset
5
- from torchtext .experimental .datasets .raw .common import check_default_set
6
- from torchtext .experimental .datasets .raw .common import wrap_datasets
5
+ from torchtext .experimental .datasets .raw .common import wrap_split_argument
6
+ from torchtext .experimental .datasets .raw .common import add_docstring_header
7
7
8
8
URLS = {
9
9
'WikiText2' :
19
19
}
20
20
21
21
22
- def _setup_datasets (dataset_name , root , split_ , year , language , offset ):
23
- if dataset_name == 'WMTNewsCrawl' :
24
- split = check_default_set (split_ , ('train' ,), dataset_name )
25
- else :
26
- split = check_default_set (split_ , ('train' , 'test' , 'valid' ), dataset_name )
27
-
22
+ def _setup_datasets (dataset_name , root , split , year , language , offset ):
28
23
if dataset_name == 'PennTreebank' :
29
24
extracted_files = [download_from_url (URLS ['PennTreebank' ][key ],
30
25
root = root , hash_value = MD5 ['PennTreebank' ][key ],
@@ -49,23 +44,13 @@ def _setup_datasets(dataset_name, root, split_, year, language, offset):
49
44
datasets .append (RawTextIterableDataset (dataset_name ,
50
45
NUM_LINES [dataset_name ][item ], iter (io .open (path [item ], encoding = "utf8" )), offset = offset ))
51
46
52
- return wrap_datasets ( tuple ( datasets ), split_ )
47
+ return datasets
53
48
54
49
50
+ @wrap_split_argument
51
+ @add_docstring_header
55
52
def WikiText2 (root = '.data' , split = ('train' , 'valid' , 'test' ), offset = 0 ):
56
- """ Defines WikiText2 datasets.
57
-
58
- Create language modeling dataset: WikiText2
59
- Separately returns the train/test/valid set
60
-
61
- Args:
62
- root: Directory where the datasets are saved. Default: ".data"
63
- split: a string or tuple for the returned datasets. Default: ('train', 'valid, 'test')
64
- By default, all the three datasets (train, test, valid) are generated. Users
65
- could also choose any one or two of them, for example ('train', 'test') or
66
- just a string 'train'.
67
- offset: the number of the starting line. Default: 0
68
-
53
+ """
69
54
Examples:
70
55
>>> from torchtext.experimental.raw.datasets import WikiText2
71
56
>>> train_dataset, valid_dataset, test_dataset = WikiText2()
@@ -76,19 +61,10 @@ def WikiText2(root='.data', split=('train', 'valid', 'test'), offset=0):
76
61
return _setup_datasets ("WikiText2" , root , split , None , None , offset )
77
62
78
63
64
+ @wrap_split_argument
65
+ @add_docstring_header
79
66
def WikiText103 (root = '.data' , split = ('train' , 'valid' , 'test' ), offset = 0 ):
80
- """ Defines WikiText103 datasets.
81
-
82
- Create language modeling dataset: WikiText103
83
- Separately returns the train/test/valid set
84
-
85
- Args:
86
- root: Directory where the datasets are saved. Default: ".data"
87
- split: the returned datasets. Default: ('train', 'valid','test')
88
- By default, all the three datasets (train, test, valid) are generated. Users
89
- could also choose any one or two of them, for example ('train', 'test').
90
- offset: the number of the starting line. Default: 0
91
-
67
+ """
92
68
Examples:
93
69
>>> from torchtext.experimental.datasets.raw import WikiText103
94
70
>>> train_dataset, valid_dataset, test_dataset = WikiText103()
@@ -98,21 +74,10 @@ def WikiText103(root='.data', split=('train', 'valid', 'test'), offset=0):
98
74
return _setup_datasets ("WikiText103" , root , split , None , None , offset )
99
75
100
76
77
+ @wrap_split_argument
78
+ @add_docstring_header
101
79
def PennTreebank (root = '.data' , split = ('train' , 'valid' , 'test' ), offset = 0 ):
102
- """ Defines PennTreebank datasets.
103
-
104
- Create language modeling dataset: PennTreebank
105
- Separately returns the train/test/valid set
106
-
107
- Args:
108
- root: Directory where the datasets are saved. Default: ".data"
109
- split: a string or tuple for the returned datasets
110
- (Default: ('train', 'test','valid'))
111
- By default, all the three datasets ('train', 'valid', 'test') are generated. Users
112
- could also choose any one or two of them, for example ('train', 'test') or
113
- just a string 'train'.
114
- offset: the number of the starting line. Default: 0
115
-
80
+ """
116
81
Examples:
117
82
>>> from torchtext.experimental.datasets.raw import PennTreebank
118
83
>>> train_dataset, valid_dataset, test_dataset = PennTreebank()
@@ -123,18 +88,11 @@ def PennTreebank(root='.data', split=('train', 'valid', 'test'), offset=0):
123
88
return _setup_datasets ("PennTreebank" , root , split , None , None , offset )
124
89
125
90
126
- def WMTNewsCrawl (root = '.data' , split = ('train' ), year = 2010 , language = 'en' , offset = 0 ):
127
- """ Defines WMT News Crawl.
128
-
129
- Create language modeling dataset: WMTNewsCrawl
130
-
131
- Args:
132
- root: Directory where the datasets are saved. Default: ".data"
133
- split: a string or tuple for the returned datasets.
134
- (Default: 'train')
135
- year: the year of the dataset (Default: 2010)
91
+ @wrap_split_argument
92
+ @add_docstring_header
93
+ def WMTNewsCrawl (root = '.data' , split = 'train' , offset = 0 , year = 2010 , language = 'en' ):
94
+ """ year: the year of the dataset (Default: 2010)
136
95
language: the language of the dataset (Default: 'en')
137
- offset: the number of the starting line. Default: 0
138
96
139
97
Note: WMTNewsCrawl provides datasets based on the year and language instead of train/valid/test.
140
98
"""
@@ -148,12 +106,14 @@ def WMTNewsCrawl(root='.data', split=('train'), year=2010, language='en', offset
148
106
'PennTreebank' : PennTreebank ,
149
107
'WMTNewsCrawl' : WMTNewsCrawl
150
108
}
109
+
151
110
NUM_LINES = {
152
111
'WikiText2' : {'train' : 36718 , 'valid' : 3760 , 'test' : 4358 },
153
112
'WikiText103' : {'train' : 1801350 , 'valid' : 3760 , 'test' : 4358 },
154
113
'PennTreebank' : {'train' : 42068 , 'valid' : 3370 , 'test' : 3761 },
155
114
'WMTNewsCrawl' : {'train' : 17676013 }
156
115
}
116
+
157
117
MD5 = {
158
118
'WikiText2' : '542ccefacc6c27f945fb54453812b3cd' ,
159
119
'WikiText103' : '9ddaacaf6af0710eda8c456decff7832' ,
0 commit comments