-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathvocab.py
More file actions
100 lines (87 loc) · 3.63 KB
/
vocab.py
File metadata and controls
100 lines (87 loc) · 3.63 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
"""
A class for basic vocab operations.
"""
from __future__ import print_function
import os
import random
import numpy as np
import pickle
from utils import constant
random.seed(1234)
np.random.seed(1234)
def build_embedding(wv_file, vocab, wv_dim):
vocab_size = len(vocab)
emb = np.random.uniform(-1, 1, (vocab_size, wv_dim))
emb[constant.PAD_ID] = 0 # <pad> should be all 0
w2id = {w: i for i, w in enumerate(vocab)}
with open(wv_file, encoding="utf8") as f:
for line in f:
elems = line.split()
token = ''.join(elems[0:-wv_dim])
if token in w2id:
emb[w2id[token]] = [float(v) for v in elems[-wv_dim:]]
return emb
def load_glove_vocab(file, wv_dim):
"""
Load all words from glove.
"""
vocab = set()
with open(file, encoding='utf8') as f:
for line in f:
elems = line.split()
token = ''.join(elems[0:-wv_dim])
vocab.add(token)
return vocab
class Vocab(object):
def __init__(self, filename, load=False, word_counter=None, threshold=0):
if load:
assert os.path.exists(filename), "Vocab file does not exist at " + filename
# load from file and ignore all other params
self.id2word, self.word2id = self.load(filename)
self.size = len(self.id2word)
print("Vocab size {} loaded from file".format(self.size))
else:
print("Creating vocab from scratch...")
assert word_counter is not None, "word_counter is not provided for vocab creation."
self.word_counter = word_counter
if threshold > 1:
# remove words that occur less than thres
self.word_counter = dict([(k,v) for k,v in self.word_counter.items() if v >= threshold])
self.id2word = sorted(self.word_counter, key=lambda k:self.word_counter[k], reverse=True)
# add special tokens to the beginning
self.id2word = [constant.PAD_TOKEN, constant.UNK_TOKEN] + self.id2word
self.word2id = dict([(self.id2word[idx],idx) for idx in range(len(self.id2word))])
self.size = len(self.id2word)
self.save(filename)
print("Vocab size {} saved to file {}".format(self.size, filename))
def load(self, filename):
with open(filename, 'rb') as infile:
id2word = pickle.load(infile)
word2id = dict([(id2word[idx], idx) for idx in range(len(id2word))])
return id2word, word2id
def save(self, filename):
if os.path.exists(filename):
print("Overwriting old vocab file at " + filename)
os.remove(filename)
with open(filename, 'wb') as outfile:
pickle.dump(self.id2word, outfile)
return
def map(self, token_list):
"""
Map a list of tokens to their ids.
"""
return [self.word2id[w] if w in self.word2id else constant.VOCAB_UNK_ID for w in token_list]
def unmap(self, idx_list):
"""
Unmap ids back to tokens.
"""
return [self.id2word[idx] for idx in idx_list]
def get_embeddings(self, word_vectors=None, dim=100):
self.embeddings = 2 * constant.EMB_INIT_RANGE * np.random.rand(self.size, dim) - constant.EMB_INIT_RANGE
if word_vectors is not None:
assert len(list(word_vectors.values())[0]) == dim, \
"Word vectors does not have required dimension {}.".format(dim)
for w, idx in self.word2id.items():
if w in word_vectors:
self.embeddings[idx] = np.asarray(word_vectors[w])
return self.embeddings