@@ -58,27 +58,18 @@ pub struct DfSessionService {
5858 session_context : Arc < SessionContext > ,
5959 parser : Arc < Parser > ,
6060 timezone : Arc < Mutex < String > > ,
61- catalog_name : String ,
6261}
6362
6463impl DfSessionService {
65- pub fn new ( session_context : SessionContext , catalog_name : Option < String > ) -> DfSessionService {
64+ pub fn new ( session_context : SessionContext ) -> DfSessionService {
6665 let session_context = Arc :: new ( session_context) ;
6766 let parser = Arc :: new ( Parser {
6867 session_context : session_context. clone ( ) ,
6968 } ) ;
70- let catalog_name = catalog_name. unwrap_or_else ( || {
71- session_context
72- . catalog_names ( )
73- . first ( )
74- . cloned ( )
75- . unwrap_or_else ( || "datafusion" . to_string ( ) )
76- } ) ;
7769 DfSessionService {
7870 session_context,
7971 parser,
8072 timezone : Arc :: new ( Mutex :: new ( "UTC" . to_string ( ) ) ) ,
81- catalog_name,
8273 }
8374 }
8475
@@ -103,35 +94,40 @@ impl DfSessionService {
10394
10495 // Mock pg_namespace response
10596 async fn mock_pg_namespace < ' a > ( & self ) -> PgWireResult < QueryResponse < ' a > > {
106- let fields = vec ! [ FieldInfo :: new(
97+ let fields = Arc :: new ( vec ! [ FieldInfo :: new(
10798 "nspname" . to_string( ) ,
10899 None ,
109100 None ,
110101 Type :: VARCHAR ,
111102 FieldFormat :: Text ,
112- ) ] ;
103+ ) ] ) ;
113104
114- let row = {
115- let mut encoder = pgwire:: api:: results:: DataRowEncoder :: new ( Arc :: new ( fields. clone ( ) ) ) ;
116- encoder. encode_field ( & Some ( & self . catalog_name ) ) ?; // Return catalog_name as a schema
117- encoder. finish ( )
118- } ;
119-
120- let row_stream = futures:: stream:: once ( async move { row } ) ;
121- Ok ( QueryResponse :: new ( Arc :: new ( fields) , Box :: pin ( row_stream) ) )
105+ let fields_ref = fields. clone ( ) ;
106+ let rows = self
107+ . session_context
108+ . catalog_names ( )
109+ . into_iter ( )
110+ . map ( move |name| {
111+ let mut encoder = pgwire:: api:: results:: DataRowEncoder :: new ( fields_ref. clone ( ) ) ;
112+ encoder. encode_field ( & Some ( & name) ) ?; // Return catalog_name as a schema
113+ encoder. finish ( )
114+ } ) ;
115+
116+ let row_stream = futures:: stream:: iter ( rows) ;
117+ Ok ( QueryResponse :: new ( fields. clone ( ) , Box :: pin ( row_stream) ) )
122118 }
123119
124120 async fn try_respond_set_time_zone < ' a > (
125121 & self ,
126122 query_lower : & str ,
127- ) -> PgWireResult < Option < Vec < Response < ' a > > > > {
123+ ) -> PgWireResult < Option < Response < ' a > > > {
128124 if query_lower. starts_with ( "set time zone" ) {
129125 let parts: Vec < & str > = query_lower. split_whitespace ( ) . collect ( ) ;
130126 if parts. len ( ) >= 4 {
131127 let tz = parts[ 3 ] . trim_matches ( '"' ) ;
132128 let mut timezone = self . timezone . lock ( ) . await ;
133129 * timezone = tz. to_string ( ) ;
134- Ok ( Some ( vec ! [ Response :: Execution ( Tag :: new( "SET" ) ) ] ) )
130+ Ok ( Some ( Response :: Execution ( Tag :: new ( "SET" ) ) ) )
135131 } else {
136132 Err ( PgWireError :: UserError ( Box :: new (
137133 pgwire:: error:: ErrorInfo :: new (
@@ -149,32 +145,33 @@ impl DfSessionService {
149145 async fn try_respond_show_statements < ' a > (
150146 & self ,
151147 query_lower : & str ,
152- ) -> PgWireResult < Option < Vec < Response < ' a > > > > {
148+ ) -> PgWireResult < Option < Response < ' a > > > {
153149 if query_lower. starts_with ( "show " ) {
154- match query_lower {
150+ match query_lower. strip_suffix ( ";" ) . unwrap_or ( query_lower ) {
155151 "show time zone" => {
156152 let timezone = self . timezone . lock ( ) . await . clone ( ) ;
157153 let resp = Self :: mock_show_response ( "TimeZone" , & timezone) ?;
158- Ok ( Some ( vec ! [ Response :: Query ( resp) ] ) )
154+ Ok ( Some ( Response :: Query ( resp) ) )
159155 }
160156 "show server_version" => {
161157 let resp = Self :: mock_show_response ( "server_version" , "15.0 (DataFusion)" ) ?;
162- Ok ( Some ( vec ! [ Response :: Query ( resp) ] ) )
158+ Ok ( Some ( Response :: Query ( resp) ) )
163159 }
164160 "show transaction_isolation" => {
165161 let resp =
166162 Self :: mock_show_response ( "transaction_isolation" , "read uncommitted" ) ?;
167- Ok ( Some ( vec ! [ Response :: Query ( resp) ] ) )
163+ Ok ( Some ( Response :: Query ( resp) ) )
168164 }
169165 "show catalogs" => {
170166 let catalogs = self . session_context . catalog_names ( ) ;
171167 let value = catalogs. join ( ", " ) ;
172168 let resp = Self :: mock_show_response ( "Catalogs" , & value) ?;
173- Ok ( Some ( vec ! [ Response :: Query ( resp) ] ) )
169+ Ok ( Some ( Response :: Query ( resp) ) )
174170 }
175171 "show search_path" => {
176- let resp = Self :: mock_show_response ( "search_path" , & self . catalog_name ) ?;
177- Ok ( Some ( vec ! [ Response :: Query ( resp) ] ) )
172+ let default_catalog = "datafusion" ;
173+ let resp = Self :: mock_show_response ( "search_path" , default_catalog) ?;
174+ Ok ( Some ( Response :: Query ( resp) ) )
178175 }
179176 _ => Err ( PgWireError :: UserError ( Box :: new (
180177 pgwire:: error:: ErrorInfo :: new (
@@ -192,31 +189,31 @@ impl DfSessionService {
192189 async fn try_respond_information_schema < ' a > (
193190 & self ,
194191 query_lower : & str ,
195- ) -> PgWireResult < Option < Vec < Response < ' a > > > > {
192+ ) -> PgWireResult < Option < Response < ' a > > > {
196193 if query_lower. contains ( "information_schema.schemata" ) {
197194 let df = schemata_df ( & self . session_context )
198195 . await
199196 . map_err ( |e| PgWireError :: ApiError ( Box :: new ( e) ) ) ?;
200197 let resp = datatypes:: encode_dataframe ( df, & Format :: UnifiedText ) . await ?;
201- return Ok ( Some ( vec ! [ Response :: Query ( resp) ] ) ) ;
198+ return Ok ( Some ( Response :: Query ( resp) ) ) ;
202199 } else if query_lower. contains ( "information_schema.tables" ) {
203200 let df = tables_df ( & self . session_context )
204201 . await
205202 . map_err ( |e| PgWireError :: ApiError ( Box :: new ( e) ) ) ?;
206203 let resp = datatypes:: encode_dataframe ( df, & Format :: UnifiedText ) . await ?;
207- return Ok ( Some ( vec ! [ Response :: Query ( resp) ] ) ) ;
204+ return Ok ( Some ( Response :: Query ( resp) ) ) ;
208205 } else if query_lower. contains ( "information_schema.columns" ) {
209206 let df = columns_df ( & self . session_context )
210207 . await
211208 . map_err ( |e| PgWireError :: ApiError ( Box :: new ( e) ) ) ?;
212209 let resp = datatypes:: encode_dataframe ( df, & Format :: UnifiedText ) . await ?;
213- return Ok ( Some ( vec ! [ Response :: Query ( resp) ] ) ) ;
210+ return Ok ( Some ( Response :: Query ( resp) ) ) ;
214211 }
215212
216213 // Handle pg_catalog.pg_namespace for pgcli compatibility
217214 if query_lower. contains ( "pg_catalog.pg_namespace" ) {
218215 let resp = self . mock_pg_namespace ( ) . await ?;
219- return Ok ( Some ( vec ! [ Response :: Query ( resp) ] ) ) ;
216+ return Ok ( Some ( Response :: Query ( resp) ) ) ;
220217 }
221218
222219 Ok ( None )
@@ -233,15 +230,15 @@ impl SimpleQueryHandler for DfSessionService {
233230 log:: debug!( "Received query: {}" , query) ; // Log the query for debugging
234231
235232 if let Some ( resp) = self . try_respond_set_time_zone ( & query_lower) . await ? {
236- return Ok ( resp) ;
233+ return Ok ( vec ! [ resp] ) ;
237234 }
238235
239236 if let Some ( resp) = self . try_respond_show_statements ( & query_lower) . await ? {
240- return Ok ( resp) ;
237+ return Ok ( vec ! [ resp] ) ;
241238 }
242239
243240 if let Some ( resp) = self . try_respond_information_schema ( & query_lower) . await ? {
244- return Ok ( resp) ;
241+ return Ok ( vec ! [ resp] ) ;
245242 }
246243
247244 let df = self
@@ -352,67 +349,12 @@ impl ExtendedQueryHandler for DfSessionService {
352349 . to_string ( ) ;
353350 log:: debug!( "Received extended query: {}" , query) ; // Log for debugging
354351
355- if query. starts_with ( "show " ) {
356- match query. as_str ( ) {
357- "show time zone" => {
358- let timezone = self . timezone . lock ( ) . await . clone ( ) ;
359- let resp = Self :: mock_show_response ( "TimeZone" , & timezone) ?;
360- return Ok ( Response :: Query ( resp) ) ;
361- }
362- "show server_version" => {
363- let resp = Self :: mock_show_response ( "server_version" , "15.0 (DataFusion)" ) ?;
364- return Ok ( Response :: Query ( resp) ) ;
365- }
366- "show transaction_isolation" => {
367- let resp =
368- Self :: mock_show_response ( "transaction_isolation" , "read uncommitted" ) ?;
369- return Ok ( Response :: Query ( resp) ) ;
370- }
371- "show catalogs" => {
372- let catalogs = self . session_context . catalog_names ( ) ;
373- let value = catalogs. join ( ", " ) ;
374- let resp = Self :: mock_show_response ( "Catalogs" , & value) ?;
375- return Ok ( Response :: Query ( resp) ) ;
376- }
377- "show search_path" => {
378- let resp = Self :: mock_show_response ( "search_path" , & self . catalog_name ) ?;
379- return Ok ( Response :: Query ( resp) ) ;
380- }
381- _ => {
382- return Err ( PgWireError :: UserError ( Box :: new (
383- pgwire:: error:: ErrorInfo :: new (
384- "ERROR" . to_string ( ) ,
385- "42704" . to_string ( ) ,
386- format ! ( "Unrecognized SHOW command: {}" , query) ,
387- ) ,
388- ) ) ) ;
389- }
390- }
391- }
392-
393- if query. contains ( "information_schema.schemata" ) {
394- let df = schemata_df ( & self . session_context )
395- . await
396- . map_err ( |e| PgWireError :: ApiError ( Box :: new ( e) ) ) ?;
397- let resp = datatypes:: encode_dataframe ( df, & portal. result_column_format ) . await ?;
398- return Ok ( Response :: Query ( resp) ) ;
399- } else if query. contains ( "information_schema.tables" ) {
400- let df = tables_df ( & self . session_context )
401- . await
402- . map_err ( |e| PgWireError :: ApiError ( Box :: new ( e) ) ) ?;
403- let resp = datatypes:: encode_dataframe ( df, & portal. result_column_format ) . await ?;
404- return Ok ( Response :: Query ( resp) ) ;
405- } else if query. contains ( "information_schema.columns" ) {
406- let df = columns_df ( & self . session_context )
407- . await
408- . map_err ( |e| PgWireError :: ApiError ( Box :: new ( e) ) ) ?;
409- let resp = datatypes:: encode_dataframe ( df, & portal. result_column_format ) . await ?;
410- return Ok ( Response :: Query ( resp) ) ;
352+ if let Some ( resp) = self . try_respond_show_statements ( & query) . await ? {
353+ return Ok ( resp) ;
411354 }
412355
413- if query. contains ( "pg_catalog.pg_namespace" ) {
414- let resp = self . mock_pg_namespace ( ) . await ?;
415- return Ok ( Response :: Query ( resp) ) ;
356+ if let Some ( resp) = self . try_respond_information_schema ( & query) . await ? {
357+ return Ok ( resp) ;
416358 }
417359
418360 let plan = & portal. statement . statement ;
0 commit comments