@@ -770,4 +770,91 @@ mod tests {
770770 "db_proxy should be Some when enabled with a factory"
771771 ) ;
772772 }
773+
774+ #[ tokio:: test( flavor = "multi_thread" , worker_threads = 2 ) ]
775+ async fn build_host_state_with_async_factory_uses_async_path ( ) {
776+ use crate :: db_proxy:: async_io:: { AsyncConnectionBackend , AsyncConnectionFactory , AsyncConnectFuture } ;
777+ use crate :: db_proxy:: { ConnectionBackend , ConnectionFactory , PoolKey } ;
778+ use std:: future:: Future ;
779+ use std:: pin:: Pin ;
780+ use std:: sync:: atomic:: { AtomicU64 , Ordering } ;
781+
782+ #[ derive( Debug ) ]
783+ struct StubBackend ;
784+
785+ impl ConnectionBackend for StubBackend {
786+ fn send ( & mut self , data : & [ u8 ] ) -> Result < usize , String > { Ok ( data. len ( ) ) }
787+ fn recv ( & mut self , _max : usize ) -> Result < Vec < u8 > , String > { Ok ( vec ! [ ] ) }
788+ fn ping ( & mut self ) -> bool { true }
789+ fn close ( & mut self ) { }
790+ }
791+
792+ struct CountingFactory ( AtomicU64 ) ;
793+ impl ConnectionFactory for CountingFactory {
794+ fn connect ( & self , _key : & PoolKey , _password : Option < & str > ) -> Result < Box < dyn ConnectionBackend > , String > {
795+ self . 0 . fetch_add ( 1 , Ordering :: Relaxed ) ;
796+ Ok ( Box :: new ( StubBackend ) )
797+ }
798+ }
799+
800+ #[ derive( Debug ) ]
801+ struct StubAsyncBackend ;
802+ impl AsyncConnectionBackend for StubAsyncBackend {
803+ fn send_async < ' a > ( & ' a mut self , data : & ' a [ u8 ] ) -> Pin < Box < dyn Future < Output = Result < usize , String > > + Send + ' a > > {
804+ Box :: pin ( async move { Ok ( data. len ( ) ) } )
805+ }
806+ fn recv_async < ' a > ( & ' a mut self , max_bytes : usize ) -> Pin < Box < dyn Future < Output = Result < Vec < u8 > , String > > + Send + ' a > > {
807+ Box :: pin ( async move { Ok ( vec ! [ 0x42 ; max_bytes. min( 4 ) ] ) } )
808+ }
809+ fn ping_async ( & mut self ) -> Pin < Box < dyn Future < Output = bool > + Send + ' _ > > {
810+ Box :: pin ( async { true } )
811+ }
812+ fn close_async ( & mut self ) -> Pin < Box < dyn Future < Output = ( ) > + Send + ' _ > > {
813+ Box :: pin ( async { } )
814+ }
815+ }
816+
817+ struct CountingAsyncFactory ( AtomicU64 ) ;
818+ impl AsyncConnectionFactory for CountingAsyncFactory {
819+ fn connect_async < ' a > ( & ' a self , _key : & ' a PoolKey , _password : Option < & ' a str > ) -> AsyncConnectFuture < ' a > {
820+ self . 0 . fetch_add ( 1 , Ordering :: Relaxed ) ;
821+ Box :: pin ( async { Ok ( Box :: new ( StubAsyncBackend ) as Box < dyn AsyncConnectionBackend > ) } )
822+ }
823+ }
824+
825+ let sync_factory = Arc :: new ( CountingFactory ( AtomicU64 :: new ( 0 ) ) ) ;
826+ let async_factory = Arc :: new ( CountingAsyncFactory ( AtomicU64 :: new ( 0 ) ) ) ;
827+
828+ let config = ShimConfig {
829+ database_proxy : true ,
830+ dns : false ,
831+ ..ShimConfig :: default ( )
832+ } ;
833+ let engine = WarpGridEngine :: new ( config) . unwrap ( ) ;
834+ let mut state = engine. build_host_state_with_async (
835+ Some ( sync_factory. clone ( ) ) ,
836+ Some ( async_factory. clone ( ) ) ,
837+ ) ;
838+
839+ assert ! ( state. db_proxy. is_some( ) , "db_proxy should be Some" ) ;
840+
841+ // Connect through HostState should use the async factory, not the sync one.
842+ let connect_config = shim:: database_proxy:: ConnectConfig {
843+ host : "db.local" . into ( ) ,
844+ port : 5432 ,
845+ database : "mydb" . into ( ) ,
846+ user : "app" . into ( ) ,
847+ password : None ,
848+ } ;
849+ let handle = shim:: database_proxy:: Host :: connect ( & mut state, connect_config) . unwrap ( ) ;
850+ assert_eq ! ( async_factory. 0 . load( Ordering :: Relaxed ) , 1 , "async factory should be called" ) ;
851+ assert_eq ! ( sync_factory. 0 . load( Ordering :: Relaxed ) , 0 , "sync factory should NOT be called" ) ;
852+
853+ // Send/recv should work through async path.
854+ let sent = shim:: database_proxy:: Host :: send ( & mut state, handle, b"SELECT 1" . to_vec ( ) ) . unwrap ( ) ;
855+ assert_eq ! ( sent, 8 ) ;
856+
857+ let data = shim:: database_proxy:: Host :: recv ( & mut state, handle, 1024 ) . unwrap ( ) ;
858+ assert ! ( !data. is_empty( ) ) ;
859+ }
773860}
0 commit comments