@@ -66,7 +66,11 @@ impl Host for DbProxyHost {
6666 let mgr = Arc :: clone ( & self . pool_manager ) ;
6767
6868 let handle = self . runtime_handle . clone ( ) ;
69- tokio:: task:: block_in_place ( || handle. block_on ( mgr. checkout ( & key, password) ) )
69+ if mgr. has_async_factory ( ) {
70+ tokio:: task:: block_in_place ( || handle. block_on ( mgr. checkout_async ( & key, password) ) )
71+ } else {
72+ tokio:: task:: block_in_place ( || handle. block_on ( mgr. checkout ( & key, password) ) )
73+ }
7074 }
7175
7276 fn send ( & mut self , conn_handle : u64 , data : Vec < u8 > ) -> Result < u32 , String > {
@@ -79,8 +83,10 @@ impl Host for DbProxyHost {
7983 let mgr = Arc :: clone ( & self . pool_manager ) ;
8084 let handle = self . runtime_handle . clone ( ) ;
8185
86+ // Use send_query() which releases the mutex during I/O for concurrent access.
87+ // Falls back to sync backend via block_in_place if no async backend is available.
8288 let sent = tokio:: task:: block_in_place ( || {
83- handle. block_on ( mgr. send ( conn_handle, & data) )
89+ handle. block_on ( mgr. send_query ( conn_handle, & data) )
8490 } ) ?;
8591
8692 Ok ( sent as u32 )
@@ -96,8 +102,10 @@ impl Host for DbProxyHost {
96102 let mgr = Arc :: clone ( & self . pool_manager ) ;
97103 let handle = self . runtime_handle . clone ( ) ;
98104
105+ // Use receive_results() which releases the mutex during I/O.
106+ // Falls back to sync backend via block_in_place if no async backend is available.
99107 tokio:: task:: block_in_place ( || {
100- handle. block_on ( mgr. recv ( conn_handle, max_bytes as usize ) )
108+ handle. block_on ( mgr. receive_results ( conn_handle, max_bytes as usize ) )
101109 } )
102110 }
103111
@@ -118,6 +126,9 @@ impl Host for DbProxyHost {
118126mod tests {
119127 use super :: * ;
120128 use super :: super :: { ConnectionBackend , ConnectionFactory , PoolConfig } ;
129+ use super :: super :: async_io:: { AsyncConnectionBackend , AsyncConnectionFactory , AsyncConnectFuture } ;
130+ use std:: future:: Future ;
131+ use std:: pin:: Pin ;
121132 use std:: sync:: atomic:: { AtomicU64 , Ordering } ;
122133 use std:: time:: Duration ;
123134
@@ -317,4 +328,121 @@ mod tests {
317328 // Reused — no new factory connect.
318329 assert_eq ! ( factory. connects. load( Ordering :: Relaxed ) , 1 ) ;
319330 }
331+
332+ // ── Async path tests ─────────────────────────────────────────────
333+
334+ #[ derive( Debug ) ]
335+ struct MockAsyncBackend {
336+ send_count : AtomicU64 ,
337+ }
338+
339+ impl MockAsyncBackend {
340+ fn new ( ) -> Self {
341+ Self {
342+ send_count : AtomicU64 :: new ( 0 ) ,
343+ }
344+ }
345+ }
346+
347+ impl AsyncConnectionBackend for MockAsyncBackend {
348+ fn send_async < ' a > (
349+ & ' a mut self ,
350+ data : & ' a [ u8 ] ,
351+ ) -> Pin < Box < dyn Future < Output = Result < usize , String > > + Send + ' a > > {
352+ self . send_count . fetch_add ( 1 , Ordering :: Relaxed ) ;
353+ let len = data. len ( ) ;
354+ Box :: pin ( async move { Ok ( len) } )
355+ }
356+
357+ fn recv_async < ' a > (
358+ & ' a mut self ,
359+ max_bytes : usize ,
360+ ) -> Pin < Box < dyn Future < Output = Result < Vec < u8 > , String > > + Send + ' a > > {
361+ Box :: pin ( async move { Ok ( vec ! [ 0x42 ; max_bytes. min( 4 ) ] ) } )
362+ }
363+
364+ fn ping_async ( & mut self ) -> Pin < Box < dyn Future < Output = bool > + Send + ' _ > > {
365+ Box :: pin ( async { true } )
366+ }
367+
368+ fn close_async ( & mut self ) -> Pin < Box < dyn Future < Output = ( ) > + Send + ' _ > > {
369+ Box :: pin ( async { } )
370+ }
371+ }
372+
373+ struct TestAsyncFactory {
374+ connects : AtomicU64 ,
375+ }
376+
377+ impl TestAsyncFactory {
378+ fn new ( ) -> Self {
379+ Self {
380+ connects : AtomicU64 :: new ( 0 ) ,
381+ }
382+ }
383+ }
384+
385+ impl AsyncConnectionFactory for TestAsyncFactory {
386+ fn connect_async < ' a > (
387+ & ' a self ,
388+ _key : & ' a PoolKey ,
389+ _password : Option < & ' a str > ,
390+ ) -> AsyncConnectFuture < ' a > {
391+ self . connects . fetch_add ( 1 , Ordering :: Relaxed ) ;
392+ Box :: pin ( async {
393+ Ok ( Box :: new ( MockAsyncBackend :: new ( ) ) as Box < dyn AsyncConnectionBackend > )
394+ } )
395+ }
396+ }
397+
398+ fn make_async_host ( ) -> DbProxyHost {
399+ let sync_factory = Arc :: new ( TestFactory :: new ( ) ) ;
400+ let async_factory = Arc :: new ( TestAsyncFactory :: new ( ) ) ;
401+ let config = PoolConfig {
402+ max_size : 5 ,
403+ connect_timeout : Duration :: from_millis ( 100 ) ,
404+ ..PoolConfig :: default ( )
405+ } ;
406+ let mgr = Arc :: new ( ConnectionPoolManager :: new_with_async (
407+ config,
408+ sync_factory,
409+ async_factory,
410+ ) ) ;
411+ let handle = tokio:: runtime:: Handle :: current ( ) ;
412+ DbProxyHost :: new ( mgr, handle)
413+ }
414+
415+ #[ tokio:: test( flavor = "multi_thread" , worker_threads = 2 ) ]
416+ async fn host_async_full_lifecycle ( ) {
417+ let mut host = make_async_host ( ) ;
418+
419+ // Connect via async path.
420+ let handle = host. connect ( test_connect_config ( ) ) . unwrap ( ) ;
421+
422+ // Send via async send_query path.
423+ let sent = host. send ( handle, b"SELECT 1;" . to_vec ( ) ) . unwrap ( ) ;
424+ assert_eq ! ( sent, 9 ) ;
425+
426+ // Recv via async receive_results path.
427+ let data = host. recv ( handle, 1024 ) . unwrap ( ) ;
428+ assert ! ( !data. is_empty( ) ) ;
429+
430+ // Close.
431+ host. close ( handle) . unwrap ( ) ;
432+
433+ // Handle invalid after close.
434+ assert ! ( host. send( handle, b"x" . to_vec( ) ) . is_err( ) ) ;
435+ }
436+
437+ #[ tokio:: test( flavor = "multi_thread" , worker_threads = 2 ) ]
438+ async fn host_sync_fallback_when_no_async_factory ( ) {
439+ // make_host() uses sync-only factory — verify it still works.
440+ let mut host = make_host ( ) ;
441+ let handle = host. connect ( test_connect_config ( ) ) . unwrap ( ) ;
442+ let sent = host. send ( handle, b"data" . to_vec ( ) ) . unwrap ( ) ;
443+ assert_eq ! ( sent, 4 ) ;
444+ let data = host. recv ( handle, 1024 ) . unwrap ( ) ;
445+ assert ! ( !data. is_empty( ) ) ;
446+ host. close ( handle) . unwrap ( ) ;
447+ }
320448}
0 commit comments