Skip to content

Commit 6e42290

Browse files
authored
Add MySQL write support (datafusion-contrib#134)
* Add MySQLTableWriter for insertion into MySQL * Add test_arrow_mysql_roundtrip * Refactor use_json_insert_for_type function to separate sqlite and mysql feature checks * feat: add feature mysql-federation * update datafusion-federation to 0.3.1
1 parent 86cd13c commit 6e42290

9 files changed

Lines changed: 829 additions & 25 deletions

File tree

src/mysql.rs

Lines changed: 380 additions & 5 deletions
Large diffs are not rendered by default.

src/mysql/write.rs

Lines changed: 190 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,190 @@
1+
use crate::mysql::MySQL;
2+
use crate::util::on_conflict::OnConflict;
3+
use crate::util::retriable_error::check_and_mark_retriable_error;
4+
use crate::util::{constraints, to_datafusion_error};
5+
use arrow::datatypes::SchemaRef;
6+
use async_trait::async_trait;
7+
use datafusion::{
8+
catalog::Session,
9+
datasource::{TableProvider, TableType},
10+
execution::{SendableRecordBatchStream, TaskContext},
11+
logical_expr::{dml::InsertOp, Expr},
12+
physical_plan::{
13+
insert::{DataSink, DataSinkExec},
14+
metrics::MetricsSet,
15+
DisplayAs, DisplayFormatType, ExecutionPlan,
16+
},
17+
};
18+
use futures::StreamExt;
19+
use mysql_async::TxOpts;
20+
use snafu::ResultExt;
21+
use std::any::Any;
22+
use std::fmt;
23+
use std::sync::Arc;
24+
25+
#[derive(Debug, Clone)]
26+
pub struct MySQLTableWriter {
27+
pub read_provider: Arc<dyn TableProvider>,
28+
mysql: Arc<MySQL>,
29+
on_conflict: Option<OnConflict>,
30+
}
31+
32+
impl MySQLTableWriter {
33+
pub fn create(
34+
read_provider: Arc<dyn TableProvider>,
35+
mysql: MySQL,
36+
on_conflict: Option<OnConflict>,
37+
) -> Arc<Self> {
38+
Arc::new(Self {
39+
read_provider,
40+
mysql: Arc::new(mysql),
41+
on_conflict,
42+
})
43+
}
44+
45+
pub fn mysql(&self) -> Arc<MySQL> {
46+
Arc::clone(&self.mysql)
47+
}
48+
}
49+
50+
#[async_trait]
51+
impl TableProvider for MySQLTableWriter {
52+
fn as_any(&self) -> &dyn Any {
53+
self
54+
}
55+
56+
fn schema(&self) -> SchemaRef {
57+
self.read_provider.schema()
58+
}
59+
60+
fn table_type(&self) -> TableType {
61+
TableType::Base
62+
}
63+
64+
async fn scan(
65+
&self,
66+
state: &dyn Session,
67+
projection: Option<&Vec<usize>>,
68+
filters: &[Expr],
69+
limit: Option<usize>,
70+
) -> datafusion::common::Result<Arc<dyn ExecutionPlan>> {
71+
self.read_provider
72+
.scan(state, projection, filters, limit)
73+
.await
74+
}
75+
76+
async fn insert_into(
77+
&self,
78+
_state: &dyn Session,
79+
input: Arc<dyn ExecutionPlan>,
80+
op: InsertOp,
81+
) -> datafusion::common::Result<Arc<dyn ExecutionPlan>> {
82+
Ok(Arc::new(DataSinkExec::new(
83+
input,
84+
Arc::new(MySQLDataSink::new(
85+
Arc::clone(&self.mysql),
86+
op == InsertOp::Overwrite,
87+
self.on_conflict.clone(),
88+
)),
89+
self.schema(),
90+
None,
91+
)))
92+
}
93+
}
94+
95+
pub struct MySQLDataSink {
96+
pub mysql: Arc<MySQL>,
97+
pub overwrite: bool,
98+
pub on_conflict: Option<OnConflict>,
99+
}
100+
101+
#[async_trait]
102+
impl DataSink for MySQLDataSink {
103+
fn as_any(&self) -> &dyn Any {
104+
self
105+
}
106+
107+
fn metrics(&self) -> Option<MetricsSet> {
108+
None
109+
}
110+
111+
async fn write_all(
112+
&self,
113+
mut data: SendableRecordBatchStream,
114+
_context: &Arc<TaskContext>,
115+
) -> datafusion::common::Result<u64> {
116+
let mut num_rows = 0u64;
117+
118+
let mut db_conn = self.mysql.connect().await.map_err(to_datafusion_error)?;
119+
let mysql_conn = MySQL::mysql_conn(&mut db_conn).map_err(to_datafusion_error)?;
120+
121+
let mut conn_guard = mysql_conn.conn.lock().await;
122+
let mut tx = conn_guard
123+
.start_transaction(TxOpts::default())
124+
.await
125+
.context(super::UnableToBeginTransactionSnafu)
126+
.map_err(to_datafusion_error)?;
127+
128+
if self.overwrite {
129+
self.mysql
130+
.delete_all_table_data(&mut tx)
131+
.await
132+
.map_err(to_datafusion_error)?;
133+
}
134+
135+
while let Some(batch) = data.next().await {
136+
let batch = batch.map_err(check_and_mark_retriable_error)?;
137+
let batch_num_rows = batch.num_rows();
138+
139+
if batch_num_rows == 0 {
140+
continue;
141+
}
142+
143+
num_rows += batch_num_rows as u64;
144+
145+
constraints::validate_batch_with_constraints(
146+
&[batch.clone()],
147+
self.mysql.constraints(),
148+
)
149+
.await
150+
.context(super::ConstraintViolationSnafu)
151+
.map_err(to_datafusion_error)?;
152+
153+
self.mysql
154+
.insert_batch(&mut tx, batch, self.on_conflict.clone())
155+
.await
156+
.map_err(to_datafusion_error)?;
157+
}
158+
159+
tx.commit()
160+
.await
161+
.context(super::UnableToCommitMySQLTransactionSnafu)
162+
.map_err(to_datafusion_error)?;
163+
164+
drop(conn_guard);
165+
166+
Ok(num_rows)
167+
}
168+
}
169+
170+
impl MySQLDataSink {
171+
pub fn new(mysql: Arc<MySQL>, overwrite: bool, on_conflict: Option<OnConflict>) -> Self {
172+
Self {
173+
mysql,
174+
overwrite,
175+
on_conflict,
176+
}
177+
}
178+
}
179+
180+
impl fmt::Debug for MySQLDataSink {
181+
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
182+
write!(f, "MySQLDataSink")
183+
}
184+
}
185+
186+
impl DisplayAs for MySQLDataSink {
187+
fn fmt_as(&self, _t: DisplayFormatType, f: &mut fmt::Formatter) -> fmt::Result {
188+
write!(f, "MySQLDataSink")
189+
}
190+
}

src/postgres.rs

Lines changed: 1 addition & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -39,6 +39,7 @@ use crate::util::{
3939
indexes::IndexType,
4040
on_conflict::{self, OnConflict},
4141
secrets::to_secret_map,
42+
to_datafusion_error,
4243
};
4344

4445
use self::write::PostgresTableWriter;
@@ -310,10 +311,6 @@ impl TableProviderFactory for PostgresTableProviderFactory {
310311
}
311312
}
312313

313-
fn to_datafusion_error(error: Error) -> DataFusionError {
314-
DataFusionError::External(Box::new(error))
315-
}
316-
317314
#[derive(Debug, Clone)]
318315
pub struct Postgres {
319316
table_name: String,

src/sql/arrow_sql_gen/statement.rs

Lines changed: 25 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -13,10 +13,10 @@ use num_bigint::BigInt;
1313
use sea_query::{
1414
Alias, ColumnDef, ColumnType, Expr, GenericBuilder, Index, InsertStatement, IntoIden,
1515
IntoIndexColumn, Keyword, MysqlQueryBuilder, OnConflict, PostgresQueryBuilder, Query,
16-
QueryBuilder, SimpleExpr, SqliteQueryBuilder, StringLen, Table,
16+
QueryBuilder, SimpleExpr, SqliteQueryBuilder, Table,
1717
};
1818
use snafu::Snafu;
19-
use std::{any::Any, sync::Arc};
19+
use std::sync::Arc;
2020
use time::{OffsetDateTime, PrimitiveDateTime};
2121

2222
#[derive(Debug, Snafu)]
@@ -105,6 +105,11 @@ impl CreateTableBuilder {
105105
#[must_use]
106106
pub fn build_mysql(self) -> String {
107107
self.build(MysqlQueryBuilder, &|f: &Arc<Field>| -> ColumnType {
108+
// MySQL does not natively support Arrays, Structs, etc
109+
// so we use JSON column type for List, FixedSizeList, LargeList, Struct, etc
110+
if f.data_type().is_nested() {
111+
return ColumnType::JsonBinary;
112+
}
108113
map_data_type_to_column_type(f.data_type())
109114
})
110115
}
@@ -181,11 +186,20 @@ pub fn use_json_insert_for_type<T: QueryBuilder + 'static>(
181186
query_builder: &T,
182187
) -> bool {
183188
#[cfg(feature = "sqlite")]
184-
if (query_builder as &dyn Any)
185-
.downcast_ref::<SqliteQueryBuilder>()
186-
.is_some()
187189
{
188-
return data_type.is_nested();
190+
use std::any::Any;
191+
let any_builder = query_builder as &dyn Any;
192+
if any_builder.is::<SqliteQueryBuilder>() {
193+
return data_type.is_nested();
194+
}
195+
}
196+
#[cfg(feature = "mysql")]
197+
{
198+
use std::any::Any;
199+
let any_builder = query_builder as &dyn Any;
200+
if any_builder.is::<MysqlQueryBuilder>() {
201+
return data_type.is_nested();
202+
}
189203
}
190204
false
191205
}
@@ -1264,7 +1278,11 @@ pub(crate) fn map_data_type_to_column_type(data_type: &DataType) -> ColumnType {
12641278
| DataType::FixedSizeList(list_type, _) => {
12651279
ColumnType::Array(map_data_type_to_column_type(list_type.data_type()).into())
12661280
}
1267-
DataType::Binary | DataType::LargeBinary => ColumnType::VarBinary(StringLen::Max),
1281+
// Originally mapped to VarBinary type, corresponding to MySQL's varbinary, which has a maximum length of 65535.
1282+
// This caused the error: "Row size too large. The maximum row size for the used table type, not counting BLOBs, is 65535.
1283+
// This includes storage overhead, check the manual. You have to change some columns to TEXT or BLOBs."
1284+
// Changing to Blob fixes this issue. This change does not affect Postgres, and for Sqlite, the mapping type changes from varbinary_blob to blob.
1285+
DataType::Binary | DataType::LargeBinary => ColumnType::Blob,
12681286
DataType::FixedSizeBinary(num_bytes) => ColumnType::Binary(num_bytes.to_owned() as u32),
12691287
DataType::Interval(_) => ColumnType::Interval(None, None),
12701288
// Add more mappings here as needed

src/sql/db_connection_pool/mysqlpool.rs

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -50,6 +50,7 @@ pub enum Error {
5050
UnknownMySQLDatabase { message: String },
5151
}
5252

53+
#[derive(Debug)]
5354
pub struct MySQLConnectionPool {
5455
pool: Arc<mysql_async::Pool>,
5556
join_push_down: JoinPushDown,

src/util/mod.rs

Lines changed: 10 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -4,13 +4,13 @@ use std::collections::HashMap;
44
use std::hash::Hash;
55
use std::sync::Arc;
66

7+
use crate::sql::sql_provider_datafusion::Engine;
8+
use datafusion::common::DataFusionError;
79
use datafusion::{
810
error::Result as DataFusionResult,
911
sql::unparser::{dialect::DefaultDialect, Unparser},
1012
};
1113

12-
use crate::sql::sql_provider_datafusion::Engine;
13-
1414
pub mod column_reference;
1515
pub mod constraints;
1616
pub mod indexes;
@@ -79,6 +79,14 @@ pub fn remove_prefix_from_hashmap_keys<V>(
7979
.collect()
8080
}
8181

82+
#[must_use]
83+
pub fn to_datafusion_error<E>(error: E) -> DataFusionError
84+
where
85+
E: std::error::Error + Send + Sync + 'static,
86+
{
87+
DataFusionError::External(Box::new(error))
88+
}
89+
8290
#[cfg(test)]
8391
mod tests {
8492
use super::*;

0 commit comments

Comments
 (0)