Skip to content

Commit 0fd6476

Browse files
committed
Make DuckDB database attachments logic more robust
1 parent c371d52 commit 0fd6476

1 file changed

Lines changed: 169 additions & 49 deletions

File tree

src/sql/db_connection_pool/dbconnection/duckdbconn.rs

Lines changed: 169 additions & 49 deletions
Original file line numberDiff line numberDiff line change
@@ -65,50 +65,57 @@ pub type DuckDBParameter = Box<dyn DuckDBSyncParameter>;
6565
#[derive(Debug)]
6666
pub struct DuckDBAttachments {
6767
attachments: HashSet<Arc<str>>,
68-
search_path: Arc<str>,
6968
random_id: String,
69+
main_db: String,
7070
}
7171

7272
impl DuckDBAttachments {
7373
/// Creates a new instance of a `DuckDBAttachments`, which instructs DuckDB connections to attach other DuckDB databases for queries.
7474
#[must_use]
75-
pub fn new(id: &str, attachments: &[Arc<str>]) -> Self {
75+
pub fn new(main_db: &str, attachments: &[Arc<str>]) -> Self {
7676
let random_id = Alphanumeric.sample_string(&mut rand::thread_rng(), 8);
7777
let attachments: HashSet<Arc<str>> = attachments.iter().cloned().collect();
78-
let search_path = Self::get_search_path(id, &random_id, &attachments);
7978
Self {
8079
attachments,
81-
search_path,
8280
random_id,
81+
main_db: main_db.to_string(),
8382
}
8483
}
8584

8685
/// Returns the search path for the given database and attachments.
8786
/// The given database needs to be included separately, as search path by default do not include the main database.
87+
/// The `attachments` parameter represents full attachment names, e.g., ["attachment_zCVN0zYJ_0", ...]
8888
#[must_use]
89-
fn get_search_path(id: &str, random_id: &str, attachments: &HashSet<Arc<str>>) -> Arc<str> {
90-
// search path includes the main database and all attached databases
91-
let mut search_path: Vec<Arc<str>> = vec![id.into()];
89+
fn get_search_path<'a>(id: &str, attachments: impl IntoIterator<Item = &'a str>) -> Arc<str> {
90+
let mut path = String::from(id);
9291

93-
search_path.extend(
94-
attachments
95-
.iter()
96-
.enumerate()
97-
.map(|(i, _)| Self::get_attachment_name(random_id, i).into()),
98-
);
92+
for attachment in attachments {
93+
path.push(',');
94+
path.push_str(attachment);
95+
}
9996

100-
search_path.join(",").into()
97+
Arc::from(path)
10198
}
10299

103100
/// Sets the search path for the given connection.
104101
///
102+
/// The `attachments` parameter represents full attachment names, e.g., ["attachment_zCVN0zYJ_0", ...]
105103
/// # Errors
106104
///
107105
/// Returns an error if the search path cannot be set or the connection fails.
108-
pub fn set_search_path(&self, conn: &Connection) -> Result<()> {
109-
conn.execute(&format!("SET search_path ='{}'", self.search_path), [])
106+
/// Returns search path if successful.
107+
pub fn set_search_path<'a>(
108+
&self,
109+
conn: &Connection,
110+
attachments: impl IntoIterator<Item = &'a str>,
111+
) -> Result<Arc<str>> {
112+
let search_path = Self::get_search_path(&self.main_db, attachments);
113+
114+
tracing::trace!("Setting search_path to {search_path}");
115+
116+
conn.execute(&format!("SET search_path ='{}'", search_path), [])
110117
.context(DuckDBConnectionSnafu)?;
111-
Ok(())
118+
Ok(search_path)
112119
}
113120

114121
/// Resets the search path for the given connection to default.
@@ -123,27 +130,59 @@ impl DuckDBAttachments {
123130
}
124131

125132
/// Attaches the databases to the given connection and sets the search path for the newly attached databases.
133+
/// If connection already contains attachments, it will skip the attachments override (including search_path).
126134
///
127135
/// # Errors
128136
///
129137
/// Returns an error if a specific attachment is missing, cannot be attached, search path cannot be set or the connection fails.
130-
pub fn attach(&self, conn: &Connection) -> Result<()> {
138+
/// Returns search path if successful.
139+
pub fn attach(&self, conn: &Connection) -> Result<Arc<str>> {
140+
// Check if attachments already exist; skip attachments override in this case as it requires changing the search_path
141+
let mut stmt = conn
142+
.prepare("PRAGMA database_list;")
143+
.context(DuckDBConnectionSnafu)?;
144+
let mut rows = stmt.query([]).context(DuckDBConnectionSnafu)?;
145+
146+
let mut existing_attachments = std::collections::HashMap::new();
147+
while let Some(row) = rows.next()? {
148+
let db_name: String = row.get(1)?;
149+
let db_path: Option<String> = row.get(2)?;
150+
if db_name.starts_with("attachment_") {
151+
// attachment always has a path so it is safe to use unwrap_or_default
152+
existing_attachments.insert(db_path.unwrap_or_default(), db_name);
153+
}
154+
}
155+
156+
// Check if the connection already contains the desired attachments
157+
if !existing_attachments.is_empty() {
158+
tracing::trace!(
159+
"Attachments {:?} creation skipped as connection contains existing attachments: {existing_attachments:?}",
160+
self.attachments
161+
);
162+
for db in &self.attachments {
163+
if !existing_attachments.contains_key(db.as_ref()) {
164+
tracing::warn!("{db} not found among existing attachments");
165+
}
166+
}
167+
// The connection can have attachments but not the search_path, so we must set it based on the existing attachment names
168+
return self.set_search_path(conn, existing_attachments.values().map(|s| s.as_str()));
169+
}
170+
171+
let mut created_attachments = Vec::new();
172+
131173
for (i, db) in self.attachments.iter().enumerate() {
132174
// check the db file exists
133175
std::fs::metadata(db.as_ref()).context(UnableToAttachDatabaseSnafu {
134176
path: Arc::clone(db),
135177
})?;
136-
let sql = format!(
137-
"ATTACH IF NOT EXISTS '{db}' AS {} (READ_ONLY)",
138-
Self::get_attachment_name(&self.random_id, i)
139-
);
178+
let attachment_name = Self::get_attachment_name(&self.random_id, i);
179+
let sql = format!("ATTACH IF NOT EXISTS '{db}' AS {attachment_name} (READ_ONLY)");
140180
tracing::trace!("Attaching {db} using: {sql}");
141-
142181
conn.execute(&sql, []).context(DuckDBConnectionSnafu)?;
182+
created_attachments.push(attachment_name);
143183
}
144184

145-
self.set_search_path(conn)?;
146-
Ok(())
185+
self.set_search_path(conn, created_attachments.iter().map(|s| s.as_str()))
147186
}
148187

149188
/// Detaches the databases from the given connection and resets the search path to default.
@@ -154,7 +193,10 @@ impl DuckDBAttachments {
154193
pub fn detach(&self, conn: &Connection) -> Result<()> {
155194
for (i, _) in self.attachments.iter().enumerate() {
156195
conn.execute(
157-
&format!("DETACH {}", Self::get_attachment_name(&self.random_id, i)),
196+
&format!(
197+
"DETACH DATABASE IF EXISTS {}",
198+
Self::get_attachment_name(&self.random_id, i)
199+
),
158200
[],
159201
)
160202
.context(DuckDBConnectionSnafu)?;
@@ -320,11 +362,12 @@ impl SyncDbConnection<r2d2::PooledConnection<DuckdbConnectionManager>, DuckDBPar
320362
) -> Result<SendableRecordBatchStream> {
321363
let (batch_tx, mut batch_rx) = tokio::sync::mpsc::channel::<RecordBatch>(4);
322364

323-
Self::attach(&self.conn, &self.attachments)?;
365+
let conn = self.conn.try_clone()?;
366+
Self::attach(&conn, &self.attachments)?;
367+
324368
let fetch_schema_sql =
325369
format!("WITH fetch_schema AS ({sql}) SELECT * FROM fetch_schema LIMIT 0");
326-
let mut stmt = self
327-
.conn
370+
let mut stmt = conn
328371
.prepare(&fetch_schema_sql)
329372
.boxed()
330373
.context(super::UnableToGetSchemaSnafu)?;
@@ -334,21 +377,15 @@ impl SyncDbConnection<r2d2::PooledConnection<DuckdbConnectionManager>, DuckDBPar
334377
.boxed()
335378
.context(super::UnableToGetSchemaSnafu)?;
336379

337-
Self::detach(&self.conn, &self.attachments)?;
338-
339380
let schema = result.get_schema();
340381

341382
let params = params.iter().map(dyn_clone::clone).collect::<Vec<_>>();
342383

343-
let conn = self.conn.try_clone()?; // try_clone creates a new connection to the same database
344-
// this creates a new connection session, requiring resetting the ATTACHments and search_path
345384
let sql = sql.to_string();
346385

347386
let cloned_schema = schema.clone();
348-
let attachments = self.attachments.clone();
349387

350388
let join_handle = tokio::task::spawn_blocking(move || {
351-
Self::attach(&conn, &attachments)?; // this attach could happen when we clone the connection, but we can't detach after the thread closes because the connection isn't thread safe
352389
let mut stmt = conn.prepare(&sql).context(DuckDBQuerySnafu)?;
353390
let params: &[&dyn ToSql] = &params
354391
.iter()
@@ -360,8 +397,6 @@ impl SyncDbConnection<r2d2::PooledConnection<DuckdbConnectionManager>, DuckDBPar
360397
for i in result {
361398
blocking_channel_send(&batch_tx, i)?;
362399
}
363-
364-
Self::detach(&conn, &attachments)?;
365400
Ok::<_, Box<dyn std::error::Error + Send + Sync>>(())
366401
});
367402

@@ -620,10 +655,31 @@ mod tests {
620655
}
621656

622657
#[test]
623-
fn test_duckdb_attachments_search_path() {
624-
let db1 = Arc::from("db1.duckdb");
625-
let db2 = Arc::from("db2.duckdb");
626-
let db3 = Arc::from("db3.duckdb");
658+
fn test_duckdb_attachments_search_path() -> Result<()> {
659+
let temp_dir = tempdir()?;
660+
let db1: Arc<str> = temp_dir
661+
.path()
662+
.join("db1.duckdb")
663+
.to_str()
664+
.expect("to convert path to str")
665+
.into();
666+
let db2: Arc<str> = temp_dir
667+
.path()
668+
.join("db2.duckdb")
669+
.to_str()
670+
.expect("to convert path to str")
671+
.into();
672+
let db3: Arc<str> = temp_dir
673+
.path()
674+
.join("db3.duckdb")
675+
.to_str()
676+
.expect("to convert path to str")
677+
.into();
678+
679+
for db in [&db1, &db2, &db3] {
680+
let conn1 = Connection::open(db.as_ref())?;
681+
conn1.execute("CREATE TABLE test1 (id INTEGER, name VARCHAR)", [])?;
682+
}
627683

628684
// Create attachments with duplicates
629685
let attachments = vec![
@@ -634,25 +690,36 @@ mod tests {
634690
Arc::clone(&db2), // duplicate of db2
635691
];
636692

637-
let duckdb_attachments = DuckDBAttachments::new("main_db", &attachments);
693+
let duckdb_attachments = DuckDBAttachments::new("main", &attachments);
694+
695+
let conn = Connection::open_in_memory()?;
696+
697+
let search_path = duckdb_attachments.attach(&conn)?;
638698

639699
// Verify that the search path contains the main database and unique attachments
640-
let search_path = duckdb_attachments.search_path.to_string();
641-
assert!(search_path.starts_with("main_db"));
700+
701+
assert!(search_path.starts_with("main"));
642702
assert!(search_path.contains("attachment_"));
643-
assert_eq!(search_path.split(',').count(), 4); // main_db + 3 unique attachments
703+
assert_eq!(search_path.split(',').count(), 4); // main + 3 unique attachments
704+
705+
Ok(())
644706
}
645707

646708
#[test]
647-
fn test_duckdb_attachments_empty() {
648-
let duckdb_attachments = DuckDBAttachments::new("main_db", &[]);
709+
fn test_duckdb_attachments_empty() -> Result<()> {
710+
let duckdb_attachments = DuckDBAttachments::new("main", &[]);
649711

650712
// Verify empty attachments
651713
assert!(duckdb_attachments.attachments.is_empty());
652714

653715
// Verify search path only contains main database
654-
let search_path = duckdb_attachments.search_path.to_string();
655-
assert_eq!(search_path, "main_db");
716+
717+
let conn = Connection::open_in_memory()?;
718+
719+
let search_path = duckdb_attachments.attach(&conn)?;
720+
assert_eq!(search_path, "main".into());
721+
722+
Ok(())
656723
}
657724

658725
#[test]
@@ -721,4 +788,57 @@ mod tests {
721788
duckdb_attachments.detach(&conn)?;
722789
Ok(())
723790
}
791+
792+
#[test]
793+
fn test_duckdb_attach_multiple_times() -> Result<()> {
794+
// Create a temporary directory for our test files
795+
let temp_dir = tempdir()?;
796+
let db1_path = temp_dir.path().join("db1.duckdb");
797+
let db2_path = temp_dir.path().join("db2.duckdb");
798+
799+
// Create two test databases with some data
800+
{
801+
let conn1 = Connection::open(&db1_path)?;
802+
conn1.execute("CREATE TABLE test1 (id INTEGER, name VARCHAR)", [])?;
803+
conn1.execute("INSERT INTO test1 VALUES (1, 'test1_1')", [])?;
804+
805+
let conn2 = Connection::open(&db2_path)?;
806+
conn2.execute("CREATE TABLE test2 (id INTEGER, name VARCHAR)", [])?;
807+
conn2.execute("INSERT INTO test2 VALUES (2, 'test2_1')", [])?;
808+
}
809+
810+
let attachments = vec![
811+
Arc::from(db1_path.to_str().expect("to convert path top str")),
812+
Arc::from(db2_path.to_str().expect("to convert path top str")),
813+
];
814+
815+
let conn = Connection::open_in_memory()?;
816+
817+
// Simulate attaching to the same connection multiple times
818+
DuckDBAttachments::new("main", &attachments).attach(&conn)?;
819+
DuckDBAttachments::new("main", &attachments).attach(&conn)?;
820+
DuckDBAttachments::new("main", &attachments).attach(&conn)?;
821+
822+
let join_result: (i64, String, i64, String) = conn
823+
.query_row(
824+
"SELECT t1.id, t1.name, t2.id, t2.name FROM test1 t1, test2 t2",
825+
[],
826+
|row| {
827+
Ok((
828+
row.get::<_, i64>(0).expect("to get i64"),
829+
row.get::<_, String>(1).expect("to get string"),
830+
row.get::<_, i64>(2).expect("to get i64"),
831+
row.get::<_, String>(3).expect("to get string"),
832+
))
833+
},
834+
)
835+
.expect("to get join result");
836+
837+
assert_eq!(
838+
join_result,
839+
(1, "test1_1".to_string(), 2, "test2_1".to_string())
840+
);
841+
842+
Ok(())
843+
}
724844
}

0 commit comments

Comments
 (0)