We read every piece of feedback, and take your input very seriously.
To see all available qualifiers, see our documentation.
There was an error while loading. Please reload this page.
1 parent e646ed1 commit f6ab5e7Copy full SHA for f6ab5e7
logits_processor_zoo/transformers/base.py
@@ -26,11 +26,14 @@ def _reset(self):
26
pass
27
28
def _reset_if_new_batch(self, input_ids: torch.LongTensor):
29
- if self.input_len is not None:
+ first_time = self.input_len is None
30
+ if first_time:
31
+ self._reset()
32
+ else:
33
+ # Assuming 1 new token is generated in sequential calls of the same batch. Resets if it is not the case.
34
+ # It is a hack to figure out if a new generation starts because transformers API doesn't provide it.
35
if input_ids.shape[1] != self.input_len + 1:
36
self._reset()
- else:
- self._reset()
37
38
self.input_len = input_ids.shape[1]
39
0 commit comments