Skip to content

Commit 5302f2a

Browse files
Fix duplicate BOS issue; add_bos_token defaults to None (EleutherAI#3347)
* fix duplicate `bos` token when `context==""` * add docs * check tokenizer.add_bos_token for bos control * fix params * skip duplicate bos * fix bos token handling * fix bos token handling * fix box_token handling * fixup! default add_special_tokens as unset * `self.tokenizer.bos_token` can be None * fix type * Update lm_eval/models/huggingface.py Co-authored-by: Cyrus Leung <[email protected]> * refactor bos token handling logic * add tests for bos * fix tests --------- Co-authored-by: Cyrus Leung <[email protected]>
1 parent 90950a8 commit 5302f2a

File tree

5 files changed

+772
-72
lines changed

5 files changed

+772
-72
lines changed

lm_eval/api/model.py

Lines changed: 72 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -324,10 +324,11 @@ class TemplateLM(LM):
324324
"""
325325

326326
tokenizer = None
327+
backend = "causal"
327328

328329
@property
329330
@abc.abstractmethod
330-
def eot_token_id(self):
331+
def eot_token_id(self) -> int:
331332
pass
332333

333334
@property
@@ -336,9 +337,13 @@ def prefix_token_id(self):
336337
return self.eot_token_id
337338

338339
@abc.abstractmethod
339-
def tok_encode(self, string: str, **kwargs) -> list[int]:
340+
def tok_encode(
341+
self, string: str, add_special_tokens: Optional[bool] = None, **kwargs
342+
) -> list[int]:
340343
"""
341344
Tokenize a string using the model's tokenizer and return a list of token IDs.
345+
NOTE: This method is expected to handle strings which already contain the BOS token (when add_special_tokens=None).
346+
Otherwise, will use add_special_tokens if specified.
342347
"""
343348
pass
344349

@@ -351,38 +356,93 @@ def _loglikelihood_tokens(
351356
def _encode_pair(
352357
self, context: str, continuation: str
353358
) -> tuple[list[int], list[int]]:
354-
import transformers
359+
"""
360+
Encode a context-continuation pair into separate token ID lists.
361+
362+
This method handles the tokenization of context and continuation strings while
363+
preserving proper boundary handling. Trailing spaces in the context are moved
364+
to the beginning of the continuation to ensure correct tokenization at the
365+
word boundary.
366+
367+
For Seq2Seq models (encoder-decoder), context and continuation are encoded
368+
separately. For other model types (decoder-only), the full sequence is encoded
369+
together to ensure proper tokenization, then split at the context boundary.
370+
371+
:param context: str
372+
The context string. Can be empty (will be handled by the caller).
373+
:param continuation: str
374+
The continuation string to be scored.
375+
376+
:return: tuple[list[int], list[int]]
377+
A tuple of (context_enc, continuation_enc) where:
378+
- context_enc: Token IDs for the context
379+
- continuation_enc: Token IDs for the continuation
380+
381+
Note:
382+
This method does NOT handle empty context. The caller should
383+
handle empty context (see loglikelihood method).
384+
"""
385+
assert context, "Context cannot be empty!"
355386

356387
n_spaces = len(context) - len(context.rstrip())
357388
if n_spaces > 0:
358389
continuation = context[-n_spaces:] + continuation
359390
context = context[:-n_spaces]
360391

361-
model_class = getattr(self, "AUTO_MODEL_CLASS", None)
362-
363-
if model_class == transformers.AutoModelForSeq2SeqLM:
364-
context_enc = self.tok_encode(context)
365-
continuation_enc = self.tok_encode(continuation, add_special_tokens=False)
366-
else:
392+
if self.backend == "causal":
367393
whole_enc = self.tok_encode(context + continuation)
368394
context_enc = self.tok_encode(context)
369395

370396
context_enc_len = len(context_enc)
371397
continuation_enc = whole_enc[context_enc_len:]
398+
else:
399+
# for SEQ2SEQ case we need to encode separately
400+
context_enc = self.tok_encode(context)
401+
continuation_enc = self.tok_encode(continuation, add_special_tokens=False)
372402

373403
return context_enc, continuation_enc
374404

375405
def loglikelihood(
376406
self, requests: list["Instance"], disable_tqdm: bool = False
377407
) -> list[tuple[float, bool]]:
408+
"""
409+
Compute log-likelihood of generating continuations from contexts.
410+
411+
This is the concrete implementation for TemplateLM and its subclasses.
412+
It tokenizes context-continuation pairs and delegates scoring to
413+
_loglikelihood_tokens.
414+
415+
**IMPORTANT**: This method is expected to handle empty context strings.
416+
When context is empty (""), it uses the model's prefix_token_id (typically
417+
BOS or EOS token) as context. If the continuation already starts with the
418+
prefix token, it reuses that token as context instead of duplicating it.
419+
420+
:param requests: list[Instance]
421+
List of Instance objects with property `args` returning (context, continuation) tuples.
422+
:param disable_tqdm: bool
423+
Whether to disable the progress bar in _loglikelihood_tokens.
424+
425+
:return: list[tuple[float, bool]]
426+
List of (log_prob, is_greedy) tuples for each request.
427+
428+
Implementation details:
429+
- Empty context: Uses prefix_token_id (BOS/EOS) as context
430+
- Non-empty context: Uses _encode_pair for proper tokenization
431+
- Avoids token duplication when continuation starts with prefix_token_id
432+
"""
378433
new_reqs = []
379434
for context, continuation in [req.args for req in requests]:
380435
if context == "":
381-
# BOS or EOS as context
436+
continuation_enc = self.tok_encode(
437+
continuation, add_special_tokens=False
438+
)
439+
# BOS or EOS as context: handle when context is empty -> (context + continuation) -> (BOS + continuation
382440
context_enc, continuation_enc = (
383-
[self.prefix_token_id],
384-
self.tok_encode(continuation),
441+
([self.prefix_token_id], continuation_enc)
442+
if self.prefix_token_id != continuation_enc[0]
443+
else (continuation_enc[:1], continuation_enc[1:])
385444
)
445+
# BOS or EOS as context
386446
else:
387447
context_enc, continuation_enc = self._encode_pair(context, continuation)
388448

lm_eval/models/huggingface.py

Lines changed: 23 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -32,10 +32,12 @@
3232
from lm_eval.api.registry import register_model
3333
from lm_eval.models.utils import (
3434
Collator,
35+
_add_special_kwargs,
3536
clear_torch_cache,
3637
configure_pad_token,
3738
get_dtype,
3839
handle_stop_sequences,
40+
has_bos_prefix,
3941
pad_and_concat,
4042
postprocess_generated_text,
4143
stop_sequences_criteria,
@@ -84,7 +86,7 @@ def __init__(
8486
max_batch_size: int | None = 64,
8587
trust_remote_code: bool | None = False,
8688
use_fast_tokenizer: bool | None = True,
87-
add_bos_token: bool | None = False,
89+
add_bos_token: bool | None = None,
8890
prefix_token_id: int | None = None,
8991
# arguments used for splitting a model across GPUs naively.
9092
# only used if `parallelize=True`.
@@ -258,11 +260,6 @@ def __init__(
258260
)
259261

260262
self.add_bos_token = add_bos_token
261-
if "gemma" in getattr(self.config, "model_type", ""):
262-
self.add_bos_token = True
263-
eval_logger.info(
264-
f"Model type is '{self.config.model_type}', part of the Gemma family--a BOS token will be used as Gemma underperforms without it."
265-
)
266263

267264
self._max_length = max_length
268265
self.pretrained = pretrained
@@ -744,7 +741,7 @@ def _create_tokenizer(
744741
trust_remote_code: bool | None = False,
745742
use_fast_tokenizer: bool | None = True,
746743
gguf_file: str | None = None,
747-
add_bos_token: bool | None = False,
744+
add_bos_token: bool | None = None,
748745
subfolder: str | None = "",
749746
) -> None:
750747
"""Helper method during initialization.
@@ -763,8 +760,8 @@ def _create_tokenizer(
763760
else:
764761
kwargs["use_fast"] = use_fast_tokenizer
765762

766-
if add_bos_token:
767-
kwargs["add_bos_token"] = True
763+
if add_bos_token is not None:
764+
kwargs["add_bos_token"] = add_bos_token
768765

769766
if subfolder:
770767
kwargs["subfolder"] = subfolder
@@ -858,24 +855,20 @@ def forward_batch(batch_size: int):
858855
def tok_encode(
859856
self,
860857
string: str,
861-
left_truncate_len: int | None = None,
862858
add_special_tokens: bool | None = None,
859+
left_truncate_len: int | None = None,
860+
**kwargs,
863861
) -> list[int]:
864-
""" """
865862
# default for None - empty dict, use predefined tokenizer param
866863
# used for all models except for CausalLM or predefined value
867-
special_tokens_kwargs = {}
868-
869-
# by default for CausalLM - false or self.add_bos_token is set
870-
if add_special_tokens is None:
871-
if self.backend == "causal":
872-
special_tokens_kwargs = {
873-
"add_special_tokens": False or self.add_bos_token
874-
}
875-
# otherwise the method explicitly defines the value
876-
else:
877-
special_tokens_kwargs = {"add_special_tokens": add_special_tokens}
878-
864+
special_tokens_kwargs = _add_special_kwargs(
865+
add_special_tokens, self.add_bos_token
866+
)
867+
# set add_special_tokens=False if the string already starts with BOS token.
868+
if add_special_tokens is None and has_bos_prefix(
869+
string, self.tokenizer.decode(self.prefix_token_id)
870+
):
871+
special_tokens_kwargs["add_special_tokens"] = False
879872
encoding = self.tokenizer.encode(string, **special_tokens_kwargs)
880873

881874
# left-truncate the encoded context to be at most `left_truncate_len` tokens long
@@ -897,7 +890,12 @@ def tok_batch_encode(
897890

898891
add_special_tokens = {}
899892
if self.backend == "causal":
900-
add_special_tokens = {"add_special_tokens": False or self.add_bos_token}
893+
if has_bos_prefix(strings[0], getattr(self.tokenizer, "bos_token", None)):
894+
add_special_tokens = {"add_special_tokens": False}
895+
elif self.add_bos_token is not None:
896+
add_special_tokens = {"add_special_tokens": self.add_bos_token}
897+
else:
898+
add_special_tokens = {}
901899

902900
encoding = self.tokenizer(
903901
strings,
@@ -971,7 +969,7 @@ def _model_generate(
971969
context,
972970
max_length: int,
973971
stop: list[str],
974-
**generation_kwargs: dict[str, Any],
972+
**generation_kwargs,
975973
) -> torch.Tensor:
976974
# temperature = 0.0 if not set
977975
# if do_sample is false and temp==0.0:

lm_eval/models/utils.py

Lines changed: 18 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -150,7 +150,7 @@ def get_original(self, grouped_dict):
150150

151151
def pad_and_concat(
152152
max_length: int,
153-
tensors: List[torch.Tensor],
153+
tensors: list[torch.Tensor],
154154
padding_side: Literal["right", "left"] = "right",
155155
):
156156
"""
@@ -881,3 +881,20 @@ def postprocess_generated_text(
881881
generation = generation.split(think_end_token)[-1].lstrip()
882882

883883
return generation
884+
885+
886+
def has_bos_prefix(sequence: str, bos_str: str | Iterable[str] | None = None):
887+
if bos_str is None:
888+
return False
889+
elif isinstance(bos_str, str):
890+
return sequence.startswith(bos_str)
891+
else:
892+
return any(sequence.startswith(x) for x in bos_str)
893+
894+
895+
def _add_special_kwargs(add_special_tokens: bool | None, add_bos: bool | None = None):
896+
if add_special_tokens is not None:
897+
return {"add_special_tokens": add_special_tokens}
898+
if add_bos is not None:
899+
return {"add_special_tokens": add_bos}
900+
return {}

0 commit comments

Comments
 (0)