Skip to content

Commit 4f7d79e

Browse files
committed
feat(gateway): make vLLM sampler requests asynchronous and rate-limited
- Dispatch vLLM token generation requests to a background task instead of blocking the FastAPI handler, aligning its async behavior with the Torch sampler backend. - Introduce VLLM_CONCURRENCY_LIMIT (default 512) and _vllm_semaphore to prevent socket/file-descriptor exhaustion and connection drop errors under heavy surges. - Maintain a global _background_tasks set to hold strong references to running background tasks and prevent premature garbage collection.
1 parent bc91e08 commit 4f7d79e

1 file changed

Lines changed: 39 additions & 27 deletions

File tree

src/server/gateway.py

Lines changed: 39 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -49,6 +49,12 @@ def filter(self, record: logging.LogRecord) -> bool:
4949

5050
TMP_DIR = os.getenv("OPEN_RL_TMP_DIR", "/tmp/open-rl")
5151
VLLM_URL = os.getenv("VLLM_URL", "http://127.0.0.1:8001")
52+
_background_tasks: set[asyncio.Task] = set()
53+
54+
# Limit the maximum concurrent active outgoing HTTP requests to the vLLM sampler
55+
# to prevent socket/file-descriptor exhaustion and connection dropped errors under heavy surges.
56+
VLLM_CONCURRENCY_LIMIT = int(os.getenv("VLLM_CONCURRENCY_LIMIT", "512"))
57+
_vllm_semaphore = asyncio.Semaphore(VLLM_CONCURRENCY_LIMIT)
5258

5359

5460
# *** Helpers ***
@@ -536,33 +542,39 @@ async def asample(req: dict):
536542
headers: dict[str, str] = {"Content-Type": "application/json"}
537543
propagate.inject(headers)
538544

539-
try:
540-
async with httpx.AsyncClient(timeout=120.0) as client:
541-
resp = await client.post(
542-
f"{VLLM_URL.rstrip('/')}/generate",
543-
json={
544-
"request_id": req_id,
545-
"prompt_token_ids": prompt,
546-
"max_tokens": max_tokens,
547-
"temperature": temperature,
548-
"stop": stop,
549-
"top_p": top_p,
550-
"top_k": top_k,
551-
"num_samples": num_samples,
552-
"lora_id": model_id,
553-
"lora_path": lora_path,
554-
"include_prompt_logprobs": include_prompt_logprobs,
555-
},
556-
headers=headers,
557-
)
558-
resp.raise_for_status()
559-
data = resp.json()
560-
if data.get("type") != "RequestFailedResponse":
561-
data["type"] = "sample"
562-
await store.set_future(req_id, data)
563-
except Exception as e:
564-
traceback.print_exc()
565-
await store.set_future(req_id, {"type": "RequestFailedResponse", "error_message": str(e)})
545+
async def _dispatch_vllm_generate():
546+
async with _vllm_semaphore:
547+
try:
548+
async with httpx.AsyncClient(timeout=120.0) as client:
549+
resp = await client.post(
550+
f"{VLLM_URL.rstrip('/')}/generate",
551+
json={
552+
"request_id": req_id,
553+
"prompt_token_ids": prompt,
554+
"max_tokens": max_tokens,
555+
"temperature": temperature,
556+
"stop": stop,
557+
"top_p": top_p,
558+
"top_k": top_k,
559+
"num_samples": num_samples,
560+
"lora_id": model_id,
561+
"lora_path": lora_path,
562+
"include_prompt_logprobs": include_prompt_logprobs,
563+
},
564+
headers=headers,
565+
)
566+
resp.raise_for_status()
567+
data = resp.json()
568+
if data.get("type") != "RequestFailedResponse":
569+
data["type"] = "sample"
570+
await store.set_future(req_id, data)
571+
except Exception as e:
572+
traceback.print_exc()
573+
await store.set_future(req_id, {"type": "RequestFailedResponse", "error_message": str(e)})
574+
575+
task = asyncio.create_task(_dispatch_vllm_generate())
576+
_background_tasks.add(task)
577+
task.add_done_callback(_background_tasks.discard)
566578

567579
return {"request_id": req_id}
568580

0 commit comments

Comments
 (0)