1+ // Safety: `CancelledSessions` uses `Mutex` for interior mutability so the `Bridge` remains
2+ // `Send`/`Sync`. The fire-and-forget publish in `spawn_session_ready` uses `tokio::spawn`
3+ // and captures only cloned, `Send` values — it never touches shared state from the closure.
4+
15mod authenticate;
26mod cancel;
37mod ext_method;
@@ -8,27 +12,85 @@ mod new_session;
812mod prompt;
913mod set_session_mode;
1014
11- use crate :: config:: Config ;
12- use crate :: nats:: { FlushClient , PublishClient , RequestClient } ;
15+ use crate :: config:: { Config , SESSION_READY_DELAY } ;
16+ use crate :: nats:: { self , ExtSessionReady , FlushClient , PublishClient , RequestClient , agent } ;
1317use crate :: pending_prompt_waiters:: PendingSessionPromptResponseWaiters ;
1418use crate :: prompt_slot_counter:: PromptSlotCounter ;
1519use crate :: telemetry:: metrics:: Metrics ;
1620use agent_client_protocol:: {
1721 Agent , AuthenticateRequest , AuthenticateResponse , CancelNotification , ExtNotification ,
1822 ExtRequest , ExtResponse , InitializeRequest , InitializeResponse , LoadSessionRequest ,
1923 LoadSessionResponse , NewSessionRequest , NewSessionResponse , PromptRequest , PromptResponse ,
20- Result , SetSessionModeRequest , SetSessionModeResponse ,
24+ Result , SessionId , SetSessionModeRequest , SetSessionModeResponse ,
2125} ;
2226use opentelemetry:: metrics:: Meter ;
27+ use std:: collections:: HashMap ;
28+ use std:: sync:: Mutex ;
29+ use std:: time:: Duration ;
30+ use tracing:: { info, warn} ;
2331use trogon_std:: time:: GetElapsed ;
2432
33+ #[ allow( dead_code) ]
34+ const CANCELLED_SESSION_TTL : Duration = Duration :: from_secs ( 300 ) ;
35+ #[ allow( dead_code) ]
36+ const CLEANUP_EVERY : usize = 16 ;
37+
38+ #[ allow( dead_code) ]
39+ pub ( crate ) struct CancelledSessions < I : Copy > {
40+ map : Mutex < HashMap < SessionId , I > > ,
41+ cleanup_counter : std:: sync:: atomic:: AtomicUsize ,
42+ }
43+
44+ #[ allow( dead_code) ]
45+ impl < I : Copy > CancelledSessions < I > {
46+ pub fn new ( ) -> Self {
47+ Self {
48+ map : Mutex :: new ( HashMap :: new ( ) ) ,
49+ cleanup_counter : std:: sync:: atomic:: AtomicUsize :: new ( 0 ) ,
50+ }
51+ }
52+
53+ pub fn mark_cancelled < C : GetElapsed < Instant = I > > ( & self , session_id : SessionId , clock : & C ) {
54+ let mut map = self . map . lock ( ) . unwrap ( ) ;
55+ map. insert ( session_id, clock. now ( ) ) ;
56+ let count = self
57+ . cleanup_counter
58+ . fetch_add ( 1 , std:: sync:: atomic:: Ordering :: Relaxed ) ;
59+ if count. is_multiple_of ( CLEANUP_EVERY ) {
60+ map. retain ( |_, ts| clock. elapsed ( * ts) < CANCELLED_SESSION_TTL ) ;
61+ }
62+ }
63+
64+ pub fn take_if_cancelled < C : GetElapsed < Instant = I > > (
65+ & self ,
66+ session_id : & SessionId ,
67+ clock : & C ,
68+ ) -> Option < ( ) > {
69+ let mut map = self . map . lock ( ) . unwrap ( ) ;
70+ let is_valid = map
71+ . get ( session_id)
72+ . is_some_and ( |ts| clock. elapsed ( * ts) < CANCELLED_SESSION_TTL ) ;
73+
74+ if is_valid {
75+ map. remove ( session_id) ;
76+ Some ( ( ) )
77+ } else {
78+ map. remove ( session_id) ;
79+ None
80+ }
81+ }
82+ }
83+
2584pub struct Bridge < N : RequestClient + PublishClient + FlushClient , C : GetElapsed > {
2685 pub ( crate ) nats : N ,
2786 pub ( crate ) clock : C ,
2887 pub ( crate ) metrics : Metrics ,
88+ #[ allow( dead_code) ]
89+ pub ( crate ) cancelled_sessions : CancelledSessions < C :: Instant > ,
2990 pub ( crate ) pending_session_prompt_responses : PendingSessionPromptResponseWaiters < C :: Instant > ,
3091 pub ( crate ) prompt_slot_counter : PromptSlotCounter ,
3192 pub ( crate ) config : Config ,
93+ pub ( crate ) session_ready_publish_tasks : Mutex < Vec < tokio:: task:: JoinHandle < ( ) > > > ,
3294}
3395
3496impl < N : RequestClient + PublishClient + FlushClient , C : GetElapsed > Bridge < N , C > {
@@ -39,14 +101,73 @@ impl<N: RequestClient + PublishClient + FlushClient, C: GetElapsed> Bridge<N, C>
39101 clock,
40102 config,
41103 metrics : Metrics :: new ( meter) ,
104+ cancelled_sessions : CancelledSessions :: new ( ) ,
42105 pending_session_prompt_responses : PendingSessionPromptResponseWaiters :: new ( ) ,
43106 prompt_slot_counter : PromptSlotCounter :: new ( max_concurrent) ,
107+ session_ready_publish_tasks : Mutex :: new ( Vec :: new ( ) ) ,
44108 }
45109 }
46110
47111 pub ( crate ) fn nats ( & self ) -> & N {
48112 & self . nats
49113 }
114+
115+ #[ allow( dead_code) ]
116+ pub ( crate ) fn register_session_ready_task ( & self , task : tokio:: task:: JoinHandle < ( ) > ) {
117+ let mut tasks = self . session_ready_publish_tasks . lock ( ) . unwrap ( ) ;
118+ tasks. retain ( |task| !task. is_finished ( ) ) ;
119+ tasks. push ( task) ;
120+ }
121+
122+ pub fn has_pending_session_ready_tasks ( & self ) -> bool {
123+ let mut tasks = self . session_ready_publish_tasks . lock ( ) . unwrap ( ) ;
124+ tasks. retain ( |task| !task. is_finished ( ) ) ;
125+ !tasks. is_empty ( )
126+ }
127+
128+ pub async fn await_session_ready_tasks ( & self ) {
129+ let tasks = std:: mem:: take ( & mut * self . session_ready_publish_tasks . lock ( ) . unwrap ( ) ) ;
130+ for task in tasks {
131+ if let Err ( e) = task. await {
132+ warn ! ( error = %e, "session_ready task panicked" ) ;
133+ }
134+ }
135+ }
136+
137+ #[ allow( dead_code) ]
138+ pub ( crate ) fn spawn_session_ready ( & self , session_id : & SessionId ) {
139+ let nats_clone = self . nats . clone ( ) ;
140+ let prefix = self . config . acp_prefix ( ) . to_string ( ) ;
141+ let session_id = session_id. clone ( ) ;
142+ let metrics = self . metrics . clone ( ) ;
143+ let session_ready_task = tokio:: spawn ( async move {
144+ tokio:: time:: sleep ( SESSION_READY_DELAY ) . await ;
145+
146+ let ready_subject = agent:: ext_session_ready ( & prefix, & session_id. to_string ( ) ) ;
147+ info ! ( session_id = %session_id, subject = %ready_subject, "Publishing session.ready" ) ;
148+
149+ let ready_message = ExtSessionReady :: new ( session_id. clone ( ) ) ;
150+
151+ let options = nats:: PublishOptions :: builder ( )
152+ . publish_retry_policy ( nats:: RetryPolicy :: standard ( ) )
153+ . flush_policy ( nats:: FlushPolicy :: standard ( ) )
154+ . build ( ) ;
155+
156+ if let Err ( e) =
157+ nats:: publish ( & nats_clone, & ready_subject, & ready_message, options) . await
158+ {
159+ warn ! (
160+ error = %e,
161+ session_id = %session_id,
162+ "Failed to publish session.ready"
163+ ) ;
164+ metrics. record_error ( "session_ready" , "session_ready_publish_failed" ) ;
165+ } else {
166+ info ! ( session_id = %session_id, "Published session.ready" ) ;
167+ }
168+ } ) ;
169+ self . register_session_ready_task ( session_ready_task) ;
170+ }
50171}
51172
52173#[ async_trait:: async_trait( ?Send ) ]
@@ -103,3 +224,169 @@ mod send_sync_tests {
103224 assert_send_sync :: < Bridge < AdvancedMockNatsClient , SystemClock > > ( ) ;
104225 }
105226}
227+
228+ #[ cfg( test) ]
229+ mod bridge_session_ready_tests {
230+ use super :: * ;
231+ use trogon_nats:: AdvancedMockNatsClient ;
232+ use trogon_std:: time:: MockClock ;
233+
234+ fn test_bridge ( ) -> Bridge < AdvancedMockNatsClient , MockClock > {
235+ let nats = AdvancedMockNatsClient :: new ( ) ;
236+ let clock = MockClock :: new ( ) ;
237+ let provider = opentelemetry:: global:: meter_provider ( ) ;
238+ let meter = provider. meter ( "test" ) ;
239+ let config = Config :: for_test ( "acp" ) ;
240+ Bridge :: new ( nats, clock, & meter, config)
241+ }
242+
243+ #[ tokio:: test]
244+ async fn register_and_has_pending ( ) {
245+ let bridge = test_bridge ( ) ;
246+ assert ! ( !bridge. has_pending_session_ready_tasks( ) ) ;
247+
248+ let ( tx, rx) = tokio:: sync:: oneshot:: channel :: < ( ) > ( ) ;
249+ let task = tokio:: spawn ( async move {
250+ let _ = rx. await ;
251+ } ) ;
252+ bridge. register_session_ready_task ( task) ;
253+ assert ! ( bridge. has_pending_session_ready_tasks( ) ) ;
254+ tx. send ( ( ) ) . unwrap ( ) ;
255+ bridge. await_session_ready_tasks ( ) . await ;
256+ }
257+
258+ #[ tokio:: test]
259+ async fn register_retains_only_unfinished_tasks ( ) {
260+ let bridge = test_bridge ( ) ;
261+ let finished = tokio:: spawn ( async { } ) ;
262+ bridge. register_session_ready_task ( finished) ;
263+ tokio:: task:: yield_now ( ) . await ;
264+
265+ let ( tx, rx) = tokio:: sync:: oneshot:: channel :: < ( ) > ( ) ;
266+ let pending = tokio:: spawn ( async move {
267+ let _ = rx. await ;
268+ } ) ;
269+ bridge. register_session_ready_task ( pending) ;
270+
271+ {
272+ let tasks = bridge. session_ready_publish_tasks . lock ( ) . unwrap ( ) ;
273+ assert_eq ! ( tasks. len( ) , 1 ) ;
274+ }
275+
276+ tx. send ( ( ) ) . unwrap ( ) ;
277+ bridge. await_session_ready_tasks ( ) . await ;
278+ }
279+
280+ #[ tokio:: test]
281+ async fn await_session_ready_tasks_drains ( ) {
282+ let bridge = test_bridge ( ) ;
283+ let task = tokio:: spawn ( async { } ) ;
284+ bridge. register_session_ready_task ( task) ;
285+ bridge. await_session_ready_tasks ( ) . await ;
286+ assert ! ( !bridge. has_pending_session_ready_tasks( ) ) ;
287+ }
288+
289+ #[ tokio:: test]
290+ async fn await_handles_panicking_task ( ) {
291+ let bridge = test_bridge ( ) ;
292+ let task = tokio:: spawn ( async { panic ! ( "intentional" ) } ) ;
293+ bridge. register_session_ready_task ( task) ;
294+ bridge. await_session_ready_tasks ( ) . await ;
295+ assert ! ( !bridge. has_pending_session_ready_tasks( ) ) ;
296+ }
297+
298+ #[ tokio:: test]
299+ async fn spawn_session_ready_registers_task ( ) {
300+ let bridge = test_bridge ( ) ;
301+ let session_id = SessionId :: new ( "test-session" . to_string ( ) ) ;
302+ bridge. spawn_session_ready ( & session_id) ;
303+
304+ assert ! ( bridge. has_pending_session_ready_tasks( ) ) ;
305+ bridge. await_session_ready_tasks ( ) . await ;
306+ }
307+
308+ #[ tokio:: test]
309+ async fn spawn_session_ready_error_path ( ) {
310+ let nats = AdvancedMockNatsClient :: new ( ) ;
311+ nats. fail_publish_count ( 4 ) ;
312+ let clock = MockClock :: new ( ) ;
313+ let provider = opentelemetry:: global:: meter_provider ( ) ;
314+ let meter = provider. meter ( "test" ) ;
315+ let config = Config :: for_test ( "acp" ) ;
316+ let bridge = Bridge :: new ( nats, clock, & meter, config) ;
317+
318+ let session_id = SessionId :: new ( "fail-session" . to_string ( ) ) ;
319+ bridge. spawn_session_ready ( & session_id) ;
320+ bridge. await_session_ready_tasks ( ) . await ;
321+ }
322+
323+ #[ tokio:: test]
324+ async fn has_pending_filters_finished_tasks ( ) {
325+ let bridge = test_bridge ( ) ;
326+ let task = tokio:: spawn ( async { } ) ;
327+ bridge. register_session_ready_task ( task) ;
328+ tokio:: task:: yield_now ( ) . await ;
329+ tokio:: task:: yield_now ( ) . await ;
330+ assert ! ( !bridge. has_pending_session_ready_tasks( ) ) ;
331+ }
332+ }
333+
334+ #[ cfg( test) ]
335+ mod cancelled_sessions_tests {
336+ use super :: * ;
337+ use agent_client_protocol:: SessionId ;
338+ use trogon_std:: time:: MockClock ;
339+
340+ fn session ( id : & str ) -> SessionId {
341+ SessionId :: new ( id. to_string ( ) )
342+ }
343+
344+ #[ test]
345+ fn mark_and_take_within_ttl ( ) {
346+ let clock = MockClock :: new ( ) ;
347+ let cs = CancelledSessions :: new ( ) ;
348+ cs. mark_cancelled ( session ( "s1" ) , & clock) ;
349+ assert ! ( cs. take_if_cancelled( & session( "s1" ) , & clock) . is_some( ) ) ;
350+ }
351+
352+ #[ test]
353+ fn take_removes_entry ( ) {
354+ let clock = MockClock :: new ( ) ;
355+ let cs = CancelledSessions :: new ( ) ;
356+ cs. mark_cancelled ( session ( "s1" ) , & clock) ;
357+ cs. take_if_cancelled ( & session ( "s1" ) , & clock) ;
358+ assert ! ( cs. take_if_cancelled( & session( "s1" ) , & clock) . is_none( ) ) ;
359+ }
360+
361+ #[ test]
362+ fn take_returns_none_for_unknown_session ( ) {
363+ let clock = MockClock :: new ( ) ;
364+ let cs = CancelledSessions :: new ( ) ;
365+ assert ! ( cs. take_if_cancelled( & session( "nope" ) , & clock) . is_none( ) ) ;
366+ }
367+
368+ #[ test]
369+ fn expired_entry_returns_none ( ) {
370+ let clock = MockClock :: new ( ) ;
371+ let cs = CancelledSessions :: new ( ) ;
372+ cs. mark_cancelled ( session ( "s1" ) , & clock) ;
373+ clock. advance ( CANCELLED_SESSION_TTL + Duration :: from_secs ( 1 ) ) ;
374+ assert ! ( cs. take_if_cancelled( & session( "s1" ) , & clock) . is_none( ) ) ;
375+ }
376+
377+ #[ test]
378+ fn cleanup_evicts_expired_entries ( ) {
379+ let clock = MockClock :: new ( ) ;
380+ let cs = CancelledSessions :: new ( ) ;
381+
382+ cs. mark_cancelled ( session ( "old" ) , & clock) ;
383+ clock. advance ( CANCELLED_SESSION_TTL + Duration :: from_secs ( 1 ) ) ;
384+
385+ for i in 0 ..CLEANUP_EVERY {
386+ cs. mark_cancelled ( session ( & format ! ( "s{i}" ) ) , & clock) ;
387+ }
388+
389+ let map = cs. map . lock ( ) . unwrap ( ) ;
390+ assert ! ( !map. contains_key( & session( "old" ) ) ) ;
391+ }
392+ }
0 commit comments