Skip to content

Commit 7dca161

Browse files
committed
Pool Postgres connections
Signed-off-by: itowlson <[email protected]>
1 parent 3870f54 commit 7dca161

File tree

6 files changed

+152
-70
lines changed

6 files changed

+152
-70
lines changed

Cargo.lock

+35
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

crates/factor-outbound-pg/Cargo.toml

+1
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@ edition = { workspace = true }
77
[dependencies]
88
anyhow = { workspace = true }
99
chrono = "0.4"
10+
deadpool-postgres = { version = "0.14", features = ["rt_tokio_1"] }
1011
native-tls = "0.2"
1112
postgres-native-tls = "0.5"
1213
spin-core = { path = "../core" }

crates/factor-outbound-pg/src/client.rs

+74-41
Original file line numberDiff line numberDiff line change
@@ -1,20 +1,83 @@
1-
use anyhow::{anyhow, Result};
1+
use anyhow::{anyhow, Context, Result};
22
use native_tls::TlsConnector;
33
use postgres_native_tls::MakeTlsConnector;
44
use spin_world::async_trait;
55
use spin_world::spin::postgres::postgres::{
66
self as v3, Column, DbDataType, DbValue, ParameterValue, RowSet,
77
};
88
use tokio_postgres::types::Type;
9-
use tokio_postgres::{config::SslMode, types::ToSql, Row};
10-
use tokio_postgres::{Client as TokioClient, NoTls, Socket};
9+
use tokio_postgres::{config::SslMode, types::ToSql, NoTls, Row};
10+
11+
const CONNECTION_POOL_SIZE: usize = 64;
1112

1213
#[async_trait]
13-
pub trait Client {
14-
async fn build_client(address: &str) -> Result<Self>
15-
where
16-
Self: Sized;
14+
pub trait ClientFactory: Send + Sync {
15+
type Client: Client + Send + Sync + 'static;
16+
fn new() -> Self;
17+
async fn build_client(&mut self, address: &str) -> Result<Self::Client>;
18+
}
19+
20+
pub struct PooledTokioClientFactory {
21+
pools: std::collections::HashMap<String, deadpool_postgres::Pool>,
22+
}
23+
24+
#[async_trait]
25+
impl ClientFactory for PooledTokioClientFactory {
26+
type Client = deadpool_postgres::Object;
27+
fn new() -> Self {
28+
Self {
29+
pools: Default::default(),
30+
}
31+
}
32+
async fn build_client(&mut self, address: &str) -> Result<Self::Client> {
33+
let pool_entry = self.pools.entry(address.to_owned());
34+
let pool = match pool_entry {
35+
std::collections::hash_map::Entry::Occupied(entry) => entry.into_mut(),
36+
std::collections::hash_map::Entry::Vacant(entry) => {
37+
let pool = create_connection_pool(address)
38+
.context("establishing PostgreSQL connection pool")?;
39+
entry.insert(pool)
40+
}
41+
};
42+
43+
Ok(pool.get().await?)
44+
}
45+
}
46+
47+
fn create_connection_pool(address: &str) -> Result<deadpool_postgres::Pool> {
48+
let config = address
49+
.parse::<tokio_postgres::Config>()
50+
.context("parsing Postgres connection string")?;
51+
52+
tracing::debug!("Build new connection: {}", address);
53+
54+
// TODO: This is slower but safer. Is it the right tradeoff?
55+
// https://docs.rs/deadpool-postgres/latest/deadpool_postgres/enum.RecyclingMethod.html
56+
let mgr_config = deadpool_postgres::ManagerConfig {
57+
recycling_method: deadpool_postgres::RecyclingMethod::Clean,
58+
};
59+
60+
let mgr = if config.get_ssl_mode() == SslMode::Disable {
61+
deadpool_postgres::Manager::from_config(config, NoTls, mgr_config)
62+
} else {
63+
let builder = TlsConnector::builder();
64+
let connector = MakeTlsConnector::new(builder.build()?);
65+
deadpool_postgres::Manager::from_config(config, connector, mgr_config)
66+
};
1767

68+
// TODO: what is our max size heuristic? Should this be passed in soe that different
69+
// hosts can manage it according to their needs? Will a plain number suffice for
70+
// sophisticated hosts anyway?
71+
let pool = deadpool_postgres::Pool::builder(mgr)
72+
.max_size(CONNECTION_POOL_SIZE)
73+
.build()
74+
.context("building Postgres connection pool")?;
75+
76+
Ok(pool)
77+
}
78+
79+
#[async_trait]
80+
pub trait Client {
1881
async fn execute(
1982
&self,
2083
statement: String,
@@ -29,28 +92,7 @@ pub trait Client {
2992
}
3093

3194
#[async_trait]
32-
impl Client for TokioClient {
33-
async fn build_client(address: &str) -> Result<Self>
34-
where
35-
Self: Sized,
36-
{
37-
let config = address.parse::<tokio_postgres::Config>()?;
38-
39-
tracing::debug!("Build new connection: {}", address);
40-
41-
if config.get_ssl_mode() == SslMode::Disable {
42-
let (client, connection) = config.connect(NoTls).await?;
43-
spawn_connection(connection);
44-
Ok(client)
45-
} else {
46-
let builder = TlsConnector::builder();
47-
let connector = MakeTlsConnector::new(builder.build()?);
48-
let (client, connection) = config.connect(connector).await?;
49-
spawn_connection(connection);
50-
Ok(client)
51-
}
52-
}
53-
95+
impl Client for deadpool_postgres::Object {
5496
async fn execute(
5597
&self,
5698
statement: String,
@@ -67,7 +109,8 @@ impl Client for TokioClient {
67109
.map(|b| b.as_ref() as &(dyn ToSql + Sync))
68110
.collect();
69111

70-
self.execute(&statement, params_refs.as_slice())
112+
self.as_ref()
113+
.execute(&statement, params_refs.as_slice())
71114
.await
72115
.map_err(|e| v3::Error::QueryFailed(format!("{:?}", e)))
73116
}
@@ -89,6 +132,7 @@ impl Client for TokioClient {
89132
.collect();
90133

91134
let results = self
135+
.as_ref()
92136
.query(&statement, params_refs.as_slice())
93137
.await
94138
.map_err(|e| v3::Error::QueryFailed(format!("{:?}", e)))?;
@@ -111,17 +155,6 @@ impl Client for TokioClient {
111155
}
112156
}
113157

114-
fn spawn_connection<T>(connection: tokio_postgres::Connection<Socket, T>)
115-
where
116-
T: tokio_postgres::tls::TlsStream + std::marker::Unpin + std::marker::Send + 'static,
117-
{
118-
tokio::spawn(async move {
119-
if let Err(e) = connection.await {
120-
tracing::error!("Postgres connection error: {}", e);
121-
}
122-
});
123-
}
124-
125158
fn to_sql_parameter(value: &ParameterValue) -> Result<Box<dyn ToSql + Send + Sync>> {
126159
match value {
127160
ParameterValue::Boolean(v) => Ok(Box::new(*v)),

crates/factor-outbound-pg/src/host.rs

+14-11
Original file line numberDiff line numberDiff line change
@@ -9,17 +9,20 @@ use tracing::field::Empty;
99
use tracing::instrument;
1010
use tracing::Level;
1111

12-
use crate::client::Client;
12+
use crate::client::{Client, ClientFactory};
1313
use crate::InstanceState;
1414

15-
impl<C: Client> InstanceState<C> {
15+
impl<CF: ClientFactory> InstanceState<CF> {
1616
async fn open_connection<Conn: 'static>(
1717
&mut self,
1818
address: &str,
1919
) -> Result<Resource<Conn>, v3::Error> {
2020
self.connections
2121
.push(
22-
C::build_client(address)
22+
self.client_factory
23+
.write()
24+
.await
25+
.build_client(address)
2326
.await
2427
.map_err(|e| v3::Error::ConnectionFailed(format!("{e:?}")))?,
2528
)
@@ -30,7 +33,7 @@ impl<C: Client> InstanceState<C> {
3033
async fn get_client<Conn: 'static>(
3134
&mut self,
3235
connection: Resource<Conn>,
33-
) -> Result<&C, v3::Error> {
36+
) -> Result<&CF::Client, v3::Error> {
3437
self.connections
3538
.get(connection.rep())
3639
.ok_or_else(|| v3::Error::ConnectionFailed("no connection found".into()))
@@ -71,8 +74,8 @@ fn v2_params_to_v3(
7174
params.into_iter().map(|p| p.try_into()).collect()
7275
}
7376

74-
impl<C: Send + Sync + Client> spin_world::spin::postgres::postgres::HostConnection
75-
for InstanceState<C>
77+
impl<CF: ClientFactory + Send + Sync> spin_world::spin::postgres::postgres::HostConnection
78+
for InstanceState<CF>
7679
{
7780
#[instrument(name = "spin_outbound_pg.open", skip(self, address), err(level = Level::INFO), fields(otel.kind = "client", db.system = "postgresql", db.address = Empty, server.port = Empty, db.namespace = Empty))]
7881
async fn open(&mut self, address: String) -> Result<Resource<v3::Connection>, v3::Error> {
@@ -122,13 +125,13 @@ impl<C: Send + Sync + Client> spin_world::spin::postgres::postgres::HostConnecti
122125
}
123126
}
124127

125-
impl<C: Send> v2_types::Host for InstanceState<C> {
128+
impl<CF: ClientFactory + Send> v2_types::Host for InstanceState<CF> {
126129
fn convert_error(&mut self, error: v2::Error) -> Result<v2::Error> {
127130
Ok(error)
128131
}
129132
}
130133

131-
impl<C: Send + Sync + Client> v3::Host for InstanceState<C> {
134+
impl<CF: Send + Sync + ClientFactory> v3::Host for InstanceState<CF> {
132135
fn convert_error(&mut self, error: v3::Error) -> Result<v3::Error> {
133136
Ok(error)
134137
}
@@ -152,9 +155,9 @@ macro_rules! delegate {
152155
}};
153156
}
154157

155-
impl<C: Send + Sync + Client> v2::Host for InstanceState<C> {}
158+
impl<CF: Send + Sync + ClientFactory> v2::Host for InstanceState<CF> {}
156159

157-
impl<C: Send + Sync + Client> v2::HostConnection for InstanceState<C> {
160+
impl<CF: Send + Sync + ClientFactory> v2::HostConnection for InstanceState<CF> {
158161
#[instrument(name = "spin_outbound_pg.open", skip(self, address), err(level = Level::INFO), fields(otel.kind = "client", db.system = "postgresql", db.address = Empty, server.port = Empty, db.namespace = Empty))]
159162
async fn open(&mut self, address: String) -> Result<Resource<v2::Connection>, v2::Error> {
160163
spin_factor_outbound_networking::record_address_fields(&address);
@@ -206,7 +209,7 @@ impl<C: Send + Sync + Client> v2::HostConnection for InstanceState<C> {
206209
}
207210
}
208211

209-
impl<C: Send + Sync + Client> v1::Host for InstanceState<C> {
212+
impl<CF: Send + Sync + ClientFactory> v1::Host for InstanceState<CF> {
210213
async fn execute(
211214
&mut self,
212215
address: String,

crates/factor-outbound-pg/src/lib.rs

+15-11
Original file line numberDiff line numberDiff line change
@@ -1,21 +1,23 @@
11
pub mod client;
22
mod host;
33

4-
use client::Client;
4+
use std::sync::Arc;
5+
6+
use client::ClientFactory;
57
use spin_factor_outbound_networking::{OutboundAllowedHosts, OutboundNetworkingFactor};
68
use spin_factors::{
79
anyhow, ConfigureAppContext, Factor, PrepareContext, RuntimeFactors, SelfInstanceBuilder,
810
};
9-
use tokio_postgres::Client as PgClient;
11+
use tokio::sync::RwLock;
1012

11-
pub struct OutboundPgFactor<C = PgClient> {
12-
_phantom: std::marker::PhantomData<C>,
13+
pub struct OutboundPgFactor<CF = crate::client::PooledTokioClientFactory> {
14+
_phantom: std::marker::PhantomData<CF>,
1315
}
1416

15-
impl<C: Send + Sync + Client + 'static> Factor for OutboundPgFactor<C> {
17+
impl<CF: ClientFactory + Send + Sync + 'static> Factor for OutboundPgFactor<CF> {
1618
type RuntimeConfig = ();
17-
type AppState = ();
18-
type InstanceBuilder = InstanceState<C>;
19+
type AppState = Arc<RwLock<CF>>;
20+
type InstanceBuilder = InstanceState<CF>;
1921

2022
fn init<T: Send + 'static>(
2123
&mut self,
@@ -31,7 +33,7 @@ impl<C: Send + Sync + Client + 'static> Factor for OutboundPgFactor<C> {
3133
&self,
3234
_ctx: ConfigureAppContext<T, Self>,
3335
) -> anyhow::Result<Self::AppState> {
34-
Ok(())
36+
Ok(Arc::new(RwLock::new(CF::new())))
3537
}
3638

3739
fn prepare<T: RuntimeFactors>(
@@ -43,6 +45,7 @@ impl<C: Send + Sync + Client + 'static> Factor for OutboundPgFactor<C> {
4345
.allowed_hosts();
4446
Ok(InstanceState {
4547
allowed_hosts,
48+
client_factory: ctx.app_state().clone(),
4649
connections: Default::default(),
4750
})
4851
}
@@ -62,9 +65,10 @@ impl<C> OutboundPgFactor<C> {
6265
}
6366
}
6467

65-
pub struct InstanceState<C> {
68+
pub struct InstanceState<CF: ClientFactory> {
6669
allowed_hosts: OutboundAllowedHosts,
67-
connections: spin_resource_table::Table<C>,
70+
client_factory: Arc<RwLock<CF>>,
71+
connections: spin_resource_table::Table<CF::Client>,
6872
}
6973

70-
impl<C: Send + 'static> SelfInstanceBuilder for InstanceState<C> {}
74+
impl<CF: ClientFactory + Send + 'static> SelfInstanceBuilder for InstanceState<CF> {}

0 commit comments

Comments
 (0)