Skip to content

Commit 454f527

Browse files
perf(anthropic): decode upstream SSE via shared SseDecoder
Replace the hand-rolled byte buffer + `find_double_newline` + `parse_sse_frame` in `consume_and_forward` with the shared `common::sse::SseDecoder`. The decoder provides cursor-tracked buffering (single compact per batch instead of per-frame `Vec::drain`), deferred UTF-8 validation for multi-byte characters split across network chunks, and the same 1 MB DoS cap. A small `resolve_event` helper preserves Anthropic's existing behavior of inferring the event type from the payload's `"type"` field when an explicit `event:` line is absent. Frame parsing is now zero-allocation for the common single-line `data:` case. Tests updated to exercise the decoder path; the inference fallback is covered explicitly. Signed-off-by: XinyueZhang369 <zoeyzhang369@gmail.com>
1 parent 06d761c commit 454f527

1 file changed

Lines changed: 122 additions & 93 deletions

File tree

  • model_gateway/src/routers/anthropic

model_gateway/src/routers/anthropic/sse.rs

Lines changed: 122 additions & 93 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33
//! Provides SSE frame parsing, event formatting, stream wrappers,
44
//! and the core stream consumption logic used by the streaming processor.
55
6-
use std::io;
6+
use std::{borrow::Cow, io};
77

88
use axum::{
99
body::Body,
@@ -18,7 +18,10 @@ use tokio::sync::mpsc;
1818
use tracing::{debug, error, warn};
1919

2020
use super::mcp::{IterationResult, McpToolCall};
21-
use crate::routers::error::internal_error;
21+
use crate::routers::{
22+
common::sse::{SseDecodeError, SseDecoder, SseFrame},
23+
error::internal_error,
24+
};
2225

2326
// ============================================================================
2427
// Constants
@@ -216,47 +219,43 @@ where
216219
F: Fn(&str) -> String,
217220
{
218221
let mut stream = response.bytes_stream();
219-
let mut buffer = Vec::<u8>::new();
222+
// Shared SSE decoder: cursor-tracked buffering, deferred UTF-8 validation
223+
// (handles multi-byte characters split across network chunks), and built-in
224+
// DoS protection at `MAX_SSE_BUFFER_SIZE`.
225+
let mut decoder = SseDecoder::with_max_size(MAX_SSE_BUFFER_SIZE);
220226
let mut processor =
221227
EventProcessor::new(tx, global_index, is_first_iteration, resolve_server_name);
222228

223229
while let Some(chunk_result) = stream.next().await {
224230
let chunk = chunk_result.map_err(|e| format!("Stream read error: {e}"))?;
225231

226-
// Guard against unbounded buffer growth (DoS protection).
227-
// Check *before* extending so a single oversized chunk never
228-
// causes an allocation beyond the cap.
229-
if buffer.len() + chunk.len() > MAX_SSE_BUFFER_SIZE {
230-
return Err(format!(
232+
decoder.push(&chunk).map_err(|e| match e {
233+
SseDecodeError::BufferOverflow => format!(
231234
"SSE buffer exceeded maximum size ({MAX_SSE_BUFFER_SIZE} bytes) — possible malformed upstream stream"
232-
));
233-
}
234-
235-
buffer.extend_from_slice(&chunk);
235+
),
236+
other => format!("SSE decode error: {other}"),
237+
})?;
236238

237-
// Process complete SSE frames (delimited by double newline).
238-
// UTF-8 validation is deferred to complete frames so that multi-byte
239-
// characters split across network chunks don't cause spurious errors.
240-
while let Some(pos) = find_double_newline(&buffer) {
241-
let frame_bytes = &buffer[..pos];
242-
let frame = std::str::from_utf8(frame_bytes)
243-
.map_err(|e| format!("Invalid UTF-8 in SSE frame: {e}"))?;
244-
if let Some((event_type, data)) = parse_sse_frame(frame) {
239+
while let Some(frame) = decoder.next_frame() {
240+
let frame = frame.map_err(|e| format!("Invalid UTF-8 in SSE frame: {e}"))?;
241+
if let Some((event_type, data)) = resolve_event(frame) {
245242
processor.process(&event_type, &data).await?;
246243
}
247-
buffer.drain(..pos + 2);
248244
}
245+
decoder.compact();
249246
}
250247

251-
// Process any remaining data in buffer
252-
if !buffer.is_empty() {
253-
let remaining = std::str::from_utf8(&buffer)
254-
.map_err(|e| format!("Invalid UTF-8 in final SSE data: {e}"))?;
255-
let trimmed = remaining.trim();
256-
if !trimmed.is_empty() {
257-
if let Some((event_type, data)) = parse_sse_frame(trimmed) {
258-
processor.process(&event_type, &data).await?;
259-
}
248+
// Process any trailing data not terminated by a blank line.
249+
if let Some(frame) = decoder.flush() {
250+
let frame = frame.map_err(|e| match e {
251+
SseDecodeError::InvalidUtf8(u) => format!("Invalid UTF-8 in final SSE data: {u}"),
252+
// The loop above drains every complete frame, so `flush()` can't
253+
// return `IncompleteFlush` here — handled for an accurate message
254+
// if that ever changes.
255+
other => format!("SSE decode error on flush: {other}"),
256+
})?;
257+
if let Some((event_type, data)) = resolve_event(frame) {
258+
processor.process(&event_type, &data).await?;
260259
}
261260
}
262261

@@ -653,57 +652,33 @@ where
653652
}
654653

655654
// ============================================================================
656-
// SSE frame parsing
655+
// SSE frame resolution
657656
// ============================================================================
658657

659-
/// Find the position of `\n\n` in a byte buffer.
660-
fn find_double_newline(buf: &[u8]) -> Option<usize> {
661-
buf.windows(2).position(|w| w == b"\n\n")
662-
}
663-
664-
/// Parse a raw SSE frame into `(event_type, data)`.
658+
/// Resolve a decoded [`SseFrame`] into Anthropic's `(event_type, data)` pair.
665659
///
666-
/// SSE frames look like:
667-
/// ```text
668-
/// event: content_block_start
669-
/// data: {"type":"content_block_start",...}
670-
/// ```
671-
fn parse_sse_frame(frame: &str) -> Option<(String, String)> {
672-
let mut event_type = String::new();
673-
let mut data_lines = Vec::new();
674-
675-
for line in frame.lines() {
676-
let line = line.trim();
677-
if line.is_empty() {
678-
continue;
679-
}
680-
if let Some(value) = line.strip_prefix("event:") {
681-
event_type = value.trim().to_string();
682-
} else if let Some(value) = line.strip_prefix("data:") {
683-
data_lines.push(value.trim().to_string());
684-
}
685-
}
686-
687-
if data_lines.is_empty() {
688-
return None;
689-
}
690-
691-
let data = data_lines.join("\n");
692-
693-
// If no event type specified, try to infer from data
694-
if event_type.is_empty() {
695-
if let Ok(parsed) = serde_json::from_str::<Value>(&data) {
696-
if let Some(t) = parsed.get("type").and_then(|v| v.as_str()) {
697-
event_type = t.to_string();
698-
}
660+
/// Anthropic always emits an explicit `event:` line, but as a fallback
661+
/// (preserving prior behavior) the event type is inferred from the payload's
662+
/// `"type"` field when the `event:` line is absent or empty. Returns `None`
663+
/// when no event type can be determined.
664+
///
665+
/// The event type is returned as a `Cow` so the common case (explicit `event:`
666+
/// line) borrows straight from the decoded frame with no allocation; only the
667+
/// rare inference fallback allocates.
668+
fn resolve_event(frame: SseFrame<'_>) -> Option<(Cow<'_, str>, Cow<'_, str>)> {
669+
let event_type = match frame.event_type {
670+
Some(e) if !e.is_empty() => e,
671+
_ => {
672+
let parsed: Value = serde_json::from_str(&frame.data).ok()?;
673+
Cow::Owned(parsed.get("type")?.as_str()?.to_string())
699674
}
700-
}
675+
};
701676

702677
if event_type.is_empty() {
703678
return None;
704679
}
705680

706-
Some((event_type, data))
681+
Some((event_type, frame.data))
707682
}
708683

709684
// ============================================================================
@@ -714,34 +689,79 @@ fn parse_sse_frame(frame: &str) -> Option<(String, String)> {
714689
mod tests {
715690
use super::*;
716691

692+
/// Decode `bytes` through the shared `SseDecoder` and resolve each frame
693+
/// to `(event_type, data)` the way `consume_and_forward` does.
694+
fn decode_events(bytes: &[u8]) -> Vec<(String, String)> {
695+
let mut decoder = SseDecoder::new();
696+
decoder.push(bytes).unwrap();
697+
let mut out = Vec::new();
698+
while let Some(frame) = decoder.next_frame() {
699+
if let Some((event_type, data)) = resolve_event(frame.unwrap()) {
700+
out.push((event_type.into_owned(), data.into_owned()));
701+
}
702+
}
703+
out
704+
}
705+
717706
#[test]
718-
fn test_parse_sse_frame_basic() {
719-
let frame = "event: message_start\ndata: {\"type\":\"message_start\"}";
720-
let (event_type, data) = parse_sse_frame(frame).unwrap();
721-
assert_eq!(event_type, "message_start");
707+
fn test_resolve_event_basic() {
708+
let frame = SseFrame {
709+
event_type: Some(Cow::Borrowed("message_start")),
710+
data: Cow::Borrowed("{\"type\":\"message_start\"}"),
711+
};
712+
let (event_type, data) = resolve_event(frame).unwrap();
713+
assert_eq!(event_type.as_ref(), "message_start");
722714
assert_eq!(data, "{\"type\":\"message_start\"}");
723715
}
724716

725717
#[test]
726-
fn test_parse_sse_frame_content_block() {
727-
let frame = "event: content_block_start\ndata: {\"type\":\"content_block_start\",\"index\":0,\"content_block\":{\"type\":\"text\",\"text\":\"\"}}";
728-
let (event_type, data) = parse_sse_frame(frame).unwrap();
729-
assert_eq!(event_type, "content_block_start");
730-
let parsed: Value = serde_json::from_str(&data).unwrap();
731-
assert_eq!(parsed["index"], 0);
718+
fn test_resolve_event_no_event_type_infers() {
719+
// No `event:` line -> infer the type from the payload's "type" field.
720+
let frame = SseFrame {
721+
event_type: None,
722+
data: Cow::Borrowed("{\"type\":\"ping\"}"),
723+
};
724+
let (event_type, _data) = resolve_event(frame).unwrap();
725+
assert_eq!(event_type.as_ref(), "ping");
726+
}
727+
728+
#[test]
729+
fn test_resolve_event_uninferable_is_none() {
730+
// No event type and no "type" field in the payload -> skipped.
731+
let frame = SseFrame {
732+
event_type: None,
733+
data: Cow::Borrowed("{\"foo\":1}"),
734+
};
735+
assert!(resolve_event(frame).is_none());
732736
}
733737

734738
#[test]
735-
fn test_parse_sse_frame_no_event_type_infers() {
736-
let frame = "data: {\"type\":\"ping\"}";
737-
let (event_type, _data) = parse_sse_frame(frame).unwrap();
738-
assert_eq!(event_type, "ping");
739+
fn test_decode_events_basic() {
740+
let events = decode_events(
741+
b"event: message_start\ndata: {\"type\":\"message_start\"}\n\nevent: ping\ndata: {\"type\":\"ping\"}\n\n",
742+
);
743+
assert_eq!(events.len(), 2);
744+
assert_eq!(events[0].0, "message_start");
745+
assert_eq!(events[1].0, "ping");
746+
}
747+
748+
#[test]
749+
fn test_decode_events_content_block() {
750+
let events = decode_events(b"event: content_block_start\ndata: {\"type\":\"content_block_start\",\"index\":0,\"content_block\":{\"type\":\"text\",\"text\":\"\"}}\n\n");
751+
assert_eq!(events.len(), 1);
752+
assert_eq!(events[0].0, "content_block_start");
753+
let parsed: Value = serde_json::from_str(&events[0].1).unwrap();
754+
assert_eq!(parsed["index"], 0);
739755
}
740756

741757
#[test]
742-
fn test_parse_sse_frame_empty() {
743-
assert!(parse_sse_frame("").is_none());
744-
assert!(parse_sse_frame("event: foo").is_none());
758+
fn test_decode_events_infers_event_type() {
759+
// A `data:`-only frame (no `event:` line) infers its type from the payload.
760+
let events = decode_events(b"data: {\"type\":\"ping\"}\n\n");
761+
assert_eq!(
762+
events,
763+
vec![("ping".to_string(), "{\"type\":\"ping\"}".to_string())]
764+
);
745765
}
746766

747767
#[test]
@@ -755,10 +775,19 @@ mod tests {
755775
}
756776

757777
#[test]
758-
fn test_parse_sse_frame_with_extra_whitespace() {
759-
let frame = " event: content_block_delta \n data: {\"type\":\"content_block_delta\",\"index\":1,\"delta\":{\"type\":\"input_json_delta\",\"partial_json\":\"{}\"}} ";
760-
let (event_type, data) = parse_sse_frame(frame).unwrap();
761-
assert_eq!(event_type, "content_block_delta");
778+
fn test_decode_events_split_across_chunks() {
779+
// A multi-byte-safe decoder must reassemble frames split mid-stream.
780+
let mut decoder = SseDecoder::new();
781+
decoder
782+
.push(b"event: content_block_delta\ndata: {\"type\":\"content_block_de")
783+
.unwrap();
784+
assert!(decoder.next_frame().is_none()); // incomplete
785+
decoder
786+
.push(b"lta\",\"index\":1,\"delta\":{\"partial_json\":\"{}\"}}\n\n")
787+
.unwrap();
788+
let frame = decoder.next_frame().unwrap().unwrap();
789+
let (event_type, data) = resolve_event(frame).unwrap();
790+
assert_eq!(event_type.as_ref(), "content_block_delta");
762791
let parsed: Value = serde_json::from_str(&data).unwrap();
763792
assert_eq!(parsed["index"], 1);
764793
}

0 commit comments

Comments
 (0)