Skip to content

Commit a75bba1

Browse files
authored
feat: add cursor support for portal (#433)
* feat: create a constructor for using portal as cursor * feat: add an example for cursor * refactor: use extended query handler to use portal api * refactor: add an error item * fix: remove unnamed portal on sync * feat: remove unnamed statement as well * fix: correct removal of unnamed portal * fix: sync should not destroy unnamed statement
1 parent 367ea7c commit a75bba1

6 files changed

Lines changed: 491 additions & 87 deletions

File tree

Cargo.toml

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -200,6 +200,10 @@ required-features = ["server-api-aws-lc-rs"]
200200
name = "cancel"
201201
required-features = ["server-api-aws-lc-rs"]
202202

203+
[[example]]
204+
name = "cursor"
205+
required-features = ["server-api-aws-lc-rs"]
206+
203207
[[example]]
204208
name = "client"
205209
required-features = ["client-api"]

examples/common/mod.rs

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -82,10 +82,12 @@ impl SimpleQueryHandler for DummyProcessor {
8282
}
8383
}
8484

85+
#[allow(dead_code)]
8586
pub struct DummyProcessorFactory {
8687
pub handler: Arc<DummyProcessor>,
8788
}
8889

90+
#[allow(dead_code)]
8991
impl DummyProcessorFactory {
9092
pub fn new() -> DummyProcessorFactory {
9193
Self {

examples/cursor.rs

Lines changed: 292 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,292 @@
1+
use std::ops::Deref;
2+
use std::sync::Arc;
3+
4+
use async_trait::async_trait;
5+
use futures::{Sink, SinkExt, StreamExt, stream};
6+
use tokio::net::TcpListener;
7+
8+
use pgwire::api::auth::StartupHandler;
9+
use pgwire::api::auth::noop::NoopStartupHandler;
10+
use pgwire::api::portal::Portal;
11+
use pgwire::api::query::SimpleQueryHandler;
12+
use pgwire::api::results::{DataRowEncoder, FieldFormat, FieldInfo, QueryResponse, Response, Tag};
13+
use pgwire::api::stmt::StoredStatement;
14+
use pgwire::api::store::{MemPortalStore, PortalStore};
15+
use pgwire::api::{ClientInfo, ClientPortalStore, PgWireServerHandlers, Type};
16+
use pgwire::error::{ErrorInfo, PgWireError, PgWireResult};
17+
use pgwire::messages::response::NoticeResponse;
18+
use pgwire::messages::{PgWireBackendMessage, PgWireFrontendMessage};
19+
use pgwire::tokio::process_socket;
20+
21+
mod common;
22+
23+
type Statement = String;
24+
25+
struct CursorBackend;
26+
27+
fn make_row_data() -> (Arc<Vec<FieldInfo>>, Vec<(i32, &'static str)>) {
28+
let schema = Arc::new(vec![
29+
FieldInfo::new("id".into(), None, None, Type::INT4, FieldFormat::Text),
30+
FieldInfo::new("name".into(), None, None, Type::VARCHAR, FieldFormat::Text),
31+
]);
32+
let data = vec![
33+
(1, "Alice"),
34+
(2, "Bob"),
35+
(3, "Charlie"),
36+
(4, "Diana"),
37+
(5, "Eve"),
38+
(6, "Frank"),
39+
(7, "Grace"),
40+
(8, "Hank"),
41+
(9, "Ivy"),
42+
(10, "Jack"),
43+
];
44+
(schema, data)
45+
}
46+
47+
enum CursorCommand {
48+
Declare { name: String, inner_query: String },
49+
Fetch { name: String, count: usize },
50+
Close { name: String },
51+
}
52+
53+
fn parse_cursor_command(query: &str) -> Option<CursorCommand> {
54+
let upper = query.to_uppercase();
55+
if upper.starts_with("DECLARE") {
56+
let rest = &query["DECLARE".len()..].trim_start();
57+
let parts: Vec<&str> = rest.splitn(2, "FOR").collect();
58+
if parts.len() == 2 {
59+
Some(CursorCommand::Declare {
60+
name: parts[0].trim().to_string(),
61+
inner_query: parts[1].trim().to_string(),
62+
})
63+
} else {
64+
None
65+
}
66+
} else if upper.starts_with("FETCH") {
67+
let rest = &query["FETCH".len()..].trim_start();
68+
let parts: Vec<&str> = rest.splitn(2, "FROM").collect();
69+
if parts.len() == 2 {
70+
Some(CursorCommand::Fetch {
71+
count: parts[0].trim().parse().unwrap_or(1),
72+
name: parts[1].trim().trim_end_matches(';').trim().to_string(),
73+
})
74+
} else {
75+
None
76+
}
77+
} else if upper.starts_with("CLOSE") {
78+
let rest = &query["CLOSE".len()..].trim_start();
79+
Some(CursorCommand::Close {
80+
name: rest.trim().trim_end_matches(';').trim().to_string(),
81+
})
82+
} else {
83+
None
84+
}
85+
}
86+
87+
fn encode_row_data(
88+
data: Vec<(i32, &'static str)>,
89+
schema: Arc<Vec<FieldInfo>>,
90+
) -> impl futures::Stream<Item = PgWireResult<pgwire::messages::data::DataRow>> + use<> {
91+
let mut encoder = DataRowEncoder::new(schema);
92+
stream::iter(data).map(move |(id, name)| {
93+
encoder.encode_field(&id)?;
94+
encoder.encode_field(&name)?;
95+
Ok(encoder.take_row())
96+
})
97+
}
98+
99+
/// Demo stub: validates that the query starts with SELECT, then returns
100+
/// hardcoded data. A real implementation would parse and execute the query.
101+
fn execute_inner_query(inner_query: &str) -> PgWireResult<QueryResponse> {
102+
if !inner_query.to_uppercase().starts_with("SELECT") {
103+
return Err(PgWireError::UserError(Box::new(ErrorInfo::new(
104+
"ERROR".to_owned(),
105+
"42809".to_owned(),
106+
"DECLARE CURSOR can only be used with SELECT queries".to_string(),
107+
))));
108+
}
109+
110+
let (schema, data) = make_row_data();
111+
let row_stream = encode_row_data(data, schema.clone());
112+
Ok(QueryResponse::new(schema, row_stream))
113+
}
114+
115+
#[async_trait]
116+
impl NoopStartupHandler for CursorBackend {
117+
async fn post_startup<C>(
118+
&self,
119+
client: &mut C,
120+
_message: PgWireFrontendMessage,
121+
) -> PgWireResult<()>
122+
where
123+
C: ClientInfo + Sink<PgWireBackendMessage> + Unpin + Send,
124+
C::Error: std::fmt::Debug,
125+
PgWireError: From<<C as Sink<PgWireBackendMessage>>::Error>,
126+
{
127+
println!("Client connected: {}", client.socket_addr());
128+
129+
let notice = NoticeResponse::from(ErrorInfo::new(
130+
"INFO".into(),
131+
"00000".into(),
132+
"Cursor example server. Supported statements:\n \
133+
DECLARE <name> FOR SELECT ...\n \
134+
FETCH <n> FROM <name>\n \
135+
CLOSE <name>\n \
136+
SELECT ..."
137+
.into(),
138+
));
139+
client
140+
.send(PgWireBackendMessage::NoticeResponse(notice))
141+
.await?;
142+
143+
Ok(())
144+
}
145+
}
146+
147+
#[async_trait]
148+
impl SimpleQueryHandler for CursorBackend {
149+
async fn do_query<C>(&self, client: &mut C, query: &str) -> PgWireResult<Vec<Response>>
150+
where
151+
C: ClientInfo + ClientPortalStore + Unpin + Send + Sync,
152+
C::PortalStore: PortalStore,
153+
{
154+
println!("Query: {:?}", query);
155+
156+
let portal_store = client
157+
.portal_store()
158+
.as_any()
159+
.downcast_ref::<MemPortalStore<Statement>>()
160+
.expect("expected MemPortalStore<String>");
161+
162+
if let Some(cmd) = parse_cursor_command(query) {
163+
match cmd {
164+
CursorCommand::Declare { name, inner_query } => {
165+
handle_declare(portal_store, &name, &inner_query)
166+
}
167+
CursorCommand::Fetch { name, count } => {
168+
handle_fetch(portal_store, &name, count).await
169+
}
170+
CursorCommand::Close { name } => handle_close(portal_store, &name),
171+
}
172+
} else if query.to_uppercase().starts_with("SELECT") {
173+
let (schema, data) = make_row_data();
174+
let row_stream = encode_row_data(data, schema.clone());
175+
Ok(vec![Response::Query(QueryResponse::new(
176+
schema, row_stream,
177+
))])
178+
} else {
179+
Ok(vec![Response::Execution(Tag::new("OK").with_rows(1))])
180+
}
181+
}
182+
}
183+
184+
fn handle_declare(
185+
portal_store: &MemPortalStore<Statement>,
186+
cursor_name: &str,
187+
inner_query: &str,
188+
) -> PgWireResult<Vec<Response>> {
189+
println!("DECLARE cursor '{}' FOR {}", cursor_name, inner_query);
190+
191+
if !inner_query.to_uppercase().starts_with("SELECT") {
192+
return Err(PgWireError::UserError(Box::new(ErrorInfo::new(
193+
"ERROR".to_owned(),
194+
"42809".to_owned(),
195+
"DECLARE CURSOR can only be used with SELECT queries".to_string(),
196+
))));
197+
}
198+
199+
let statement = StoredStatement::new(cursor_name.to_string(), inner_query.to_string(), vec![]);
200+
let portal = Portal::new_cursor(cursor_name.to_string(), Arc::new(statement));
201+
portal_store.put_portal(Arc::new(portal));
202+
203+
Ok(vec![Response::Execution(Tag::new("DECLARE CURSOR"))])
204+
}
205+
206+
async fn handle_fetch(
207+
portal_store: &MemPortalStore<Statement>,
208+
cursor_name: &str,
209+
count: usize,
210+
) -> PgWireResult<Vec<Response>> {
211+
println!("FETCH {} FROM {}", count, cursor_name);
212+
213+
let Some(portal) = portal_store.get_portal(cursor_name) else {
214+
return Err(PgWireError::UserError(Box::new(ErrorInfo::new(
215+
"ERROR".to_owned(),
216+
"34000".to_owned(),
217+
format!("cursor \"{}\" does not exist", cursor_name),
218+
))));
219+
};
220+
221+
// Lazy execution: if the cursor hasn't been started yet, execute the
222+
// stored statement now and transition to Suspended.
223+
if matches!(
224+
portal.state().lock().await.deref(),
225+
pgwire::api::portal::PortalExecutionState::Initial
226+
) {
227+
let inner_query = &portal.statement.statement;
228+
println!(" -> Lazy execution of: {}", inner_query);
229+
let response = execute_inner_query(inner_query)?;
230+
portal.start(response).await;
231+
}
232+
233+
let fetch_result = portal.fetch(count).await?;
234+
println!(
235+
" -> Fetched {} rows, has_more: {}",
236+
fetch_result.rows.len(),
237+
fetch_result.suspended
238+
);
239+
240+
let schema = fetch_result.row_schema;
241+
let row_stream = stream::iter(fetch_result.rows.into_iter().map(Ok));
242+
Ok(vec![Response::Query(QueryResponse::new(
243+
schema, row_stream,
244+
))])
245+
}
246+
247+
fn handle_close(
248+
portal_store: &MemPortalStore<Statement>,
249+
cursor_name: &str,
250+
) -> PgWireResult<Vec<Response>> {
251+
println!("CLOSE {}", cursor_name);
252+
portal_store.rm_portal(cursor_name);
253+
Ok(vec![Response::Execution(Tag::new("CLOSE CURSOR"))])
254+
}
255+
256+
struct CursorBackendFactory {
257+
handler: Arc<CursorBackend>,
258+
}
259+
260+
impl PgWireServerHandlers for CursorBackendFactory {
261+
fn simple_query_handler(&self) -> Arc<impl SimpleQueryHandler> {
262+
self.handler.clone()
263+
}
264+
265+
fn startup_handler(&self) -> Arc<impl StartupHandler> {
266+
self.handler.clone()
267+
}
268+
}
269+
270+
#[tokio::main]
271+
pub async fn main() {
272+
let factory = Arc::new(CursorBackendFactory {
273+
handler: Arc::new(CursorBackend),
274+
});
275+
276+
let server_addr = "127.0.0.1:5432";
277+
let listener = TcpListener::bind(server_addr).await.unwrap();
278+
println!("Listening to {}", server_addr);
279+
println!();
280+
println!("Try these commands with psql:");
281+
println!(" DECLARE my_cursor FOR SELECT * FROM users;");
282+
println!(" FETCH 3 FROM my_cursor;");
283+
println!(" FETCH 3 FROM my_cursor;");
284+
println!(" FETCH 3 FROM my_cursor;");
285+
println!(" FETCH 3 FROM my_cursor;");
286+
println!(" CLOSE my_cursor;");
287+
loop {
288+
let incoming_socket = listener.accept().await.unwrap();
289+
let factory_ref = factory.clone();
290+
tokio::spawn(async move { process_socket(incoming_socket.0, None, factory_ref).await });
291+
}
292+
}

0 commit comments

Comments
 (0)