33import time
44import random
55import torch
6- import torch .nn .functional as F
76
87
9- class Skip (torch .nn .Module ):
8+ class WordEmbSkip (torch .nn .Module ):
109 def __init__ (self , nwords , emb_size ):
11- super (Skip , self ).__init__ ()
10+ super (WordEmbSkip , self ).__init__ ()
1211
1312 """ word embeddings """
1413 self .word_embedding = torch .nn .Embedding (nwords , emb_size )
1514 # uniform initialization
1615 torch .nn .init .uniform_ (self .word_embedding .weight , - 0.25 , 0.25 )
1716 """ context embeddings"""
18- self .context_embedding = torch .nn .Embedding (nwords , emb_size )
19- # uniform initialization
20- torch .nn .init .uniform_ (self .context_embedding .weight , - 0.25 , 0.25 )
17+ self .context_embedding = torch .nn .Parameter (torch .randn (emb_size , nwords ))
2118
22- def forward (self , word_pos , context_pos ):
23- embed_word = self .word_embedding (word_pos ) # 1 * emb_size
24- embed_context = self .context_embedding (context_pos ) # 1 * emb_size
25- score = torch .mul (embed_word , embed_context )
26- score = torch .sum (score , dim = 1 )
27- log_target = - 1 * F .logsigmoid (score ).squeeze ()
28- return log_target
19+ def forward (self , word ):
20+ embed_word = self .word_embedding (word ) # 1 * emb_size
21+ out = torch .mm (embed_word , self .context_embedding ) # 1 * nwords
22+ return out
2923
3024
3125N = 2 # length of window on each side (so N=2 gives a total window size of 5, as in t-2 t-1 t t+1 t+2)
@@ -58,7 +52,8 @@ def read_dataset(filename):
5852 labels_file .write (i2w [i ] + '\n ' )
5953
6054# initialize the model
61- model = Skip (nwords , EMB_SIZE )
55+ model = WordEmbSkip (nwords , EMB_SIZE )
56+ criterion = torch .nn .CrossEntropyLoss ()
6257optimizer = torch .optim .SGD (model .parameters (), lr = 0.1 )
6358
6459type = torch .LongTensor
@@ -70,26 +65,22 @@ def read_dataset(filename):
7065
7166
7267# Calculate the loss value for the entire sentence
73- def calc_sent_loss (sent , inference = False ):
68+ def calc_sent_loss (sent ):
7469 # add padding to the sentence equal to the size of the window
7570 # as we need to predict the eos as well, the future window at that point is N past it
7671
7772 # Step through the sentence
78- total_loss = 0
73+ losses = []
7974 for i , word in enumerate (sent ):
80- c = torch .tensor ([word ]).type (type ) # This is tensor for center word
8175 for j in range (1 , N + 1 ):
8276 for direction in [- 1 , 1 ]:
77+ c = torch .tensor ([word ]).type (type ) # This is tensor for center word
8378 context_id = sent [i + direction * j ] if 0 <= i + direction * j < len (sent ) else S
8479 context = torch .tensor ([context_id ]).type (type ) # Tensor for context word
85- loss = model (c , context )
86- if not inference :
87- # Back prop while training only
88- optimizer .zero_grad ()
89- loss .backward ()
90- optimizer .step ()
91- total_loss += loss .data .cpu ().item ()
92- return total_loss
80+ logits = model (c )
81+ loss = criterion (logits , context )
82+ losses .append (loss )
83+ return torch .stack (losses ).sum ()
9384
9485
9586MAX_LEN = 100
@@ -100,20 +91,26 @@ def calc_sent_loss(sent, inference=False):
10091 random .shuffle (train )
10192 train_words , train_loss = 0 , 0.0
10293 start = time .time ()
94+ model .train ()
10395 for sent_id , sent in enumerate (train ):
10496 my_loss = calc_sent_loss (sent )
105- train_loss += my_loss
97+ train_loss += my_loss . item ()
10698 train_words += len (sent )
99+ # Back prop while training
100+ optimizer .zero_grad ()
101+ my_loss .backward ()
102+ optimizer .step ()
107103 if (sent_id + 1 ) % 5000 == 0 :
108104 print ("--finished %r sentences" % (sent_id + 1 ))
109105 print ("iter %r: train loss/word=%.4f, ppl=%.4f, time=%.2fs" % (
110106 ITER , train_loss / train_words , math .exp (train_loss / train_words ), time .time () - start ))
111107 # Evaluate on dev set
112108 dev_words , dev_loss = 0 , 0.0
113109 start = time .time ()
110+ model .eval ()
114111 for sent_id , sent in enumerate (dev ):
115- my_loss = calc_sent_loss (sent , inference = True )
116- dev_loss += my_loss
112+ my_loss = calc_sent_loss (sent )
113+ dev_loss += my_loss . item ()
117114 dev_words += len (sent )
118115 print ("iter %r: dev loss/word=%.4f, ppl=%.4f, time=%.2fs" % (
119116 ITER , dev_loss / dev_words , math .exp (dev_loss / dev_words ), time .time () - start ))
0 commit comments