@@ -16,12 +16,17 @@ use lightning::util::persist::{
1616 KVStore , KVStoreSync , PageToken , PaginatedKVStore , PaginatedKVStoreSync , PaginatedListResponse ,
1717} ;
1818use lightning_types:: string:: PrintableString ;
19- use tokio_postgres:: { connect, Client , Config , Error as PgError , NoTls } ;
19+ use native_tls:: TlsConnector ;
20+ use postgres_native_tls:: MakeTlsConnector ;
21+ use tokio_postgres:: { Client , Config , Error as PgError , NoTls } ;
2022
2123use crate :: io:: utils:: check_namespace_key_validity;
2224
2325mod migrations;
2426
27+ /// The default database name used when none is specified.
28+ pub const DEFAULT_DB_NAME : & str = "ldk_node" ;
29+
2530/// The default table in which we store all data.
2631pub const DEFAULT_KV_TABLE_NAME : & str = "ldk_data" ;
2732
@@ -61,11 +66,21 @@ impl PostgresStore {
6166 ///
6267 /// Connects to the PostgreSQL database at the given `connection_string`.
6368 ///
64- /// If the connection string includes a `dbname`, the database will be created automatically
65- /// if it doesn't already exist.
69+ /// The given `db_name` will be used or default to [`DEFAULT_DB_NAME`]. The database will be
70+ /// created automatically if it doesn't already exist.
6671 ///
6772 /// The given `kv_table_name` will be used or default to [`DEFAULT_KV_TABLE_NAME`].
68- pub fn new ( connection_string : String , kv_table_name : Option < String > ) -> io:: Result < Self > {
73+ ///
74+ /// If `tls_config` is `Some`, TLS will be used for database connections. A custom CA
75+ /// certificate can be provided via [`PostgresTlsConfig::certificate_pem`], otherwise the
76+ /// system's default root certificates are used. If `tls_config` is `None`, connections
77+ /// will be unencrypted.
78+ pub fn new (
79+ connection_string : String , db_name : Option < String > , kv_table_name : Option < String > ,
80+ tls_config : Option < PostgresTlsConfig > ,
81+ ) -> io:: Result < Self > {
82+ let tls = Self :: build_tls_connector ( tls_config) ?;
83+
6984 let internal_runtime = tokio:: runtime:: Builder :: new_multi_thread ( )
7085 . enable_all ( )
7186 . thread_name_fn ( || {
@@ -79,15 +94,41 @@ impl PostgresStore {
7994 . unwrap ( ) ;
8095
8196 let inner = tokio:: task:: block_in_place ( || {
82- internal_runtime
83- . block_on ( async { PostgresStoreInner :: new ( connection_string, kv_table_name) . await } )
97+ internal_runtime. block_on ( async {
98+ PostgresStoreInner :: new ( connection_string, db_name, kv_table_name, tls) . await
99+ } )
84100 } ) ?;
85101
86102 let inner = Arc :: new ( inner) ;
87103 let next_write_version = AtomicU64 :: new ( 1 ) ;
88104 Ok ( Self { inner, next_write_version, internal_runtime : Some ( internal_runtime) } )
89105 }
90106
107+ fn build_tls_connector ( tls_config : Option < PostgresTlsConfig > ) -> io:: Result < PgTlsConnector > {
108+ match tls_config {
109+ Some ( config) => {
110+ let mut builder = TlsConnector :: builder ( ) ;
111+ if let Some ( pem) = config. certificate_pem {
112+ let crt = native_tls:: Certificate :: from_pem ( pem. as_bytes ( ) ) . map_err ( |e| {
113+ io:: Error :: new (
114+ io:: ErrorKind :: InvalidInput ,
115+ format ! ( "Failed to parse PEM certificate: {e}" ) ,
116+ )
117+ } ) ?;
118+ builder. add_root_certificate ( crt) ;
119+ }
120+ let connector = builder. build ( ) . map_err ( |e| {
121+ io:: Error :: new (
122+ io:: ErrorKind :: Other ,
123+ format ! ( "Failed to build TLS connector: {e}" ) ,
124+ )
125+ } ) ?;
126+ Ok ( PgTlsConnector :: NativeTls ( MakeTlsConnector :: new ( connector) ) )
127+ } ,
128+ None => Ok ( PgTlsConnector :: Plain ) ,
129+ }
130+ }
131+
91132 fn build_locking_key (
92133 & self , primary_namespace : & str , secondary_namespace : & str , key : & str ,
93134 ) -> String {
@@ -309,28 +350,39 @@ impl PaginatedKVStore for PostgresStore {
309350
310351struct PostgresStoreInner {
311352 client : tokio:: sync:: Mutex < Client > ,
312- connection_string : String ,
353+ config : Config ,
313354 kv_table_name : String ,
355+ tls : PgTlsConnector ,
314356 write_version_locks : Mutex < HashMap < String , Arc < tokio:: sync:: Mutex < u64 > > > > ,
315357 next_sort_order : AtomicI64 ,
316358}
317359
318360impl PostgresStoreInner {
319- async fn new ( connection_string : String , kv_table_name : Option < String > ) -> io:: Result < Self > {
361+ async fn new (
362+ connection_string : String , db_name : Option < String > , kv_table_name : Option < String > ,
363+ tls : PgTlsConnector ,
364+ ) -> io:: Result < Self > {
320365 let kv_table_name = kv_table_name. unwrap_or ( DEFAULT_KV_TABLE_NAME . to_string ( ) ) ;
321366
322- // If a dbname is specified in the connection string, ensure the database exists
323- // by first connecting without a dbname and creating it if necessary.
324- let config: Config = connection_string. parse ( ) . map_err ( |e : PgError | {
367+ let mut config: Config = connection_string. parse ( ) . map_err ( |e : PgError | {
325368 let msg = format ! ( "Failed to parse PostgreSQL connection string: {e}" ) ;
326369 io:: Error :: new ( io:: ErrorKind :: InvalidInput , msg)
327370 } ) ?;
328371
329- if let Some ( db_name) = config. get_dbname ( ) {
330- Self :: create_database_if_not_exists ( & connection_string, db_name) . await ?;
372+ if db_name. is_some ( ) && config. get_dbname ( ) . is_some ( ) {
373+ return Err ( io:: Error :: new (
374+ io:: ErrorKind :: InvalidInput ,
375+ "db_name must not be set when the connection string already contains a dbname" ,
376+ ) ) ;
331377 }
332378
333- let client = Self :: make_connection ( & connection_string) . await ?;
379+ let db_name = db_name
380+ . or_else ( || config. get_dbname ( ) . map ( |s| s. to_string ( ) ) )
381+ . unwrap_or ( DEFAULT_DB_NAME . to_string ( ) ) ;
382+ config. dbname ( & db_name) ;
383+ Self :: create_database_if_not_exists ( & config, & db_name, & tls) . await ?;
384+
385+ let client = Self :: make_config_connection ( & config, & tls) . await ?;
334386
335387 // Create the KV data table if it doesn't exist.
336388 let sql = format ! (
@@ -399,29 +451,17 @@ impl PostgresStoreInner {
399451
400452 let client = tokio:: sync:: Mutex :: new ( client) ;
401453 let write_version_locks = Mutex :: new ( HashMap :: new ( ) ) ;
402- Ok ( Self { client, connection_string , kv_table_name, write_version_locks, next_sort_order } )
454+ Ok ( Self { client, config , kv_table_name, tls , write_version_locks, next_sort_order } )
403455 }
404456
405457 async fn create_database_if_not_exists (
406- connection_string : & str , db_name : & str ,
458+ config : & Config , db_name : & str , tls : & PgTlsConnector ,
407459 ) -> io:: Result < ( ) > {
408460 // Connect without a dbname (to the default database) so we can create the target.
409- let mut config: Config = connection_string. parse ( ) . map_err ( |e : PgError | {
410- let msg = format ! ( "Failed to parse PostgreSQL connection string: {e}" ) ;
411- io:: Error :: new ( io:: ErrorKind :: InvalidInput , msg)
412- } ) ?;
461+ let mut config = config. clone ( ) ;
413462 config. dbname ( "postgres" ) ;
414463
415- let ( client, connection) = config. connect ( NoTls ) . await . map_err ( |e| {
416- let msg = format ! ( "Failed to connect to PostgreSQL: {e}" ) ;
417- io:: Error :: new ( io:: ErrorKind :: Other , msg)
418- } ) ?;
419-
420- tokio:: spawn ( async move {
421- if let Err ( e) = connection. await {
422- log:: error!( "PostgreSQL connection error: {e}" ) ;
423- }
424- } ) ;
464+ let client = Self :: make_config_connection ( & config, tls) . await ?;
425465
426466 let row = client
427467 . query_opt ( "SELECT 1 FROM pg_database WHERE datname = $1" , & [ & db_name] )
@@ -443,27 +483,41 @@ impl PostgresStoreInner {
443483 Ok ( ( ) )
444484 }
445485
446- async fn make_connection ( connection_string : & str ) -> io:: Result < Client > {
447- let ( client , connection ) = connect ( connection_string , NoTls ) . await . map_err ( |e| {
486+ async fn make_config_connection ( config : & Config , tls : & PgTlsConnector ) -> io:: Result < Client > {
487+ let err_map = |e| {
448488 let msg = format ! ( "Failed to connect to PostgreSQL: {e}" ) ;
449489 io:: Error :: new ( io:: ErrorKind :: Other , msg)
450- } ) ?;
451-
452- tokio:: spawn ( async move {
453- if let Err ( e) = connection. await {
454- log:: error!( "PostgreSQL connection error: {e}" ) ;
455- }
456- } ) ;
490+ } ;
457491
458- Ok ( client)
492+ match tls {
493+ PgTlsConnector :: Plain => {
494+ let ( client, connection) = config. connect ( NoTls ) . await . map_err ( err_map) ?;
495+ tokio:: spawn ( async move {
496+ if let Err ( e) = connection. await {
497+ log:: error!( "PostgreSQL connection error: {e}" ) ;
498+ }
499+ } ) ;
500+ Ok ( client)
501+ } ,
502+ PgTlsConnector :: NativeTls ( tls_connector) => {
503+ let ( client, connection) =
504+ config. connect ( tls_connector. clone ( ) ) . await . map_err ( err_map) ?;
505+ tokio:: spawn ( async move {
506+ if let Err ( e) = connection. await {
507+ log:: error!( "PostgreSQL connection error: {e}" ) ;
508+ }
509+ } ) ;
510+ Ok ( client)
511+ } ,
512+ }
459513 }
460514
461515 async fn ensure_connected (
462516 & self , client : & mut tokio:: sync:: MutexGuard < ' _ , Client > ,
463517 ) -> io:: Result < ( ) > {
464518 if client. is_closed ( ) || client. check_connection ( ) . await . is_err ( ) {
465519 log:: debug!( "Reconnecting to PostgreSQL database" ) ;
466- let new_client = Self :: make_connection ( & self . connection_string ) . await ?;
520+ let new_client = Self :: make_config_connection ( & self . config , & self . tls ) . await ?;
467521 * * client = new_client;
468522 }
469523 Ok ( ( ) )
@@ -750,6 +804,19 @@ impl PostgresStoreInner {
750804 }
751805}
752806
807+ /// TLS configuration for PostgreSQL connections.
808+ #[ derive( Debug , Clone ) ]
809+ pub struct PostgresTlsConfig {
810+ /// PEM-encoded CA certificate. If `None`, the system's default root certificates are used.
811+ pub certificate_pem : Option < String > ,
812+ }
813+
814+ #[ derive( Clone ) ]
815+ enum PgTlsConnector {
816+ Plain ,
817+ NativeTls ( MakeTlsConnector ) ,
818+ }
819+
753820#[ cfg( test) ]
754821mod tests {
755822 use super :: * ;
@@ -761,7 +828,8 @@ mod tests {
761828 }
762829
763830 fn create_test_store ( table_name : & str ) -> PostgresStore {
764- PostgresStore :: new ( test_connection_string ( ) , Some ( table_name. to_string ( ) ) ) . unwrap ( )
831+ PostgresStore :: new ( test_connection_string ( ) , None , Some ( table_name. to_string ( ) ) , None )
832+ . unwrap ( )
765833 }
766834
767835 fn cleanup_store ( store : & PostgresStore ) {
@@ -1092,4 +1160,25 @@ mod tests {
10921160 cleanup_store ( & store) ;
10931161 }
10941162 }
1163+
1164+ #[ test]
1165+ fn test_tls_config_none_builds_plain_connector ( ) {
1166+ let connector = PostgresStore :: build_tls_connector ( None ) . unwrap ( ) ;
1167+ assert ! ( matches!( connector, PgTlsConnector :: Plain ) ) ;
1168+ }
1169+
1170+ #[ test]
1171+ fn test_tls_config_system_certs_builds_native_tls_connector ( ) {
1172+ let config = Some ( PostgresTlsConfig { certificate_pem : None } ) ;
1173+ let connector = PostgresStore :: build_tls_connector ( config) . unwrap ( ) ;
1174+ assert ! ( matches!( connector, PgTlsConnector :: NativeTls ( _) ) ) ;
1175+ }
1176+
1177+ #[ test]
1178+ fn test_tls_config_invalid_pem_returns_error ( ) {
1179+ let config =
1180+ Some ( PostgresTlsConfig { certificate_pem : Some ( "not-a-valid-pem" . to_string ( ) ) } ) ;
1181+ let result = PostgresStore :: build_tls_connector ( config) ;
1182+ assert ! ( result. is_err( ) ) ;
1183+ }
10951184}
0 commit comments