@@ -6,9 +6,10 @@ use tokio::sync::mpsc;
66
77/// Manages active download sessions for streaming files from MSP nodes to clients.
88///
9- /// Each session maps a file key to a channel sender, allowing the internal upload
9+ /// Each session maps a session ID to a channel sender, allowing the internal upload
1010/// endpoint (which receives chunks from the MSP node) to forward them to the
1111/// download endpoint (which streams them to the client).
12+ #[ derive( Debug ) ]
1213pub struct DownloadSessionManager {
1314 sessions : Arc < RwLock < HashMap < String , mpsc:: Sender < Result < Bytes , std:: io:: Error > > > > > ,
1415 max_sessions : usize ,
@@ -22,13 +23,15 @@ impl DownloadSessionManager {
2223 }
2324 }
2425
25- /// Atomically adds a new download session for the given file key.
26- /// Fails if there is already an active session for the file.
27- pub fn add_session (
26+ /// Atomically registers a new download session for the given session ID.
27+ /// Returns a guard that will automatically clean up the session when dropped.
28+ /// Returns an error if there is already an active session with this ID
29+ /// or if the maximum number of concurrent downloads has been reached.
30+ pub fn start_session (
2831 & self ,
29- id : & String ,
32+ session_id : String ,
3033 sender : mpsc:: Sender < Result < Bytes , std:: io:: Error > > ,
31- ) -> Result < ( ) , String > {
34+ ) -> Result < DownloadSessionGuard , String > {
3235 let mut sessions = self
3336 . sessions
3437 . write ( )
@@ -41,22 +44,32 @@ impl DownloadSessionManager {
4144 ) ) ;
4245 }
4346
44- match sessions. entry ( id. clone ( ) ) {
45- Entry :: Occupied ( _) => Err ( "File is already being downloaded" . to_string ( ) ) ,
47+ match sessions. entry ( session_id. clone ( ) ) {
48+ Entry :: Occupied ( _) => Err ( format ! (
49+ "Session ID {} is already active. Please retry with a new session ID." ,
50+ session_id
51+ ) ) ,
4652 Entry :: Vacant ( entry) => {
4753 entry. insert ( sender) ;
48- Ok ( ( ) )
54+ Ok ( DownloadSessionGuard {
55+ manager : self . clone ( ) ,
56+ session_id,
57+ } )
4958 }
5059 }
5160 }
5261
53- pub fn remove_session ( & self , id : & str ) {
62+ /// Removes a download session for the given session ID.
63+ /// This is called automatically by the guard's Drop implementation.
64+ fn end_session ( & self , session_id : & str ) {
5465 self . sessions
5566 . write ( )
5667 . expect ( "Download sessions lock poisoned" )
57- . remove ( id ) ;
68+ . remove ( session_id ) ;
5869 }
5970
71+ /// Retrieves the channel sender for the given session ID.
72+ /// Used by internal_upload_by_key to forward chunks to the client.
6073 pub fn get_session ( & self , id : & str ) -> Option < mpsc:: Sender < Result < Bytes , std:: io:: Error > > > {
6174 self . sessions
6275 . read ( )
@@ -65,3 +78,138 @@ impl DownloadSessionManager {
6578 . cloned ( )
6679 }
6780}
81+
82+ impl Clone for DownloadSessionManager {
83+ fn clone ( & self ) -> Self {
84+ DownloadSessionManager {
85+ sessions : Arc :: clone ( & self . sessions ) ,
86+ max_sessions : self . max_sessions ,
87+ }
88+ }
89+ }
90+
91+ /// RAII guard that ensures download sessions are always cleaned up.
92+ /// The download session will be automatically removed when this guard is dropped,
93+ /// regardless of whether the download succeeded, failed, or panicked.
94+ #[ derive( Debug ) ]
95+ pub struct DownloadSessionGuard {
96+ manager : DownloadSessionManager ,
97+ session_id : String ,
98+ }
99+
100+ impl Drop for DownloadSessionGuard {
101+ fn drop ( & mut self ) {
102+ self . manager . end_session ( & self . session_id ) ;
103+ }
104+ }
105+
106+ #[ cfg( test) ]
107+ mod tests {
108+ use super :: * ;
109+
110+ #[ test]
111+ fn test_start_session_success ( ) {
112+ let manager = DownloadSessionManager :: new ( 100 ) ;
113+ let session_id = "test_session_123" ;
114+ let ( tx, _rx) = mpsc:: channel ( 10 ) ;
115+
116+ let _guard = manager. start_session ( session_id. to_string ( ) , tx) . unwrap ( ) ;
117+
118+ // Session should exist
119+ assert ! ( manager. get_session( session_id) . is_some( ) ) ;
120+ }
121+
122+ #[ test]
123+ fn test_start_session_duplicate_fails ( ) {
124+ let manager = DownloadSessionManager :: new ( 100 ) ;
125+ let session_id = "test_session_123" ;
126+ let ( tx1, _rx1) = mpsc:: channel ( 10 ) ;
127+ let ( tx2, _rx2) = mpsc:: channel ( 10 ) ;
128+
129+ let _guard = manager. start_session ( session_id. to_string ( ) , tx1) . unwrap ( ) ;
130+ let result = manager. start_session ( session_id. to_string ( ) , tx2) ;
131+
132+ assert ! ( result. is_err( ) ) ;
133+ assert ! ( result. unwrap_err( ) . contains( "is already active" ) ) ;
134+ }
135+
136+ #[ test]
137+ fn test_guard_cleanup_on_drop ( ) {
138+ let manager = DownloadSessionManager :: new ( 100 ) ;
139+ let session_id = "test_session_123" ;
140+ let ( tx1, _rx1) = mpsc:: channel ( 10 ) ;
141+ let ( tx2, _rx2) = mpsc:: channel ( 10 ) ;
142+
143+ {
144+ let _guard = manager. start_session ( session_id. to_string ( ) , tx1) . unwrap ( ) ;
145+ assert ! ( manager. get_session( session_id) . is_some( ) ) ;
146+ } // guard dropped here
147+
148+ assert ! ( manager. get_session( session_id) . is_none( ) ) ;
149+
150+ // Should be able to start a new session after guard is dropped
151+ let _guard = manager. start_session ( session_id. to_string ( ) , tx2) . unwrap ( ) ;
152+ assert ! ( manager. get_session( session_id) . is_some( ) ) ;
153+ }
154+
155+ #[ test]
156+ fn test_max_sessions_limit ( ) {
157+ let manager = DownloadSessionManager :: new ( 2 ) ;
158+ let ( tx1, _rx1) = mpsc:: channel ( 10 ) ;
159+ let ( tx2, _rx2) = mpsc:: channel ( 10 ) ;
160+ let ( tx3, _rx3) = mpsc:: channel ( 10 ) ;
161+
162+ let _guard1 = manager. start_session ( "session1" . to_string ( ) , tx1) . unwrap ( ) ;
163+ let _guard2 = manager. start_session ( "session2" . to_string ( ) , tx2) . unwrap ( ) ;
164+
165+ // Third session should fail due to max sessions reached
166+ let result = manager. start_session ( "session3" . to_string ( ) , tx3) ;
167+ assert ! ( result. is_err( ) ) ;
168+ assert ! ( result. unwrap_err( ) . contains( "Maximum number" ) ) ;
169+ }
170+
171+ #[ test]
172+ fn test_multiple_different_sessions ( ) {
173+ let manager = DownloadSessionManager :: new ( 100 ) ;
174+ let ( tx1, _rx1) = mpsc:: channel ( 10 ) ;
175+ let ( tx2, _rx2) = mpsc:: channel ( 10 ) ;
176+ let ( tx3, _rx3) = mpsc:: channel ( 10 ) ;
177+
178+ let _guard1 = manager. start_session ( "session1" . to_string ( ) , tx1) . unwrap ( ) ;
179+ let _guard2 = manager. start_session ( "session2" . to_string ( ) , tx2) . unwrap ( ) ;
180+ let _guard3 = manager. start_session ( "session3" . to_string ( ) , tx3) . unwrap ( ) ;
181+
182+ assert ! ( manager. get_session( "session1" ) . is_some( ) ) ;
183+ assert ! ( manager. get_session( "session2" ) . is_some( ) ) ;
184+ assert ! ( manager. get_session( "session3" ) . is_some( ) ) ;
185+
186+ drop ( _guard2) ;
187+ assert ! ( manager. get_session( "session1" ) . is_some( ) ) ;
188+ assert ! ( manager. get_session( "session2" ) . is_none( ) ) ;
189+ assert ! ( manager. get_session( "session3" ) . is_some( ) ) ;
190+ }
191+
192+ #[ tokio:: test]
193+ async fn test_guard_cleanup_on_task_failure ( ) {
194+ let manager = DownloadSessionManager :: new ( 100 ) ;
195+ let session_id = "test_session_123" ;
196+ let ( tx, _rx) = mpsc:: channel ( 10 ) ;
197+
198+ let guard = manager. start_session ( session_id. to_string ( ) , tx) . unwrap ( ) ;
199+ assert ! ( manager. get_session( session_id) . is_some( ) ) ;
200+
201+ // Simulate what happens in download_by_key: move guard into task
202+ let manager_clone = manager. clone ( ) ;
203+ let handle = tokio:: spawn ( async move {
204+ let _guard = guard;
205+ // Simulate RPC failure
206+ Err :: < ( ) , String > ( "RPC call failed" . to_string ( ) )
207+ } ) ;
208+
209+ // Wait for task to complete
210+ let _ = handle. await ;
211+
212+ // Session should be cleaned up even though task failed
213+ assert ! ( manager_clone. get_session( session_id) . is_none( ) ) ;
214+ }
215+ }
0 commit comments