Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
67 changes: 66 additions & 1 deletion tests/v1/core/test_async_scheduler.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,12 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from collections import deque
from unittest.mock import Mock

import pytest

from vllm.v1.core.sched.output import SchedulerOutput
from vllm.v1.core.sched.async_scheduler import AsyncScheduler
from vllm.v1.core.sched.output import CachedRequestData, SchedulerOutput
from vllm.v1.outputs import ModelRunnerOutput
from vllm.v1.request import RequestStatus
from vllm.v1.utils import ConstantList
Expand Down Expand Up @@ -247,3 +249,66 @@ def test_prefix_caching_for_multi_turn():
# requests.
for req in next_turn_requests:
assert req.num_cached_tokens == req.num_prompt_tokens // BLOCK_SIZE * BLOCK_SIZE


def test_abort_request_when_structured_output_fsm_cannot_advance():
scheduler = object.__new__(AsyncScheduler)
request = create_requests(num_requests=1, num_tokens=1)[0]
request.structured_output_request = Mock()
request.structured_output_request.grammar = Mock()
request.structured_output_request.grammar.accept_tokens.return_value = False
request.status = RequestStatus.RUNNING
request.num_computed_tokens = request.num_tokens
request.num_output_placeholders = 1

scheduler.perf_metrics = None
scheduler.connector = None
scheduler.structured_output_manager = Mock()
scheduler.structured_output_manager.should_advance.return_value = True
scheduler.requests = {request.request_id: request}
scheduler.running = [request]
scheduler.waiting = Mock()
scheduler.kv_cache_manager = Mock()
scheduler.kv_cache_manager.take_events.return_value = None
scheduler.kv_event_publisher = Mock()
scheduler.finished_req_ids = set()
scheduler.finished_req_ids_dict = None
scheduler.vllm_config = Mock()
scheduler.vllm_config.model_config.enable_return_routed_experts = False
scheduler.recompute_kv_load_failures = False
scheduler.make_stats = Mock(return_value=None)
scheduler.max_model_len = 128

def free_request(req, delay_free_blocks=False):
scheduler.finished_req_ids.add(req.request_id)
scheduler.requests.pop(req.request_id, None)
return None

scheduler._free_request = Mock(side_effect=free_request)

output = SchedulerOutput(
scheduled_new_reqs=[],
scheduled_cached_reqs=CachedRequestData.make_empty(),
num_scheduled_tokens={request.request_id: 1},
total_num_scheduled_tokens=1,
scheduled_encoder_inputs={},
scheduled_spec_decode_tokens={},
num_common_prefix_blocks=[],
finished_req_ids=set(),
free_encoder_mm_hashes=[],
)
model_runner_output = ModelRunnerOutput(
req_ids=[request.request_id],
req_id_to_index={request.request_id: 0},
sampled_token_ids=[[123]],
logprobs=None,
prompt_logprobs_dict={},
pooler_output=[],
)

scheduler.update_from_output(output, model_runner_output)

assert request.resumable is False
assert request.status == RequestStatus.FINISHED_ERROR
assert request.request_id not in scheduler.requests
assert not scheduler.running
81 changes: 81 additions & 0 deletions tests/v1/core/test_scheduler.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@
from vllm.v1.core.kv_cache_utils import get_request_block_hasher, init_none_hash
from vllm.v1.core.sched.output import CachedRequestData, SchedulerOutput
from vllm.v1.core.sched.scheduler import Scheduler
from vllm.v1.engine import FinishReason
from vllm.v1.kv_cache_interface import (
FullAttentionSpec,
KVCacheConfig,
Expand Down Expand Up @@ -2463,6 +2464,86 @@ def test_schedule_skip_tokenizer_init_structured_output_request():
assert len(scheduler.skipped_waiting) == 1


def test_abort_request_when_structured_output_fsm_cannot_advance():
scheduler = object.__new__(Scheduler)
sampling_params = SamplingParams(ignore_eos=True, max_tokens=4)
sampling_params.update_from_generation_config({}, EOS_TOKEN_ID)

request = Request(
request_id="0",
prompt_token_ids=[0, 1],
mm_features=None,
sampling_params=sampling_params,
pooling_params=None,
)
request.structured_output_request = Mock()
request.structured_output_request.grammar = Mock()
request.structured_output_request.grammar.accept_tokens.return_value = False
request.status = RequestStatus.RUNNING
request.num_computed_tokens = request.num_tokens

scheduler.perf_metrics = None
scheduler.connector = None
scheduler.structured_output_manager = Mock()
scheduler.structured_output_manager.should_advance.return_value = True
scheduler.requests = {request.request_id: request}
scheduler.running = [request]
scheduler.waiting = Mock()
scheduler.kv_cache_manager = Mock()
scheduler.kv_cache_manager.take_events.return_value = None
scheduler.kv_event_publisher = Mock()
scheduler.finished_req_ids = set()
scheduler.finished_req_ids_dict = None
scheduler.vllm_config = Mock()
scheduler.vllm_config.model_config.enable_return_routed_experts = False
scheduler.recompute_kv_load_failures = False
scheduler.make_stats = Mock(return_value=None)
scheduler.max_model_len = 128

def free_request(req: Request, delay_free_blocks: bool = False):
scheduler.finished_req_ids.add(req.request_id)
scheduler.requests.pop(req.request_id, None)
return None

scheduler._free_request = Mock(side_effect=free_request)

output = SchedulerOutput(
scheduled_new_reqs=[],
scheduled_cached_reqs=CachedRequestData.make_empty(),
num_scheduled_tokens={request.request_id: 1},
total_num_scheduled_tokens=1,
scheduled_encoder_inputs={},
scheduled_spec_decode_tokens={},
num_common_prefix_blocks=[],
finished_req_ids=set(),
free_encoder_mm_hashes=[],
)

model_runner_output = ModelRunnerOutput(
req_ids=[request.request_id],
req_id_to_index={request.request_id: 0},
sampled_token_ids=[[123]],
logprobs=None,
prompt_logprobs_dict={},
pooler_output=[],
)
engine_core_outputs = scheduler.update_from_output(output, model_runner_output)

request.structured_output_request.grammar.accept_tokens.assert_called_once_with(
request.request_id, [123]
)
assert request.resumable is False
assert request.status == RequestStatus.FINISHED_ERROR
assert request.request_id not in scheduler.requests
assert not scheduler.running
scheduler._free_request.assert_called_once_with(request)
assert len(engine_core_outputs[0].outputs) == 1
engine_core_output = engine_core_outputs[0].outputs[0]
assert engine_core_output.request_id == request.request_id
assert engine_core_output.new_token_ids == [123]
assert engine_core_output.finish_reason == FinishReason.ERROR


@pytest.mark.parametrize(
"use_ec_connector, ec_role", [(False, None), (True, "ec_consumer")]
)
Expand Down
29 changes: 17 additions & 12 deletions vllm/v1/core/sched/scheduler.py
Original file line number Diff line number Diff line change
Expand Up @@ -1406,6 +1406,23 @@ def update_from_output(
request.status = RequestStatus.FINISHED_STOPPED
stopped = True

if new_token_ids and self.structured_output_manager.should_advance(request):
struct_output_request = request.structured_output_request
assert struct_output_request is not None
assert struct_output_request.grammar is not None
if not struct_output_request.grammar.accept_tokens( # type: ignore[union-attr]
req_id, new_token_ids
):
logger.error(
"Unexpected: grammar rejected tokens %s for request %s. "
"Terminating request.",
new_token_ids,
req_id,
)
request.status = RequestStatus.FINISHED_ERROR
request.resumable = False
stopped = True

routed_experts = None
finish_reason = None
if stopped:
Expand All @@ -1431,18 +1448,6 @@ def update_from_output(
):
new_logprobs = logprobs.slice_request(req_index, len(new_token_ids))

if new_token_ids and self.structured_output_manager.should_advance(request):
struct_output_request = request.structured_output_request
assert struct_output_request is not None
assert struct_output_request.grammar is not None
ok = struct_output_request.grammar.accept_tokens(req_id, new_token_ids)
if not ok:
logger.warning(
"Unexpected: grammar rejected tokens %s for request %s.",
new_token_ids,
req_id,
)

if num_nans_in_logits is not None and req_id in num_nans_in_logits:
request.num_nans_in_logits = num_nans_in_logits[req_id]

Expand Down
Loading