Skip to content
Merged
Show file tree
Hide file tree
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
19 changes: 11 additions & 8 deletions vllm/v1/worker/gpu/input_batch.py
Original file line number Diff line number Diff line change
Expand Up @@ -438,17 +438,20 @@ def _post_update_kernel(

for i in range(num_sampled):
token_id = tl.load(sampled_tokens_ptr + req_id * sampled_tokens_stride + i)
token_ptr = (
output_bin_counts_ptr + req_state_idx * output_bin_counts_stride + token_id
)
count = tl.load(token_ptr)
count += 1
tl.store(token_ptr, count)
tl.store(
all_token_ids_ptr + req_state_idx * all_token_ids_stride + total_len + i,
token_id,
)

if output_bin_counts_ptr is not None:
token_ptr = (
output_bin_counts_ptr
+ req_state_idx * output_bin_counts_stride
+ token_id
)
count = tl.load(token_ptr)
tl.store(token_ptr, count + 1)

query_start = tl.load(query_start_loc_ptr + req_id)
query_end = tl.load(query_start_loc_ptr + req_id + 1)
query_len = query_end - query_start
Expand All @@ -467,7 +470,7 @@ def post_update(
# [max_num_reqs]
last_sampled_tokens: torch.Tensor,
# [max_num_reqs, vocab_size]
output_bin_counts: torch.Tensor,
output_bin_counts: torch.Tensor | None,
# [num_reqs, num_speculative_steps + 1]
sampled_tokens: torch.Tensor,
# [num_reqs]
Expand All @@ -487,7 +490,7 @@ def post_update(
num_computed_tokens,
last_sampled_tokens,
output_bin_counts,
output_bin_counts.stride(0),
output_bin_counts.stride(0) if output_bin_counts is not None else 0,
sampled_tokens,
sampled_tokens.stride(0),
num_sampled,
Expand Down
89 changes: 53 additions & 36 deletions vllm/v1/worker/gpu/model_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -183,6 +183,10 @@
# Draft tokens propagation - for spec-dec + struct outputs.
self.draft_tokens_handler = DraftTokensHandler(self.device)

# Pooling models.
self.is_pooling_model = self.model_config.runner_type == "pooling"
self.pooling_runner: PoolingRunner | None = None

# General request states.
self.req_states = RequestState(
max_num_reqs=self.max_num_reqs,
Expand All @@ -199,20 +203,35 @@
max_num_tokens=self.max_num_tokens,
device=self.device,
)
self.sampler = Sampler(
max_num_reqs=self.max_num_reqs,
vocab_size=self.vocab_size,
device=self.device,
req_states=self.req_states,
logprobs_mode=self.model_config.logprobs_mode,
num_speculative_tokens=self.num_speculative_steps + 1,
)
self.rejection_sampler = RejectionSampler(
self.sampler,
num_speculative_steps=self.num_speculative_steps,
use_strict_rejection_sampling=use_strict_rejection_sampling,
)
self.prompt_logprobs_worker = PromptLogprobsWorker(self.max_num_reqs)

if self.is_last_pp_rank and not self.is_pooling_model:
# Initialize sampling-related workers.
# These components are only set up on the last PP rank and
# for generative (non-pooling) models.
self.sampler = Sampler(
max_num_reqs=self.max_num_reqs,
vocab_size=self.vocab_size,
device=self.device,
req_states=self.req_states,
logprobs_mode=self.model_config.logprobs_mode,
num_speculative_tokens=self.num_speculative_steps + 1,
)
self.rejection_sampler = RejectionSampler(
self.sampler,
num_speculative_steps=self.num_speculative_steps,
use_strict_rejection_sampling=use_strict_rejection_sampling,
)
self.prompt_logprobs_worker = PromptLogprobsWorker(self.max_num_reqs)
self.structured_outputs_worker = StructuredOutputsWorker(
max_num_logits=self.max_num_reqs * (self.num_speculative_steps + 1),
vocab_size=self.vocab_size,
device=self.device,
)
else:
self.sampler = None

Check failure on line 231 in vllm/v1/worker/gpu/model_runner.py

View workflow job for this annotation

GitHub Actions / pre-commit

Incompatible types in assignment (expression has type "None", variable has type "Sampler") [assignment]

Check failure on line 231 in vllm/v1/worker/gpu/model_runner.py

View workflow job for this annotation

GitHub Actions / pre-commit

Incompatible types in assignment (expression has type "None", variable has type "Sampler") [assignment]

Check failure on line 231 in vllm/v1/worker/gpu/model_runner.py

View workflow job for this annotation

GitHub Actions / pre-commit

Incompatible types in assignment (expression has type "None", variable has type "Sampler") [assignment]
self.rejection_sampler = None

Check failure on line 232 in vllm/v1/worker/gpu/model_runner.py

View workflow job for this annotation

GitHub Actions / pre-commit

Incompatible types in assignment (expression has type "None", variable has type "RejectionSampler") [assignment]

Check failure on line 232 in vllm/v1/worker/gpu/model_runner.py

View workflow job for this annotation

GitHub Actions / pre-commit

Incompatible types in assignment (expression has type "None", variable has type "RejectionSampler") [assignment]

Check failure on line 232 in vllm/v1/worker/gpu/model_runner.py

View workflow job for this annotation

GitHub Actions / pre-commit

Incompatible types in assignment (expression has type "None", variable has type "RejectionSampler") [assignment]
self.prompt_logprobs_worker = None

Check failure on line 233 in vllm/v1/worker/gpu/model_runner.py

View workflow job for this annotation

GitHub Actions / pre-commit

Incompatible types in assignment (expression has type "None", variable has type "PromptLogprobsWorker") [assignment]

Check failure on line 233 in vllm/v1/worker/gpu/model_runner.py

View workflow job for this annotation

GitHub Actions / pre-commit

Incompatible types in assignment (expression has type "None", variable has type "PromptLogprobsWorker") [assignment]
self.structured_outputs_worker = None

Check failure on line 234 in vllm/v1/worker/gpu/model_runner.py

View workflow job for this annotation

GitHub Actions / pre-commit

Incompatible types in assignment (expression has type "None", variable has type "StructuredOutputsWorker") [assignment]

Check failure on line 234 in vllm/v1/worker/gpu/model_runner.py

View workflow job for this annotation

GitHub Actions / pre-commit

Incompatible types in assignment (expression has type "None", variable has type "StructuredOutputsWorker") [assignment]

# CUDA graphs.
self.decode_query_len = self.num_speculative_steps + 1
Expand All @@ -222,21 +241,11 @@
self.compilation_config.cudagraph_mode,
decode_query_len=self.decode_query_len,
)
# Structured outputs worker.
self.structured_outputs_worker = StructuredOutputsWorker(
max_num_logits=self.max_num_reqs * (self.num_speculative_steps + 1),
vocab_size=self.vocab_size,
device=self.device,
)
# LoRA-related workers.
self.lora_state = LoraState(max_num_reqs=self.max_num_reqs)
# KV Connector if configured.
self.kv_connector: KVConnector = NO_OP_KV_CONNECTOR

# Pooling models.
self.is_pooling_model = self.model_config.runner_type == "pooling"
self.pooling_runner: PoolingRunner | None = None

# For transferring state from execute_model to subsequent sample_tokens call.
self.execute_model_state: ExecuteModelState | None = None

Expand Down Expand Up @@ -420,6 +429,7 @@

# dummy run the eagle speculator's propose to ensure DP/EP sync.
if self.speculator is not None:
assert self.sampler is not None
self.speculator.propose(
input_batch=input_batch,
attn_metadata=attn_metadata,
Expand Down Expand Up @@ -457,10 +467,8 @@
# NOTE(woosuk): During the initial memory profiling, the sampler may skip
# top_k, top_p, and logprobs, using less GPU memory than what is possible
# during actual execution.
self.sampler(
logits,
dummy_input_batch,
)
assert self.sampler is not None
self.sampler(logits, dummy_input_batch)

@torch.inference_mode()
def _dummy_pooler_run(self, hidden_states: torch.Tensor) -> None:
Expand Down Expand Up @@ -558,7 +566,8 @@
self.req_states.remove_request(req_id)
if self.encoder_cache is not None:
self.encoder_cache.remove_request(req_id)
self.prompt_logprobs_worker.remove_request(req_id)
if self.prompt_logprobs_worker is not None:
self.prompt_logprobs_worker.remove_request(req_id)
self.lora_state.remove_request(req_id)

def free_states(self, scheduler_output: SchedulerOutput) -> None:
Expand Down Expand Up @@ -589,18 +598,21 @@
)
self.lora_state.add_request(req_id, req_index, new_req_data.lora_request)

if new_req_data.sampling_params is not None:
if self.is_last_pp_rank and new_req_data.sampling_params is not None:
assert self.sampler is not None
self.sampler.add_request(
req_index, prompt_len, new_req_data.sampling_params
)
assert self.prompt_logprobs_worker is not None
self.prompt_logprobs_worker.add_request(
req_id, req_index, new_req_data.sampling_params
)

if scheduler_output.scheduled_new_reqs:
self.req_states.apply_staged_writes()
self.sampler.apply_staged_writes()
self.model_state.apply_staged_writes()
if self.sampler is not None:
self.sampler.apply_staged_writes()

def update_requests(self, scheduler_output: SchedulerOutput) -> None:
# Add new blocks for the existing requests.
Expand Down Expand Up @@ -788,6 +800,7 @@
logits = self.model.compute_logits(sample_hidden_states)
if grammar_output is not None:
# Apply grammar bitmask to the logits in-place.
assert self.structured_outputs_worker is not None
self.structured_outputs_worker.apply_grammar_bitmask(
logits,
input_batch,
Expand All @@ -797,12 +810,11 @@

if input_batch.num_draft_tokens == 0:
# No draft tokens (common case).
sampler_output = self.sampler(
logits,
input_batch,
)
assert self.sampler is not None
sampler_output = self.sampler(logits, input_batch)
else:
# Rejection sampling for spec decoding.
assert self.rejection_sampler is not None
sampler_output = self.rejection_sampler(
logits,
input_batch,
Expand Down Expand Up @@ -831,11 +843,16 @@
num_rejected: torch.Tensor,
) -> None:
# Update the number of computed tokens.
if self.is_last_pp_rank:
assert self.sampler is not None
output_bin_counts = self.sampler.penalties_state.output_bin_counts
else:
output_bin_counts = None
post_update(
input_batch.idx_mapping,
self.req_states.num_computed_tokens.gpu,
self.req_states.last_sampled_tokens,
self.sampler.penalties_state.output_bin_counts,
output_bin_counts,
sampled_tokens,
num_sampled,
num_rejected,
Expand Down
Loading