Skip to content

Commit 1759b73

Browse files
authored
Remove legacy shuffle, add docs for distributed testing (#19)
* always use Ray shuffle * remove legacy shuffle * remove reference to use_ray_shuffle * remove unused imports * remove unused struct * update example * update example * update docs * update expected plans * cargo fmt * address feedback
1 parent b91705c commit 1759b73

29 files changed

+378
-1009
lines changed

datafusion_ray/context.py

+23-44
Original file line numberDiff line numberDiff line change
@@ -77,7 +77,9 @@ def _get_worker_inputs(
7777
plan_bytes = datafusion_ray.serialize_execution_plan(stage.get_execution_plan())
7878
futures = []
7979
opt = {}
80-
opt["resources"] = {"worker": 1e-3}
80+
# TODO not sure why we had this but my Ray cluster could not find suitable resource
81+
# until I commented this out
82+
# opt["resources"] = {"worker": 1e-3}
8183
opt["num_returns"] = output_partitions_count
8284
for part in range(concurrency):
8385
ids, inputs = _get_worker_inputs(part)
@@ -93,7 +95,6 @@ def _get_worker_inputs(
9395
def execute_query_stage(
9496
query_stages: list[QueryStage],
9597
stage_id: int,
96-
use_ray_shuffle: bool,
9798
) -> tuple[int, list[ray.ObjectRef]]:
9899
"""
99100
Execute a query stage on the workers.
@@ -106,7 +107,7 @@ def execute_query_stage(
106107
child_futures = []
107108
for child_id in stage.get_child_stage_ids():
108109
child_futures.append(
109-
execute_query_stage.remote(query_stages, child_id, use_ray_shuffle)
110+
execute_query_stage.remote(query_stages, child_id)
110111
)
111112

112113
# if the query stage has a single output partition then we need to execute for the output
@@ -133,33 +134,28 @@ def _get_worker_inputs(
133134
) -> tuple[list[tuple[int, int, int]], list[ray.ObjectRef]]:
134135
ids = []
135136
futures = []
136-
if use_ray_shuffle:
137-
for child_stage_id, child_futures in child_outputs:
138-
for i, lst in enumerate(child_futures):
139-
if isinstance(lst, list):
140-
for j, f in enumerate(lst):
141-
if concurrency == 1 or j == part:
142-
# If concurrency is 1, pass in all shuffle partitions. Otherwise,
143-
# only pass in the partitions that match the current worker partition.
144-
ids.append((child_stage_id, i, j))
145-
futures.append(f)
146-
elif concurrency == 1 or part == 0:
147-
ids.append((child_stage_id, i, 0))
148-
futures.append(lst)
137+
for child_stage_id, child_futures in child_outputs:
138+
for i, lst in enumerate(child_futures):
139+
if isinstance(lst, list):
140+
for j, f in enumerate(lst):
141+
if concurrency == 1 or j == part:
142+
# If concurrency is 1, pass in all shuffle partitions. Otherwise,
143+
# only pass in the partitions that match the current worker partition.
144+
ids.append((child_stage_id, i, j))
145+
futures.append(f)
146+
elif concurrency == 1 or part == 0:
147+
ids.append((child_stage_id, i, 0))
148+
futures.append(lst)
149149
return ids, futures
150150

151-
# if we are using disk-based shuffle, wait until the child stages to finish
152-
# writing the shuffle files to disk first.
153-
if not use_ray_shuffle:
154-
ray.get([f for _, lst in child_outputs for f in lst])
155-
156151
# schedule the actual execution workers
157152
plan_bytes = datafusion_ray.serialize_execution_plan(stage.get_execution_plan())
158153
futures = []
159154
opt = {}
160-
opt["resources"] = {"worker": 1e-3}
161-
if use_ray_shuffle:
162-
opt["num_returns"] = output_partitions_count
155+
# TODO not sure why we had this but my Ray cluster could not find suitable resource
156+
# until I commented this out
157+
#opt["resources"] = {"worker": 1e-3}
158+
opt["num_returns"] = output_partitions_count
163159
for part in range(concurrency):
164160
ids, inputs = _get_worker_inputs(part)
165161
futures.append(
@@ -210,10 +206,9 @@ def execute_query_partition(
210206

211207

212208
class DatafusionRayContext:
213-
def __init__(self, num_workers: int = 1, use_ray_shuffle: bool = False):
214-
self.ctx = Context(num_workers, use_ray_shuffle)
209+
def __init__(self, num_workers: int = 1):
210+
self.ctx = Context(num_workers)
215211
self.num_workers = num_workers
216-
self.use_ray_shuffle = use_ray_shuffle
217212

218213
def register_csv(self, table_name: str, path: str, has_header: bool):
219214
self.ctx.register_csv(table_name, path, has_header)
@@ -234,23 +229,7 @@ def sql(self, sql: str) -> pa.RecordBatch:
234229

235230
graph = self.ctx.plan(sql)
236231
final_stage_id = graph.get_final_query_stage().id()
237-
if self.use_ray_shuffle:
238-
partitions = schedule_execution(graph, final_stage_id, True)
239-
else:
240-
# serialize the query stages and store in Ray object store
241-
query_stages = [
242-
datafusion_ray.serialize_execution_plan(
243-
graph.get_query_stage(i).get_execution_plan()
244-
)
245-
for i in range(final_stage_id + 1)
246-
]
247-
# schedule execution
248-
future = execute_query_stage.remote(
249-
query_stages,
250-
final_stage_id,
251-
self.use_ray_shuffle,
252-
)
253-
_, partitions = ray.get(future)
232+
partitions = schedule_execution(graph, final_stage_id, True)
254233
# assert len(partitions) == 1, len(partitions)
255234
result_set = ray.get(partitions[0])
256235
return result_set

datafusion_ray/main.py

+4-5
Original file line numberDiff line numberDiff line change
@@ -31,9 +31,9 @@
3131
RESULTS_DIR = f"results-sf{SF}"
3232

3333

34-
def setup_context(use_ray_shuffle: bool, num_workers: int = 2) -> DatafusionRayContext:
34+
def setup_context(num_workers: int = 2) -> DatafusionRayContext:
3535
print(f"Using {num_workers} workers")
36-
ctx = DatafusionRayContext(num_workers, use_ray_shuffle)
36+
ctx = DatafusionRayContext(num_workers)
3737
for table in [
3838
"customer",
3939
"lineitem",
@@ -103,10 +103,9 @@ def compare(q: int):
103103
def tpch_bench():
104104
ray.init(resources={"worker": 1})
105105
num_workers = int(ray.cluster_resources().get("worker", 1)) * NUM_CPUS_PER_WORKER
106-
use_ray_shuffle = False
107-
ctx = setup_context(use_ray_shuffle, num_workers)
106+
ctx = setup_context(num_workers)
108107
# t = tpch_timing(ctx, 11, print_result=True)
109-
# print(f"query,{t},{use_ray_shuffle},{num_workers}")
108+
# print(f"query,{t},{num_workers}")
110109
# return
111110
run_id = time.strftime("%Y-%m-%d-%H-%M-%S")
112111
with open(f"results-sf{SF}-{run_id}.csv", "w") as fout:

docs/testing.md

+50
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,50 @@
1+
# Distributed Testing
2+
3+
Install Ray on at least two nodes.
4+
5+
https://docs.ray.io/en/latest/ray-overview/installation.html
6+
7+
```shell
8+
sudo apt install -y python3-pip python3.12-venv
9+
python3 -m venv venv
10+
source venv/bin/activate
11+
pip3 install -U "ray[default]"
12+
```
13+
14+
## Start Ray Head Node
15+
16+
```shell
17+
ray start --head --node-ip-address=10.0.0.23 --port=6379 --dashboard-host=0.0.0.0
18+
```
19+
20+
## Start Ray Worker Nodes(s)
21+
22+
```shell
23+
ray start --address=10.0.0.23:6379 --redis-password='5241590000000000'
24+
```
25+
26+
## Install DataFusion Ray (on each node)
27+
28+
Clone the repo with the version that you want to test. Run `maturin build --release` in the virtual env.
29+
30+
```shell
31+
curl --proto '=https' --tlsv1.2 -sSf https://sh.rustup.rs | sh
32+
. "$HOME/.cargo/env"
33+
```
34+
35+
```shell
36+
pip3 install maturin
37+
```
38+
39+
```shell
40+
git clone https://github.com/apache/datafusion-ray.git
41+
cd datafusion-ray
42+
maturin develop --release
43+
```
44+
45+
## Submit Job
46+
47+
```shell
48+
cd examples
49+
RAY_ADDRESS='http://10.0.0.23:8265' ray job submit --working-dir `pwd` -- python3 tips.py
50+
```

examples/tips.py

+4-2
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,6 @@
1616
# under the License.
1717

1818
import os
19-
import pandas as pd
2019
import ray
2120

2221
from datafusion_ray import DatafusionRayContext
@@ -26,8 +25,11 @@
2625
# Start a local cluster
2726
ray.init(resources={"worker": 1})
2827

28+
# Connect to a cluster
29+
# ray.init()
30+
2931
# Create a context and register a table
30-
ctx = DatafusionRayContext(2, use_ray_shuffle=True)
32+
ctx = DatafusionRayContext(2)
3133
# Register either a CSV or Parquet file
3234
# ctx.register_csv("tips", f"{SCRIPT_DIR}/tips.csv", True)
3335
ctx.register_parquet("tips", f"{SCRIPT_DIR}/tips.parquet")

src/context.rs

+7-13
Original file line numberDiff line numberDiff line change
@@ -21,12 +21,10 @@ use crate::utils::wait_for_future;
2121
use datafusion::arrow::pyarrow::FromPyArrow;
2222
use datafusion::arrow::pyarrow::ToPyArrow;
2323
use datafusion::arrow::record_batch::RecordBatch;
24-
use datafusion::config::Extensions;
2524
use datafusion::error::{DataFusionError, Result};
2625
use datafusion::execution::context::TaskContext;
2726
use datafusion::execution::disk_manager::DiskManagerConfig;
2827
use datafusion::execution::memory_pool::FairSpillPool;
29-
use datafusion::execution::options::ReadOptions;
3028
use datafusion::execution::runtime_env::RuntimeEnv;
3129
use datafusion::physical_plan::{displayable, ExecutionPlan};
3230
use datafusion::prelude::*;
@@ -47,13 +45,12 @@ type PyResultSet = Vec<PyObject>;
4745
#[pyclass(name = "Context", module = "datafusion_ray", subclass)]
4846
pub struct PyContext {
4947
pub(crate) ctx: SessionContext,
50-
use_ray_shuffle: bool,
5148
}
5249

5350
#[pymethods]
5451
impl PyContext {
5552
#[new]
56-
pub fn new(target_partitions: usize, use_ray_shuffle: bool) -> Result<Self> {
53+
pub fn new(target_partitions: usize) -> Result<Self> {
5754
let config = SessionConfig::default()
5855
.with_target_partitions(target_partitions)
5956
.with_batch_size(16 * 1024)
@@ -67,11 +64,8 @@ impl PyContext {
6764
.with_memory_pool(Arc::new(FairSpillPool::new(mem_pool_size)))
6865
.with_disk_manager(DiskManagerConfig::new_specified(vec!["/tmp".into()]));
6966
let runtime = Arc::new(RuntimeEnv::new(runtime_config)?);
70-
let ctx = SessionContext::with_config_rt(config, runtime);
71-
Ok(Self {
72-
ctx,
73-
use_ray_shuffle,
74-
})
67+
let ctx = SessionContext::new_with_config_rt(config, runtime);
68+
Ok(Self { ctx })
7569
}
7670

7771
pub fn register_csv(
@@ -94,9 +88,9 @@ impl PyContext {
9488

9589
pub fn register_datalake_table(
9690
&self,
97-
name: &str,
98-
path: Vec<String>,
99-
py: Python,
91+
_name: &str,
92+
_path: Vec<String>,
93+
_py: Python,
10094
) -> PyResult<()> {
10195
// let options = ParquetReadOptions::default();
10296
// let listing_options = options.to_listing_options(&self.ctx.state().config());
@@ -119,7 +113,7 @@ impl PyContext {
119113
let df = wait_for_future(py, self.ctx.sql(sql))?;
120114
let plan = wait_for_future(py, df.create_physical_plan())?;
121115

122-
let graph = make_execution_graph(plan.clone(), self.use_ray_shuffle)?;
116+
let graph = make_execution_graph(plan.clone())?;
123117

124118
// debug logging
125119
let mut stages = graph.query_stages.values().collect::<Vec<_>>();

0 commit comments

Comments
 (0)