Skip to content

Commit 882b3b6

Browse files
committed
agent-loop: sanitize upstream errors + cap Brave response
PR #144 round 3: 1. Inline SSE error chunks were forwarded verbatim. The chunk transform only encrypts `delta`/`message` fields, so `error.message` stayed plaintext on the wire — and backends sometimes echo input/validation details in those messages, which under E2EE is data we decrypted inside the CVM. Fix: - Unified the existing `emit_upstream_error_chunk` into a more general `emit_synthetic_error_chunk(message, error_type, code)` that always emits a controlled message string — no upstream-provided text is passed through. - `run_iteration` now detects `outcome.upstream_error.is_some()` immediately after `ingest_chunk_metadata`, replaces the offending chunk with a sanitized synthetic, and breaks the iteration. The original upstream chunk is never forwarded. - Both the mid-loop HTTP-non-2xx path and the inline SSE-error path funnel through the same helper. 2. Brave response was read with `response.text().await` (unbounded) and `format_context_response` walked all entries with no size cap. A misconfigured or compromised search backend could allocate arbitrarily inside the proxy. - `BRAVE_MAX_RESPONSE_BYTES = 2 MiB` enforced via streaming `bytes_stream` read; oversize → `BraveError::Other` → tool-error chunk, loop continues. - `MAX_FORMATTED_OUTPUT_BYTES = 32 KiB` enforced inside `format_context_response`; truncates on a UTF-8 boundary with a `\n[truncated]` marker. Same string is both emitted downstream and fed back to the model on the next iteration, so this also bounds prompt growth across iterations. 3. The earlier disconnect regression test created a new app for the `/v1/signature/...` check — fresh cache, so 404 was guaranteed regardless of what the loop did. Now clones the `Router` (and therefore its `AppState`/`ChatCache`) and queries the same instance, so the 404 actually proves no signature was written. New regression tests: - `upstream_error_message_does_not_leak_prompt_or_query`: upstream emits an error chunk whose `message` echoes the user's prompt; the proxy's response body contains the sanitized synthetic message and does NOT contain the prompt fragment. - `brave_response_oversized_body_is_rejected`: Brave returns 3 MiB body (above the 2 MiB cap); the loop surfaces it as `status:"error"` and continues; no body bytes leak into the stream. - `brave_formatted_output_is_truncated`: Brave returns 200 entries with ~600 B snippets each (~120 KB formatted); the emitted `nearai_tool_result.output` is capped at 32 KiB + marker. Existing `mid_stream_error_chunk_skips_done_and_signature` updated: now asserts the sanitized message is on the wire and the original `"BadRequestError"` / `"request was aborted"` strings are NOT. Totals: 302 unit + 1 main + 15 agent_loop (was 12) + 134 integration = 452 tests, all pass. Clippy + fmt clean.
1 parent ac2e749 commit 882b3b6

2 files changed

Lines changed: 397 additions & 44 deletions

File tree

src/agent_loop.rs

Lines changed: 115 additions & 32 deletions
Original file line numberDiff line numberDiff line change
@@ -38,6 +38,18 @@ pub const WEB_CONTEXT_SEARCH_TOOL_NAME: &str = "web_context_search";
3838
const NEARAI_TOOL_RESULT_KEY: &str = "nearai_tool_result";
3939
const NEARAI_LOOP_TERMINATED_KEY: &str = "nearai_loop_terminated";
4040

41+
/// Hard cap on bytes read from Brave's response body. The defaults we send
42+
/// (`maximum_number_of_tokens=8192`, `maximum_number_of_urls=20`) should
43+
/// produce well under this; the cap is a backstop against a misconfigured
44+
/// or malicious search endpoint that returns an unbounded body.
45+
const BRAVE_MAX_RESPONSE_BYTES: usize = 2 * 1024 * 1024;
46+
47+
/// Hard cap on the formatted tool output that we emit downstream and feed
48+
/// back to the model on the next iteration. Independent of Brave's input
49+
/// caps so we don't depend on the upstream respecting them. Beyond this,
50+
/// the output is truncated with a marker.
51+
const MAX_FORMATTED_OUTPUT_BYTES: usize = 32 * 1024;
52+
4153
/// True iff the request's `tools` field is exactly one entry of type
4254
/// `web_context_search`. Mixed tool types or multiple entries return false
4355
/// and let the request flow through the existing pass-through path.
@@ -338,16 +350,22 @@ async fn drive_loop(
338350
// Surface the failure to the client as an SSE error
339351
// chunk so they don't see a silent stall. No `[DONE]`
340352
// and no signature — this is not a successful
341-
// completion.
342-
emit_upstream_error_chunk(
353+
// completion. The `body` from upstream is dropped
354+
// (not forwarded) because it can contain
355+
// provider-side internals or user data.
356+
let _ = body;
357+
let status_code = status.as_u16();
358+
let code_str = status_code.to_string();
359+
emit_synthetic_error_chunk(
343360
ctx.tx,
344361
ctx.chunk_transform,
345362
&mut hasher,
346363
chat_id.as_deref(),
347364
model_echo.as_deref(),
348365
created,
349-
status.as_u16(),
350-
&body,
366+
&format!("upstream returned HTTP {status_code} on a follow-up tool-loop iteration"),
367+
"upstream_error",
368+
Some(&code_str),
351369
)
352370
.await?;
353371
terminated_by = "upstream_error";
@@ -722,6 +740,33 @@ async fn run_iteration(
722740
};
723741
ingest_chunk_metadata(&parsed, &mut outcome);
724742

743+
// SGLang and friends emit top-level
744+
// `{"error": {...}}` chunks on aborts.
745+
// Do NOT forward the upstream chunk
746+
// verbatim — `error.message` is outside
747+
// what the chunk transform encrypts, and
748+
// backends may put validation
749+
// input/request details (which under E2EE
750+
// is data we decrypted inside the CVM)
751+
// into it. Replace with a sanitized
752+
// synthetic error chunk and abort the
753+
// iteration.
754+
if outcome.upstream_error.is_some() {
755+
emit_synthetic_error_chunk(
756+
ctx.tx,
757+
ctx.chunk_transform,
758+
ctx.hasher,
759+
outcome.chat_id.as_deref(),
760+
outcome.model.as_deref(),
761+
outcome.created,
762+
"upstream emitted an error chunk; response was aborted",
763+
"upstream_error",
764+
None,
765+
)
766+
.await?;
767+
break 'outer;
768+
}
769+
725770
if let Some(new_id) = ctx.rewrite_id_to {
726771
if parsed.get("id").is_some() {
727772
parsed["id"] = Value::String(new_id.to_string());
@@ -947,44 +992,47 @@ async fn emit_tool_result_chunk(
947992
Ok(())
948993
}
949994

950-
/// Emit an OpenAI-shaped error chunk to the client when an upstream
951-
/// iteration past iter 0 returns non-2xx. We've already sent
952-
/// `200 text/event-stream` headers, so we can't change the HTTP status;
953-
/// instead surface the failure as a `data: {"error": {...}}` chunk and
954-
/// close the stream without `[DONE]` so the response isn't signed.
995+
/// Emit a sanitized OpenAI-shaped error chunk to the client when an
996+
/// upstream failure happens after `200 text/event-stream` headers have
997+
/// already been sent. The message text is controlled by us — we don't
998+
/// pass through upstream-provided strings, which under E2EE could
999+
/// contain prompt fragments or other user data the upstream backend
1000+
/// echoed back. Closing without `[DONE]` keeps the response unsigned.
9551001
#[allow(clippy::too_many_arguments)]
956-
async fn emit_upstream_error_chunk(
1002+
async fn emit_synthetic_error_chunk(
9571003
tx: &tokio::sync::mpsc::Sender<Result<Bytes, std::io::Error>>,
9581004
chunk_transform: Option<&ChunkTransform>,
9591005
hasher: &mut Sha256,
9601006
chat_id: Option<&str>,
9611007
model: Option<&str>,
9621008
created: Option<i64>,
963-
status_code: u16,
964-
upstream_body: &[u8],
1009+
message: &str,
1010+
error_type: &str,
1011+
code: Option<&str>,
9651012
) -> Result<(), AppError> {
966-
// Don't leak upstream body bytes verbatim — they could contain provider
967-
// internals. Surface the status code only.
968-
let _ = upstream_body;
1013+
let mut error = json!({
1014+
"message": message,
1015+
"type": error_type,
1016+
});
1017+
if let Some(c) = code {
1018+
error["code"] = c.into();
1019+
}
9691020
let mut chunk = json!({
9701021
"id": chat_id.unwrap_or(""),
9711022
"object": "chat.completion.chunk",
9721023
"choices": [],
973-
"error": {
974-
"message": format!("upstream returned HTTP {status_code} on a follow-up tool-loop iteration"),
975-
"type": "upstream_error",
976-
"code": status_code.to_string(),
977-
}
1024+
"error": error,
9781025
});
9791026
if let Some(m) = model {
9801027
chunk["model"] = m.into();
9811028
}
9821029
if let Some(c) = created {
9831030
chunk["created"] = c.into();
9841031
}
985-
// Pass through the chunk transform so E2EE clients still get a
986-
// well-formed (encrypted-where-applicable) error chunk. The `error`
987-
// object itself has no encryptable string fields by design.
1032+
// Pass through the chunk transform for shape parity with normal
1033+
// chunks. The `error` object has no encryptable string fields under
1034+
// the current transform — but the message text we put here is
1035+
// controlled by us, so this is safe regardless.
9881036
if let Some(transform) = chunk_transform {
9891037
transform(&mut chunk)?;
9901038
}
@@ -1101,11 +1149,25 @@ async fn brave_llm_context_search(
11011149
return Err(BraveError::Other(format!("brave HTTP {}", status.as_u16())));
11021150
}
11031151

1104-
let body = response.text().await.map_err(|e| {
1105-
BraveError::Other(format!("brave body read failed: {}", error_category(&e)))
1106-
})?;
1152+
// Read the body in chunks so we can enforce a hard size cap regardless
1153+
// of `Content-Length`. The defaults we send keep responses well under
1154+
// `BRAVE_MAX_RESPONSE_BYTES`; this cap is a backstop in case the search
1155+
// endpoint is misconfigured or compromised.
1156+
let mut body_bytes: Vec<u8> = Vec::with_capacity(16 * 1024);
1157+
let mut body_stream = response.bytes_stream();
1158+
while let Some(chunk) = body_stream.next().await {
1159+
let chunk = chunk.map_err(|e| {
1160+
BraveError::Other(format!("brave body read failed: {}", error_category(&e)))
1161+
})?;
1162+
if body_bytes.len().saturating_add(chunk.len()) > BRAVE_MAX_RESPONSE_BYTES {
1163+
return Err(BraveError::Other(format!(
1164+
"brave response exceeded {BRAVE_MAX_RESPONSE_BYTES}-byte cap"
1165+
)));
1166+
}
1167+
body_bytes.extend_from_slice(&chunk);
1168+
}
11071169

1108-
let parsed: BraveContextResponse = serde_json::from_str(&body)
1170+
let parsed: BraveContextResponse = serde_json::from_slice(&body_bytes)
11091171
.map_err(|e| BraveError::Other(format!("brave JSON parse failed: {e}")))?;
11101172

11111173
Ok(format_context_response(&parsed))
@@ -1162,10 +1224,16 @@ struct BraveContextSource {
11621224
/// consume directly. Skips entries with no URL or no usable snippets; falls
11631225
/// back to sources[url].title when the grounding entry has no title of its
11641226
/// own. Mirrors `context_response_to_web_results` in cloud-api/brave.rs.
1227+
///
1228+
/// Truncates at `MAX_FORMATTED_OUTPUT_BYTES` with a marker, since we don't
1229+
/// want to depend on Brave honoring its input caps. The truncated output
1230+
/// is what we both emit to the client AND feed back to the model on the
1231+
/// next iteration, so this also bounds prompt growth across iterations.
11651232
fn format_context_response(resp: &BraveContextResponse) -> String {
11661233
let mut out = String::new();
11671234
let mut n: u32 = 0;
1168-
for entry in &resp.grounding.generic {
1235+
let mut truncated = false;
1236+
'outer: for entry in &resp.grounding.generic {
11691237
let url = entry.url.trim();
11701238
if url.is_empty() {
11711239
continue;
@@ -1195,11 +1263,26 @@ fn format_context_response(resp: &BraveContextResponse) -> String {
11951263
})
11961264
.unwrap_or_else(|| url.to_string());
11971265
n += 1;
1198-
if n > 1 {
1199-
out.push_str("\n\n");
1266+
let separator = if n > 1 { "\n\n" } else { "" };
1267+
let header = format!("{separator}[{n}] {title}\n{url}\n");
1268+
let joined = snippets.join("\n\n");
1269+
for piece in [header.as_str(), joined.as_str()] {
1270+
if out.len() + piece.len() > MAX_FORMATTED_OUTPUT_BYTES {
1271+
let remaining = MAX_FORMATTED_OUTPUT_BYTES.saturating_sub(out.len());
1272+
// Find a UTF-8 char boundary at or before `remaining`.
1273+
let mut cut = remaining;
1274+
while cut > 0 && !piece.is_char_boundary(cut) {
1275+
cut -= 1;
1276+
}
1277+
out.push_str(&piece[..cut]);
1278+
truncated = true;
1279+
break 'outer;
1280+
}
1281+
out.push_str(piece);
12001282
}
1201-
out.push_str(&format!("[{n}] {title}\n{url}\n"));
1202-
out.push_str(&snippets.join("\n\n"));
1283+
}
1284+
if truncated {
1285+
out.push_str("\n[truncated]");
12031286
}
12041287
if out.is_empty() {
12051288
"No results.".to_string()

0 commit comments

Comments
 (0)