Skip to content

Commit 088f046

Browse files
committed
Now works with Pytorch 1.8
1 parent a91691e commit 088f046

File tree

2 files changed

+7
-11
lines changed

2 files changed

+7
-11
lines changed

README.md

Lines changed: 3 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -1,20 +1,16 @@
11
# Neural complexity
22
A neural language model that computes various information-theoretic processing complexity measures (e.g., surprisal) for each word given the preceding context. Also, it can function as an adaptive language model ([van Schijndel and Linzen, 2018](http://aclweb.org/anthology/D18-1499)) which adapts to test domains.
33

4+
**Note**: Recent updates remove dependencies but break compatibility with pre-2021 models. To use older models, use version 1.1.0: `git checkout tags/v1.1.0`
5+
46
### Dependencies
57
Requires the following python packages (available through pip):
6-
* [pytorch](https://pytorch.org/) v1.0.0
7-
* nltk
8+
* [pytorch](https://pytorch.org/)
89

910
The following python packages are optional:
1011
* progress
1112
* dill (to handle binarized vocabularies)
1213

13-
Requires the `punkt` nltk module. Install it from within python:
14-
15-
import nltk
16-
nltk.download('punkt')
17-
1814
### Quick Usage
1915
The below all use GPUs. To use CPUs instead, omit the `--cuda` flag.
2016

main.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -451,7 +451,7 @@ def test_evaluate(test_sentences, data_source):
451451
target = targets[word_index].unsqueeze(0)
452452
output, hidden = model(word_input, hidden)
453453
output_flat = output.view(-1, ntokens)
454-
loss = criterion(output_flat, target)
454+
loss = criterion(output_flat, target.long())
455455
total_loss += loss.item()
456456
input_word = corpus.dictionary.idx2word[int(word_input.data)]
457457
targ_word = corpus.dictionary.idx2word[int(target.data)]
@@ -482,7 +482,7 @@ def test_evaluate(test_sentences, data_source):
482482
except RuntimeError:
483483
print("Vocabulary Error! Most likely there weren't unks in training and unks are now needed for testing")
484484
raise
485-
loss = criterion(output_flat, targets)
485+
loss = criterion(output_flat, targets.long())
486486
total_loss += loss.item()
487487
if args.words:
488488
# output word-level complexity metrics
@@ -527,7 +527,7 @@ def evaluate(data_source):
527527
data, targets = get_batch(data_source, i)
528528
output, hidden = model(data, hidden)
529529
output_flat = output.view(-1, ntokens)
530-
total_loss += len(data) * criterion(output_flat, targets).item()
530+
total_loss += len(data) * criterion(output_flat, targets.long()).item()
531531
hidden = repackage_hidden(hidden)
532532
return total_loss / len(data_source)
533533

@@ -546,7 +546,7 @@ def train():
546546
hidden = repackage_hidden(hidden)
547547
model.zero_grad()
548548
output, hidden = model(data, hidden)
549-
loss = criterion(output.view(-1, ntokens), targets)
549+
loss = criterion(output.view(-1, ntokens), targets.long())
550550
loss.backward()
551551

552552
# `clip_grad_norm` helps prevent the exploding gradient problem in RNNs / LSTMs.

0 commit comments

Comments
 (0)