Skip to content
Merged
Show file tree
Hide file tree
Changes from 10 commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions Cargo.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

37 changes: 31 additions & 6 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@ let ctx = SessionContext::with_state(state);
- Flight SQL
- ODBC

## Examples
## Examples (in Rust)

Run the included examples to see how to use the table providers:

Expand All @@ -45,6 +45,7 @@ cargo run --example duckdb_function --features duckdb
### SQLite

```bash
# Run from repo folder
cargo run --example sqlite --features sqlite
```

Expand All @@ -69,7 +70,9 @@ EOF
```

```bash
cargo run --example postgres --features postgres
# Run from repo folder
cargo run -p datafusion-table-providers --example postgres --features postgres

```

### MySQL
Expand All @@ -93,7 +96,8 @@ EOF
```

```bash
cargo run --example mysql --features mysql
# Run from repo folder
cargo run -p datafusion-table-providers --example mysql --features mysql
```

### Flight SQL
Expand All @@ -104,16 +108,37 @@ brew install roapi
# cargo install --locked --git https://github.com/roapi/roapi --branch main --bins roapi
roapi -t taxi=https://d37ci6vzurychx.cloudfront.net/trip-data/yellow_tripdata_2024-01.parquet &

cargo run --example flight-sql --features flight
# Run from repo folder
cargo run -p datafusion-table-providers --example flight-sql --features flight
```

### ODBC

```bash
apt-get install unixodbc-dev libsqliteodbc
# or
# brew install unixodbc & brew install sqliteodbc
# If you use ARM Mac, please see https://github.com/pacman82/odbc-api#os-x-arm--mac-m1

cargo run --example odbc_sqlite --features odbc
```

#### ARM Mac

Please see https://github.com/pacman82/odbc-api#os-x-arm--mac-m1 for reference.

Steps:
1. Install unixodbc and sqliteodbc by `brew install unixodbc sqliteodbc`.
2. Find local sqliteodbc driver path by running `brew info sqliteodbc`. The path might look like `/opt/homebrew/Cellar/sqliteodbc/0.99991`.
3. Set up odbc config file at `~/.odbcinst.ini` with your local sqliteodbc path.
Example config file:
```
[SQLite3]
Description = SQLite3 ODBC Driver
Driver = /opt/homebrew/Cellar/sqliteodbc/0.99991/lib/libsqlite3odbc.dylib
```
4. Test configuration by running `odbcinst -q -d -n SQLite3`. If the path is printed out correctly, then you are all set.

## Examples (in Python)
1. Start a Python venv
2. Enter into venv
3. Inside python/ folder, run `maturin develop`.
4. Inside python/examples/ folder, run the corresponding test using `python3 [file_name]`.
2 changes: 1 addition & 1 deletion core/examples/duckdb.rs
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@ async fn main() {
// Opening in ReadOnly mode allows multiple reader processes to access
// the database at the same time
let duckdb_pool = Arc::new(
DuckDbConnectionPool::new_file("examples/duckdb_example.db", &AccessMode::ReadOnly)
DuckDbConnectionPool::new_file("core/examples/duckdb_example.db", &AccessMode::ReadOnly)
.expect("unable to create DuckDB connection pool"),
);

Expand Down
2 changes: 1 addition & 1 deletion core/examples/odbc_sqlite.rs
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@ async fn main() {
// Create SQLite ODBC connection pool
let params = to_secret_map(HashMap::from([(
"connection_string".to_owned(),
"driver=SQLite3;database=examples/sqlite_example.db;".to_owned(),
"driver=SQLite3;database=core/examples/sqlite_example.db;".to_owned(),
)]));
let odbc_pool =
Arc::new(ODBCPool::new(params).expect("unable to create SQLite ODBC connection pool"));
Expand Down
2 changes: 1 addition & 1 deletion core/examples/sqlite.rs
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@ async fn main() {
// - arg3: Connection timeout duration
let sqlite_pool = Arc::new(
SqliteConnectionPoolFactory::new(
"examples/sqlite_example.db",
"core/examples/sqlite_example.db",
Mode::File,
Duration::from_millis(5000),
)
Expand Down
5 changes: 4 additions & 1 deletion core/src/flight/exec.rs
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@ use std::str::FromStr;
use std::sync::Arc;

use crate::flight::{flight_channel, to_df_err, FlightMetadata, FlightProperties, SizeLimits};
use crate::sql::db_connection_pool::runtime::run_async_with_tokio;
use arrow_flight::error::FlightError;
use arrow_flight::flight_service_client::FlightServiceClient;
use arrow_flight::{FlightClient, FlightEndpoint, Ticket};
Expand Down Expand Up @@ -190,7 +191,9 @@ async fn flight_stream(
) -> Result<SendableRecordBatchStream> {
let mut errors: Vec<Box<dyn Error + Send + Sync>> = vec![];
for loc in partition.locations.iter() {
let client = flight_client(loc, grpc_headers.as_ref(), &size_limits).await?;
let get_client = || async { flight_client(loc, grpc_headers.as_ref(), &size_limits).await };
let client = run_async_with_tokio(get_client).await?;
// let client = flight_client(loc, grpc_headers.as_ref(), &size_limits).await?;
Comment thread
phillipleblanc marked this conversation as resolved.
Outdated
match try_fetch_stream(client, &partition.ticket, schema.clone()).await {
Ok(stream) => return Ok(stream),
Err(e) => errors.push(Box::new(e)),
Expand Down
19 changes: 3 additions & 16 deletions core/src/sql/db_connection_pool/dbconnection/duckdbconn.rs
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
use std::any::Any;
use std::sync::{Arc, OnceLock};
use std::sync::Arc;

use arrow::array::RecordBatch;
use arrow_schema::{DataType, Field};
Expand All @@ -18,9 +18,9 @@ use duckdb::{Connection, DuckdbConnectionManager};
use dyn_clone::DynClone;
use rand::distr::{Alphanumeric, SampleString};
use snafu::{prelude::*, ResultExt};
use tokio::runtime::{Handle, Runtime};
use tokio::sync::mpsc::Sender;

use crate::sql::db_connection_pool::runtime::run_sync_with_tokio;
use crate::util::schema::SchemaValidator;
use crate::UnsupportedTypeAction;

Expand Down Expand Up @@ -282,13 +282,6 @@ impl DbConnection<r2d2::PooledConnection<DuckdbConnectionManager>, DuckDBParamet
}
}

fn get_tokio_runtime() -> &'static Runtime {
// TODO: this function is a repetition of python/src/utils.rs::get_tokio_runtime.
// Think about how to refactor it
static RUNTIME: OnceLock<Runtime> = OnceLock::new();
RUNTIME.get_or_init(|| Runtime::new().expect("Failed to create Tokio runtime"))
}

impl SyncDbConnection<r2d2::PooledConnection<DuckdbConnectionManager>, DuckDBParameter>
for DuckDbConnection
{
Expand Down Expand Up @@ -448,13 +441,7 @@ impl SyncDbConnection<r2d2::PooledConnection<DuckdbConnectionManager>, DuckDBPar
)))
};

// If calling directly from Rust, there is already tokio runtime so no
// additional work is needed. If calling from Python FFI, there's no existing
// tokio runtime, so we need to start a new one.
match Handle::try_current() {
Ok(_) => create_stream(),
Err(_) => get_tokio_runtime().block_on(async { create_stream() }),
}
run_sync_with_tokio(create_stream)
}

fn execute(&self, sql: &str, params: &[DuckDBParameter]) -> Result<u64> {
Expand Down
101 changes: 52 additions & 49 deletions core/src/sql/db_connection_pool/dbconnection/odbcconn.rs
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@ use std::sync::Arc;

use crate::sql::db_connection_pool::{
dbconnection::{self, AsyncDbConnection, DbConnection, GenericError},
runtime::run_async_with_tokio,
DbConnectionPool,
};
use arrow_odbc::arrow_schema_from;
Expand All @@ -42,7 +43,7 @@ use odbc_api::handles::StatementImpl;
use odbc_api::parameter::InputParameter;
use odbc_api::Cursor;
use odbc_api::CursorImpl;
use secrecy::{SecretBox, ExposeSecret, SecretString};
use secrecy::{ExposeSecret, SecretBox, SecretString};
use snafu::prelude::*;
use snafu::Snafu;
use tokio::runtime::Handle;
Expand Down Expand Up @@ -184,69 +185,71 @@ where
let params = params.iter().map(dyn_clone::clone).collect::<Vec<_>>();
let secrets = Arc::clone(&self.params);

let join_handle = tokio::task::spawn_blocking(move || {
let handle = Handle::current();
let cxn = handle.block_on(async { conn.lock().await });
let create_stream = async || -> Result<SendableRecordBatchStream> {
let join_handle = tokio::task::spawn_blocking(move || {
let handle = Handle::current();
let cxn = handle.block_on(async { conn.lock().await });

let mut prepared = cxn.prepare(&sql)?;
let schema = Arc::new(arrow_schema_from(&mut prepared, false)?);
blocking_channel_send(&schema_tx, Arc::clone(&schema))?;
let mut prepared = cxn.prepare(&sql)?;
let schema = Arc::new(arrow_schema_from(&mut prepared, false)?);
blocking_channel_send(&schema_tx, Arc::clone(&schema))?;

let mut statement = prepared.into_statement();
let mut statement = prepared.into_statement();

bind_parameters(&mut statement, &params)?;
bind_parameters(&mut statement, &params)?;

// StatementImpl<'_>::execute is unsafe, CursorImpl<_>::new is unsafe
let cursor = unsafe {
if let SqlResult::Error { function } = statement.execute() {
return Err(Error::ODBCAPIErrorNoSource {
message: function.to_string(),
// StatementImpl<'_>::execute is unsafe, CursorImpl<_>::new is unsafe
let cursor = unsafe {
if let SqlResult::Error { function } = statement.execute() {
return Err(Error::ODBCAPIErrorNoSource {
message: function.to_string(),
}
.into());
}
.into());
}

Ok::<_, GenericError>(CursorImpl::new(statement.as_stmt_ref()))
}?;
Ok::<_, GenericError>(CursorImpl::new(statement.as_stmt_ref()))
}?;

let reader = build_odbc_reader(cursor, &schema, &secrets)?;
for batch in reader {
blocking_channel_send(&batch_tx, batch.context(ArrowSnafu)?)?;
}
let reader = build_odbc_reader(cursor, &schema, &secrets)?;
for batch in reader {
blocking_channel_send(&batch_tx, batch.context(ArrowSnafu)?)?;
}

Ok::<_, GenericError>(())
});
Ok::<_, GenericError>(())
});

// we need to wait for the schema first before we can build our RecordBatchStreamAdapter
let Some(schema) = schema_rx.recv().await else {
// if the channel drops, the task errored
if !join_handle.is_finished() {
unreachable!("Schema channel should not have dropped before the task finished");
}
// we need to wait for the schema first before we can build our RecordBatchStreamAdapter
let Some(schema) = schema_rx.recv().await else {
// if the channel drops, the task errored
if !join_handle.is_finished() {
unreachable!("Schema channel should not have dropped before the task finished");
}

let result = join_handle.await?;
let Err(err) = result else {
unreachable!("Task should have errored");
let result = join_handle.await?;
let Err(err) = result else {
unreachable!("Task should have errored");
};

return Err(err);
};

return Err(err);
};
let output_stream = stream! {
while let Some(batch) = batch_rx.recv().await {
yield Ok(batch);
}

let output_stream = stream! {
while let Some(batch) = batch_rx.recv().await {
yield Ok(batch);
}
if let Err(e) = join_handle.await {
yield Err(DataFusionError::Execution(format!(
"Failed to execute ODBC query: {e}"
)))
}
};

if let Err(e) = join_handle.await {
yield Err(DataFusionError::Execution(format!(
"Failed to execute ODBC query: {e}"
)))
}
let result: SendableRecordBatchStream =
Box::pin(RecordBatchStreamAdapter::new(schema, output_stream));
Ok(result)
};

Ok(Box::pin(RecordBatchStreamAdapter::new(
schema,
output_stream,
)))
run_async_with_tokio(create_stream).await
}

async fn execute(&self, query: &str, params: &[ODBCParameter]) -> Result<u64> {
Expand Down
2 changes: 0 additions & 2 deletions core/src/sql/db_connection_pool/duckdbpool.rs
Original file line number Diff line number Diff line change
Expand Up @@ -54,7 +54,6 @@ impl std::fmt::Debug for DuckDbConnectionPool {
}

impl DuckDbConnectionPool {

/// Get the dataset path. Returns `:memory:` if the in memory database is used.
pub fn db_path(&self) -> &str {
self.path.as_ref()
Expand Down Expand Up @@ -266,7 +265,6 @@ mod test {

use super::*;
use crate::sql::db_connection_pool::DbConnectionPool;
use std::sync::Arc;

fn random_db_name() -> String {
let mut rng = rand::rng();
Expand Down
1 change: 1 addition & 0 deletions core/src/sql/db_connection_pool/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@ pub mod mysqlpool;
pub mod odbcpool;
#[cfg(feature = "postgres")]
pub mod postgrespool;
pub mod runtime;
#[cfg(feature = "sqlite")]
pub mod sqlitepool;

Expand Down
2 changes: 1 addition & 1 deletion core/src/sql/db_connection_pool/odbcpool.rs
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@ use crate::sql::db_connection_pool::dbconnection::odbcconn::{ODBCDbConnection, O
use crate::sql::db_connection_pool::{DbConnectionPool, JoinPushDown};
use async_trait::async_trait;
use odbc_api::{sys::AttrConnectionPooling, Connection, ConnectionOptions, Environment};
use secrecy::{SecretBox, ExposeSecret, SecretString};
use secrecy::{ExposeSecret, SecretBox, SecretString};
use sha2::{Digest, Sha256};
use snafu::prelude::*;
use std::{
Expand Down
5 changes: 3 additions & 2 deletions core/src/sql/db_connection_pool/postgrespool.rs
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@ use secrecy::{ExposeSecret, SecretBox, SecretString};
use snafu::{prelude::*, ResultExt};
use tokio_postgres;

use super::DbConnectionPool;
use super::{runtime::run_async_with_tokio, DbConnectionPool};
use crate::sql::db_connection_pool::{
dbconnection::{postgresconn::PostgresConnection, AsyncDbConnection, DbConnection},
JoinPushDown,
Expand Down Expand Up @@ -385,7 +385,8 @@ impl
>,
> {
let pool = Arc::clone(&self.pool);
let conn = pool.get_owned().await.context(ConnectionPoolRunSnafu)?;
let get_conn = async || pool.get_owned().await.context(ConnectionPoolRunSnafu);
let conn = run_async_with_tokio(get_conn).await?;
Comment thread
phillipleblanc marked this conversation as resolved.
Ok(Box::new(
PostgresConnection::new(conn)
.with_unsupported_type_action(self.unsupported_type_action),
Expand Down
Loading