Skip to content

Commit 4b85c72

Browse files
authored
Merge pull request #3 from forrestdavis/master
Added embedding flag --view_emb. Gets embeddings for input words
2 parents bf40b9e + f0a1b0b commit 4b85c72

File tree

1 file changed

+12
-0
lines changed

1 file changed

+12
-0
lines changed

main.py

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -101,6 +101,11 @@
101101
help='adapt model weights during evaluation')
102102
parser.add_argument('--interact', action='store_true',
103103
help='run a trained network interactively')
104+
105+
#For getting embeddings
106+
parser.add_argument('--view_emb', action='store_true',
107+
help='output the word embedding rather than the cell state')
108+
104109
parser.add_argument('--view_layer', type=int, default=-1,
105110
help='which layer should output cell states')
106111
parser.add_argument('--view_hidden', action='store_true',
@@ -459,6 +464,13 @@ def test_evaluate(test_sentences, data_source):
459464
if args.view_hidden:
460465
# output hidden state
461466
print(*list(hidden[0][args.view_layer].view(1, -1).data.cpu().numpy().flatten()), sep=' ')
467+
468+
elif args.view_emb:
469+
#Get embedding for input word
470+
emb = model.encoder(word_input)
471+
# output embedding
472+
print(*list(emb[0].view(1,-1).data.cpu().numpy().flatten()), sep=' ')
473+
462474
else:
463475
# output cell state
464476
print(*list(hidden[1][args.view_layer].view(1, -1).data.cpu().numpy().flatten()), sep=' ')

0 commit comments

Comments
 (0)