Skip to content

Commit 8f2a63f

Browse files
Merge pull request #112 from hozan23/rebase-spiceai-main-branches
Rebase spiceai onto main branch
2 parents 55cfc67 + 9332074 commit 8f2a63f

30 files changed

Lines changed: 2179 additions & 414 deletions

File tree

.github/workflows/pr.yaml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@ on:
55
pull_request:
66
branches:
77
- main
8+
- spiceai
89

910
jobs:
1011
lint:

Cargo.toml

Lines changed: 29 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -8,26 +8,25 @@ license = "Apache-2.0"
88
description = "Extend the capabilities of DataFusion to support additional data sources via implementations of the `TableProvider` trait."
99

1010
[dependencies]
11-
arrow = "53.1.0"
12-
arrow-array = { version = "53.1.0", optional = true }
13-
arrow-cast = { version = "53.1.0", optional = true }
14-
arrow-flight = { version = "53.1.0", optional = true, features = ["flight-sql-experimental", "tls"] }
15-
arrow-schema = { version = "53.1.0", optional = true, features = ["serde"] }
16-
arrow-json = "53.1.0"
11+
arrow = "53"
12+
arrow-array = { version = "53", optional = true }
13+
arrow-cast = { version = "53", optional = true }
14+
arrow-flight = { version = "53", optional = true, features = ["flight-sql-experimental", "tls"] }
15+
arrow-schema = { version = "53", optional = true, features = ["serde"] }
16+
arrow-json = "53"
1717
async-stream = { version = "0.3.5", optional = true }
1818
async-trait = "0.1.80"
1919
num-bigint = "0.4.4"
2020
bigdecimal = "0.4.5"
21-
bigdecimal_0_3_0 = { package = "bigdecimal", version = "0.3.0" }
2221
byteorder = "1.5.0"
2322
chrono = "0.4.38"
2423
datafusion = "42.0.0"
2524
datafusion-expr = { version = "42.0.0", optional = true }
2625
datafusion-physical-expr = { version = "42.0.0", optional = true }
2726
datafusion-physical-plan = { version = "42.0.0", optional = true }
2827
datafusion-proto = { version = "42.0.0", optional = true }
29-
datafusion-federation = { version = "0.3.0", features = ["sql"] }
30-
duckdb = { version = "1", features = [
28+
datafusion-federation = { version = "0.3.0", features = ["sql"] }
29+
duckdb = { version = "1.1.1", features = [
3130
"bundled",
3231
"r2d2",
3332
"vtab",
@@ -37,16 +36,28 @@ duckdb = { version = "1", features = [
3736
fallible-iterator = "0.3.0"
3837
futures = "0.3.30"
3938
mysql_async = { version = "0.34.1", features = ["native-tls-tls", "chrono"], optional = true }
39+
prost = { version = "0.13.2", optional = true }
4040
r2d2 = { version = "0.8.10", optional = true }
4141
rusqlite = { version = "0.31.0", optional = true }
42-
sea-query = { version = "0.31.0", features = ["backend-sqlite", "backend-postgres", "postgres-array", "with-rust_decimal", "with-bigdecimal", "with-time", "with-chrono"] }
42+
sea-query = { version = "0.32.0-rc.1", features = [
43+
"backend-sqlite",
44+
"backend-postgres",
45+
"postgres-array",
46+
"with-rust_decimal",
47+
"with-bigdecimal",
48+
"with-time",
49+
"with-chrono"] }
4350
secrecy = "0.8.0"
4451
serde = { version = "1.0.209", optional = true }
4552
serde_json = "1.0.124"
4653
snafu = "0.8.3"
4754
time = "0.3.36"
4855
tokio = { version = "1.38.0", features = ["macros", "fs"] }
49-
tokio-postgres = { version = "0.7.10", features = ["with-chrono-0_4", "with-uuid-1", "with-serde_json-1", "with-geo-types-0_7"], optional = true }
56+
tokio-postgres = { version = "0.7.10", features = [
57+
"with-chrono-0_4",
58+
"with-uuid-1",
59+
"with-serde_json-1",
60+
"with-geo-types-0_7"], optional = true }
5061
tracing = "0.1.40"
5162
uuid = { version = "1.9.1", optional = true }
5263
postgres-native-tls = { version = "0.5.0", optional = true }
@@ -57,9 +68,11 @@ trust-dns-resolver = "0.23.2"
5768
url = "2.5.1"
5869
pem = { version = "3.0.4", optional = true }
5970
tokio-rusqlite = { version = "0.5.1", optional = true }
60-
tonic = { version = "0.12", optional = true }
71+
tonic = { version = "0.12.2", optional = true }
6172
itertools = "0.13.0"
73+
dyn-clone = { version = "1.0.17", optional = true }
6274
geo-types = "0.7.13"
75+
fundu = "2.0.1"
6376

6477
[dev-dependencies]
6578
anyhow = "1.0.86"
@@ -79,7 +92,7 @@ prost = { version = "0.13"}
7992
mysql = ["dep:mysql_async", "dep:async-stream"]
8093
postgres = ["dep:tokio-postgres", "dep:uuid", "dep:postgres-native-tls", "dep:bb8", "dep:bb8-postgres", "dep:native-tls", "dep:pem", "dep:async-stream"]
8194
sqlite = ["dep:rusqlite", "dep:tokio-rusqlite"]
82-
duckdb = ["dep:duckdb", "dep:r2d2", "dep:uuid"]
95+
duckdb = ["dep:duckdb", "dep:r2d2", "dep:uuid", "dep:dyn-clone", "dep:async-stream"]
8396
flight = [
8497
"dep:arrow-array",
8598
"dep:arrow-cast",
@@ -94,3 +107,6 @@ flight = [
94107
]
95108
duckdb-federation = ["duckdb"]
96109
sqlite-federation = ["sqlite"]
110+
postgres-federation = ["postgres"]
111+
112+

examples/duckdb_external_table.rs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,7 @@ use duckdb::AccessMode;
1212
/// DuckDB-backed tables can be created at runtime.
1313
#[tokio::main]
1414
async fn main() {
15-
let duckdb = Arc::new(DuckDBTableProviderFactory::new().access_mode(AccessMode::ReadWrite));
15+
let duckdb = Arc::new(DuckDBTableProviderFactory::new(AccessMode::ReadWrite));
1616

1717
let runtime = Arc::new(RuntimeEnv::default());
1818
let state = SessionStateBuilder::new()

src/duckdb.rs

Lines changed: 66 additions & 32 deletions
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,13 @@
11
use crate::sql::db_connection_pool::{
22
self,
33
dbconnection::{
4-
duckdbconn::{flatten_table_function_name, is_table_function, DuckDbConnection},
4+
duckdbconn::{
5+
flatten_table_function_name, is_table_function, DuckDBParameter, DuckDbConnection,
6+
},
57
get_schema, DbConnection,
68
},
79
duckdbpool::DuckDbConnectionPool,
8-
DbConnectionPool, Mode,
10+
DbConnectionPool, DbInstanceKey, Mode,
911
};
1012
use crate::sql::sql_provider_datafusion;
1113
use crate::util::{
@@ -25,10 +27,11 @@ use datafusion::{
2527
logical_expr::CreateExternalTable,
2628
sql::TableReference,
2729
};
28-
use duckdb::{AccessMode, DuckdbConnectionManager, ToSql, Transaction};
30+
use duckdb::{AccessMode, DuckdbConnectionManager, Transaction};
2931
use itertools::Itertools;
3032
use snafu::prelude::*;
3133
use std::{cmp, collections::HashMap, sync::Arc};
34+
use tokio::sync::Mutex;
3235

3336
use self::{creator::TableCreator, sql_table::DuckDBTable, write::DuckDBTableWriter};
3437

@@ -87,11 +90,6 @@ pub enum Error {
8790
#[snafu(display("Unable to commit transaction: {source}"))]
8891
UnableToCommitTransaction { source: duckdb::Error },
8992

90-
#[snafu(display("Unable to checkpoint duckdb: {source}"))]
91-
UnableToCheckpoint {
92-
source: Box<dyn std::error::Error + Send + Sync>,
93-
},
94-
9593
#[snafu(display("Unable to begin duckdb transaction: {source}"))]
9694
UnableToBeginTransaction { source: duckdb::Error },
9795

@@ -121,6 +119,7 @@ type Result<T, E = Error> = std::result::Result<T, E>;
121119

122120
pub struct DuckDBTableProviderFactory {
123121
access_mode: AccessMode,
122+
instances: Arc<Mutex<HashMap<DbInstanceKey, DuckDbConnectionPool>>>,
124123
}
125124

126125
const DUCKDB_DB_PATH_PARAM: &str = "open";
@@ -129,9 +128,10 @@ const DUCKDB_ATTACH_DATABASES_PARAM: &str = "attach_databases";
129128

130129
impl DuckDBTableProviderFactory {
131130
#[must_use]
132-
pub fn new() -> Self {
131+
pub fn new(access_mode: AccessMode) -> Self {
133132
Self {
134-
access_mode: AccessMode::ReadOnly,
133+
access_mode,
134+
instances: Arc::new(Mutex::new(HashMap::new())),
135135
}
136136
}
137137

@@ -148,12 +148,6 @@ impl DuckDBTableProviderFactory {
148148
.unwrap_or_default()
149149
}
150150

151-
#[must_use]
152-
pub fn access_mode(mut self, access_mode: AccessMode) -> Self {
153-
self.access_mode = access_mode;
154-
self
155-
}
156-
157151
#[must_use]
158152
pub fn duckdb_file_path(&self, name: &str, options: &mut HashMap<String, String>) -> String {
159153
let options = util::remove_prefix_from_hashmap_keys(options.clone(), "duckdb_");
@@ -169,15 +163,44 @@ impl DuckDBTableProviderFactory {
169163
.cloned()
170164
.unwrap_or(default_filepath)
171165
}
172-
}
173166

174-
impl Default for DuckDBTableProviderFactory {
175-
fn default() -> Self {
176-
Self::new()
167+
pub async fn get_or_init_memory_instance(&self) -> Result<DuckDbConnectionPool> {
168+
let key = DbInstanceKey::memory();
169+
let mut instances = self.instances.lock().await;
170+
171+
if let Some(instance) = instances.get(&key) {
172+
return Ok(instance.clone());
173+
}
174+
175+
let pool = DuckDbConnectionPool::new_memory().context(DbConnectionPoolSnafu)?;
176+
177+
instances.insert(key, pool.clone());
178+
179+
Ok(pool)
180+
}
181+
182+
pub async fn get_or_init_file_instance(
183+
&self,
184+
db_path: impl Into<Arc<str>>,
185+
) -> Result<DuckDbConnectionPool> {
186+
let db_path = db_path.into();
187+
let key = DbInstanceKey::file(Arc::clone(&db_path));
188+
let mut instances = self.instances.lock().await;
189+
190+
if let Some(instance) = instances.get(&key) {
191+
return Ok(instance.clone());
192+
}
193+
194+
let pool = DuckDbConnectionPool::new_file(&db_path, &self.access_mode)
195+
.context(DbConnectionPoolSnafu)?;
196+
197+
instances.insert(key, pool.clone());
198+
199+
Ok(pool)
177200
}
178201
}
179202

180-
type DynDuckDbConnectionPool = dyn DbConnectionPool<r2d2::PooledConnection<DuckdbConnectionManager>, &'static dyn ToSql>
203+
type DynDuckDbConnectionPool = dyn DbConnectionPool<r2d2::PooledConnection<DuckdbConnectionManager>, DuckDBParameter>
181204
+ Send
182205
+ Sync;
183206

@@ -229,12 +252,13 @@ impl TableProviderFactory for DuckDBTableProviderFactory {
229252
// open duckdb at given path or create a new one
230253
let db_path = self.duckdb_file_path(&name, &mut options);
231254

232-
DuckDbConnectionPool::new_file(&db_path, &self.access_mode)
233-
.context(DbConnectionPoolSnafu)
255+
self.get_or_init_file_instance(db_path)
256+
.await
234257
.map_err(to_datafusion_error)?
235258
}
236-
Mode::Memory => DuckDbConnectionPool::new_memory()
237-
.context(DbConnectionPoolSnafu)
259+
Mode::Memory => self
260+
.get_or_init_memory_instance()
261+
.await
238262
.map_err(to_datafusion_error)?,
239263
};
240264

@@ -265,7 +289,12 @@ impl TableProviderFactory for DuckDBTableProviderFactory {
265289
));
266290

267291
#[cfg(feature = "duckdb-federation")]
268-
let read_provider = Arc::new(read_provider.create_federated_table_provider()?);
292+
let read_provider: Arc<dyn TableProvider> = if mode == Mode::File {
293+
// federation is disabled for in-memory mode until memory connections are updated to use the same database instance instead of separate instances
294+
Arc::new(read_provider.create_federated_table_provider()?)
295+
} else {
296+
read_provider
297+
};
269298

270299
Ok(DuckDBTableWriter::create(
271300
read_provider,
@@ -317,18 +346,18 @@ impl DuckDB {
317346
pub fn connect_sync(
318347
&self,
319348
) -> Result<
320-
Box<dyn DbConnection<r2d2::PooledConnection<DuckdbConnectionManager>, &'static dyn ToSql>>,
349+
Box<dyn DbConnection<r2d2::PooledConnection<DuckdbConnectionManager>, DuckDBParameter>>,
321350
> {
322351
Arc::clone(&self.pool)
323352
.connect_sync()
324353
.context(DbConnectionSnafu)
325354
}
326355

327-
pub fn duckdb_conn<'a>(
328-
db_connection: &'a mut Box<
329-
dyn DbConnection<r2d2::PooledConnection<DuckdbConnectionManager>, &'static dyn ToSql>,
356+
pub fn duckdb_conn(
357+
db_connection: &mut Box<
358+
dyn DbConnection<r2d2::PooledConnection<DuckdbConnectionManager>, DuckDBParameter>,
330359
>,
331-
) -> Result<&'a mut DuckDbConnection> {
360+
) -> Result<&mut DuckDbConnection> {
332361
db_connection
333362
.as_any_mut()
334363
.downcast_mut::<DuckDbConnection>()
@@ -441,7 +470,12 @@ impl DuckDBTableFactory {
441470
));
442471

443472
#[cfg(feature = "duckdb-federation")]
444-
let table_provider = Arc::new(table_provider.create_federated_table_provider()?);
473+
let table_provider: Arc<dyn TableProvider> = if self.pool.mode() == Mode::File {
474+
// federation is disabled for in-memory mode until memory connections are updated to use the same database instance instead of separate instances
475+
Arc::new(table_provider.create_federated_table_provider()?)
476+
} else {
477+
table_provider
478+
};
445479

446480
Ok(table_provider)
447481
}

0 commit comments

Comments
 (0)