Skip to content

Commit e17abc5

Browse files
committed
feat(acp-nats): add Bridge infrastructure for session cancel and ready
Add CancelledSessions, spawn_session_ready, session-ready task tracking, cancel_waiter_for_session, new NATS wildcard subjects, and parsing consolidation. Includes comprehensive unit tests for all new code paths. Signed-off-by: Yordis Prieto <yordis.prieto@gmail.com>
1 parent cbd2cf8 commit e17abc5

File tree

8 files changed

+526
-58
lines changed

8 files changed

+526
-58
lines changed

rsworkspace/crates/acp-nats/src/agent/mod.rs

Lines changed: 290 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,7 @@
1+
// Safety: `CancelledSessions` uses `Mutex` for interior mutability so the `Bridge` remains
2+
// `Send`/`Sync`. The fire-and-forget publish in `spawn_session_ready` uses `tokio::spawn`
3+
// and captures only cloned, `Send` values — it never touches shared state from the closure.
4+
15
mod authenticate;
26
mod cancel;
37
mod ext_method;
@@ -8,27 +12,85 @@ mod new_session;
812
mod prompt;
913
mod set_session_mode;
1014

11-
use crate::config::Config;
12-
use crate::nats::{FlushClient, PublishClient, RequestClient};
15+
use crate::config::{Config, SESSION_READY_DELAY};
16+
use crate::nats::{self, ExtSessionReady, FlushClient, PublishClient, RequestClient, agent};
1317
use crate::pending_prompt_waiters::PendingSessionPromptResponseWaiters;
1418
use crate::prompt_slot_counter::PromptSlotCounter;
1519
use crate::telemetry::metrics::Metrics;
1620
use agent_client_protocol::{
1721
Agent, AuthenticateRequest, AuthenticateResponse, CancelNotification, ExtNotification,
1822
ExtRequest, ExtResponse, InitializeRequest, InitializeResponse, LoadSessionRequest,
1923
LoadSessionResponse, NewSessionRequest, NewSessionResponse, PromptRequest, PromptResponse,
20-
Result, SetSessionModeRequest, SetSessionModeResponse,
24+
Result, SessionId, SetSessionModeRequest, SetSessionModeResponse,
2125
};
2226
use opentelemetry::metrics::Meter;
27+
use std::collections::HashMap;
28+
use std::sync::Mutex;
29+
use std::time::Duration;
30+
use tracing::{info, warn};
2331
use trogon_std::time::GetElapsed;
2432

33+
#[allow(dead_code)]
34+
const CANCELLED_SESSION_TTL: Duration = Duration::from_secs(300);
35+
#[allow(dead_code)]
36+
const CLEANUP_EVERY: usize = 16;
37+
38+
#[allow(dead_code)]
39+
pub(crate) struct CancelledSessions<I: Copy> {
40+
map: Mutex<HashMap<SessionId, I>>,
41+
cleanup_counter: std::sync::atomic::AtomicUsize,
42+
}
43+
44+
#[allow(dead_code)]
45+
impl<I: Copy> CancelledSessions<I> {
46+
pub fn new() -> Self {
47+
Self {
48+
map: Mutex::new(HashMap::new()),
49+
cleanup_counter: std::sync::atomic::AtomicUsize::new(0),
50+
}
51+
}
52+
53+
pub fn mark_cancelled<C: GetElapsed<Instant = I>>(&self, session_id: SessionId, clock: &C) {
54+
let mut map = self.map.lock().unwrap();
55+
map.insert(session_id, clock.now());
56+
let count = self
57+
.cleanup_counter
58+
.fetch_add(1, std::sync::atomic::Ordering::Relaxed);
59+
if count.is_multiple_of(CLEANUP_EVERY) {
60+
map.retain(|_, ts| clock.elapsed(*ts) < CANCELLED_SESSION_TTL);
61+
}
62+
}
63+
64+
pub fn take_if_cancelled<C: GetElapsed<Instant = I>>(
65+
&self,
66+
session_id: &SessionId,
67+
clock: &C,
68+
) -> Option<()> {
69+
let mut map = self.map.lock().unwrap();
70+
let is_valid = map
71+
.get(session_id)
72+
.is_some_and(|ts| clock.elapsed(*ts) < CANCELLED_SESSION_TTL);
73+
74+
if is_valid {
75+
map.remove(session_id);
76+
Some(())
77+
} else {
78+
map.remove(session_id);
79+
None
80+
}
81+
}
82+
}
83+
2584
pub struct Bridge<N: RequestClient + PublishClient + FlushClient, C: GetElapsed> {
2685
pub(crate) nats: N,
2786
pub(crate) clock: C,
2887
pub(crate) metrics: Metrics,
88+
#[allow(dead_code)]
89+
pub(crate) cancelled_sessions: CancelledSessions<C::Instant>,
2990
pub(crate) pending_session_prompt_responses: PendingSessionPromptResponseWaiters<C::Instant>,
3091
pub(crate) prompt_slot_counter: PromptSlotCounter,
3192
pub(crate) config: Config,
93+
pub(crate) session_ready_publish_tasks: Mutex<Vec<tokio::task::JoinHandle<()>>>,
3294
}
3395

3496
impl<N: RequestClient + PublishClient + FlushClient, C: GetElapsed> Bridge<N, C> {
@@ -39,14 +101,73 @@ impl<N: RequestClient + PublishClient + FlushClient, C: GetElapsed> Bridge<N, C>
39101
clock,
40102
config,
41103
metrics: Metrics::new(meter),
104+
cancelled_sessions: CancelledSessions::new(),
42105
pending_session_prompt_responses: PendingSessionPromptResponseWaiters::new(),
43106
prompt_slot_counter: PromptSlotCounter::new(max_concurrent),
107+
session_ready_publish_tasks: Mutex::new(Vec::new()),
44108
}
45109
}
46110

47111
pub(crate) fn nats(&self) -> &N {
48112
&self.nats
49113
}
114+
115+
#[allow(dead_code)]
116+
pub(crate) fn register_session_ready_task(&self, task: tokio::task::JoinHandle<()>) {
117+
let mut tasks = self.session_ready_publish_tasks.lock().unwrap();
118+
tasks.retain(|task| !task.is_finished());
119+
tasks.push(task);
120+
}
121+
122+
pub fn has_pending_session_ready_tasks(&self) -> bool {
123+
let mut tasks = self.session_ready_publish_tasks.lock().unwrap();
124+
tasks.retain(|task| !task.is_finished());
125+
!tasks.is_empty()
126+
}
127+
128+
pub async fn await_session_ready_tasks(&self) {
129+
let tasks = std::mem::take(&mut *self.session_ready_publish_tasks.lock().unwrap());
130+
for task in tasks {
131+
if let Err(e) = task.await {
132+
warn!(error = %e, "session_ready task panicked");
133+
}
134+
}
135+
}
136+
137+
#[allow(dead_code)]
138+
pub(crate) fn spawn_session_ready(&self, session_id: &SessionId) {
139+
let nats_clone = self.nats.clone();
140+
let prefix = self.config.acp_prefix().to_string();
141+
let session_id = session_id.clone();
142+
let metrics = self.metrics.clone();
143+
let session_ready_task = tokio::spawn(async move {
144+
tokio::time::sleep(SESSION_READY_DELAY).await;
145+
146+
let ready_subject = agent::ext_session_ready(&prefix, &session_id.to_string());
147+
info!(session_id = %session_id, subject = %ready_subject, "Publishing session.ready");
148+
149+
let ready_message = ExtSessionReady::new(session_id.clone());
150+
151+
let options = nats::PublishOptions::builder()
152+
.publish_retry_policy(nats::RetryPolicy::standard())
153+
.flush_policy(nats::FlushPolicy::standard())
154+
.build();
155+
156+
if let Err(e) =
157+
nats::publish(&nats_clone, &ready_subject, &ready_message, options).await
158+
{
159+
warn!(
160+
error = %e,
161+
session_id = %session_id,
162+
"Failed to publish session.ready"
163+
);
164+
metrics.record_error("session_ready", "session_ready_publish_failed");
165+
} else {
166+
info!(session_id = %session_id, "Published session.ready");
167+
}
168+
});
169+
self.register_session_ready_task(session_ready_task);
170+
}
50171
}
51172

52173
#[async_trait::async_trait(?Send)]
@@ -103,3 +224,169 @@ mod send_sync_tests {
103224
assert_send_sync::<Bridge<AdvancedMockNatsClient, SystemClock>>();
104225
}
105226
}
227+
228+
#[cfg(test)]
229+
mod bridge_session_ready_tests {
230+
use super::*;
231+
use trogon_nats::AdvancedMockNatsClient;
232+
use trogon_std::time::MockClock;
233+
234+
fn test_bridge() -> Bridge<AdvancedMockNatsClient, MockClock> {
235+
let nats = AdvancedMockNatsClient::new();
236+
let clock = MockClock::new();
237+
let provider = opentelemetry::global::meter_provider();
238+
let meter = provider.meter("test");
239+
let config = Config::for_test("acp");
240+
Bridge::new(nats, clock, &meter, config)
241+
}
242+
243+
#[tokio::test]
244+
async fn register_and_has_pending() {
245+
let bridge = test_bridge();
246+
assert!(!bridge.has_pending_session_ready_tasks());
247+
248+
let (tx, rx) = tokio::sync::oneshot::channel::<()>();
249+
let task = tokio::spawn(async move {
250+
let _ = rx.await;
251+
});
252+
bridge.register_session_ready_task(task);
253+
assert!(bridge.has_pending_session_ready_tasks());
254+
tx.send(()).unwrap();
255+
bridge.await_session_ready_tasks().await;
256+
}
257+
258+
#[tokio::test]
259+
async fn register_retains_only_unfinished_tasks() {
260+
let bridge = test_bridge();
261+
let finished = tokio::spawn(async {});
262+
bridge.register_session_ready_task(finished);
263+
tokio::task::yield_now().await;
264+
265+
let (tx, rx) = tokio::sync::oneshot::channel::<()>();
266+
let pending = tokio::spawn(async move {
267+
let _ = rx.await;
268+
});
269+
bridge.register_session_ready_task(pending);
270+
271+
{
272+
let tasks = bridge.session_ready_publish_tasks.lock().unwrap();
273+
assert_eq!(tasks.len(), 1);
274+
}
275+
276+
tx.send(()).unwrap();
277+
bridge.await_session_ready_tasks().await;
278+
}
279+
280+
#[tokio::test]
281+
async fn await_session_ready_tasks_drains() {
282+
let bridge = test_bridge();
283+
let task = tokio::spawn(async {});
284+
bridge.register_session_ready_task(task);
285+
bridge.await_session_ready_tasks().await;
286+
assert!(!bridge.has_pending_session_ready_tasks());
287+
}
288+
289+
#[tokio::test]
290+
async fn await_handles_panicking_task() {
291+
let bridge = test_bridge();
292+
let task = tokio::spawn(async { panic!("intentional") });
293+
bridge.register_session_ready_task(task);
294+
bridge.await_session_ready_tasks().await;
295+
assert!(!bridge.has_pending_session_ready_tasks());
296+
}
297+
298+
#[tokio::test]
299+
async fn spawn_session_ready_registers_task() {
300+
let bridge = test_bridge();
301+
let session_id = SessionId::new("test-session".to_string());
302+
bridge.spawn_session_ready(&session_id);
303+
304+
assert!(bridge.has_pending_session_ready_tasks());
305+
bridge.await_session_ready_tasks().await;
306+
}
307+
308+
#[tokio::test]
309+
async fn spawn_session_ready_error_path() {
310+
let nats = AdvancedMockNatsClient::new();
311+
nats.fail_publish_count(4);
312+
let clock = MockClock::new();
313+
let provider = opentelemetry::global::meter_provider();
314+
let meter = provider.meter("test");
315+
let config = Config::for_test("acp");
316+
let bridge = Bridge::new(nats, clock, &meter, config);
317+
318+
let session_id = SessionId::new("fail-session".to_string());
319+
bridge.spawn_session_ready(&session_id);
320+
bridge.await_session_ready_tasks().await;
321+
}
322+
323+
#[tokio::test]
324+
async fn has_pending_filters_finished_tasks() {
325+
let bridge = test_bridge();
326+
let task = tokio::spawn(async {});
327+
bridge.register_session_ready_task(task);
328+
tokio::task::yield_now().await;
329+
tokio::task::yield_now().await;
330+
assert!(!bridge.has_pending_session_ready_tasks());
331+
}
332+
}
333+
334+
#[cfg(test)]
335+
mod cancelled_sessions_tests {
336+
use super::*;
337+
use agent_client_protocol::SessionId;
338+
use trogon_std::time::MockClock;
339+
340+
fn session(id: &str) -> SessionId {
341+
SessionId::new(id.to_string())
342+
}
343+
344+
#[test]
345+
fn mark_and_take_within_ttl() {
346+
let clock = MockClock::new();
347+
let cs = CancelledSessions::new();
348+
cs.mark_cancelled(session("s1"), &clock);
349+
assert!(cs.take_if_cancelled(&session("s1"), &clock).is_some());
350+
}
351+
352+
#[test]
353+
fn take_removes_entry() {
354+
let clock = MockClock::new();
355+
let cs = CancelledSessions::new();
356+
cs.mark_cancelled(session("s1"), &clock);
357+
cs.take_if_cancelled(&session("s1"), &clock);
358+
assert!(cs.take_if_cancelled(&session("s1"), &clock).is_none());
359+
}
360+
361+
#[test]
362+
fn take_returns_none_for_unknown_session() {
363+
let clock = MockClock::new();
364+
let cs = CancelledSessions::new();
365+
assert!(cs.take_if_cancelled(&session("nope"), &clock).is_none());
366+
}
367+
368+
#[test]
369+
fn expired_entry_returns_none() {
370+
let clock = MockClock::new();
371+
let cs = CancelledSessions::new();
372+
cs.mark_cancelled(session("s1"), &clock);
373+
clock.advance(CANCELLED_SESSION_TTL + Duration::from_secs(1));
374+
assert!(cs.take_if_cancelled(&session("s1"), &clock).is_none());
375+
}
376+
377+
#[test]
378+
fn cleanup_evicts_expired_entries() {
379+
let clock = MockClock::new();
380+
let cs = CancelledSessions::new();
381+
382+
cs.mark_cancelled(session("old"), &clock);
383+
clock.advance(CANCELLED_SESSION_TTL + Duration::from_secs(1));
384+
385+
for i in 0..CLEANUP_EVERY {
386+
cs.mark_cancelled(session(&format!("s{i}")), &clock);
387+
}
388+
389+
let map = cs.map.lock().unwrap();
390+
assert!(!map.contains_key(&session("old")));
391+
}
392+
}

0 commit comments

Comments
 (0)