Skip to content

Commit e34f441

Browse files
committed
Hotfix: logits_all bug
1 parent 4d1eb88 commit e34f441

File tree

1 file changed

+2
-1
lines changed

1 file changed

+2
-1
lines changed

Diff for: llama_cpp/llama.py

+2-1
Original file line numberDiff line numberDiff line change
@@ -443,7 +443,8 @@ def eval(self, tokens: Sequence[int]):
443443
# Save logits
444444
rows = n_tokens if self.params.logits_all else 1
445445
cols = self._n_vocab
446-
self.scores[self.n_tokens : self.n_tokens + n_tokens, :].reshape(-1)[:] = llama_cpp.llama_get_logits(self.ctx)[:rows * cols]
446+
offset = 0 if self.params.logits_all else n_tokens - 1 # NOTE: Only save the last token logits if logits_all is False
447+
self.scores[self.n_tokens + offset: self.n_tokens + n_tokens, :].reshape(-1)[:] = llama_cpp.llama_get_logits(self.ctx)[:rows * cols]
447448
# Update n_tokens
448449
self.n_tokens += n_tokens
449450

0 commit comments

Comments
 (0)