Skip to content

Commit bf40b9e

Browse files
committed
get_embeddings
1 parent cd055d3 commit bf40b9e

File tree

1 file changed

+78
-0
lines changed

1 file changed

+78
-0
lines changed

get_embeddings.py

Lines changed: 78 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,78 @@
1+
'''
2+
Code to extract all the vocabulary embeddings from a neural language model.
3+
'''
4+
5+
from __future__ import print_function
6+
import argparse
7+
import time
8+
import math
9+
import sys
10+
import warnings
11+
import torch
12+
import torch.nn as nn
13+
import data
14+
import model
15+
16+
try:
17+
from progress.bar import Bar
18+
PROGRESS = True
19+
except ModuleNotFoundError:
20+
PROGRESS = False
21+
22+
# suppress SourceChangeWarnings
23+
warnings.filterwarnings("ignore")
24+
25+
sys.stderr.write('Libraries loaded\n')
26+
27+
# Parallelization notes:
28+
# Does not currently operate across multiple nodes
29+
# Single GPU is better for default: tied,emsize:200,nhid:200,nlayers:2,dropout:0.2
30+
#
31+
# Multiple GPUs are better for tied,emsize:1500,nhid:1500,nlayers:2,dropout:0.65
32+
# 4 GPUs train on wikitext-2 in 1/2 - 2/3 the time of 1 GPU
33+
34+
parser = argparse.ArgumentParser(description='PyTorch RNN/LSTM Language Model')
35+
36+
# Model parameters
37+
parser.add_argument('--cuda', action='store_true',
38+
help='use CUDA')
39+
40+
# Data parameters
41+
parser.add_argument('--model_file', type=str, default='model.pt',
42+
help='path to save the final model')
43+
44+
args = parser.parse_args()
45+
46+
if torch.cuda.is_available():
47+
if not args.cuda:
48+
print("WARNING: You have a CUDA device, so you should probably run with --cuda")
49+
else:
50+
torch.cuda.manual_seed(args.seed)
51+
if torch.cuda.device_count() == 1:
52+
args.single = True
53+
54+
device = torch.device("cuda" if args.cuda else "cpu")
55+
56+
###############################################################################
57+
# Load the model
58+
###############################################################################
59+
60+
with open(args.model_file, 'rb') as f:
61+
if args.cuda:
62+
model = torch.load(f).to(device)
63+
else:
64+
model = torch.load(f, map_location='cpu')
65+
66+
if args.cuda and (not args.single) and (torch.cuda.device_count() > 1):
67+
# If applicable, use multi-gpu for training
68+
# Scatters minibatches (in dim=1) across available GPUs
69+
model = nn.DataParallel(model, dim=1)
70+
if isinstance(model, torch.nn.DataParallel):
71+
# if multi-gpu, access real model for training
72+
model = model.module
73+
# after load the rnn params are not a continuous chunk of memory
74+
# this makes them a continuous chunk, and will speed up forward pass
75+
model.rnn.flatten_parameters()
76+
77+
for word in model.encoder(torch.LongTensor([w for w in range(model.encoder.num_embeddings)])).data.numpy().tolist():
78+
print(' '.join(str(f) for f in word))

0 commit comments

Comments
 (0)