Skip to content

Commit 7c1a533

Browse files
committed
feat(gateway): make vLLM sampler requests asynchronous and rate-limited
- Dispatch vLLM token generation requests to a background task instead of blocking the FastAPI handler, aligning its async behavior with the Torch sampler backend. - Introduce VLLM_CONCURRENCY_LIMIT (default 512) and _vllm_semaphore to prevent socket/file-descriptor exhaustion and connection drop errors under heavy surges. - Maintain a global _background_tasks set to hold strong references to running background tasks and prevent premature garbage collection.
1 parent 3df2b6a commit 7c1a533

1 file changed

Lines changed: 39 additions & 27 deletions

File tree

src/server/gateway.py

Lines changed: 39 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -50,6 +50,12 @@ def filter(self, record: logging.LogRecord) -> bool:
5050

5151
TMP_DIR = os.getenv("OPEN_RL_TMP_DIR", "/tmp/open-rl")
5252
VLLM_URL = os.getenv("VLLM_URL", "http://127.0.0.1:8001")
53+
_background_tasks: set[asyncio.Task] = set()
54+
55+
# Limit the maximum concurrent active outgoing HTTP requests to the vLLM sampler
56+
# to prevent socket/file-descriptor exhaustion and connection dropped errors under heavy surges.
57+
VLLM_CONCURRENCY_LIMIT = int(os.getenv("VLLM_CONCURRENCY_LIMIT", "512"))
58+
_vllm_semaphore = asyncio.Semaphore(VLLM_CONCURRENCY_LIMIT)
5359

5460

5561
# *** Helpers ***
@@ -537,33 +543,39 @@ async def asample(req: dict):
537543
headers: dict[str, str] = {"Content-Type": "application/json"}
538544
propagate.inject(headers)
539545

540-
try:
541-
async with httpx.AsyncClient(timeout=120.0) as client:
542-
resp = await client.post(
543-
f"{VLLM_URL.rstrip('/')}/generate",
544-
json={
545-
"request_id": req_id,
546-
"prompt_token_ids": prompt,
547-
"max_tokens": max_tokens,
548-
"temperature": temperature,
549-
"stop": stop,
550-
"top_p": top_p,
551-
"top_k": top_k,
552-
"num_samples": num_samples,
553-
"lora_id": model_id,
554-
"lora_path": lora_path,
555-
"include_prompt_logprobs": include_prompt_logprobs,
556-
},
557-
headers=headers,
558-
)
559-
resp.raise_for_status()
560-
data = resp.json()
561-
if data.get("type") != "RequestFailedResponse":
562-
data["type"] = "sample"
563-
await store.set_future(req_id, data)
564-
except Exception as e:
565-
traceback.print_exc()
566-
await store.set_future(req_id, {"type": "RequestFailedResponse", "error_message": str(e)})
546+
async def _dispatch_vllm_generate():
547+
async with _vllm_semaphore:
548+
try:
549+
async with httpx.AsyncClient(timeout=120.0) as client:
550+
resp = await client.post(
551+
f"{VLLM_URL.rstrip('/')}/generate",
552+
json={
553+
"request_id": req_id,
554+
"prompt_token_ids": prompt,
555+
"max_tokens": max_tokens,
556+
"temperature": temperature,
557+
"stop": stop,
558+
"top_p": top_p,
559+
"top_k": top_k,
560+
"num_samples": num_samples,
561+
"lora_id": model_id,
562+
"lora_path": lora_path,
563+
"include_prompt_logprobs": include_prompt_logprobs,
564+
},
565+
headers=headers,
566+
)
567+
resp.raise_for_status()
568+
data = resp.json()
569+
if data.get("type") != "RequestFailedResponse":
570+
data["type"] = "sample"
571+
await store.set_future(req_id, data)
572+
except Exception as e:
573+
traceback.print_exc()
574+
await store.set_future(req_id, {"type": "RequestFailedResponse", "error_message": str(e)})
575+
576+
task = asyncio.create_task(_dispatch_vllm_generate())
577+
_background_tasks.add(task)
578+
task.add_done_callback(_background_tasks.discard)
567579

568580
return {"request_id": req_id}
569581

0 commit comments

Comments
 (0)