Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
20 changes: 15 additions & 5 deletions tests/models/language/pooling/test_extract_hidden_states.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
import pytest
import torch

from vllm import TokensPrompt
from vllm import SamplingParams, TokensPrompt


@pytest.mark.parametrize(
Expand All @@ -14,16 +14,17 @@
def test_extract_hidden_states(hf_runner, vllm_runner, model: str):
n_prompt_tokens = [55, 56, 57]
token_prompts = [[1024 + i for i in range(n)] for n in n_prompt_tokens]
prompts = [TokensPrompt(prompt_token_ids=t) for t in token_prompts]

with vllm_runner(
model,
max_model_len=128,
enforce_eager=True,
runner="pooling",
runner="generate",
enable_prefix_caching=True,
) as vllm_model:
pooling_outputs = vllm_model.llm.encode(
[TokensPrompt(prompt_token_ids=t) for t in token_prompts],
prompts=prompts,
pooling_task="token_embed",
)

Expand All @@ -36,7 +37,7 @@ def test_extract_hidden_states(hf_runner, vllm_runner, model: str):
# we need to skip reading cache at this request by
# request.skip_reading_prefix_cache
pooling_outputs = vllm_model.llm.encode(
[TokensPrompt(prompt_token_ids=t) for t in token_prompts],
prompts=prompts,
pooling_task="token_embed",
)

Expand All @@ -48,10 +49,19 @@ def test_extract_hidden_states(hf_runner, vllm_runner, model: str):
# skip_reading_prefix_cache can still write to cache
# to accelerate following requests
pooling_outputs = vllm_model.llm.encode(
[TokensPrompt(prompt_token_ids=t) for t in token_prompts],
prompts=prompts,
pooling_task="embed",
)

for n, output in zip(n_prompt_tokens, pooling_outputs):
assert len(output.prompt_token_ids) == n
assert output.num_cached_tokens > 0

# Support generate text and returning Prompt Hidden States
generate_outputs = vllm_model.llm.generate(
prompts=prompts,
sampling_params=SamplingParams(max_tokens=1),
)
for n, output in zip(n_prompt_tokens, generate_outputs):
assert len(output.prompt_token_ids) == n
assert output.num_cached_tokens > 0
12 changes: 2 additions & 10 deletions vllm/entrypoints/llm.py
Original file line number Diff line number Diff line change
Expand Up @@ -1025,13 +1025,8 @@ def encode(
raise ValueError(error_str)

model_config = self.model_config
runner_type = model_config.runner_type
if runner_type != "pooling":
raise ValueError(
"LLM.encode() is only supported for pooling models. "
"Try passing `--runner pooling` to use the model as a "
"pooling model."
)
if pooling_task not in self.supported_tasks:
raise ValueError(f"pooling_task must be one of {self.supported_tasks}.")

io_processor_prompt = False
if isinstance(prompts, dict) and "data" in prompts:
Expand Down Expand Up @@ -1069,9 +1064,6 @@ def encode(
# Use default pooling params.
pooling_params = PoolingParams()

if pooling_task not in self.supported_tasks:
raise ValueError(f"pooling_task must be one of {self.supported_tasks}.")

for param in as_iter(pooling_params):
param.verify(pooling_task, model_config)
# for backwards compatibility
Expand Down
18 changes: 9 additions & 9 deletions vllm/pooling_params.py
Original file line number Diff line number Diff line change
Expand Up @@ -108,6 +108,15 @@ def verify(
def _merge_default_parameters(
self, model_config: Optional["ModelConfig"] = None
) -> None:
if self.skip_reading_prefix_cache is None:
# If prefix caching is enabled,
# the output of all pooling may less than n_prompt_tokens,
# we need to skip reading cache at this request.
if self.task in ["token_embed", "token_classify"]:
self.skip_reading_prefix_cache = True
else:
self.skip_reading_prefix_cache = False

if model_config is None:
return

Expand All @@ -125,15 +134,6 @@ def _merge_default_parameters(
if getattr(self, k, None) is None:
setattr(self, k, getattr(pooler_config, k))

if self.skip_reading_prefix_cache is None:
# If prefix caching is enabled,
# the output of all pooling may less than n_prompt_tokens,
# we need to skip reading cache at this request.
if self.task in ["token_embed", "token_classify"]:
self.skip_reading_prefix_cache = True
else:
self.skip_reading_prefix_cache = False

self._verify_step_pooling(pooler_config, valid_parameters)

def _verify_step_pooling(
Expand Down
4 changes: 2 additions & 2 deletions vllm/v1/worker/gpu_input_batch.py
Original file line number Diff line number Diff line change
Expand Up @@ -479,8 +479,8 @@ def remove_request(self, req_id: str) -> int | None:
del self.lora_id_to_lora_request[lora_id]
self.request_lora_mapping[req_index] = 0

if self.is_pooling_model:
self.pooling_params.pop(req_id, None)
pooling_params = self.pooling_params.pop(req_id, None)
if pooling_params is not None:
self.pooling_states.pop(req_id, None)
return req_index

Expand Down
22 changes: 16 additions & 6 deletions vllm/v1/worker/gpu_model_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,7 @@
CUDAGraphMode,
VllmConfig,
get_layers_from_vllm_config,
update_config,
update_config, PoolerConfig, set_current_vllm_config,
)
from vllm.distributed.ec_transfer import get_ec_transfer, has_ec_transfer
from vllm.distributed.eplb.eplb_state import EplbState
Expand Down Expand Up @@ -173,6 +173,7 @@
sanity_check_mm_encoder_outputs,
scatter_mm_placeholders,
)
from ...model_executor.layers.pooler import DispatchPooler, Pooler

if TYPE_CHECKING:
from vllm.model_executor.model_loader.tensorizer import TensorizerConfig
Expand Down Expand Up @@ -817,8 +818,7 @@
else:
generator = None

if self.is_pooling_model:
assert pooling_params is not None
if pooling_params is not None:
task = pooling_params.task
assert task is not None, "You did not set `task` in the API"

Expand Down Expand Up @@ -2295,14 +2295,14 @@
return self.model.unwrap()
return self.model

def get_supported_generation_tasks(self) -> list[GenerationTask]:
def get_supported_generation_tasks(self) -> list[GenerationTask|PoolingTask]:
model = self.get_model()
supported_tasks = list[GenerationTask]()

if is_text_generation_model(model):
supported_tasks.append("generate")
supported_tasks.extend(["generate", "embed", "token_embed"])

if supports_transcription(model):

Check failure on line 2305 in vllm/v1/worker/gpu_model_runner.py

View workflow job for this annotation

GitHub Actions / pre-commit

List item 1 has incompatible type "Literal['embed']"; expected "Literal['generate', 'transcription']" [list-item]

Check failure on line 2305 in vllm/v1/worker/gpu_model_runner.py

View workflow job for this annotation

GitHub Actions / pre-commit

List item 2 has incompatible type "Literal['token_embed']"; expected "Literal['generate', 'transcription']" [list-item]

Check failure on line 2305 in vllm/v1/worker/gpu_model_runner.py

View workflow job for this annotation

GitHub Actions / pre-commit

List item 1 has incompatible type "Literal['embed']"; expected "Literal['generate', 'transcription']" [list-item]

Check failure on line 2305 in vllm/v1/worker/gpu_model_runner.py

View workflow job for this annotation

GitHub Actions / pre-commit

List item 2 has incompatible type "Literal['token_embed']"; expected "Literal['generate', 'transcription']" [list-item]

Check failure on line 2305 in vllm/v1/worker/gpu_model_runner.py

View workflow job for this annotation

GitHub Actions / pre-commit

List item 1 has incompatible type "Literal['embed']"; expected "Literal['generate', 'transcription']" [list-item]

Check failure on line 2305 in vllm/v1/worker/gpu_model_runner.py

View workflow job for this annotation

GitHub Actions / pre-commit

List item 2 has incompatible type "Literal['token_embed']"; expected "Literal['generate', 'transcription']" [list-item]

Check failure on line 2305 in vllm/v1/worker/gpu_model_runner.py

View workflow job for this annotation

GitHub Actions / pre-commit

List item 1 has incompatible type "Literal['embed']"; expected "Literal['generate', 'transcription']" [list-item]
if model.supports_transcription_only:
return ["transcription"]

Expand All @@ -2310,7 +2310,7 @@

return supported_tasks

def get_supported_pooling_tasks(self) -> list[PoolingTask]:

Check failure on line 2313 in vllm/v1/worker/gpu_model_runner.py

View workflow job for this annotation

GitHub Actions / pre-commit

Incompatible return value type (got "list[Literal['generate', 'transcription']]", expected "list[Literal['generate', 'transcription'] | Literal['embed', 'classify', 'score', 'token_embed', 'token_classify', 'plugin']]") [return-value]

Check failure on line 2313 in vllm/v1/worker/gpu_model_runner.py

View workflow job for this annotation

GitHub Actions / pre-commit

Incompatible return value type (got "list[Literal['generate', 'transcription']]", expected "list[Literal['generate', 'transcription'] | Literal['embed', 'classify', 'score', 'token_embed', 'token_classify', 'plugin']]") [return-value]

Check failure on line 2313 in vllm/v1/worker/gpu_model_runner.py

View workflow job for this annotation

GitHub Actions / pre-commit

Incompatible return value type (got "list[Literal['generate', 'transcription']]", expected "list[Literal['generate', 'transcription'] | Literal['embed', 'classify', 'score', 'token_embed', 'token_classify', 'plugin']]") [return-value]
model = self.get_model()
if not is_pooling_model(model):
return []
Expand Down Expand Up @@ -3110,7 +3110,7 @@
self.kv_connector_output = kv_connector_output
return hidden_states

if self.is_pooling_model:
if len(self.input_batch.pooling_params) > 0:
# Return the pooling output.
output = self._pool(
hidden_states, num_scheduled_tokens, num_scheduled_tokens_np
Expand Down Expand Up @@ -3674,6 +3674,16 @@
and mm_config.is_multimodal_pruning_enabled()
)

if not self.is_pooling_model:
with set_current_vllm_config(self.vllm_config):
pooler_config = PoolerConfig(pooling_type="LAST")
self.model.pooler = DispatchPooler(
{
"token_embed": Pooler.for_token_embed(pooler_config),
"embed": Pooler.for_embed(pooler_config),
},
)

if is_mixture_of_experts(self.model) and self.parallel_config.enable_eplb:
logger.info_once("EPLB is enabled for model %s.", self.model_config.model)
global_expert_load = (
Expand Down