Skip to content

Commit c2dd2ff

Browse files
authored
Remove offset option from torchtext.datasets, move torchtext.datasets.common to torchtext.data.dataset_utils (#1188)
1 parent 4a7aad7 commit c2dd2ff

28 files changed

+137
-162
lines changed

test/data/test_builtin_datasets.py

-12
Original file line numberDiff line numberDiff line change
@@ -186,11 +186,6 @@ def test_raw_text_classification(self, info):
186186
self.assertEqual(torchtext.datasets.MD5[dataset_name], info['MD5'])
187187
del data_iter
188188

189-
def test_num_lines_of_dataset(self):
190-
train_iter, test_iter = torchtext.datasets.AG_NEWS(offset=10)
191-
_data = [item for item in train_iter]
192-
self.assertEqual(len(_data), 119990)
193-
194189
@parameterized.expand(list(sorted(torchtext.datasets.DATASETS.keys())))
195190
def test_raw_datasets_split_argument(self, dataset_name):
196191
if dataset_name in GOOGLE_DRIVE_BASED_DATASETS:
@@ -223,13 +218,6 @@ def test_datasets_split_argument(self, dataset_name):
223218
# Exercise default constructor
224219
_ = dataset()
225220

226-
def test_offset_dataset(self):
227-
train_iter, test_iter = torchtext.datasets.AG_NEWS(split=('train', 'test'), offset=10)
228-
container = [text[:20] for idx, (label, text) in enumerate(train_iter) if idx < 5]
229-
self.assertEqual(container, ['Oil and Economy Clou', 'No Need for OPEC to ',
230-
'Non-OPEC Nations Sho', 'Google IPO Auction O',
231-
'Dollar Falls Broadly'])
232-
233221
def test_next_method_dataset(self):
234222
train_iter, test_iter = torchtext.datasets.AG_NEWS()
235223
for_count = 0

torchtext/datasets/common.py torchtext/data/datasets_utils.py

+15-27
Original file line numberDiff line numberDiff line change
@@ -44,12 +44,11 @@ def dataset_docstring_header(fn):
4444
"""
4545
Returns docstring for a dataset based on function arguments.
4646
47-
Assumes function signature of form (root='.data', split=<some tuple of strings>, offset=0, **kwargs)
47+
Assumes function signature of form (root='.data', split=<some tuple of strings>, **kwargs)
4848
"""
4949
argspec = inspect.getfullargspec(fn)
5050
if not (argspec.args[0] == "root" and
51-
argspec.args[1] == "split" and
52-
argspec.args[2] == "offset"):
51+
argspec.args[1] == "split"):
5352
raise ValueError("Internal Error: Given function {} did not adhere to standard signature.".format(fn))
5453
default_split = argspec.defaults[1]
5554

@@ -68,8 +67,6 @@ def dataset_docstring_header(fn):
6867
By default, all three datasets are generated. Users
6968
could also choose any subset of them, for example {} or just 'train'.
7069
Default: {}
71-
offset: the number of the starting line.
72-
Default: 0
7370
""".format(fn.__name__, "/".join(default_split), str(example_subset), str(default_split))
7471

7572
if isinstance(default_split, str):
@@ -81,9 +78,7 @@ def dataset_docstring_header(fn):
8178
root: Directory where the datasets are saved.
8279
Default: ".data"
8380
split: Only {default_split} is available.
84-
Default: {default_split}
85-
offset: the number of the starting line.
86-
Default: 0""".format(fn.__name__, default_split=default_split)
81+
Default: {default_split}""".format(fn.__name__, default_split=default_split)
8782

8883
raise ValueError("default_split type expected to be of string or tuple but got {}".format(type(default_split)))
8984

@@ -116,9 +111,7 @@ def wrap_split_argument(fn):
116111
argspec = inspect.getfullargspec(fn)
117112
if not (argspec.args[0] == "root" and
118113
argspec.args[1] == "split" and
119-
argspec.args[2] == "offset" and
120114
argspec.defaults[0] == ".data" and
121-
argspec.defaults[2] == 0 and
122115
argspec.varargs is None and
123116
argspec.varkw is None and
124117
len(argspec.kwonlyargs) == 0 and
@@ -133,16 +126,15 @@ def wrap_split_argument(fn):
133126
# keyword arguments with default values only, so only a dictionary of default
134127
# values is needed to support that behavior for new_fn as well.
135128
fn_kwargs_dict = {}
136-
for arg, default in zip(argspec.args[3:], argspec.defaults[3:]):
129+
for arg, default in zip(argspec.args[2:], argspec.defaults[2:]):
137130
fn_kwargs_dict[arg] = default
138131

139132
@functools.wraps(fn)
140-
def new_fn(root='.data', split=argspec.defaults[1], offset=0, **kwargs):
133+
def new_fn(root='.data', split=argspec.defaults[1], **kwargs):
141134
for arg in fn_kwargs_dict:
142135
if arg not in kwargs:
143136
kwargs[arg] = fn_kwargs_dict[arg]
144137
kwargs["root"] = root
145-
kwargs["offset"] = offset
146138
kwargs["split"] = check_default_set(split, argspec.defaults[1], fn.__name__)
147139
result = fn(**kwargs)
148140
return wrap_datasets(tuple(result), split)
@@ -154,32 +146,28 @@ class RawTextIterableDataset(torch.utils.data.IterableDataset):
154146
"""Defines an abstraction for raw text iterable datasets.
155147
"""
156148

157-
def __init__(self, name, full_num_lines, iterator, offset=0):
149+
def __init__(self, name, full_num_lines, iterator):
158150
"""Initiate text-classification dataset.
159151
"""
160152
super(RawTextIterableDataset, self).__init__()
161153
self.name = name
162154
self.full_num_lines = full_num_lines
163155
self._iterator = iterator
164-
self.start = offset
165-
if offset < 0:
166-
raise ValueError("Given offset must be non-negative, got {} instead.".format(offset))
167-
self.num_lines = full_num_lines - offset
156+
self.num_lines = full_num_lines
157+
self.current_pos = None
168158

169159
def __iter__(self):
170-
for i, item in enumerate(self._iterator):
171-
if i < self.start:
172-
continue
173-
if self.num_lines and i >= (self.start + self.num_lines):
174-
break
175-
yield item
160+
return self
176161

177162
def __next__(self):
163+
if self.current_pos == self.num_lines - 1:
164+
raise StopIteration
178165
item = next(self._iterator)
166+
if self.current_pos is None:
167+
self.current_pos = 0
168+
else:
169+
self.current_pos += 1
179170
return item
180171

181172
def __len__(self):
182173
return self.num_lines
183-
184-
def get_iterator(self):
185-
return self._iterator

torchtext/datasets/ag_news.py

+5-5
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
from torchtext.utils import download_from_url, unicode_csv_reader
2-
from torchtext.datasets.common import RawTextIterableDataset
3-
from torchtext.datasets.common import wrap_split_argument
4-
from torchtext.datasets.common import add_docstring_header
2+
from torchtext.data.datasets_utils import RawTextIterableDataset
3+
from torchtext.data.datasets_utils import wrap_split_argument
4+
from torchtext.data.datasets_utils import add_docstring_header
55
import os
66
import io
77

@@ -23,7 +23,7 @@
2323

2424
@wrap_split_argument
2525
@add_docstring_header()
26-
def AG_NEWS(root='.data', split=('train', 'test'), offset=0):
26+
def AG_NEWS(root='.data', split=('train', 'test')):
2727
def _create_data_from_csv(data_path):
2828
with io.open(data_path, encoding="utf8") as f:
2929
reader = unicode_csv_reader(f)
@@ -37,5 +37,5 @@ def _create_data_from_csv(data_path):
3737
hash_value=MD5[item],
3838
hash_type='md5')
3939
datasets.append(RawTextIterableDataset("AG_NEWS", NUM_LINES[item],
40-
_create_data_from_csv(path), offset=offset))
40+
_create_data_from_csv(path)))
4141
return datasets

torchtext/datasets/amazonreviewfull.py

+6-6
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,8 @@
11
from torchtext.utils import download_from_url, extract_archive, unicode_csv_reader
2-
from torchtext.datasets.common import RawTextIterableDataset
3-
from torchtext.datasets.common import wrap_split_argument
4-
from torchtext.datasets.common import add_docstring_header
5-
from torchtext.datasets.common import find_match
2+
from torchtext.data.datasets_utils import RawTextIterableDataset
3+
from torchtext.data.datasets_utils import wrap_split_argument
4+
from torchtext.data.datasets_utils import add_docstring_header
5+
from torchtext.data.datasets_utils import find_match
66
import os
77
import io
88
import logging
@@ -21,7 +21,7 @@
2121

2222
@wrap_split_argument
2323
@add_docstring_header()
24-
def AmazonReviewFull(root='.data', split=('train', 'test'), offset=0):
24+
def AmazonReviewFull(root='.data', split=('train', 'test')):
2525
def _create_data_from_csv(data_path):
2626
with io.open(data_path, encoding="utf8") as f:
2727
reader = unicode_csv_reader(f)
@@ -37,5 +37,5 @@ def _create_data_from_csv(data_path):
3737
path = find_match(item + '.csv', extracted_files)
3838
logging.info('Creating {} data'.format(item))
3939
datasets.append(RawTextIterableDataset("AmazonReviewFull", NUM_LINES[item],
40-
_create_data_from_csv(path), offset=offset))
40+
_create_data_from_csv(path)))
4141
return datasets

torchtext/datasets/amazonreviewpolarity.py

+6-6
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,8 @@
11
from torchtext.utils import download_from_url, extract_archive, unicode_csv_reader
2-
from torchtext.datasets.common import RawTextIterableDataset
3-
from torchtext.datasets.common import wrap_split_argument
4-
from torchtext.datasets.common import add_docstring_header
5-
from torchtext.datasets.common import find_match
2+
from torchtext.data.datasets_utils import RawTextIterableDataset
3+
from torchtext.data.datasets_utils import wrap_split_argument
4+
from torchtext.data.datasets_utils import add_docstring_header
5+
from torchtext.data.datasets_utils import find_match
66
import os
77
import io
88

@@ -20,7 +20,7 @@
2020

2121
@wrap_split_argument
2222
@add_docstring_header()
23-
def AmazonReviewPolarity(root='.data', split=('train', 'test'), offset=0):
23+
def AmazonReviewPolarity(root='.data', split=('train', 'test')):
2424
def _create_data_from_csv(data_path):
2525
with io.open(data_path, encoding="utf8") as f:
2626
reader = unicode_csv_reader(f)
@@ -35,5 +35,5 @@ def _create_data_from_csv(data_path):
3535
for item in split:
3636
path = find_match(item + '.csv', extracted_files)
3737
datasets.append(RawTextIterableDataset("AmazonReviewPolarity", NUM_LINES[item],
38-
_create_data_from_csv(path), offset=offset))
38+
_create_data_from_csv(path)))
3939
return datasets

torchtext/datasets/conll2000chunking.py

+6-6
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,8 @@
11
from torchtext.utils import download_from_url, extract_archive
2-
from torchtext.datasets.common import RawTextIterableDataset
3-
from torchtext.datasets.common import wrap_split_argument
4-
from torchtext.datasets.common import add_docstring_header
5-
from torchtext.datasets.common import find_match
2+
from torchtext.data.datasets_utils import RawTextIterableDataset
3+
from torchtext.data.datasets_utils import wrap_split_argument
4+
from torchtext.data.datasets_utils import add_docstring_header
5+
from torchtext.data.datasets_utils import find_match
66

77
URL = {
88
'train': "https://www.clips.uantwerpen.be/conll2000/chunking/train.txt.gz",
@@ -40,12 +40,12 @@ def _create_data_from_iob(data_path, separator):
4040

4141
@wrap_split_argument
4242
@add_docstring_header()
43-
def CoNLL2000Chunking(root='.data', split=('train', 'test'), offset=0):
43+
def CoNLL2000Chunking(root='.data', split=('train', 'test')):
4444
datasets = []
4545
for item in split:
4646
dataset_tar = download_from_url(URL[item], root=root, hash_value=MD5[item], hash_type='md5')
4747
extracted_files = extract_archive(dataset_tar)
4848
data_filename = find_match(item + ".txt", extracted_files)
4949
datasets.append(RawTextIterableDataset("CoNLL2000Chunking", NUM_LINES[item],
50-
_create_data_from_iob(data_filename, " "), offset=offset))
50+
_create_data_from_iob(data_filename, " ")))
5151
return datasets

torchtext/datasets/dbpedia.py

+6-6
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,8 @@
11
from torchtext.utils import download_from_url, extract_archive, unicode_csv_reader
2-
from torchtext.datasets.common import RawTextIterableDataset
3-
from torchtext.datasets.common import wrap_split_argument
4-
from torchtext.datasets.common import add_docstring_header
5-
from torchtext.datasets.common import find_match
2+
from torchtext.data.datasets_utils import RawTextIterableDataset
3+
from torchtext.data.datasets_utils import wrap_split_argument
4+
from torchtext.data.datasets_utils import add_docstring_header
5+
from torchtext.data.datasets_utils import find_match
66
import os
77
import io
88

@@ -20,7 +20,7 @@
2020

2121
@wrap_split_argument
2222
@add_docstring_header()
23-
def DBpedia(root='.data', split=('train', 'test'), offset=0):
23+
def DBpedia(root='.data', split=('train', 'test')):
2424
def _create_data_from_csv(data_path):
2525
with io.open(data_path, encoding="utf8") as f:
2626
reader = unicode_csv_reader(f)
@@ -35,5 +35,5 @@ def _create_data_from_csv(data_path):
3535
for item in split:
3636
path = find_match(item + '.csv', extracted_files)
3737
datasets.append(RawTextIterableDataset("DBpedia", NUM_LINES[item],
38-
_create_data_from_csv(path), offset=offset))
38+
_create_data_from_csv(path)))
3939
return datasets

torchtext/datasets/enwik9.py

+5-5
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,8 @@
11
import logging
22
from torchtext.utils import download_from_url, extract_archive
3-
from torchtext.datasets.common import RawTextIterableDataset
4-
from torchtext.datasets.common import wrap_split_argument
5-
from torchtext.datasets.common import add_docstring_header
3+
from torchtext.data.datasets_utils import RawTextIterableDataset
4+
from torchtext.data.datasets_utils import wrap_split_argument
5+
from torchtext.data.datasets_utils import add_docstring_header
66
import io
77

88
URL = 'http://mattmahoney.net/dc/enwik9.zip'
@@ -16,10 +16,10 @@
1616

1717
@wrap_split_argument
1818
@add_docstring_header()
19-
def EnWik9(root='.data', split='train', offset=0):
19+
def EnWik9(root='.data', split='train'):
2020
dataset_tar = download_from_url(URL, root=root, hash_value=MD5, hash_type='md5')
2121
extracted_files = extract_archive(dataset_tar)
2222
path = extracted_files[0]
2323
logging.info('Creating {} data'.format(split[0]))
2424
return [RawTextIterableDataset('EnWik9',
25-
NUM_LINES[split[0]], iter(io.open(path, encoding="utf8")), offset=offset)]
25+
NUM_LINES[split[0]], iter(io.open(path, encoding="utf8")))]

torchtext/datasets/imdb.py

+5-5
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
from torchtext.utils import download_from_url, extract_archive
2-
from torchtext.datasets.common import RawTextIterableDataset
3-
from torchtext.datasets.common import wrap_split_argument
4-
from torchtext.datasets.common import add_docstring_header
2+
from torchtext.data.datasets_utils import RawTextIterableDataset
3+
from torchtext.data.datasets_utils import wrap_split_argument
4+
from torchtext.data.datasets_utils import add_docstring_header
55
import io
66

77
URL = 'http://ai.stanford.edu/~amaas/data/sentiment/aclImdb_v1.tar.gz'
@@ -18,7 +18,7 @@
1818

1919
@wrap_split_argument
2020
@add_docstring_header()
21-
def IMDB(root='.data', split=('train', 'test'), offset=0):
21+
def IMDB(root='.data', split=('train', 'test')):
2222
def generate_imdb_data(key, extracted_files):
2323
for fname in extracted_files:
2424
if 'urls' in fname:
@@ -33,5 +33,5 @@ def generate_imdb_data(key, extracted_files):
3333
datasets = []
3434
for item in split:
3535
iterator = generate_imdb_data(item, extracted_files)
36-
datasets.append(RawTextIterableDataset("IMDB", NUM_LINES[item], iterator, offset=offset))
36+
datasets.append(RawTextIterableDataset("IMDB", NUM_LINES[item], iterator))
3737
return datasets

torchtext/datasets/iwslt.py

+4-4
Original file line numberDiff line numberDiff line change
@@ -3,9 +3,9 @@
33
import codecs
44
import xml.etree.ElementTree as ET
55
from torchtext.utils import (download_from_url, extract_archive)
6-
from torchtext.datasets.common import RawTextIterableDataset
7-
from torchtext.datasets.common import wrap_split_argument
8-
from torchtext.datasets.common import add_docstring_header
6+
from torchtext.data.datasets_utils import RawTextIterableDataset
7+
from torchtext.data.datasets_utils import wrap_split_argument
8+
from torchtext.data.datasets_utils import add_docstring_header
99

1010
URL = 'https://drive.google.com/uc?id=1l5y6Giag9aRPwGtuZHswh3w5v3qEz8D8'
1111

@@ -278,6 +278,6 @@ def _iter(src_data_iter, tgt_data_iter):
278278
yield item
279279

280280
datasets.append(
281-
RawTextIterableDataset("IWSLT", NUM_LINES[key], _iter(src_data_iter, tgt_data_iter), offset=offset))
281+
RawTextIterableDataset("IWSLT", NUM_LINES[key], _iter(src_data_iter, tgt_data_iter)))
282282

283283
return datasets

torchtext/datasets/multi30k.py

+4-4
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,9 @@
11
import io
22
import os
33
from torchtext.utils import (download_from_url, extract_archive)
4-
from torchtext.datasets.common import RawTextIterableDataset
5-
from torchtext.datasets.common import wrap_split_argument
6-
from torchtext.datasets.common import add_docstring_header
4+
from torchtext.data.datasets_utils import RawTextIterableDataset
5+
from torchtext.data.datasets_utils import wrap_split_argument
6+
from torchtext.data.datasets_utils import add_docstring_header
77

88
_URL_BASE_ = 'https://raw.githubusercontent.com/multi30k/dataset/master/data/task'
99

@@ -171,6 +171,6 @@ def _iter(src_data_iter, tgt_data_iter):
171171
yield item
172172

173173
datasets.append(
174-
RawTextIterableDataset("Multi30k", NUM_LINES[key], _iter(src_data_iter, tgt_data_iter), offset=offset))
174+
RawTextIterableDataset("Multi30k", NUM_LINES[key], _iter(src_data_iter, tgt_data_iter)))
175175

176176
return datasets

torchtext/datasets/penntreebank.py

+5-6
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,8 @@
11
import logging
22
from torchtext.utils import download_from_url
3-
from torchtext.datasets.common import RawTextIterableDataset
4-
from torchtext.datasets.common import wrap_split_argument
5-
from torchtext.datasets.common import add_docstring_header
3+
from torchtext.data.datasets_utils import RawTextIterableDataset
4+
from torchtext.data.datasets_utils import wrap_split_argument
5+
from torchtext.data.datasets_utils import add_docstring_header
66
import io
77

88
URL = {
@@ -26,7 +26,7 @@
2626

2727
@wrap_split_argument
2828
@add_docstring_header()
29-
def PennTreebank(root='.data', split=('train', 'valid', 'test'), offset=0):
29+
def PennTreebank(root='.data', split=('train', 'valid', 'test')):
3030
datasets = []
3131
for item in split:
3232
path = download_from_url(URL[item],
@@ -35,6 +35,5 @@ def PennTreebank(root='.data', split=('train', 'valid', 'test'), offset=0):
3535
logging.info('Creating {} data'.format(item))
3636
datasets.append(RawTextIterableDataset('PennTreebank',
3737
NUM_LINES[item],
38-
iter(io.open(path, encoding="utf8")),
39-
offset=offset))
38+
iter(io.open(path, encoding="utf8"))))
4039
return datasets

0 commit comments

Comments
 (0)