9
9
def _fake_tokenizer ():
10
10
tok = MagicMock ()
11
11
tok .apply_chat_template .side_effect = (
12
- lambda conversation , tokenize , add_generation_prompt , continue_final_message :
13
- f"TEMPLATE::{ conversation [- 1 ]['role' ]} ::{ conversation [- 1 ]['content' ]} "
12
+ lambda conversation , tokenize , add_generation_prompt , continue_final_message : f"TEMPLATE::{ conversation [- 1 ]['role' ]} ::{ conversation [- 1 ]['content' ]} "
14
13
)
15
14
tok .decode .side_effect = lambda ids : "<s>" if ids == [0 ] else f"tok{ ids [0 ]} "
16
15
return tok
17
16
18
17
19
18
def _fake_llm (return_logprobs ):
20
- out_obj = SimpleNamespace (
21
- outputs = [SimpleNamespace (
22
- text = "dummy" ,
23
- logprobs = [return_logprobs ]
24
- )]
25
- )
19
+ out_obj = SimpleNamespace (outputs = [SimpleNamespace (text = "dummy" , logprobs = [return_logprobs ])])
26
20
llm = MagicMock ()
27
21
llm .generate .return_value = [out_obj ]
28
22
return llm
@@ -41,15 +35,14 @@ async def test_generate_logits(monkeypatch, messages, continue_last):
41
35
fake_logprobs = {
42
36
3 : SimpleNamespace (logprob = - 0.1 ),
43
37
2 : SimpleNamespace (logprob = - 0.5 ),
44
- 1 : SimpleNamespace (logprob = - 1.0 )
38
+ 1 : SimpleNamespace (logprob = - 1.0 ),
45
39
}
46
40
47
41
tokenizer_stub = _fake_tokenizer ()
48
42
llm_stub = _fake_llm (fake_logprobs )
49
43
50
- with (
51
- patch ("prompting.llms.vllm_llm.LLM" , return_value = llm_stub ),
52
- patch ("prompting.llms.vllm_llm.SamplingParams" , lambda ** kw : kw )
44
+ with patch ("prompting.llms.vllm_llm.LLM" , return_value = llm_stub ), patch (
45
+ "prompting.llms.vllm_llm.SamplingParams" , lambda ** kw : kw
53
46
):
54
47
model = ReproducibleVLLM (model_id = "mock-model" )
55
48
# Swap tokenizer (LLM stub has none).
@@ -72,4 +65,4 @@ async def test_generate_logits(monkeypatch, messages, continue_last):
72
65
assert all (a >= b for a , b in zip (out_dict .values (), list (out_dict .values ())[1 :]))
73
66
74
67
# 3. generate() was invoked with that exact prompt.
75
- llm_stub .generate .assert_called_once_with (rendered_prompt , {' max_tokens' : 1 , ' logprobs' : 3 })
68
+ llm_stub .generate .assert_called_once_with (rendered_prompt , {" max_tokens" : 1 , " logprobs" : 3 })
0 commit comments