|
1 | 1 | from unittest.mock import MagicMock, patch |
2 | 2 |
|
3 | 3 | import pytest |
| 4 | +from transformers import AutoTokenizer |
4 | 5 |
|
5 | 6 | from promptolution.llms import LocalLLM |
6 | 7 |
|
@@ -67,3 +68,44 @@ def test_local_llm_get_response(mock_local_dependencies): |
67 | 68 | assert len(responses) == 2 |
68 | 69 | assert responses[0] == "Mock response 1" |
69 | 70 | assert responses[1] == "Mock response 2" |
| 71 | + |
| 72 | + |
| 73 | +@pytest.mark.parametrize( |
| 74 | + "model_id", |
| 75 | + [ |
| 76 | + "Qwen/Qwen2.5-0.5B-Instruct", |
| 77 | + "HuggingFaceTB/SmolLM2-135M-Instruct", |
| 78 | + "microsoft/Phi-3.5-mini-instruct", |
| 79 | + "mistralai/Mistral-Nemo-Instruct-2407", |
| 80 | + ], |
| 81 | +) |
| 82 | +def test_local_llm_chat_template_renders(model_id): |
| 83 | + """Regression for #71: message dicts must use 'content' key so the |
| 84 | + tokenizer's chat template renders the system and user text.""" |
| 85 | + tokenizer = AutoTokenizer.from_pretrained(model_id) |
| 86 | + |
| 87 | + with patch("promptolution.llms.local_llm.pipeline") as mock_pipeline_func, patch( |
| 88 | + "promptolution.llms.local_llm.torch" |
| 89 | + ): |
| 90 | + mock_pipeline_obj = MagicMock() |
| 91 | + mock_pipeline_obj.tokenizer = tokenizer |
| 92 | + mock_pipeline_func.return_value = mock_pipeline_obj |
| 93 | + |
| 94 | + def fake_call(inputs, **_): |
| 95 | + return [ |
| 96 | + [{"generated_text": tokenizer.apply_chat_template(msg, tokenize=False, add_generation_prompt=True)}] |
| 97 | + for msg in inputs |
| 98 | + ] |
| 99 | + |
| 100 | + mock_pipeline_obj.side_effect = fake_call |
| 101 | + |
| 102 | + local_llm = LocalLLM(model_id=model_id, batch_size=2) |
| 103 | + prompts = ["What is 2 + 2?", "Name a colour."] |
| 104 | + sys_prompts = ["You are a math tutor.", "You are concise."] |
| 105 | + |
| 106 | + responses = local_llm._get_response(prompts, system_prompts=sys_prompts) |
| 107 | + |
| 108 | + assert len(responses) == 2 |
| 109 | + for response, prompt, sys_prompt in zip(responses, prompts, sys_prompts): |
| 110 | + assert prompt in response |
| 111 | + assert sys_prompt in response |
0 commit comments