Skip to content

Commit 42681a1

Browse files
authored
Add object store support (#78)
1 parent 2fc0694 commit 42681a1

File tree

9 files changed

+282
-96
lines changed

9 files changed

+282
-96
lines changed

Cargo.toml

+1-1
Original file line numberDiff line numberDiff line change
@@ -52,6 +52,7 @@ object_store = { version = "0.11.0", features = [
5252
] }
5353
parking_lot = { version = "0.12", features = ["deadlock_detection"] }
5454
prost = "0.13"
55+
protobuf-src = "2.1"
5556
pyo3 = { version = "0.23", features = [
5657
"extension-module",
5758
"abi3",
@@ -85,7 +86,6 @@ tonic-build = { version = "0.8", default-features = false, features = [
8586
"prost",
8687
] }
8788
url = "2"
88-
protobuf-src = "2.1"
8989

9090
[dev-dependencies]
9191
tempfile = "3.17"

README.md

+10-2
Original file line numberDiff line numberDiff line change
@@ -60,12 +60,20 @@ Once installed, you can run queries using DataFusion's familiar API while levera
6060
capabilities of Ray.
6161

6262
```python
63+
# from example in ./examples/http_csv.py
6364
import ray
6465
from datafusion_ray import DFRayContext, df_ray_runtime_env
6566

6667
ray.init(runtime_env=df_ray_runtime_env)
67-
session = DFRayContext()
68-
df = session.sql("SELECT * FROM my_table WHERE value > 100")
68+
69+
ctx = DFRayContext()
70+
ctx.register_csv(
71+
"aggregate_test_100",
72+
"https://github.com/apache/arrow-testing/raw/master/data/csv/aggregate_test_100.csv",
73+
)
74+
75+
df = ctx.sql("SELECT c1,c2,c3 FROM aggregate_test_100 LIMIT 5")
76+
6977
df.show()
7078
```
7179

datafusion_ray/core.py

+59-34
Original file line numberDiff line numberDiff line change
@@ -86,9 +86,7 @@ async def wait_for(coros, name=""):
8686
# wrap the coro in a task to work with python 3.10 and 3.11+ where asyncio.wait semantics
8787
# changed to not accept any awaitable
8888
start = time.time()
89-
done, _ = await asyncio.wait(
90-
[asyncio.create_task(_ensure_coro(c)) for c in coros]
91-
)
89+
done, _ = await asyncio.wait([asyncio.create_task(_ensure_coro(c)) for c in coros])
9290
end = time.time()
9391
log.info(f"waiting for {name} took {end - start}s")
9492
for d in done:
@@ -166,9 +164,7 @@ async def acquire(self, need=1):
166164
need_to_make = need - have
167165

168166
if need_to_make > can_make:
169-
raise Exception(
170-
f"Cannot allocate workers above {self.max_workers}"
171-
)
167+
raise Exception(f"Cannot allocate workers above {self.max_workers}")
172168

173169
if need_to_make > 0:
174170
log.debug(f"creating {need_to_make} additional processors")
@@ -197,9 +193,9 @@ def _new_processor(self):
197193
self.processors_ready.clear()
198194
processor_key = new_friendly_name()
199195
log.debug(f"starting processor: {processor_key}")
200-
processor = DFRayProcessor.options(
201-
name=f"Processor : {processor_key}"
202-
).remote(processor_key)
196+
processor = DFRayProcessor.options(name=f"Processor : {processor_key}").remote(
197+
processor_key
198+
)
203199
self.pool[processor_key] = processor
204200
self.processors_started.add(processor.start_up.remote())
205201
self.available.add(processor_key)
@@ -248,9 +244,7 @@ async def _wait_for_serve(self):
248244

249245
async def all_done(self):
250246
log.info("calling processor all done")
251-
refs = [
252-
processor.all_done.remote() for processor in self.pool.values()
253-
]
247+
refs = [processor.all_done.remote() for processor in self.pool.values()]
254248
await wait_for(refs, "processors to be all done")
255249
log.info("all processors shutdown")
256250

@@ -293,9 +287,7 @@ async def update_plan(
293287
)
294288

295289
async def serve(self):
296-
log.info(
297-
f"[{self.processor_key}] serving on {self.processor_service.addr()}"
298-
)
290+
log.info(f"[{self.processor_key}] serving on {self.processor_service.addr()}")
299291
await self.processor_service.serve()
300292
log.info(f"[{self.processor_key}] done serving")
301293

@@ -332,9 +324,7 @@ def __init__(
332324
worker_pool_min: int,
333325
worker_pool_max: int,
334326
) -> None:
335-
log.info(
336-
f"Creating DFRayContextSupervisor worker_pool_min: {worker_pool_min}"
337-
)
327+
log.info(f"Creating DFRayContextSupervisor worker_pool_min: {worker_pool_min}")
338328
self.pool = DFRayProcessorPool(worker_pool_min, worker_pool_max)
339329
self.stages: dict[str, InternalStageData] = {}
340330
log.info("Created DFRayContextSupervisor")
@@ -347,9 +337,7 @@ async def wait_for_ready(self):
347337

348338
async def get_stage_addrs(self, stage_id: int):
349339
addrs = [
350-
sd.remote_addr
351-
for sd in self.stages.values()
352-
if sd.stage_id == stage_id
340+
sd.remote_addr for sd in self.stages.values() if sd.stage_id == stage_id
353341
]
354342
return addrs
355343

@@ -399,10 +387,7 @@ async def new_query(
399387
refs.append(
400388
isd.remote_processor.update_plan.remote(
401389
isd.stage_id,
402-
{
403-
stage_id: val["child_addrs"]
404-
for (stage_id, val) in kid.items()
405-
},
390+
{stage_id: val["child_addrs"] for (stage_id, val) in kid.items()},
406391
isd.partition_group,
407392
isd.plan_bytes,
408393
)
@@ -434,9 +419,7 @@ async def sort_out_addresses(self):
434419
]
435420

436421
# sanity check
437-
assert all(
438-
[op == output_partitions[0] for op in output_partitions]
439-
)
422+
assert all([op == output_partitions[0] for op in output_partitions])
440423
output_partitions = output_partitions[0]
441424

442425
for child_stage_isd in child_stage_datas:
@@ -520,9 +503,7 @@ def collect(self) -> list[pa.RecordBatch]:
520503
)
521504
log.debug(f"last stage addrs {last_stage_addrs}")
522505

523-
reader = self.df.read_final_stage(
524-
last_stage_id, last_stage_addrs[0]
525-
)
506+
reader = self.df.read_final_stage(last_stage_id, last_stage_addrs[0])
526507
log.debug("got reader")
527508
self._batches = list(reader)
528509
return self._batches
@@ -589,11 +570,55 @@ def __init__(
589570
)
590571

591572
def register_parquet(self, name: str, path: str):
573+
"""
574+
Register a Parquet file with the given name and path.
575+
The path can be a local filesystem path, absolute filesystem path, or a url.
576+
577+
If the path is a object store url, the appropriate object store will be registered.
578+
Configuration of the object store will be gathered from the environment.
579+
580+
For example for s3:// urls, credentials will be looked for by the AWS SDK,
581+
which will check environment variables, credential files, etc
582+
583+
Parameters:
584+
path (str): The file path to the Parquet file.
585+
name (str): The name to register the Parquet file under.
586+
"""
592587
self.ctx.register_parquet(name, path)
593588

594-
def register_listing_table(
595-
self, name: str, path: str, file_extention="parquet"
596-
):
589+
def register_csv(self, name: str, path: str):
590+
"""
591+
Register a csvfile with the given name and path.
592+
The path can be a local filesystem path, absolute filesystem path, or a url.
593+
594+
If the path is a object store url, the appropriate object store will be registered.
595+
Configuration of the object store will be gathered from the environment.
596+
597+
For example for s3:// urls, credentials will be looked for by the AWS SDK,
598+
which will check environment variables, credential files, etc
599+
600+
Parameters:
601+
path (str): The file path to the csv file.
602+
name (str): The name to register the Parquet file under.
603+
"""
604+
self.ctx.register_csv(name, path)
605+
606+
def register_listing_table(self, name: str, path: str, file_extention="parquet"):
607+
"""
608+
Register a directory of parquet files with the given name.
609+
The path can be a local filesystem path, absolute filesystem path, or a url.
610+
611+
If the path is a object store url, the appropriate object store will be registered.
612+
Configuration of the object store will be gathered from the environment.
613+
614+
For example for s3:// urls, credentials will be looked for by the AWS SDK,
615+
which will check environment variables, credential files, etc
616+
617+
Parameters:
618+
path (str): The file path to the Parquet file directory
619+
name (str): The name to register the Parquet file under.
620+
"""
621+
597622
self.ctx.register_listing_table(name, path, file_extention)
598623

599624
def sql(self, query: str) -> DFRayDataFrame:

examples/http_csv.py

+40
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,40 @@
1+
# Licensed to the Apache Software Foundation (ASF) under one
2+
# or more contributor license agreements. See the NOTICE file
3+
# distributed with this work for additional information
4+
# regarding copyright ownership. The ASF licenses this file
5+
# to you under the Apache License, Version 2.0 (the
6+
# "License"); you may not use this file except in compliance
7+
# with the License. You may obtain a copy of the License at
8+
#
9+
# http://www.apache.org/licenses/LICENSE-2.0
10+
#
11+
# Unless required by applicable law or agreed to in writing,
12+
# software distributed under the License is distributed on an
13+
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14+
# KIND, either express or implied. See the License for the
15+
# specific language governing permissions and limitations
16+
# under the License.
17+
18+
# this is a port of the example at
19+
# https://github.com/apache/datafusion/blob/45.0.0/datafusion-examples/examples/query-http-csv.rs
20+
21+
import ray
22+
23+
from datafusion_ray import DFRayContext, df_ray_runtime_env
24+
25+
26+
def main():
27+
ctx = DFRayContext()
28+
ctx.register_csv(
29+
"aggregate_test_100",
30+
"https://github.com/apache/arrow-testing/raw/master/data/csv/aggregate_test_100.csv",
31+
)
32+
33+
df = ctx.sql("SELECT c1,c2,c3 FROM aggregate_test_100 LIMIT 5")
34+
35+
df.show()
36+
37+
38+
if __name__ == __name__:
39+
ray.init(namespace="http_csv", runtime_env=df_ray_runtime_env)
40+
main()

examples/tips.py

+5-18
Original file line numberDiff line numberDiff line change
@@ -16,40 +16,27 @@
1616
# under the License.
1717

1818
import argparse
19-
import datafusion
19+
import os
2020
import ray
2121

22-
from datafusion_ray import DFRayContext
22+
from datafusion_ray import DFRayContext, df_ray_runtime_env
2323

2424

2525
def go(data_dir: str):
2626
ctx = DFRayContext()
27-
# we could set this value to however many CPUs we plan to give each
28-
# ray task
29-
ctx.set("datafusion.execution.target_partitions", "1")
30-
ctx.set("datafusion.optimizer.enable_round_robin_repartition", "false")
3127

32-
ctx.register_parquet("tips", f"{data_dir}/tips*.parquet")
28+
ctx.register_parquet("tips", os.path.join(data_dir, "tips.parquet"))
3329

3430
df = ctx.sql(
3531
"select sex, smoker, avg(tip/total_bill) as tip_pct from tips group by sex, smoker order by sex, smoker"
3632
)
3733
df.show()
3834

39-
print("no ray result:")
40-
41-
# compare to non ray version
42-
ctx = datafusion.SessionContext()
43-
ctx.register_parquet("tips", f"{data_dir}/tips*.parquet")
44-
ctx.sql(
45-
"select sex, smoker, avg(tip/total_bill) as tip_pct from tips group by sex, smoker order by sex, smoker"
46-
).show()
47-
4835

4936
if __name__ == "__main__":
50-
ray.init(namespace="tips")
37+
ray.init(namespace="tips", runtime_env=df_ray_runtime_env)
5138
parser = argparse.ArgumentParser()
52-
parser.add_argument("--data-dir", required=True, help="path to tips*.parquet files")
39+
parser.add_argument("--data-dir", required=True, help="path to tips.parquet files")
5340
args = parser.parse_args()
5441

5542
go(args.data_dir)

src/context.rs

+30-17
Original file line numberDiff line numberDiff line change
@@ -16,17 +16,17 @@
1616
// under the License.
1717

1818
use datafusion::datasource::file_format::parquet::ParquetFormat;
19-
use datafusion::datasource::listing::ListingOptions;
20-
use datafusion::{execution::SessionStateBuilder, prelude::*};
19+
use datafusion::datasource::listing::{ListingOptions, ListingTableUrl};
20+
use datafusion::execution::SessionStateBuilder;
21+
use datafusion::prelude::{CsvReadOptions, ParquetReadOptions, SessionConfig, SessionContext};
2122
use datafusion_python::utils::wait_for_future;
22-
use object_store::aws::AmazonS3Builder;
23+
use log::debug;
2324
use pyo3::prelude::*;
2425
use std::sync::Arc;
2526

2627
use crate::dataframe::DFRayDataFrame;
2728
use crate::physical::RayStageOptimizerRule;
28-
use crate::util::ResultExt;
29-
use url::Url;
29+
use crate::util::{maybe_register_object_store, ResultExt};
3030

3131
/// Internal Session Context object for the python class DFRayContext
3232
#[pyclass]
@@ -54,23 +54,27 @@ impl DFRayContext {
5454
Ok(Self { ctx })
5555
}
5656

57-
pub fn register_s3(&self, bucket_name: String) -> PyResult<()> {
58-
let s3 = AmazonS3Builder::from_env()
59-
.with_bucket_name(&bucket_name)
60-
.build()
61-
.to_py_err()?;
57+
pub fn register_parquet(&self, py: Python, name: String, path: String) -> PyResult<()> {
58+
let options = ParquetReadOptions::default();
59+
60+
let url = ListingTableUrl::parse(&path).to_py_err()?;
6261

63-
let path = format!("s3://{bucket_name}");
64-
let s3_url = Url::parse(&path).to_py_err()?;
65-
let arc_s3 = Arc::new(s3);
66-
self.ctx.register_object_store(&s3_url, arc_s3.clone());
62+
maybe_register_object_store(&self.ctx, url.as_ref()).to_py_err()?;
63+
debug!("register_parquet: registering table {} at {}", name, path);
64+
65+
wait_for_future(py, self.ctx.register_parquet(&name, &path, options.clone()))?;
6766
Ok(())
6867
}
6968

70-
pub fn register_parquet(&self, py: Python, name: String, path: String) -> PyResult<()> {
71-
let options = ParquetReadOptions::default();
69+
pub fn register_csv(&self, py: Python, name: String, path: String) -> PyResult<()> {
70+
let options = CsvReadOptions::default();
7271

73-
wait_for_future(py, self.ctx.register_parquet(&name, &path, options.clone()))?;
72+
let url = ListingTableUrl::parse(&path).to_py_err()?;
73+
74+
maybe_register_object_store(&self.ctx, url.as_ref()).to_py_err()?;
75+
debug!("register_csv: registering table {} at {}", name, path);
76+
77+
wait_for_future(py, self.ctx.register_csv(&name, &path, options.clone()))?;
7478
Ok(())
7579
}
7680

@@ -85,6 +89,15 @@ impl DFRayContext {
8589
let options =
8690
ListingOptions::new(Arc::new(ParquetFormat::new())).with_file_extension(file_extension);
8791

92+
let path = format!("{path}/");
93+
let url = ListingTableUrl::parse(&path).to_py_err()?;
94+
95+
maybe_register_object_store(&self.ctx, url.as_ref()).to_py_err()?;
96+
97+
debug!(
98+
"register_listing_table: registering table {} at {}",
99+
name, path
100+
);
88101
wait_for_future(
89102
py,
90103
self.ctx

0 commit comments

Comments
 (0)