Skip to content

Commit 499e327

Browse files
keitakuritamttk
authored andcommitted
Exposed parameters for csv reader to user through dataset (#432)
* Exposed parameters for csv reader to user through dataset * Added tests
1 parent 558fff6 commit 499e327

File tree

2 files changed

+45
-3
lines changed

2 files changed

+45
-3
lines changed

test/data/test_dataset.py

Lines changed: 36 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,8 @@
11
# -*- coding: utf-8 -*-
22
from __future__ import unicode_literals
33
import torchtext.data as data
4+
import tempfile
5+
import six
46

57
import pytest
68

@@ -206,6 +208,40 @@ def test_csv_file_no_header_one_col_multiple_fields(self):
206208
# 6 Fields including None for ids
207209
assert len(dataset.fields) == 6
208210

211+
def test_csv_dataset_quotechar(self):
212+
# Based on issue #349
213+
example_data = [("text", "label"),
214+
('" hello world', "0"),
215+
('goodbye " world', "1"),
216+
('this is a pen " ', "0")]
217+
218+
with tempfile.NamedTemporaryFile(dir=self.test_dir) as f:
219+
for example in example_data:
220+
f.write(six.b("{}\n".format(",".join(example))))
221+
222+
TEXT = data.Field(lower=True, tokenize=lambda x: x.split())
223+
fields = {
224+
"label": ("label", data.Field(use_vocab=False,
225+
sequential=False)),
226+
"text": ("text", TEXT)
227+
}
228+
229+
f.seek(0)
230+
231+
dataset = data.TabularDataset(
232+
path=f.name, format="csv",
233+
skip_header=False, fields=fields,
234+
csv_reader_params={"quotechar": None})
235+
236+
TEXT.build_vocab(dataset)
237+
238+
self.assertEqual(len(dataset), len(example_data) - 1)
239+
240+
for i, example in enumerate(dataset):
241+
self.assertEqual(example.text,
242+
example_data[i + 1][0].lower().split())
243+
self.assertEqual(example.label, example_data[i + 1][1])
244+
209245
def test_dataset_split_arguments(self):
210246
num_examples, num_labels = 30, 3
211247
self.write_test_splitting_dataset(num_examples=num_examples,

torchtext/data/dataset.py

Lines changed: 9 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -217,7 +217,8 @@ def filter_examples(self, field_names):
217217
class TabularDataset(Dataset):
218218
"""Defines a Dataset of columns stored in CSV, TSV, or JSON format."""
219219

220-
def __init__(self, path, format, fields, skip_header=False, **kwargs):
220+
def __init__(self, path, format, fields, skip_header=False,
221+
csv_reader_params={}, **kwargs):
221222
"""Create a TabularDataset given a path, file format, and field list.
222223
223224
Arguments:
@@ -236,6 +237,11 @@ def __init__(self, path, format, fields, skip_header=False, **kwargs):
236237
This allows the user to rename columns from their JSON/CSV/TSV key names
237238
and also enables selecting a subset of columns to load.
238239
skip_header (bool): Whether to skip the first line of the input file.
240+
csv_reader_params(dict): Parameters to pass to the csv reader.
241+
Only relevant when format is csv or tsv.
242+
See
243+
https://docs.python.org/3/library/csv.html#csv.reader
244+
for more details.
239245
"""
240246
format = format.lower()
241247
make_example = {
@@ -244,9 +250,9 @@ def __init__(self, path, format, fields, skip_header=False, **kwargs):
244250

245251
with io.open(os.path.expanduser(path), encoding="utf8") as f:
246252
if format == 'csv':
247-
reader = unicode_csv_reader(f)
253+
reader = unicode_csv_reader(f, **csv_reader_params)
248254
elif format == 'tsv':
249-
reader = unicode_csv_reader(f, delimiter='\t')
255+
reader = unicode_csv_reader(f, delimiter='\t', **csv_reader_params)
250256
else:
251257
reader = f
252258

0 commit comments

Comments
 (0)