Skip to content

Commit 3a58cb8

Browse files
committed
remove extra stuff
1 parent 1b261ce commit 3a58cb8

File tree

2 files changed

+63
-240
lines changed

2 files changed

+63
-240
lines changed

backend/windmill-common/src/more_serde.rs

Lines changed: 3 additions & 71 deletions
Original file line numberDiff line numberDiff line change
@@ -8,20 +8,10 @@
88

99
//! helpers for serde + serde derive attributes
1010
11-
use crate::{
12-
error::{self, to_anyhow},
13-
utils::rd_string,
14-
};
15-
use bytes::Bytes;
16-
use futures::TryStreamExt;
17-
use serde::{de::DeserializeSeed, Deserialize, Deserializer};
18-
use serde_json::{value::RawValue, Value};
11+
use crate::utils::rd_string;
12+
use serde::{Deserialize, Deserializer};
13+
use serde_json::value::RawValue;
1914
use std::{fmt::Display, str::FromStr};
20-
use tokio::{
21-
sync::mpsc::Sender,
22-
task::{self},
23-
};
24-
use tokio_stream::StreamExt;
2515

2616
pub fn default_true() -> bool {
2717
true
@@ -75,61 +65,3 @@ where
7565
NumericOrNull::Null => Ok(None),
7666
}
7767
}
78-
79-
// Takes a json stream and returns a stream of json values, without loading the
80-
// entire input into memory.
81-
pub async fn json_stream_values<
82-
'a,
83-
D: DeserializeSeed<'a> + 'static + Send,
84-
F: FnOnce(Sender<Value>) -> D,
85-
E: Display,
86-
>(
87-
mut stream: impl TryStreamExt<Item = Result<Bytes, E>> + Send + Unpin + 'static,
88-
mpsc_deserializer_factory: F,
89-
) -> error::Result<impl StreamExt<Item = serde_json::Value>> {
90-
const MAX_MPSC_SIZE: usize = 1000;
91-
92-
use std::path::PathBuf;
93-
use tokio::io::AsyncWriteExt;
94-
95-
let tmp_filename = format!("tmp_json_stream_{}", rd_string(8));
96-
97-
// Start by writing the async stream (from network) to a file.
98-
let mut path = PathBuf::from(std::env::temp_dir());
99-
path.push(tmp_filename);
100-
let mut file: tokio::fs::File = tokio::fs::File::create(&path).await.map_err(to_anyhow)?;
101-
while let Some(chunk) = stream.next().await {
102-
let chunk: Bytes = match chunk {
103-
Ok(chunk) => chunk,
104-
Err(e) => {
105-
std::fs::remove_file(&path)?;
106-
return Err(error::Error::ExecutionErr(format!(
107-
"Error reading stream: {}",
108-
e
109-
)));
110-
}
111-
};
112-
file.write_all(&chunk).await?;
113-
}
114-
115-
let (tx, rx) = tokio::sync::mpsc::channel(MAX_MPSC_SIZE);
116-
// Takes ownership of tx
117-
let mpsc_deserializer = mpsc_deserializer_factory(tx);
118-
119-
// We read the file and pipe each element to the channel in a blocking task.
120-
let _ = task::spawn_blocking::<_, anyhow::Result<()>>(move || {
121-
let sync_file = std::fs::File::open(&path).map_err(to_anyhow)?;
122-
let mut buf_reader = std::io::BufReader::new(sync_file);
123-
124-
let mut deserializer = serde_json::Deserializer::from_reader(&mut buf_reader);
125-
// This deserializer will read the file and send each row to the channel
126-
let _ = mpsc_deserializer.deserialize(&mut deserializer)?;
127-
128-
std::fs::remove_file(&path)?;
129-
Ok(())
130-
// tx drops with mpsc_deserializer so the stream ends
131-
});
132-
133-
let stream = tokio_stream::wrappers::ReceiverStream::new(rx);
134-
Ok(stream)
135-
}

backend/windmill-worker/src/snowflake_executor.rs

Lines changed: 60 additions & 169 deletions
Original file line numberDiff line numberDiff line change
@@ -2,16 +2,13 @@ use base64::{engine, Engine as _};
22
use chrono::Datelike;
33
use core::fmt::Write;
44
use futures::future::BoxFuture;
5-
use futures::{FutureExt, StreamExt, TryFutureExt};
5+
use futures::{FutureExt, StreamExt, TryFutureExt, TryStreamExt};
66
use jsonwebtoken::{encode, Algorithm, EncodingKey, Header};
77
use reqwest::{Client, Response};
88
use serde_json::{json, value::RawValue, Value};
99
use sha2::{Digest, Sha256};
10-
use std::collections::{BTreeMap, HashMap};
11-
use std::convert::Infallible;
10+
use std::collections::HashMap;
1211
use windmill_common::error::to_anyhow;
13-
use windmill_common::more_serde::json_stream_values;
14-
use windmill_common::s3_helpers::convert_json_line_stream;
1512
use windmill_common::worker::Connection;
1613

1714
use windmill_common::{error::Error, worker::to_raw_value};
@@ -179,30 +176,12 @@ fn do_snowflake_inner<'a>(
179176
if skip_collect {
180177
handle_snowflake_result(result).await?;
181178
Ok(to_raw_value(&Value::Array(vec![])))
182-
} else if let Some(ref s3) = s3 {
183-
// do not do parse_snowflake_response as it will call .json() and
184-
// load the entire response into memory
185-
let result = result.map_err(|e| {
186-
Error::ExecutionErr(format!("Could not send request to Snowflake: {:?}", e))
187-
})?;
188-
189-
let rows_stream = json_stream_values(result.bytes_stream(), |sender| {
190-
RootMpscDeserializer { sender }
191-
})
192-
.await?
193-
.boxed()
194-
.map(|chunk| Ok::<_, Infallible>(chunk));
195-
196-
let stream = convert_json_line_stream(rows_stream, s3.format).await?;
197-
s3.upload(stream.boxed()).await?;
198-
199-
Ok(serde_json::value::to_raw_value(&s3.object_key)?)
200179
} else {
201180
let response = result
202181
.parse_snowflake_response::<SnowflakeResponse>()
203182
.await?;
204183

205-
if response.resultSetMetaData.numRows > 10000 {
184+
if s3.is_none() && response.resultSetMetaData.numRows > 10000 {
206185
return Err(Error::ExecutionErr(
207186
"More than 10000 rows were requested, use LIMIT 10000 to limit the number of rows"
208187
.to_string(),
@@ -219,54 +198,66 @@ fn do_snowflake_inner<'a>(
219198
);
220199
}
221200

222-
let mut rows = response.data;
223-
224-
if response.resultSetMetaData.partitionInfo.len() > 1 {
225-
for idx in 1..response.resultSetMetaData.partitionInfo.len() {
226-
let url = format!(
227-
"https://{}.snowflakecomputing.com/api/v2/statements/{}",
228-
account_identifier.to_uppercase(),
229-
response.statementHandle
230-
);
231-
let mut request = HTTP_CLIENT
232-
.get(url)
233-
.bearer_auth(token)
234-
.query(&[("partition", idx.to_string())]);
235-
236-
if token_is_keypair {
237-
request =
238-
request.header("X-Snowflake-Authorization-Token-Type", "KEYPAIR_JWT");
239-
}
240-
241-
let response = request
242-
.send()
243-
.await
244-
.parse_snowflake_response::<SnowflakeDataOnlyResponse>()
245-
.await?;
201+
let rows_stream = async_stream::stream! {
202+
for row in response.data {
203+
yield Ok::<Vec<Value>, windmill_common::error::Error>(row);
204+
}
246205

247-
rows.extend(response.data);
206+
if response.resultSetMetaData.partitionInfo.len() > 1 {
207+
for idx in 1..response.resultSetMetaData.partitionInfo.len() {
208+
let url = format!(
209+
"https://{}.snowflakecomputing.com/api/v2/statements/{}",
210+
account_identifier.to_uppercase(),
211+
response.statementHandle
212+
);
213+
let mut request = HTTP_CLIENT
214+
.get(url)
215+
.bearer_auth(token)
216+
.query(&[("partition", idx.to_string())]);
217+
218+
if token_is_keypair {
219+
request =
220+
request.header("X-Snowflake-Authorization-Token-Type", "KEYPAIR_JWT");
221+
}
222+
223+
let response = request
224+
.send()
225+
.await
226+
.parse_snowflake_response::<SnowflakeDataOnlyResponse>()
227+
.await?;
228+
229+
for row in response.data {
230+
yield Ok(row);
231+
}
232+
}
248233
}
234+
};
235+
236+
let rows_stream = rows_stream.map_ok(|row| {
237+
let mut row_map = serde_json::Map::new();
238+
row.iter()
239+
.zip(response.resultSetMetaData.rowType.iter())
240+
.for_each(|(val, row_type)| {
241+
row_map.insert(row_type.name.clone(), parse_val(&val, &row_type.r#type));
242+
});
243+
row_map
244+
});
245+
246+
if let Some(s3) = s3 {
247+
// let rows_stream =
248+
// rows_stream.map(|r| serde_json::value::to_value(&r?).map_err(to_anyhow));
249+
// let stream = convert_json_line_stream(rows_stream.boxed(), s3.format).await?;
250+
// TODO fix this
251+
// s3.upload(stream.boxed()).await?;
252+
Ok(to_raw_value(&s3.object_key))
253+
} else {
254+
let rows = rows_stream
255+
.collect::<Vec<_>>()
256+
.await
257+
.into_iter()
258+
.collect::<Result<Vec<_>, _>>()?;
259+
Ok(to_raw_value(&rows))
249260
}
250-
251-
let rows = to_raw_value(
252-
&rows
253-
.iter()
254-
.map(|row| {
255-
let mut row_map = serde_json::Map::new();
256-
row.iter()
257-
.zip(response.resultSetMetaData.rowType.iter())
258-
.for_each(|(val, row_type)| {
259-
row_map.insert(
260-
row_type.name.clone(),
261-
parse_val(&val, &row_type.r#type),
262-
);
263-
});
264-
row_map
265-
})
266-
.collect::<Vec<_>>(),
267-
);
268-
269-
Ok(rows)
270261
}
271262
};
272263

@@ -587,103 +578,3 @@ fn parse_val(value: &Value, typ: &str) -> Value {
587578
))
588579
}
589580
}
590-
591-
// This deserializer takes a snowflake response as a stream and sends each row to an mpsc
592-
// channel as a json record without storing the full input json in memory.
593-
struct RootMpscDeserializer {
594-
sender: tokio::sync::mpsc::Sender<serde_json::Value>,
595-
}
596-
597-
impl<'de> serde::de::DeserializeSeed<'de> for RootMpscDeserializer {
598-
type Value = ();
599-
fn deserialize<D>(self, deserializer: D) -> Result<Self::Value, D::Error>
600-
where
601-
D: serde::Deserializer<'de>,
602-
{
603-
struct RootVisitor<'a> {
604-
sender: &'a tokio::sync::mpsc::Sender<serde_json::Value>,
605-
col_defs: Vec<String>,
606-
}
607-
608-
impl<'de, 'a> serde::de::Visitor<'de> for RootVisitor<'a> {
609-
type Value = ();
610-
fn expecting(&self, formatter: &mut std::fmt::Formatter) -> std::fmt::Result {
611-
formatter.write_str("data field from snowflake response")
612-
}
613-
fn visit_map<A>(mut self, mut map: A) -> Result<(), A::Error>
614-
where
615-
A: serde::de::MapAccess<'de>,
616-
{
617-
while let Some(key) = map.next_key::<String>()? {
618-
if key == "resultSetMetaData" {
619-
let result_set_metadata: SnowflakeResultSetMetaData = map.next_value()?;
620-
self.col_defs = result_set_metadata
621-
.rowType
622-
.iter()
623-
.map(|x| x.name.clone())
624-
.collect::<Vec<String>>();
625-
} else if key == "data" {
626-
let () = map.next_value_seed(RowsMpscDeserializer {
627-
sender: self.sender,
628-
col_defs: &self.col_defs,
629-
})?;
630-
} else {
631-
let _: serde::de::IgnoredAny = map.next_value()?;
632-
}
633-
}
634-
Ok(())
635-
}
636-
}
637-
638-
deserializer.deserialize_map(RootVisitor { sender: &self.sender, col_defs: vec![] })
639-
}
640-
}
641-
642-
struct RowsMpscDeserializer<'a> {
643-
sender: &'a tokio::sync::mpsc::Sender<serde_json::Value>,
644-
col_defs: &'a Vec<String>,
645-
}
646-
647-
impl<'de, 'a> serde::de::DeserializeSeed<'de> for RowsMpscDeserializer<'a> {
648-
type Value = ();
649-
fn deserialize<D>(self, deserializer: D) -> Result<Self::Value, D::Error>
650-
where
651-
D: serde::Deserializer<'de>,
652-
{
653-
struct RowsVisitor<'a> {
654-
sender: &'a tokio::sync::mpsc::Sender<serde_json::Value>,
655-
col_defs: &'a Vec<String>,
656-
}
657-
658-
impl<'de, 'a> serde::de::Visitor<'de> for RowsVisitor<'a> {
659-
type Value = ();
660-
661-
fn expecting(&self, formatter: &mut std::fmt::Formatter) -> std::fmt::Result {
662-
formatter.write_str("a sequence of rows")
663-
}
664-
665-
fn visit_seq<A>(self, mut seq: A) -> Result<(), A::Error>
666-
where
667-
A: serde::de::SeqAccess<'de>,
668-
{
669-
while let Some(elem) = seq.next_element::<Vec<Value>>()? {
670-
let mut row = BTreeMap::<&str, Value>::new();
671-
for (i, field) in elem.iter().enumerate() {
672-
let col_name = self.col_defs.get(i).map(|s| s.as_str()).unwrap_or("");
673-
row.insert(col_name, field.clone());
674-
}
675-
let row = serde_json::to_value(row).map_err(|err| {
676-
serde::de::Error::custom(format!("Map parse failed: {err}"))
677-
})?;
678-
self.sender.blocking_send(row).map_err(|err| {
679-
serde::de::Error::custom(format!("Queue send failed: {err}"))
680-
})?;
681-
}
682-
683-
Ok(())
684-
}
685-
}
686-
687-
deserializer.deserialize_seq(RowsVisitor { sender: &self.sender, col_defs: &self.col_defs })
688-
}
689-
}

0 commit comments

Comments
 (0)