Skip to content

Commit 1aad833

Browse files
authored
feat: high-level connection manager and cancel support (#427)
* feat: add connection manager * feat: query with cancel token * feat: integrate connection manager into startup and cancel handler * feat: don't panic on duplicated registration * chore: add new example to feature gate * fix: revert changes to sqlite example * fix: revert changes in test server * refactor: remove inconsistent query * chore: bring back comments
1 parent 202ff9f commit 1aad833

12 files changed

Lines changed: 606 additions & 93 deletions

File tree

Cargo.toml

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -196,6 +196,10 @@ required-features = ["server-api-aws-lc-rs"]
196196
name = "transaction"
197197
required-features = ["server-api-aws-lc-rs"]
198198

199+
[[example]]
200+
name = "cancel"
201+
required-features = ["server-api-aws-lc-rs"]
202+
199203
[[example]]
200204
name = "client"
201205
required-features = ["client-api"]

examples/cancel.rs

Lines changed: 128 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,128 @@
1+
use std::sync::Arc;
2+
use std::time::Duration;
3+
4+
use async_trait::async_trait;
5+
use futures::{Sink, SinkExt, StreamExt, stream};
6+
use tokio::net::TcpListener;
7+
use tokio::time::sleep;
8+
9+
use pgwire::api::auth::noop::NoopStartupHandler;
10+
use pgwire::api::cancel::DefaultCancelHandler;
11+
use pgwire::api::query::SimpleQueryHandler;
12+
use pgwire::api::results::{FieldFormat, FieldInfo, QueryResponse, Response, Tag};
13+
use pgwire::api::store::PortalStore;
14+
use pgwire::api::{ClientInfo, ClientPortalStore, ConnectionManager, PgWireServerHandlers, Type};
15+
use pgwire::error::{ErrorInfo, PgWireError, PgWireResult};
16+
use pgwire::messages::response::NoticeResponse;
17+
use pgwire::messages::{PgWireBackendMessage, PgWireFrontendMessage};
18+
use pgwire::tokio::process_socket;
19+
20+
struct SlowProcessor {
21+
manager: Arc<ConnectionManager>,
22+
}
23+
24+
#[async_trait]
25+
impl NoopStartupHandler for SlowProcessor {
26+
fn connection_manager(&self) -> Option<Arc<ConnectionManager>> {
27+
Some(self.manager.clone())
28+
}
29+
30+
async fn post_startup<C>(
31+
&self,
32+
client: &mut C,
33+
_message: PgWireFrontendMessage,
34+
) -> PgWireResult<()>
35+
where
36+
C: ClientInfo + Sink<PgWireBackendMessage> + Unpin + Send,
37+
C::Error: std::fmt::Debug,
38+
PgWireError: From<<C as Sink<PgWireBackendMessage>>::Error>,
39+
{
40+
client
41+
.send(PgWireBackendMessage::NoticeResponse(NoticeResponse::from(
42+
ErrorInfo::new(
43+
"NOTICE".to_owned(),
44+
"01000".to_owned(),
45+
"This example demonstrates query cancellation.\n\
46+
Supported queries:\n\
47+
- SELECT 1; (instant query)\n\
48+
- SELECT pg_sleep(10); (sleeps 10s, press Ctrl+C to cancel)"
49+
.to_string(),
50+
),
51+
)))
52+
.await?;
53+
Ok(())
54+
}
55+
}
56+
57+
#[async_trait]
58+
impl SimpleQueryHandler for SlowProcessor {
59+
async fn do_query<C>(&self, _client: &mut C, query: &str) -> PgWireResult<Vec<Response>>
60+
where
61+
C: ClientInfo + ClientPortalStore + Unpin + Send + Sync,
62+
C::PortalStore: PortalStore,
63+
{
64+
if query.trim().starts_with("SELECT pg_sleep") {
65+
sleep(Duration::from_secs(10)).await;
66+
Ok(vec![Response::Execution(Tag::new("SELECT 1"))])
67+
} else if query.trim().starts_with("SELECT") {
68+
let f1 = FieldInfo::new("id".into(), None, None, Type::INT4, FieldFormat::Text);
69+
let schema = Arc::new(vec![f1]);
70+
let data = vec![(Some(1),), (Some(2),), (Some(3),)];
71+
let schema_ref = schema.clone();
72+
let data_row_stream = stream::iter(data).map(move |r| {
73+
let mut encoder = pgwire::api::results::DataRowEncoder::new(schema_ref.clone());
74+
encoder.encode_field(&r.0)?;
75+
Ok(encoder.take_row())
76+
});
77+
Ok(vec![Response::Query(QueryResponse::new(
78+
schema,
79+
data_row_stream,
80+
))])
81+
} else {
82+
Ok(vec![Response::Execution(Tag::new("OK").with_rows(1))])
83+
}
84+
}
85+
}
86+
87+
struct HandlerFactory {
88+
processor: Arc<SlowProcessor>,
89+
cancel_handler: Arc<DefaultCancelHandler>,
90+
}
91+
92+
impl PgWireServerHandlers for HandlerFactory {
93+
fn simple_query_handler(&self) -> Arc<impl SimpleQueryHandler> {
94+
self.processor.clone()
95+
}
96+
97+
fn startup_handler(&self) -> Arc<impl pgwire::api::auth::StartupHandler> {
98+
self.processor.clone()
99+
}
100+
101+
fn cancel_handler(&self) -> Arc<impl pgwire::api::cancel::CancelHandler> {
102+
self.cancel_handler.clone()
103+
}
104+
}
105+
106+
#[tokio::main]
107+
pub async fn main() {
108+
let manager = Arc::new(ConnectionManager::new());
109+
let processor = Arc::new(SlowProcessor {
110+
manager: manager.clone(),
111+
});
112+
let cancel_handler = Arc::new(DefaultCancelHandler::new(manager));
113+
let factory = Arc::new(HandlerFactory {
114+
processor,
115+
cancel_handler,
116+
});
117+
118+
let server_addr = "127.0.0.1:5432";
119+
let listener = TcpListener::bind(server_addr).await.unwrap();
120+
println!("Listening to {}", server_addr);
121+
println!("Connect with: psql -h 127.0.0.1 -p 5432");
122+
println!("Test cancel: SELECT pg_sleep(10); then press Ctrl+C");
123+
loop {
124+
let incoming_socket = listener.accept().await.unwrap();
125+
let factory_ref = factory.clone();
126+
tokio::spawn(async move { process_socket(incoming_socket.0, None, factory_ref).await });
127+
}
128+
}

src/api/auth/cleartext.rs

Lines changed: 33 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
use std::fmt::Debug;
2+
use std::sync::Arc;
23

34
use async_trait::async_trait;
45
use futures::sink::{Sink, SinkExt};
@@ -7,14 +8,40 @@ use super::{
78
AuthSource, ClientInfo, LoginInfo, PgWireConnectionState, ServerParameterProvider,
89
StartupHandler,
910
};
11+
use crate::api::{ConnectionManager, PidSecretKeyGenerator, RandomPidSecretKeyGenerator};
1012
use crate::error::{PgWireError, PgWireResult};
1113
use crate::messages::startup::Authentication;
1214
use crate::messages::{PgWireBackendMessage, PgWireFrontendMessage};
1315

14-
#[derive(new)]
1516
pub struct CleartextPasswordAuthStartupHandler<A, P> {
1617
auth_source: A,
1718
parameter_provider: P,
19+
pid_secret_key_generator: Arc<dyn PidSecretKeyGenerator>,
20+
connection_manager: Option<Arc<ConnectionManager>>,
21+
}
22+
23+
impl<A, P> CleartextPasswordAuthStartupHandler<A, P> {
24+
pub fn new(auth_source: A, parameter_provider: P) -> Self {
25+
Self {
26+
auth_source,
27+
parameter_provider,
28+
pid_secret_key_generator: Arc::new(RandomPidSecretKeyGenerator),
29+
connection_manager: None,
30+
}
31+
}
32+
33+
pub fn with_pid_secret_key_generator(
34+
mut self,
35+
generator: Arc<dyn PidSecretKeyGenerator>,
36+
) -> Self {
37+
self.pid_secret_key_generator = generator;
38+
self
39+
}
40+
41+
pub fn with_connection_manager(mut self, manager: Arc<ConnectionManager>) -> Self {
42+
self.connection_manager = Some(manager);
43+
self
44+
}
1845
}
1946

2047
#[async_trait]
@@ -47,6 +74,11 @@ impl<V: AuthSource, P: ServerParameterProvider> StartupHandler
4774
let login_info = LoginInfo::from_client_info(client);
4875
let pass = self.auth_source.get_password(&login_info).await?;
4976
if pass.password == pwd.password.as_bytes() {
77+
let (pid, secret_key) = self.pid_secret_key_generator.generate(client);
78+
client.set_pid_and_secret_key(pid, secret_key);
79+
if let Some(manager) = &self.connection_manager {
80+
super::register_connection(client, manager);
81+
}
5082
super::finish_authentication(client, &self.parameter_provider).await?;
5183
} else {
5284
return Err(PgWireError::InvalidPassword(

src/api/auth/md5pass.rs

Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -9,13 +9,16 @@ use super::{
99
AuthSource, ClientInfo, LoginInfo, PgWireConnectionState, ServerParameterProvider,
1010
StartupHandler,
1111
};
12+
use crate::api::{ConnectionManager, PidSecretKeyGenerator, RandomPidSecretKeyGenerator};
1213
use crate::error::{PgWireError, PgWireResult};
1314
use crate::messages::startup::Authentication;
1415
use crate::messages::{PgWireBackendMessage, PgWireFrontendMessage};
1516

1617
pub struct Md5PasswordAuthStartupHandler<A, P> {
1718
auth_source: Arc<A>,
1819
parameter_provider: Arc<P>,
20+
pid_secret_key_generator: Arc<dyn PidSecretKeyGenerator>,
21+
connection_manager: Option<Arc<ConnectionManager>>,
1922
cached_password: Mutex<Vec<u8>>,
2023
}
2124

@@ -24,9 +27,24 @@ impl<A, P> Md5PasswordAuthStartupHandler<A, P> {
2427
Md5PasswordAuthStartupHandler {
2528
auth_source,
2629
parameter_provider,
30+
pid_secret_key_generator: Arc::new(RandomPidSecretKeyGenerator),
31+
connection_manager: None,
2732
cached_password: Mutex::new(vec![]),
2833
}
2934
}
35+
36+
pub fn with_pid_secret_key_generator(
37+
mut self,
38+
generator: Arc<dyn PidSecretKeyGenerator>,
39+
) -> Self {
40+
self.pid_secret_key_generator = generator;
41+
self
42+
}
43+
44+
pub fn with_connection_manager(mut self, manager: Arc<ConnectionManager>) -> Self {
45+
self.connection_manager = Some(manager);
46+
self
47+
}
3048
}
3149

3250
#[async_trait]
@@ -73,6 +91,11 @@ impl<A: AuthSource, P: ServerParameterProvider> StartupHandler
7391
let cached_pass = self.cached_password.lock().await;
7492

7593
if pwd.password.as_bytes() == *cached_pass {
94+
let (pid, secret_key) = self.pid_secret_key_generator.generate(client);
95+
client.set_pid_and_secret_key(pid, secret_key);
96+
if let Some(manager) = &self.connection_manager {
97+
super::register_connection(client, manager);
98+
}
7699
super::finish_authentication(client, self.parameter_provider.as_ref()).await?;
77100
} else {
78101
let login_info = LoginInfo::from_client_info(client);

src/api/auth/mod.rs

Lines changed: 15 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,12 +1,13 @@
11
use std::collections::HashMap;
22
use std::fmt::Debug;
3+
use std::sync::Arc;
34

45
use async_trait::async_trait;
56
use futures::sink::{Sink, SinkExt};
67

78
use super::{
8-
ClientInfo, METADATA_APPLICATION_NAME, METADATA_CLIENT_ENCODING, METADATA_DATABASE,
9-
METADATA_USER, PgWireConnectionState,
9+
ClientInfo, ConnectionGuard, ConnectionHandle, METADATA_APPLICATION_NAME,
10+
METADATA_CLIENT_ENCODING, METADATA_DATABASE, METADATA_USER, PgWireConnectionState,
1011
};
1112
use crate::error::{PgWireError, PgWireResult};
1213
use crate::messages::response::{ReadyForQuery, TransactionStatus};
@@ -291,6 +292,18 @@ where
291292
);
292293
}
293294

295+
pub(crate) fn register_connection<C>(client: &C, manager: &Arc<super::ConnectionManager>)
296+
where
297+
C: ClientInfo,
298+
{
299+
let (pid, secret_key) = client.pid_and_secret_key();
300+
let (handle, guard) = manager.register(pid, secret_key);
301+
client
302+
.session_extensions()
303+
.insert::<Arc<ConnectionHandle>>(handle);
304+
client.session_extensions().insert::<ConnectionGuard>(guard);
305+
}
306+
294307
pub(crate) async fn finish_authentication0<C, P>(
295308
client: &mut C,
296309
server_parameter_provider: &P,

src/api/auth/noop.rs

Lines changed: 17 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,16 +1,27 @@
11
use std::fmt::Debug;
2+
use std::sync::Arc;
23

34
use async_trait::async_trait;
45
use futures::sink::{Sink, SinkExt};
56

67
use super::{ClientInfo, DefaultServerParameterProvider, StartupHandler};
7-
use crate::api::PgWireConnectionState;
8+
use crate::api::{
9+
ConnectionManager, PgWireConnectionState, PidSecretKeyGenerator, RandomPidSecretKeyGenerator,
10+
};
811
use crate::error::{PgWireError, PgWireResult};
912
use crate::messages::response::{ReadyForQuery, TransactionStatus};
1013
use crate::messages::{PgWireBackendMessage, PgWireFrontendMessage};
1114

1215
#[async_trait]
1316
pub trait NoopStartupHandler: StartupHandler {
17+
fn connection_manager(&self) -> Option<Arc<ConnectionManager>> {
18+
None
19+
}
20+
21+
fn pid_secret_key_generator(&self) -> &dyn PidSecretKeyGenerator {
22+
&RandomPidSecretKeyGenerator
23+
}
24+
1425
async fn post_startup<C>(
1526
&self,
1627
_client: &mut C,
@@ -43,6 +54,11 @@ where
4354
if let PgWireFrontendMessage::Startup(ref startup) = message {
4455
super::protocol_negotiation(client, startup).await?;
4556
super::save_startup_parameters_to_metadata(client, startup);
57+
let (pid, secret_key) = self.pid_secret_key_generator().generate(client);
58+
client.set_pid_and_secret_key(pid, secret_key);
59+
if let Some(manager) = self.connection_manager() {
60+
super::register_connection(client, &manager);
61+
}
4662
super::finish_authentication0(client, &DefaultServerParameterProvider::default())
4763
.await?;
4864

0 commit comments

Comments
 (0)