Skip to content

Commit 0346166

Browse files
committed
init
Signed-off-by: wang.yuqi <[email protected]>
1 parent e3a1cd1 commit 0346166

File tree

1 file changed

+11
-2
lines changed

1 file changed

+11
-2
lines changed

tests/models/language/pooling/test_extract_hidden_states.py

Lines changed: 11 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33
import pytest
44
import torch
55

6-
from vllm import TokensPrompt
6+
from vllm import SamplingParams, TokensPrompt
77

88

99
@pytest.mark.parametrize(
@@ -19,7 +19,7 @@ def test_extract_hidden_states(hf_runner, vllm_runner, model: str):
1919
model,
2020
max_model_len=128,
2121
enforce_eager=True,
22-
runner="pooling",
22+
runner="generate",
2323
enable_prefix_caching=True,
2424
) as vllm_model:
2525
pooling_outputs = vllm_model.llm.encode(
@@ -55,3 +55,12 @@ def test_extract_hidden_states(hf_runner, vllm_runner, model: str):
5555
for n, output in zip(n_prompt_tokens, pooling_outputs):
5656
assert len(output.prompt_token_ids) == n
5757
assert output.num_cached_tokens > 0
58+
59+
# Support generate text and returning Prompt Hidden States
60+
generate_outputs = vllm_model.generate(
61+
prompts=[TokensPrompt(prompt_token_ids=t) for t in token_prompts],
62+
sampling_params=SamplingParams(max_tokens=1),
63+
)
64+
for n, output in zip(n_prompt_tokens, generate_outputs):
65+
assert len(output.prompt_token_ids) == n
66+
assert output.num_cached_tokens > 0

0 commit comments

Comments
 (0)