|
2 | 2 | import inspect
|
3 | 3 | import os
|
4 | 4 | import io
|
| 5 | +import json |
5 | 6 | import torch
|
6 |
| -from torchtext.utils import validate_file |
7 |
| -from torchtext.utils import download_from_url |
8 |
| -from torchtext.utils import extract_archive |
| 7 | +from torchtext.utils import ( |
| 8 | + validate_file, |
| 9 | + download_from_url, |
| 10 | + extract_archive, |
| 11 | + unicode_csv_reader, |
| 12 | +) |
9 | 13 | import codecs
|
10 | 14 | import xml.etree.ElementTree as ET
|
11 | 15 | """
|
@@ -40,6 +44,53 @@ def _clean_tags_file(f_orig):
|
40 | 44 | fd_txt.write(line.strip() + '\n')
|
41 | 45 |
|
42 | 46 |
|
| 47 | +def _create_data_from_json(data_path): |
| 48 | + with open(data_path) as json_file: |
| 49 | + raw_json_data = json.load(json_file)['data'] |
| 50 | + for layer1 in raw_json_data: |
| 51 | + for layer2 in layer1['paragraphs']: |
| 52 | + for layer3 in layer2['qas']: |
| 53 | + _context, _question = layer2['context'], layer3['question'] |
| 54 | + _answers = [item['text'] for item in layer3['answers']] |
| 55 | + _answer_start = [item['answer_start'] for item in layer3['answers']] |
| 56 | + if len(_answers) == 0: |
| 57 | + _answers = [""] |
| 58 | + _answer_start = [-1] |
| 59 | + # yield the raw data in the order of context, question, answers, answer_start |
| 60 | + yield (_context, _question, _answers, _answer_start) |
| 61 | + |
| 62 | + |
| 63 | +def _create_data_from_iob(data_path, separator='\t'): |
| 64 | + with open(data_path, encoding="utf-8") as input_file: |
| 65 | + columns = [] |
| 66 | + for line in input_file: |
| 67 | + line = line.strip() |
| 68 | + if line == "": |
| 69 | + if columns: |
| 70 | + yield columns |
| 71 | + columns = [] |
| 72 | + else: |
| 73 | + for i, column in enumerate(line.split(separator)): |
| 74 | + if len(columns) < i + 1: |
| 75 | + columns.append([]) |
| 76 | + columns[i].append(column) |
| 77 | + if len(columns) > 0: |
| 78 | + yield columns |
| 79 | + |
| 80 | + |
| 81 | +def _read_text_iterator(path): |
| 82 | + with io.open(path, encoding="utf8") as f: |
| 83 | + for row in f: |
| 84 | + yield row |
| 85 | + |
| 86 | + |
| 87 | +def _create_data_from_csv(data_path): |
| 88 | + with io.open(data_path, encoding="utf8") as f: |
| 89 | + reader = unicode_csv_reader(f) |
| 90 | + for row in reader: |
| 91 | + yield int(row[0]), ' '.join(row[1:]) |
| 92 | + |
| 93 | + |
43 | 94 | def _check_default_set(split, target_select, dataset_name):
|
44 | 95 | # Check whether given object split is either a tuple of strings or string
|
45 | 96 | # and represents a valid selection of options given by the tuple of strings
|
@@ -148,7 +199,6 @@ def _wrap_split_argument_with_fn(fn, splits):
|
148 | 199 | train = AG_NEWS(split='train')
|
149 | 200 | train, valid = AG_NEWS(split=('train', 'valid'))
|
150 | 201 | """
|
151 |
| - |
152 | 202 | argspec = inspect.getfullargspec(fn)
|
153 | 203 | if not (argspec.args[0] == "root" and
|
154 | 204 | argspec.args[1] == "split" and
|
@@ -184,6 +234,30 @@ def new_fn(fn):
|
184 | 234 | return new_fn
|
185 | 235 |
|
186 | 236 |
|
| 237 | +def _create_dataset_directory(dataset_name): |
| 238 | + def decorator(func): |
| 239 | + argspec = inspect.getfullargspec(func) |
| 240 | + if not (argspec.args[0] == "root" and |
| 241 | + argspec.args[1] == "split" and |
| 242 | + argspec.varargs is None and |
| 243 | + argspec.varkw is None and |
| 244 | + len(argspec.kwonlyargs) == 0 and |
| 245 | + len(argspec.annotations) == 0 |
| 246 | + ): |
| 247 | + raise ValueError("Internal Error: Given function {} did not adhere to standard signature.".format(fn)) |
| 248 | + |
| 249 | + @functools.wraps(func) |
| 250 | + def wrapper(root='.data', *args, **kwargs): |
| 251 | + new_root = os.path.join(root, dataset_name) |
| 252 | + if not os.path.exists(new_root): |
| 253 | + os.makedirs(new_root) |
| 254 | + return func(root=new_root, *args, **kwargs) |
| 255 | + |
| 256 | + return wrapper |
| 257 | + |
| 258 | + return decorator |
| 259 | + |
| 260 | + |
187 | 261 | def _download_extract_validate(root, url, url_md5, downloaded_file, extracted_file, extracted_file_md5,
|
188 | 262 | hash_type="sha256"):
|
189 | 263 | root = os.path.abspath(root)
|
|
0 commit comments