Skip to content

Commit bae45f1

Browse files
committed
feat: update cursor to support extended query
1 parent f57f49e commit bae45f1

5 files changed

Lines changed: 58 additions & 89 deletions

File tree

Cargo.lock

Lines changed: 6 additions & 7 deletions
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

Cargo.toml

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -30,3 +30,6 @@ opt-level = "z"
3030
lto = true
3131
codegen-units = 1
3232
panic = "abort"
33+
34+
[patch.crates-io]
35+
pgwire = { git = "https://github.com/sunng87/pgwire", branch = "feature/cursor-for-portal" }

arrow-pg/Cargo.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -27,7 +27,7 @@ datafusion = { workspace = true, optional = true }
2727
futures.workspace = true
2828
geoarrow = { version = "0.8", optional = true }
2929
geoarrow-schema = { version = "0.8", optional = true }
30-
pg_interval = { version = "0.5.1", package = "pg_interval_2" }
30+
pg_interval = { version = "0.5.0" }
3131
pgwire = { workspace = true, default-features = false, features = ["server-api", "pg-ext-types"] }
3232
postgres-types.workspace = true
3333
rust_decimal.workspace = true

datafusion-pg-catalog/src/sql/parser.rs

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@ use std::sync::Arc;
22

33
use datafusion::sql::sqlparser::ast::Statement;
44
use datafusion::sql::sqlparser::dialect::PostgreSqlDialect;
5+
use datafusion::sql::sqlparser::keywords::Keyword;
56
use datafusion::sql::sqlparser::parser::Parser;
67
use datafusion::sql::sqlparser::parser::ParserError;
78
use datafusion::sql::sqlparser::tokenizer::Token;
@@ -247,10 +248,19 @@ impl PostgresCompatibilityParser {
247248

248249
// Get token values (without spans) and filter out only whitespace
249250
// Keep semicolons as they separate statements
251+
// Also rewrite ABORT to ROLLBACK for postgres compatibility
252+
// remove this when https://github.com/apache/datafusion-sqlparser-rs/pull/2332 is ready
250253
let filtered_tokens: Vec<Token> = tokens
251254
.iter()
252255
.map(|t| t.token.clone())
253256
.filter(|t| !matches!(t, Token::Whitespace(_)))
257+
.map(|t| {
258+
if matches!(&t, Token::Word(w) if w.keyword == Keyword::ABORT) {
259+
Token::make_keyword("ROLLBACK")
260+
} else {
261+
t
262+
}
263+
})
254264
.collect();
255265

256266
// Handle empty input

datafusion-postgres/src/hooks/cursor.rs

Lines changed: 38 additions & 81 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,3 @@
1-
use std::ops::DerefMut;
21
use std::sync::Arc;
32

43
use async_trait::async_trait;
@@ -7,9 +6,8 @@ use datafusion::logical_expr::LogicalPlan;
76
use datafusion::prelude::SessionContext;
87
use datafusion::sql::sqlparser;
98
use datafusion::sql::sqlparser::ast::{CloseCursor, DeclareType, FetchDirection};
10-
use futures::StreamExt;
119
use pgwire::api::ClientInfo;
12-
use pgwire::api::portal::{Format, Portal, PortalExecutionState};
10+
use pgwire::api::portal::{Format, Portal};
1311
use pgwire::api::results::{QueryResponse, Response, Tag};
1412
use pgwire::api::stmt::StoredStatement;
1513
use pgwire::api::store::{MemPortalStore, PortalStore};
@@ -48,22 +46,43 @@ impl QueryHook for CursorStatementHook {
4846

4947
async fn handle_extended_parse_query(
5048
&self,
51-
_statement: &sqlparser::ast::Statement,
49+
statement: &sqlparser::ast::Statement,
5250
_session_context: &SessionContext,
5351
_client: &(dyn ClientInfo + Send + Sync),
5452
) -> Option<PgWireResult<LogicalPlan>> {
55-
None
53+
match statement {
54+
sqlparser::ast::Statement::Declare { .. }
55+
| sqlparser::ast::Statement::Fetch { .. }
56+
| sqlparser::ast::Statement::Close { .. } => Some(Ok(LogicalPlan::EmptyRelation(
57+
datafusion::logical_expr::EmptyRelation {
58+
produce_one_row: false,
59+
schema: Arc::new(datafusion::common::DFSchema::empty()),
60+
},
61+
))),
62+
_ => None,
63+
}
5664
}
5765

5866
async fn handle_extended_query(
5967
&self,
60-
_statement: &sqlparser::ast::Statement,
68+
statement: &sqlparser::ast::Statement,
6169
_logical_plan: &LogicalPlan,
6270
_params: &ParamValues,
63-
_session_context: &SessionContext,
64-
_client: &mut dyn HookClient,
71+
session_context: &SessionContext,
72+
client: &mut dyn HookClient,
6573
) -> Option<PgWireResult<Response>> {
66-
None
74+
let store = client.portal_store();
75+
76+
match statement {
77+
sqlparser::ast::Statement::Declare { stmts } => {
78+
Some(handle_declare(store, stmts, session_context).await)
79+
}
80+
sqlparser::ast::Statement::Fetch {
81+
name, direction, ..
82+
} => Some(handle_fetch(store, name, direction).await),
83+
sqlparser::ast::Statement::Close { cursor } => Some(handle_close(store, cursor)),
84+
_ => None,
85+
}
6786
}
6887
}
6988

@@ -122,20 +141,9 @@ async fn handle_declare(
122141
vec![],
123142
));
124143

125-
let bind = pgwire::messages::extendedquery::Bind::new(
126-
Some(cursor_name.clone()),
127-
None,
128-
vec![],
129-
vec![],
130-
vec![],
131-
);
132-
let portal = Portal::try_new(&bind, stored_stmt)?;
144+
let portal = Portal::new_cursor(cursor_name.clone(), stored_stmt);
133145

134-
let state = portal.state();
135-
{
136-
let mut portal_state = state.lock().await;
137-
*portal_state = PortalExecutionState::Suspended(query_response);
138-
}
146+
portal.start(query_response).await;
139147

140148
store.put_portal(Arc::new(portal));
141149
}
@@ -150,7 +158,7 @@ async fn handle_fetch(
150158
) -> PgWireResult<Response> {
151159
let cursor_name = &name.value;
152160

153-
let limit = match direction {
161+
let max_rows = match direction {
154162
FetchDirection::Next | FetchDirection::Forward { limit: None } => Some(1),
155163
FetchDirection::Forward { limit: Some(v) } | FetchDirection::Count { limit: v } => {
156164
parse_value_as_usize(v)
@@ -187,70 +195,19 @@ async fn handle_fetch(
187195
)))
188196
})?;
189197

190-
let state = portal.state();
191-
let mut state = state.lock().await;
192-
193-
let query_response = match state.deref_mut() {
194-
PortalExecutionState::Suspended(qr) => qr,
195-
PortalExecutionState::Finished => {
196-
return Ok(Response::Execution(Tag::new("FETCH").with_rows(0)));
197-
}
198-
PortalExecutionState::Initial => {
199-
return Err(PgWireError::UserError(Box::new(
200-
pgwire::error::ErrorInfo::new(
201-
"ERROR".to_string(),
202-
"24000".to_string(),
203-
"cursor is in invalid state".to_string(),
204-
),
205-
)));
206-
}
207-
};
208-
209-
let schema = query_response.row_schema();
210-
let mut fetched_rows: Vec<pgwire::messages::data::DataRow> = vec![];
211-
let mut stream_exhausted = false;
212-
213-
if let Some(n) = limit {
214-
for _ in 0..n {
215-
match query_response.data_rows().next().await {
216-
Some(Ok(row)) => fetched_rows.push(row),
217-
Some(Err(e)) => return Err(e),
218-
None => {
219-
stream_exhausted = true;
220-
break;
221-
}
222-
}
223-
}
224-
} else {
225-
loop {
226-
match query_response.data_rows().next().await {
227-
Some(Ok(row)) => fetched_rows.push(row),
228-
Some(Err(e)) => return Err(e),
229-
None => {
230-
stream_exhausted = true;
231-
break;
232-
}
233-
}
234-
}
235-
}
236-
237-
if stream_exhausted {
238-
*state = PortalExecutionState::Finished;
239-
}
240-
241-
drop(state);
198+
let fetch_result = portal.fetch(max_rows.unwrap_or(0)).await?;
242199

243-
if fetched_rows.is_empty() {
200+
if fetch_result.rows.is_empty() {
244201
return Ok(Response::Execution(Tag::new("FETCH").with_rows(0)));
245202
}
246203

247-
let mut fetched_response = QueryResponse::new(
248-
schema,
249-
futures::stream::iter(fetched_rows.into_iter().map(Ok)),
204+
let mut response = QueryResponse::new(
205+
fetch_result.row_schema,
206+
futures::stream::iter(fetch_result.rows.into_iter().map(Ok)),
250207
);
251-
fetched_response.set_command_tag("FETCH");
208+
response.set_command_tag("FETCH");
252209

253-
Ok(Response::Query(fetched_response))
210+
Ok(Response::Query(response))
254211
}
255212

256213
fn handle_close(

0 commit comments

Comments
 (0)