Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
212 changes: 191 additions & 21 deletions agentlightning/store/client_server.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,18 +3,19 @@
from __future__ import annotations

import asyncio
import json
import logging
import os
import threading
import time
import traceback
from contextlib import suppress
from typing import Any, Awaitable, Callable, Dict, List, Literal, Optional, Sequence, Union
from typing import Any, AsyncIterator, Awaitable, Callable, Dict, List, Literal, Optional, Sequence, Union

import aiohttp
import uvicorn
from fastapi import FastAPI, Request, Response
from fastapi.responses import JSONResponse
from fastapi.responses import JSONResponse, StreamingResponse
from opentelemetry.sdk.trace import ReadableSpan
from pydantic import BaseModel, Field

Expand Down Expand Up @@ -348,9 +349,29 @@
async def get_next_span_sequence_id(rollout_id: str, attempt_id: str): # pyright: ignore[reportUnusedFunction]
return await self.get_next_span_sequence_id(rollout_id, attempt_id)

@self.app.post("/wait_for_rollouts", response_model=List[Rollout])
async def wait_for_rollouts(request: WaitForRolloutsRequest): # pyright: ignore[reportUnusedFunction]
return await self.wait_for_rollouts(rollout_ids=request.rollout_ids, timeout=request.timeout)
@self.app.post("/wait_for_rollouts")
async def wait_for_rollouts(request: Request): # pyright: ignore[reportUnusedFunction]
payload = WaitForRolloutsRequest.model_validate(await request.json())

async def event_stream():
# Send an initial comment to flush the headers early and keep the
# connection warm for intermediaries that expect activity.
yield ":ok\n\n"
try:
rollouts = await self.wait_for_rollouts(
rollout_ids=payload.rollout_ids,
timeout=payload.timeout,
)
except Exception as exc: # pragma: no cover - surfaced via SSE
error_payload = {"error": str(exc)}
yield "event: error\n"
yield f"data: {json.dumps(error_payload)}\n\n"
return

data = [rollout.model_dump(mode="json") for rollout in rollouts]
yield f"data: {json.dumps(data)}\n\n"

return StreamingResponse(event_stream(), media_type="text/event-stream")

Check warning

Code scanning / CodeQL

Information exposure through an exception Medium

Stack trace information
flows to this location and may be exposed to an external user.

Copilot Autofix

AI 11 days ago

To address the information exposure problem, we should prevent returning detailed exception information to external users. Instead, only a generic error message should be sent, such as "An internal error has occurred". The detailed exception information (including potentially the stack trace) should be logged server-side using the logging module or similar tooling.

Steps:

  • In the event_stream() function, replace the use of str(exc) in the yielded SSE data with a generic error string.
  • Log the detailed exception and stack trace using logger.error(traceback.format_exc()) or equivalent.
  • Do not change any application logic except for the error reporting in the SSE stream; maintain existing behavior and error-handling structure.

The required edits are confined to the relevant block in the shown code within agentlightning/store/client_server.py.


Suggested changeset 1
agentlightning/store/client_server.py

Autofix patch

Autofix patch
Run the following command in your local git repository to apply this patch
cat << 'EOF' | git apply
diff --git a/agentlightning/store/client_server.py b/agentlightning/store/client_server.py
--- a/agentlightning/store/client_server.py
+++ b/agentlightning/store/client_server.py
@@ -363,7 +363,8 @@
                         timeout=payload.timeout,
                     )
                 except Exception as exc:  # pragma: no cover - surfaced via SSE
-                    error_payload = {"error": str(exc)}
+                    logger.error("Exception in wait_for_rollouts SSE", exc_info=True)
+                    error_payload = {"error": "An internal error has occurred."}
                     yield "event: error\n"
                     yield f"data: {json.dumps(error_payload)}\n\n"
                     return
EOF
@@ -363,7 +363,8 @@
timeout=payload.timeout,
)
except Exception as exc: # pragma: no cover - surfaced via SSE
error_payload = {"error": str(exc)}
logger.error("Exception in wait_for_rollouts SSE", exc_info=True)
error_payload = {"error": "An internal error has occurred."}
yield "event: error\n"
yield f"data: {json.dumps(error_payload)}\n\n"
return
Copilot is powered by AI and may make mistakes. Always verify output.

@self.app.get("/query_spans/{rollout_id}", response_model=List[Span])
async def query_spans( # pyright: ignore[reportUnusedFunction]
Expand Down Expand Up @@ -704,6 +725,123 @@
assert last_exc is not None
raise last_exc

async def _iter_sse_events(self, response: aiohttp.ClientResponse) -> AsyncIterator[tuple[str, Any | None]]:
"""Yield parsed SSE events from a streaming response."""

buffer = ""
async for chunk in response.content.iter_any():
try:
decoded = chunk.decode("utf-8")
except UnicodeDecodeError:
decoded = chunk.decode("utf-8", errors="ignore")
normalized = decoded.replace("\r\n", "\n").replace("\r", "\n")
buffer += normalized

while True:
separator_index = buffer.find("\n\n")
if separator_index == -1:
break

raw_event, buffer = buffer[:separator_index], buffer[separator_index + 2 :]
event_type = "message"
data_lines: list[str] = []

for line in raw_event.splitlines():
if line.startswith("event:"):
event_type = line[6:].strip() or "message"
elif line.startswith("data:"):
data_lines.append(line[5:].lstrip())
elif line.startswith(":"):
# Comment/keep-alive line; ignore
continue

if not data_lines:
yield event_type, None
continue

raw_data = "\n".join(data_lines)
try:
parsed: Any | None = json.loads(raw_data)
except json.JSONDecodeError:
parsed = raw_data
yield event_type, parsed

async def _request_sse(
self,
path: str,
*,
json: Any | None = None,
read_timeout: Optional[float],
) -> Any:
"""Issue an SSE request and return the payload from the last data event."""

session = await self._get_session()
url = f"{self.server_address}{path if path.startswith('/') else '/' + path}"

attempts = (0.0,) + self._retry_delays
last_exc: Exception | None = None

for delay in attempts:
if delay:
logger.info(f"Waiting {delay} seconds before retrying sse: {path}")
await asyncio.sleep(delay)

timeout_margin = 5.0
if read_timeout is None:
timeout_cfg = aiohttp.ClientTimeout(total=None, connect=5.0, sock_connect=5.0, sock_read=None)
else:
limit = max(read_timeout, 0.0) + timeout_margin
timeout_cfg = aiohttp.ClientTimeout(total=limit, connect=5.0, sock_connect=5.0, sock_read=limit)

try:
async with session.post(
url,
json=json,
headers={"Accept": "text/event-stream"},
timeout=timeout_cfg,
) as resp:
resp.raise_for_status()

payload: Any | None = None
async for event_type, event_data in self._iter_sse_events(resp):
if event_type == "error":
message: str
if isinstance(event_data, dict) and "error" in event_data:
message = str(event_data["error"])
else:
message = str(event_data)
raise RuntimeError(f"SSE error: {message}")

if event_data is not None:
payload = event_data

return payload if payload is not None else []
except RuntimeError:
raise
except aiohttp.ClientResponseError as cre:
logger.debug(f"ClientResponseError: {cre.status} {cre.message}", exc_info=True)
if 400 <= cre.status < 500 and cre.status != 408:
raise
last_exc = cre
logger.info(f"5xx and other status codes will be retried. Retrying the request sse: {path}")
if not await self._wait_until_healthy(session):
break
except (
aiohttp.ServerDisconnectedError,
aiohttp.ClientConnectorError,
aiohttp.ClientOSError,
aiohttp.ClientPayloadError,
asyncio.TimeoutError,
) as net_exc:
logger.debug(f"Network/session issue: {net_exc}", exc_info=True)
last_exc = net_exc
logger.info(f"Network/session issue will be retried. Retrying the request sse: {path}")
if not await self._wait_until_healthy(session):
break

assert last_exc is not None
raise last_exc

async def close(self):
"""Close the HTTP session."""
with self._lock:
Expand Down Expand Up @@ -945,25 +1083,57 @@
return span

async def wait_for_rollouts(self, *, rollout_ids: List[str], timeout: Optional[float] = None) -> List[Rollout]:
"""Wait for rollouts to complete.
"""Wait for the given rollouts to complete using server-sent events."""

Args:
rollout_ids: List of rollout IDs to wait for.
timeout: Timeout in seconds. If not None, the method will raise a ValueError if the timeout is greater than 0.1 seconds.
if timeout is not None and timeout < 0:
raise ValueError("Timeout must be non-negative")

Returns:
List of rollouts that are completed.
"""
if timeout is not None and timeout > 0.1:
raise ValueError(
"Timeout must be less than 0.1 seconds in LightningStoreClient to avoid blocking the event loop"
pending_ids = set(rollout_ids)
completed: Dict[str, Rollout] = {}
deadline = None if timeout is None else time.monotonic() + timeout

while pending_ids:
if deadline is not None:
remaining = deadline - time.monotonic()
if remaining <= 0:
break
request_timeout = max(0.0, remaining)
else:
request_timeout = None

per_call_timeout = request_timeout
if per_call_timeout is not None:
per_call_timeout = min(per_call_timeout, 60.0)

payload = WaitForRolloutsRequest(
rollout_ids=list(pending_ids),
timeout=per_call_timeout,
).model_dump()

raw_data = await self._request_sse(
"/wait_for_rollouts",
json=payload,
read_timeout=per_call_timeout,
)
data = await self._request_json(
"post",
"/wait_for_rollouts",
json=WaitForRolloutsRequest(rollout_ids=rollout_ids, timeout=timeout).model_dump(),
)
return [Rollout.model_validate(item) for item in data]

if isinstance(raw_data, dict) and "rollouts" in raw_data:
items = raw_data.get("rollouts")
else:
items = raw_data

new_rollout_seen = False
for item in items or []:
rollout = Rollout.model_validate(item)
completed[rollout.rollout_id] = rollout
if rollout.rollout_id in pending_ids:
pending_ids.remove(rollout.rollout_id)
new_rollout_seen = True

if not new_rollout_seen:
if per_call_timeout is None or per_call_timeout == 0:
break

return [completed[rid] for rid in rollout_ids if rid in completed]

async def query_spans(
self,
Expand Down
5 changes: 4 additions & 1 deletion docs/deep-dive/store.md
Original file line number Diff line number Diff line change
Expand Up @@ -201,4 +201,7 @@ agl store --port 4747

!!! note

[`LightningStoreClient.wait_for_rollouts`][agentlightning.LightningStoreClient.wait_for_rollouts] intentionally enforces a tiny timeout (≤ 0.1s) to avoid blocking event loops. Poll with short timeouts or compose with `asyncio.wait_for` at a higher layer.
[`LightningStoreClient.wait_for_rollouts`][agentlightning.LightningStoreClient.wait_for_rollouts] streams results over
Server-Sent Events (SSE). The client keeps the connection open while the store waits for completions and automatically
reconnects in 60-second chunks for very large timeouts. Use long timeouts directly or wrap the call in
[`asyncio.wait_for`](https://docs.python.org/3/library/asyncio-task.html#asyncio.wait_for) if you need an overall deadline.
4 changes: 3 additions & 1 deletion docs/how-to/unsloth-sft.md
Original file line number Diff line number Diff line change
Expand Up @@ -118,7 +118,9 @@ while True:

!!! note

The `timeout=0.0` is needed here because this example uses a [`LightningStoreClient`][agentlightning.LightningStoreClient], and `wait_for_rollouts` establishes an HTTP connection to that store. Currently, only non-blocking wait requests are supported, which avoids holding the store connection open.
[`LightningStoreClient.wait_for_rollouts`][agentlightning.LightningStoreClient.wait_for_rollouts] now streams results over
SSE, so you can use longer timeouts when it simplifies your loop. In this example we still poll with `timeout=0.0` to check
progress alongside other work, but a larger timeout would keep the connection open until the store reports completions.

Once the rollouts complete, we terminate the vLLM server to free up GPU memory.

Expand Down
46 changes: 41 additions & 5 deletions tests/store/test_client_server.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
import multiprocessing
import socket
import sys
from typing import Any, AsyncGenerator, Tuple, cast
from typing import Any, AsyncGenerator, Optional, Tuple, cast
from unittest.mock import patch

import aiohttp
Expand All @@ -19,7 +19,7 @@
from agentlightning.store.base import UNSET
from agentlightning.store.client_server import LightningStoreClient, LightningStoreServer
from agentlightning.store.memory import InMemoryLightningStore
from agentlightning.types import LLM, OtelResource, PromptTemplate, RolloutConfig, Span, TraceStatus
from agentlightning.types import LLM, OtelResource, PromptTemplate, Rollout, RolloutConfig, Span, TraceStatus


def _get_free_port() -> int:
Expand Down Expand Up @@ -622,12 +622,48 @@ def failing_health_get(self: aiohttp.ClientSession, url: Any, *args: Any, **kwar


@pytest.mark.asyncio
async def test_wait_for_rollouts_timeout_guard_raises(
async def test_wait_for_rollouts_allows_large_timeouts(
server_client: Tuple[LightningStoreServer, LightningStoreClient],
) -> None:
_, client = server_client
with pytest.raises(ValueError):
await client.wait_for_rollouts(rollout_ids=["dummy"], timeout=0.2)
# Larger timeouts are now supported via SSE streaming.
assert (await client.wait_for_rollouts(rollout_ids=["dummy"], timeout=0.2)) == []


@pytest.mark.asyncio
async def test_wait_for_rollouts_chunks_long_timeouts(monkeypatch: MonkeyPatch) -> None:
client = LightningStoreClient("http://test")
calls: list[tuple[Optional[float], Optional[float], list[str]]] = []

async def fake_request_sse(
self: LightningStoreClient,
path: str,
*,
json: Any | None = None,
read_timeout: Optional[float],
) -> Any:
assert path == "/wait_for_rollouts"
payload = cast(dict[str, Any], json)
rollout_list = list(payload["rollout_ids"])
calls.append((payload.get("timeout"), read_timeout, rollout_list))
if len(calls) < 3:
return []
rollout = Rollout(rollout_id="r1", input={}, start_time=0.0, status="succeeded")
return [rollout.model_dump(mode="json")]

monkeypatch.setattr(LightningStoreClient, "_request_sse", fake_request_sse, raising=True)

try:
result = await client.wait_for_rollouts(rollout_ids=["r1"], timeout=120.0)
assert result and result[0].rollout_id == "r1"
assert calls, "SSE helper should have been invoked"
first_timeout, first_read_timeout, rollout_list = calls[0]
assert rollout_list == ["r1"]
assert first_timeout == 60.0
assert first_read_timeout == 60.0
assert all(timeout is None or timeout <= 60.0 for timeout, _, _ in calls)
finally:
await client.close()


@pytest.mark.asyncio
Expand Down
Loading