Skip to content

Commit c18372d

Browse files
committed
changed cnn-activations to work with python3 (erstwhile python2 compatible)
1 parent aaad004 commit c18372d

File tree

1 file changed

+20
-27
lines changed

1 file changed

+20
-27
lines changed

05-cnn/cnn-activation.py

Lines changed: 20 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -24,7 +24,7 @@ def read_dataset(filename):
2424
ntags = 5
2525

2626
# Start DyNet and define trainer
27-
model = dy.Model()
27+
model = dy.ParameterCollection()
2828
trainer = dy.AdamTrainer(model)
2929

3030
# Define the model
@@ -40,58 +40,50 @@ def read_dataset(filename):
4040

4141
def calc_scores(wids):
4242
dy.renew_cg()
43-
W_cnn_express = dy.parameter(W_cnn)
44-
b_cnn_express = dy.parameter(b_cnn)
45-
W_sm_express = dy.parameter(W_sm)
46-
b_sm_express = dy.parameter(b_sm)
4743
if len(wids) < WIN_SIZE:
4844
wids += [0] * (WIN_SIZE-len(wids))
4945

5046
cnn_in = dy.concatenate([dy.lookup(W_emb, x) for x in wids], d=1)
51-
cnn_out = dy.conv2d_bias(cnn_in, W_cnn_express, b_cnn_express, stride=(1, 1), is_valid=False)
47+
cnn_out = dy.conv2d_bias(cnn_in, W_cnn, b_cnn, stride=(1, 1), is_valid=False)
5248
pool_out = dy.max_dim(cnn_out, d=1)
5349
pool_out = dy.reshape(pool_out, (FILTER_SIZE,))
5450
pool_out = dy.rectify(pool_out)
55-
return W_sm_express * pool_out + b_sm_express
51+
return W_sm * pool_out + b_sm
5652

5753
def calc_predict_and_activations(wids, tag, words):
5854
dy.renew_cg()
59-
W_cnn_express = dy.parameter(W_cnn)
60-
b_cnn_express = dy.parameter(b_cnn)
61-
W_sm_express = dy.parameter(W_sm)
62-
b_sm_express = dy.parameter(b_sm)
6355
if len(wids) < WIN_SIZE:
6456
wids += [0] * (WIN_SIZE-len(wids))
6557

6658
cnn_in = dy.concatenate([dy.lookup(W_emb, x) for x in wids], d=1)
67-
cnn_out = dy.conv2d_bias(cnn_in, W_cnn_express, b_cnn_express, stride=(1, 1), is_valid=False)
59+
cnn_out = dy.conv2d_bias(cnn_in, W_cnn, b_cnn, stride=(1, 1), is_valid=False)
6860
filters = (dy.reshape(cnn_out, (len(wids), FILTER_SIZE))).npvalue()
6961
activations = filters.argmax(axis=0)
7062

7163
pool_out = dy.max_dim(cnn_out, d=1)
7264
pool_out = dy.reshape(pool_out, (FILTER_SIZE,))
7365
pool_out = dy.rectify(pool_out)
7466

75-
scores = (W_sm_express * pool_out + b_sm_express).npvalue()
76-
print '%d ||| %s' % (tag, ' '.join(words))
67+
scores = (W_sm * pool_out + b_sm).npvalue()
68+
print ('%d ||| %s' % (tag, ' '.join(words)))
7769
predict = np.argmax(scores)
78-
print display_activations(words, activations)
79-
print 'scores=%s, predict: %d' % (scores, predict)
70+
print (display_activations(words, activations))
71+
print ('scores=%s, predict: %d' % (scores, predict))
8072
features = pool_out.npvalue()
81-
W = W_sm_express.npvalue()
82-
bias = b_sm_express.npvalue()
83-
print ' bias=%s' % bias
73+
W = W_sm.npvalue()
74+
bias = b_sm.npvalue()
75+
print (' bias=%s' % bias)
8476
contributions = W * features
85-
print ' very bad (%.4f): %s' % (scores[0], contributions[0])
86-
print ' bad (%.4f): %s' % (scores[1], contributions[1])
87-
print ' neutral (%.4f): %s' % (scores[2], contributions[2])
88-
print ' good (%.4f): %s' % (scores[3], contributions[3])
89-
print 'very good (%.4f): %s' % (scores[4], contributions[4])
77+
print (' very bad (%.4f): %s' % (scores[0], contributions[0]))
78+
print (' bad (%.4f): %s' % (scores[1], contributions[1]))
79+
print (' neutral (%.4f): %s' % (scores[2], contributions[2]))
80+
print (' good (%.4f): %s' % (scores[3], contributions[3]))
81+
print ('very good (%.4f): %s' % (scores[4], contributions[4]))
9082

9183

9284
def display_activations(words, activations):
93-
pad_begin = (WIN_SIZE - 1) / 2
94-
pad_end = WIN_SIZE - 1 - pad_begin
85+
pad_begin = int((WIN_SIZE - 1) / 2)
86+
pad_end = int(WIN_SIZE - 1 - pad_begin)
9587
words_padded = ['pad' for i in range(pad_begin)] + words + ['pad' for i in range(pad_end)]
9688

9789
ngrams = []
@@ -129,4 +121,5 @@ def display_activations(words, activations):
129121

130122
for words, wids, tag in dev:
131123
calc_predict_and_activations(wids, tag, words)
132-
raw_input()
124+
# input prompt so that the next example is revealed on a key press
125+
input()

0 commit comments

Comments
 (0)