@@ -20,6 +20,7 @@ use std::sync::Arc;
2020
2121use crate :: sql:: db_connection_pool:: {
2222 dbconnection:: { self , AsyncDbConnection , DbConnection , GenericError } ,
23+ runtime:: run_async_with_tokio,
2324 DbConnectionPool ,
2425} ;
2526use arrow_odbc:: arrow_schema_from;
@@ -42,7 +43,7 @@ use odbc_api::handles::StatementImpl;
4243use odbc_api:: parameter:: InputParameter ;
4344use odbc_api:: Cursor ;
4445use odbc_api:: CursorImpl ;
45- use secrecy:: { SecretBox , ExposeSecret , SecretString } ;
46+ use secrecy:: { ExposeSecret , SecretBox , SecretString } ;
4647use snafu:: prelude:: * ;
4748use snafu:: Snafu ;
4849use tokio:: runtime:: Handle ;
@@ -184,69 +185,71 @@ where
184185 let params = params. iter ( ) . map ( dyn_clone:: clone) . collect :: < Vec < _ > > ( ) ;
185186 let secrets = Arc :: clone ( & self . params ) ;
186187
187- let join_handle = tokio:: task:: spawn_blocking ( move || {
188- let handle = Handle :: current ( ) ;
189- let cxn = handle. block_on ( async { conn. lock ( ) . await } ) ;
188+ let create_stream = async || -> Result < SendableRecordBatchStream > {
189+ let join_handle = tokio:: task:: spawn_blocking ( move || {
190+ let handle = Handle :: current ( ) ;
191+ let cxn = handle. block_on ( async { conn. lock ( ) . await } ) ;
190192
191- let mut prepared = cxn. prepare ( & sql) ?;
192- let schema = Arc :: new ( arrow_schema_from ( & mut prepared, false ) ?) ;
193- blocking_channel_send ( & schema_tx, Arc :: clone ( & schema) ) ?;
193+ let mut prepared = cxn. prepare ( & sql) ?;
194+ let schema = Arc :: new ( arrow_schema_from ( & mut prepared, false ) ?) ;
195+ blocking_channel_send ( & schema_tx, Arc :: clone ( & schema) ) ?;
194196
195- let mut statement = prepared. into_statement ( ) ;
197+ let mut statement = prepared. into_statement ( ) ;
196198
197- bind_parameters ( & mut statement, & params) ?;
199+ bind_parameters ( & mut statement, & params) ?;
198200
199- // StatementImpl<'_>::execute is unsafe, CursorImpl<_>::new is unsafe
200- let cursor = unsafe {
201- if let SqlResult :: Error { function } = statement. execute ( ) {
202- return Err ( Error :: ODBCAPIErrorNoSource {
203- message : function. to_string ( ) ,
201+ // StatementImpl<'_>::execute is unsafe, CursorImpl<_>::new is unsafe
202+ let cursor = unsafe {
203+ if let SqlResult :: Error { function } = statement. execute ( ) {
204+ return Err ( Error :: ODBCAPIErrorNoSource {
205+ message : function. to_string ( ) ,
206+ }
207+ . into ( ) ) ;
204208 }
205- . into ( ) ) ;
206- }
207209
208- Ok :: < _ , GenericError > ( CursorImpl :: new ( statement. as_stmt_ref ( ) ) )
209- } ?;
210+ Ok :: < _ , GenericError > ( CursorImpl :: new ( statement. as_stmt_ref ( ) ) )
211+ } ?;
210212
211- let reader = build_odbc_reader ( cursor, & schema, & secrets) ?;
212- for batch in reader {
213- blocking_channel_send ( & batch_tx, batch. context ( ArrowSnafu ) ?) ?;
214- }
213+ let reader = build_odbc_reader ( cursor, & schema, & secrets) ?;
214+ for batch in reader {
215+ blocking_channel_send ( & batch_tx, batch. context ( ArrowSnafu ) ?) ?;
216+ }
215217
216- Ok :: < _ , GenericError > ( ( ) )
217- } ) ;
218+ Ok :: < _ , GenericError > ( ( ) )
219+ } ) ;
218220
219- // we need to wait for the schema first before we can build our RecordBatchStreamAdapter
220- let Some ( schema) = schema_rx. recv ( ) . await else {
221- // if the channel drops, the task errored
222- if !join_handle. is_finished ( ) {
223- unreachable ! ( "Schema channel should not have dropped before the task finished" ) ;
224- }
221+ // we need to wait for the schema first before we can build our RecordBatchStreamAdapter
222+ let Some ( schema) = schema_rx. recv ( ) . await else {
223+ // if the channel drops, the task errored
224+ if !join_handle. is_finished ( ) {
225+ unreachable ! ( "Schema channel should not have dropped before the task finished" ) ;
226+ }
225227
226- let result = join_handle. await ?;
227- let Err ( err) = result else {
228- unreachable ! ( "Task should have errored" ) ;
228+ let result = join_handle. await ?;
229+ let Err ( err) = result else {
230+ unreachable ! ( "Task should have errored" ) ;
231+ } ;
232+
233+ return Err ( err) ;
229234 } ;
230235
231- return Err ( err) ;
232- } ;
236+ let output_stream = stream ! {
237+ while let Some ( batch) = batch_rx. recv( ) . await {
238+ yield Ok ( batch) ;
239+ }
233240
234- let output_stream = stream ! {
235- while let Some ( batch) = batch_rx. recv( ) . await {
236- yield Ok ( batch) ;
237- }
241+ if let Err ( e) = join_handle. await {
242+ yield Err ( DataFusionError :: Execution ( format!(
243+ "Failed to execute ODBC query: {e}"
244+ ) ) )
245+ }
246+ } ;
238247
239- if let Err ( e) = join_handle. await {
240- yield Err ( DataFusionError :: Execution ( format!(
241- "Failed to execute ODBC query: {e}"
242- ) ) )
243- }
248+ let result: SendableRecordBatchStream =
249+ Box :: pin ( RecordBatchStreamAdapter :: new ( schema, output_stream) ) ;
250+ Ok ( result)
244251 } ;
245-
246- Ok ( Box :: pin ( RecordBatchStreamAdapter :: new (
247- schema,
248- output_stream,
249- ) ) )
252+ run_async_with_tokio ( create_stream) . await
250253 }
251254
252255 async fn execute ( & self , query : & str , params : & [ ODBCParameter ] ) -> Result < u64 > {
0 commit comments