-
Notifications
You must be signed in to change notification settings - Fork 8
Expand file tree
/
Copy pathlogit_processor.py
More file actions
46 lines (40 loc) · 1.87 KB
/
logit_processor.py
File metadata and controls
46 lines (40 loc) · 1.87 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
from collections.abc import Callable
from typing import Any
import torch
from transformers.generation import LogitsProcessor
class ConstrainedLogitsProcessor(LogitsProcessor):
"""based on trie to restrict next token candidates."""
def __init__(
self,
prefix_allowed_tokens_fn: Callable[[int, list[int]], list[int]],
num_beams: int,
prefix_index: int = 3,
prefix_ids: list[int] = None,
eos_token_id: int = None,
):
self._prefix_allowed_tokens_fn = prefix_allowed_tokens_fn
self._num_beams = num_beams
self.prefix_index = prefix_index
self.eos_token_id = eos_token_id
self.prefix_ids = prefix_ids
self.count = 0
def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor) -> torch.FloatTensor:
scores = torch.nn.functional.log_softmax(scores, dim=-1)
mask = torch.full_like(scores, float("-inf"))
# support beam search: (batch*beams, seq) -> (batch, beams, seq)
assert input_ids.dim() == 2, "input_ids must be a 2D tensor"
input_ids.shape[0]
seq_len = input_ids.shape[1]
beam_sents = input_ids.view(-1, self._num_beams, seq_len)
for batch_id, beam_sent in enumerate(beam_sents):
for beam_id, sent in enumerate[Any](beam_sent):
if sent[-self.prefix_index :].tolist() == self.prefix_ids:
self.count = 0
prefix_ids = sent[-self.prefix_index - self.count :].tolist()
prefix_allowed_tokens = self._prefix_allowed_tokens_fn(prefix_ids)
if len(prefix_allowed_tokens) == 0:
assert len(prefix_allowed_tokens) > 0, "No valid tokens for prefix_ids"
idx = batch_id * self._num_beams + beam_id
mask[idx, prefix_allowed_tokens] = 0
self.count += 1
return scores + mask