Skip to content

Commit 111abaa

Browse files
authored
refactor: use arrow-pg for duckdb example (#273)
* refactor: use arrow-pg for duckdb example * chore: update arrow-pg
1 parent 8ade649 commit 111abaa

2 files changed

Lines changed: 66 additions & 143 deletions

File tree

Cargo.toml

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -90,13 +90,17 @@ rustls-pki-types = { version = "1.10" }
9090
rusqlite = { version = "0.36.0", features = ["column_decltype"] }
9191
## for duckdb example
9292
duckdb = { version = "1.0.0" }
93+
arrow-pg = "0.1.1"
9394

9495
## for loading custom cert files
9596
rustls-pemfile = "2.0"
9697
## webpki-roots has mozilla's set of roots
9798
## rustls-native-certs loads roots from current system
9899
gluesql = { version = "0.16", default-features = false, features = ["gluesql_memory_storage"] }
99100

101+
[patch.crates-io]
102+
pgwire = { path = "." }
103+
100104
[workspace]
101105
members = [
102106
".",

examples/duckdb.rs

Lines changed: 62 additions & 143 deletions
Original file line numberDiff line numberDiff line change
@@ -1,25 +1,23 @@
11
use std::sync::{Arc, Mutex};
22

3+
use arrow_pg::datatypes::arrow_schema_to_pg_fields;
4+
use arrow_pg::datatypes::encode_recordbatch;
5+
use arrow_pg::datatypes::into_pg_type;
36
use async_trait::async_trait;
4-
use duckdb::arrow::datatypes::DataType;
5-
use duckdb::Rows;
6-
use duckdb::{params, types::ValueRef, Connection, Statement, ToSql};
7+
use duckdb::{params, Connection, Statement, ToSql};
78
use futures::stream;
8-
use futures::Stream;
99
use pgwire::api::auth::md5pass::{hash_md5_password, Md5PasswordAuthStartupHandler};
1010
use pgwire::api::auth::{AuthSource, DefaultServerParameterProvider, LoginInfo, Password};
1111
use pgwire::api::cancel::NoopCancelHandler;
1212
use pgwire::api::copy::NoopCopyHandler;
1313
use pgwire::api::portal::{Format, Portal};
1414
use pgwire::api::query::{ExtendedQueryHandler, SimpleQueryHandler};
1515
use pgwire::api::results::{
16-
DataRowEncoder, DescribePortalResponse, DescribeStatementResponse, FieldInfo, QueryResponse,
17-
Response, Tag,
16+
DescribePortalResponse, DescribeStatementResponse, FieldInfo, QueryResponse, Response, Tag,
1817
};
1918
use pgwire::api::stmt::{NoopQueryParser, StoredStatement};
2019
use pgwire::api::{ClientInfo, NoopErrorHandler, PgWireServerHandlers, Type};
21-
use pgwire::error::{ErrorInfo, PgWireError, PgWireResult};
22-
use pgwire::messages::data::DataRow;
20+
use pgwire::error::{PgWireError, PgWireResult};
2321
use pgwire::tokio::process_socket;
2422
use tokio::net::TcpListener;
2523

@@ -55,140 +53,34 @@ impl SimpleQueryHandler for DuckDBBackend {
5553
let mut stmt = conn
5654
.prepare(query)
5755
.map_err(|e| PgWireError::ApiError(Box::new(e)))?;
58-
let rows = stmt
59-
.query(params![])
56+
57+
let ret = stmt
58+
.query_arrow(params![])
6059
.map_err(|e| PgWireError::ApiError(Box::new(e)))?;
61-
let row_stmt = rows.as_ref().unwrap();
62-
let header = Arc::new(row_desc_from_stmt(row_stmt, &Format::UnifiedText)?);
63-
let s = encode_row_data(rows, header.clone());
64-
Ok(vec![Response::Query(QueryResponse::new(header, s))])
60+
let schema = ret.get_schema();
61+
let header = Arc::new(arrow_schema_to_pg_fields(
62+
schema.as_ref(),
63+
&Format::UnifiedText,
64+
)?);
65+
66+
let header_ref = header.clone();
67+
let data = ret
68+
.flat_map(move |rb| encode_recordbatch(header_ref.clone(), rb))
69+
.collect::<Vec<_>>();
70+
Ok(vec![Response::Query(QueryResponse::new(
71+
header,
72+
stream::iter(data.into_iter()),
73+
))])
6574
} else {
6675
conn.execute(query, params![])
6776
.map(|affected_rows| {
68-
vec![Response::Execution(
69-
Tag::new("OK").with_rows(affected_rows).into(),
70-
)]
77+
vec![Response::Execution(Tag::new("OK").with_rows(affected_rows))]
7178
})
7279
.map_err(|e| PgWireError::ApiError(Box::new(e)))
7380
}
7481
}
7582
}
7683

77-
fn into_pg_type(df_type: &DataType) -> PgWireResult<Type> {
78-
Ok(match df_type {
79-
DataType::Null => Type::UNKNOWN,
80-
DataType::Boolean => Type::BOOL,
81-
DataType::Int8 | DataType::UInt8 => Type::CHAR,
82-
DataType::Int16 | DataType::UInt16 => Type::INT2,
83-
DataType::Int32 | DataType::UInt32 => Type::INT4,
84-
DataType::Int64 | DataType::UInt64 => Type::INT8,
85-
DataType::Timestamp(_, _) => Type::TIMESTAMP,
86-
DataType::Time32(_) | DataType::Time64(_) => Type::TIME,
87-
DataType::Date32 | DataType::Date64 => Type::DATE,
88-
DataType::Binary => Type::BYTEA,
89-
DataType::Float32 => Type::FLOAT4,
90-
DataType::Float64 => Type::FLOAT8,
91-
DataType::Utf8 => Type::VARCHAR,
92-
DataType::List(field) => match field.data_type() {
93-
DataType::Boolean => Type::BOOL_ARRAY,
94-
DataType::Int8 | DataType::UInt8 => Type::CHAR_ARRAY,
95-
DataType::Int16 | DataType::UInt16 => Type::INT2_ARRAY,
96-
DataType::Int32 | DataType::UInt32 => Type::INT4_ARRAY,
97-
DataType::Int64 | DataType::UInt64 => Type::INT8_ARRAY,
98-
DataType::Timestamp(_, _) => Type::TIMESTAMP_ARRAY,
99-
DataType::Time32(_) | DataType::Time64(_) => Type::TIME_ARRAY,
100-
DataType::Date32 | DataType::Date64 => Type::DATE_ARRAY,
101-
DataType::Binary => Type::BYTEA_ARRAY,
102-
DataType::Float32 => Type::FLOAT4_ARRAY,
103-
DataType::Float64 => Type::FLOAT8_ARRAY,
104-
DataType::Utf8 => Type::VARCHAR_ARRAY,
105-
list_type => {
106-
return Err(PgWireError::UserError(Box::new(ErrorInfo::new(
107-
"ERROR".to_owned(),
108-
"XX000".to_owned(),
109-
format!("Unsupported List Datatype {list_type}"),
110-
))));
111-
}
112-
},
113-
_ => {
114-
return Err(PgWireError::UserError(Box::new(ErrorInfo::new(
115-
"ERROR".to_owned(),
116-
"XX000".to_owned(),
117-
format!("Unsupported Datatype {df_type}"),
118-
))));
119-
}
120-
})
121-
}
122-
123-
fn row_desc_from_stmt(stmt: &Statement, format: &Format) -> PgWireResult<Vec<FieldInfo>> {
124-
let columns = stmt.column_count();
125-
126-
(0..columns)
127-
.map(|idx| {
128-
let datatype = stmt.column_type(idx);
129-
let name = stmt.column_name(idx).unwrap();
130-
131-
Ok(FieldInfo::new(
132-
name.clone(),
133-
None,
134-
None,
135-
into_pg_type(&datatype).unwrap(),
136-
format.format_for(idx),
137-
))
138-
})
139-
.collect()
140-
}
141-
142-
fn encode_row_data(
143-
mut rows: Rows<'_>,
144-
schema: Arc<Vec<FieldInfo>>,
145-
) -> impl Stream<Item = PgWireResult<DataRow>> {
146-
let mut results = Vec::new();
147-
let ncols = schema.len();
148-
while let Ok(Some(row)) = rows.next() {
149-
let mut encoder = DataRowEncoder::new(schema.clone());
150-
for idx in 0..ncols {
151-
let data = row.get_ref_unwrap::<usize>(idx);
152-
match data {
153-
ValueRef::Null => encoder.encode_field(&None::<i8>).unwrap(),
154-
ValueRef::TinyInt(i) => {
155-
encoder.encode_field(&i).unwrap();
156-
}
157-
ValueRef::SmallInt(i) => {
158-
encoder.encode_field(&i).unwrap();
159-
}
160-
ValueRef::Int(i) => {
161-
encoder.encode_field(&i).unwrap();
162-
}
163-
ValueRef::BigInt(i) => {
164-
encoder.encode_field(&i).unwrap();
165-
}
166-
ValueRef::Float(f) => {
167-
encoder.encode_field(&f).unwrap();
168-
}
169-
ValueRef::Double(f) => {
170-
encoder.encode_field(&f).unwrap();
171-
}
172-
ValueRef::Text(t) => {
173-
encoder
174-
.encode_field(&String::from_utf8_lossy(t).as_ref())
175-
.unwrap();
176-
}
177-
ValueRef::Blob(b) => {
178-
encoder.encode_field(&b).unwrap();
179-
}
180-
_ => {
181-
unimplemented!("More types to be supported.")
182-
}
183-
}
184-
}
185-
186-
results.push(encoder.finish());
187-
}
188-
189-
stream::iter(results.into_iter())
190-
}
191-
19284
fn get_params(portal: &Portal<String>) -> Vec<Box<dyn ToSql>> {
19385
let mut results = Vec::with_capacity(portal.parameter_len());
19486
for i in 0..portal.parameter_len() {
@@ -232,6 +124,25 @@ fn get_params(portal: &Portal<String>) -> Vec<Box<dyn ToSql>> {
232124
results
233125
}
234126

127+
fn row_desc_from_stmt(stmt: &Statement, format: &Format) -> PgWireResult<Vec<FieldInfo>> {
128+
let columns = stmt.column_count();
129+
130+
(0..columns)
131+
.map(|idx| {
132+
let datatype = stmt.column_type(idx);
133+
let name = stmt.column_name(idx).unwrap();
134+
135+
Ok(FieldInfo::new(
136+
name.clone(),
137+
None,
138+
None,
139+
into_pg_type(&datatype).unwrap(),
140+
format.format_for(idx),
141+
))
142+
})
143+
.collect()
144+
}
145+
235146
#[async_trait]
236147
impl ExtendedQueryHandler for DuckDBBackend {
237148
type Statement = String;
@@ -262,18 +173,27 @@ impl ExtendedQueryHandler for DuckDBBackend {
262173
.collect::<Vec<&dyn duckdb::ToSql>>();
263174

264175
if query.to_uppercase().starts_with("SELECT") {
265-
let rows: Rows<'_> = stmt
266-
.query::<&[&dyn duckdb::ToSql]>(params_ref.as_ref())
176+
let ret = stmt
177+
.query_arrow(params![])
267178
.map_err(|e| PgWireError::ApiError(Box::new(e)))?;
268-
let row_stmt = rows.as_ref().unwrap();
269-
let header = Arc::new(row_desc_from_stmt(row_stmt, &portal.result_column_format)?);
270-
let s = encode_row_data(rows, header.clone());
271-
Ok(Response::Query(QueryResponse::new(header, s)))
179+
let schema = ret.get_schema();
180+
let header = Arc::new(arrow_schema_to_pg_fields(
181+
schema.as_ref(),
182+
&Format::UnifiedText,
183+
)?);
184+
185+
let header_ref = header.clone();
186+
let data = ret
187+
.flat_map(move |rb| encode_recordbatch(header_ref.clone(), rb))
188+
.collect::<Vec<_>>();
189+
190+
Ok(Response::Query(QueryResponse::new(
191+
header,
192+
stream::iter(data.into_iter()),
193+
)))
272194
} else {
273195
stmt.execute::<&[&dyn duckdb::ToSql]>(params_ref.as_ref())
274-
.map(|affected_rows| {
275-
Response::Execution(Tag::new("OK").with_rows(affected_rows).into())
276-
})
196+
.map(|affected_rows| Response::Execution(Tag::new("OK").with_rows(affected_rows)))
277197
.map_err(|e| PgWireError::ApiError(Box::new(e)))
278198
}
279199
}
@@ -307,8 +227,7 @@ impl ExtendedQueryHandler for DuckDBBackend {
307227
let stmt = conn
308228
.prepare_cached(&portal.statement.statement)
309229
.map_err(|e| PgWireError::ApiError(Box::new(e)))?;
310-
row_desc_from_stmt(&stmt, &portal.result_column_format)
311-
.map(|fields| DescribePortalResponse::new(fields))
230+
row_desc_from_stmt(&stmt, &portal.result_column_format).map(DescribePortalResponse::new)
312231
}
313232
}
314233

0 commit comments

Comments
 (0)