|
1 | 1 | use std::sync::{Arc, Mutex}; |
2 | 2 |
|
| 3 | +use arrow_pg::datatypes::arrow_schema_to_pg_fields; |
| 4 | +use arrow_pg::datatypes::encode_recordbatch; |
| 5 | +use arrow_pg::datatypes::into_pg_type; |
3 | 6 | use async_trait::async_trait; |
4 | | -use duckdb::arrow::datatypes::DataType; |
5 | | -use duckdb::Rows; |
6 | | -use duckdb::{params, types::ValueRef, Connection, Statement, ToSql}; |
| 7 | +use duckdb::{params, Connection, Statement, ToSql}; |
7 | 8 | use futures::stream; |
8 | | -use futures::Stream; |
9 | 9 | use pgwire::api::auth::md5pass::{hash_md5_password, Md5PasswordAuthStartupHandler}; |
10 | 10 | use pgwire::api::auth::{AuthSource, DefaultServerParameterProvider, LoginInfo, Password}; |
11 | 11 | use pgwire::api::cancel::NoopCancelHandler; |
12 | 12 | use pgwire::api::copy::NoopCopyHandler; |
13 | 13 | use pgwire::api::portal::{Format, Portal}; |
14 | 14 | use pgwire::api::query::{ExtendedQueryHandler, SimpleQueryHandler}; |
15 | 15 | use pgwire::api::results::{ |
16 | | - DataRowEncoder, DescribePortalResponse, DescribeStatementResponse, FieldInfo, QueryResponse, |
17 | | - Response, Tag, |
| 16 | + DescribePortalResponse, DescribeStatementResponse, FieldInfo, QueryResponse, Response, Tag, |
18 | 17 | }; |
19 | 18 | use pgwire::api::stmt::{NoopQueryParser, StoredStatement}; |
20 | 19 | use pgwire::api::{ClientInfo, NoopErrorHandler, PgWireServerHandlers, Type}; |
21 | | -use pgwire::error::{ErrorInfo, PgWireError, PgWireResult}; |
22 | | -use pgwire::messages::data::DataRow; |
| 20 | +use pgwire::error::{PgWireError, PgWireResult}; |
23 | 21 | use pgwire::tokio::process_socket; |
24 | 22 | use tokio::net::TcpListener; |
25 | 23 |
|
@@ -55,140 +53,34 @@ impl SimpleQueryHandler for DuckDBBackend { |
55 | 53 | let mut stmt = conn |
56 | 54 | .prepare(query) |
57 | 55 | .map_err(|e| PgWireError::ApiError(Box::new(e)))?; |
58 | | - let rows = stmt |
59 | | - .query(params![]) |
| 56 | + |
| 57 | + let ret = stmt |
| 58 | + .query_arrow(params![]) |
60 | 59 | .map_err(|e| PgWireError::ApiError(Box::new(e)))?; |
61 | | - let row_stmt = rows.as_ref().unwrap(); |
62 | | - let header = Arc::new(row_desc_from_stmt(row_stmt, &Format::UnifiedText)?); |
63 | | - let s = encode_row_data(rows, header.clone()); |
64 | | - Ok(vec![Response::Query(QueryResponse::new(header, s))]) |
| 60 | + let schema = ret.get_schema(); |
| 61 | + let header = Arc::new(arrow_schema_to_pg_fields( |
| 62 | + schema.as_ref(), |
| 63 | + &Format::UnifiedText, |
| 64 | + )?); |
| 65 | + |
| 66 | + let header_ref = header.clone(); |
| 67 | + let data = ret |
| 68 | + .flat_map(move |rb| encode_recordbatch(header_ref.clone(), rb)) |
| 69 | + .collect::<Vec<_>>(); |
| 70 | + Ok(vec![Response::Query(QueryResponse::new( |
| 71 | + header, |
| 72 | + stream::iter(data.into_iter()), |
| 73 | + ))]) |
65 | 74 | } else { |
66 | 75 | conn.execute(query, params![]) |
67 | 76 | .map(|affected_rows| { |
68 | | - vec![Response::Execution( |
69 | | - Tag::new("OK").with_rows(affected_rows).into(), |
70 | | - )] |
| 77 | + vec![Response::Execution(Tag::new("OK").with_rows(affected_rows))] |
71 | 78 | }) |
72 | 79 | .map_err(|e| PgWireError::ApiError(Box::new(e))) |
73 | 80 | } |
74 | 81 | } |
75 | 82 | } |
76 | 83 |
|
77 | | -fn into_pg_type(df_type: &DataType) -> PgWireResult<Type> { |
78 | | - Ok(match df_type { |
79 | | - DataType::Null => Type::UNKNOWN, |
80 | | - DataType::Boolean => Type::BOOL, |
81 | | - DataType::Int8 | DataType::UInt8 => Type::CHAR, |
82 | | - DataType::Int16 | DataType::UInt16 => Type::INT2, |
83 | | - DataType::Int32 | DataType::UInt32 => Type::INT4, |
84 | | - DataType::Int64 | DataType::UInt64 => Type::INT8, |
85 | | - DataType::Timestamp(_, _) => Type::TIMESTAMP, |
86 | | - DataType::Time32(_) | DataType::Time64(_) => Type::TIME, |
87 | | - DataType::Date32 | DataType::Date64 => Type::DATE, |
88 | | - DataType::Binary => Type::BYTEA, |
89 | | - DataType::Float32 => Type::FLOAT4, |
90 | | - DataType::Float64 => Type::FLOAT8, |
91 | | - DataType::Utf8 => Type::VARCHAR, |
92 | | - DataType::List(field) => match field.data_type() { |
93 | | - DataType::Boolean => Type::BOOL_ARRAY, |
94 | | - DataType::Int8 | DataType::UInt8 => Type::CHAR_ARRAY, |
95 | | - DataType::Int16 | DataType::UInt16 => Type::INT2_ARRAY, |
96 | | - DataType::Int32 | DataType::UInt32 => Type::INT4_ARRAY, |
97 | | - DataType::Int64 | DataType::UInt64 => Type::INT8_ARRAY, |
98 | | - DataType::Timestamp(_, _) => Type::TIMESTAMP_ARRAY, |
99 | | - DataType::Time32(_) | DataType::Time64(_) => Type::TIME_ARRAY, |
100 | | - DataType::Date32 | DataType::Date64 => Type::DATE_ARRAY, |
101 | | - DataType::Binary => Type::BYTEA_ARRAY, |
102 | | - DataType::Float32 => Type::FLOAT4_ARRAY, |
103 | | - DataType::Float64 => Type::FLOAT8_ARRAY, |
104 | | - DataType::Utf8 => Type::VARCHAR_ARRAY, |
105 | | - list_type => { |
106 | | - return Err(PgWireError::UserError(Box::new(ErrorInfo::new( |
107 | | - "ERROR".to_owned(), |
108 | | - "XX000".to_owned(), |
109 | | - format!("Unsupported List Datatype {list_type}"), |
110 | | - )))); |
111 | | - } |
112 | | - }, |
113 | | - _ => { |
114 | | - return Err(PgWireError::UserError(Box::new(ErrorInfo::new( |
115 | | - "ERROR".to_owned(), |
116 | | - "XX000".to_owned(), |
117 | | - format!("Unsupported Datatype {df_type}"), |
118 | | - )))); |
119 | | - } |
120 | | - }) |
121 | | -} |
122 | | - |
123 | | -fn row_desc_from_stmt(stmt: &Statement, format: &Format) -> PgWireResult<Vec<FieldInfo>> { |
124 | | - let columns = stmt.column_count(); |
125 | | - |
126 | | - (0..columns) |
127 | | - .map(|idx| { |
128 | | - let datatype = stmt.column_type(idx); |
129 | | - let name = stmt.column_name(idx).unwrap(); |
130 | | - |
131 | | - Ok(FieldInfo::new( |
132 | | - name.clone(), |
133 | | - None, |
134 | | - None, |
135 | | - into_pg_type(&datatype).unwrap(), |
136 | | - format.format_for(idx), |
137 | | - )) |
138 | | - }) |
139 | | - .collect() |
140 | | -} |
141 | | - |
142 | | -fn encode_row_data( |
143 | | - mut rows: Rows<'_>, |
144 | | - schema: Arc<Vec<FieldInfo>>, |
145 | | -) -> impl Stream<Item = PgWireResult<DataRow>> { |
146 | | - let mut results = Vec::new(); |
147 | | - let ncols = schema.len(); |
148 | | - while let Ok(Some(row)) = rows.next() { |
149 | | - let mut encoder = DataRowEncoder::new(schema.clone()); |
150 | | - for idx in 0..ncols { |
151 | | - let data = row.get_ref_unwrap::<usize>(idx); |
152 | | - match data { |
153 | | - ValueRef::Null => encoder.encode_field(&None::<i8>).unwrap(), |
154 | | - ValueRef::TinyInt(i) => { |
155 | | - encoder.encode_field(&i).unwrap(); |
156 | | - } |
157 | | - ValueRef::SmallInt(i) => { |
158 | | - encoder.encode_field(&i).unwrap(); |
159 | | - } |
160 | | - ValueRef::Int(i) => { |
161 | | - encoder.encode_field(&i).unwrap(); |
162 | | - } |
163 | | - ValueRef::BigInt(i) => { |
164 | | - encoder.encode_field(&i).unwrap(); |
165 | | - } |
166 | | - ValueRef::Float(f) => { |
167 | | - encoder.encode_field(&f).unwrap(); |
168 | | - } |
169 | | - ValueRef::Double(f) => { |
170 | | - encoder.encode_field(&f).unwrap(); |
171 | | - } |
172 | | - ValueRef::Text(t) => { |
173 | | - encoder |
174 | | - .encode_field(&String::from_utf8_lossy(t).as_ref()) |
175 | | - .unwrap(); |
176 | | - } |
177 | | - ValueRef::Blob(b) => { |
178 | | - encoder.encode_field(&b).unwrap(); |
179 | | - } |
180 | | - _ => { |
181 | | - unimplemented!("More types to be supported.") |
182 | | - } |
183 | | - } |
184 | | - } |
185 | | - |
186 | | - results.push(encoder.finish()); |
187 | | - } |
188 | | - |
189 | | - stream::iter(results.into_iter()) |
190 | | -} |
191 | | - |
192 | 84 | fn get_params(portal: &Portal<String>) -> Vec<Box<dyn ToSql>> { |
193 | 85 | let mut results = Vec::with_capacity(portal.parameter_len()); |
194 | 86 | for i in 0..portal.parameter_len() { |
@@ -232,6 +124,25 @@ fn get_params(portal: &Portal<String>) -> Vec<Box<dyn ToSql>> { |
232 | 124 | results |
233 | 125 | } |
234 | 126 |
|
| 127 | +fn row_desc_from_stmt(stmt: &Statement, format: &Format) -> PgWireResult<Vec<FieldInfo>> { |
| 128 | + let columns = stmt.column_count(); |
| 129 | + |
| 130 | + (0..columns) |
| 131 | + .map(|idx| { |
| 132 | + let datatype = stmt.column_type(idx); |
| 133 | + let name = stmt.column_name(idx).unwrap(); |
| 134 | + |
| 135 | + Ok(FieldInfo::new( |
| 136 | + name.clone(), |
| 137 | + None, |
| 138 | + None, |
| 139 | + into_pg_type(&datatype).unwrap(), |
| 140 | + format.format_for(idx), |
| 141 | + )) |
| 142 | + }) |
| 143 | + .collect() |
| 144 | +} |
| 145 | + |
235 | 146 | #[async_trait] |
236 | 147 | impl ExtendedQueryHandler for DuckDBBackend { |
237 | 148 | type Statement = String; |
@@ -262,18 +173,27 @@ impl ExtendedQueryHandler for DuckDBBackend { |
262 | 173 | .collect::<Vec<&dyn duckdb::ToSql>>(); |
263 | 174 |
|
264 | 175 | if query.to_uppercase().starts_with("SELECT") { |
265 | | - let rows: Rows<'_> = stmt |
266 | | - .query::<&[&dyn duckdb::ToSql]>(params_ref.as_ref()) |
| 176 | + let ret = stmt |
| 177 | + .query_arrow(params![]) |
267 | 178 | .map_err(|e| PgWireError::ApiError(Box::new(e)))?; |
268 | | - let row_stmt = rows.as_ref().unwrap(); |
269 | | - let header = Arc::new(row_desc_from_stmt(row_stmt, &portal.result_column_format)?); |
270 | | - let s = encode_row_data(rows, header.clone()); |
271 | | - Ok(Response::Query(QueryResponse::new(header, s))) |
| 179 | + let schema = ret.get_schema(); |
| 180 | + let header = Arc::new(arrow_schema_to_pg_fields( |
| 181 | + schema.as_ref(), |
| 182 | + &Format::UnifiedText, |
| 183 | + )?); |
| 184 | + |
| 185 | + let header_ref = header.clone(); |
| 186 | + let data = ret |
| 187 | + .flat_map(move |rb| encode_recordbatch(header_ref.clone(), rb)) |
| 188 | + .collect::<Vec<_>>(); |
| 189 | + |
| 190 | + Ok(Response::Query(QueryResponse::new( |
| 191 | + header, |
| 192 | + stream::iter(data.into_iter()), |
| 193 | + ))) |
272 | 194 | } else { |
273 | 195 | stmt.execute::<&[&dyn duckdb::ToSql]>(params_ref.as_ref()) |
274 | | - .map(|affected_rows| { |
275 | | - Response::Execution(Tag::new("OK").with_rows(affected_rows).into()) |
276 | | - }) |
| 196 | + .map(|affected_rows| Response::Execution(Tag::new("OK").with_rows(affected_rows))) |
277 | 197 | .map_err(|e| PgWireError::ApiError(Box::new(e))) |
278 | 198 | } |
279 | 199 | } |
@@ -307,8 +227,7 @@ impl ExtendedQueryHandler for DuckDBBackend { |
307 | 227 | let stmt = conn |
308 | 228 | .prepare_cached(&portal.statement.statement) |
309 | 229 | .map_err(|e| PgWireError::ApiError(Box::new(e)))?; |
310 | | - row_desc_from_stmt(&stmt, &portal.result_column_format) |
311 | | - .map(|fields| DescribePortalResponse::new(fields)) |
| 230 | + row_desc_from_stmt(&stmt, &portal.result_column_format).map(DescribePortalResponse::new) |
312 | 231 | } |
313 | 232 | } |
314 | 233 |
|
|
0 commit comments