Skip to content

Commit 4742f40

Browse files
authored
Merge pull request #402 from superfly/somtochi/pgwire-0.36.0
Parse arrays with text format in postgres
2 parents f508236 + 9e5d848 commit 4742f40

File tree

4 files changed

+244
-51
lines changed

4 files changed

+244
-51
lines changed

crates/corro-pg/src/codec.rs

Lines changed: 113 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5,9 +5,12 @@ use pgwire::{
55
api::{self, ClientInfo},
66
error::PgWireError,
77
messages as msg,
8+
types::FromSqlText,
89
};
10+
use postgres_types::Type;
911
use std::{collections::HashMap, io};
1012
use tokio_util::codec;
13+
use tracing::debug;
1114

1215
pub struct Client {
1316
pub socket_addr: std::net::SocketAddr,
@@ -169,3 +172,113 @@ where
169172
}
170173
}
171174
}
175+
176+
pub trait VecFromSqlText: Sized {
177+
fn from_vec_sql_text(
178+
ty: &Type,
179+
input: &[u8],
180+
) -> Result<Self, Box<dyn std::error::Error + Sync + Send>>;
181+
}
182+
183+
// Re-implementation of the ToSqlText trait from pg_wire to make it generic over different types.
184+
// Implemented as a macro in pgwire
185+
// https://github.com/sunng87/pgwire/blob/6cbce9d444cc86a01d992f6b35f84c024f10ceda/src/types/from_sql_text.rs#L402
186+
impl<T: FromSqlText> VecFromSqlText for Vec<T> {
187+
fn from_vec_sql_text(
188+
ty: &Type,
189+
input: &[u8],
190+
) -> Result<Self, Box<dyn std::error::Error + Sync + Send>> {
191+
// PostgreSQL array text format: {elem1,elem2,elem3}
192+
// Remove the outer braces
193+
let input_str = std::str::from_utf8(input)?;
194+
195+
if input_str.is_empty() {
196+
return Ok(Vec::new());
197+
}
198+
199+
// Check if it's an array format
200+
if !input_str.starts_with('{') || !input_str.ends_with('}') {
201+
return Err(format!(
202+
"Invalid array format: must start with '{{' and end with '}}', input: {input_str}"
203+
)
204+
.into());
205+
}
206+
207+
let inner = &input_str[1..input_str.len() - 1];
208+
209+
if inner.is_empty() {
210+
return Ok(Vec::new());
211+
}
212+
213+
let elements = extract_array_elements(inner)?;
214+
let mut result = Vec::new();
215+
216+
for element_str in elements {
217+
let element = T::from_sql_text(ty, element_str.as_bytes())?;
218+
result.push(element);
219+
}
220+
221+
Ok(result)
222+
}
223+
}
224+
225+
// Helper function to extract array elements
226+
// https://github.com/sunng87/pgwire/blob/6cbce9d444cc86a01d992f6b35f84c024f10ceda/src/types/from_sql_text.rs#L402
227+
fn extract_array_elements(
228+
input: &str,
229+
) -> Result<Vec<String>, Box<dyn std::error::Error + Sync + Send>> {
230+
if input.is_empty() {
231+
return Ok(Vec::new());
232+
}
233+
234+
let mut elements = Vec::new();
235+
let mut current = String::new();
236+
let mut in_quotes = false;
237+
let mut escape_next = false;
238+
let mut seen_content = false; // helpful for tracking when the last element is an empty string
239+
240+
for ch in input.chars() {
241+
match ch {
242+
'\\' if !escape_next => {
243+
escape_next = true;
244+
}
245+
'"' if !escape_next => {
246+
in_quotes = !in_quotes;
247+
// we have seen a new element surrounded by quotes
248+
if !in_quotes {
249+
seen_content = true;
250+
}
251+
// Don't include the quotes in the output
252+
}
253+
'{' if !in_quotes && !escape_next => {
254+
return Err("Nested arrays are not supported".into());
255+
}
256+
'}' if !in_quotes && !escape_next => {
257+
return Err("Nested arrays are not supported".into());
258+
}
259+
',' if !in_quotes && !escape_next => {
260+
// End of current element
261+
if !current.trim().eq_ignore_ascii_case("NULL") {
262+
elements.push(std::mem::take(&mut current));
263+
seen_content = false;
264+
}
265+
}
266+
_ => {
267+
current.push(ch);
268+
escape_next = false;
269+
seen_content = true;
270+
}
271+
}
272+
}
273+
274+
// Process the last element
275+
if seen_content && !current.trim().eq_ignore_ascii_case("NULL") {
276+
elements.push(current);
277+
}
278+
279+
debug!(
280+
"extracted elements: {elements:?} from input: {input}, lenght: {}",
281+
elements.len()
282+
);
283+
Ok(elements)
284+
}

crates/corro-pg/src/lib.rs

Lines changed: 20 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@ mod ssl;
44
pub mod utils;
55
mod vtab;
66

7+
use codec::VecFromSqlText;
78
use eyre::WrapErr;
89
use std::{
910
collections::{BTreeSet, HashMap, VecDeque},
@@ -42,6 +43,7 @@ use pgwire::{
4243
startup::ParameterStatus,
4344
PgWireBackendMessage, PgWireFrontendMessage,
4445
},
46+
types::FromSqlText,
4547
};
4648
use postgres_types::{FromSql, Type};
4749
use rusqlite::{
@@ -2647,13 +2649,13 @@ fn from_type_and_format<'a, E, T: FromSql<'a> + FromStr<Err = E>>(
26472649
})
26482650
}
26492651

2650-
fn from_array_type_and_format<'a, T: FromSql<'a>>(
2652+
fn from_array_type_and_format<'a, T: FromSql<'a> + FromSqlText>(
26512653
t: &Type,
26522654
b: &'a [u8],
26532655
format_code: FormatCode,
26542656
) -> Result<Vec<T>, ToParamError<String>> {
26552657
Ok(match format_code {
2656-
FormatCode::Text => panic!("Impossible - arrays are only sent in binary format"),
2658+
FormatCode::Text => Vec::<T>::from_vec_sql_text(t, b).map_err(ToParamError::FromSql)?,
26572659
FormatCode::Binary => Vec::<T>::from_sql(t, b).map_err(ToParamError::FromSql)?,
26582660
})
26592661
}
@@ -2675,7 +2677,7 @@ impl From<UnsupportedSqliteToPostgresType> for ErrorResponse {
26752677
}
26762678

26772679
#[derive(Debug, thiserror::Error)]
2678-
#[error("Untyped array argument for unnest(), please use CAST($N AS T) where T is one of: TEXT[] BLOB[] INT[] INTEGER[] BIGINT[] REAL[] FLOAT[] DOUBLE[]")]
2680+
#[error("Untyped array argument for unnest() (or corro_unnest()), please use CAST($N AS T) where T is one of: TEXT[] BLOB[] INT[] INTEGER[] BIGINT[] REAL[] FLOAT[] DOUBLE[]")]
26792681
struct UntypedUnnestParameter;
26802682

26812683
impl From<UntypedUnnestParameter> for PgWireBackendMessage {
@@ -2902,10 +2904,16 @@ fn extract_params<'schema, 'stmt>(
29022904
Expr::FunctionCall {
29032905
name: _,
29042906
distinctness: _,
2905-
args: _,
2907+
args,
29062908
filter_over: _,
29072909
order_by: _,
2908-
} => {}
2910+
} => {
2911+
if let Some(args) = args {
2912+
for expr in args.iter() {
2913+
extract_params(schema, expr, tables, params)?
2914+
}
2915+
}
2916+
}
29092917

29102918
Expr::FunctionCallStar {
29112919
name: _,
@@ -3105,7 +3113,8 @@ fn handle_table_call_params<'schema, 'stmt>(
31053113
params: &mut ParamsList<'stmt, 'schema>,
31063114
) -> Result<(), UntypedUnnestParameter> {
31073115
if let Some(exprs) = args {
3108-
let is_unnest = qname.name.0.eq_ignore_ascii_case("UNNEST");
3116+
let is_unnest = qname.name.0.eq_ignore_ascii_case("CORRO_UNNEST")
3117+
|| qname.name.0.eq_ignore_ascii_case("UNNEST");
31093118

31103119
for expr in exprs.iter() {
31113120
// If not unnest, just extract params
@@ -3352,7 +3361,11 @@ fn parameter_types<'schema, 'stmt>(
33523361

33533362
let mut tables = HashMap::new();
33543363
if let Some(tbl) = schema.tables.get(&tbl_name.name.0) {
3355-
tables.insert(tbl_name.name.0.clone(), tbl);
3364+
if let Some(alias) = &tbl_name.alias {
3365+
tables.insert(alias.0.clone(), tbl);
3366+
} else {
3367+
tables.insert(tbl_name.name.0.clone(), tbl);
3368+
}
33563369
}
33573370
if let Some(where_clause) = where_clause {
33583371
extract_params(schema, where_clause, &tables, &mut params)?;

0 commit comments

Comments
 (0)