Skip to content

Commit 74b285e

Browse files
committed
add some test case for corro-pg tls and mtls
1 parent dc58de3 commit 74b285e

File tree

2 files changed

+219
-8
lines changed

2 files changed

+219
-8
lines changed

crates/corro-pg/Cargo.toml

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -34,3 +34,6 @@ tokio-rustls = "0.24.1"
3434
corro-tests = { path = "../corro-tests" }
3535
tokio-postgres = { version = "0.7.10" }
3636
tracing-subscriber = { workspace = true }
37+
camino = { workspace = true }
38+
tokio-postgres-rustls = "0.10.0"
39+
rcgen = { workspace = true }

crates/corro-pg/src/lib.rs

Lines changed: 216 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -3252,20 +3252,32 @@ fn field_types(
32523252

32533253
#[cfg(test)]
32543254
mod tests {
3255-
use std::time::{Duration, Instant};
3255+
use std::{
3256+
io::BufReader,
3257+
time::{Duration, Instant},
3258+
};
32563259

3260+
use camino::Utf8PathBuf;
32573261
use chrono::{DateTime, Utc};
3258-
use corro_tests::launch_test_agent;
3262+
use corro_tests::{launch_test_agent, TestAgent};
3263+
use corro_types::{
3264+
config::PgTlsConfig,
3265+
tls::{generate_ca, generate_client_cert, generate_server_cert},
3266+
};
3267+
use rcgen::Certificate;
32593268
use spawn::wait_for_all_pending_handles;
3269+
use tempfile::TempDir;
32603270
use tokio_postgres::NoTls;
3271+
use tokio_postgres_rustls::MakeRustlsConnect;
32613272
use tripwire::Tripwire;
32623273

32633274
use super::*;
32643275

3265-
#[tokio::test(flavor = "multi_thread")]
3266-
async fn test_pg() -> Result<(), BoxError> {
3276+
async fn setup_pg_test_server(
3277+
tripwire: Tripwire,
3278+
tls_config: Option<PgTlsConfig>,
3279+
) -> Result<(TestAgent, PgServer), BoxError> {
32673280
_ = tracing_subscriber::fmt::try_init();
3268-
let (tripwire, tripwire_worker, tripwire_tx) = Tripwire::new_simple();
32693281

32703282
let tmpdir = tempfile::tempdir()?;
32713283

@@ -3291,18 +3303,27 @@ mod tests {
32913303
)
32923304
.await?;
32933305

3294-
let sema = ta.agent.write_sema().clone();
3295-
32963306
let server = start(
32973307
ta.agent.clone(),
32983308
PgConfig {
32993309
bind_addr: "127.0.0.1:0".parse()?,
3300-
tls: None,
3310+
tls: tls_config,
33013311
},
33023312
tripwire,
33033313
)
33043314
.await?;
33053315

3316+
Ok((ta, server))
3317+
}
3318+
3319+
#[tokio::test(flavor = "multi_thread")]
3320+
async fn test_pg() -> Result<(), BoxError> {
3321+
let (tripwire, tripwire_worker, tripwire_tx) = Tripwire::new_simple();
3322+
3323+
let (ta, server) = setup_pg_test_server(tripwire, None).await?;
3324+
3325+
let sema = ta.agent.write_sema().clone();
3326+
33063327
let conn_str = format!(
33073328
"host={} port={} user=testuser",
33083329
server.local_addr.ip(),
@@ -3472,4 +3493,191 @@ mod tests {
34723493

34733494
Ok(())
34743495
}
3496+
3497+
struct TestCertificates {
3498+
ca_cert: Certificate,
3499+
client_cert_signed: String,
3500+
client_key: Vec<u8>,
3501+
ca_file: Utf8PathBuf,
3502+
server_cert_file: Utf8PathBuf,
3503+
server_key_file: Utf8PathBuf,
3504+
}
3505+
3506+
async fn generate_and_write_certs(tmpdir: &TempDir) -> Result<TestCertificates, BoxError> {
3507+
let ca_cert = generate_ca()?;
3508+
let (server_cert, server_cert_signed) = generate_server_cert(
3509+
&ca_cert.serialize_pem()?,
3510+
&ca_cert.serialize_private_key_pem(),
3511+
"127.0.0.1".parse()?,
3512+
)?;
3513+
3514+
let (client_cert, client_cert_signed) = generate_client_cert(
3515+
&ca_cert.serialize_pem()?,
3516+
&ca_cert.serialize_private_key_pem(),
3517+
)?;
3518+
3519+
let base_path = Utf8PathBuf::from(tmpdir.path().display().to_string());
3520+
3521+
let cert_file = base_path.join("cert.pem");
3522+
let key_file = base_path.join("cert.key");
3523+
let ca_file = base_path.join("ca.pem");
3524+
3525+
let client_cert_file = base_path.join("client-cert.pem");
3526+
let client_key_file = base_path.join("client-cert.key");
3527+
3528+
tokio::fs::write(&cert_file, &server_cert_signed).await?;
3529+
tokio::fs::write(&key_file, server_cert.serialize_private_key_pem()).await?;
3530+
3531+
tokio::fs::write(&ca_file, ca_cert.serialize_pem()?).await?;
3532+
3533+
tokio::fs::write(&client_cert_file, &client_cert_signed).await?;
3534+
tokio::fs::write(&client_key_file, client_cert.serialize_private_key_pem()).await?;
3535+
3536+
Ok(TestCertificates {
3537+
server_cert_file: cert_file,
3538+
server_key_file: key_file,
3539+
ca_cert,
3540+
client_cert_signed: client_cert_signed,
3541+
client_key: client_cert.serialize_private_key_der(),
3542+
ca_file,
3543+
})
3544+
}
3545+
3546+
#[tokio::test(flavor = "multi_thread")]
3547+
async fn test_pg_ssl() -> Result<(), BoxError> {
3548+
let (tripwire, tripwire_worker, tripwire_tx) = Tripwire::new_simple();
3549+
3550+
let tmpdir = TempDir::new()?;
3551+
let certs = generate_and_write_certs(&tmpdir).await?;
3552+
3553+
let (ta, server) = setup_pg_test_server(
3554+
tripwire,
3555+
Some(PgTlsConfig {
3556+
cert_file: certs.server_cert_file,
3557+
key_file: certs.server_key_file,
3558+
ca_file: None,
3559+
verify_client: false,
3560+
}),
3561+
)
3562+
.await?;
3563+
3564+
let sema = ta.agent.write_sema().clone();
3565+
3566+
let conn_str = format!(
3567+
"host={} port={} user=testuser",
3568+
server.local_addr.ip(),
3569+
server.local_addr.port()
3570+
);
3571+
3572+
{
3573+
let mut root_cert_store = tokio_rustls::rustls::RootCertStore::empty();
3574+
root_cert_store.add(&rustls::Certificate(certs.ca_cert.serialize_der()?))?;
3575+
let config = rustls::ClientConfig::builder()
3576+
.with_safe_defaults()
3577+
.with_root_certificates(root_cert_store)
3578+
.with_no_client_auth();
3579+
3580+
let connector = MakeRustlsConnect::new(config);
3581+
3582+
println!("connecting to: {conn_str}");
3583+
3584+
let (client, client_conn) = tokio_postgres::connect(&conn_str, connector).await?;
3585+
3586+
tokio::spawn(client_conn);
3587+
3588+
let _permit = sema.acquire().await;
3589+
3590+
println!("before query");
3591+
3592+
client.simple_query("SELECT 1").await?;
3593+
}
3594+
3595+
tripwire_tx.send(()).await.ok();
3596+
tripwire_worker.await;
3597+
wait_for_all_pending_handles().await;
3598+
3599+
Ok(())
3600+
}
3601+
3602+
#[tokio::test(flavor = "multi_thread")]
3603+
async fn test_pg_mtls() -> Result<(), BoxError> {
3604+
let (tripwire, tripwire_worker, tripwire_tx) = Tripwire::new_simple();
3605+
3606+
let tmpdir = TempDir::new()?;
3607+
3608+
let certs = generate_and_write_certs(&tmpdir).await?;
3609+
3610+
let (ta, server) = setup_pg_test_server(
3611+
tripwire,
3612+
Some(PgTlsConfig {
3613+
cert_file: certs.server_cert_file,
3614+
key_file: certs.server_key_file,
3615+
ca_file: Some(certs.ca_file),
3616+
verify_client: true,
3617+
}),
3618+
)
3619+
.await?;
3620+
3621+
let sema = ta.agent.write_sema().clone();
3622+
3623+
let conn_str = format!(
3624+
"host={} port={} user=testuser",
3625+
server.local_addr.ip(),
3626+
server.local_addr.port()
3627+
);
3628+
3629+
{
3630+
let mut root_cert_store = tokio_rustls::rustls::RootCertStore::empty();
3631+
root_cert_store.add(&rustls::Certificate(certs.ca_cert.serialize_der()?))?;
3632+
3633+
let client_cert =
3634+
rustls_pemfile::certs(&mut BufReader::new(certs.client_cert_signed.as_bytes()))
3635+
.map_err(|e| format!("failed to read client cert: {e}"))?;
3636+
3637+
let client_cert: Vec<rustls::Certificate> = client_cert
3638+
.iter()
3639+
.map(|cert| rustls::Certificate(cert.clone()))
3640+
.collect();
3641+
3642+
let config = rustls::ClientConfig::builder()
3643+
.with_safe_defaults()
3644+
.with_root_certificates(root_cert_store.clone())
3645+
.with_client_auth_cert(client_cert, rustls::PrivateKey(certs.client_key))?;
3646+
3647+
let connector = MakeRustlsConnect::new(config);
3648+
3649+
println!("connecting to: {conn_str} with client auth cert");
3650+
let (client, client_conn) = tokio_postgres::connect(&conn_str, connector).await?;
3651+
3652+
tokio::spawn(client_conn);
3653+
3654+
println!("successfully connected!");
3655+
3656+
let _permit = sema.acquire().await;
3657+
3658+
client.simple_query("SELECT 1").await?;
3659+
3660+
let config = rustls::ClientConfig::builder()
3661+
.with_safe_defaults()
3662+
.with_root_certificates(root_cert_store)
3663+
.with_no_client_auth();
3664+
3665+
let connector = MakeRustlsConnect::new(config);
3666+
3667+
println!("connecting to: {conn_str} without client auth cert");
3668+
let result = tokio_postgres::connect(&conn_str, connector).await;
3669+
assert!(
3670+
result.is_err(),
3671+
"expected connect to fail without client auth cert"
3672+
);
3673+
3674+
println!("successfully failed to connect without client auth cert");
3675+
}
3676+
3677+
tripwire_tx.send(()).await.ok();
3678+
tripwire_worker.await;
3679+
wait_for_all_pending_handles().await;
3680+
3681+
Ok(())
3682+
}
34753683
}

0 commit comments

Comments
 (0)