diff --git a/src/lighteval/models/abstract_model.py b/src/lighteval/models/abstract_model.py index 81d725e6a..84fcc085e 100644 --- a/src/lighteval/models/abstract_model.py +++ b/src/lighteval/models/abstract_model.py @@ -30,7 +30,7 @@ from pydantic import BaseModel from transformers.tokenization_utils_base import BatchEncoding, PreTrainedTokenizerBase -from lighteval.models.model_input import GenerationParameters +from lighteval.models.model_input import ChatTemplateParameters, GenerationParameters from lighteval.models.model_output import ModelResponse from lighteval.tasks.requests import Doc @@ -51,6 +51,9 @@ class ModelConfig(BaseModel, extra="forbid"): generation_parameters (GenerationParameters): Configuration parameters that control text generation behavior, including temperature, top_p, max_new_tokens, etc. Defaults to empty GenerationParameters. + chat_template_parameters (ChatTemplateParameters): + Configuration parameters that control chat template behavior, including + reasoning_effort, enable_thinking, etc. Defaults to empty ChatTemplateParameters. system_prompt (str | None): Optional system prompt to be used with chat models. This prompt sets the behavior and context for the model during evaluation. @@ -85,6 +88,7 @@ class ModelConfig(BaseModel, extra="forbid"): model_name: str = None generation_parameters: GenerationParameters = GenerationParameters() + chat_template_parameters: ChatTemplateParameters = ChatTemplateParameters() system_prompt: str | None = None cache_dir: str = "~/.cache/huggingface/lighteval" @@ -128,7 +132,7 @@ def _parse_args(args: str) -> dict: 'model': {'model_name': 'gpt2', 'generation_parameters': {'temperature': 0.7, 'top_p': 0.9}, } - >>> parse_args("model_name=gpt2,use_cache,generation_parameters={temperature:0.7}") + >>> parse_args("model_name=gpt2,use_cache,generation_parameters={temperature:0.7},chat_template_parameters={reasoning_effort:low}") { 'model': {'model_name': 'gpt2', 'use_cache': True, 'generation_parameters': {'temperature': 0.7}}, } diff --git a/src/lighteval/models/custom/custom_model.py b/src/lighteval/models/custom/custom_model.py index ea11fe3c6..d5e2159e5 100644 --- a/src/lighteval/models/custom/custom_model.py +++ b/src/lighteval/models/custom/custom_model.py @@ -69,5 +69,4 @@ def loglikelihood(self, docs: list[Doc]) -> list[ModelResponse]: An example of a custom model can be found in `examples/custom_models/google_translate_model.py`. """ - model_name: str model_definition_file_path: str diff --git a/src/lighteval/models/endpoints/endpoint_model.py b/src/lighteval/models/endpoints/endpoint_model.py index 6b08be575..b05e41990 100644 --- a/src/lighteval/models/endpoints/endpoint_model.py +++ b/src/lighteval/models/endpoints/endpoint_model.py @@ -100,7 +100,6 @@ class ServerlessEndpointModelConfig(ModelConfig): ``` """ - model_name: str add_special_tokens: bool = True batch_size: int = 1 diff --git a/src/lighteval/models/endpoints/litellm_model.py b/src/lighteval/models/endpoints/litellm_model.py index e620fba70..630822881 100644 --- a/src/lighteval/models/endpoints/litellm_model.py +++ b/src/lighteval/models/endpoints/litellm_model.py @@ -103,7 +103,6 @@ class LiteLLMModelConfig(ModelConfig): ``` """ - model_name: str provider: str | None = None base_url: str | None = None api_key: str | None = None diff --git a/src/lighteval/models/model_input.py b/src/lighteval/models/model_input.py index ad41c23eb..92cf5e5c3 100644 --- a/src/lighteval/models/model_input.py +++ b/src/lighteval/models/model_input.py @@ -241,3 +241,20 @@ def to_sglang_dict(self) -> dict: "min_new_tokens": self.min_new_tokens, } return {k: v for k, v in args.items() if v is not None} + + +class ChatTemplateParameters(BaseModel): + reasoning_effort: str | None = None + enable_thinking: bool | None = None + + def to_transformers_dict(self) -> dict: + """Selects relevant chat template parameters for transformers models. + + Returns: + dict: Valid parameters for the chat template + """ + args = { + "reasoning_effort": self.reasoning_effort, + "enable_thinking": self.enable_thinking, + } + return {k: v for k, v in args.items() if v is not None} diff --git a/src/lighteval/models/sglang/sglang_model.py b/src/lighteval/models/sglang/sglang_model.py index e5c0f4d87..603fe88ba 100644 --- a/src/lighteval/models/sglang/sglang_model.py +++ b/src/lighteval/models/sglang/sglang_model.py @@ -118,7 +118,6 @@ class SGLangModelConfig(ModelConfig): ``` """ - model_name: str load_format: str = "auto" dtype: str = "auto" tp_size: PositiveInt = 1 # how many GPUs to use for tensor parallelism diff --git a/src/lighteval/models/transformers/transformers_model.py b/src/lighteval/models/transformers/transformers_model.py index ed97faf84..c74515ef0 100644 --- a/src/lighteval/models/transformers/transformers_model.py +++ b/src/lighteval/models/transformers/transformers_model.py @@ -141,7 +141,6 @@ class TransformersModelConfig(ModelConfig): (bitsandbytes for 4-bit/8-bit quantization). """ - model_name: str tokenizer: str | None = None subfolder: str | None = None revision: str = "main" @@ -234,7 +233,10 @@ def __init__( model_size = -1 self.prompt_manager = PromptManager( - use_chat_template=self.use_chat_template, tokenizer=self.tokenizer, system_prompt=config.system_prompt + use_chat_template=self.use_chat_template, + tokenizer=self.tokenizer, + system_prompt=config.system_prompt, + chat_template_parameters=config.chat_template_parameters, ) # Initialize cache for tokenization and predictions diff --git a/src/lighteval/models/transformers/vlm_transformers_model.py b/src/lighteval/models/transformers/vlm_transformers_model.py index 0697ab729..433ae67a5 100644 --- a/src/lighteval/models/transformers/vlm_transformers_model.py +++ b/src/lighteval/models/transformers/vlm_transformers_model.py @@ -112,7 +112,6 @@ class VLMTransformersModelConfig(ModelConfig): cache_dir (str, optional, defaults to "~/.cache/huggingface/lighteval"): Directory to cache the model. """ - model_name: str processor: str | None = None use_fast_image_processor: bool | None = None subfolder: str | None = None diff --git a/src/lighteval/models/vllm/vllm_model.py b/src/lighteval/models/vllm/vllm_model.py index 969caf8fa..983947909 100644 --- a/src/lighteval/models/vllm/vllm_model.py +++ b/src/lighteval/models/vllm/vllm_model.py @@ -153,7 +153,6 @@ class VLLMModelConfig(ModelConfig): ``` """ - model_name: str tokenizer: str | None = None revision: str = "main" # revision of the model dtype: str = "bfloat16" diff --git a/src/lighteval/tasks/prompt_manager.py b/src/lighteval/tasks/prompt_manager.py index 2c854281d..2454984ea 100644 --- a/src/lighteval/tasks/prompt_manager.py +++ b/src/lighteval/tasks/prompt_manager.py @@ -28,6 +28,7 @@ from itertools import cycle from typing import TYPE_CHECKING +from lighteval.models.model_input import ChatTemplateParameters from lighteval.tasks.requests import Doc from lighteval.utils.utils import as_list @@ -40,10 +41,17 @@ class PromptManager: - def __init__(self, use_chat_template: bool = False, tokenizer=None, system_prompt: str | None = None): + def __init__( + self, + use_chat_template: bool = False, + tokenizer=None, + system_prompt: str | None = None, + chat_template_parameters: ChatTemplateParameters | None = None, + ): self.use_chat_template = use_chat_template self.tokenizer = tokenizer self.system_prompt = system_prompt # System prompt to be used in chat templates + self.chat_template_parameters = chat_template_parameters if chat_template_parameters else {} def prepare_prompt(self, doc: Doc) -> str: """Prepare a prompt from a document, either using chat template or plain text format. @@ -133,6 +141,7 @@ def _prepare_chat_template(self, doc: Doc, tokenize: bool = True) -> str: messages, tokenize=False, add_generation_prompt=True, + **self.chat_template_parameters.to_transformers_dict(), ) else: # for apis diff --git a/tests/unit/prompt/test_prompt_manager_class.py b/tests/unit/prompt/test_prompt_manager_class.py index f552a9c31..9fa21139a 100644 --- a/tests/unit/prompt/test_prompt_manager_class.py +++ b/tests/unit/prompt/test_prompt_manager_class.py @@ -24,6 +24,7 @@ import pytest +from lighteval.models.model_input import ChatTemplateParameters from lighteval.tasks.prompt_manager import PromptManager from lighteval.tasks.requests import Doc @@ -47,6 +48,22 @@ def test_init_with_chat_template(self): assert pm.tokenizer == tokenizer assert pm.system_prompt == system_prompt + def test_init_with_chat_template_and_chat_template_parameters(self): + """Test PromptManager initialization with chat template enabled and chat template parameters.""" + tokenizer = Mock() + system_prompt = "You are a helpful assistant." + pm = PromptManager( + use_chat_template=True, + tokenizer=tokenizer, + system_prompt=system_prompt, + chat_template_parameters=ChatTemplateParameters(reasoning_effort="medium"), + ) + assert pm.use_chat_template is True + assert pm.tokenizer == tokenizer + assert pm.system_prompt == system_prompt + assert pm.chat_template_parameters is not None + assert pm.chat_template_parameters.reasoning_effort == "medium" + def test_prepare_prompt_plain_text_basic(self): """Test prepare_prompt with plain text format and basic document.""" pm = PromptManager() diff --git a/tests/unit/utils/test_model_config.py b/tests/unit/utils/test_model_config.py new file mode 100644 index 000000000..f89b4d7d2 --- /dev/null +++ b/tests/unit/utils/test_model_config.py @@ -0,0 +1,84 @@ +# MIT License + +# Copyright (c) 2024 The HuggingFace Team + +# Permission is hereby granted, free of charge, to any person obtaining a copy +# of this software and associated documentation files (the "Software"), to deal +# in the Software without restriction, including without limitation the rights +# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +# copies of the Software, and to permit persons to whom the Software is +# furnished to do so, subject to the following conditions: + +# The above copyright notice and this permission notice shall be included in all +# copies or substantial portions of the Software. + +# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE +# SOFTWARE. + +import unittest + +from lighteval.models.model_input import ChatTemplateParameters, GenerationParameters +from lighteval.models.utils import ModelConfig + + +class TestModelConfig(unittest.TestCase): + def test_model_config_init(self): + config = ModelConfig( + model_name="meta-llama/Llama-3.1-8B-Instruct", + generation_parameters=GenerationParameters(temperature=0.7), + system_prompt="You are a helpful assistant.", + chat_template_parameters=ChatTemplateParameters(reasoning_effort="low"), + ) + + self.assertEqual(config.model_name, "meta-llama/Llama-3.1-8B-Instruct") + self.assertEqual(config.generation_parameters.temperature, 0.7) + self.assertEqual(config.system_prompt, "You are a helpful assistant.") + self.assertEqual(config.chat_template_parameters.reasoning_effort, "low") + + def test_model_config_init_command_line(self): + config = ModelConfig.from_args( + 'model_name=meta-llama/Llama-3.1-8B-Instruct,system_prompt="You are a helpful assistant.",generation_parameters={temperature:0.7},chat_template_parameters={reasoning_effort:low}' + ) + + self.assertEqual(config.model_name, "meta-llama/Llama-3.1-8B-Instruct") + self.assertEqual(config.generation_parameters.temperature, 0.7) + self.assertEqual(config.system_prompt, '"You are a helpful assistant."') # is this what we want? + self.assertEqual(config.chat_template_parameters.reasoning_effort, "low") + + def test_model_config_generation_parameters_parse_single_int(self): + config = ModelConfig.from_args( + "model_name=meta-llama/Llama-3.1-8B-Instruct,generation_parameters={temperature:0.7}" + ) + self.assertEqual(config.generation_parameters.temperature, 0.7) + + def test_model_config_generation_parameters_parse_multiple_int(self): + config = ModelConfig.from_args( + "model_name=meta-llama/Llama-3.1-8B-Instruct,generation_parameters={temperature:0.7,top_k:42}" + ) + self.assertEqual(config.generation_parameters.temperature, 0.7) + self.assertEqual(config.generation_parameters.top_k, 42) + + @unittest.skip("This is not working at this time") + def test_model_config_generation_parameters_parse_string(self): + config = ModelConfig.from_args( + 'model_name=meta-llama/Llama-3.1-8B-Instruct,generation_parameters={response_format:{"type":"json_object"}}' + ) + self.assertEqual(config.generation_parameters.temperature, 0.7) + + @unittest.skip("This is not working at this time") + def test_model_config_chat_template_parameters_parse_single_int(self): + config = ModelConfig.from_args( + "model_name=meta-llama/Llama-3.1-8B-Instruct,chat_template_parameters={temperature:0.7}" + ) + self.assertEqual(config.chat_template_parameters.temperature, 0.7) + + def test_model_config_chat_template_parameters_parse_string(self): + config = ModelConfig.from_args( + "model_name=meta-llama/Llama-3.1-8B-Instruct,chat_template_parameters={reasoning_effort:low}" + ) + self.assertEqual(config.chat_template_parameters.reasoning_effort, "low")