Skip to content

Commit f6ab5e7

Browse files
committed
Comment reset_if_new_batch logic
Signed-off-by: aerdem4 <[email protected]>
1 parent e646ed1 commit f6ab5e7

File tree

1 file changed

+6
-3
lines changed
  • logits_processor_zoo/transformers

1 file changed

+6
-3
lines changed

logits_processor_zoo/transformers/base.py

Lines changed: 6 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -26,11 +26,14 @@ def _reset(self):
2626
pass
2727

2828
def _reset_if_new_batch(self, input_ids: torch.LongTensor):
29-
if self.input_len is not None:
29+
first_time = self.input_len is None
30+
if first_time:
31+
self._reset()
32+
else:
33+
# Assuming 1 new token is generated in sequential calls of the same batch. Resets if it is not the case.
34+
# It is a hack to figure out if a new generation starts because transformers API doesn't provide it.
3035
if input_ids.shape[1] != self.input_len + 1:
3136
self._reset()
32-
else:
33-
self._reset()
3437

3538
self.input_len = input_ids.shape[1]
3639

0 commit comments

Comments
 (0)