Skip to content

Commit aee1967

Browse files
authored
MongoDB: upgrade crate, add SRV support, re-enable Python bindings (#602)
* MongoDB: upgrade crate, add SRV support, re-enable Python bindings - Upgrade mongodb crate from 3.2.2 to 3.5.2 - Add mongodb+srv:// SRV connection format support via 'srv' parameter - Add tables() method to MongoDBConnection for listing collections - Re-enable MongoDB Python bindings (feature flag, PyO3 module, wrapper class) - Add MongoDB Python example and demo script - Add unit tests for SRV connection URI building * Update Rust toolchain to 1.92.0 to match DataFusion v52 * CI: authenticate to ghcr.io to avoid Docker pull rate limits
1 parent df93942 commit aee1967

12 files changed

Lines changed: 285 additions & 13 deletions

File tree

.github/workflows/pr.yaml

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -75,8 +75,13 @@ jobs:
7575
- name: Cache Rust dependencies
7676
uses: Swatinem/rust-cache@v2
7777

78+
- name: Log in to container registries
79+
run: |
80+
echo "${{ secrets.GITHUB_TOKEN }}" | docker login ghcr.io -u ${{ github.actor }} --password-stdin
81+
7882
- name: Pull the Postgres/MySQL images
7983
run: |
84+
: work around spurious network errors in curl 8.0
8085
docker pull ${{ env.PG_DOCKER_IMAGE }}
8186
docker pull ${{ env.MYSQL_DOCKER_IMAGE }}
8287

Cargo.lock

Lines changed: 7 additions & 7 deletions
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

core/Cargo.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -50,7 +50,7 @@ fundu = { workspace = true }
5050
futures = { workspace = true }
5151
geo-types = { workspace = true }
5252
itertools = { workspace = true }
53-
mongodb = { version = "3.2.2", features = ["openssl-tls"], optional = true }
53+
mongodb = { version = "3.5.2", features = ["openssl-tls"], optional = true }
5454
mysql_async = { version = "0.36", features = [
5555
"native-tls-tls",
5656
"chrono",

core/src/mongodb/connection.rs

Lines changed: 11 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,7 @@ use std::sync::Arc;
1515
use crate::mongodb::utils::arrow::mongo_docs_to_arrow;
1616
use crate::mongodb::utils::schema::infer_arrow_schema_from_documents;
1717
use crate::mongodb::utils::unnest::{unnest_bson_documents, UnnestBehavior, UnnestParameters};
18-
use crate::mongodb::{Error, QuerySnafu, Result, UnableToGetSchemaSnafu};
18+
use crate::mongodb::{Error, QuerySnafu, Result, UnableToGetSchemaSnafu, UnableToGetTablesSnafu};
1919

2020
pub struct MongoDBConnection {
2121
pub client: Arc<Client>,
@@ -46,6 +46,16 @@ impl MongoDBConnection {
4646
self.client.database(&self.db_name).collection(collection)
4747
}
4848

49+
pub async fn tables(&self) -> Result<Vec<String>, Error> {
50+
let db = self.client.database(&self.db_name);
51+
let collections = db
52+
.list_collection_names()
53+
.await
54+
.boxed()
55+
.context(UnableToGetTablesSnafu)?;
56+
Ok(collections)
57+
}
58+
4959
pub async fn get_schema(&self, table_reference: &TableReference) -> Result<SchemaRef, Error> {
5060
let collection_name = table_reference.table();
5161
let coll = self.get_collection(collection_name);

core/src/mongodb/connection_pool.rs

Lines changed: 75 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -133,7 +133,17 @@ fn build_connection_uri(
133133
format!("?{}", query_params.join("&"))
134134
};
135135

136-
let uri = format!("mongodb://{auth}{host}:{port}/{db_name}{query_string}");
136+
let use_srv = params
137+
.get("srv")
138+
.map(|s| s.expose_secret().eq_ignore_ascii_case("true"))
139+
.unwrap_or(false);
140+
141+
let uri = if use_srv {
142+
// mongodb+srv:// uses DNS SRV records for discovery; port is not allowed
143+
format!("mongodb+srv://{auth}{host}/{db_name}{query_string}")
144+
} else {
145+
format!("mongodb://{auth}{host}:{port}/{db_name}{query_string}")
146+
};
137147
Ok((uri, Some(db_name.to_string())))
138148
}
139149

@@ -327,6 +337,70 @@ mod tests {
327337
assert_eq!(result.1, Some("testdb".to_string()));
328338
}
329339

340+
#[test]
341+
fn test_build_connection_uri_with_srv() {
342+
let params = create_params(vec![
343+
("db", "mydb"),
344+
("host", "cluster0.example.mongodb.net"),
345+
("user", "testuser"),
346+
("pass", "testpass"),
347+
("srv", "true"),
348+
]);
349+
350+
let result = build_connection_uri(&params).unwrap();
351+
assert_eq!(
352+
result.0,
353+
"mongodb+srv://testuser:testpass@cluster0.example.mongodb.net/mydb"
354+
);
355+
assert_eq!(result.1, Some("mydb".to_string()));
356+
}
357+
358+
#[test]
359+
fn test_build_connection_uri_with_srv_and_query_params() {
360+
let params = create_params(vec![
361+
("db", "mydb"),
362+
("host", "cluster0.example.mongodb.net"),
363+
("srv", "true"),
364+
("auth_source", "admin"),
365+
]);
366+
367+
let result = build_connection_uri(&params).unwrap();
368+
assert_eq!(
369+
result.0,
370+
"mongodb+srv://cluster0.example.mongodb.net/mydb?authSource=admin"
371+
);
372+
assert_eq!(result.1, Some("mydb".to_string()));
373+
}
374+
375+
#[test]
376+
fn test_build_connection_uri_with_srv_false() {
377+
let params = create_params(vec![
378+
("db", "mydb"),
379+
("host", "localhost"),
380+
("port", "27017"),
381+
("srv", "false"),
382+
]);
383+
384+
let result = build_connection_uri(&params).unwrap();
385+
assert_eq!(result.0, "mongodb://localhost:27017/mydb");
386+
assert_eq!(result.1, Some("mydb".to_string()));
387+
}
388+
389+
#[test]
390+
fn test_build_connection_uri_with_connection_string_srv() {
391+
let params = create_params(vec![(
392+
"connection_string",
393+
"mongodb+srv://user:pass@cluster0.example.mongodb.net/testdb",
394+
)]);
395+
396+
let result = build_connection_uri(&params).unwrap();
397+
assert_eq!(
398+
result.0,
399+
"mongodb+srv://user:pass@cluster0.example.mongodb.net/testdb"
400+
);
401+
assert_eq!(result.1, None);
402+
}
403+
330404
#[test]
331405
fn test_build_connection_uri_user_without_password() {
332406
let params = create_params(vec![("user", "testuser")]);

python/Cargo.toml

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -36,7 +36,8 @@ default = [
3636
"mysql",
3737
"postgres",
3838
"odbc",
39-
"flight",
39+
"flight",
40+
"mongodb",
4041
]
4142
clickhouse = ["datafusion-table-providers/clickhouse-federation"]
4243
duckdb = ["dep:duckdb", "datafusion-table-providers/duckdb-federation"]
@@ -45,3 +46,4 @@ mysql = ["datafusion-table-providers/mysql-federation"]
4546
postgres = ["datafusion-table-providers/postgres-federation"]
4647
odbc = ["datafusion-table-providers/odbc-federation"]
4748
flight = ["dep:arrow-flight", "datafusion-table-providers/flight"]
49+
mongodb = ["datafusion-table-providers/mongodb"]

python/examples/mongodb_demo.py

Lines changed: 75 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,75 @@
1+
# Licensed to the Apache Software Foundation (ASF) under one
2+
# or more contributor license agreements. See the NOTICE file
3+
# distributed with this work for additional information
4+
# regarding copyright ownership. The ASF licenses this file
5+
# to you under the Apache License, Version 2.0 (the
6+
# "License"); you may not use this file except in compliance
7+
# with the License. You may obtain a copy of the License at
8+
#
9+
# http://www.apache.org/licenses/LICENSE-2.0
10+
#
11+
# Unless required by applicable law or agreed to in writing,
12+
# software distributed under the License is distributed on an
13+
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14+
# KIND, either express or implied. See the License for the
15+
# specific language governing permissions and limitations
16+
# under the License.
17+
"""Example demonstrating MongoDB table provider usage.
18+
19+
Prerequisites:
20+
Start a MongoDB server using Docker:
21+
22+
```bash
23+
docker run --name mongodb \
24+
-e MONGO_INITDB_ROOT_USERNAME=root \
25+
-e MONGO_INITDB_ROOT_PASSWORD=password \
26+
-e MONGO_INITDB_DATABASE=mongo_db \
27+
-p 27017:27017 \
28+
-d mongo:7.0
29+
30+
# Wait for the MongoDB server to start
31+
sleep 30
32+
33+
# Create a table in the MongoDB server and insert some data
34+
docker exec -i mongodb mongosh -u root -p password --authenticationDatabase admin <<EOF
35+
use mongo_db;
36+
37+
db.companies.insertOne({
38+
id: 1,
39+
name: "Acme Corporation"
40+
});
41+
EOF
42+
```
43+
"""
44+
45+
import datafusion
46+
from datafusion_table_providers.mongodb import MongoDBTableFactory
47+
48+
49+
def main():
50+
# Create MongoDB connection parameters
51+
mongodb_params = {
52+
"connection_string": "mongodb://root:password@localhost:27017/mongo_db?authSource=admin&tls=false"
53+
}
54+
55+
# Create MongoDB table factory
56+
factory = MongoDBTableFactory(mongodb_params)
57+
58+
# List all tables
59+
tables = factory.tables()
60+
print(f"Tables: {tables}")
61+
62+
# Get table provider for 'companies' table
63+
table = factory.get_table("companies")
64+
65+
# Create DataFusion context and register the table
66+
ctx = datafusion.SessionContext()
67+
ctx.register_table_provider("companies", table)
68+
69+
# Query the table
70+
df = ctx.sql("SELECT * FROM companies")
71+
df.show()
72+
73+
74+
if __name__ == "__main__":
75+
main()

python/pyproject.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -24,7 +24,7 @@ classifier = [
2424
"Programming Language :: Python",
2525
"Programming Language :: Rust",
2626
]
27-
dependencies = ["datafusion>=45.0.0"]
27+
dependencies = ["datafusion>=45.0.0,<52.0.0"]
2828

2929
[project.urls]
3030
repository = "https://github.com/datafusion-contrib/datafusion-table-providers"
Lines changed: 39 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,39 @@
1+
# Licensed to the Apache Software Foundation (ASF) under one
2+
# or more contributor license agreements. See the NOTICE file
3+
# distributed with this work for additional information
4+
# regarding copyright ownership. The ASF licenses this file
5+
# to you under the Apache License, Version 2.0 (the
6+
# "License"); you may not use this file except in compliance
7+
# with the License. You may obtain a copy of the License at
8+
#
9+
# http://www.apache.org/licenses/LICENSE-2.0
10+
#
11+
# Unless required by applicable law or agreed to in writing,
12+
# software distributed under the License is distributed on an
13+
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14+
# KIND, either express or implied. See the License for the
15+
# specific language governing permissions and limitations
16+
# under the License.
17+
"""Python interface for MongoDB table provider."""
18+
19+
from typing import Any, List
20+
from . import _internal
21+
22+
class MongoDBTableFactory:
23+
"""MongoDB table factory."""
24+
25+
def __init__(self, params: dict) -> None:
26+
"""Create a MongoDB table factory."""
27+
self._raw = _internal.mongodb.RawMongoDBTableFactory(params)
28+
29+
def tables(self) -> List[str]:
30+
"""Get all the table names."""
31+
return self._raw.tables()
32+
33+
def get_table(self, table_reference: str) -> Any:
34+
"""Return the table provider for table named `table_reference`.
35+
36+
Args:
37+
table_reference (str): table name
38+
"""
39+
return self._raw.get_table(table_reference)

python/src/lib.rs

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -49,6 +49,8 @@ pub mod clickhouse;
4949
pub mod duckdb;
5050
#[cfg(feature = "flight")]
5151
pub mod flight;
52+
#[cfg(feature = "mongodb")]
53+
pub mod mongodb;
5254
#[cfg(feature = "mysql")]
5355
pub mod mysql;
5456
#[cfg(feature = "odbc")]
@@ -113,5 +115,12 @@ fn _internal(py: Python, m: &Bound<'_, PyModule>) -> PyResult<()> {
113115
m.add_submodule(&clickhouse)?;
114116
}
115117

118+
#[cfg(feature = "mongodb")]
119+
{
120+
let mongodb = PyModule::new(py, "mongodb")?;
121+
mongodb::init_module(&mongodb)?;
122+
m.add_submodule(&mongodb)?;
123+
}
124+
116125
Ok(())
117126
}

0 commit comments

Comments
 (0)