Skip to content

Commit 177c58c

Browse files
authored
Merge pull request #144 from nearai/feat/agent-loop-web-context-search
feat(agent-loop): server-side web_context_search loop for /v1/chat/completions
2 parents 7b67f54 + 882b3b6 commit 177c58c

9 files changed

Lines changed: 2874 additions & 10 deletions

File tree

benches/e2e.rs

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -87,6 +87,10 @@ fn build_test_app(mock_url: &str) -> axum::Router {
8787
dstack_socket_path: "/var/run/dstack.sock".to_string(),
8888
gpu_evidence_delegate_url: None,
8989
gpu_evidence_delegate_timeout_secs: 30,
90+
web_context_search_url: None,
91+
web_context_search_api_key: None,
92+
agent_loop_max_iterations: 5,
93+
web_context_search_timeout_secs: 30,
9094
};
9195

9296
let ecdsa = signing::EcdsaContext::from_key_bytes(&ECDSA_KEY).unwrap();

src/agent_loop.rs

Lines changed: 1454 additions & 0 deletions
Large diffs are not rendered by default.

src/config.rs

Lines changed: 40 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -160,6 +160,21 @@ pub struct Config {
160160
/// unreachable (otherwise `/v1/attestation/report` silently 500s while
161161
/// `/v1/models` still passes). Default: `/var/run/dstack.sock`.
162162
pub dstack_socket_path: String,
163+
164+
// Agent loop (server-side web_context_search tool)
165+
/// Brave LLM Context API endpoint. When unset, requests advertising the
166+
/// `{"type":"web_context_search"}` tool are rejected with 400. All
167+
/// tool execution happens inside the CVM; the query is the only thing
168+
/// that egresses, going directly to Brave under TLS.
169+
pub web_context_search_url: Option<String>,
170+
/// Brave subscription token for the LLM Context endpoint. Sent as
171+
/// `X-Subscription-Token`. Required when `web_context_search_url` is set.
172+
pub web_context_search_api_key: Option<String>,
173+
/// Hard cap on tool-call iterations within a single chat completion.
174+
/// Once hit, the loop emits a synthetic terminator chunk and stops.
175+
pub agent_loop_max_iterations: u32,
176+
/// Per-tool-call timeout for the Brave HTTP request.
177+
pub web_context_search_timeout_secs: u64,
163178
}
164179

165180
impl Config {
@@ -316,6 +331,18 @@ impl Config {
316331
rerank_url_override,
317332
score_url_override,
318333
dstack_socket_path: env_or("DSTACK_SOCKET_PATH", "/var/run/dstack.sock"),
334+
web_context_search_url: env::var("WEB_CONTEXT_SEARCH_URL")
335+
.ok()
336+
.filter(|s| !s.is_empty()),
337+
web_context_search_api_key: env::var("WEB_CONTEXT_SEARCH_API_KEY")
338+
.ok()
339+
.filter(|s| !s.is_empty()),
340+
// `env_int` returns `usize`; on 64-bit hosts a user-supplied value
341+
// > u32::MAX would silently wrap. `try_from` surfaces it as a
342+
// config error instead so a typo can't become a tiny iteration cap.
343+
agent_loop_max_iterations: u32::try_from(env_int("AGENT_LOOP_MAX_ITERATIONS", 5))
344+
.map_err(|_| anyhow::anyhow!("AGENT_LOOP_MAX_ITERATIONS exceeds the u32 range"))?,
345+
web_context_search_timeout_secs: env_int("WEB_CONTEXT_SEARCH_TIMEOUT_SECS", 30) as u64,
319346
};
320347

321348
// Validate attestation cache TTL (TTL/2 is used as refresh interval, so TTL < 2 would cause a busy loop)
@@ -334,6 +361,19 @@ impl Config {
334361
anyhow::bail!("STARTUP_CHECK_TIMEOUT_SECS must be greater than 0");
335362
}
336363

364+
// Agent loop: URL and key must both be set or both unset; iteration cap must be positive.
365+
if config.web_context_search_url.is_some() != config.web_context_search_api_key.is_some() {
366+
anyhow::bail!(
367+
"WEB_CONTEXT_SEARCH_URL and WEB_CONTEXT_SEARCH_API_KEY must both be set or both unset"
368+
);
369+
}
370+
if config.agent_loop_max_iterations == 0 {
371+
anyhow::bail!("AGENT_LOOP_MAX_ITERATIONS must be at least 1");
372+
}
373+
if config.web_context_search_timeout_secs == 0 {
374+
anyhow::bail!("WEB_CONTEXT_SEARCH_TIMEOUT_SECS must be greater than 0");
375+
}
376+
337377
Ok(config)
338378
}
339379
}

src/encryption.rs

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -718,6 +718,19 @@ fn encrypt_chat_response_choices(
718718
encrypt_field(audio, "data", ctx, signing)?;
719719
}
720720

721+
// Synthetic agent-loop tool result chunk (streaming only).
722+
// Emitted by `agent_loop::run_chat_completion` between
723+
// iterations as `delta.nearai_tool_result.output`. Encrypted
724+
// unconditionally whenever an encryption context is active —
725+
// the agent loop is the privacy-critical path, so the search
726+
// output must travel encrypted whether or not the client set
727+
// `X-Encrypt-All-Fields`. The field only appears when the
728+
// client also opted into server-side tool execution by
729+
// sending `tools: [{"type":"web_context_search"}]`.
730+
if let Some(tool_result) = msg.get_mut("nearai_tool_result") {
731+
encrypt_field(tool_result, "output", ctx, signing)?;
732+
}
733+
721734
// Extended fields — only when client opts in
722735
if ctx.encrypt_all_fields {
723736
encrypt_field(msg, "refusal", ctx, signing)?;

src/lib.rs

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@ use axum::middleware::Next;
55
use axum::response::Response;
66
use tracing::Instrument;
77

8+
pub mod agent_loop;
89
pub mod attestation;
910
pub mod attestation_sdk;
1011
pub mod auth;

src/proxy.rs

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -310,7 +310,7 @@ fn try_report_usage(response_data: &serde_json::Value, id: &str, opts: &ProxyOpt
310310
/// service-token path (`/v1/internal/usage`) when [`UsageReporter::can_use_service_token_path`]
311311
/// is true; otherwise falls back to the legacy sk-bearer path
312312
/// (`/v1/usage`) verbatim.
313-
fn spawn_usage_report(reporter: &UsageReporter, mut body: serde_json::Value) {
313+
pub(crate) fn spawn_usage_report(reporter: &UsageReporter, mut body: serde_json::Value) {
314314
let client = reporter.http_client.clone();
315315
let (url, auth, mode) = if reporter.can_use_service_token_path() {
316316
// Inject subject identity into the body. Cloud-api's
@@ -1169,7 +1169,7 @@ pub async fn proxy_streaming_request(
11691169
/// (vLLM `qwen3` parser as of v0.10) so downstream clients see the standard
11701170
/// `delta.reasoning_content` field consistently across reasoning models.
11711171
/// If both fields are present the existing `reasoning_content` is kept.
1172-
fn normalize_chat_chunk(val: &mut serde_json::Value) {
1172+
pub(crate) fn normalize_chat_chunk(val: &mut serde_json::Value) {
11731173
let Some(choices) = val.get_mut("choices").and_then(|c| c.as_array_mut()) else {
11741174
return;
11751175
};
@@ -1652,10 +1652,10 @@ pub async fn proxy_streaming_response(
16521652

16531653
/// Drop guard that tracks the streaming_connections gauge.
16541654
/// Increments on creation, decrements on drop — guarantees they stay paired.
1655-
struct StreamingGuard;
1655+
pub(crate) struct StreamingGuard;
16561656

16571657
impl StreamingGuard {
1658-
fn new() -> Self {
1658+
pub(crate) fn new() -> Self {
16591659
metrics::gauge!("streaming_connections").increment(1);
16601660
Self
16611661
}

src/routes/chat.rs

Lines changed: 48 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@ use axum::Extension;
66

77
use sha2::Digest;
88

9+
use crate::agent_loop;
910
use crate::auth::RequireAuth;
1011
use crate::encryption::{self, Endpoint};
1112
use crate::error::AppError;
@@ -35,11 +36,8 @@ pub async fn chat_completions(
3536
// honored when authenticated with config.token (trusted gateway); sk- clients
3637
// always bind signatures to the wire body so they cannot forge a hash for a
3738
// different payload.
38-
let original_request_hash = Some(resolve_request_hash_for_signing(
39-
&headers,
40-
&request_body,
41-
auth.cloud_api_key.is_none(),
42-
));
39+
let request_hash =
40+
resolve_request_hash_for_signing(&headers, &request_body, auth.cloud_api_key.is_none());
4341

4442
// Decrypt request fields if encryption is active
4543
if let Some(ref ctx) = enc_ctx {
@@ -56,6 +54,50 @@ pub async fn chat_completions(
5654
.and_then(|v| v.as_bool())
5755
.unwrap_or(false);
5856

57+
// Server-side agent loop opt-in: the request advertises exactly
58+
// `{"type":"web_context_search"}` and nothing else. Requires streaming
59+
// (we splice tool-result chunks between iterations) and Brave creds
60+
// configured on this CVM. Anything that doesn't match falls through to
61+
// the existing pass-through path below, byte-for-byte identical.
62+
if agent_loop::is_web_context_search_request(&request_json) {
63+
if !is_stream {
64+
return Err(AppError::BadRequest(
65+
"web_context_search requires stream:true".to_string(),
66+
));
67+
}
68+
if state.config.web_context_search_url.is_none()
69+
|| state.config.web_context_search_api_key.is_none()
70+
{
71+
return Err(AppError::BadRequest(
72+
"web_context_search is not configured on this deployment".to_string(),
73+
));
74+
}
75+
76+
// Build chunk transform if E2EE is active. The agent loop is the
77+
// privacy-critical path — both our synthetic `nearai_tool_result`
78+
// chunks AND the model's own `tool_calls[].function.{name,arguments}`
79+
// (which contain the search query the model just generated from the
80+
// user's E2EE-decrypted prompt) must travel encrypted. Force
81+
// `encrypt_all_fields: true` on the context used to build this
82+
// transform so clients don't need to remember to send
83+
// `X-Encrypt-All-Fields: true` to get the full privacy guarantee.
84+
// This only affects the agent-loop path; the regular chat path
85+
// below still honors the client's `X-Encrypt-All-Fields` choice.
86+
let chunk_transform = enc_ctx.map(|mut ctx| {
87+
ctx.encrypt_all_fields = true;
88+
encryption::make_chunk_transform(Endpoint::ChatCompletions, ctx, state.signing.clone())
89+
});
90+
return agent_loop::run_chat_completion(
91+
state,
92+
auth,
93+
tracing_ids,
94+
request_hash,
95+
request_json,
96+
chunk_transform,
97+
)
98+
.await;
99+
}
100+
59101
// For cloud API key requests with streaming, force include_usage
60102
// so the backend always sends token counts for billing.
61103
// (Non-streaming requests also stream internally via proxy_json_request,
@@ -102,7 +144,7 @@ pub async fn chat_completions(
102144
model_name: state.config.model_name.clone(),
103145
usage_reporter: make_usage_reporter(&auth, &state),
104146
usage_type: UsageType::ChatCompletion,
105-
request_hash: original_request_hash,
147+
request_hash: Some(request_hash),
106148
response_transform,
107149
chunk_transform,
108150
backend_guard: Some(guard),

0 commit comments

Comments
 (0)