Skip to content

Commit 64da7a6

Browse files
committed
very dirty fix
Signed-off-by: wang.yuqi <[email protected]>
1 parent 7873b25 commit 64da7a6

File tree

5 files changed

+41
-39
lines changed

5 files changed

+41
-39
lines changed

tests/models/language/pooling/test_extract_hidden_states.py

Lines changed: 6 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@
1414
def test_extract_hidden_states(hf_runner, vllm_runner, model: str):
1515
n_prompt_tokens = [55, 56, 57]
1616
token_prompts = [[1024 + i for i in range(n)] for n in n_prompt_tokens]
17+
prompts = [TokensPrompt(prompt_token_ids=t) for t in token_prompts]
1718

1819
with vllm_runner(
1920
model,
@@ -23,7 +24,7 @@ def test_extract_hidden_states(hf_runner, vllm_runner, model: str):
2324
enable_prefix_caching=True,
2425
) as vllm_model:
2526
pooling_outputs = vllm_model.llm.encode(
26-
[TokensPrompt(prompt_token_ids=t) for t in token_prompts],
27+
prompts=prompts,
2728
pooling_task="token_embed",
2829
)
2930

@@ -36,7 +37,7 @@ def test_extract_hidden_states(hf_runner, vllm_runner, model: str):
3637
# we need to skip reading cache at this request by
3738
# request.skip_reading_prefix_cache
3839
pooling_outputs = vllm_model.llm.encode(
39-
[TokensPrompt(prompt_token_ids=t) for t in token_prompts],
40+
prompts=prompts,
4041
pooling_task="token_embed",
4142
)
4243

@@ -48,7 +49,7 @@ def test_extract_hidden_states(hf_runner, vllm_runner, model: str):
4849
# skip_reading_prefix_cache can still write to cache
4950
# to accelerate following requests
5051
pooling_outputs = vllm_model.llm.encode(
51-
[TokensPrompt(prompt_token_ids=t) for t in token_prompts],
52+
prompts=prompts,
5253
pooling_task="embed",
5354
)
5455

@@ -57,8 +58,8 @@ def test_extract_hidden_states(hf_runner, vllm_runner, model: str):
5758
assert output.num_cached_tokens > 0
5859

5960
# Support generate text and returning Prompt Hidden States
60-
generate_outputs = vllm_model.generate(
61-
prompts=[TokensPrompt(prompt_token_ids=t) for t in token_prompts],
61+
generate_outputs = vllm_model.llm.generate(
62+
prompts=prompts,
6263
sampling_params=SamplingParams(max_tokens=1),
6364
)
6465
for n, output in zip(n_prompt_tokens, generate_outputs):

vllm/entrypoints/llm.py

Lines changed: 2 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -1025,13 +1025,8 @@ def encode(
10251025
raise ValueError(error_str)
10261026

10271027
model_config = self.model_config
1028-
runner_type = model_config.runner_type
1029-
if runner_type != "pooling":
1030-
raise ValueError(
1031-
"LLM.encode() is only supported for pooling models. "
1032-
"Try passing `--runner pooling` to use the model as a "
1033-
"pooling model."
1034-
)
1028+
if pooling_task not in self.supported_tasks:
1029+
raise ValueError(f"pooling_task must be one of {self.supported_tasks}.")
10351030

10361031
io_processor_prompt = False
10371032
if isinstance(prompts, dict) and "data" in prompts:
@@ -1069,9 +1064,6 @@ def encode(
10691064
# Use default pooling params.
10701065
pooling_params = PoolingParams()
10711066

1072-
if pooling_task not in self.supported_tasks:
1073-
raise ValueError(f"pooling_task must be one of {self.supported_tasks}.")
1074-
10751067
for param in as_iter(pooling_params):
10761068
param.verify(pooling_task, model_config)
10771069
# for backwards compatibility

vllm/pooling_params.py

Lines changed: 9 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -108,6 +108,15 @@ def verify(
108108
def _merge_default_parameters(
109109
self, model_config: Optional["ModelConfig"] = None
110110
) -> None:
111+
if self.skip_reading_prefix_cache is None:
112+
# If prefix caching is enabled,
113+
# the output of all pooling may less than n_prompt_tokens,
114+
# we need to skip reading cache at this request.
115+
if self.task in ["token_embed", "token_classify"]:
116+
self.skip_reading_prefix_cache = True
117+
else:
118+
self.skip_reading_prefix_cache = False
119+
111120
if model_config is None:
112121
return
113122

@@ -125,15 +134,6 @@ def _merge_default_parameters(
125134
if getattr(self, k, None) is None:
126135
setattr(self, k, getattr(pooler_config, k))
127136

128-
if self.skip_reading_prefix_cache is None:
129-
# If prefix caching is enabled,
130-
# the output of all pooling may less than n_prompt_tokens,
131-
# we need to skip reading cache at this request.
132-
if self.task in ["token_embed", "token_classify"]:
133-
self.skip_reading_prefix_cache = True
134-
else:
135-
self.skip_reading_prefix_cache = False
136-
137137
self._verify_step_pooling(pooler_config, valid_parameters)
138138

139139
def _verify_step_pooling(

vllm/v1/worker/gpu_input_batch.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -479,8 +479,8 @@ def remove_request(self, req_id: str) -> int | None:
479479
del self.lora_id_to_lora_request[lora_id]
480480
self.request_lora_mapping[req_index] = 0
481481

482-
if self.is_pooling_model:
483-
self.pooling_params.pop(req_id, None)
482+
pooling_params = self.pooling_params.pop(req_id, None)
483+
if pooling_params is not None:
484484
self.pooling_states.pop(req_id, None)
485485
return req_index
486486

vllm/v1/worker/gpu_model_runner.py

Lines changed: 22 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -35,7 +35,7 @@
3535
CUDAGraphMode,
3636
VllmConfig,
3737
get_layers_from_vllm_config,
38-
update_config,
38+
update_config, PoolerConfig, set_current_vllm_config,
3939
)
4040
from vllm.distributed.ec_transfer import get_ec_transfer, has_ec_transfer
4141
from vllm.distributed.eplb.eplb_state import EplbState
@@ -173,6 +173,7 @@
173173
sanity_check_mm_encoder_outputs,
174174
scatter_mm_placeholders,
175175
)
176+
from ...model_executor.layers.pooler import DispatchPooler, Pooler
176177

177178
if TYPE_CHECKING:
178179
from vllm.model_executor.model_loader.tensorizer import TensorizerConfig
@@ -817,14 +818,13 @@ def _update_states(self, scheduler_output: "SchedulerOutput") -> None:
817818
else:
818819
generator = None
819820

820-
if self.is_pooling_model:
821-
assert pooling_params is not None
822-
task = pooling_params.task
823-
assert task is not None, "You did not set `task` in the API"
821+
if pooling_params is not None:
822+
task = pooling_params.task
823+
assert task is not None, "You did not set `task` in the API"
824824

825-
model = cast(VllmModelForPooling, self.get_model())
826-
to_update = model.pooler.get_pooling_updates(task)
827-
to_update.apply(pooling_params)
825+
model = cast(VllmModelForPooling, self.get_model())
826+
to_update = model.pooler.get_pooling_updates(task)
827+
to_update.apply(pooling_params)
828828

829829
req_state = CachedRequestState(
830830
req_id=req_id,
@@ -2295,12 +2295,12 @@ def get_model(self) -> nn.Module:
22952295
return self.model.unwrap()
22962296
return self.model
22972297

2298-
def get_supported_generation_tasks(self) -> list[GenerationTask]:
2298+
def get_supported_generation_tasks(self) -> list[GenerationTask|PoolingTask]:
22992299
model = self.get_model()
23002300
supported_tasks = list[GenerationTask]()
23012301

23022302
if is_text_generation_model(model):
2303-
supported_tasks.append("generate")
2303+
supported_tasks.extend(["generate", "embed", "token_embed"])
23042304

23052305
if supports_transcription(model):
23062306
if model.supports_transcription_only:
@@ -2400,8 +2400,7 @@ def _pool(
24002400
num_scheduled_tokens_np.tolist(), seq_lens_cpu, device=hidden_states.device
24012401
)
24022402

2403-
model = cast(VllmModelForPooling, self.model)
2404-
raw_pooler_output: PoolerOutput = model.pooler(
2403+
raw_pooler_output: PoolerOutput = self.model.pooler(
24052404
hidden_states=hidden_states,
24062405
pooling_metadata=pooling_metadata,
24072406
)
@@ -3110,7 +3109,7 @@ def execute_model(
31103109
self.kv_connector_output = kv_connector_output
31113110
return hidden_states
31123111

3113-
if self.is_pooling_model:
3112+
if len(self.input_batch.pooling_params) > 0:
31143113
# Return the pooling output.
31153114
output = self._pool(
31163115
hidden_states, num_scheduled_tokens, num_scheduled_tokens_np
@@ -3674,6 +3673,16 @@ def load_model(self, eep_scale_up: bool = False) -> None:
36743673
and mm_config.is_multimodal_pruning_enabled()
36753674
)
36763675

3676+
if not self.is_pooling_model:
3677+
with set_current_vllm_config(self.vllm_config):
3678+
pooler_config = PoolerConfig(pooling_type="LAST")
3679+
self.model.pooler = DispatchPooler(
3680+
{
3681+
"token_embed": Pooler.for_token_embed(pooler_config),
3682+
"embed": Pooler.for_embed(pooler_config),
3683+
},
3684+
)
3685+
36773686
if is_mixture_of_experts(self.model) and self.parallel_config.enable_eplb:
36783687
logger.info_once("EPLB is enabled for model %s.", self.model_config.model)
36793688
global_expert_load = (

0 commit comments

Comments
 (0)