-
Notifications
You must be signed in to change notification settings - Fork 304
Description
Hi there,
I've been fiddling with perplexity computation in the past few days, and I was struck by a few discrepancies:
- In your manual implementation, I read:
When we run the above with
stride = 1024, i.e. no overlap, the resulting PPL is19.44, which is about the same as the19.93reported in the GPT-2 paper (note: p. 5). - On the reference page (see 'Examples'):
(Of course this is run on a very short sequence, only 50 characters!)perplexity = evaluate.load("perplexity", module_type="metric") input_texts = datasets.load_dataset("wikitext", "wikitext-2-raw-v1", split="test")["text"][:50] input_texts = [s for s in input_texts if s!=''] results = perplexity.compute(model_id='gpt2', predictions=input_texts) print(list(results.keys())) >>>['perplexities', 'mean_perplexity'] print(round(results["mean_perplexity"], 2)) >>>576.76 print(round(results["perplexities"][0], 2)) >>>889.28
Here's a Colab notebook with my experiments.
Question 1
Even when using evaluate, I get fairly different results: just for the 50 character selection, I obtain:
320.85
567.91
Something in the implementation changed since the docs were written? What am I missing?
Question 2
There is still a humongous difference between the manual computation (first page I quoted above, ported in the notebook, and similar results when experimenting with a faster, batched implementation using stride in the tokenizer) and what the metric produces. What I'm expecting for wikitext is something of the order of 20-30, and lower for bigger models, not in the hundreds, let alone in the thousands (see the notebook, result when joining all rows with \n\n)... (At first I thought it was because of pad_tokens being counted, that was wrong.) I'm now at a loss as to where that might come from. Any ideas?
Question 3
Also, notice that I can pass the dataset either as a list of strings of varying lengths, or as one long string, and it runs, but the perplexity class itself, on its own, does not seem to handle that (no return_overflowing_tokens = True passed to the tokenizer, for instance, I performed a few tests with the class only): I'm guessing the dataset is processed somewhere before being fed to it? I'm in the process of understanding the whole pipeline of invocation of the metric, but so far I haven't found where it happens.
Apologies if the first version of this issue had quite a few silly mistakes, and thanks for reading!