@@ -2,16 +2,13 @@ use base64::{engine, Engine as _};
22use chrono:: Datelike ;
33use core:: fmt:: Write ;
44use futures:: future:: BoxFuture ;
5- use futures:: { FutureExt , StreamExt , TryFutureExt } ;
5+ use futures:: { FutureExt , StreamExt , TryFutureExt , TryStreamExt } ;
66use jsonwebtoken:: { encode, Algorithm , EncodingKey , Header } ;
77use reqwest:: { Client , Response } ;
88use serde_json:: { json, value:: RawValue , Value } ;
99use sha2:: { Digest , Sha256 } ;
10- use std:: collections:: { BTreeMap , HashMap } ;
11- use std:: convert:: Infallible ;
10+ use std:: collections:: HashMap ;
1211use windmill_common:: error:: to_anyhow;
13- use windmill_common:: more_serde:: json_stream_values;
14- use windmill_common:: s3_helpers:: convert_json_line_stream;
1512use windmill_common:: worker:: Connection ;
1613
1714use windmill_common:: { error:: Error , worker:: to_raw_value} ;
@@ -179,30 +176,12 @@ fn do_snowflake_inner<'a>(
179176 if skip_collect {
180177 handle_snowflake_result ( result) . await ?;
181178 Ok ( to_raw_value ( & Value :: Array ( vec ! [ ] ) ) )
182- } else if let Some ( ref s3) = s3 {
183- // do not do parse_snowflake_response as it will call .json() and
184- // load the entire response into memory
185- let result = result. map_err ( |e| {
186- Error :: ExecutionErr ( format ! ( "Could not send request to Snowflake: {:?}" , e) )
187- } ) ?;
188-
189- let rows_stream = json_stream_values ( result. bytes_stream ( ) , |sender| {
190- RootMpscDeserializer { sender }
191- } )
192- . await ?
193- . boxed ( )
194- . map ( |chunk| Ok :: < _ , Infallible > ( chunk) ) ;
195-
196- let stream = convert_json_line_stream ( rows_stream, s3. format ) . await ?;
197- s3. upload ( stream. boxed ( ) ) . await ?;
198-
199- Ok ( serde_json:: value:: to_raw_value ( & s3. object_key ) ?)
200179 } else {
201180 let response = result
202181 . parse_snowflake_response :: < SnowflakeResponse > ( )
203182 . await ?;
204183
205- if response. resultSetMetaData . numRows > 10000 {
184+ if s3 . is_none ( ) && response. resultSetMetaData . numRows > 10000 {
206185 return Err ( Error :: ExecutionErr (
207186 "More than 10000 rows were requested, use LIMIT 10000 to limit the number of rows"
208187 . to_string ( ) ,
@@ -219,54 +198,66 @@ fn do_snowflake_inner<'a>(
219198 ) ;
220199 }
221200
222- let mut rows = response. data ;
223-
224- if response. resultSetMetaData . partitionInfo . len ( ) > 1 {
225- for idx in 1 ..response. resultSetMetaData . partitionInfo . len ( ) {
226- let url = format ! (
227- "https://{}.snowflakecomputing.com/api/v2/statements/{}" ,
228- account_identifier. to_uppercase( ) ,
229- response. statementHandle
230- ) ;
231- let mut request = HTTP_CLIENT
232- . get ( url)
233- . bearer_auth ( token)
234- . query ( & [ ( "partition" , idx. to_string ( ) ) ] ) ;
235-
236- if token_is_keypair {
237- request =
238- request. header ( "X-Snowflake-Authorization-Token-Type" , "KEYPAIR_JWT" ) ;
239- }
240-
241- let response = request
242- . send ( )
243- . await
244- . parse_snowflake_response :: < SnowflakeDataOnlyResponse > ( )
245- . await ?;
201+ let rows_stream = async_stream:: stream! {
202+ for row in response. data {
203+ yield Ok :: <Vec <Value >, windmill_common:: error:: Error >( row) ;
204+ }
246205
247- rows. extend ( response. data ) ;
206+ if response. resultSetMetaData. partitionInfo. len( ) > 1 {
207+ for idx in 1 ..response. resultSetMetaData. partitionInfo. len( ) {
208+ let url = format!(
209+ "https://{}.snowflakecomputing.com/api/v2/statements/{}" ,
210+ account_identifier. to_uppercase( ) ,
211+ response. statementHandle
212+ ) ;
213+ let mut request = HTTP_CLIENT
214+ . get( url)
215+ . bearer_auth( token)
216+ . query( & [ ( "partition" , idx. to_string( ) ) ] ) ;
217+
218+ if token_is_keypair {
219+ request =
220+ request. header( "X-Snowflake-Authorization-Token-Type" , "KEYPAIR_JWT" ) ;
221+ }
222+
223+ let response = request
224+ . send( )
225+ . await
226+ . parse_snowflake_response:: <SnowflakeDataOnlyResponse >( )
227+ . await ?;
228+
229+ for row in response. data {
230+ yield Ok ( row) ;
231+ }
232+ }
248233 }
234+ } ;
235+
236+ let rows_stream = rows_stream. map_ok ( |row| {
237+ let mut row_map = serde_json:: Map :: new ( ) ;
238+ row. iter ( )
239+ . zip ( response. resultSetMetaData . rowType . iter ( ) )
240+ . for_each ( |( val, row_type) | {
241+ row_map. insert ( row_type. name . clone ( ) , parse_val ( & val, & row_type. r#type ) ) ;
242+ } ) ;
243+ row_map
244+ } ) ;
245+
246+ if let Some ( s3) = s3 {
247+ // let rows_stream =
248+ // rows_stream.map(|r| serde_json::value::to_value(&r?).map_err(to_anyhow));
249+ // let stream = convert_json_line_stream(rows_stream.boxed(), s3.format).await?;
250+ // TODO fix this
251+ // s3.upload(stream.boxed()).await?;
252+ Ok ( to_raw_value ( & s3. object_key ) )
253+ } else {
254+ let rows = rows_stream
255+ . collect :: < Vec < _ > > ( )
256+ . await
257+ . into_iter ( )
258+ . collect :: < Result < Vec < _ > , _ > > ( ) ?;
259+ Ok ( to_raw_value ( & rows) )
249260 }
250-
251- let rows = to_raw_value (
252- & rows
253- . iter ( )
254- . map ( |row| {
255- let mut row_map = serde_json:: Map :: new ( ) ;
256- row. iter ( )
257- . zip ( response. resultSetMetaData . rowType . iter ( ) )
258- . for_each ( |( val, row_type) | {
259- row_map. insert (
260- row_type. name . clone ( ) ,
261- parse_val ( & val, & row_type. r#type ) ,
262- ) ;
263- } ) ;
264- row_map
265- } )
266- . collect :: < Vec < _ > > ( ) ,
267- ) ;
268-
269- Ok ( rows)
270261 }
271262 } ;
272263
@@ -587,103 +578,3 @@ fn parse_val(value: &Value, typ: &str) -> Value {
587578 ) )
588579 }
589580}
590-
591- // This deserializer takes a snowflake response as a stream and sends each row to an mpsc
592- // channel as a json record without storing the full input json in memory.
593- struct RootMpscDeserializer {
594- sender : tokio:: sync:: mpsc:: Sender < serde_json:: Value > ,
595- }
596-
597- impl < ' de > serde:: de:: DeserializeSeed < ' de > for RootMpscDeserializer {
598- type Value = ( ) ;
599- fn deserialize < D > ( self , deserializer : D ) -> Result < Self :: Value , D :: Error >
600- where
601- D : serde:: Deserializer < ' de > ,
602- {
603- struct RootVisitor < ' a > {
604- sender : & ' a tokio:: sync:: mpsc:: Sender < serde_json:: Value > ,
605- col_defs : Vec < String > ,
606- }
607-
608- impl < ' de , ' a > serde:: de:: Visitor < ' de > for RootVisitor < ' a > {
609- type Value = ( ) ;
610- fn expecting ( & self , formatter : & mut std:: fmt:: Formatter ) -> std:: fmt:: Result {
611- formatter. write_str ( "data field from snowflake response" )
612- }
613- fn visit_map < A > ( mut self , mut map : A ) -> Result < ( ) , A :: Error >
614- where
615- A : serde:: de:: MapAccess < ' de > ,
616- {
617- while let Some ( key) = map. next_key :: < String > ( ) ? {
618- if key == "resultSetMetaData" {
619- let result_set_metadata: SnowflakeResultSetMetaData = map. next_value ( ) ?;
620- self . col_defs = result_set_metadata
621- . rowType
622- . iter ( )
623- . map ( |x| x. name . clone ( ) )
624- . collect :: < Vec < String > > ( ) ;
625- } else if key == "data" {
626- let ( ) = map. next_value_seed ( RowsMpscDeserializer {
627- sender : self . sender ,
628- col_defs : & self . col_defs ,
629- } ) ?;
630- } else {
631- let _: serde:: de:: IgnoredAny = map. next_value ( ) ?;
632- }
633- }
634- Ok ( ( ) )
635- }
636- }
637-
638- deserializer. deserialize_map ( RootVisitor { sender : & self . sender , col_defs : vec ! [ ] } )
639- }
640- }
641-
642- struct RowsMpscDeserializer < ' a > {
643- sender : & ' a tokio:: sync:: mpsc:: Sender < serde_json:: Value > ,
644- col_defs : & ' a Vec < String > ,
645- }
646-
647- impl < ' de , ' a > serde:: de:: DeserializeSeed < ' de > for RowsMpscDeserializer < ' a > {
648- type Value = ( ) ;
649- fn deserialize < D > ( self , deserializer : D ) -> Result < Self :: Value , D :: Error >
650- where
651- D : serde:: Deserializer < ' de > ,
652- {
653- struct RowsVisitor < ' a > {
654- sender : & ' a tokio:: sync:: mpsc:: Sender < serde_json:: Value > ,
655- col_defs : & ' a Vec < String > ,
656- }
657-
658- impl < ' de , ' a > serde:: de:: Visitor < ' de > for RowsVisitor < ' a > {
659- type Value = ( ) ;
660-
661- fn expecting ( & self , formatter : & mut std:: fmt:: Formatter ) -> std:: fmt:: Result {
662- formatter. write_str ( "a sequence of rows" )
663- }
664-
665- fn visit_seq < A > ( self , mut seq : A ) -> Result < ( ) , A :: Error >
666- where
667- A : serde:: de:: SeqAccess < ' de > ,
668- {
669- while let Some ( elem) = seq. next_element :: < Vec < Value > > ( ) ? {
670- let mut row = BTreeMap :: < & str , Value > :: new ( ) ;
671- for ( i, field) in elem. iter ( ) . enumerate ( ) {
672- let col_name = self . col_defs . get ( i) . map ( |s| s. as_str ( ) ) . unwrap_or ( "" ) ;
673- row. insert ( col_name, field. clone ( ) ) ;
674- }
675- let row = serde_json:: to_value ( row) . map_err ( |err| {
676- serde:: de:: Error :: custom ( format ! ( "Map parse failed: {err}" ) )
677- } ) ?;
678- self . sender . blocking_send ( row) . map_err ( |err| {
679- serde:: de:: Error :: custom ( format ! ( "Queue send failed: {err}" ) )
680- } ) ?;
681- }
682-
683- Ok ( ( ) )
684- }
685- }
686-
687- deserializer. deserialize_seq ( RowsVisitor { sender : & self . sender , col_defs : & self . col_defs } )
688- }
689- }
0 commit comments