Skip to content

Commit 347a661

Browse files
committed
Enable flight
1 parent 44c3a98 commit 347a661

8 files changed

Lines changed: 112 additions & 3 deletions

File tree

Cargo.lock

Lines changed: 1 addition & 0 deletions
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

README.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -111,7 +111,7 @@ brew install roapi
111111
# cargo install --locked --git https://github.com/roapi/roapi --branch main --bins roapi
112112
roapi -t taxi=https://d37ci6vzurychx.cloudfront.net/trip-data/yellow_tripdata_2024-01.parquet &
113113

114-
cargo run --example flight-sql --features flight
114+
cargo run -p datafusion-table-providers --example flight-sql --no-default-features --features flight
115115
```
116116

117117
### ODBC

core/src/flight/exec.rs

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,7 @@ use std::str::FromStr;
2424
use std::sync::Arc;
2525

2626
use crate::flight::{flight_channel, to_df_err, FlightMetadata, FlightProperties, SizeLimits};
27+
use crate::sql::db_connection_pool::runtime::run_async_with_tokio;
2728
use arrow_flight::error::FlightError;
2829
use arrow_flight::flight_service_client::FlightServiceClient;
2930
use arrow_flight::{FlightClient, FlightEndpoint, Ticket};
@@ -190,7 +191,9 @@ async fn flight_stream(
190191
) -> Result<SendableRecordBatchStream> {
191192
let mut errors: Vec<Box<dyn Error + Send + Sync>> = vec![];
192193
for loc in partition.locations.iter() {
193-
let client = flight_client(loc, grpc_headers.as_ref(), &size_limits).await?;
194+
let get_client = || async { flight_client(loc, grpc_headers.as_ref(), &size_limits).await };
195+
let client = run_async_with_tokio(get_client).await?;
196+
// let client = flight_client(loc, grpc_headers.as_ref(), &size_limits).await?;
194197
match try_fetch_stream(client, &partition.ticket, schema.clone()).await {
195198
Ok(stream) => return Ok(stream),
196199
Err(e) => errors.push(Box::new(e)),

python/Cargo.toml

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -15,9 +15,10 @@ doc = false
1515

1616
[dependencies]
1717
arrow = { workspace = true }
18+
arrow-flight = {workspace = true}
1819
datafusion = { workspace = true, features = ["pyarrow"] }
1920
datafusion-ffi = { workspace = true }
20-
datafusion-table-providers = { workspace = true, features = ["sqlite", "duckdb", "odbc", "mysql", "postgres"] }
21+
datafusion-table-providers = { workspace = true, features = ["sqlite", "duckdb", "odbc", "mysql", "postgres", "flight"] }
2122
pyo3 = { version = "0.23" }
2223
tokio = { version = "1.42", features = ["macros", "rt", "rt-multi-thread", "sync"] }
2324
duckdb = { workspace = true }

python/examples/flight_demo.py

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,14 @@
1+
from datafusion import SessionContext
2+
from datafusion_table_providers import flight
3+
4+
ctx = SessionContext()
5+
pool = flight.FlightTableFactory()
6+
table_provider = pool.get_table("http://localhost:32010", {"flight.sql.query": "SELECT * FROM taxi"})
7+
table_name = "taxi_flight_table"
8+
ctx.register_table_provider(table_name, table_provider)
9+
ctx.sql(f"""
10+
SELECT "VendorID", COUNT(*), SUM(passenger_count), SUM(total_amount)
11+
FROM {table_name}
12+
GROUP BY "VendorID"
13+
ORDER BY COUNT(*) DESC
14+
""").show()
Lines changed: 36 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,36 @@
1+
# Licensed to the Apache Software Foundation (ASF) under one
2+
# or more contributor license agreements. See the NOTICE file
3+
# distributed with this work for additional information
4+
# regarding copyright ownership. The ASF licenses this file
5+
# to you under the Apache License, Version 2.0 (the
6+
# "License"); you may not use this file except in compliance
7+
# with the License. You may obtain a copy of the License at
8+
#
9+
# http://www.apache.org/licenses/LICENSE-2.0
10+
#
11+
# Unless required by applicable law or agreed to in writing,
12+
# software distributed under the License is distributed on an
13+
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14+
# KIND, either express or implied. See the License for the
15+
# specific language governing permissions and limitations
16+
# under the License.
17+
"""Python interface for sqlite table provider."""
18+
19+
from typing import Any, List
20+
from . import _internal
21+
22+
class FlightTableFactory:
23+
"""Flight table factory."""
24+
25+
def __init__(self) -> None:
26+
"""Create a Flight table factory."""
27+
self._raw = _internal.flight.RawFlightTableFactory()
28+
29+
def get_table(self, entry_point: str, options: dict) -> Any:
30+
"""Return the table provider for table.
31+
32+
Args:
33+
entry_point: uri
34+
options: table information
35+
"""
36+
return self._raw.get_table(entry_point, options)

python/src/flight.rs

Lines changed: 50 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1 +1,51 @@
1+
use std::sync::Arc;
12

3+
use datafusion::catalog::TableProvider;
4+
use datafusion_table_providers::flight::{sql::FlightSqlDriver, FlightDriver, FlightTableFactory};
5+
use pyo3::{prelude::*, types::PyDict};
6+
7+
use crate::{
8+
utils::{pydict_to_hashmap, to_pyerr, wait_for_future},
9+
RawTableProvider,
10+
};
11+
12+
#[pyclass(module = "datafusion_table_providers._internal.Flight")]
13+
struct RawFlightTableFactory {
14+
factory: FlightTableFactory,
15+
}
16+
17+
#[pymethods]
18+
impl RawFlightTableFactory {
19+
#[new]
20+
#[pyo3(signature = ())]
21+
pub fn new() -> PyResult<Self> {
22+
let driver: Arc<dyn FlightDriver> = Arc::new(FlightSqlDriver::new());
23+
24+
Ok(Self {
25+
factory: FlightTableFactory::new(Arc::clone(&driver)),
26+
})
27+
}
28+
29+
pub fn get_table(
30+
&self,
31+
py: Python,
32+
entry_point: &str,
33+
options: &Bound<'_, PyDict>,
34+
) -> PyResult<RawTableProvider> {
35+
let options = pydict_to_hashmap(options)?;
36+
let table: Arc<dyn TableProvider> = Arc::new(
37+
wait_for_future(py, self.factory.open_table(entry_point, options)).map_err(to_pyerr)?,
38+
);
39+
40+
Ok(RawTableProvider {
41+
table,
42+
supports_pushdown_filters: true,
43+
})
44+
}
45+
}
46+
47+
pub(crate) fn init_module(m: &Bound<'_, PyModule>) -> PyResult<()> {
48+
m.add_class::<RawFlightTableFactory>()?;
49+
50+
Ok(())
51+
}

python/src/lib.rs

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -61,5 +61,9 @@ fn _internal(py: Python, m: &Bound<'_, PyModule>) -> PyResult<()> {
6161
postgres::init_module(&postgres)?;
6262
m.add_submodule(&postgres)?;
6363

64+
let flight = PyModule::new(py, "flight")?;
65+
flight::init_module(&flight)?;
66+
m.add_submodule(&flight)?;
67+
6468
Ok(())
6569
}

0 commit comments

Comments
 (0)