|
| 1 | +from collections import defaultdict |
| 2 | +import time |
| 3 | +import random |
| 4 | +import torch |
| 5 | +import numpy as np |
| 6 | + |
| 7 | + |
| 8 | +class CNNclass(torch.nn.Module): |
| 9 | + def __init__(self, nwords, emb_size, num_filters, window_size, ntags): |
| 10 | + super(CNNclass, self).__init__() |
| 11 | + |
| 12 | + """ layers """ |
| 13 | + self.embedding = torch.nn.Embedding(nwords, emb_size) |
| 14 | + # uniform initialization |
| 15 | + torch.nn.init.uniform_(self.embedding.weight, -0.25, 0.25) |
| 16 | + # Conv 1d |
| 17 | + self.conv_1d = torch.nn.Conv1d(in_channels=emb_size, out_channels=num_filters, kernel_size=window_size, |
| 18 | + stride=1, padding=0, dilation=1, groups=1, bias=True) |
| 19 | + self.relu = torch.nn.ReLU() |
| 20 | + self.projection_layer = torch.nn.Linear(in_features=num_filters, out_features=ntags, bias=True) |
| 21 | + # Initializing the projection layer |
| 22 | + torch.nn.init.xavier_uniform_(self.projection_layer.weight) |
| 23 | + |
| 24 | + def forward(self, words, return_activations=False): |
| 25 | + emb = self.embedding(words) # nwords x emb_size |
| 26 | + emb = emb.unsqueeze(0).permute(0, 2, 1) # 1 x emb_size x nwords |
| 27 | + h = self.conv_1d(emb) # 1 x num_filters x nwords |
| 28 | + activations = h.squeeze(0).max(dim=1)[1] # argmax along length of the sentence |
| 29 | + # Do max pooling |
| 30 | + h = h.max(dim=2)[0] # 1 x num_filters |
| 31 | + h = self.relu(h) |
| 32 | + features = h.squeeze(0) |
| 33 | + out = self.projection_layer(h) # size(out) = 1 x ntags |
| 34 | + if return_activations: |
| 35 | + return out, activations.data.cpu().numpy(), features.data.cpu().numpy() |
| 36 | + return out |
| 37 | + |
| 38 | + |
| 39 | +np.set_printoptions(linewidth=np.nan, threshold=np.nan) |
| 40 | + |
| 41 | +# Functions to read in the corpus |
| 42 | +w2i = defaultdict(lambda: len(w2i)) |
| 43 | +UNK = w2i["<unk>"] |
| 44 | +def read_dataset(filename): |
| 45 | + with open(filename, "r") as f: |
| 46 | + for line in f: |
| 47 | + tag, words = line.lower().strip().split(" ||| ") |
| 48 | + words = words.split(" ") |
| 49 | + yield (words, [w2i[x] for x in words], int(tag)) |
| 50 | + |
| 51 | + |
| 52 | +# Read in the data |
| 53 | +train = list(read_dataset("../data/classes/train.txt")) |
| 54 | +w2i = defaultdict(lambda: UNK, w2i) |
| 55 | +dev = list(read_dataset("../data/classes/test.txt")) |
| 56 | +nwords = len(w2i) |
| 57 | +ntags = 5 |
| 58 | + |
| 59 | +# Define the model |
| 60 | +EMB_SIZE = 10 |
| 61 | +WIN_SIZE = 3 |
| 62 | +FILTER_SIZE = 8 |
| 63 | + |
| 64 | +# initialize the model |
| 65 | +model = CNNclass(nwords, EMB_SIZE, FILTER_SIZE, WIN_SIZE, ntags) |
| 66 | +criterion = torch.nn.CrossEntropyLoss() |
| 67 | +optimizer = torch.optim.Adam(model.parameters()) |
| 68 | + |
| 69 | +type = torch.LongTensor |
| 70 | +use_cuda = torch.cuda.is_available() |
| 71 | + |
| 72 | +if use_cuda: |
| 73 | + type = torch.cuda.LongTensor |
| 74 | + model.cuda() |
| 75 | + |
| 76 | + |
| 77 | +def calc_predict_and_activations(wids, tag, words): |
| 78 | + if len(wids) < WIN_SIZE: |
| 79 | + wids += [0] * (WIN_SIZE-len(wids)) |
| 80 | + words_tensor = torch.tensor(wids).type(type) |
| 81 | + scores, activations, features = model(words_tensor, return_activations=True) |
| 82 | + scores = scores.squeeze().cpu().data.numpy() |
| 83 | + print('%d ||| %s' % (tag, ' '.join(words))) |
| 84 | + predict = np.argmax(scores) |
| 85 | + print(display_activations(words, activations)) |
| 86 | + W = model.projection_layer.weight.data.cpu().numpy() |
| 87 | + bias = model.projection_layer.bias.data.cpu().numpy() |
| 88 | + print('scores=%s, predict: %d' % (scores, predict)) |
| 89 | + print(' bias=%s' % bias) |
| 90 | + contributions = W * features |
| 91 | + print(' very bad (%.4f): %s' % (scores[0], contributions[0])) |
| 92 | + print(' bad (%.4f): %s' % (scores[1], contributions[1])) |
| 93 | + print(' neutral (%.4f): %s' % (scores[2], contributions[2])) |
| 94 | + print(' good (%.4f): %s' % (scores[3], contributions[3])) |
| 95 | + print('very good (%.4f): %s' % (scores[4], contributions[4])) |
| 96 | + |
| 97 | + |
| 98 | +def display_activations(words, activations): |
| 99 | + pad_begin = (WIN_SIZE - 1) / 2 |
| 100 | + pad_end = WIN_SIZE - 1 - pad_begin |
| 101 | + words_padded = ['pad' for _ in range(int(pad_begin))] + words + ['pad' for _ in range(int(pad_end))] |
| 102 | + |
| 103 | + ngrams = [] |
| 104 | + for act in activations: |
| 105 | + ngrams.append('[' + ', '.join(words_padded[act:act+WIN_SIZE]) + ']') |
| 106 | + |
| 107 | + return ngrams |
| 108 | + |
| 109 | + |
| 110 | +for ITER in range(10): |
| 111 | + # Perform training |
| 112 | + random.shuffle(train) |
| 113 | + train_loss = 0.0 |
| 114 | + train_correct = 0.0 |
| 115 | + start = time.time() |
| 116 | + for _, wids, tag in train: |
| 117 | + # Padding (can be done in the conv layer as well) |
| 118 | + if len(wids) < WIN_SIZE: |
| 119 | + wids += [0] * (WIN_SIZE - len(wids)) |
| 120 | + words_tensor = torch.tensor(wids).type(type) |
| 121 | + tag_tensor = torch.tensor([tag]).type(type) |
| 122 | + scores = model(words_tensor) |
| 123 | + predict = scores[0].argmax().item() |
| 124 | + if predict == tag: |
| 125 | + train_correct += 1 |
| 126 | + |
| 127 | + my_loss = criterion(scores, tag_tensor) |
| 128 | + train_loss += my_loss.item() |
| 129 | + # Do back-prop |
| 130 | + optimizer.zero_grad() |
| 131 | + my_loss.backward() |
| 132 | + optimizer.step() |
| 133 | + print("iter %r: train loss/sent=%.4f, acc=%.4f, time=%.2fs" % (ITER, train_loss/len(train), train_correct/len(train), time.time()-start)) |
| 134 | + # Perform testing |
| 135 | + test_correct = 0.0 |
| 136 | + for _, wids, tag in dev: |
| 137 | + # Padding (can be done in the conv layer as well) |
| 138 | + if len(wids) < WIN_SIZE: |
| 139 | + wids += [0] * (WIN_SIZE - len(wids)) |
| 140 | + words_tensor = torch.tensor(wids).type(type) |
| 141 | + scores = model(words_tensor) |
| 142 | + predict = scores[0].argmax().item() |
| 143 | + if predict == tag: |
| 144 | + test_correct += 1 |
| 145 | + print("iter %r: test acc=%.4f" % (ITER, test_correct/len(dev))) |
| 146 | + |
| 147 | + |
| 148 | +for words, wids, tag in dev: |
| 149 | + calc_predict_and_activations(wids, tag, words) |
| 150 | + input() |
0 commit comments