Skip to content

Commit 7a2e442

Browse files
sivareddygjekbradbury
authored andcommitted
Sequence Labeling Dataset (#157)
1 parent 08adbbf commit 7a2e442

File tree

4 files changed

+119
-3
lines changed

4 files changed

+119
-3
lines changed

README.md

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -74,8 +74,9 @@ The datasets module currently contains:
7474
- Sentiment analysis: SST and IMDb
7575
- Question classification: TREC
7676
- Entailment: SNLI
77-
- Language modeling: WikiText-2
78-
- Machine translation: Multi30k, IWSLT, WMT14
77+
- Language modeling: abstract class + WikiText-2
78+
- Machine translation: abstract class + Multi30k, IWSLT, WMT14
79+
- Sequence tagging (e.g. POS/NER): abstract class + UDPOS
7980

8081
Others are planned or a work in progress:
8182

test/sequence_tagging.py

Lines changed: 47 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,47 @@
1+
from torchtext import data
2+
from torchtext import datasets
3+
4+
# Define the fields associated with the sequences.
5+
WORD = data.Field(init_token="<bos>", eos_token="<eos>")
6+
UD_TAG = data.Field(init_token="<bos>", eos_token="<eos>")
7+
8+
# Download and the load default data.
9+
train, val, test = datasets.UDPOS.splits(
10+
fields=(('word', WORD), ('udtag', UD_TAG), (None, None)))
11+
12+
print(train.fields)
13+
print(len(train))
14+
print(vars(train[0]))
15+
16+
# We can also define more than two columns.
17+
WORD = data.Field(init_token="<bos>", eos_token="<eos>")
18+
UD_TAG = data.Field(init_token="<bos>", eos_token="<eos>")
19+
PTB_TAG = data.Field(init_token="<bos>", eos_token="<eos>")
20+
21+
# Load the specified data.
22+
train, val, test = datasets.UDPOS.splits(
23+
fields=(('word', WORD), ('udtag', UD_TAG), ('ptbtag', PTB_TAG)),
24+
path=".data/sequence-labeling/en-ud-v2",
25+
train="en-ud-tag.v2.train.txt",
26+
validation="en-ud-tag.v2.dev.txt",
27+
test="en-ud-tag.v2.test.txt")
28+
29+
print(train.fields)
30+
print(len(train))
31+
print(vars(train[0]))
32+
33+
WORD.build_vocab(train.word, min_freq=3)
34+
UD_TAG.build_vocab(train.udtag)
35+
PTB_TAG.build_vocab(train.ptbtag)
36+
37+
print(UD_TAG.vocab.freqs)
38+
print(PTB_TAG.vocab.freqs)
39+
40+
train_iter, val_iter = data.BucketIterator.splits(
41+
(train, val), batch_size=3, device=0)
42+
43+
batch = next(iter(train_iter))
44+
45+
print("words", batch.word)
46+
print("udtags", batch.udtag)
47+
print("ptbtags", batch.ptbtag)

torchtext/datasets/__init__.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22
from .snli import SNLI
33
from .sst import SST
44
from .translation import TranslationDataset, Multi30k, IWSLT, WMT14 # NOQA
5+
from .sequence_tagging import SequenceTaggingDataset, UDPOS # NOQA
56
from .trec import TREC
67
from .imdb import IMDB
78

@@ -15,4 +16,6 @@
1516
'WMT14'
1617
'WikiText2',
1718
'TREC',
18-
'IMDB']
19+
'IMDB',
20+
'SequenceTaggingDataset',
21+
'UDPOS']
Lines changed: 65 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,65 @@
1+
from .. import data
2+
3+
4+
class SequenceTaggingDataset(data.Dataset):
5+
"""Defines a dataset for sequence tagging. Examples in this dataset
6+
contain paired lists -- paired list of words and tags.
7+
8+
For example, in the case of part-of-speech tagging, an example is of the
9+
form
10+
[I, love, PyTorch, .] paired with [PRON, VERB, PROPN, PUNCT]
11+
12+
See torchtext/test/sequence_tagging.py on how to use this class.
13+
"""
14+
15+
@staticmethod
16+
def sort_key(example):
17+
for attr in dir(example):
18+
if not callable(getattr(example, attr)) and \
19+
not attr.startswith("__"):
20+
return len(getattr(example, attr))
21+
return 0
22+
23+
def __init__(self, path, fields, **kwargs):
24+
examples = []
25+
columns = []
26+
27+
with open(path) as input_file:
28+
for line in input_file:
29+
line = line.strip()
30+
if line == "":
31+
if columns:
32+
examples.append(data.Example.fromlist(columns, fields))
33+
columns = []
34+
else:
35+
for i, column in enumerate(line.split("\t")):
36+
if len(columns) < i + 1:
37+
columns.append([])
38+
columns[i].append(column)
39+
40+
if columns:
41+
examples.append(data.Example.fromlist(columns, fields))
42+
super(SequenceTaggingDataset, self).__init__(examples, fields,
43+
**kwargs)
44+
45+
46+
class UDPOS(SequenceTaggingDataset):
47+
48+
# Universal Dependencies English Web Treebank.
49+
# Download original at http://universaldependencies.org/
50+
# License: http://creativecommons.org/licenses/by-sa/4.0/
51+
urls = ['https://bitbucket.org/sivareddyg/public/downloads/en-ud-v2.zip']
52+
dirname = 'en-ud-v2'
53+
name = 'udpos'
54+
55+
@classmethod
56+
def splits(cls, fields, root=".data", train="en-ud-tag.v2.train.txt",
57+
validation="en-ud-tag.v2.dev.txt",
58+
test="en-ud-tag.v2.test.txt", **kwargs):
59+
"""Downloads and loads the Universal Dependencies Version 2 POS Tagged
60+
data.
61+
"""
62+
63+
return super(UDPOS, cls).splits(
64+
fields=fields, root=root, train=train, validation=validation,
65+
test=test, **kwargs)

0 commit comments

Comments
 (0)