1- use std:: ops:: DerefMut ;
21use std:: sync:: Arc ;
32
43use async_trait:: async_trait;
@@ -7,9 +6,8 @@ use datafusion::logical_expr::LogicalPlan;
76use datafusion:: prelude:: SessionContext ;
87use datafusion:: sql:: sqlparser;
98use datafusion:: sql:: sqlparser:: ast:: { CloseCursor , DeclareType , FetchDirection } ;
10- use futures:: StreamExt ;
119use pgwire:: api:: ClientInfo ;
12- use pgwire:: api:: portal:: { Format , Portal , PortalExecutionState } ;
10+ use pgwire:: api:: portal:: { Format , Portal } ;
1311use pgwire:: api:: results:: { QueryResponse , Response , Tag } ;
1412use pgwire:: api:: stmt:: StoredStatement ;
1513use pgwire:: api:: store:: { MemPortalStore , PortalStore } ;
@@ -48,22 +46,43 @@ impl QueryHook for CursorStatementHook {
4846
4947 async fn handle_extended_parse_query (
5048 & self ,
51- _statement : & sqlparser:: ast:: Statement ,
49+ statement : & sqlparser:: ast:: Statement ,
5250 _session_context : & SessionContext ,
5351 _client : & ( dyn ClientInfo + Send + Sync ) ,
5452 ) -> Option < PgWireResult < LogicalPlan > > {
55- None
53+ match statement {
54+ sqlparser:: ast:: Statement :: Declare { .. }
55+ | sqlparser:: ast:: Statement :: Fetch { .. }
56+ | sqlparser:: ast:: Statement :: Close { .. } => Some ( Ok ( LogicalPlan :: EmptyRelation (
57+ datafusion:: logical_expr:: EmptyRelation {
58+ produce_one_row : false ,
59+ schema : Arc :: new ( datafusion:: common:: DFSchema :: empty ( ) ) ,
60+ } ,
61+ ) ) ) ,
62+ _ => None ,
63+ }
5664 }
5765
5866 async fn handle_extended_query (
5967 & self ,
60- _statement : & sqlparser:: ast:: Statement ,
68+ statement : & sqlparser:: ast:: Statement ,
6169 _logical_plan : & LogicalPlan ,
6270 _params : & ParamValues ,
63- _session_context : & SessionContext ,
64- _client : & mut dyn HookClient ,
71+ session_context : & SessionContext ,
72+ client : & mut dyn HookClient ,
6573 ) -> Option < PgWireResult < Response > > {
66- None
74+ let store = client. portal_store ( ) ;
75+
76+ match statement {
77+ sqlparser:: ast:: Statement :: Declare { stmts } => {
78+ Some ( handle_declare ( store, stmts, session_context) . await )
79+ }
80+ sqlparser:: ast:: Statement :: Fetch {
81+ name, direction, ..
82+ } => Some ( handle_fetch ( store, name, direction) . await ) ,
83+ sqlparser:: ast:: Statement :: Close { cursor } => Some ( handle_close ( store, cursor) ) ,
84+ _ => None ,
85+ }
6786 }
6887}
6988
@@ -122,20 +141,9 @@ async fn handle_declare(
122141 vec ! [ ] ,
123142 ) ) ;
124143
125- let bind = pgwire:: messages:: extendedquery:: Bind :: new (
126- Some ( cursor_name. clone ( ) ) ,
127- None ,
128- vec ! [ ] ,
129- vec ! [ ] ,
130- vec ! [ ] ,
131- ) ;
132- let portal = Portal :: try_new ( & bind, stored_stmt) ?;
144+ let portal = Portal :: new_cursor ( cursor_name. clone ( ) , stored_stmt) ;
133145
134- let state = portal. state ( ) ;
135- {
136- let mut portal_state = state. lock ( ) . await ;
137- * portal_state = PortalExecutionState :: Suspended ( query_response) ;
138- }
146+ portal. start ( query_response) . await ;
139147
140148 store. put_portal ( Arc :: new ( portal) ) ;
141149 }
@@ -150,7 +158,7 @@ async fn handle_fetch(
150158) -> PgWireResult < Response > {
151159 let cursor_name = & name. value ;
152160
153- let limit = match direction {
161+ let max_rows = match direction {
154162 FetchDirection :: Next | FetchDirection :: Forward { limit : None } => Some ( 1 ) ,
155163 FetchDirection :: Forward { limit : Some ( v) } | FetchDirection :: Count { limit : v } => {
156164 parse_value_as_usize ( v)
@@ -187,70 +195,19 @@ async fn handle_fetch(
187195 ) ) )
188196 } ) ?;
189197
190- let state = portal. state ( ) ;
191- let mut state = state. lock ( ) . await ;
192-
193- let query_response = match state. deref_mut ( ) {
194- PortalExecutionState :: Suspended ( qr) => qr,
195- PortalExecutionState :: Finished => {
196- return Ok ( Response :: Execution ( Tag :: new ( "FETCH" ) . with_rows ( 0 ) ) ) ;
197- }
198- PortalExecutionState :: Initial => {
199- return Err ( PgWireError :: UserError ( Box :: new (
200- pgwire:: error:: ErrorInfo :: new (
201- "ERROR" . to_string ( ) ,
202- "24000" . to_string ( ) ,
203- "cursor is in invalid state" . to_string ( ) ,
204- ) ,
205- ) ) ) ;
206- }
207- } ;
208-
209- let schema = query_response. row_schema ( ) ;
210- let mut fetched_rows: Vec < pgwire:: messages:: data:: DataRow > = vec ! [ ] ;
211- let mut stream_exhausted = false ;
212-
213- if let Some ( n) = limit {
214- for _ in 0 ..n {
215- match query_response. data_rows ( ) . next ( ) . await {
216- Some ( Ok ( row) ) => fetched_rows. push ( row) ,
217- Some ( Err ( e) ) => return Err ( e) ,
218- None => {
219- stream_exhausted = true ;
220- break ;
221- }
222- }
223- }
224- } else {
225- loop {
226- match query_response. data_rows ( ) . next ( ) . await {
227- Some ( Ok ( row) ) => fetched_rows. push ( row) ,
228- Some ( Err ( e) ) => return Err ( e) ,
229- None => {
230- stream_exhausted = true ;
231- break ;
232- }
233- }
234- }
235- }
236-
237- if stream_exhausted {
238- * state = PortalExecutionState :: Finished ;
239- }
240-
241- drop ( state) ;
198+ let fetch_result = portal. fetch ( max_rows. unwrap_or ( 0 ) ) . await ?;
242199
243- if fetched_rows . is_empty ( ) {
200+ if fetch_result . rows . is_empty ( ) {
244201 return Ok ( Response :: Execution ( Tag :: new ( "FETCH" ) . with_rows ( 0 ) ) ) ;
245202 }
246203
247- let mut fetched_response = QueryResponse :: new (
248- schema ,
249- futures:: stream:: iter ( fetched_rows . into_iter ( ) . map ( Ok ) ) ,
204+ let mut response = QueryResponse :: new (
205+ fetch_result . row_schema ,
206+ futures:: stream:: iter ( fetch_result . rows . into_iter ( ) . map ( Ok ) ) ,
250207 ) ;
251- fetched_response . set_command_tag ( "FETCH" ) ;
208+ response . set_command_tag ( "FETCH" ) ;
252209
253- Ok ( Response :: Query ( fetched_response ) )
210+ Ok ( Response :: Query ( response ) )
254211}
255212
256213fn handle_close (
0 commit comments