Skip to content

Commit 721d649

Browse files
committed
Pool Postgres connections
Signed-off-by: itowlson <[email protected]>
1 parent 3870f54 commit 721d649

File tree

6 files changed

+139
-69
lines changed

6 files changed

+139
-69
lines changed

Cargo.lock

+35
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

crates/factor-outbound-pg/Cargo.toml

+1
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@ edition = { workspace = true }
77
[dependencies]
88
anyhow = { workspace = true }
99
chrono = "0.4"
10+
deadpool-postgres = { version = "0.14", features = ["rt_tokio_1"] }
1011
native-tls = "0.2"
1112
postgres-native-tls = "0.5"
1213
spin-core = { path = "../core" }

crates/factor-outbound-pg/src/client.rs

+61-40
Original file line numberDiff line numberDiff line change
@@ -6,15 +6,66 @@ use spin_world::spin::postgres::postgres::{
66
self as v3, Column, DbDataType, DbValue, ParameterValue, RowSet,
77
};
88
use tokio_postgres::types::Type;
9-
use tokio_postgres::{config::SslMode, types::ToSql, Row};
10-
use tokio_postgres::{Client as TokioClient, NoTls, Socket};
9+
use tokio_postgres::{config::SslMode, types::ToSql, NoTls, Row};
1110

1211
#[async_trait]
13-
pub trait Client {
14-
async fn build_client(address: &str) -> Result<Self>
15-
where
16-
Self: Sized;
12+
pub trait ClientFactory: Send + Sync {
13+
type Client: Client + Send + Sync + 'static;
14+
fn new() -> Self;
15+
async fn build_client(&mut self, address: &str) -> Result<Self::Client>;
16+
}
17+
18+
pub struct PooledTokioClientFactory {
19+
pools: std::collections::HashMap<String, deadpool_postgres::Pool>,
20+
}
21+
22+
#[async_trait]
23+
impl ClientFactory for PooledTokioClientFactory {
24+
type Client = deadpool_postgres::Object;
25+
fn new() -> Self {
26+
Self {
27+
pools: Default::default(),
28+
}
29+
}
30+
async fn build_client(&mut self, address: &str) -> Result<Self::Client> {
31+
// TODO: propagate error instead of unwrapping
32+
let pool = self
33+
.pools
34+
.entry(address.to_owned())
35+
.or_insert_with(|| pool_for(address).unwrap());
36+
Ok(pool.get().await?)
37+
}
38+
}
39+
40+
fn pool_for(address: &str) -> Result<deadpool_postgres::Pool> {
41+
let config = address.parse::<tokio_postgres::Config>()?;
1742

43+
tracing::debug!("Build new connection: {}", address);
44+
45+
// TODO: This is slower but safer. Is it the right tradeoff?
46+
// https://docs.rs/deadpool-postgres/latest/deadpool_postgres/enum.RecyclingMethod.html
47+
let mgr_config = deadpool_postgres::ManagerConfig {
48+
recycling_method: deadpool_postgres::RecyclingMethod::Clean,
49+
};
50+
51+
let mgr = if config.get_ssl_mode() == SslMode::Disable {
52+
deadpool_postgres::Manager::from_config(config.clone(), NoTls, mgr_config)
53+
} else {
54+
let builder = TlsConnector::builder();
55+
let connector = MakeTlsConnector::new(builder.build()?);
56+
deadpool_postgres::Manager::from_config(config.clone(), connector, mgr_config)
57+
};
58+
59+
// TODO: what is our max size heuristic? Should this be passed in soe that different
60+
// hosts can manage it according to their needs? Will a plain number suffice for
61+
// sophisticated hosts anyway?
62+
let pool = deadpool_postgres::Pool::builder(mgr).max_size(4).build()?;
63+
64+
Ok(pool)
65+
}
66+
67+
#[async_trait]
68+
pub trait Client {
1869
async fn execute(
1970
&self,
2071
statement: String,
@@ -29,28 +80,7 @@ pub trait Client {
2980
}
3081

3182
#[async_trait]
32-
impl Client for TokioClient {
33-
async fn build_client(address: &str) -> Result<Self>
34-
where
35-
Self: Sized,
36-
{
37-
let config = address.parse::<tokio_postgres::Config>()?;
38-
39-
tracing::debug!("Build new connection: {}", address);
40-
41-
if config.get_ssl_mode() == SslMode::Disable {
42-
let (client, connection) = config.connect(NoTls).await?;
43-
spawn_connection(connection);
44-
Ok(client)
45-
} else {
46-
let builder = TlsConnector::builder();
47-
let connector = MakeTlsConnector::new(builder.build()?);
48-
let (client, connection) = config.connect(connector).await?;
49-
spawn_connection(connection);
50-
Ok(client)
51-
}
52-
}
53-
83+
impl Client for deadpool_postgres::Object {
5484
async fn execute(
5585
&self,
5686
statement: String,
@@ -67,7 +97,8 @@ impl Client for TokioClient {
6797
.map(|b| b.as_ref() as &(dyn ToSql + Sync))
6898
.collect();
6999

70-
self.execute(&statement, params_refs.as_slice())
100+
self.as_ref()
101+
.execute(&statement, params_refs.as_slice())
71102
.await
72103
.map_err(|e| v3::Error::QueryFailed(format!("{:?}", e)))
73104
}
@@ -89,6 +120,7 @@ impl Client for TokioClient {
89120
.collect();
90121

91122
let results = self
123+
.as_ref()
92124
.query(&statement, params_refs.as_slice())
93125
.await
94126
.map_err(|e| v3::Error::QueryFailed(format!("{:?}", e)))?;
@@ -111,17 +143,6 @@ impl Client for TokioClient {
111143
}
112144
}
113145

114-
fn spawn_connection<T>(connection: tokio_postgres::Connection<Socket, T>)
115-
where
116-
T: tokio_postgres::tls::TlsStream + std::marker::Unpin + std::marker::Send + 'static,
117-
{
118-
tokio::spawn(async move {
119-
if let Err(e) = connection.await {
120-
tracing::error!("Postgres connection error: {}", e);
121-
}
122-
});
123-
}
124-
125146
fn to_sql_parameter(value: &ParameterValue) -> Result<Box<dyn ToSql + Send + Sync>> {
126147
match value {
127148
ParameterValue::Boolean(v) => Ok(Box::new(*v)),

crates/factor-outbound-pg/src/host.rs

+14-11
Original file line numberDiff line numberDiff line change
@@ -9,17 +9,20 @@ use tracing::field::Empty;
99
use tracing::instrument;
1010
use tracing::Level;
1111

12-
use crate::client::Client;
12+
use crate::client::{Client, ClientFactory};
1313
use crate::InstanceState;
1414

15-
impl<C: Client> InstanceState<C> {
15+
impl<CF: ClientFactory> InstanceState<CF> {
1616
async fn open_connection<Conn: 'static>(
1717
&mut self,
1818
address: &str,
1919
) -> Result<Resource<Conn>, v3::Error> {
2020
self.connections
2121
.push(
22-
C::build_client(address)
22+
self.client_factory
23+
.write()
24+
.await
25+
.build_client(address)
2326
.await
2427
.map_err(|e| v3::Error::ConnectionFailed(format!("{e:?}")))?,
2528
)
@@ -30,7 +33,7 @@ impl<C: Client> InstanceState<C> {
3033
async fn get_client<Conn: 'static>(
3134
&mut self,
3235
connection: Resource<Conn>,
33-
) -> Result<&C, v3::Error> {
36+
) -> Result<&CF::Client, v3::Error> {
3437
self.connections
3538
.get(connection.rep())
3639
.ok_or_else(|| v3::Error::ConnectionFailed("no connection found".into()))
@@ -71,8 +74,8 @@ fn v2_params_to_v3(
7174
params.into_iter().map(|p| p.try_into()).collect()
7275
}
7376

74-
impl<C: Send + Sync + Client> spin_world::spin::postgres::postgres::HostConnection
75-
for InstanceState<C>
77+
impl<CF: ClientFactory + Send + Sync> spin_world::spin::postgres::postgres::HostConnection
78+
for InstanceState<CF>
7679
{
7780
#[instrument(name = "spin_outbound_pg.open", skip(self, address), err(level = Level::INFO), fields(otel.kind = "client", db.system = "postgresql", db.address = Empty, server.port = Empty, db.namespace = Empty))]
7881
async fn open(&mut self, address: String) -> Result<Resource<v3::Connection>, v3::Error> {
@@ -122,13 +125,13 @@ impl<C: Send + Sync + Client> spin_world::spin::postgres::postgres::HostConnecti
122125
}
123126
}
124127

125-
impl<C: Send> v2_types::Host for InstanceState<C> {
128+
impl<CF: ClientFactory + Send> v2_types::Host for InstanceState<CF> {
126129
fn convert_error(&mut self, error: v2::Error) -> Result<v2::Error> {
127130
Ok(error)
128131
}
129132
}
130133

131-
impl<C: Send + Sync + Client> v3::Host for InstanceState<C> {
134+
impl<CF: Send + Sync + ClientFactory> v3::Host for InstanceState<CF> {
132135
fn convert_error(&mut self, error: v3::Error) -> Result<v3::Error> {
133136
Ok(error)
134137
}
@@ -152,9 +155,9 @@ macro_rules! delegate {
152155
}};
153156
}
154157

155-
impl<C: Send + Sync + Client> v2::Host for InstanceState<C> {}
158+
impl<CF: Send + Sync + ClientFactory> v2::Host for InstanceState<CF> {}
156159

157-
impl<C: Send + Sync + Client> v2::HostConnection for InstanceState<C> {
160+
impl<CF: Send + Sync + ClientFactory> v2::HostConnection for InstanceState<CF> {
158161
#[instrument(name = "spin_outbound_pg.open", skip(self, address), err(level = Level::INFO), fields(otel.kind = "client", db.system = "postgresql", db.address = Empty, server.port = Empty, db.namespace = Empty))]
159162
async fn open(&mut self, address: String) -> Result<Resource<v2::Connection>, v2::Error> {
160163
spin_factor_outbound_networking::record_address_fields(&address);
@@ -206,7 +209,7 @@ impl<C: Send + Sync + Client> v2::HostConnection for InstanceState<C> {
206209
}
207210
}
208211

209-
impl<C: Send + Sync + Client> v1::Host for InstanceState<C> {
212+
impl<CF: Send + Sync + ClientFactory> v1::Host for InstanceState<CF> {
210213
async fn execute(
211214
&mut self,
212215
address: String,

crates/factor-outbound-pg/src/lib.rs

+15-11
Original file line numberDiff line numberDiff line change
@@ -1,21 +1,23 @@
11
pub mod client;
22
mod host;
33

4-
use client::Client;
4+
use std::sync::Arc;
5+
6+
use client::ClientFactory;
57
use spin_factor_outbound_networking::{OutboundAllowedHosts, OutboundNetworkingFactor};
68
use spin_factors::{
79
anyhow, ConfigureAppContext, Factor, PrepareContext, RuntimeFactors, SelfInstanceBuilder,
810
};
9-
use tokio_postgres::Client as PgClient;
11+
use tokio::sync::RwLock;
1012

11-
pub struct OutboundPgFactor<C = PgClient> {
12-
_phantom: std::marker::PhantomData<C>,
13+
pub struct OutboundPgFactor<CF = crate::client::PooledTokioClientFactory> {
14+
_phantom: std::marker::PhantomData<CF>,
1315
}
1416

15-
impl<C: Send + Sync + Client + 'static> Factor for OutboundPgFactor<C> {
17+
impl<CF: ClientFactory + Send + Sync + 'static> Factor for OutboundPgFactor<CF> {
1618
type RuntimeConfig = ();
17-
type AppState = ();
18-
type InstanceBuilder = InstanceState<C>;
19+
type AppState = Arc<RwLock<CF>>;
20+
type InstanceBuilder = InstanceState<CF>;
1921

2022
fn init<T: Send + 'static>(
2123
&mut self,
@@ -31,7 +33,7 @@ impl<C: Send + Sync + Client + 'static> Factor for OutboundPgFactor<C> {
3133
&self,
3234
_ctx: ConfigureAppContext<T, Self>,
3335
) -> anyhow::Result<Self::AppState> {
34-
Ok(())
36+
Ok(Arc::new(RwLock::new(CF::new())))
3537
}
3638

3739
fn prepare<T: RuntimeFactors>(
@@ -43,6 +45,7 @@ impl<C: Send + Sync + Client + 'static> Factor for OutboundPgFactor<C> {
4345
.allowed_hosts();
4446
Ok(InstanceState {
4547
allowed_hosts,
48+
client_factory: ctx.app_state().clone(),
4649
connections: Default::default(),
4750
})
4851
}
@@ -62,9 +65,10 @@ impl<C> OutboundPgFactor<C> {
6265
}
6366
}
6467

65-
pub struct InstanceState<C> {
68+
pub struct InstanceState<CF: ClientFactory> {
6669
allowed_hosts: OutboundAllowedHosts,
67-
connections: spin_resource_table::Table<C>,
70+
client_factory: Arc<RwLock<CF>>,
71+
connections: spin_resource_table::Table<CF::Client>,
6872
}
6973

70-
impl<C: Send + 'static> SelfInstanceBuilder for InstanceState<C> {}
74+
impl<CF: ClientFactory + Send + 'static> SelfInstanceBuilder for InstanceState<CF> {}

crates/factor-outbound-pg/tests/factor_test.rs

+13-7
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
use anyhow::{bail, Result};
22
use spin_factor_outbound_networking::OutboundNetworkingFactor;
33
use spin_factor_outbound_pg::client::Client;
4+
use spin_factor_outbound_pg::client::ClientFactory;
45
use spin_factor_outbound_pg::OutboundPgFactor;
56
use spin_factor_variables::VariablesFactor;
67
use spin_factors::{anyhow, RuntimeFactors};
@@ -15,14 +16,14 @@ use spin_world::spin::postgres::postgres::{ParameterValue, RowSet};
1516
struct TestFactors {
1617
variables: VariablesFactor,
1718
networking: OutboundNetworkingFactor,
18-
pg: OutboundPgFactor<MockClient>,
19+
pg: OutboundPgFactor<MockClientFactory>,
1920
}
2021

2122
fn factors() -> TestFactors {
2223
TestFactors {
2324
variables: VariablesFactor::default(),
2425
networking: OutboundNetworkingFactor::new(),
25-
pg: OutboundPgFactor::<MockClient>::new(),
26+
pg: OutboundPgFactor::<MockClientFactory>::new(),
2627
}
2728
}
2829

@@ -104,17 +105,22 @@ async fn exercise_query() -> anyhow::Result<()> {
104105
}
105106

106107
// TODO: We can expand this mock to track calls and simulate return values
108+
pub struct MockClientFactory {}
107109
pub struct MockClient {}
108110

109111
#[async_trait]
110-
impl Client for MockClient {
111-
async fn build_client(_address: &str) -> anyhow::Result<Self>
112-
where
113-
Self: Sized,
114-
{
112+
impl ClientFactory for MockClientFactory {
113+
type Client = MockClient;
114+
fn new() -> Self {
115+
Self {}
116+
}
117+
async fn build_client(&mut self, _address: &str) -> Result<Self::Client> {
115118
Ok(MockClient {})
116119
}
120+
}
117121

122+
#[async_trait]
123+
impl Client for MockClient {
118124
async fn execute(
119125
&self,
120126
_statement: String,

0 commit comments

Comments
 (0)