Skip to content
This repository was archived by the owner on Feb 12, 2022. It is now read-only.

Commit acb04a5

Browse files
committed
Code release
0 parents  commit acb04a5

13 files changed

+1107
-0
lines changed

LICENSE

+29
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,29 @@
1+
BSD 3-Clause License
2+
3+
Copyright (c) 2017,
4+
All rights reserved.
5+
6+
Redistribution and use in source and binary forms, with or without
7+
modification, are permitted provided that the following conditions are met:
8+
9+
* Redistributions of source code must retain the above copyright notice, this
10+
list of conditions and the following disclaimer.
11+
12+
* Redistributions in binary form must reproduce the above copyright notice,
13+
this list of conditions and the following disclaimer in the documentation
14+
and/or other materials provided with the distribution.
15+
16+
* Neither the name of the copyright holder nor the names of its
17+
contributors may be used to endorse or promote products derived from
18+
this software without specific prior written permission.
19+
20+
THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
21+
AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
22+
IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
23+
DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
24+
FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
25+
DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
26+
SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
27+
CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
28+
OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
29+
OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.

README.md

+103
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,103 @@
1+
# AWD-LSTM Language Model
2+
3+
### Averaged Stochastic Gradient Descent with Weight Dropped LSTM
4+
5+
This repository contains the code used for [Salesforce Research](https://einstein.ai/)'s [Regularizing and Optimizing LSTM Language Models](https://arxiv.org/abs/1708.02182) paper, originally forked from the [PyTorch word level language modeling example](https://github.com/pytorch/examples/tree/master/word_language_model).
6+
The model comes with instructions to train a word level language model over the Penn Treebank (PTB) and [WikiText-2](https://einstein.ai/research/the-wikitext-long-term-dependency-language-modeling-dataset) (WT2) datasets, though the model is likely extensible to many other datasets.
7+
8+
+ Install PyTorch 0.1.12_2
9+
+ Run `getdata.sh` to acquire the Penn Treebank and WikiText-2 datasets
10+
+ Train the base model using `main.py`
11+
+ Finetune the model using `finetune.py`
12+
+ Apply the [continuous cache pointer](http://xxx.lanl.gov/abs/1612.04426) to the finetuned model using `pointer.py`
13+
14+
If you use this code or our results in your research, please cite:
15+
16+
```
17+
@article{merityRegOpt,
18+
title={{Regularizing and Optimizing LSTM Language Models}},
19+
author={Merity, Stephen and Keskar, Nitish Shirish and Socher, Richard},
20+
journal={arXiv preprint arXiv:1708.02182},
21+
year={2017}
22+
}
23+
```
24+
25+
## Software Requirements
26+
27+
This codebase requires Python 3 and PyTorch 0.1.12_2.
28+
29+
Note the older version of PyTorch - upgrading to later versions would require minor updates and would prevent the exact reproductions of the results below.
30+
Pull requests which update to later PyTorch versions are welcome, especially if they have baseline numbers to report too :)
31+
32+
## Experiments
33+
34+
The codebase was modified during the writing of the paper, preventing exact reproduction due to minor differences in random seeds or similar.
35+
The guide below produces results largely similar to the numbers reported.
36+
37+
For data setup, run `./getdata.sh`.
38+
This script collects the Mikolov pre-processed Penn Treebank and the WikiText-2 datasets and places them in the `data` directory.
39+
40+
**Important:** If you're going to continue experimentation beyond reproduction, comment out the test code and use the validation metrics until reporting your final results.
41+
This is proper experimental practice and is especially important when tuning hyperparameters, such as those used by the pointer.
42+
43+
#### Penn Treebank (PTB)
44+
45+
The instruction below trains a PTB model that without finetuning achieves perplexities of `61.2` / `58.9` (validation / testing), with finetuning achieves perplexities of `58.8` / `56.6`, and with the continuous cache pointer augmentation achieves perplexities of `53.5` / `53.0`.
46+
47+
First, train the model:
48+
49+
`python main.py --batch_size 20 --data data/penn --dropouti 0.4 --seed 28 --epoch 300 --save PTB.pt`
50+
51+
The first epoch should result in a validation perplexity of `308.03`.
52+
53+
To then fine-tune that model:
54+
55+
`python finetune.py --batch_size 20 --data data/penn --dropouti 0.4 --seed 28 --epoch 300 --save PTB.pt`
56+
57+
The validation perplexity after the first epoch should be `60.85`.
58+
59+
**Note:** Fine-tuning modifies the original saved model in `PTB.pt` - if you wish to keep the original weights you must copy the file.
60+
61+
Finally, to run the pointer:
62+
63+
`python pointer.py --data data/penn --save PTB.pt --lambdasm 0.1 --theta 1.0 --window 500 --bptt 5000`
64+
65+
Note that the model in the paper was trained for 500 epochs and the batch size was 40, in comparison to 300 and 20 for the model above.
66+
The window size for this pointer is chosen to be 500 instead of 2000 as in the paper.
67+
68+
**Note:** BPTT just changes the length of the sequence pushed onto the GPU but won't impact the final result.
69+
70+
#### WikiText-2 (WT2)
71+
72+
The instruction below train a WT2 model that without finetuning achieves perplexities of `69.1` / `66.1` (validation / testing), with finetuning achieves perplexities of `68.7` / `65.8`, and with the continuous cache pointer augmentation achieves perplexities of `53.6` / `52.0` (`51.95` specifically).
73+
74+
`python main.py --seed 20923 --epochs 750 --data data/wikitext-2 --save WT2.pt`
75+
76+
The first epoch should result in a validation perplexity of `629.93`.
77+
78+
`python -u finetune.py --seed 1111 --epochs 750 --data data/wikitext-2 --save WT2.pt`
79+
80+
The validation perplexity after the first epoch should be `69.14`.
81+
82+
**Note:** Fine-tuning modifies the original saved model in `PTB.pt` - if you wish to keep the original weights you must copy the file.
83+
84+
Finally, run the pointer:
85+
86+
`python pointer.py --save WT2.pt --lambdasm 0.1279 --theta 0.662 --window 3785 --bptt 2000 --data data/wikitext-2`
87+
88+
**Note:** BPTT just changes the length of the sequence pushed onto the GPU but won't impact the final result.
89+
90+
## Speed
91+
92+
All the augmentations to the LSTM, including our variant of [DropConnect (Wan et al. 2013)](https://cs.nyu.edu/~wanli/dropc/dropc.pdf) termed weight dropping which adds recurrent dropout, allow for the use of NVIDIA's cuDNN LSTM implementation.
93+
PyTorch will automatically use the cuDNN backend if run on CUDA with cuDNN installed.
94+
This ensures the model is fast to train even when convergence may take many hundreds of epochs.
95+
96+
The default speeds for the model during training on an NVIDIA Quadro GP100:
97+
98+
+ Penn Treebank: approximately 45 seconds per epoch for batch size 40, approximately 65 seconds per epoch with batch size 20
99+
+ WikiText-2: approximately 105 seconds per epoch for batch size 80
100+
101+
Speeds are approximately three times slower on a K80. On a K80 or other memory cards with less memory you may wish to enable [the cap on the maximum sampled sequence length](https://github.com/salesforce/awd-lstm-lm/blob/ef9369d277f8326b16a9f822adae8480b6d492d0/main.py#L131) to prevent out-of-memory (OOM) errors, especially for WikiText-2.
102+
103+
If speed is a major issue, SGD converges more quickly than our non-monotonically triggered variant of ASGD though achieves a worse overall perplexity.

data.py

+56
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,56 @@
1+
import os
2+
import torch
3+
4+
from collections import Counter
5+
6+
7+
class Dictionary(object):
8+
def __init__(self):
9+
self.word2idx = {}
10+
self.idx2word = []
11+
self.counter = Counter()
12+
self.total = 0
13+
14+
def add_word(self, word):
15+
if word not in self.word2idx:
16+
self.idx2word.append(word)
17+
self.word2idx[word] = len(self.idx2word) - 1
18+
token_id = self.word2idx[word]
19+
self.counter[token_id] += 1
20+
self.total += 1
21+
return self.word2idx[word]
22+
23+
def __len__(self):
24+
return len(self.idx2word)
25+
26+
27+
class Corpus(object):
28+
def __init__(self, path):
29+
self.dictionary = Dictionary()
30+
self.train = self.tokenize(os.path.join(path, 'train.txt'))
31+
self.valid = self.tokenize(os.path.join(path, 'valid.txt'))
32+
self.test = self.tokenize(os.path.join(path, 'test.txt'))
33+
34+
def tokenize(self, path):
35+
"""Tokenizes a text file."""
36+
assert os.path.exists(path)
37+
# Add words to the dictionary
38+
with open(path, 'r') as f:
39+
tokens = 0
40+
for line in f:
41+
words = line.split() + ['<eos>']
42+
tokens += len(words)
43+
for word in words:
44+
self.dictionary.add_word(word)
45+
46+
# Tokenize file content
47+
with open(path, 'r') as f:
48+
ids = torch.LongTensor(tokens)
49+
token = 0
50+
for line in f:
51+
words = line.split() + ['<eos>']
52+
for word in words:
53+
ids[token] = self.dictionary.word2idx[word]
54+
token += 1
55+
56+
return ids

embed_regularize.py

+41
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,41 @@
1+
import numpy as np
2+
3+
import torch
4+
from torch.autograd import Variable
5+
6+
def embedded_dropout(embed, words, dropout=0.1, scale=None):
7+
if dropout:
8+
mask = embed.weight.data.new().resize_((embed.weight.size(0), 1)).bernoulli_(1 - dropout).expand_as(embed.weight) / (1 - dropout)
9+
mask = Variable(mask)
10+
masked_embed_weight = mask * embed.weight
11+
else:
12+
masked_embed_weight = embed.weight
13+
if scale:
14+
masked_embed_weight = scale.expand_as(masked_embed_weight) * masked_embed_weight
15+
16+
padding_idx = embed.padding_idx
17+
if padding_idx is None:
18+
padding_idx = -1
19+
X = embed._backend.Embedding(
20+
padding_idx, embed.max_norm, embed.norm_type,
21+
embed.scale_grad_by_freq, embed.sparse
22+
)(words, masked_embed_weight)
23+
return X
24+
25+
if __name__ == '__main__':
26+
V = 50
27+
h = 4
28+
bptt = 10
29+
batch_size = 2
30+
31+
embed = torch.nn.Embedding(V, h)
32+
33+
words = np.random.random_integers(low=0, high=V-1, size=(batch_size, bptt))
34+
words = torch.LongTensor(words)
35+
words = Variable(words)
36+
37+
origX = embed(words)
38+
X = embedded_dropout(embed, words)
39+
40+
print(origX)
41+
print(X)

0 commit comments

Comments
 (0)