Skip to content

di/stream sql to s3 #5704

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Draft
wants to merge 5 commits into
base: main
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion backend/Cargo.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

12 changes: 12 additions & 0 deletions backend/parsers/windmill-parser-sql/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -120,6 +120,17 @@ pub fn parse_db_resource(code: &str) -> Option<String> {
cap.map(|x| x.get(1).map(|x| x.as_str().to_string()).unwrap())
}

pub struct S3ModeArgs {
pub folder_key: String,
pub storage: Option<String>,
}
pub fn parse_s3_mode(code: &str) -> Option<S3ModeArgs> {
let cap = RE_S3_MODE.captures(code)?;
let arg1 = cap.get(1).map(|x| x.as_str().to_string())?;
let arg2 = cap.get(2).map(|x| x.as_str().to_string());
Some(S3ModeArgs { folder_key: arg1, storage: arg2 })
}

pub fn parse_sql_blocks(code: &str) -> Vec<&str> {
let mut blocks = vec![];
let mut last_idx = 0;
Expand Down Expand Up @@ -147,6 +158,7 @@ lazy_static::lazy_static! {
static ref RE_NONEMPTY_SQL_BLOCK: Regex = Regex::new(r#"(?m)^\s*[^\s](?:[^-]|$)"#).unwrap();

static ref RE_DB: Regex = Regex::new(r#"(?m)^-- database (\S+) *(?:\r|\n|$)"#).unwrap();
static ref RE_S3_MODE: Regex = Regex::new(r#"(?m)^-- s3 (\S+)( +(\S+))? *(?:\r|\n|$)"#).unwrap();

// -- $1 name (type) = default
static ref RE_ARG_MYSQL: Regex = Regex::new(r#"(?m)^-- \? (\w+) \((\w+)\)(?: ?\= ?(.+))? *(?:\r|\n|$)"#).unwrap();
Expand Down
31 changes: 30 additions & 1 deletion backend/windmill-worker/src/bigquery_executor.rs
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,8 @@ use windmill_common::error::to_anyhow;
use windmill_common::worker::Connection;
use windmill_common::{error::Error, worker::to_raw_value};
use windmill_parser_sql::{
parse_bigquery_sig, parse_db_resource, parse_sql_blocks, parse_sql_statement_named_params,
parse_bigquery_sig, parse_db_resource, parse_s3_mode, parse_sql_blocks,
parse_sql_statement_named_params,
};
use windmill_queue::CanceledBy;

Expand Down Expand Up @@ -65,6 +66,14 @@ struct BigqueryError {
message: String,
}

#[derive(Clone)]
struct S3Mode {
client: AuthedClient,
object_key: String,
storage: Option<String>,
workspace_id: String,
}

fn do_bigquery_inner<'a>(
query: &'a str,
all_statement_values: &'a HashMap<String, Value>,
Expand All @@ -74,6 +83,7 @@ fn do_bigquery_inner<'a>(
column_order: Option<&'a mut Option<Vec<String>>>,
skip_collect: bool,
http_client: &'a Client,
s3: Option<S3Mode>,
) -> windmill_common::error::Result<BoxFuture<'a, windmill_common::error::Result<Box<RawValue>>>> {
let param_names = parse_sql_statement_named_params(query, '@');

Expand Down Expand Up @@ -113,6 +123,17 @@ fn do_bigquery_inner<'a>(
Ok(_) => {
if skip_collect {
return Ok(to_raw_value(&Value::Array(vec![])));
} else if let Some(ref s3) = s3 {
s3.client
.upload_s3_file(
s3.workspace_id.as_str(),
s3.object_key.clone(),
s3.storage.clone(),
response.bytes_stream(),
)
.await?;

Ok(serde_json::value::to_raw_value(&s3.object_key)?)
} else {
let result = response.json::<BigqueryResponse>().await.map_err(|e| {
Error::ExecutionErr(format!(
Expand Down Expand Up @@ -220,6 +241,12 @@ pub async fn do_bigquery(
let bigquery_args = build_args_values(job, client, conn).await?;

let inline_db_res_path = parse_db_resource(&query);
let s3 = parse_s3_mode(&query).map(|s3_mode| S3Mode {
client: client.clone(),
storage: s3_mode.storage,
object_key: format!("{}/{}.txt", s3_mode.folder_key, job.id),
workspace_id: job.workspace_id.clone(),
});

let db_arg = if let Some(inline_db_res_path) = inline_db_res_path {
Some(
Expand Down Expand Up @@ -332,6 +359,7 @@ pub async fn do_bigquery(
None,
annotations.return_last_result && i < queries.len() - 1,
&http_client,
s3.clone(),
)
})
.collect::<windmill_common::error::Result<Vec<_>>>()?;
Expand Down Expand Up @@ -361,6 +389,7 @@ pub async fn do_bigquery(
Some(column_order),
false,
&http_client,
s3,
)?
};

Expand Down
50 changes: 48 additions & 2 deletions backend/windmill-worker/src/pg_executor.rs
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@ use anyhow::Context;
use base64::{engine, Engine as _};
use chrono::Utc;
use futures::future::BoxFuture;
use futures::{FutureExt, TryStreamExt};
use futures::{FutureExt, StreamExt, TryStreamExt};
use itertools::Itertools;
use native_tls::{Certificate, TlsConnector};
use postgres_native_tls::MakeTlsConnector;
Expand All @@ -30,7 +30,8 @@ use windmill_common::error::{self, Error};
use windmill_common::worker::{to_raw_value, Connection, CLOUD_HOSTED};
use windmill_parser::{Arg, Typ};
use windmill_parser_sql::{
parse_db_resource, parse_pg_statement_arg_indices, parse_pgsql_sig, parse_sql_blocks,
parse_db_resource, parse_pg_statement_arg_indices, parse_pgsql_sig, parse_s3_mode,
parse_sql_blocks,
};
use windmill_queue::{CanceledBy, MiniPulledJob};

Expand All @@ -53,6 +54,14 @@ struct PgDatabase {
root_certificate_pem: Option<String>,
}

#[derive(Clone)]
struct S3Mode {
client: AuthedClient,
object_key: String,
storage: Option<String>,
workspace_id: String,
}

lazy_static! {
pub static ref CONNECTION_CACHE: Arc<Mutex<Option<(String, tokio_postgres::Client)>>> =
Arc::new(Mutex::new(None));
Expand All @@ -68,6 +77,7 @@ fn do_postgresql_inner<'a>(
column_order: Option<&'a mut Option<Vec<String>>>,
siz: &'a AtomicUsize,
skip_collect: bool,
s3: Option<S3Mode>,
) -> error::Result<BoxFuture<'a, anyhow::Result<Box<RawValue>>>> {
let mut query_params = vec![];

Expand Down Expand Up @@ -106,6 +116,34 @@ fn do_postgresql_inner<'a>(
.execute_raw(&query, query_params)
.await
.map_err(to_anyhow)?;
} else if let Some(ref s3) = s3 {
let rows_stream = client
.query_raw(&query, query_params)
.await?
.map_err(to_anyhow)
.enumerate()
.map(|(i, row_result)| {
row_result.and_then(|row| {
postgres_row_to_json_value(row)
.map_err(to_anyhow)
.and_then(|ref v| serde_json::to_string(v).map_err(to_anyhow))
.map(|s| if i == 0 { s } else { format!(",\n{}", s) })
})
});
let start_bracket = futures::stream::once(async { Ok("{ rows: [\n".to_string()) });
let end_bracket = futures::stream::once(async { Ok("\n]}".to_string()) });
let rows_stream = start_bracket.chain(rows_stream).chain(end_bracket);

s3.client
.upload_s3_file(
s3.workspace_id.as_str(),
s3.object_key.clone(),
s3.storage.clone(),
rows_stream,
)
.await?;

return Ok(serde_json::value::to_raw_value(&s3.object_key)?);
} else {
let rows = client
.query_raw(&query, query_params)
Expand Down Expand Up @@ -171,6 +209,12 @@ pub async fn do_postgresql(
let pg_args = build_args_values(job, client, conn).await?;

let inline_db_res_path = parse_db_resource(&query);
let s3 = parse_s3_mode(&query).map(|s3_mode| S3Mode {
client: client.clone(),
storage: s3_mode.storage,
object_key: format!("{}/{}.txt", s3_mode.folder_key, job.id),
workspace_id: job.workspace_id.clone(),
});

let db_arg = if let Some(inline_db_res_path) = inline_db_res_path {
Some(
Expand Down Expand Up @@ -321,6 +365,7 @@ pub async fn do_postgresql(
None,
&size,
annotations.return_last_result && i < queries.len() - 1,
s3.clone(),
)
})
.collect::<error::Result<Vec<_>>>()?;
Expand All @@ -347,6 +392,7 @@ pub async fn do_postgresql(
Some(column_order),
&size,
false,
s3,
)?
};

Expand Down
37 changes: 36 additions & 1 deletion backend/windmill-worker/src/snowflake_executor.rs
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,9 @@ use windmill_common::error::to_anyhow;
use windmill_common::worker::Connection;

use windmill_common::{error::Error, worker::to_raw_value};
use windmill_parser_sql::{parse_db_resource, parse_snowflake_sig, parse_sql_blocks};
use windmill_parser_sql::{
parse_db_resource, parse_s3_mode, parse_snowflake_sig, parse_sql_blocks,
};
use windmill_queue::{CanceledBy, MiniPulledJob, HTTP_CLIENT};

use serde::{Deserialize, Serialize};
Expand Down Expand Up @@ -114,6 +116,14 @@ impl SnowflakeResponseExt for Result<Response, reqwest::Error> {
}
}

#[derive(Clone)]
struct S3Mode {
client: AuthedClient,
object_key: String,
storage: Option<String>,
workspace_id: String,
}

fn do_snowflake_inner<'a>(
query: &'a str,
job_args: &HashMap<String, Value>,
Expand All @@ -124,6 +134,7 @@ fn do_snowflake_inner<'a>(
column_order: Option<&'a mut Option<Vec<String>>>,
skip_collect: bool,
http_client: &'a Client,
s3: Option<S3Mode>,
) -> windmill_common::error::Result<BoxFuture<'a, windmill_common::error::Result<Box<RawValue>>>> {
let sig = parse_snowflake_sig(&query)
.map_err(|x| Error::ExecutionErr(x.to_string()))?
Expand Down Expand Up @@ -170,6 +181,22 @@ fn do_snowflake_inner<'a>(
if skip_collect {
handle_snowflake_result(result).await?;
Ok(to_raw_value(&Value::Array(vec![])))
} else if let Some(ref s3) = s3 {
// do not do parse_snowflake_response as it will call .json() and
// load the entire response into memory
let result = result.map_err(|e| {
Error::ExecutionErr(format!("Could not send request to Snowflake: {:?}", e))
})?;
s3.client
.upload_s3_file(
s3.workspace_id.as_str(),
s3.object_key.clone(),
s3.storage.clone(),
result.bytes_stream(),
)
.await?;

Ok(serde_json::value::to_raw_value(&s3.object_key)?)
} else {
let response = result
.parse_snowflake_response::<SnowflakeResponse>()
Expand Down Expand Up @@ -260,6 +287,12 @@ pub async fn do_snowflake(
let snowflake_args = build_args_values(job, client, conn).await?;

let inline_db_res_path = parse_db_resource(&query);
let s3 = parse_s3_mode(&query).map(|s3_mode| S3Mode {
client: client.clone(),
storage: s3_mode.storage,
object_key: format!("{}/{}.txt", s3_mode.folder_key, job.id),
workspace_id: job.workspace_id.clone(),
});

let db_arg = if let Some(inline_db_res_path) = inline_db_res_path {
Some(
Expand Down Expand Up @@ -391,6 +424,7 @@ pub async fn do_snowflake(
None,
annotations.return_last_result && i < queries.len() - 1,
&http_client,
s3.clone(),
)
})
.collect::<windmill_common::error::Result<Vec<_>>>()?;
Expand Down Expand Up @@ -420,6 +454,7 @@ pub async fn do_snowflake(
Some(column_order),
false,
&http_client,
s3.clone(),
)?
};
let r = run_future_with_polling_update_job_poller(
Expand Down
40 changes: 39 additions & 1 deletion backend/windmill-worker/src/worker.rs
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,7 @@ use windmill_common::METRICS_DEBUG_ENABLED;
#[cfg(feature = "prometheus")]
use windmill_common::METRICS_ENABLED;

use reqwest::Response;
use reqwest::{Body, Response};
use serde::{de::DeserializeOwned, Deserialize, Serialize};
use sqlx::types::Json;
use std::{
Expand Down Expand Up @@ -520,6 +520,44 @@ impl AuthedClient {
_ => Err(anyhow::anyhow!(response.text().await.unwrap_or_default())),
}
}

pub async fn upload_s3_file<S>(
&self,
workspace_id: &str,
object_key: String,
storage: Option<String>,
body: S,
) -> anyhow::Result<Response>
where
S: futures::stream::TryStream + Send + 'static,
S::Error: Into<Box<dyn std::error::Error + Send + Sync>>,
bytes::Bytes: From<S::Ok>,
{
let mut query = vec![("file_key", object_key)];
if let Some(storage) = storage {
query.push(("storage", storage.clone()));
}
self.force_client
.as_ref()
.unwrap_or(&HTTP_CLIENT)
.post(format!(
"{}/api/w/{}/job_helpers/upload_s3_file",
self.base_internal_url, workspace_id
))
.query(&query)
.header(
reqwest::header::ACCEPT,
reqwest::header::HeaderValue::from_static("application/json"),
)
.header(
reqwest::header::AUTHORIZATION,
reqwest::header::HeaderValue::from_str(&format!("Bearer {}", self.token))?,
)
.body(Body::wrap_stream(body))
.send()
.await
.context(format!("Sent upload_s3_file request",))
}
}

#[derive(Clone)]
Expand Down