Skip to content
Merged
Show file tree
Hide file tree
Changes from 5 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
19 changes: 11 additions & 8 deletions vllm/v1/worker/gpu/input_batch.py
Original file line number Diff line number Diff line change
Expand Up @@ -438,17 +438,20 @@ def _post_update_kernel(

for i in range(num_sampled):
token_id = tl.load(sampled_tokens_ptr + req_id * sampled_tokens_stride + i)
token_ptr = (
output_bin_counts_ptr + req_state_idx * output_bin_counts_stride + token_id
)
count = tl.load(token_ptr)
count += 1
tl.store(token_ptr, count)
tl.store(
all_token_ids_ptr + req_state_idx * all_token_ids_stride + total_len + i,
token_id,
)

if output_bin_counts_ptr is not None:
token_ptr = (
output_bin_counts_ptr
+ req_state_idx * output_bin_counts_stride
+ token_id
)
count = tl.load(token_ptr)
tl.store(token_ptr, count + 1)

query_start = tl.load(query_start_loc_ptr + req_id)
query_end = tl.load(query_start_loc_ptr + req_id + 1)
query_len = query_end - query_start
Expand All @@ -467,7 +470,7 @@ def post_update(
# [max_num_reqs]
last_sampled_tokens: torch.Tensor,
# [max_num_reqs, vocab_size]
output_bin_counts: torch.Tensor,
output_bin_counts: torch.Tensor | None,
# [num_reqs, num_speculative_steps + 1]
sampled_tokens: torch.Tensor,
# [num_reqs]
Expand All @@ -487,7 +490,7 @@ def post_update(
num_computed_tokens,
last_sampled_tokens,
output_bin_counts,
output_bin_counts.stride(0),
output_bin_counts.stride(0) if output_bin_counts is not None else 0,
sampled_tokens,
sampled_tokens.stride(0),
num_sampled,
Expand Down
92 changes: 55 additions & 37 deletions vllm/v1/worker/gpu/model_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -183,6 +183,10 @@ def __init__(self, vllm_config: VllmConfig, device: torch.device):
# Draft tokens propagation - for spec-dec + struct outputs.
self.draft_tokens_handler = DraftTokensHandler(self.device)

# Pooling models.
self.is_pooling_model = self.model_config.runner_type == "pooling"
self.pooling_runner: PoolingRunner | None = None

# General request states.
self.req_states = RequestState(
max_num_reqs=self.max_num_reqs,
Expand All @@ -199,20 +203,34 @@ def __init__(self, vllm_config: VllmConfig, device: torch.device):
max_num_tokens=self.max_num_tokens,
device=self.device,
)
self.sampler = Sampler(
max_num_reqs=self.max_num_reqs,
vocab_size=self.vocab_size,
device=self.device,
req_states=self.req_states,
logprobs_mode=self.model_config.logprobs_mode,
num_speculative_tokens=self.num_speculative_steps + 1,
)
self.rejection_sampler = RejectionSampler(
self.sampler,
num_speculative_steps=self.num_speculative_steps,
use_strict_rejection_sampling=use_strict_rejection_sampling,
)
self.prompt_logprobs_worker = PromptLogprobsWorker(self.max_num_reqs)

self.sampler: Sampler | None = None
self.rejection_sampler: RejectionSampler | None = None
self.prompt_logprobs_worker: PromptLogprobsWorker | None = None
self.structured_outputs_worker: StructuredOutputsWorker | None = None
if self.is_last_pp_rank and not self.is_pooling_model:
# Initialize sampling-related workers.
# These components are only set up on the last PP rank and
# for generative (non-pooling) models.
self.sampler = Sampler(
max_num_reqs=self.max_num_reqs,
vocab_size=self.vocab_size,
device=self.device,
req_states=self.req_states,
logprobs_mode=self.model_config.logprobs_mode,
num_speculative_tokens=self.num_speculative_steps + 1,
)
self.rejection_sampler = RejectionSampler(
self.sampler,
num_speculative_steps=self.num_speculative_steps,
use_strict_rejection_sampling=use_strict_rejection_sampling,
)
self.prompt_logprobs_worker = PromptLogprobsWorker(self.max_num_reqs)
self.structured_outputs_worker = StructuredOutputsWorker(
max_num_logits=self.max_num_reqs * (self.num_speculative_steps + 1),
vocab_size=self.vocab_size,
device=self.device,
)

# CUDA graphs.
self.decode_query_len = self.num_speculative_steps + 1
Expand All @@ -222,21 +240,11 @@ def __init__(self, vllm_config: VllmConfig, device: torch.device):
self.compilation_config.cudagraph_mode,
decode_query_len=self.decode_query_len,
)
# Structured outputs worker.
self.structured_outputs_worker = StructuredOutputsWorker(
max_num_logits=self.max_num_reqs * (self.num_speculative_steps + 1),
vocab_size=self.vocab_size,
device=self.device,
)
# LoRA-related workers.
self.lora_state = LoraState(max_num_reqs=self.max_num_reqs)
# KV Connector if configured.
self.kv_connector: KVConnector = NO_OP_KV_CONNECTOR

# Pooling models.
self.is_pooling_model = self.model_config.runner_type == "pooling"
self.pooling_runner: PoolingRunner | None = None

# For transferring state from execute_model to subsequent sample_tokens call.
self.execute_model_state: ExecuteModelState | None = None

Expand Down Expand Up @@ -289,7 +297,7 @@ def load_model(self, *args, **kwargs) -> None:
self.model_state = init_model_state(
self.vllm_config, self.model, self.encoder_cache, self.device
)
if self.is_pooling_model:
if self.is_pooling_model and self.is_last_pp_rank:
self.pooling_runner = PoolingRunner(self.model)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We may still need to get/store the supported tasks here ... since that's queried from the front-end and the executor returns the result from rank 0.


def get_model(self) -> nn.Module:
Expand Down Expand Up @@ -420,6 +428,7 @@ def _dummy_run(

# dummy run the eagle speculator's propose to ensure DP/EP sync.
if self.speculator is not None:
assert self.sampler is not None
self.speculator.propose(
input_batch=input_batch,
attn_metadata=attn_metadata,
Expand Down Expand Up @@ -457,10 +466,8 @@ def _dummy_sampler_run(self, hidden_states: torch.Tensor) -> None:
# NOTE(woosuk): During the initial memory profiling, the sampler may skip
# top_k, top_p, and logprobs, using less GPU memory than what is possible
# during actual execution.
self.sampler(
logits,
dummy_input_batch,
)
assert self.sampler is not None
self.sampler(logits, dummy_input_batch)

@torch.inference_mode()
def _dummy_pooler_run(self, hidden_states: torch.Tensor) -> None:
Expand Down Expand Up @@ -558,7 +565,8 @@ def finish_requests(self, scheduler_output: SchedulerOutput) -> None:
self.req_states.remove_request(req_id)
if self.encoder_cache is not None:
self.encoder_cache.remove_request(req_id)
self.prompt_logprobs_worker.remove_request(req_id)
if self.prompt_logprobs_worker is not None:
self.prompt_logprobs_worker.remove_request(req_id)
self.lora_state.remove_request(req_id)

def free_states(self, scheduler_output: SchedulerOutput) -> None:
Expand Down Expand Up @@ -589,18 +597,21 @@ def add_requests(self, scheduler_output: SchedulerOutput) -> None:
)
self.lora_state.add_request(req_id, req_index, new_req_data.lora_request)

if new_req_data.sampling_params is not None:
if self.is_last_pp_rank and new_req_data.sampling_params is not None:
assert self.sampler is not None
self.sampler.add_request(
req_index, prompt_len, new_req_data.sampling_params
)
assert self.prompt_logprobs_worker is not None
self.prompt_logprobs_worker.add_request(
req_id, req_index, new_req_data.sampling_params
)

if scheduler_output.scheduled_new_reqs:
self.req_states.apply_staged_writes()
self.sampler.apply_staged_writes()
self.model_state.apply_staged_writes()
if self.sampler is not None:
self.sampler.apply_staged_writes()

def update_requests(self, scheduler_output: SchedulerOutput) -> None:
# Add new blocks for the existing requests.
Expand Down Expand Up @@ -788,6 +799,7 @@ def sample(
logits = self.model.compute_logits(sample_hidden_states)
if grammar_output is not None:
# Apply grammar bitmask to the logits in-place.
assert self.structured_outputs_worker is not None
self.structured_outputs_worker.apply_grammar_bitmask(
logits,
input_batch,
Expand All @@ -797,12 +809,11 @@ def sample(

if input_batch.num_draft_tokens == 0:
# No draft tokens (common case).
sampler_output = self.sampler(
logits,
input_batch,
)
assert self.sampler is not None
sampler_output = self.sampler(logits, input_batch)
else:
# Rejection sampling for spec decoding.
assert self.rejection_sampler is not None
sampler_output = self.rejection_sampler(
logits,
input_batch,
Expand Down Expand Up @@ -831,11 +842,16 @@ def postprocess(
num_rejected: torch.Tensor,
) -> None:
# Update the number of computed tokens.
if self.is_last_pp_rank:
assert self.sampler is not None
output_bin_counts = self.sampler.penalties_state.output_bin_counts
else:
output_bin_counts = None
post_update(
input_batch.idx_mapping,
self.req_states.num_computed_tokens.gpu,
self.req_states.last_sampled_tokens,
self.sampler.penalties_state.output_bin_counts,
output_bin_counts,
sampled_tokens,
num_sampled,
num_rejected,
Expand Down Expand Up @@ -1076,6 +1092,7 @@ def sample_tokens(
# Broadcast to non-last PP ranks (handles spec decode multi-token).
pp_broadcast(sampler_output.sampled_token_ids, num_sampled, num_rejected)

assert self.prompt_logprobs_worker is not None
prompt_logprobs_dict = self.prompt_logprobs_worker.compute_prompt_logprobs(
self.model.compute_logits,
hidden_states,
Expand Down Expand Up @@ -1115,6 +1132,7 @@ def sample_tokens(
input_batch, sampler_output.sampled_token_ids, num_sampled, num_rejected
)
if self.speculator is not None:
assert self.sampler is not None
draft_tokens = self.speculator.propose(
input_batch,
attn_metadata,
Expand Down
Loading