Skip to content

Commit 4c65331

Browse files
Suggested changes
1 parent 166b256 commit 4c65331

File tree

1 file changed

+3
-2
lines changed

1 file changed

+3
-2
lines changed

05-cnn-pytorch/cnn-activation.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -25,7 +25,7 @@ def forward(self, words, return_activations=False):
2525
emb = self.embedding(words) # nwords x emb_size
2626
emb = emb.unsqueeze(0).permute(0, 2, 1) # 1 x emb_size x nwords
2727
h = self.conv_1d(emb) # 1 x num_filters x nwords
28-
activations = h.squeeze().max(dim=1)[1] # argmax along length of the sentence
28+
activations = h.squeeze(0).max(dim=1)[1] # argmax along length of the sentence
2929
# Do max pooling
3030
h = h.max(dim=2)[0] # 1 x num_filters
3131
h = self.relu(h)
@@ -146,4 +146,5 @@ def display_activations(words, activations):
146146

147147

148148
for words, wids, tag in dev:
149-
calc_predict_and_activations(wids, tag, words)
149+
calc_predict_and_activations(wids, tag, words)
150+
input()

0 commit comments

Comments
 (0)