Skip to content

Commit 3db3f59

Browse files
Add IO runtime handle to postgres connector (#549)
1 parent ebdf355 commit 3db3f59

2 files changed

Lines changed: 64 additions & 3 deletions

File tree

core/src/sql/db_connection_pool/postgrespool.rs

Lines changed: 30 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@ use native_tls::{Certificate, TlsConnector};
1111
use postgres_native_tls::MakeTlsConnector;
1212
use secrecy::{ExposeSecret, SecretBox, SecretString};
1313
use snafu::{prelude::*, ResultExt};
14+
use tokio::runtime::Handle;
1415
use tokio_postgres;
1516

1617
use super::{
@@ -79,6 +80,9 @@ pub enum Error {
7980
PasswordProviderError {
8081
source: Box<dyn std::error::Error + Send + Sync>,
8182
},
83+
84+
#[snafu(display("Task failed to execute on IO runtime.\n{source}"))]
85+
IoRuntimeError { source: tokio::task::JoinError },
8286
}
8387

8488
pub type Result<T, E = Error> = std::result::Result<T, E>;
@@ -187,6 +191,7 @@ pub struct PostgresConnectionPool {
187191
pool: Arc<bb8::Pool<ConnectionManager>>,
188192
join_push_down: JoinPushDown,
189193
unsupported_type_action: UnsupportedTypeAction,
194+
io_handle: Option<Handle>,
190195
}
191196

192197
impl PostgresConnectionPool {
@@ -380,6 +385,7 @@ impl PostgresConnectionPool {
380385
pool: Arc::new(pool),
381386
join_push_down,
382387
unsupported_type_action: UnsupportedTypeAction::default(),
388+
io_handle: None,
383389
})
384390
}
385391

@@ -390,14 +396,28 @@ impl PostgresConnectionPool {
390396
self
391397
}
392398

399+
/// Route all Postgres connection background tasks to a dedicated IO runtime.
400+
#[must_use]
401+
pub fn with_io_runtime(mut self, handle: Handle) -> Self {
402+
self.io_handle = Some(handle);
403+
self
404+
}
405+
393406
/// Returns a direct connection to the underlying database.
394407
///
395408
/// # Errors
396409
///
397410
/// Returns an error if there is a problem creating the connection pool.
398411
pub async fn connect_direct(&self) -> super::Result<PostgresConnection> {
399412
let pool = Arc::clone(&self.pool);
400-
let conn = pool.get_owned().await.map_err(map_pool_run_error)?;
413+
let conn = if let Some(handle) = &self.io_handle {
414+
handle
415+
.spawn(async move { pool.get_owned().await.map_err(map_pool_run_error) })
416+
.await
417+
.context(IoRuntimeSnafu)??
418+
} else {
419+
pool.get_owned().await.map_err(map_pool_run_error)?
420+
};
401421
Ok(PostgresConnection::new(conn))
402422
}
403423
}
@@ -581,8 +601,15 @@ impl
581601
>,
582602
> {
583603
let pool = Arc::clone(&self.pool);
584-
let get_conn = async || pool.get_owned().await.map_err(map_pool_run_error);
585-
let conn = run_async_with_tokio(get_conn).await?;
604+
let conn = if let Some(handle) = &self.io_handle {
605+
handle
606+
.spawn(async move { pool.get_owned().await.map_err(map_pool_run_error) })
607+
.await
608+
.context(IoRuntimeSnafu)??
609+
} else {
610+
let get_conn = async || pool.get_owned().await.map_err(map_pool_run_error);
611+
run_async_with_tokio(get_conn).await?
612+
};
586613
Ok(Box::new(
587614
PostgresConnection::new(conn)
588615
.with_unsupported_type_action(self.unsupported_type_action),

core/tests/postgres/mod.rs

Lines changed: 34 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -453,3 +453,37 @@ async fn arrow_postgres_one_way(
453453

454454
assert_eq!(record_batch[0], expected_record);
455455
}
456+
457+
#[rstest]
458+
#[test_log::test(tokio::test)]
459+
async fn test_postgres_io_runtime_segregation(container_manager: &Mutex<ContainerManager>) {
460+
let mut container_manager = container_manager.lock().await;
461+
if !container_manager.claimed {
462+
container_manager.claimed = true;
463+
start_container(&mut container_manager).await;
464+
}
465+
466+
// Create a separate IO runtime
467+
let io_runtime = tokio::runtime::Builder::new_multi_thread()
468+
.worker_threads(2)
469+
.enable_all()
470+
.build()
471+
.expect("IO runtime should be created");
472+
473+
let pool = common::get_postgres_connection_pool(container_manager.port)
474+
.await
475+
.expect("pool created")
476+
.with_io_runtime(io_runtime.handle().clone());
477+
478+
// Verify the pool works through the IO runtime
479+
let sqltable_pool: Arc<DynPostgresConnectionPool> = Arc::new(pool);
480+
let conn = sqltable_pool.connect().await.expect("connect should work");
481+
let async_conn = conn.as_async().expect("should be async connection");
482+
// Execute a simple query to confirm IO runtime is functional
483+
let stream = async_conn
484+
.query_arrow("SELECT 1 AS val", &[], None)
485+
.await
486+
.expect("query should work");
487+
let batches: Vec<_> = futures::StreamExt::collect(stream).await;
488+
assert!(!batches.is_empty(), "should return results via IO runtime");
489+
}

0 commit comments

Comments
 (0)