@@ -3252,20 +3252,32 @@ fn field_types(
32523252
32533253#[ cfg( test) ]
32543254mod tests {
3255- use std:: time:: { Duration , Instant } ;
3255+ use std:: {
3256+ io:: BufReader ,
3257+ time:: { Duration , Instant } ,
3258+ } ;
32563259
3260+ use camino:: Utf8PathBuf ;
32573261 use chrono:: { DateTime , Utc } ;
3258- use corro_tests:: launch_test_agent;
3262+ use corro_tests:: { launch_test_agent, TestAgent } ;
3263+ use corro_types:: {
3264+ config:: PgTlsConfig ,
3265+ tls:: { generate_ca, generate_client_cert, generate_server_cert} ,
3266+ } ;
3267+ use rcgen:: Certificate ;
32593268 use spawn:: wait_for_all_pending_handles;
3269+ use tempfile:: TempDir ;
32603270 use tokio_postgres:: NoTls ;
3271+ use tokio_postgres_rustls:: MakeRustlsConnect ;
32613272 use tripwire:: Tripwire ;
32623273
32633274 use super :: * ;
32643275
3265- #[ tokio:: test( flavor = "multi_thread" ) ]
3266- async fn test_pg ( ) -> Result < ( ) , BoxError > {
3276+ async fn setup_pg_test_server (
3277+ tripwire : Tripwire ,
3278+ tls_config : Option < PgTlsConfig > ,
3279+ ) -> Result < ( TestAgent , PgServer ) , BoxError > {
32673280 _ = tracing_subscriber:: fmt:: try_init ( ) ;
3268- let ( tripwire, tripwire_worker, tripwire_tx) = Tripwire :: new_simple ( ) ;
32693281
32703282 let tmpdir = tempfile:: tempdir ( ) ?;
32713283
@@ -3291,18 +3303,27 @@ mod tests {
32913303 )
32923304 . await ?;
32933305
3294- let sema = ta. agent . write_sema ( ) . clone ( ) ;
3295-
32963306 let server = start (
32973307 ta. agent . clone ( ) ,
32983308 PgConfig {
32993309 bind_addr : "127.0.0.1:0" . parse ( ) ?,
3300- tls : None ,
3310+ tls : tls_config ,
33013311 } ,
33023312 tripwire,
33033313 )
33043314 . await ?;
33053315
3316+ Ok ( ( ta, server) )
3317+ }
3318+
3319+ #[ tokio:: test( flavor = "multi_thread" ) ]
3320+ async fn test_pg ( ) -> Result < ( ) , BoxError > {
3321+ let ( tripwire, tripwire_worker, tripwire_tx) = Tripwire :: new_simple ( ) ;
3322+
3323+ let ( ta, server) = setup_pg_test_server ( tripwire, None ) . await ?;
3324+
3325+ let sema = ta. agent . write_sema ( ) . clone ( ) ;
3326+
33063327 let conn_str = format ! (
33073328 "host={} port={} user=testuser" ,
33083329 server. local_addr. ip( ) ,
@@ -3472,4 +3493,191 @@ mod tests {
34723493
34733494 Ok ( ( ) )
34743495 }
3496+
3497+ struct TestCertificates {
3498+ ca_cert : Certificate ,
3499+ client_cert_signed : String ,
3500+ client_key : Vec < u8 > ,
3501+ ca_file : Utf8PathBuf ,
3502+ server_cert_file : Utf8PathBuf ,
3503+ server_key_file : Utf8PathBuf ,
3504+ }
3505+
3506+ async fn generate_and_write_certs ( tmpdir : & TempDir ) -> Result < TestCertificates , BoxError > {
3507+ let ca_cert = generate_ca ( ) ?;
3508+ let ( server_cert, server_cert_signed) = generate_server_cert (
3509+ & ca_cert. serialize_pem ( ) ?,
3510+ & ca_cert. serialize_private_key_pem ( ) ,
3511+ "127.0.0.1" . parse ( ) ?,
3512+ ) ?;
3513+
3514+ let ( client_cert, client_cert_signed) = generate_client_cert (
3515+ & ca_cert. serialize_pem ( ) ?,
3516+ & ca_cert. serialize_private_key_pem ( ) ,
3517+ ) ?;
3518+
3519+ let base_path = Utf8PathBuf :: from ( tmpdir. path ( ) . display ( ) . to_string ( ) ) ;
3520+
3521+ let cert_file = base_path. join ( "cert.pem" ) ;
3522+ let key_file = base_path. join ( "cert.key" ) ;
3523+ let ca_file = base_path. join ( "ca.pem" ) ;
3524+
3525+ let client_cert_file = base_path. join ( "client-cert.pem" ) ;
3526+ let client_key_file = base_path. join ( "client-cert.key" ) ;
3527+
3528+ tokio:: fs:: write ( & cert_file, & server_cert_signed) . await ?;
3529+ tokio:: fs:: write ( & key_file, server_cert. serialize_private_key_pem ( ) ) . await ?;
3530+
3531+ tokio:: fs:: write ( & ca_file, ca_cert. serialize_pem ( ) ?) . await ?;
3532+
3533+ tokio:: fs:: write ( & client_cert_file, & client_cert_signed) . await ?;
3534+ tokio:: fs:: write ( & client_key_file, client_cert. serialize_private_key_pem ( ) ) . await ?;
3535+
3536+ Ok ( TestCertificates {
3537+ server_cert_file : cert_file,
3538+ server_key_file : key_file,
3539+ ca_cert,
3540+ client_cert_signed : client_cert_signed,
3541+ client_key : client_cert. serialize_private_key_der ( ) ,
3542+ ca_file,
3543+ } )
3544+ }
3545+
3546+ #[ tokio:: test( flavor = "multi_thread" ) ]
3547+ async fn test_pg_ssl ( ) -> Result < ( ) , BoxError > {
3548+ let ( tripwire, tripwire_worker, tripwire_tx) = Tripwire :: new_simple ( ) ;
3549+
3550+ let tmpdir = TempDir :: new ( ) ?;
3551+ let certs = generate_and_write_certs ( & tmpdir) . await ?;
3552+
3553+ let ( ta, server) = setup_pg_test_server (
3554+ tripwire,
3555+ Some ( PgTlsConfig {
3556+ cert_file : certs. server_cert_file ,
3557+ key_file : certs. server_key_file ,
3558+ ca_file : None ,
3559+ verify_client : false ,
3560+ } ) ,
3561+ )
3562+ . await ?;
3563+
3564+ let sema = ta. agent . write_sema ( ) . clone ( ) ;
3565+
3566+ let conn_str = format ! (
3567+ "host={} port={} user=testuser" ,
3568+ server. local_addr. ip( ) ,
3569+ server. local_addr. port( )
3570+ ) ;
3571+
3572+ {
3573+ let mut root_cert_store = tokio_rustls:: rustls:: RootCertStore :: empty ( ) ;
3574+ root_cert_store. add ( & rustls:: Certificate ( certs. ca_cert . serialize_der ( ) ?) ) ?;
3575+ let config = rustls:: ClientConfig :: builder ( )
3576+ . with_safe_defaults ( )
3577+ . with_root_certificates ( root_cert_store)
3578+ . with_no_client_auth ( ) ;
3579+
3580+ let connector = MakeRustlsConnect :: new ( config) ;
3581+
3582+ println ! ( "connecting to: {conn_str}" ) ;
3583+
3584+ let ( client, client_conn) = tokio_postgres:: connect ( & conn_str, connector) . await ?;
3585+
3586+ tokio:: spawn ( client_conn) ;
3587+
3588+ let _permit = sema. acquire ( ) . await ;
3589+
3590+ println ! ( "before query" ) ;
3591+
3592+ client. simple_query ( "SELECT 1" ) . await ?;
3593+ }
3594+
3595+ tripwire_tx. send ( ( ) ) . await . ok ( ) ;
3596+ tripwire_worker. await ;
3597+ wait_for_all_pending_handles ( ) . await ;
3598+
3599+ Ok ( ( ) )
3600+ }
3601+
3602+ #[ tokio:: test( flavor = "multi_thread" ) ]
3603+ async fn test_pg_mtls ( ) -> Result < ( ) , BoxError > {
3604+ let ( tripwire, tripwire_worker, tripwire_tx) = Tripwire :: new_simple ( ) ;
3605+
3606+ let tmpdir = TempDir :: new ( ) ?;
3607+
3608+ let certs = generate_and_write_certs ( & tmpdir) . await ?;
3609+
3610+ let ( ta, server) = setup_pg_test_server (
3611+ tripwire,
3612+ Some ( PgTlsConfig {
3613+ cert_file : certs. server_cert_file ,
3614+ key_file : certs. server_key_file ,
3615+ ca_file : Some ( certs. ca_file ) ,
3616+ verify_client : true ,
3617+ } ) ,
3618+ )
3619+ . await ?;
3620+
3621+ let sema = ta. agent . write_sema ( ) . clone ( ) ;
3622+
3623+ let conn_str = format ! (
3624+ "host={} port={} user=testuser" ,
3625+ server. local_addr. ip( ) ,
3626+ server. local_addr. port( )
3627+ ) ;
3628+
3629+ {
3630+ let mut root_cert_store = tokio_rustls:: rustls:: RootCertStore :: empty ( ) ;
3631+ root_cert_store. add ( & rustls:: Certificate ( certs. ca_cert . serialize_der ( ) ?) ) ?;
3632+
3633+ let client_cert =
3634+ rustls_pemfile:: certs ( & mut BufReader :: new ( certs. client_cert_signed . as_bytes ( ) ) )
3635+ . map_err ( |e| format ! ( "failed to read client cert: {e}" ) ) ?;
3636+
3637+ let client_cert: Vec < rustls:: Certificate > = client_cert
3638+ . iter ( )
3639+ . map ( |cert| rustls:: Certificate ( cert. clone ( ) ) )
3640+ . collect ( ) ;
3641+
3642+ let config = rustls:: ClientConfig :: builder ( )
3643+ . with_safe_defaults ( )
3644+ . with_root_certificates ( root_cert_store. clone ( ) )
3645+ . with_client_auth_cert ( client_cert, rustls:: PrivateKey ( certs. client_key ) ) ?;
3646+
3647+ let connector = MakeRustlsConnect :: new ( config) ;
3648+
3649+ println ! ( "connecting to: {conn_str} with client auth cert" ) ;
3650+ let ( client, client_conn) = tokio_postgres:: connect ( & conn_str, connector) . await ?;
3651+
3652+ tokio:: spawn ( client_conn) ;
3653+
3654+ println ! ( "successfully connected!" ) ;
3655+
3656+ let _permit = sema. acquire ( ) . await ;
3657+
3658+ client. simple_query ( "SELECT 1" ) . await ?;
3659+
3660+ let config = rustls:: ClientConfig :: builder ( )
3661+ . with_safe_defaults ( )
3662+ . with_root_certificates ( root_cert_store)
3663+ . with_no_client_auth ( ) ;
3664+
3665+ let connector = MakeRustlsConnect :: new ( config) ;
3666+
3667+ println ! ( "connecting to: {conn_str} without client auth cert" ) ;
3668+ let result = tokio_postgres:: connect ( & conn_str, connector) . await ;
3669+ assert ! (
3670+ result. is_err( ) ,
3671+ "expected connect to fail without client auth cert"
3672+ ) ;
3673+
3674+ println ! ( "successfully failed to connect without client auth cert" ) ;
3675+ }
3676+
3677+ tripwire_tx. send ( ( ) ) . await . ok ( ) ;
3678+ tripwire_worker. await ;
3679+ wait_for_all_pending_handles ( ) . await ;
3680+
3681+ Ok ( ( ) )
3682+ }
34753683}
0 commit comments