@@ -79,9 +79,8 @@ impl PostgresStore {
7979 . unwrap ( ) ;
8080
8181 let inner = tokio:: task:: block_in_place ( || {
82- internal_runtime. block_on ( async {
83- PostgresStoreInner :: new ( & connection_string, kv_table_name) . await
84- } )
82+ internal_runtime
83+ . block_on ( async { PostgresStoreInner :: new ( connection_string, kv_table_name) . await } )
8584 } ) ?;
8685
8786 let inner = Arc :: new ( inner) ;
@@ -310,13 +309,14 @@ impl PaginatedKVStore for PostgresStore {
310309
311310struct PostgresStoreInner {
312311 client : tokio:: sync:: Mutex < Client > ,
312+ connection_string : String ,
313313 kv_table_name : String ,
314314 write_version_locks : Mutex < HashMap < String , Arc < tokio:: sync:: Mutex < u64 > > > > ,
315315 next_sort_order : AtomicI64 ,
316316}
317317
318318impl PostgresStoreInner {
319- async fn new ( connection_string : & str , kv_table_name : Option < String > ) -> io:: Result < Self > {
319+ async fn new ( connection_string : String , kv_table_name : Option < String > ) -> io:: Result < Self > {
320320 let kv_table_name = kv_table_name. unwrap_or ( DEFAULT_KV_TABLE_NAME . to_string ( ) ) ;
321321
322322 // If a dbname is specified in the connection string, ensure the database exists
@@ -327,20 +327,10 @@ impl PostgresStoreInner {
327327 } ) ?;
328328
329329 if let Some ( db_name) = config. get_dbname ( ) {
330- Self :: create_database_if_not_exists ( connection_string, db_name) . await ?;
330+ Self :: create_database_if_not_exists ( & connection_string, db_name) . await ?;
331331 }
332332
333- let ( client, connection) = connect ( connection_string, NoTls ) . await . map_err ( |e| {
334- let msg = format ! ( "Failed to connect to PostgreSQL: {e}" ) ;
335- io:: Error :: new ( io:: ErrorKind :: Other , msg)
336- } ) ?;
337-
338- // Spawn the connection task so it runs in the background.
339- tokio:: spawn ( async move {
340- if let Err ( e) = connection. await {
341- log:: error!( "PostgreSQL connection error: {e}" ) ;
342- }
343- } ) ;
333+ let client = Self :: make_connection ( & connection_string) . await ?;
344334
345335 // Create the KV data table if it doesn't exist.
346336 let sql = format ! (
@@ -409,7 +399,7 @@ impl PostgresStoreInner {
409399
410400 let client = tokio:: sync:: Mutex :: new ( client) ;
411401 let write_version_locks = Mutex :: new ( HashMap :: new ( ) ) ;
412- Ok ( Self { client, kv_table_name, write_version_locks, next_sort_order } )
402+ Ok ( Self { client, connection_string , kv_table_name, write_version_locks, next_sort_order } )
413403 }
414404
415405 async fn create_database_if_not_exists (
@@ -453,6 +443,32 @@ impl PostgresStoreInner {
453443 Ok ( ( ) )
454444 }
455445
446+ async fn make_connection ( connection_string : & str ) -> io:: Result < Client > {
447+ let ( client, connection) = connect ( connection_string, NoTls ) . await . map_err ( |e| {
448+ let msg = format ! ( "Failed to connect to PostgreSQL: {e}" ) ;
449+ io:: Error :: new ( io:: ErrorKind :: Other , msg)
450+ } ) ?;
451+
452+ tokio:: spawn ( async move {
453+ if let Err ( e) = connection. await {
454+ log:: error!( "PostgreSQL connection error: {e}" ) ;
455+ }
456+ } ) ;
457+
458+ Ok ( client)
459+ }
460+
461+ async fn ensure_connected (
462+ & self , client : & mut tokio:: sync:: MutexGuard < ' _ , Client > ,
463+ ) -> io:: Result < ( ) > {
464+ if client. is_closed ( ) || client. check_connection ( ) . await . is_err ( ) {
465+ log:: debug!( "Reconnecting to PostgreSQL database" ) ;
466+ let new_client = Self :: make_connection ( & self . connection_string ) . await ?;
467+ * * client = new_client;
468+ }
469+ Ok ( ( ) )
470+ }
471+
456472 fn get_inner_lock_ref ( & self , locking_key : String ) -> Arc < tokio:: sync:: Mutex < u64 > > {
457473 let mut outer_lock = self . write_version_locks . lock ( ) . unwrap ( ) ;
458474 Arc :: clone ( & outer_lock. entry ( locking_key) . or_default ( ) )
@@ -463,7 +479,8 @@ impl PostgresStoreInner {
463479 ) -> io:: Result < Vec < u8 > > {
464480 check_namespace_key_validity ( primary_namespace, secondary_namespace, Some ( key) , "read" ) ?;
465481
466- let locked_client = self . client . lock ( ) . await ;
482+ let mut locked_client = self . client . lock ( ) . await ;
483+ self . ensure_connected ( & mut locked_client) . await ?;
467484 let sql = format ! (
468485 "SELECT value FROM {} WHERE primary_namespace=$1 AND secondary_namespace=$2 AND key=$3" ,
469486 self . kv_table_name
@@ -507,7 +524,8 @@ impl PostgresStoreInner {
507524 check_namespace_key_validity ( primary_namespace, secondary_namespace, Some ( key) , "write" ) ?;
508525
509526 self . execute_locked_write ( inner_lock_ref, locking_key, version, async move || {
510- let locked_client = self . client . lock ( ) . await ;
527+ let mut locked_client = self . client . lock ( ) . await ;
528+ self . ensure_connected ( & mut locked_client) . await ?;
511529
512530 let sort_order = self . next_sort_order . fetch_add ( 1 , Ordering :: Relaxed ) ;
513531
@@ -552,7 +570,8 @@ impl PostgresStoreInner {
552570 check_namespace_key_validity ( primary_namespace, secondary_namespace, Some ( key) , "remove" ) ?;
553571
554572 self . execute_locked_write ( inner_lock_ref, locking_key, version, async move || {
555- let locked_client = self . client . lock ( ) . await ;
573+ let mut locked_client = self . client . lock ( ) . await ;
574+ self . ensure_connected ( & mut locked_client) . await ?;
556575
557576 let sql = format ! (
558577 "DELETE FROM {} WHERE primary_namespace=$1 AND secondary_namespace=$2 AND key=$3" ,
@@ -582,7 +601,8 @@ impl PostgresStoreInner {
582601 ) -> io:: Result < Vec < String > > {
583602 check_namespace_key_validity ( primary_namespace, secondary_namespace, None , "list" ) ?;
584603
585- let locked_client = self . client . lock ( ) . await ;
604+ let mut locked_client = self . client . lock ( ) . await ;
605+ self . ensure_connected ( & mut locked_client) . await ?;
586606
587607 let sql = format ! (
588608 "SELECT key FROM {} WHERE primary_namespace=$1 AND secondary_namespace=$2" ,
@@ -611,7 +631,8 @@ impl PostgresStoreInner {
611631 "list_paginated" ,
612632 ) ?;
613633
614- let locked_client = self . client . lock ( ) . await ;
634+ let mut locked_client = self . client . lock ( ) . await ;
635+ self . ensure_connected ( & mut locked_client) . await ?;
615636
616637 // Fetch one extra row beyond PAGE_SIZE to determine whether a next page exists.
617638 let fetch_limit = ( PAGE_SIZE + 1 ) as i64 ;
@@ -772,6 +793,40 @@ mod tests {
772793 cleanup_store ( & store_1) ;
773794 }
774795
796+ #[ test]
797+ fn test_postgres_store_auto_reconnect ( ) {
798+ let store = create_test_store ( "test_pg_reconnect" ) ;
799+
800+ let ns = "test_ns" ;
801+ let sub = "test_sub" ;
802+
803+ // Write a value before disconnecting.
804+ KVStoreSync :: write ( & store, ns, sub, "key_a" , vec ! [ 1u8 ; 8 ] ) . unwrap ( ) ;
805+
806+ // Terminate the backend connection to simulate a dropped connection.
807+ if let Some ( ref runtime) = store. internal_runtime {
808+ let inner = Arc :: clone ( & store. inner ) ;
809+ tokio:: task:: block_in_place ( || {
810+ runtime. block_on ( async {
811+ let client = inner. client . lock ( ) . await ;
812+ let _ =
813+ client. execute ( "SELECT pg_terminate_backend(pg_backend_pid())" , & [ ] ) . await ;
814+ } )
815+ } ) ;
816+ }
817+
818+ // Read should auto-reconnect and return the previously written value.
819+ let data = KVStoreSync :: read ( & store, ns, sub, "key_a" ) . unwrap ( ) ;
820+ assert_eq ! ( data, vec![ 1u8 ; 8 ] ) ;
821+
822+ // Write should also work after reconnect.
823+ KVStoreSync :: write ( & store, ns, sub, "key_b" , vec ! [ 2u8 ; 8 ] ) . unwrap ( ) ;
824+ let data = KVStoreSync :: read ( & store, ns, sub, "key_b" ) . unwrap ( ) ;
825+ assert_eq ! ( data, vec![ 2u8 ; 8 ] ) ;
826+
827+ cleanup_store ( & store) ;
828+ }
829+
775830 #[ test]
776831 fn test_postgres_store_paginated_listing ( ) {
777832 let store = create_test_store ( "test_pg_paginated" ) ;
0 commit comments