Skip to content

Commit baecec8

Browse files
authored
Merge pull request #92 from zhuyanhuazhuyanhua/sse-split
feat(sse): add SSE agent and utils
2 parents d678810 + 318add4 commit baecec8

File tree

2 files changed

+194
-62
lines changed

2 files changed

+194
-62
lines changed

oxygent/oxy/agents/sse_oxy_agent.py

Lines changed: 42 additions & 62 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@
88

99
from ...schemas import OxyRequest, OxyResponse, OxyState
1010
from ...utils.common_utils import build_url
11+
from ...utils.sse_utils import iter_sse_events
1112
from .remote_agent import RemoteAgent
1213

1314
logger = logging.getLogger(__name__)
@@ -101,77 +102,56 @@ async def _execute(self, oxy_request: OxyRequest) -> OxyResponse:
101102
status=resp.status
102103
)
103104

104-
message_id = None
105-
message_event = None
106-
message_data = None
107-
message_retry = None
105+
# 使用规范的 SSE 事件解析
106+
async for event in iter_sse_events(resp):
107+
message_event = event.get("event")
108+
message_data = event.get("data")
109+
message_id = event.get("id")
110+
message_retry = event.get("retry")
108111

109-
async for line in resp.content:
110-
line = line.decode("utf-8").strip()
111-
112-
if line.startswith("id:"):
113-
message_id = line[3:].strip()
114-
elif line.startswith("event:"):
115-
message_event = line[6:].strip()
116-
elif line.startswith("data:"):
117-
message_data = line[5:].strip()
118-
elif line.startswith("retry:"):
112+
if message_event == "close":
113+
logger.info(
114+
f"Received request to terminate SSE connection: {message_data}. {self.server_url}",
115+
extra={
116+
"trace_id": oxy_request.current_trace_id,
117+
"node_id": oxy_request.node_id,
118+
},
119+
)
120+
await resp.release()
121+
return OxyResponse(state=OxyState.COMPLETED, output=answer)
122+
else:
119123
try:
120-
message_retry = int(line[6:].strip())
121-
except ValueError:
122-
message_retry = None
123-
124-
# 当到达一个完整的事件时,处理它
125-
if line == "":
126-
if message_event == "close":
127-
logger.info(
128-
f"Received request to terminate SSE connection: {message_data}. {self.server_url}",
129-
extra={
130-
"trace_id": oxy_request.current_trace_id,
131-
"node_id": oxy_request.node_id,
132-
},
133-
)
134-
await resp.release()
135-
return OxyResponse(state=OxyState.COMPLETED, output=answer)
136-
else:
137-
try:
138-
data = json.loads(message_data)
139-
message_data_type = data.get("type", "")
140-
if message_data_type == "answer":
141-
answer = data.get("content")
142-
elif message_data_type in [
143-
"tool_call",
144-
"observation",
145-
]:
146-
if (
147-
data["content"]["caller_category"] == "user"
148-
or data["content"]["callee_category"] == "user"
149-
):
150-
continue
151-
else:
152-
# Discord user and callee
153-
if not self.is_share_call_stack:
154-
data["content"]["call_stack"] = (
155-
oxy_request.call_stack
156-
+ data["content"]["call_stack"][2:]
157-
)
158-
await oxy_request.send_message(
159-
data, event=message_event, id=message_id, retry=message_retry
160-
)
124+
data = json.loads(message_data)
125+
message_data_type = data.get("type", "")
126+
if message_data_type == "answer":
127+
answer = data.get("content")
128+
elif message_data_type in [
129+
"tool_call",
130+
"observation",
131+
]:
132+
if (
133+
data["content"]["caller_category"] == "user"
134+
or data["content"]["callee_category"] == "user"
135+
):
136+
continue
161137
else:
138+
# Discord user and callee
139+
if not self.is_share_call_stack:
140+
data["content"]["call_stack"] = (
141+
oxy_request.call_stack
142+
+ data["content"]["call_stack"][2:]
143+
)
162144
await oxy_request.send_message(
163145
data, event=message_event, id=message_id, retry=message_retry
164146
)
165-
except json.JSONDecodeError:
147+
else:
166148
await oxy_request.send_message(
167149
data, event=message_event, id=message_id, retry=message_retry
168150
)
169-
170-
# 重置变量以准备下一个事件
171-
message_id = None
172-
message_event = None
173-
message_data = None
174-
message_retry = None
151+
except json.JSONDecodeError:
152+
await oxy_request.send_message(
153+
data, event=message_event, id=message_id, retry=message_retry
154+
)
175155

176156
# 如果正常完成,直接返回
177157
return OxyResponse(state=OxyState.COMPLETED, output=answer)

oxygent/utils/sse_utils.py

Lines changed: 152 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,152 @@
1+
import aiohttp
2+
from typing import AsyncIterator, Dict, Any, Optional
3+
async def iter_sse_events(
4+
resp: aiohttp.ClientResponse,
5+
*,
6+
chunk_size: int = 8 * 1024,
7+
max_buffer_bytes: int = 2 * 1024 * 1024, # 2 MiB total buffered (anti-DoS)
8+
max_event_bytes: int = 512 * 1024, # 512 KiB per event block (anti-DoS)
9+
max_data_bytes: int = 512 * 1024, # 512 KiB of accumulated "data:" per event
10+
allow_partial_final_event: bool = True, # flush remaining bytes on EOF
11+
) -> AsyncIterator[Dict[str, Any]]:
12+
"""
13+
Parse an HTTP response body as a Server-Sent Events (SSE) stream.
14+
15+
Yields dicts:
16+
{
17+
"event": str | None,
18+
"data": str,
19+
"id": str | None,
20+
"retry": int | None,
21+
}
22+
23+
Security/robustness hardening:
24+
- Caps buffer size to mitigate memory/CPU DoS if delimiters never arrive.
25+
- Caps per-event size and per-event data size.
26+
- Best-effort flush of the final event on EOF (optional).
27+
- Ignores `id` values containing NUL (\\x00).
28+
- Accepts `retry` only if non-negative int.
29+
"""
30+
31+
def _parse_event_block(raw: bytes) -> Optional[Dict[str, Any]]:
32+
# Trim leading/trailing newlines within the event block.
33+
raw = raw.strip(b"\n")
34+
if not raw:
35+
return None
36+
37+
event_type: Optional[str] = None
38+
event_id: Optional[str] = None
39+
retry: Optional[int] = None
40+
data_lines: list[str] = []
41+
data_bytes = 0
42+
43+
for line in raw.split(b"\n"):
44+
# Comment/heartbeat lines start with ":" and should be ignored.
45+
if line.startswith(b":"):
46+
continue
47+
48+
# Field parsing:
49+
# - "field:value"
50+
# - "field:" (empty value)
51+
# - "field" (empty value)
52+
if b":" in line:
53+
field_b, value_b = line.split(b":", 1)
54+
if value_b.startswith(b" "):
55+
value_b = value_b[1:]
56+
else:
57+
field_b, value_b = line, b""
58+
59+
# Empty field names are ignored in practice.
60+
if not field_b:
61+
continue
62+
63+
field = field_b.decode("utf-8", errors="replace")
64+
65+
if field == "event":
66+
event_type = value_b.decode("utf-8", errors="replace")
67+
68+
elif field == "data":
69+
# Cap total event data bytes (prevents huge multi-line data DoS).
70+
data_bytes += len(value_b)
71+
if data_bytes > max_data_bytes:
72+
raise ValueError(
73+
f"SSE event data too large (> {max_data_bytes} bytes)"
74+
)
75+
data_lines.append(value_b.decode("utf-8", errors="replace"))
76+
77+
elif field == "id":
78+
# Ignore IDs containing NUL (common interoperability/safety behavior).
79+
if b"\x00" in value_b:
80+
continue
81+
event_id = value_b.decode("utf-8", errors="replace")
82+
83+
elif field == "retry":
84+
try:
85+
n = int(value_b.decode("utf-8", errors="replace"))
86+
if n >= 0:
87+
retry = n
88+
except ValueError:
89+
pass
90+
91+
# Skip blocks that contain no meaningful fields (e.g., only comments).
92+
if event_type is None and event_id is None and retry is None and not data_lines:
93+
return None
94+
95+
return {
96+
"event": event_type,
97+
"data": "\n".join(data_lines),
98+
"id": event_id,
99+
"retry": retry,
100+
}
101+
102+
buf = bytearray()
103+
104+
async for chunk in resp.content.iter_chunked(chunk_size):
105+
if not chunk:
106+
continue
107+
108+
# Normalize newlines to LF ("\n") safely across chunk boundaries.
109+
if buf and buf[-1] == 0x0D: # '\r'
110+
if chunk[:1] == b"\n":
111+
buf[-1] = 0x0A # '\n'
112+
chunk = chunk[1:]
113+
else:
114+
buf[-1] = 0x0A # '\n'
115+
116+
# Normalize inside this chunk.
117+
chunk = chunk.replace(b"\r\n", b"\n").replace(b"\r", b"\n")
118+
buf.extend(chunk)
119+
120+
# Total buffer cap (anti-DoS).
121+
if len(buf) > max_buffer_bytes:
122+
raise ValueError(
123+
f"SSE buffer too large (> {max_buffer_bytes} bytes); delimiter not found"
124+
)
125+
126+
while True:
127+
# After normalization, SSE events are separated by "\n\n".
128+
sep = buf.find(b"\n\n")
129+
if sep == -1:
130+
break
131+
132+
# Per-event cap before copying bytes out.
133+
if sep > max_event_bytes:
134+
raise ValueError(f"SSE event too large (> {max_event_bytes} bytes)")
135+
136+
raw = bytes(buf[:sep])
137+
del buf[: sep + 2] # consume the delimiter too
138+
139+
evt = _parse_event_block(raw)
140+
if evt is not None:
141+
yield evt
142+
143+
# EOF flush: normalize trailing '\r' that never got paired.
144+
if buf and buf[-1] == 0x0D:
145+
buf[-1] = 0x0A
146+
147+
if allow_partial_final_event and buf:
148+
if len(buf) > max_event_bytes:
149+
raise ValueError(f"SSE final event too large (> {max_event_bytes} bytes)")
150+
evt = _parse_event_block(bytes(buf))
151+
if evt is not None:
152+
yield evt

0 commit comments

Comments
 (0)