Skip to content

Commit f31488d

Browse files
Making the code consistent with dynet version
1 parent 6ae371f commit f31488d

File tree

2 files changed

+41
-45
lines changed

2 files changed

+41
-45
lines changed

03-wordemb-pytorch/wordemb-cbow.py

Lines changed: 16 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -5,9 +5,9 @@
55
import torch
66

77

8-
class CBoW(torch.nn.Module):
8+
class WordEmbCbow(torch.nn.Module):
99
def __init__(self, nwords, emb_size):
10-
super(CBoW, self).__init__()
10+
super(WordEmbCbow, self).__init__()
1111

1212
""" layers """
1313
self.embedding = torch.nn.Embedding(nwords, emb_size)
@@ -55,7 +55,7 @@ def read_dataset(filename):
5555
labels_file.write(i2w[i] + '\n')
5656

5757
# initialize the model
58-
model = CBoW(nwords, EMB_SIZE)
58+
model = WordEmbCbow(nwords, EMB_SIZE)
5959
criterion = torch.nn.CrossEntropyLoss()
6060
optimizer = torch.optim.SGD(model.parameters(), lr=0.1)
6161

@@ -68,27 +68,22 @@ def read_dataset(filename):
6868

6969

7070
# Calculate the loss value for the entire sentence
71-
def calc_sent_loss(sent, inference=False):
71+
def calc_sent_loss(sent):
7272

7373
# add padding to the sentence equal to the size of the window
7474
# as we need to predict the eos as well, the future window at that point is N past it
7575
padded_sent = [S] * N + sent + [S] * N
7676

7777
# Step through the sentence
78-
total_loss = 0
78+
losses = []
7979
for i in range(N, len(sent) + N):
8080
# c is the context vector
8181
c = torch.tensor(padded_sent[i - N:i] + padded_sent[i + 1:i + N + 1]).type(type)
8282
t = torch.tensor([padded_sent[i]]).type(type) # This is the target vector
83-
log_prob = model(c)
84-
loss = criterion(log_prob, t) # loss for predicting target from context vector
85-
if not inference:
86-
# Back prop while training only
87-
optimizer.zero_grad()
88-
loss.backward()
89-
optimizer.step()
90-
total_loss += loss.data.cpu().item()
91-
return total_loss
83+
logits = model(c)
84+
loss = criterion(logits, t) # loss for predicting target from context vector
85+
losses.append(loss)
86+
return torch.stack(losses).sum()
9287

9388

9489
MAX_LEN = 100
@@ -101,8 +96,12 @@ def calc_sent_loss(sent, inference=False):
10196
start = time.time()
10297
for sent_id, sent in enumerate(train):
10398
my_loss = calc_sent_loss(sent)
104-
train_loss += my_loss
99+
train_loss += my_loss.item()
105100
train_words += len(sent)
101+
# Taking the step after calculating loss for all the words in the sentence
102+
optimizer.zero_grad()
103+
my_loss.backward()
104+
optimizer.step()
106105
if (sent_id + 1) % 5000 == 0:
107106
print("--finished %r sentences" % (sent_id + 1))
108107
print("iter %r: train loss/word=%.4f, ppl=%.4f, time=%.2fs" % (
@@ -111,8 +110,8 @@ def calc_sent_loss(sent, inference=False):
111110
dev_words, dev_loss = 0, 0.0
112111
start = time.time()
113112
for sent_id, sent in enumerate(dev):
114-
my_loss = calc_sent_loss(sent, inference=True)
115-
dev_loss += my_loss
113+
my_loss = calc_sent_loss(sent)
114+
dev_loss += my_loss.item()
116115
dev_words += len(sent)
117116
print("iter %r: dev loss/word=%.4f, ppl=%.4f, time=%.2fs" % (
118117
ITER, dev_loss / dev_words, math.exp(dev_loss / dev_words), time.time() - start))

03-wordemb-pytorch/wordemb-skip.py

Lines changed: 25 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -3,29 +3,23 @@
33
import time
44
import random
55
import torch
6-
import torch.nn.functional as F
76

87

9-
class Skip(torch.nn.Module):
8+
class WordEmbSkip(torch.nn.Module):
109
def __init__(self, nwords, emb_size):
11-
super(Skip, self).__init__()
10+
super(WordEmbSkip, self).__init__()
1211

1312
""" word embeddings """
1413
self.word_embedding = torch.nn.Embedding(nwords, emb_size)
1514
# uniform initialization
1615
torch.nn.init.uniform_(self.word_embedding.weight, -0.25, 0.25)
1716
""" context embeddings"""
18-
self.context_embedding = torch.nn.Embedding(nwords, emb_size)
19-
# uniform initialization
20-
torch.nn.init.uniform_(self.context_embedding.weight, -0.25, 0.25)
17+
self.context_embedding = torch.nn.Parameter(torch.randn(emb_size, nwords))
2118

22-
def forward(self, word_pos, context_pos):
23-
embed_word = self.word_embedding(word_pos) # 1 * emb_size
24-
embed_context = self.context_embedding(context_pos) # 1 * emb_size
25-
score = torch.mul(embed_word, embed_context)
26-
score = torch.sum(score, dim=1)
27-
log_target = -1 * F.logsigmoid(score).squeeze()
28-
return log_target
19+
def forward(self, word):
20+
embed_word = self.word_embedding(word) # 1 * emb_size
21+
out = torch.mm(embed_word, self.context_embedding) # 1 * nwords
22+
return out
2923

3024

3125
N = 2 # length of window on each side (so N=2 gives a total window size of 5, as in t-2 t-1 t t+1 t+2)
@@ -58,7 +52,8 @@ def read_dataset(filename):
5852
labels_file.write(i2w[i] + '\n')
5953

6054
# initialize the model
61-
model = Skip(nwords, EMB_SIZE)
55+
model = WordEmbSkip(nwords, EMB_SIZE)
56+
criterion = torch.nn.CrossEntropyLoss()
6257
optimizer = torch.optim.SGD(model.parameters(), lr=0.1)
6358

6459
type = torch.LongTensor
@@ -70,26 +65,22 @@ def read_dataset(filename):
7065

7166

7267
# Calculate the loss value for the entire sentence
73-
def calc_sent_loss(sent, inference=False):
68+
def calc_sent_loss(sent):
7469
# add padding to the sentence equal to the size of the window
7570
# as we need to predict the eos as well, the future window at that point is N past it
7671

7772
# Step through the sentence
78-
total_loss = 0
73+
losses = []
7974
for i, word in enumerate(sent):
80-
c = torch.tensor([word]).type(type) # This is tensor for center word
8175
for j in range(1, N + 1):
8276
for direction in [-1, 1]:
77+
c = torch.tensor([word]).type(type) # This is tensor for center word
8378
context_id = sent[i + direction * j] if 0 <= i + direction * j < len(sent) else S
8479
context = torch.tensor([context_id]).type(type) # Tensor for context word
85-
loss = model(c, context)
86-
if not inference:
87-
# Back prop while training only
88-
optimizer.zero_grad()
89-
loss.backward()
90-
optimizer.step()
91-
total_loss += loss.data.cpu().item()
92-
return total_loss
80+
logits = model(c)
81+
loss = criterion(logits, context)
82+
losses.append(loss)
83+
return torch.stack(losses).sum()
9384

9485

9586
MAX_LEN = 100
@@ -100,20 +91,26 @@ def calc_sent_loss(sent, inference=False):
10091
random.shuffle(train)
10192
train_words, train_loss = 0, 0.0
10293
start = time.time()
94+
model.train()
10395
for sent_id, sent in enumerate(train):
10496
my_loss = calc_sent_loss(sent)
105-
train_loss += my_loss
97+
train_loss += my_loss.item()
10698
train_words += len(sent)
99+
# Back prop while training
100+
optimizer.zero_grad()
101+
my_loss.backward()
102+
optimizer.step()
107103
if (sent_id + 1) % 5000 == 0:
108104
print("--finished %r sentences" % (sent_id + 1))
109105
print("iter %r: train loss/word=%.4f, ppl=%.4f, time=%.2fs" % (
110106
ITER, train_loss / train_words, math.exp(train_loss / train_words), time.time() - start))
111107
# Evaluate on dev set
112108
dev_words, dev_loss = 0, 0.0
113109
start = time.time()
110+
model.eval()
114111
for sent_id, sent in enumerate(dev):
115-
my_loss = calc_sent_loss(sent, inference=True)
116-
dev_loss += my_loss
112+
my_loss = calc_sent_loss(sent)
113+
dev_loss += my_loss.item()
117114
dev_words += len(sent)
118115
print("iter %r: dev loss/word=%.4f, ppl=%.4f, time=%.2fs" % (
119116
ITER, dev_loss / dev_words, math.exp(dev_loss / dev_words), time.time() - start))

0 commit comments

Comments
 (0)