Skip to content

Commit 9053d95

Browse files
authored
[BC-breaking] Standardize raw dataset doc strings and argument order. (#1151)
1 parent ef363fa commit 9053d95

File tree

6 files changed

+477
-547
lines changed

6 files changed

+477
-547
lines changed

torchtext/experimental/datasets/raw/common.py

+105
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,6 @@
11
import torch
2+
import inspect
3+
import functools
24

35

46
def check_default_set(split, target_select, dataset_name):
@@ -25,6 +27,109 @@ def wrap_datasets(datasets, split):
2527
return datasets
2628

2729

30+
def dataset_docstring_header(fn):
31+
"""
32+
Returns docstring for a dataset based on function arguments.
33+
34+
Assumes function signature of form (root='.data', split=<some tuple of strings>, offset=0, **kwargs)
35+
"""
36+
argspec = inspect.getfullargspec(fn)
37+
if not (argspec.args[0] == "root" and
38+
argspec.args[1] == "split" and
39+
argspec.args[2] == "offset"):
40+
raise ValueError("Internal Error: Given function {} did not adhere to standard signature.".format(fn))
41+
default_split = argspec.defaults[1]
42+
43+
if isinstance(default_split, tuple):
44+
example_subset = default_split[:2]
45+
if len(default_split) < 3:
46+
example_subset = (default_split[1],)
47+
return """{} dataset
48+
49+
Separately returns the {} split
50+
51+
Args:
52+
root: Directory where the datasets are saved.
53+
Default: ".data"
54+
split: split or splits to be returned. Can be a string or tuple of strings.
55+
By default, all three datasets are generated. Users
56+
could also choose any subset of them, for example {} or just 'train'.
57+
Default: {}
58+
offset: the number of the starting line.
59+
Default: 0
60+
""".format(fn.__name__, "/".join(default_split), str(example_subset), str(default_split)) + fn.__doc__
61+
62+
if isinstance(default_split, str):
63+
return """{} dataset
64+
65+
Only returns the {default_split} split
66+
67+
Args:
68+
root: Directory where the datasets are saved.
69+
Default: ".data"
70+
split: Only {default_split} is available.
71+
Default: {default_split}
72+
offset: the number of the starting line.
73+
Default: 0
74+
""".format(fn.__name__, default_split=default_split) + fn.__doc__
75+
76+
raise ValueError("default_split type expected to be of string or tuple but got {}".format(type(default_split)))
77+
78+
79+
def add_docstring_header(fn):
80+
fn.__doc__ = dataset_docstring_header(fn)
81+
return fn
82+
83+
84+
def wrap_split_argument(fn):
85+
"""
86+
Wraps given function of specific signature to extend behavior of split
87+
to support individual strings. The given function is expected to have a split
88+
kwarg that accepts tuples of strings, e.g. ('train', 'valid') and the returned
89+
function will have a split argument that also accepts strings, e.g. 'train', which
90+
are then turned single entry tuples. Furthermore, the return value of the wrapped
91+
function is unpacked if split is only a single string to enable behavior such as
92+
93+
train = AG_NEWS(split='train')
94+
train, valid = AG_NEWS(split=('train', 'valid'))
95+
"""
96+
97+
argspec = inspect.getfullargspec(fn)
98+
if not (argspec.args[0] == "root" and
99+
argspec.args[1] == "split" and
100+
argspec.args[2] == "offset" and
101+
argspec.defaults[0] == ".data" and
102+
argspec.defaults[2] == 0 and
103+
argspec.varargs is None and
104+
argspec.varkw is None and
105+
len(argspec.kwonlyargs) == 0 and
106+
argspec.kwonlydefaults is None and
107+
len(argspec.annotations) == 0
108+
):
109+
raise ValueError("Internal Error: Given function {} did not adhere to standard signature.".format(fn))
110+
111+
# functools.wraps only forwards __module__, __name__, etc
112+
# (see https://docs.python.org/3/library/functools.html#functools.update_wrapper)
113+
# but not default values of arguments. The wrapped function fn is assumed to have
114+
# keyword arguments with default values only, so only a dictionary of default
115+
# values is needed to support that behavior for new_fn as well.
116+
fn_kwargs_dict = {}
117+
for arg, default in zip(argspec.args, argspec.defaults):
118+
fn_kwargs_dict[arg] = default
119+
120+
@functools.wraps(fn)
121+
def new_fn(**kwargs):
122+
for arg in fn_kwargs_dict:
123+
if arg not in kwargs:
124+
kwargs[arg] = fn_kwargs_dict[arg]
125+
orig_split = kwargs["split"]
126+
kwargs["split"] = check_default_set(orig_split, argspec.defaults[1], fn.__name__)
127+
result = fn(**kwargs)
128+
return wrap_datasets(tuple(result), orig_split)
129+
130+
return new_fn
131+
132+
28133
class RawTextIterableDataset(torch.utils.data.IterableDataset):
29134
"""Defines an abstraction for raw text iterable datasets.
30135
"""

torchtext/experimental/datasets/raw/language_modeling.py

+19-59
Original file line numberDiff line numberDiff line change
@@ -2,8 +2,8 @@
22
import io
33
from torchtext.utils import download_from_url, extract_archive
44
from torchtext.experimental.datasets.raw.common import RawTextIterableDataset
5-
from torchtext.experimental.datasets.raw.common import check_default_set
6-
from torchtext.experimental.datasets.raw.common import wrap_datasets
5+
from torchtext.experimental.datasets.raw.common import wrap_split_argument
6+
from torchtext.experimental.datasets.raw.common import add_docstring_header
77

88
URLS = {
99
'WikiText2':
@@ -19,12 +19,7 @@
1919
}
2020

2121

22-
def _setup_datasets(dataset_name, root, split_, year, language, offset):
23-
if dataset_name == 'WMTNewsCrawl':
24-
split = check_default_set(split_, ('train',), dataset_name)
25-
else:
26-
split = check_default_set(split_, ('train', 'test', 'valid'), dataset_name)
27-
22+
def _setup_datasets(dataset_name, root, split, year, language, offset):
2823
if dataset_name == 'PennTreebank':
2924
extracted_files = [download_from_url(URLS['PennTreebank'][key],
3025
root=root, hash_value=MD5['PennTreebank'][key],
@@ -49,23 +44,13 @@ def _setup_datasets(dataset_name, root, split_, year, language, offset):
4944
datasets.append(RawTextIterableDataset(dataset_name,
5045
NUM_LINES[dataset_name][item], iter(io.open(path[item], encoding="utf8")), offset=offset))
5146

52-
return wrap_datasets(tuple(datasets), split_)
47+
return datasets
5348

5449

50+
@wrap_split_argument
51+
@add_docstring_header
5552
def WikiText2(root='.data', split=('train', 'valid', 'test'), offset=0):
56-
""" Defines WikiText2 datasets.
57-
58-
Create language modeling dataset: WikiText2
59-
Separately returns the train/test/valid set
60-
61-
Args:
62-
root: Directory where the datasets are saved. Default: ".data"
63-
split: a string or tuple for the returned datasets. Default: ('train', 'valid, 'test')
64-
By default, all the three datasets (train, test, valid) are generated. Users
65-
could also choose any one or two of them, for example ('train', 'test') or
66-
just a string 'train'.
67-
offset: the number of the starting line. Default: 0
68-
53+
"""
6954
Examples:
7055
>>> from torchtext.experimental.raw.datasets import WikiText2
7156
>>> train_dataset, valid_dataset, test_dataset = WikiText2()
@@ -76,19 +61,10 @@ def WikiText2(root='.data', split=('train', 'valid', 'test'), offset=0):
7661
return _setup_datasets("WikiText2", root, split, None, None, offset)
7762

7863

64+
@wrap_split_argument
65+
@add_docstring_header
7966
def WikiText103(root='.data', split=('train', 'valid', 'test'), offset=0):
80-
""" Defines WikiText103 datasets.
81-
82-
Create language modeling dataset: WikiText103
83-
Separately returns the train/test/valid set
84-
85-
Args:
86-
root: Directory where the datasets are saved. Default: ".data"
87-
split: the returned datasets. Default: ('train', 'valid','test')
88-
By default, all the three datasets (train, test, valid) are generated. Users
89-
could also choose any one or two of them, for example ('train', 'test').
90-
offset: the number of the starting line. Default: 0
91-
67+
"""
9268
Examples:
9369
>>> from torchtext.experimental.datasets.raw import WikiText103
9470
>>> train_dataset, valid_dataset, test_dataset = WikiText103()
@@ -98,21 +74,10 @@ def WikiText103(root='.data', split=('train', 'valid', 'test'), offset=0):
9874
return _setup_datasets("WikiText103", root, split, None, None, offset)
9975

10076

77+
@wrap_split_argument
78+
@add_docstring_header
10179
def PennTreebank(root='.data', split=('train', 'valid', 'test'), offset=0):
102-
""" Defines PennTreebank datasets.
103-
104-
Create language modeling dataset: PennTreebank
105-
Separately returns the train/test/valid set
106-
107-
Args:
108-
root: Directory where the datasets are saved. Default: ".data"
109-
split: a string or tuple for the returned datasets
110-
(Default: ('train', 'test','valid'))
111-
By default, all the three datasets ('train', 'valid', 'test') are generated. Users
112-
could also choose any one or two of them, for example ('train', 'test') or
113-
just a string 'train'.
114-
offset: the number of the starting line. Default: 0
115-
80+
"""
11681
Examples:
11782
>>> from torchtext.experimental.datasets.raw import PennTreebank
11883
>>> train_dataset, valid_dataset, test_dataset = PennTreebank()
@@ -123,18 +88,11 @@ def PennTreebank(root='.data', split=('train', 'valid', 'test'), offset=0):
12388
return _setup_datasets("PennTreebank", root, split, None, None, offset)
12489

12590

126-
def WMTNewsCrawl(root='.data', split=('train'), year=2010, language='en', offset=0):
127-
""" Defines WMT News Crawl.
128-
129-
Create language modeling dataset: WMTNewsCrawl
130-
131-
Args:
132-
root: Directory where the datasets are saved. Default: ".data"
133-
split: a string or tuple for the returned datasets.
134-
(Default: 'train')
135-
year: the year of the dataset (Default: 2010)
91+
@wrap_split_argument
92+
@add_docstring_header
93+
def WMTNewsCrawl(root='.data', split='train', offset=0, year=2010, language='en'):
94+
""" year: the year of the dataset (Default: 2010)
13695
language: the language of the dataset (Default: 'en')
137-
offset: the number of the starting line. Default: 0
13896
13997
Note: WMTNewsCrawl provides datasets based on the year and language instead of train/valid/test.
14098
"""
@@ -148,12 +106,14 @@ def WMTNewsCrawl(root='.data', split=('train'), year=2010, language='en', offset
148106
'PennTreebank': PennTreebank,
149107
'WMTNewsCrawl': WMTNewsCrawl
150108
}
109+
151110
NUM_LINES = {
152111
'WikiText2': {'train': 36718, 'valid': 3760, 'test': 4358},
153112
'WikiText103': {'train': 1801350, 'valid': 3760, 'test': 4358},
154113
'PennTreebank': {'train': 42068, 'valid': 3370, 'test': 3761},
155114
'WMTNewsCrawl': {'train': 17676013}
156115
}
116+
157117
MD5 = {
158118
'WikiText2': '542ccefacc6c27f945fb54453812b3cd',
159119
'WikiText103': '9ddaacaf6af0710eda8c456decff7832',

torchtext/experimental/datasets/raw/question_answer.py

+23-32
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,8 @@
11
from torchtext.utils import download_from_url
22
import json
33
from torchtext.experimental.datasets.raw.common import RawTextIterableDataset
4-
from torchtext.experimental.datasets.raw.common import check_default_set
5-
from torchtext.experimental.datasets.raw.common import wrap_datasets
4+
from torchtext.experimental.datasets.raw.common import wrap_split_argument
5+
from torchtext.experimental.datasets.raw.common import add_docstring_header
66

77
URLS = {
88
'SQuAD1':
@@ -30,59 +30,48 @@ def _create_data_from_json(data_path):
3030
yield (_context, _question, _answers, _answer_start)
3131

3232

33-
def _setup_datasets(dataset_name, root, split_, offset):
34-
split = check_default_set(split_, ('train', 'dev'), dataset_name)
33+
def _setup_datasets(dataset_name, root, split, offset):
3534
extracted_files = {key: download_from_url(URLS[dataset_name][key], root=root,
3635
hash_value=MD5[dataset_name][key], hash_type='md5') for key in split}
37-
return wrap_datasets(tuple(RawTextIterableDataset(dataset_name, NUM_LINES[dataset_name][item],
38-
_create_data_from_json(extracted_files[item]), offset=offset) for item in split), split_)
36+
return [RawTextIterableDataset(dataset_name, NUM_LINES[dataset_name][item],
37+
_create_data_from_json(extracted_files[item]), offset=offset) for item in split]
3938

4039

40+
@wrap_split_argument
41+
@add_docstring_header
4142
def SQuAD1(root='.data', split=('train', 'dev'), offset=0):
42-
""" A dataset iterator yields the data of Stanford Question Answering dataset - SQuAD1.0.
43+
"""
44+
Examples:
45+
>>> train_dataset, dev_dataset = torchtext.experimental.datasets.raw.SQuAD1()
46+
>>> for idx, (context, question, answer, ans_pos) in enumerate(train_dataset):
47+
>>> print(idx, (context, question, answer, ans_pos))
48+
4349
The iterator yields a tuple of (raw context, raw question, a list of raw answer,
4450
a list of answer positions in the raw context).
4551
For example, ('Architecturally, the school has a Catholic character. Atop the ...',
4652
'To whom did the Virgin Mary allegedly appear in 1858 in Lourdes France?',
4753
['Saint Bernadette Soubirous'],
4854
[515])
49-
50-
Args:
51-
root: Directory where the datasets are saved. Default: ".data"
52-
split: a string or tuple for the returned datasets (Default: ('train', 'dev'))
53-
By default, both datasets (train, dev) are generated. Users could also choose any one or two of them,
54-
for example ('train', 'dev') or just a string 'train'.
55-
offset: the number of the starting line. Default: 0
56-
57-
Examples:
58-
>>> train_dataset, dev_dataset = torchtext.experimental.datasets.raw.SQuAD1()
59-
>>> for idx, (context, question, answer, ans_pos) in enumerate(train_dataset):
60-
>>> print(idx, (context, question, answer, ans_pos))
6155
"""
6256

6357
return _setup_datasets("SQuAD1", root, split, offset)
6458

6559

60+
@wrap_split_argument
61+
@add_docstring_header
6662
def SQuAD2(root='.data', split=('train', 'dev'), offset=0):
67-
""" A dataset iterator yields the data of Stanford Question Answering dataset - SQuAD2.0.
63+
"""
64+
Examples:
65+
>>> train_dataset, dev_dataset = torchtext.experimental.datasets.raw.SQuAD2()
66+
>>> for idx, (context, question, answer, ans_pos) in enumerate(train_dataset):
67+
>>> print(idx, (context, question, answer, ans_pos))
68+
6869
The iterator yields a tuple of (raw context, raw question, a list of raw answer,
6970
a list of answer positions in the raw context).
7071
For example, ('Beyoncé Giselle Knowles-Carter (/biːˈjɒnseɪ/ bee-YON-say) (born September 4, 1981) is an ...',
7172
'When did Beyonce start becoming popular?',
7273
['in the late 1990s'],
7374
[269])
74-
75-
Args:
76-
root: Directory where the datasets are saved. Default: ".data"
77-
split: a string or tuple for the returned datasets (Default: ('train', 'dev'))
78-
By default, both datasets (train, dev) are generated. Users could also choose any one or two of them,
79-
for example ('train', 'dev') or just a string 'train'.
80-
offset: the number of the starting line. Default: 0
81-
82-
Examples:
83-
>>> train_dataset, dev_dataset = torchtext.experimental.datasets.raw.SQuAD2()
84-
>>> for idx, (context, question, answer, ans_pos) in enumerate(train_dataset):
85-
>>> print(idx, (context, question, answer, ans_pos))
8675
"""
8776

8877
return _setup_datasets("SQuAD2", root, split, offset)
@@ -92,10 +81,12 @@ def SQuAD2(root='.data', split=('train', 'dev'), offset=0):
9281
'SQuAD1': SQuAD1,
9382
'SQuAD2': SQuAD2
9483
}
84+
9585
NUM_LINES = {
9686
'SQuAD1': {'train': 87599, 'dev': 10570},
9787
'SQuAD2': {'train': 130319, 'dev': 11873}
9888
}
89+
9990
MD5 = {
10091
'SQuAD1': {'train': '981b29407e0affa3b1b156f72073b945', 'dev': '3e85deb501d4e538b6bc56f786231552'},
10192
'SQuAD2': {'train': '62108c273c268d70893182d5cf8df740', 'dev': '246adae8b7002f8679c027697b0b7cf8'}

0 commit comments

Comments
 (0)