1+ from collections import defaultdict
2+ import time
3+ import random
4+ import torch
5+ import numpy as np
6+
7+
8+ class CNNclass (torch .nn .Module ):
9+ def __init__ (self , nwords , emb_size , num_filters , window_size , ntags ):
10+ super (CNNclass , self ).__init__ ()
11+
12+ """ layers """
13+ self .embedding = torch .nn .Embedding (nwords , emb_size )
14+ # uniform initialization
15+ torch .nn .init .uniform_ (self .embedding .weight , - 0.25 , 0.25 )
16+ # Conv 1d
17+ self .conv_1d = torch .nn .Conv1d (in_channels = emb_size , out_channels = num_filters , kernel_size = window_size ,
18+ stride = 1 , padding = 0 , dilation = 1 , groups = 1 , bias = True )
19+ self .relu = torch .nn .ReLU ()
20+ self .projection_layer = torch .nn .Linear (in_features = num_filters , out_features = ntags , bias = True )
21+ # Initializing the projection layer
22+ torch .nn .init .xavier_uniform_ (self .projection_layer .weight )
23+
24+ def forward (self , words , return_activations = False ):
25+ emb = self .embedding (words ) # nwords x emb_size
26+ emb = emb .unsqueeze (0 ).permute (0 , 2 , 1 ) # 1 x emb_size x nwords
27+ h = self .conv_1d (emb ) # 1 x num_filters x nwords
28+ activations = h .squeeze ().max (dim = 1 )[1 ] # argmax along length of the sentence
29+ # Do max pooling
30+ h = h .max (dim = 2 )[0 ] # 1 x num_filters
31+ h = self .relu (h )
32+ features = h .squeeze ()
33+ out = self .projection_layer (h ) # size(out) = 1 x ntags
34+ if return_activations :
35+ return out , activations .data .cpu ().numpy (), features .data .cpu ().numpy ()
36+ return out
37+
38+
39+ np .set_printoptions (linewidth = np .nan , threshold = np .nan )
40+
41+ # Functions to read in the corpus
42+ w2i = defaultdict (lambda : len (w2i ))
43+ UNK = w2i ["<unk>" ]
44+ def read_dataset (filename ):
45+ with open (filename , "r" ) as f :
46+ for line in f :
47+ tag , words = line .lower ().strip ().split (" ||| " )
48+ words = words .split (" " )
49+ yield (words , [w2i [x ] for x in words ], int (tag ))
50+
51+ # Read in the data
52+ train = list (read_dataset ("../data/classes/train.txt" ))[:50 ]
53+ w2i = defaultdict (lambda : UNK , w2i )
54+ dev = list (read_dataset ("../data/classes/test.txt" ))[:10 ]
55+ nwords = len (w2i )
56+ ntags = 5
57+
58+ # Define the model
59+ EMB_SIZE = 10
60+ WIN_SIZE = 3
61+ FILTER_SIZE = 8
62+
63+ # initialize the model
64+ model = CNNclass (nwords , EMB_SIZE , FILTER_SIZE , WIN_SIZE , ntags )
65+ criterion = torch .nn .CrossEntropyLoss ()
66+ optimizer = torch .optim .Adam (model .parameters ())
67+
68+ type = torch .LongTensor
69+ use_cuda = torch .cuda .is_available ()
70+
71+ if use_cuda :
72+ type = torch .cuda .LongTensor
73+ model .cuda ()
74+
75+
76+ def calc_predict_and_activations (wids , tag , words ):
77+ if len (wids ) < WIN_SIZE :
78+ wids += [0 ] * (WIN_SIZE - len (wids ))
79+ words_tensor = torch .tensor (wids ).type (type )
80+ scores , activations , features = model (words_tensor , return_activations = True )
81+ scores = scores .squeeze ().cpu ().data .numpy ()
82+ print ('%d ||| %s' % (tag , ' ' .join (words )))
83+ predict = np .argmax (scores )
84+ print (display_activations (words , activations ))
85+ W = model .projection_layer .weight .data .cpu ().numpy ()
86+ bias = model .projection_layer .bias .data .cpu ().numpy ()
87+ print ('scores=%s, predict: %d' % (scores , predict ))
88+ print (' bias=%s' % bias )
89+ contributions = W * features
90+ print (' very bad (%.4f): %s' % (scores [0 ], contributions [0 ]))
91+ print (' bad (%.4f): %s' % (scores [1 ], contributions [1 ]))
92+ print (' neutral (%.4f): %s' % (scores [2 ], contributions [2 ]))
93+ print (' good (%.4f): %s' % (scores [3 ], contributions [3 ]))
94+ print ('very good (%.4f): %s' % (scores [4 ], contributions [4 ]))
95+
96+
97+ def display_activations (words , activations ):
98+ pad_begin = (WIN_SIZE - 1 ) / 2
99+ pad_end = WIN_SIZE - 1 - pad_begin
100+ words_padded = ['pad' for _ in range (int (pad_begin ))] + words + ['pad' for _ in range (int (pad_end ))]
101+
102+ ngrams = []
103+ for act in activations :
104+ ngrams .append ('[' + ', ' .join (words_padded [act :act + WIN_SIZE ]) + ']' )
105+
106+ return ngrams
107+
108+ for ITER in range (10 ):
109+ # Perform training
110+ random .shuffle (train )
111+ train_loss = 0.0
112+ train_correct = 0.0
113+ start = time .time ()
114+ for _ , wids , tag in train :
115+ # Padding (can be done in the conv layer as well)
116+ if len (wids ) < WIN_SIZE :
117+ wids += [0 ] * (WIN_SIZE - len (wids ))
118+ words_tensor = torch .tensor (wids ).type (type )
119+ tag_tensor = torch .tensor ([tag ]).type (type )
120+ scores = model (words_tensor )
121+ predict = scores [0 ].argmax ().item ()
122+ if predict == tag :
123+ train_correct += 1
124+
125+ my_loss = criterion (scores , tag_tensor )
126+ train_loss += my_loss .item ()
127+ # Do back-prop
128+ optimizer .zero_grad ()
129+ my_loss .backward ()
130+ optimizer .step ()
131+ print ("iter %r: train loss/sent=%.4f, acc=%.4f, time=%.2fs" % (ITER , train_loss / len (train ), train_correct / len (train ), time .time ()- start ))
132+ # Perform testing
133+ test_correct = 0.0
134+ for _ , wids , tag in dev :
135+ # Padding (can be done in the conv layer as well)
136+ if len (wids ) < WIN_SIZE :
137+ wids += [0 ] * (WIN_SIZE - len (wids ))
138+ words_tensor = torch .tensor (wids ).type (type )
139+ scores = model (words_tensor )
140+ predict = scores [0 ].argmax ().item ()
141+ if predict == tag :
142+ test_correct += 1
143+ print ("iter %r: test acc=%.4f" % (ITER , test_correct / len (dev )))
144+
145+
146+ for words , wids , tag in dev :
147+ calc_predict_and_activations (wids , tag , words )
0 commit comments