Skip to content

Commit 1f2b398

Browse files
committed
Run formatter
1 parent 5b07671 commit 1f2b398

File tree

1 file changed

+6
-13
lines changed

1 file changed

+6
-13
lines changed

tests/prompting/llms/vllm_llm.py

+6-13
Original file line numberDiff line numberDiff line change
@@ -9,20 +9,14 @@
99
def _fake_tokenizer():
1010
tok = MagicMock()
1111
tok.apply_chat_template.side_effect = (
12-
lambda conversation, tokenize, add_generation_prompt, continue_final_message:
13-
f"TEMPLATE::{conversation[-1]['role']}::{conversation[-1]['content']}"
12+
lambda conversation, tokenize, add_generation_prompt, continue_final_message: f"TEMPLATE::{conversation[-1]['role']}::{conversation[-1]['content']}"
1413
)
1514
tok.decode.side_effect = lambda ids: "<s>" if ids == [0] else f"tok{ids[0]}"
1615
return tok
1716

1817

1918
def _fake_llm(return_logprobs):
20-
out_obj = SimpleNamespace(
21-
outputs=[SimpleNamespace(
22-
text="dummy",
23-
logprobs=[return_logprobs]
24-
)]
25-
)
19+
out_obj = SimpleNamespace(outputs=[SimpleNamespace(text="dummy", logprobs=[return_logprobs])])
2620
llm = MagicMock()
2721
llm.generate.return_value = [out_obj]
2822
return llm
@@ -41,15 +35,14 @@ async def test_generate_logits(monkeypatch, messages, continue_last):
4135
fake_logprobs = {
4236
3: SimpleNamespace(logprob=-0.1),
4337
2: SimpleNamespace(logprob=-0.5),
44-
1: SimpleNamespace(logprob=-1.0)
38+
1: SimpleNamespace(logprob=-1.0),
4539
}
4640

4741
tokenizer_stub = _fake_tokenizer()
4842
llm_stub = _fake_llm(fake_logprobs)
4943

50-
with (
51-
patch("prompting.llms.vllm_llm.LLM", return_value=llm_stub),
52-
patch("prompting.llms.vllm_llm.SamplingParams", lambda **kw: kw)
44+
with patch("prompting.llms.vllm_llm.LLM", return_value=llm_stub), patch(
45+
"prompting.llms.vllm_llm.SamplingParams", lambda **kw: kw
5346
):
5447
model = ReproducibleVLLM(model_id="mock-model")
5548
# Swap tokenizer (LLM stub has none).
@@ -72,4 +65,4 @@ async def test_generate_logits(monkeypatch, messages, continue_last):
7265
assert all(a >= b for a, b in zip(out_dict.values(), list(out_dict.values())[1:]))
7366

7467
# 3. generate() was invoked with that exact prompt.
75-
llm_stub.generate.assert_called_once_with(rendered_prompt, {'max_tokens': 1, 'logprobs': 3})
68+
llm_stub.generate.assert_called_once_with(rendered_prompt, {"max_tokens": 1, "logprobs": 3})

0 commit comments

Comments
 (0)