Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
32 changes: 28 additions & 4 deletions lm_eval/models/huggingface.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
import torch
import torch.nn.functional as F
import transformers
from transformers import LlamaForCausalLM, LlamaTokenizer
import peft
from peft import __version__ as PEFT_VERSION
from pathlib import Path
Expand Down Expand Up @@ -323,6 +324,22 @@ def _create_auto_model_peft(
)
return model

# def _create_auto_tokenizer(
# self,
# *,
# pretrained: str,
# revision: str,
# subfolder: str,
# tokenizer: Optional[str] = None,
# ) -> transformers.PreTrainedTokenizer:
# """Returns a pre-trained tokenizer from a pre-trained tokenizer configuration."""
# tokenizer = self.AUTO_TOKENIZER_CLASS.from_pretrained(
# pretrained if tokenizer is None else tokenizer,
# revision=revision + ("/" + subfolder if subfolder is not None else ""),
# )
# tokenizer.pad_token = tokenizer.eos_token
# return tokenizer

def _create_auto_tokenizer(
self,
*,
Expand All @@ -332,13 +349,20 @@ def _create_auto_tokenizer(
tokenizer: Optional[str] = None,
) -> transformers.PreTrainedTokenizer:
"""Returns a pre-trained tokenizer from a pre-trained tokenizer configuration."""
tokenizer = self.AUTO_TOKENIZER_CLASS.from_pretrained(
pretrained if tokenizer is None else tokenizer,
revision=revision + ("/" + subfolder if subfolder is not None else ""),
)
try:
tokenizer = self.AUTO_TOKENIZER_CLASS.from_pretrained(
pretrained if tokenizer is None else tokenizer,
revision=revision + ("/" + subfolder if subfolder is not None else ""),
)
except:
# LLaMATokenizer not found, using default tokenizer
tokenizer = LlamaTokenizer.from_pretrained(pretrained,
revision=revision + ("/" + subfolder if subfolder is not None else ""),)

tokenizer.pad_token = tokenizer.eos_token
return tokenizer


@property
def add_special_tokens(self) -> bool:
"""Whether to include special tokens in encoded text. This should be
Expand Down