Skip to content
Merged
Show file tree
Hide file tree
Changes from 6 commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
17 changes: 16 additions & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -113,7 +113,22 @@ cargo run --example flight-sql --features flight
apt-get install unixodbc-dev libsqliteodbc
# or
# brew install unixodbc & brew install sqliteodbc
# If you use ARM Mac, please see https://github.com/pacman82/odbc-api#os-x-arm--mac-m1

cargo run --example odbc_sqlite --features odbc
```

#### ARM Mac

Please see https://github.com/pacman82/odbc-api#os-x-arm--mac-m1 for reference.

Steps:
1. Install unixodbc and sqliteodbc by `brew install unixodbc sqliteodbc`.
2. Find local sqliteodbc driver path by running `brew info sqliteodbc`. The path might look like `/opt/homebrew/Cellar/sqliteodbc/0.99991`.
3. Set up odbc config file at `~/.odbcinst.ini` with your local sqliteodbc path.
Example config file:
```
[SQLite3]
Description = SQLite3 ODBC Driver
Driver = /opt/homebrew/Cellar/sqliteodbc/0.99991/lib/libsqlite3odbc.dylib
```
4. Test configuration by running `odbcinst -q -d -n SQLite3`. If the path is printed out correctly, then you are all set.
2 changes: 1 addition & 1 deletion core/examples/odbc_sqlite.rs
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@ async fn main() {
// Create SQLite ODBC connection pool
let params = to_secret_map(HashMap::from([(
"connection_string".to_owned(),
"driver=SQLite3;database=examples/sqlite_example.db;".to_owned(),
"driver=SQLite3;database=core/examples/sqlite_example.db;".to_owned(),
)]));
let odbc_pool =
Arc::new(ODBCPool::new(params).expect("unable to create SQLite ODBC connection pool"));
Expand Down
12 changes: 3 additions & 9 deletions core/src/sql/db_connection_pool/dbconnection/duckdbconn.rs
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
use std::any::Any;
use std::sync::{Arc, OnceLock};
use std::sync::Arc;

use arrow::array::RecordBatch;
use arrow_schema::{DataType, Field};
Expand All @@ -18,9 +18,10 @@ use duckdb::{Connection, DuckdbConnectionManager};
use dyn_clone::DynClone;
use rand::distr::{Alphanumeric, SampleString};
use snafu::{prelude::*, ResultExt};
use tokio::runtime::{Handle, Runtime};
use tokio::runtime::Handle;
use tokio::sync::mpsc::Sender;

use crate::sql::db_connection_pool::runtime::get_tokio_runtime;
use crate::util::schema::SchemaValidator;
use crate::UnsupportedTypeAction;

Expand Down Expand Up @@ -282,13 +283,6 @@ impl DbConnection<r2d2::PooledConnection<DuckdbConnectionManager>, DuckDBParamet
}
}

fn get_tokio_runtime() -> &'static Runtime {
// TODO: this function is a repetition of python/src/utils.rs::get_tokio_runtime.
// Think about how to refactor it
static RUNTIME: OnceLock<Runtime> = OnceLock::new();
RUNTIME.get_or_init(|| Runtime::new().expect("Failed to create Tokio runtime"))
}

impl SyncDbConnection<r2d2::PooledConnection<DuckdbConnectionManager>, DuckDBParameter>
for DuckDbConnection
{
Expand Down
104 changes: 55 additions & 49 deletions core/src/sql/db_connection_pool/dbconnection/odbcconn.rs
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@ use std::sync::Arc;

use crate::sql::db_connection_pool::{
dbconnection::{self, AsyncDbConnection, DbConnection, GenericError},
runtime::get_tokio_runtime,
DbConnectionPool,
};
use arrow_odbc::arrow_schema_from;
Expand All @@ -42,7 +43,7 @@ use odbc_api::handles::StatementImpl;
use odbc_api::parameter::InputParameter;
use odbc_api::Cursor;
use odbc_api::CursorImpl;
use secrecy::{SecretBox, ExposeSecret, SecretString};
use secrecy::{ExposeSecret, SecretBox, SecretString};
use snafu::prelude::*;
use snafu::Snafu;
use tokio::runtime::Handle;
Expand Down Expand Up @@ -184,69 +185,74 @@ where
let params = params.iter().map(dyn_clone::clone).collect::<Vec<_>>();
let secrets = Arc::clone(&self.params);

let join_handle = tokio::task::spawn_blocking(move || {
let handle = Handle::current();
let cxn = handle.block_on(async { conn.lock().await });
let create_stream = async || -> Result<SendableRecordBatchStream> {
let join_handle = tokio::task::spawn_blocking(move || {
let handle = Handle::current();
let cxn = handle.block_on(async { conn.lock().await });

let mut prepared = cxn.prepare(&sql)?;
let schema = Arc::new(arrow_schema_from(&mut prepared, false)?);
blocking_channel_send(&schema_tx, Arc::clone(&schema))?;
let mut prepared = cxn.prepare(&sql)?;
let schema = Arc::new(arrow_schema_from(&mut prepared, false)?);
blocking_channel_send(&schema_tx, Arc::clone(&schema))?;

let mut statement = prepared.into_statement();
let mut statement = prepared.into_statement();

bind_parameters(&mut statement, &params)?;
bind_parameters(&mut statement, &params)?;

// StatementImpl<'_>::execute is unsafe, CursorImpl<_>::new is unsafe
let cursor = unsafe {
if let SqlResult::Error { function } = statement.execute() {
return Err(Error::ODBCAPIErrorNoSource {
message: function.to_string(),
// StatementImpl<'_>::execute is unsafe, CursorImpl<_>::new is unsafe
let cursor = unsafe {
if let SqlResult::Error { function } = statement.execute() {
return Err(Error::ODBCAPIErrorNoSource {
message: function.to_string(),
}
.into());
}
.into());
}

Ok::<_, GenericError>(CursorImpl::new(statement.as_stmt_ref()))
}?;
Ok::<_, GenericError>(CursorImpl::new(statement.as_stmt_ref()))
}?;

let reader = build_odbc_reader(cursor, &schema, &secrets)?;
for batch in reader {
blocking_channel_send(&batch_tx, batch.context(ArrowSnafu)?)?;
}
let reader = build_odbc_reader(cursor, &schema, &secrets)?;
for batch in reader {
blocking_channel_send(&batch_tx, batch.context(ArrowSnafu)?)?;
}

Ok::<_, GenericError>(())
});
Ok::<_, GenericError>(())
});

// we need to wait for the schema first before we can build our RecordBatchStreamAdapter
let Some(schema) = schema_rx.recv().await else {
// if the channel drops, the task errored
if !join_handle.is_finished() {
unreachable!("Schema channel should not have dropped before the task finished");
}
// we need to wait for the schema first before we can build our RecordBatchStreamAdapter
let Some(schema) = schema_rx.recv().await else {
// if the channel drops, the task errored
if !join_handle.is_finished() {
unreachable!("Schema channel should not have dropped before the task finished");
}

let result = join_handle.await?;
let Err(err) = result else {
unreachable!("Task should have errored");
let result = join_handle.await?;
let Err(err) = result else {
unreachable!("Task should have errored");
};

return Err(err);
};

return Err(err);
};
let output_stream = stream! {
while let Some(batch) = batch_rx.recv().await {
yield Ok(batch);
}

let output_stream = stream! {
while let Some(batch) = batch_rx.recv().await {
yield Ok(batch);
}
if let Err(e) = join_handle.await {
yield Err(DataFusionError::Execution(format!(
"Failed to execute ODBC query: {e}"
)))
}
};

if let Err(e) = join_handle.await {
yield Err(DataFusionError::Execution(format!(
"Failed to execute ODBC query: {e}"
)))
}
let result: SendableRecordBatchStream =
Box::pin(RecordBatchStreamAdapter::new(schema, output_stream));
Ok(result)
};

Ok(Box::pin(RecordBatchStreamAdapter::new(
schema,
output_stream,
)))
match Handle::try_current() {
Ok(_) => create_stream().await,
Err(_) => get_tokio_runtime().block_on(async { create_stream().await }),
}
}

async fn execute(&self, query: &str, params: &[ODBCParameter]) -> Result<u64> {
Expand Down
1 change: 1 addition & 0 deletions core/src/sql/db_connection_pool/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@ pub mod mysqlpool;
pub mod odbcpool;
#[cfg(feature = "postgres")]
pub mod postgrespool;
pub mod runtime;
#[cfg(feature = "sqlite")]
pub mod sqlitepool;

Expand Down
2 changes: 1 addition & 1 deletion core/src/sql/db_connection_pool/odbcpool.rs
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@ use crate::sql::db_connection_pool::dbconnection::odbcconn::{ODBCDbConnection, O
use crate::sql::db_connection_pool::{DbConnectionPool, JoinPushDown};
use async_trait::async_trait;
use odbc_api::{sys::AttrConnectionPooling, Connection, ConnectionOptions, Environment};
use secrecy::{SecretBox, ExposeSecret, SecretString};
use secrecy::{ExposeSecret, SecretBox, SecretString};
use sha2::{Digest, Sha256};
use snafu::prelude::*;
use std::{
Expand Down
10 changes: 10 additions & 0 deletions core/src/sql/db_connection_pool/runtime.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,10 @@
use std::sync::OnceLock;

use tokio::runtime::Runtime;

pub fn get_tokio_runtime() -> &'static Runtime {
// TODO: this function is a repetition of python/src/utils.rs::get_tokio_runtime.
// Think about how to refactor it
Comment thread
crystalxyz marked this conversation as resolved.
Outdated
static RUNTIME: OnceLock<Runtime> = OnceLock::new();
RUNTIME.get_or_init(|| Runtime::new().expect("Failed to create Tokio runtime"))
}
2 changes: 1 addition & 1 deletion python/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@ doc = false
arrow = { workspace = true }
datafusion = { workspace = true, features = ["pyarrow"] }
datafusion-ffi = { workspace = true }
datafusion-table-providers = { workspace = true, features = ["sqlite", "duckdb"] }
datafusion-table-providers = { workspace = true, features = ["sqlite", "duckdb", "odbc"] }
pyo3 = { version = "0.23" }
tokio = { version = "1.42", features = ["macros", "rt", "rt-multi-thread", "sync"] }
duckdb = { workspace = true }
9 changes: 9 additions & 0 deletions python/examples/odbc_sqlite_demo.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,9 @@
from datafusion import SessionContext
from datafusion_table_providers import odbc

ctx = SessionContext()
connection_param: dict = {'connection_string': 'driver=SQLite3;database=../../core/examples/sqlite_example.db;'}
pool = odbc.ODBCTableFactory(connection_param)

ctx.register_table_provider(name = "companies", provider = pool.get_table("companies"))
ctx.table("companies").show()
35 changes: 35 additions & 0 deletions python/python/datafusion_table_providers/odbc.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,35 @@
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
"""Python interface for ODBC table provider."""

from typing import Any, List
from . import _internal

class ODBCTableFactory:
"""ODBC table factory."""

def __init__(self, params: dict) -> None:
"""Create am odbc table factory."""
Comment thread
crystalxyz marked this conversation as resolved.
Outdated
self._raw = _internal.odbc.RawODBCTableFactory(params)

def get_table(self, table_reference: str) -> Any:
"""Return the table provider for table named `table_reference`.

Args:
table_reference (str): table name
"""
return self._raw.get_table(table_reference)
4 changes: 4 additions & 0 deletions python/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -49,5 +49,9 @@ fn _internal(py: Python, m: &Bound<'_, PyModule>) -> PyResult<()> {
duckdb::init_module(&duckdb)?;
m.add_submodule(&duckdb)?;

let odbc = PyModule::new(py, "odbc")?;
odbc::init_module(&odbc)?;
m.add_submodule(&odbc)?;

Ok(())
}
65 changes: 65 additions & 0 deletions python/src/odbc.rs
Original file line number Diff line number Diff line change
@@ -1 +1,66 @@
use std::{collections::HashMap, sync::Arc};

use datafusion_table_providers::{
odbc::ODBCTableFactory, sql::db_connection_pool::odbcpool::ODBCPool,
util::secrets::to_secret_map,
};
use pyo3::{prelude::*, types::PyDict};

use crate::{
utils::{to_pyerr, wait_for_future},
RawTableProvider,
};

#[pyclass(module = "datafusion_table_providers._internal.odbc")]
struct RawODBCTableFactory {
_pool: Arc<ODBCPool>,
// TODO: 'static lifetime might be wrong, we want the lifetime to be 'py but it is
// still unclear how to define it.
factory: ODBCTableFactory<'static>,
Comment thread
phillipleblanc marked this conversation as resolved.
}

#[pymethods]
impl RawODBCTableFactory {
#[new]
#[pyo3(signature = (params))]
pub fn new(params: &Bound<'_, PyDict>) -> PyResult<Self> {
// Convert Python dict into Rust hashmap, and convert it to secret map
let mut hashmap = HashMap::new();
for (key, value) in params.iter() {
let key: String = key.extract()?;
let value: String = value.extract()?;
hashmap.insert(key, value);
}
let hashmap = to_secret_map(hashmap);
Comment thread
crystalxyz marked this conversation as resolved.

let pool = Arc::new(ODBCPool::new(hashmap).map_err(to_pyerr)?);
Ok(Self {
factory: ODBCTableFactory::new(pool.clone()),
_pool: pool,
})
}

pub fn tables(&self) -> PyResult<Vec<String>> {
// This method is not supported yet because of unimplemented traints in odbcconn.
unimplemented!();
Comment thread
phillipleblanc marked this conversation as resolved.
}
Comment thread
phillipleblanc marked this conversation as resolved.

pub fn get_table(&self, py: Python, table_reference: &str) -> PyResult<RawTableProvider> {
let table = wait_for_future(
py,
self.factory.table_provider(table_reference.into(), None),
Comment thread
crystalxyz marked this conversation as resolved.
)
.map_err(to_pyerr)?;

Ok(RawTableProvider {
table,
supports_pushdown_filters: true,
})
}
}

pub(crate) fn init_module(m: &Bound<'_, PyModule>) -> PyResult<()> {
m.add_class::<RawODBCTableFactory>()?;

Ok(())
}