Skip to content

Commit 5e62157

Browse files
authored
Raw datasets cleanup (#1233)
1 parent 408c4cf commit 5e62157

23 files changed

+308
-270
lines changed

torchtext/data/datasets_utils.py

+78-4
Original file line numberDiff line numberDiff line change
@@ -2,10 +2,14 @@
22
import inspect
33
import os
44
import io
5+
import json
56
import torch
6-
from torchtext.utils import validate_file
7-
from torchtext.utils import download_from_url
8-
from torchtext.utils import extract_archive
7+
from torchtext.utils import (
8+
validate_file,
9+
download_from_url,
10+
extract_archive,
11+
unicode_csv_reader,
12+
)
913
import codecs
1014
import xml.etree.ElementTree as ET
1115
"""
@@ -40,6 +44,53 @@ def _clean_tags_file(f_orig):
4044
fd_txt.write(line.strip() + '\n')
4145

4246

47+
def _create_data_from_json(data_path):
48+
with open(data_path) as json_file:
49+
raw_json_data = json.load(json_file)['data']
50+
for layer1 in raw_json_data:
51+
for layer2 in layer1['paragraphs']:
52+
for layer3 in layer2['qas']:
53+
_context, _question = layer2['context'], layer3['question']
54+
_answers = [item['text'] for item in layer3['answers']]
55+
_answer_start = [item['answer_start'] for item in layer3['answers']]
56+
if len(_answers) == 0:
57+
_answers = [""]
58+
_answer_start = [-1]
59+
# yield the raw data in the order of context, question, answers, answer_start
60+
yield (_context, _question, _answers, _answer_start)
61+
62+
63+
def _create_data_from_iob(data_path, separator='\t'):
64+
with open(data_path, encoding="utf-8") as input_file:
65+
columns = []
66+
for line in input_file:
67+
line = line.strip()
68+
if line == "":
69+
if columns:
70+
yield columns
71+
columns = []
72+
else:
73+
for i, column in enumerate(line.split(separator)):
74+
if len(columns) < i + 1:
75+
columns.append([])
76+
columns[i].append(column)
77+
if len(columns) > 0:
78+
yield columns
79+
80+
81+
def _read_text_iterator(path):
82+
with io.open(path, encoding="utf8") as f:
83+
for row in f:
84+
yield row
85+
86+
87+
def _create_data_from_csv(data_path):
88+
with io.open(data_path, encoding="utf8") as f:
89+
reader = unicode_csv_reader(f)
90+
for row in reader:
91+
yield int(row[0]), ' '.join(row[1:])
92+
93+
4394
def _check_default_set(split, target_select, dataset_name):
4495
# Check whether given object split is either a tuple of strings or string
4596
# and represents a valid selection of options given by the tuple of strings
@@ -148,7 +199,6 @@ def _wrap_split_argument_with_fn(fn, splits):
148199
train = AG_NEWS(split='train')
149200
train, valid = AG_NEWS(split=('train', 'valid'))
150201
"""
151-
152202
argspec = inspect.getfullargspec(fn)
153203
if not (argspec.args[0] == "root" and
154204
argspec.args[1] == "split" and
@@ -184,6 +234,30 @@ def new_fn(fn):
184234
return new_fn
185235

186236

237+
def _create_dataset_directory(dataset_name):
238+
def decorator(func):
239+
argspec = inspect.getfullargspec(func)
240+
if not (argspec.args[0] == "root" and
241+
argspec.args[1] == "split" and
242+
argspec.varargs is None and
243+
argspec.varkw is None and
244+
len(argspec.kwonlyargs) == 0 and
245+
len(argspec.annotations) == 0
246+
):
247+
raise ValueError("Internal Error: Given function {} did not adhere to standard signature.".format(fn))
248+
249+
@functools.wraps(func)
250+
def wrapper(root='.data', *args, **kwargs):
251+
new_root = os.path.join(root, dataset_name)
252+
if not os.path.exists(new_root):
253+
os.makedirs(new_root)
254+
return func(root=new_root, *args, **kwargs)
255+
256+
return wrapper
257+
258+
return decorator
259+
260+
187261
def _download_extract_validate(root, url, url_md5, downloaded_file, extracted_file, extracted_file_md5,
188262
hash_type="sha256"):
189263
root = os.path.abspath(root)

torchtext/datasets/ag_news.py

+13-13
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,14 @@
1-
from torchtext.utils import download_from_url, unicode_csv_reader
2-
from torchtext.data.datasets_utils import _RawTextIterableDataset
3-
from torchtext.data.datasets_utils import _wrap_split_argument
4-
from torchtext.data.datasets_utils import _add_docstring_header
1+
from torchtext.utils import (
2+
download_from_url,
3+
)
4+
from torchtext.data.datasets_utils import (
5+
_RawTextIterableDataset,
6+
_wrap_split_argument,
7+
_add_docstring_header,
8+
_create_dataset_directory,
9+
_create_data_from_csv,
10+
)
511
import os
6-
import io
712

813
URL = {
914
'train': "https://raw.githubusercontent.com/mhjabreel/CharCnn_Keras/master/data/ag_news_csv/train.csv",
@@ -20,19 +25,14 @@
2025
'test': 7600,
2126
}
2227

23-
28+
DATASET_NAME = "AG_NEWS"
2429
@_add_docstring_header(num_lines=NUM_LINES, num_classes=4)
30+
@_create_dataset_directory(dataset_name=DATASET_NAME)
2531
@_wrap_split_argument(('train', 'test'))
2632
def AG_NEWS(root, split):
27-
def _create_data_from_csv(data_path):
28-
with io.open(data_path, encoding="utf8") as f:
29-
reader = unicode_csv_reader(f)
30-
for row in reader:
31-
yield int(row[0]), ' '.join(row[1:])
32-
3333
path = download_from_url(URL[split], root=root,
3434
path=os.path.join(root, split + ".csv"),
3535
hash_value=MD5[split],
3636
hash_type='md5')
37-
return _RawTextIterableDataset("AG_NEWS", NUM_LINES[split],
37+
return _RawTextIterableDataset(DATASET_NAME, NUM_LINES[split],
3838
_create_data_from_csv(path))

torchtext/datasets/amazonreviewfull.py

+11-14
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,11 @@
1-
from torchtext.utils import unicode_csv_reader
2-
from torchtext.data.datasets_utils import _RawTextIterableDataset
3-
from torchtext.data.datasets_utils import _wrap_split_argument
4-
from torchtext.data.datasets_utils import _add_docstring_header
5-
from torchtext.data.datasets_utils import _download_extract_validate
6-
import io
1+
from torchtext.data.datasets_utils import (
2+
_RawTextIterableDataset,
3+
_wrap_split_argument,
4+
_add_docstring_header,
5+
_download_extract_validate,
6+
_create_dataset_directory,
7+
_create_data_from_csv,
8+
)
79
import os
810
import logging
911

@@ -28,18 +30,13 @@
2830
'test': "0f1e78ab60f625f2a30eab6810ef987c"
2931
}
3032

31-
33+
DATASET_NAME = "AmazonReviewFull"
3234
@_add_docstring_header(num_lines=NUM_LINES, num_classes=5)
35+
@_create_dataset_directory(dataset_name=DATASET_NAME)
3336
@_wrap_split_argument(('train', 'test'))
3437
def AmazonReviewFull(root, split):
35-
def _create_data_from_csv(data_path):
36-
with io.open(data_path, encoding="utf8") as f:
37-
reader = unicode_csv_reader(f)
38-
for row in reader:
39-
yield int(row[0]), ' '.join(row[1:])
40-
4138
path = _download_extract_validate(root, URL, MD5, os.path.join(root, _PATH), os.path.join(root, _EXTRACTED_FILES[split]),
4239
_EXTRACTED_FILES_MD5[split], hash_type="md5")
4340
logging.info('Creating {} data'.format(split))
44-
return _RawTextIterableDataset("AmazonReviewFull", NUM_LINES[split],
41+
return _RawTextIterableDataset(DATASET_NAME, NUM_LINES[split],
4542
_create_data_from_csv(path))

torchtext/datasets/amazonreviewpolarity.py

+11-13
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,11 @@
1-
from torchtext.utils import unicode_csv_reader
2-
from torchtext.data.datasets_utils import _RawTextIterableDataset
3-
from torchtext.data.datasets_utils import _wrap_split_argument
4-
from torchtext.data.datasets_utils import _add_docstring_header
5-
from torchtext.data.datasets_utils import _download_extract_validate
6-
import io
1+
from torchtext.data.datasets_utils import (
2+
_RawTextIterableDataset,
3+
_wrap_split_argument,
4+
_add_docstring_header,
5+
_download_extract_validate,
6+
_create_dataset_directory,
7+
_create_data_from_csv,
8+
)
79
import os
810
import logging
911

@@ -28,17 +30,13 @@
2830
'test': "f4c8bded2ecbde5f996b675db6228f16"
2931
}
3032

31-
33+
DATASET_NAME = "AmazonReviewPolarity"
3234
@_add_docstring_header(num_lines=NUM_LINES, num_classes=2)
35+
@_create_dataset_directory(dataset_name=DATASET_NAME)
3336
@_wrap_split_argument(('train', 'test'))
3437
def AmazonReviewPolarity(root, split):
35-
def _create_data_from_csv(data_path):
36-
with io.open(data_path, encoding="utf8") as f:
37-
reader = unicode_csv_reader(f)
38-
for row in reader:
39-
yield int(row[0]), ' '.join(row[1:])
4038
path = _download_extract_validate(root, URL, MD5, os.path.join(root, _PATH), os.path.join(root, _EXTRACTED_FILES[split]),
4139
_EXTRACTED_FILES_MD5[split], hash_type="md5")
4240
logging.info('Creating {} data'.format(split))
43-
return _RawTextIterableDataset("AmazonReviewPolarity", NUM_LINES[split],
41+
return _RawTextIterableDataset(DATASET_NAME, NUM_LINES[split],
4442
_create_data_from_csv(path))
+11-23
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,11 @@
1-
from torchtext.data.datasets_utils import _RawTextIterableDataset
2-
from torchtext.data.datasets_utils import _wrap_split_argument
3-
from torchtext.data.datasets_utils import _add_docstring_header
4-
from torchtext.data.datasets_utils import _download_extract_validate
1+
from torchtext.data.datasets_utils import (
2+
_RawTextIterableDataset,
3+
_wrap_split_argument,
4+
_add_docstring_header,
5+
_download_extract_validate,
6+
_create_dataset_directory,
7+
_create_data_from_iob,
8+
)
59
import os
610
import logging
711

@@ -31,25 +35,9 @@
3135
}
3236

3337

34-
def _create_data_from_iob(data_path, separator):
35-
with open(data_path, encoding="utf-8") as input_file:
36-
columns = []
37-
for line in input_file:
38-
line = line.strip()
39-
if line == "":
40-
if columns:
41-
yield columns
42-
columns = []
43-
else:
44-
for i, column in enumerate(line.split(separator)):
45-
if len(columns) < i + 1:
46-
columns.append([])
47-
columns[i].append(column)
48-
if len(columns) > 0:
49-
yield columns
50-
51-
38+
DATASET_NAME = "CoNLL2000Chunking"
5239
@_add_docstring_header(num_lines=NUM_LINES)
40+
@_create_dataset_directory(dataset_name=DATASET_NAME)
5341
@_wrap_split_argument(('train', 'test'))
5442
def CoNLL2000Chunking(root, split):
5543
# Create a dataset specific subfolder to deal with generic download filenames
@@ -58,5 +46,5 @@ def CoNLL2000Chunking(root, split):
5846
data_filename = _download_extract_validate(root, URL[split], MD5[split], path, os.path.join(root, _EXTRACTED_FILES[split]),
5947
_EXTRACTED_FILES_MD5[split], hash_type="md5")
6048
logging.info('Creating {} data'.format(split))
61-
return _RawTextIterableDataset("CoNLL2000Chunking", NUM_LINES[split],
49+
return _RawTextIterableDataset(DATASET_NAME, NUM_LINES[split],
6250
_create_data_from_iob(data_filename, " "))

torchtext/datasets/dbpedia.py

+15-13
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,16 @@
1-
from torchtext.utils import download_from_url, extract_archive, unicode_csv_reader
2-
from torchtext.data.datasets_utils import _RawTextIterableDataset
3-
from torchtext.data.datasets_utils import _wrap_split_argument
4-
from torchtext.data.datasets_utils import _add_docstring_header
5-
from torchtext.data.datasets_utils import _find_match
1+
from torchtext.utils import (
2+
download_from_url,
3+
extract_archive,
4+
)
5+
from torchtext.data.datasets_utils import (
6+
_RawTextIterableDataset,
7+
_wrap_split_argument,
8+
_add_docstring_header,
9+
_find_match,
10+
_create_dataset_directory,
11+
_create_data_from_csv,
12+
)
613
import os
7-
import io
814

915
URL = 'https://drive.google.com/uc?export=download&id=0Bz8a_Dbh9QhbQ2Vic1kxMmZZQ1k'
1016

@@ -17,20 +23,16 @@
1723

1824
_PATH = 'dbpedia_csv.tar.gz'
1925

20-
26+
DATASET_NAME = "DBpedia"
2127
@_add_docstring_header(num_lines=NUM_LINES, num_classes=14)
28+
@_create_dataset_directory(dataset_name=DATASET_NAME)
2229
@_wrap_split_argument(('train', 'test'))
2330
def DBpedia(root, split):
24-
def _create_data_from_csv(data_path):
25-
with io.open(data_path, encoding="utf8") as f:
26-
reader = unicode_csv_reader(f)
27-
for row in reader:
28-
yield int(row[0]), ' '.join(row[1:])
2931
dataset_tar = download_from_url(URL, root=root,
3032
path=os.path.join(root, _PATH),
3133
hash_value=MD5, hash_type='md5')
3234
extracted_files = extract_archive(dataset_tar)
3335

3436
path = _find_match(split + '.csv', extracted_files)
35-
return _RawTextIterableDataset("DBpedia", NUM_LINES[split],
37+
return _RawTextIterableDataset(DATASET_NAME, NUM_LINES[split],
3638
_create_data_from_csv(path))

torchtext/datasets/enwik9.py

+15-8
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,15 @@
11
import logging
2-
from torchtext.utils import download_from_url, extract_archive
3-
from torchtext.data.datasets_utils import _RawTextIterableDataset
4-
from torchtext.data.datasets_utils import _wrap_split_argument
5-
from torchtext.data.datasets_utils import _add_docstring_header
6-
import io
2+
from torchtext.utils import (
3+
download_from_url,
4+
extract_archive,
5+
)
6+
from torchtext.data.datasets_utils import (
7+
_RawTextIterableDataset,
8+
_wrap_split_argument,
9+
_add_docstring_header,
10+
_create_dataset_directory,
11+
_read_text_iterator,
12+
)
713

814
URL = 'http://mattmahoney.net/dc/enwik9.zip'
915

@@ -13,13 +19,14 @@
1319
'train': 13147026
1420
}
1521

16-
22+
DATASET_NAME = "EnWik9"
1723
@_add_docstring_header(num_lines=NUM_LINES)
24+
@_create_dataset_directory(dataset_name=DATASET_NAME)
1825
@_wrap_split_argument(('train',))
1926
def EnWik9(root, split):
2027
dataset_tar = download_from_url(URL, root=root, hash_value=MD5, hash_type='md5')
2128
extracted_files = extract_archive(dataset_tar)
2229
path = extracted_files[0]
2330
logging.info('Creating {} data'.format(split))
24-
return _RawTextIterableDataset('EnWik9',
25-
NUM_LINES[split], iter(io.open(path, encoding="utf8")))
31+
return _RawTextIterableDataset(DATASET_NAME,
32+
NUM_LINES[split], _read_text_iterator(path))

torchtext/datasets/imdb.py

+4-2
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22
from torchtext.data.datasets_utils import _RawTextIterableDataset
33
from torchtext.data.datasets_utils import _wrap_split_argument
44
from torchtext.data.datasets_utils import _add_docstring_header
5+
from torchtext.data.datasets_utils import _create_dataset_directory
56
import io
67

78
URL = 'http://ai.stanford.edu/~amaas/data/sentiment/aclImdb_v1.tar.gz'
@@ -15,8 +16,9 @@
1516

1617
_PATH = 'aclImdb_v1.tar.gz'
1718

18-
19+
DATASET_NAME = "IMDB"
1920
@_add_docstring_header(num_lines=NUM_LINES, num_classes=2)
21+
@_create_dataset_directory(dataset_name=DATASET_NAME)
2022
@_wrap_split_argument(('train', 'test'))
2123
def IMDB(root, split):
2224
def generate_imdb_data(key, extracted_files):
@@ -31,4 +33,4 @@ def generate_imdb_data(key, extracted_files):
3133
hash_value=MD5, hash_type='md5')
3234
extracted_files = extract_archive(dataset_tar)
3335
iterator = generate_imdb_data(split, extracted_files)
34-
return _RawTextIterableDataset("IMDB", NUM_LINES[split], iterator)
36+
return _RawTextIterableDataset(DATASET_NAME, NUM_LINES[split], iterator)

0 commit comments

Comments
 (0)