Skip to content

Commit b4d40d4

Browse files
Adding pytorch version for 05-cnn
1 parent 81ba549 commit b4d40d4

File tree

2 files changed

+255
-0
lines changed

2 files changed

+255
-0
lines changed

05-cnn-pytorch/cnn-activation.py

Lines changed: 147 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,147 @@
1+
from collections import defaultdict
2+
import time
3+
import random
4+
import torch
5+
import numpy as np
6+
7+
8+
class CNNclass(torch.nn.Module):
9+
def __init__(self, nwords, emb_size, num_filters, window_size, ntags):
10+
super(CNNclass, self).__init__()
11+
12+
""" layers """
13+
self.embedding = torch.nn.Embedding(nwords, emb_size)
14+
# uniform initialization
15+
torch.nn.init.uniform_(self.embedding.weight, -0.25, 0.25)
16+
# Conv 1d
17+
self.conv_1d = torch.nn.Conv1d(in_channels=emb_size, out_channels=num_filters, kernel_size=window_size,
18+
stride=1, padding=0, dilation=1, groups=1, bias=True)
19+
self.relu = torch.nn.ReLU()
20+
self.projection_layer = torch.nn.Linear(in_features=num_filters, out_features=ntags, bias=True)
21+
# Initializing the projection layer
22+
torch.nn.init.xavier_uniform_(self.projection_layer.weight)
23+
24+
def forward(self, words, return_activations=False):
25+
emb = self.embedding(words) # nwords x emb_size
26+
emb = emb.unsqueeze(0).permute(0, 2, 1) # 1 x emb_size x nwords
27+
h = self.conv_1d(emb) # 1 x num_filters x nwords
28+
activations = h.squeeze().max(dim=1)[1] # argmax along length of the sentence
29+
# Do max pooling
30+
h = h.max(dim=2)[0] # 1 x num_filters
31+
h = self.relu(h)
32+
features = h.squeeze()
33+
out = self.projection_layer(h) # size(out) = 1 x ntags
34+
if return_activations:
35+
return out, activations.data.cpu().numpy(), features.data.cpu().numpy()
36+
return out
37+
38+
39+
np.set_printoptions(linewidth=np.nan, threshold=np.nan)
40+
41+
# Functions to read in the corpus
42+
w2i = defaultdict(lambda: len(w2i))
43+
UNK = w2i["<unk>"]
44+
def read_dataset(filename):
45+
with open(filename, "r") as f:
46+
for line in f:
47+
tag, words = line.lower().strip().split(" ||| ")
48+
words = words.split(" ")
49+
yield (words, [w2i[x] for x in words], int(tag))
50+
51+
# Read in the data
52+
train = list(read_dataset("../data/classes/train.txt"))[:50]
53+
w2i = defaultdict(lambda: UNK, w2i)
54+
dev = list(read_dataset("../data/classes/test.txt"))[:10]
55+
nwords = len(w2i)
56+
ntags = 5
57+
58+
# Define the model
59+
EMB_SIZE = 10
60+
WIN_SIZE = 3
61+
FILTER_SIZE = 8
62+
63+
# initialize the model
64+
model = CNNclass(nwords, EMB_SIZE, FILTER_SIZE, WIN_SIZE, ntags)
65+
criterion = torch.nn.CrossEntropyLoss()
66+
optimizer = torch.optim.Adam(model.parameters())
67+
68+
type = torch.LongTensor
69+
use_cuda = torch.cuda.is_available()
70+
71+
if use_cuda:
72+
type = torch.cuda.LongTensor
73+
model.cuda()
74+
75+
76+
def calc_predict_and_activations(wids, tag, words):
77+
if len(wids) < WIN_SIZE:
78+
wids += [0] * (WIN_SIZE-len(wids))
79+
words_tensor = torch.tensor(wids).type(type)
80+
scores, activations, features = model(words_tensor, return_activations=True)
81+
scores = scores.squeeze().cpu().data.numpy()
82+
print('%d ||| %s' % (tag, ' '.join(words)))
83+
predict = np.argmax(scores)
84+
print(display_activations(words, activations))
85+
W = model.projection_layer.weight.data.cpu().numpy()
86+
bias = model.projection_layer.bias.data.cpu().numpy()
87+
print('scores=%s, predict: %d' % (scores, predict))
88+
print(' bias=%s' % bias)
89+
contributions = W * features
90+
print(' very bad (%.4f): %s' % (scores[0], contributions[0]))
91+
print(' bad (%.4f): %s' % (scores[1], contributions[1]))
92+
print(' neutral (%.4f): %s' % (scores[2], contributions[2]))
93+
print(' good (%.4f): %s' % (scores[3], contributions[3]))
94+
print('very good (%.4f): %s' % (scores[4], contributions[4]))
95+
96+
97+
def display_activations(words, activations):
98+
pad_begin = (WIN_SIZE - 1) / 2
99+
pad_end = WIN_SIZE - 1 - pad_begin
100+
words_padded = ['pad' for _ in range(int(pad_begin))] + words + ['pad' for _ in range(int(pad_end))]
101+
102+
ngrams = []
103+
for act in activations:
104+
ngrams.append('[' + ', '.join(words_padded[act:act+WIN_SIZE]) + ']')
105+
106+
return ngrams
107+
108+
for ITER in range(10):
109+
# Perform training
110+
random.shuffle(train)
111+
train_loss = 0.0
112+
train_correct = 0.0
113+
start = time.time()
114+
for _, wids, tag in train:
115+
# Padding (can be done in the conv layer as well)
116+
if len(wids) < WIN_SIZE:
117+
wids += [0] * (WIN_SIZE - len(wids))
118+
words_tensor = torch.tensor(wids).type(type)
119+
tag_tensor = torch.tensor([tag]).type(type)
120+
scores = model(words_tensor)
121+
predict = scores[0].argmax().item()
122+
if predict == tag:
123+
train_correct += 1
124+
125+
my_loss = criterion(scores, tag_tensor)
126+
train_loss += my_loss.item()
127+
# Do back-prop
128+
optimizer.zero_grad()
129+
my_loss.backward()
130+
optimizer.step()
131+
print("iter %r: train loss/sent=%.4f, acc=%.4f, time=%.2fs" % (ITER, train_loss/len(train), train_correct/len(train), time.time()-start))
132+
# Perform testing
133+
test_correct = 0.0
134+
for _, wids, tag in dev:
135+
# Padding (can be done in the conv layer as well)
136+
if len(wids) < WIN_SIZE:
137+
wids += [0] * (WIN_SIZE - len(wids))
138+
words_tensor = torch.tensor(wids).type(type)
139+
scores = model(words_tensor)
140+
predict = scores[0].argmax().item()
141+
if predict == tag:
142+
test_correct += 1
143+
print("iter %r: test acc=%.4f" % (ITER, test_correct/len(dev)))
144+
145+
146+
for words, wids, tag in dev:
147+
calc_predict_and_activations(wids, tag, words)

05-cnn-pytorch/cnn-class.py

Lines changed: 108 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,108 @@
1+
from collections import defaultdict
2+
import time
3+
import random
4+
import torch
5+
6+
7+
class CNNclass(torch.nn.Module):
8+
def __init__(self, nwords, emb_size, num_filters, window_size, ntags):
9+
super(CNNclass, self).__init__()
10+
11+
""" layers """
12+
self.embedding = torch.nn.Embedding(nwords, emb_size)
13+
# uniform initialization
14+
torch.nn.init.uniform_(self.embedding.weight, -0.25, 0.25)
15+
# Conv 1d
16+
self.conv_1d = torch.nn.Conv1d(in_channels=emb_size, out_channels=num_filters, kernel_size=window_size,
17+
stride=1, padding=0, dilation=1, groups=1, bias=True)
18+
self.relu = torch.nn.ReLU()
19+
self.projection_layer = torch.nn.Linear(in_features=num_filters, out_features=ntags, bias=True)
20+
# Initializing the projection layer
21+
torch.nn.init.xavier_uniform_(self.projection_layer.weight)
22+
23+
def forward(self, words):
24+
emb = self.embedding(words) # nwords x emb_size
25+
emb = emb.unsqueeze(0).permute(0, 2, 1) # 1 x emb_size x nwords
26+
h = self.conv_1d(emb) # 1 x num_filters x nwords
27+
# Do max pooling
28+
h = h.max(dim=2)[0] # 1 x num_filters
29+
h = self.relu(h)
30+
out = self.projection_layer(h) # size(out) = 1 x ntags
31+
return out
32+
33+
34+
# Functions to read in the corpus
35+
w2i = defaultdict(lambda: len(w2i))
36+
t2i = defaultdict(lambda: len(t2i))
37+
UNK = w2i["<unk>"]
38+
39+
40+
def read_dataset(filename):
41+
with open(filename, "r") as f:
42+
for line in f:
43+
tag, words = line.lower().strip().split(" ||| ")
44+
yield ([w2i[x] for x in words.split(" ")], t2i[tag])
45+
46+
47+
# Read in the data
48+
train = list(read_dataset("../data/classes/train.txt"))
49+
w2i = defaultdict(lambda: UNK, w2i)
50+
dev = list(read_dataset("../data/classes/test.txt"))
51+
nwords = len(w2i)
52+
ntags = len(t2i)
53+
54+
# Define the model
55+
EMB_SIZE = 64
56+
WIN_SIZE = 3
57+
FILTER_SIZE = 64
58+
59+
# initialize the model
60+
model = CNNclass(nwords, EMB_SIZE, FILTER_SIZE, WIN_SIZE, ntags)
61+
criterion = torch.nn.CrossEntropyLoss()
62+
optimizer = torch.optim.Adam(model.parameters())
63+
64+
type = torch.LongTensor
65+
use_cuda = torch.cuda.is_available()
66+
67+
if use_cuda:
68+
type = torch.cuda.LongTensor
69+
model.cuda()
70+
71+
72+
for ITER in range(100):
73+
# Perform training
74+
random.shuffle(train)
75+
train_loss = 0.0
76+
train_correct = 0.0
77+
start = time.time()
78+
for words, tag in train:
79+
# Padding (can be done in the conv layer as well)
80+
if len(words) < WIN_SIZE:
81+
words += [0] * (WIN_SIZE - len(words))
82+
words_tensor = torch.tensor(words).type(type)
83+
tag_tensor = torch.tensor([tag]).type(type)
84+
scores = model(words_tensor)
85+
predict = scores[0].argmax().item()
86+
if predict == tag:
87+
train_correct += 1
88+
89+
my_loss = criterion(scores, tag_tensor)
90+
train_loss += my_loss.item()
91+
# Do back-prop
92+
optimizer.zero_grad()
93+
my_loss.backward()
94+
optimizer.step()
95+
print("iter %r: train loss/sent=%.4f, acc=%.4f, time=%.2fs" % (
96+
ITER, train_loss / len(train), train_correct / len(train), time.time() - start))
97+
# Perform testing
98+
test_correct = 0.0
99+
for words, tag in dev:
100+
# Padding (can be done in the conv layer as well)
101+
if len(words) < WIN_SIZE:
102+
words += [0] * (WIN_SIZE - len(words))
103+
words_tensor = torch.tensor(words).type(type)
104+
scores = model(words_tensor)[0]
105+
predict = scores.argmax().item()
106+
if predict == tag:
107+
test_correct += 1
108+
print("iter %r: test acc=%.4f" % (ITER, test_correct / len(dev)))

0 commit comments

Comments
 (0)