Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
82 changes: 82 additions & 0 deletions src/tests/test_engine_stats.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,82 @@
import threading
from unittest.mock import MagicMock, patch

from vllm_router.stats.engine_stats import EngineStats, EngineStatsScraper


def make_scraper():
scraper = object.__new__(EngineStatsScraper)
scraper.engine_stats = {}
scraper.engine_stats_lock = threading.Lock()
scraper.scrape_interval = 30.0
scraper.admission_scrape_interval = 1.0
scraper.on_metrics_update = None
scraper.running = False
return scraper


def test_queue_only_scrape_merges_waiting_count_without_mutating_existing_stats():
scraper = make_scraper()
existing_stats = EngineStats(
num_running_requests=7,
num_queuing_requests=2,
gpu_prefix_cache_hit_rate=0.2,
gpu_prefix_cache_hits_total=11,
gpu_prefix_cache_queries_total=17,
gpu_cache_usage_perc=0.5,
)
scraper.engine_stats = {"http://engine1": existing_stats}
scraped_stats = EngineStats(
num_running_requests=99,
num_queuing_requests=5,
gpu_prefix_cache_hit_rate=0.9,
gpu_prefix_cache_hits_total=99,
gpu_prefix_cache_queries_total=99,
gpu_cache_usage_perc=0.9,
)

endpoint = MagicMock(url="http://engine1")
with (
patch(
"vllm_router.stats.engine_stats.get_service_discovery",
return_value=MagicMock(
get_endpoint_info=MagicMock(return_value=[endpoint])
),
),
patch.object(scraper, "_scrape_one_endpoint", return_value=scraped_stats),
):
scraper._scrape_metrics(queue_only=True)

updated_stats = scraper.engine_stats["http://engine1"]
assert updated_stats is not existing_stats
assert updated_stats.num_queuing_requests == 5
assert updated_stats.num_running_requests == 7
assert updated_stats.gpu_prefix_cache_hit_rate == 0.2
assert updated_stats.gpu_prefix_cache_hits_total == 11
assert updated_stats.gpu_prefix_cache_queries_total == 17
assert updated_stats.gpu_cache_usage_perc == 0.5


def test_scrape_one_endpoint_uses_mode_specific_timeout():
scraper = make_scraper()
mock_response = MagicMock()
mock_response.text = ""
mock_response.raise_for_status.return_value = None

with (
patch(
"vllm_router.stats.engine_stats.requests.get", return_value=mock_response
) as mock_get,
patch(
"vllm_router.stats.engine_stats.EngineStats.from_vllm_scrape",
return_value=EngineStats(),
),
):
scraper._scrape_one_endpoint("http://engine1", queue_only=False)
scraper._scrape_one_endpoint("http://engine1", queue_only=True)

assert mock_get.call_args_list[0].kwargs["timeout"] == scraper.scrape_interval
assert (
mock_get.call_args_list[1].kwargs["timeout"]
== scraper.admission_scrape_interval
)
40 changes: 40 additions & 0 deletions src/tests/test_parser.py
Original file line number Diff line number Diff line change
Expand Up @@ -147,6 +147,46 @@ def test_validate_args_when_service_discovery_is_set_to_static_and_static_backen
)


def test_validate_args_when_router_queue_enabled_with_non_roundrobin_raises_value_error() -> (
None
):
with pytest.raises(ValueError):
parser.validate_args(
MagicMock(
routing_logic="session",
service_discovery="static",
static_backends="http://localhost:8000",
static_models="m1",
static_backend_health_checks=False,
enable_router_queue=True,
router_max_queued_requests=10,
router_max_queue_wait_seconds=5.0,
router_waiting_threshold_per_endpoint=1,
router_admission_scrape_interval_seconds=1.0,
)
)


def test_validate_args_when_router_queue_size_is_not_positive_raises_value_error() -> (
None
):
with pytest.raises(ValueError):
parser.validate_args(
MagicMock(
routing_logic="roundrobin",
service_discovery="static",
static_backends="http://localhost:8000",
static_models="m1",
static_backend_health_checks=False,
enable_router_queue=True,
router_max_queued_requests=0,
router_max_queue_wait_seconds=5.0,
router_waiting_threshold_per_endpoint=1,
router_admission_scrape_interval_seconds=1.0,
)
)


def test_validate_static_model_types_when_model_types_is_not_defines_raises_value_error() -> (
None
):
Expand Down
10 changes: 9 additions & 1 deletion src/vllm_router/app.py
Original file line number Diff line number Diff line change
Expand Up @@ -219,7 +219,15 @@ def initialize_all(app: FastAPI, args):
raise ValueError(f"Invalid service discovery type: {args.service_discovery}")

# Initialize singletons via custom functions.
initialize_engine_stats_scraper(args.engine_stats_interval)
app.state.admission_controller = None
initialize_engine_stats_scraper(
args.engine_stats_interval,
admission_scrape_interval=(
args.router_admission_scrape_interval_seconds
if args.enable_router_queue
else None
),
)
initialize_request_stats_monitor(args.request_stats_window)

if args.enable_batch_api:
Expand Down
56 changes: 54 additions & 2 deletions src/vllm_router/parsers/parser.py
Original file line number Diff line number Diff line change
Expand Up @@ -83,6 +83,10 @@ def validate_static_model_types(model_types: str | None) -> None:

# --- Argument Parsing and Initialization ---
def validate_args(args):
def _get_numeric_arg(name: str, default):
value = getattr(args, name, default)
return value if isinstance(value, (int, float)) else default

verify_required_args_provided(args)
if args.service_discovery == "static":
if args.static_backends is None:
Expand All @@ -101,6 +105,25 @@ def validate_args(args):
raise ValueError(
"Session key must be provided when using session routing logic."
)
if (
getattr(args, "enable_router_queue", False) is True
and args.routing_logic != "roundrobin"
):
raise ValueError(
"Router queue is only supported with roundrobin routing in phase 1."
)
if _get_numeric_arg("router_max_queued_requests", 256) <= 0:

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

only check these when enable_router_queue is true

raise ValueError("Router max queued requests must be greater than 0.")
if _get_numeric_arg("router_max_queue_wait_seconds", 5.0) <= 0:
raise ValueError("Router max queue wait seconds must be greater than 0.")
if _get_numeric_arg("router_waiting_threshold_per_endpoint", 1) <= 0:
raise ValueError(
"Router waiting threshold per endpoint must be greater than 0."
)
if _get_numeric_arg("router_admission_scrape_interval_seconds", 1.0) <= 0:
raise ValueError(
"Router admission scrape interval seconds must be greater than 0."
)
if args.log_stats and args.log_stats_interval <= 0:
raise ValueError("Log stats interval must be greater than 0.")
if args.engine_stats_interval <= 0:
Expand Down Expand Up @@ -246,6 +269,35 @@ def parse_args():
choices=["noop"],
help="The request rewriter to use. Default is 'noop' (no rewriting).",
)
parser.add_argument(
"--enable-router-queue",
action="store_true",
help="Enable router-side request queueing (phase 1 supports roundrobin only).",
)
parser.add_argument(
"--router-max-queued-requests",
type=int,
default=256,
help="Maximum number of requests that can wait in the router queue.",
)
parser.add_argument(
"--router-max-queue-wait-seconds",
type=float,
default=5.0,
help="Maximum time a request may wait in the router queue.",
)
parser.add_argument(
"--router-waiting-threshold-per-endpoint",
type=int,
default=1,
help="Maximum backend waiting depth per endpoint before router queueing kicks in.",
)
parser.add_argument(
"--router-admission-scrape-interval-seconds",
type=float,
default=1.0,
help="Admission-oriented metrics refresh interval when router queueing is enabled.",
)

# Batch API
# TODO(gaocegege): Make these batch api related arguments to a separate config.
Expand Down Expand Up @@ -363,14 +415,14 @@ def parse_args():
"--sentry-traces-sample-rate",
type=float,
default=0.1,
help="The sample rate for Sentry traces. Default is 0.1 (10%)",
help="The sample rate for Sentry traces. Default is 0.1 (10%%)",
)

parser.add_argument(
"--sentry-profile-session-sample-rate",
type=float,
default=1.0,
help="The sample rate for Sentry profiling sessions. Default is 1.0 (100%)",
help="The sample rate for Sentry profiling sessions. Default is 1.0 (100%%)",
)

# OpenTelemetry tracing arguments
Expand Down
40 changes: 31 additions & 9 deletions src/vllm_router/routers/routing_logic.py
Original file line number Diff line number Diff line change
Expand Up @@ -147,8 +147,38 @@ def __init__(self):
self.sorted_endpoints = []
self.last_endpoints_id = None
self.last_endpoints_hash = None
self._lock = threading.Lock()
self._initialized = True

def _refresh_sorted_endpoints(self, endpoints: List[EndpointInfo]) -> None:
endpoints_id = id(endpoints)
if endpoints_id != self.last_endpoints_id:
current_hash = hash(tuple(e.url for e in endpoints))
if current_hash != self.last_endpoints_hash:
self.sorted_endpoints = sorted(endpoints, key=lambda e: e.url)
self.last_endpoints_hash = current_hash
self.last_endpoints_id = endpoints_id

def pick_admissible_endpoint(
self,
endpoints: List[EndpointInfo],
is_admissible,

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Please help add comment that is_admissible check is always skipped for now.

And in the future, we will also skip this check router queue is not enabled.

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Missing type hint for is_admissible

) -> Optional[EndpointInfo]:
with self._lock:
self._refresh_sorted_endpoints(endpoints)
if not self.sorted_endpoints:
return None

start_index = self.req_id % len(self.sorted_endpoints)
for offset in range(len(self.sorted_endpoints)):
endpoint = self.sorted_endpoints[
(start_index + offset) % len(self.sorted_endpoints)
]
if is_admissible(endpoint):
self.req_id += offset + 1
return endpoint
return None

def route_request(
self,
endpoints: List[EndpointInfo],
Expand All @@ -168,15 +198,7 @@ def route_request(
indicating the request-level performance of each engine
request (Request): The incoming request
"""
endpoints_id = id(endpoints)
if endpoints_id != self.last_endpoints_id:
current_hash = hash(tuple(e.url for e in endpoints))
if current_hash != self.last_endpoints_hash:
self.sorted_endpoints = sorted(endpoints, key=lambda e: e.url)
self.last_endpoints_hash = current_hash
self.last_endpoints_id = endpoints_id
chosen = self.sorted_endpoints[self.req_id % len(self.sorted_endpoints)]
self.req_id += 1
chosen = self.pick_admissible_endpoint(endpoints, lambda _: True)

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If chosen is None, chosen.url will crash

return chosen.url


Expand Down
Loading
Loading