Skip to content

Commit 24460f1

Browse files
authored
Merge pull request #26 from danishpruthi/master
Translating 04-efficiency to pytorch
2 parents e947583 + aaad004 commit 24460f1

File tree

3 files changed

+160
-2
lines changed

3 files changed

+160
-2
lines changed
Lines changed: 160 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,160 @@
1+
from collections import defaultdict
2+
import math
3+
import numpy as np
4+
import time
5+
import random
6+
import torch
7+
import torch.nn.functional as F
8+
9+
10+
class WordEmbSkip(torch.nn.Module):
11+
def __init__(self, nwords, emb_size):
12+
super(WordEmbSkip, self).__init__()
13+
14+
""" word embeddings """
15+
self.word_embedding = torch.nn.Embedding(nwords, emb_size, sparse=True)
16+
# initialize the weights with xavier uniform (Glorot, X. & Bengio, Y. (2010))
17+
torch.nn.init.xavier_uniform_(self.word_embedding.weight)
18+
""" context embeddings"""
19+
self.context_embedding = torch.nn.Embedding(nwords, emb_size, sparse=True)
20+
# initialize the weights with xavier uniform (Glorot, X. & Bengio, Y. (2010))
21+
torch.nn.init.xavier_uniform_(self.context_embedding.weight)
22+
23+
# useful ref: https://arxiv.org/abs/1402.3722
24+
def forward(self, word_pos, context_positions, negative_sample=False):
25+
embed_word = self.word_embedding(word_pos) # 1 * emb_size
26+
embed_context = self.context_embedding(context_positions) # n * emb_size
27+
score = torch.matmul(embed_context, embed_word.transpose(dim0=1, dim1=0)) #score = n * 1
28+
29+
# following is an example of something you can only do in a framework that allows
30+
# dynamic graph creation
31+
if negative_sample:
32+
score = -1*score
33+
obj = -1 * torch.sum(F.logsigmoid(score))
34+
return obj
35+
36+
K=3 #number of negative samples
37+
N=2 #length of window on each side (so N=2 gives a total window size of 5, as in t-2 t-1 t t+1 t+2)
38+
EMB_SIZE = 128 # The size of the embedding
39+
40+
embeddings_location = "embeddings.txt" #the file to write the word embeddings to
41+
labels_location = "labels.txt" #the file to write the labels to
42+
43+
# We reuse the data reading from the language modeling class
44+
w2i = defaultdict(lambda: len(w2i))
45+
46+
#word counts for negative sampling
47+
word_counts = defaultdict(int)
48+
49+
S = w2i["<s>"]
50+
UNK = w2i["<unk>"]
51+
def read_dataset(filename):
52+
with open(filename, "r") as f:
53+
for line in f:
54+
line = line.strip().split(" ")
55+
for word in line:
56+
word_counts[w2i[word]] += 1
57+
yield [w2i[x] for x in line]
58+
59+
60+
# Read in the data
61+
train = list(read_dataset("../data/ptb/train.txt"))
62+
w2i = defaultdict(lambda: UNK, w2i)
63+
dev = list(read_dataset("../data/ptb/valid.txt"))
64+
i2w = {v: k for k, v in w2i.items()}
65+
nwords = len(w2i)
66+
67+
68+
# take the word counts to the 3/4, normalize
69+
counts = np.array([list(x) for x in word_counts.items()])[:,1]**.75
70+
normalizing_constant = sum(counts)
71+
word_probabilities = np.zeros(nwords)
72+
for word_id in word_counts:
73+
word_probabilities[word_id] = word_counts[word_id]**.75/normalizing_constant
74+
75+
with open(labels_location, 'w') as labels_file:
76+
for i in range(nwords):
77+
labels_file.write(i2w[i] + '\n')
78+
79+
# initialize the model
80+
model = WordEmbSkip(nwords, EMB_SIZE)
81+
optimizer = torch.optim.SGD(model.parameters(), lr=0.1)
82+
83+
type = torch.LongTensor
84+
use_cuda = torch.cuda.is_available()
85+
86+
if use_cuda:
87+
type = torch.cuda.LongTensor
88+
model.cuda()
89+
90+
91+
# Calculate the loss value for the entire sentence
92+
def calc_sent_loss(sent):
93+
# add padding to the sentence equal to the size of the window
94+
# as we need to predict the eos as well, the future window at that point is N past it
95+
all_neg_words = np.random.choice(nwords, size=2*N*K*len(sent), replace=True, p=word_probabilities)
96+
97+
# Step through the sentence
98+
losses = []
99+
for i, word in enumerate(sent):
100+
pos_words = [sent[x] if x >= 0 else S for x in range(i-N,i)] + \
101+
[sent[x] if x < len(sent) else S for x in range(i+1,i+N+1)]
102+
pos_words_tensor = torch.tensor(pos_words).type(type)
103+
neg_words = all_neg_words[i*K*2*N:(i+1)*K*2*N]
104+
neg_words_tensor = torch.tensor(neg_words).type(type)
105+
target_word_tensor = torch.tensor([word]).type(type)
106+
107+
#NOTE: technically, one should ensure that the neg words don't contain the
108+
# the context (i.e. positive) words, but it is very unlikely, so we can ignore that
109+
110+
pos_loss = model(target_word_tensor, pos_words_tensor)
111+
neg_loss = model(target_word_tensor, neg_words_tensor, negative_sample=True)
112+
113+
losses.append(pos_loss + neg_loss)
114+
115+
return torch.stack(losses).sum()
116+
117+
118+
MAX_LEN = 100
119+
120+
for ITER in range(100):
121+
print("started iter %r" % ITER)
122+
# Perform training
123+
random.shuffle(train)
124+
train_words, train_loss = 0, 0.0
125+
start = time.time()
126+
model.train()
127+
for sent_id, sent in enumerate(train):
128+
my_loss = calc_sent_loss(sent)
129+
train_loss += my_loss.item()
130+
train_words += len(sent)
131+
# Back prop while training
132+
optimizer.zero_grad()
133+
my_loss.backward()
134+
optimizer.step()
135+
if (sent_id + 1) % 50 == 0:
136+
print("--finished %r sentences" % (sent_id + 1))
137+
train_ppl = float('inf') if train_loss / train_words > 709 else math.exp(train_loss / train_words)
138+
print("after sentences %r: train loss/word=%.4f, ppl=%.4f, time=%.2fs" % (
139+
sent_id + 1, train_loss / train_words, train_ppl, time.time() - start))
140+
train_ppl = float('inf') if train_loss / train_words > 709 else math.exp(train_loss / train_words)
141+
print("iter %r: train loss/word=%.4f, ppl=%.4f, time=%.2fs" % (
142+
ITER, train_loss / train_words, train_ppl, time.time() - start))
143+
# Evaluate on dev set
144+
dev_words, dev_loss = 0, 0.0
145+
start = time.time()
146+
model.eval()
147+
for sent_id, sent in enumerate(dev):
148+
my_loss = calc_sent_loss(sent)
149+
dev_loss += my_loss.item()
150+
dev_words += len(sent)
151+
dev_ppl = float('inf') if dev_loss / dev_words > 709 else math.exp(dev_loss / dev_words)
152+
print("iter %r: dev loss/word=%.4f, ppl=%.4f, time=%.2fs" % (
153+
ITER, dev_loss / dev_words, dev_ppl, time.time() - start))
154+
155+
print("saving embedding files")
156+
with open(embeddings_location, 'w') as embeddings_file:
157+
W_w_np = model.word_embedding.weight.data.cpu().numpy()
158+
for i in range(nwords):
159+
ith_embedding = '\t'.join(map(str, W_w_np[i]))
160+
embeddings_file.write(ith_embedding + '\n')

04-efficiency/wordemb-skip-binary.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -92,7 +92,6 @@ def calc_sent_loss(sent):
9292
my_loss = calc_sent_loss(sent)
9393
dev_loss += my_loss.value()
9494
dev_words += len(sent)
95-
trainer.update()
9695
print("iter %r: dev loss/word=%.4f, ppl=%.4f, time=%.2fs" % (ITER, dev_loss/dev_words, math.exp(dev_loss/dev_words), time.time()-start))
9796

9897
print("saving embedding files")

04-efficiency/wordemb-skip-ns.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -104,7 +104,6 @@ def calc_sent_loss(sent):
104104
my_loss = calc_sent_loss(sent)
105105
dev_loss += my_loss.value()
106106
dev_words += len(sent)
107-
trainer.update()
108107
print("iter %r: dev loss/word=%.4f, ppl=%.4f, time=%.2fs" % (ITER, dev_loss/dev_words, math.exp(dev_loss/dev_words), time.time()-start))
109108

110109
print("saving embedding files")

0 commit comments

Comments
 (0)