1
- use anyhow:: { anyhow, Result } ;
1
+ use anyhow:: { anyhow, Context , Result } ;
2
2
use native_tls:: TlsConnector ;
3
3
use postgres_native_tls:: MakeTlsConnector ;
4
4
use spin_world:: async_trait;
5
5
use spin_world:: spin:: postgres:: postgres:: {
6
6
self as v3, Column , DbDataType , DbValue , ParameterValue , RowSet ,
7
7
} ;
8
8
use tokio_postgres:: types:: Type ;
9
- use tokio_postgres:: { config:: SslMode , types:: ToSql , Row } ;
10
- use tokio_postgres:: { Client as TokioClient , NoTls , Socket } ;
9
+ use tokio_postgres:: { config:: SslMode , types:: ToSql , NoTls , Row } ;
10
+
11
+ const CONNECTION_POOL_SIZE : usize = 64 ;
11
12
12
13
#[ async_trait]
13
- pub trait Client {
14
- async fn build_client ( address : & str ) -> Result < Self >
15
- where
16
- Self : Sized ;
14
+ pub trait ClientFactory : Send + Sync {
15
+ type Client : Client + Send + Sync + ' static ;
16
+ fn new ( ) -> Self ;
17
+ async fn build_client ( & mut self , address : & str ) -> Result < Self :: Client > ;
18
+ }
19
+
20
+ pub struct PooledTokioClientFactory {
21
+ pools : std:: collections:: HashMap < String , deadpool_postgres:: Pool > ,
22
+ }
23
+
24
+ #[ async_trait]
25
+ impl ClientFactory for PooledTokioClientFactory {
26
+ type Client = deadpool_postgres:: Object ;
27
+ fn new ( ) -> Self {
28
+ Self {
29
+ pools : Default :: default ( ) ,
30
+ }
31
+ }
32
+ async fn build_client ( & mut self , address : & str ) -> Result < Self :: Client > {
33
+ let pool_entry = self . pools . entry ( address. to_owned ( ) ) ;
34
+ let pool = match pool_entry {
35
+ std:: collections:: hash_map:: Entry :: Occupied ( entry) => entry. into_mut ( ) ,
36
+ std:: collections:: hash_map:: Entry :: Vacant ( entry) => {
37
+ let pool = create_connection_pool ( address)
38
+ . context ( "establishing PostgreSQL connection pool" ) ?;
39
+ entry. insert ( pool)
40
+ }
41
+ } ;
42
+
43
+ Ok ( pool. get ( ) . await ?)
44
+ }
45
+ }
46
+
47
+ fn create_connection_pool ( address : & str ) -> Result < deadpool_postgres:: Pool > {
48
+ let config = address
49
+ . parse :: < tokio_postgres:: Config > ( )
50
+ . context ( "parsing Postgres connection string" ) ?;
51
+
52
+ tracing:: debug!( "Build new connection: {}" , address) ;
53
+
54
+ // TODO: This is slower but safer. Is it the right tradeoff?
55
+ // https://docs.rs/deadpool-postgres/latest/deadpool_postgres/enum.RecyclingMethod.html
56
+ let mgr_config = deadpool_postgres:: ManagerConfig {
57
+ recycling_method : deadpool_postgres:: RecyclingMethod :: Clean ,
58
+ } ;
59
+
60
+ let mgr = if config. get_ssl_mode ( ) == SslMode :: Disable {
61
+ deadpool_postgres:: Manager :: from_config ( config, NoTls , mgr_config)
62
+ } else {
63
+ let builder = TlsConnector :: builder ( ) ;
64
+ let connector = MakeTlsConnector :: new ( builder. build ( ) ?) ;
65
+ deadpool_postgres:: Manager :: from_config ( config, connector, mgr_config)
66
+ } ;
17
67
68
+ // TODO: what is our max size heuristic? Should this be passed in soe that different
69
+ // hosts can manage it according to their needs? Will a plain number suffice for
70
+ // sophisticated hosts anyway?
71
+ let pool = deadpool_postgres:: Pool :: builder ( mgr)
72
+ . max_size ( CONNECTION_POOL_SIZE )
73
+ . build ( )
74
+ . context ( "building Postgres connection pool" ) ?;
75
+
76
+ Ok ( pool)
77
+ }
78
+
79
+ #[ async_trait]
80
+ pub trait Client {
18
81
async fn execute (
19
82
& self ,
20
83
statement : String ,
@@ -29,28 +92,7 @@ pub trait Client {
29
92
}
30
93
31
94
#[ async_trait]
32
- impl Client for TokioClient {
33
- async fn build_client ( address : & str ) -> Result < Self >
34
- where
35
- Self : Sized ,
36
- {
37
- let config = address. parse :: < tokio_postgres:: Config > ( ) ?;
38
-
39
- tracing:: debug!( "Build new connection: {}" , address) ;
40
-
41
- if config. get_ssl_mode ( ) == SslMode :: Disable {
42
- let ( client, connection) = config. connect ( NoTls ) . await ?;
43
- spawn_connection ( connection) ;
44
- Ok ( client)
45
- } else {
46
- let builder = TlsConnector :: builder ( ) ;
47
- let connector = MakeTlsConnector :: new ( builder. build ( ) ?) ;
48
- let ( client, connection) = config. connect ( connector) . await ?;
49
- spawn_connection ( connection) ;
50
- Ok ( client)
51
- }
52
- }
53
-
95
+ impl Client for deadpool_postgres:: Object {
54
96
async fn execute (
55
97
& self ,
56
98
statement : String ,
@@ -67,7 +109,8 @@ impl Client for TokioClient {
67
109
. map ( |b| b. as_ref ( ) as & ( dyn ToSql + Sync ) )
68
110
. collect ( ) ;
69
111
70
- self . execute ( & statement, params_refs. as_slice ( ) )
112
+ self . as_ref ( )
113
+ . execute ( & statement, params_refs. as_slice ( ) )
71
114
. await
72
115
. map_err ( |e| v3:: Error :: QueryFailed ( format ! ( "{:?}" , e) ) )
73
116
}
@@ -89,6 +132,7 @@ impl Client for TokioClient {
89
132
. collect ( ) ;
90
133
91
134
let results = self
135
+ . as_ref ( )
92
136
. query ( & statement, params_refs. as_slice ( ) )
93
137
. await
94
138
. map_err ( |e| v3:: Error :: QueryFailed ( format ! ( "{:?}" , e) ) ) ?;
@@ -111,17 +155,6 @@ impl Client for TokioClient {
111
155
}
112
156
}
113
157
114
- fn spawn_connection < T > ( connection : tokio_postgres:: Connection < Socket , T > )
115
- where
116
- T : tokio_postgres:: tls:: TlsStream + std:: marker:: Unpin + std:: marker:: Send + ' static ,
117
- {
118
- tokio:: spawn ( async move {
119
- if let Err ( e) = connection. await {
120
- tracing:: error!( "Postgres connection error: {}" , e) ;
121
- }
122
- } ) ;
123
- }
124
-
125
158
fn to_sql_parameter ( value : & ParameterValue ) -> Result < Box < dyn ToSql + Send + Sync > > {
126
159
match value {
127
160
ParameterValue :: Boolean ( v) => Ok ( Box :: new ( * v) ) ,
0 commit comments