Skip to content

Commit 15272b2

Browse files
akristing22erininfinitearthCopilot
authored
fix: chat template for local model (#72)
* fix: chat template for local model * added regression test * added regression test * remove tasks * remove tasks * Potential fix for pull request finding Co-authored-by: Copilot Autofix powered by AI <175728472+Copilot@users.noreply.github.com> * resolving precommit * replaced gated models for ungated models * alles raus was keine miete zahlt --------- Co-authored-by: erinin <erinin@altara.zitis.lan> Co-authored-by: finitearth <t.zehle@gmail.com> Co-authored-by: Copilot Autofix powered by AI <175728472+Copilot@users.noreply.github.com>
1 parent 0523fb3 commit 15272b2

4 files changed

Lines changed: 46 additions & 2 deletions

File tree

.gitignore

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -12,3 +12,4 @@ poetry.lock
1212
CLAUDE.md
1313
**/CLAUDE.local.md
1414
.mypy_cache/
15+
token.txt

promptolution/llms/local_llm.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -79,7 +79,7 @@ def _get_response(self, prompts: List[str], system_prompts: List[str]) -> List[s
7979
"""
8080
inputs: List[List[Dict[str, str]]] = []
8181
for prompt, sys_prompt in zip(prompts, system_prompts):
82-
inputs.append([{"role": "system", "prompt": sys_prompt}, {"role": "user", "prompt": prompt}])
82+
inputs.append([{"role": "system", "content": sys_prompt}, {"role": "user", "content": prompt}])
8383

8484
with torch.no_grad():
8585
response = self.pipeline(inputs, pad_token_id=self.eos_token_id)

pyproject.toml

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -52,8 +52,9 @@ pytest = ">=8.3.5"
5252
pytest-cov = ">=6.1.1"
5353
openai = ">=1.0.0"
5454
requests = ">=2.31.0"
55-
vllm = ">=0.13.0"
5655
transformers = ">=4.48.0"
56+
vllm = ">=0.13.0"
57+
torch = ">=2.0.0"
5758

5859
[tool.poetry.group.docs.dependencies]
5960
mkdocs = ">=1.6.1"

tests/llms/test_local_llm.py

Lines changed: 42 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
from unittest.mock import MagicMock, patch
22

33
import pytest
4+
from transformers import AutoTokenizer
45

56
from promptolution.llms import LocalLLM
67

@@ -67,3 +68,44 @@ def test_local_llm_get_response(mock_local_dependencies):
6768
assert len(responses) == 2
6869
assert responses[0] == "Mock response 1"
6970
assert responses[1] == "Mock response 2"
71+
72+
73+
@pytest.mark.parametrize(
74+
"model_id",
75+
[
76+
"Qwen/Qwen2.5-0.5B-Instruct",
77+
"HuggingFaceTB/SmolLM2-135M-Instruct",
78+
"microsoft/Phi-3.5-mini-instruct",
79+
"mistralai/Mistral-Nemo-Instruct-2407",
80+
],
81+
)
82+
def test_local_llm_chat_template_renders(model_id):
83+
"""Regression for #71: message dicts must use 'content' key so the
84+
tokenizer's chat template renders the system and user text."""
85+
tokenizer = AutoTokenizer.from_pretrained(model_id)
86+
87+
with patch("promptolution.llms.local_llm.pipeline") as mock_pipeline_func, patch(
88+
"promptolution.llms.local_llm.torch"
89+
):
90+
mock_pipeline_obj = MagicMock()
91+
mock_pipeline_obj.tokenizer = tokenizer
92+
mock_pipeline_func.return_value = mock_pipeline_obj
93+
94+
def fake_call(inputs, **_):
95+
return [
96+
[{"generated_text": tokenizer.apply_chat_template(msg, tokenize=False, add_generation_prompt=True)}]
97+
for msg in inputs
98+
]
99+
100+
mock_pipeline_obj.side_effect = fake_call
101+
102+
local_llm = LocalLLM(model_id=model_id, batch_size=2)
103+
prompts = ["What is 2 + 2?", "Name a colour."]
104+
sys_prompts = ["You are a math tutor.", "You are concise."]
105+
106+
responses = local_llm._get_response(prompts, system_prompts=sys_prompts)
107+
108+
assert len(responses) == 2
109+
for response, prompt, sys_prompt in zip(responses, prompts, sys_prompts):
110+
assert prompt in response
111+
assert sys_prompt in response

0 commit comments

Comments
 (0)