Skip to content

Commit 4c83d76

Browse files
authored
fix query 15, update validator, formatting, ray scheduling fix (#83)
1 parent 42681a1 commit 4c83d76

File tree

6 files changed

+99
-131
lines changed

6 files changed

+99
-131
lines changed

datafusion_ray/core.py

+2-2
Original file line numberDiff line numberDiff line change
@@ -249,7 +249,7 @@ async def all_done(self):
249249
log.info("all processors shutdown")
250250

251251

252-
@ray.remote(num_cpus=0)
252+
@ray.remote(num_cpus=0.01, scheduling_strategy="SPREAD")
253253
class DFRayProcessor:
254254
def __init__(self, processor_key):
255255
self.processor_key = processor_key
@@ -317,7 +317,7 @@ def __str__(self):
317317
return f"""Stage: {self.stage_id}, pg: {self.partition_group}, child_stages:{self.child_stage_ids}, listening addr:{self.remote_addr}"""
318318

319319

320-
@ray.remote(num_cpus=0)
320+
@ray.remote(num_cpus=0.01, scheduling_strategy="SPREAD")
321321
class DFRayContextSupervisor:
322322
def __init__(
323323
self,

datafusion_ray/util.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
11
from datafusion_ray._datafusion_ray_internal import (
2-
exec_sql_on_tables,
2+
LocalValidator,
33
prettify,
44
)

src/dataframe.rs

+1-1
Original file line numberDiff line numberDiff line change
@@ -50,10 +50,10 @@ use crate::max_rows::MaxRowsExec;
5050
use crate::pre_fetch::PrefetchExec;
5151
use crate::stage::DFRayStageExec;
5252
use crate::stage_reader::DFRayStageReaderExec;
53+
use crate::util::ResultExt;
5354
use crate::util::collect_from_stage;
5455
use crate::util::display_plan_with_partition_counts;
5556
use crate::util::physical_plan_to_bytes;
56-
use crate::util::ResultExt;
5757

5858
/// Internal rust class beyind the DFRayDataFrame python object
5959
///

src/lib.rs

+1-1
Original file line numberDiff line numberDiff line change
@@ -43,8 +43,8 @@ fn _datafusion_ray_internal(m: &Bound<'_, PyModule>) -> PyResult<()> {
4343
m.add_class::<dataframe::DFRayDataFrame>()?;
4444
m.add_class::<dataframe::PyDFRayStage>()?;
4545
m.add_class::<processor_service::DFRayProcessorService>()?;
46+
m.add_class::<util::LocalValidator>()?;
4647
m.add_function(wrap_pyfunction!(util::prettify, m)?)?;
47-
m.add_function(wrap_pyfunction!(util::exec_sql_on_tables, m)?)?;
4848
Ok(())
4949
}
5050

src/util.rs

+75-110
Original file line numberDiff line numberDiff line change
@@ -8,13 +8,12 @@ use std::task::{Context, Poll};
88
use std::time::Duration;
99

1010
use arrow::array::RecordBatch;
11-
use arrow::compute::concat_batches;
1211
use arrow::datatypes::SchemaRef;
1312
use arrow::error::ArrowError;
1413
use arrow::ipc::convert::fb_to_schema;
1514
use arrow::ipc::reader::StreamReader;
1615
use arrow::ipc::writer::{IpcWriteOptions, StreamWriter};
17-
use arrow::ipc::{root_as_message, MetadataVersion};
16+
use arrow::ipc::{MetadataVersion, root_as_message};
1817
use arrow::pyarrow::*;
1918
use arrow::util::pretty;
2019
use arrow_flight::{FlightClient, FlightData, Ticket};
@@ -30,16 +29,16 @@ use datafusion::error::DataFusionError;
3029
use datafusion::execution::object_store::ObjectStoreUrl;
3130
use datafusion::execution::{RecordBatchStream, SendableRecordBatchStream, SessionStateBuilder};
3231
use datafusion::physical_plan::stream::RecordBatchStreamAdapter;
33-
use datafusion::physical_plan::{displayable, ExecutionPlan, ExecutionPlanProperties};
34-
use datafusion::prelude::{SessionConfig, SessionContext};
32+
use datafusion::physical_plan::{ExecutionPlan, ExecutionPlanProperties, displayable};
33+
use datafusion::prelude::{ParquetReadOptions, SessionConfig, SessionContext};
3534
use datafusion_proto::physical_plan::AsExecutionPlan;
3635
use datafusion_python::utils::wait_for_future;
3736
use futures::{Stream, StreamExt};
3837
use log::debug;
38+
use object_store::ObjectStore;
3939
use object_store::aws::AmazonS3Builder;
4040
use object_store::gcp::GoogleCloudStorageBuilder;
4141
use object_store::http::HttpBuilder;
42-
use object_store::ObjectStore;
4342
use parking_lot::Mutex;
4443
use pyo3::prelude::*;
4544
use pyo3::types::{PyBytes, PyList};
@@ -411,62 +410,77 @@ fn print_node(plan: &Arc<dyn ExecutionPlan>, indent: usize, output: &mut String)
411410
}
412411
}
413412

414-
async fn exec_sql(
415-
query: String,
416-
tables: Vec<(String, String)>,
417-
) -> Result<RecordBatch, DataFusionError> {
418-
let ctx = SessionContext::new();
419-
for (name, path) in tables {
420-
let opt =
421-
ListingOptions::new(Arc::new(ParquetFormat::new())).with_file_extension(".parquet");
422-
debug!("exec_sql: registering table {} at {}", name, path);
413+
#[pyclass]
414+
pub struct LocalValidator {
415+
ctx: SessionContext,
416+
}
417+
418+
#[pymethods]
419+
impl LocalValidator {
420+
#[new]
421+
fn new() -> Self {
422+
let ctx = SessionContext::new();
423+
Self { ctx }
424+
}
425+
426+
pub fn register_parquet(&self, py: Python, name: String, path: String) -> PyResult<()> {
427+
let options = ParquetReadOptions::default();
423428

424-
let url = ListingTableUrl::parse(&path)?;
429+
let url = ListingTableUrl::parse(&path).to_py_err()?;
425430

426-
maybe_register_object_store(&ctx, url.as_ref())?;
431+
maybe_register_object_store(&self.ctx, url.as_ref()).to_py_err()?;
432+
debug!("register_parquet: registering table {} at {}", name, path);
427433

428-
ctx.register_listing_table(&name, &path, opt, None, None)
429-
.await?;
434+
wait_for_future(py, self.ctx.register_parquet(&name, &path, options.clone()))?;
435+
Ok(())
430436
}
431-
let df = ctx.sql(&query).await?;
432-
let schema = df.schema().inner().clone();
433-
let batches = df.collect().await?;
434-
concat_batches(&schema, batches.iter()).map_err(|e| DataFusionError::ArrowError(e, None))
435-
}
436437

437-
/// Executes a query on the specified tables using DataFusion without Ray.
438-
///
439-
/// Returns the query results as a RecordBatch that can be used to verify the
440-
/// correctness of DataFusion-Ray execution of the same query.
441-
///
442-
/// # Arguments
443-
///
444-
/// * `py`: the Python token
445-
/// * `query`: the SQL query string to execute
446-
/// * `tables`: a list of `(name, url)` tuples specifying the tables to query;
447-
/// the `url` identifies the parquet files for each listing table and see
448-
/// [`datafusion::datasource::listing::ListingTableUrl::parse`] for details
449-
/// of supported URL formats
450-
/// * `listing`: boolean indicating whether this is a listing table path or not
451-
#[pyfunction]
452-
#[pyo3(signature = (query, tables, listing=false))]
453-
pub fn exec_sql_on_tables(
454-
py: Python,
455-
query: String,
456-
tables: Bound<'_, PyList>,
457-
listing: bool,
458-
) -> PyResult<PyObject> {
459-
let table_vec = {
460-
let mut v = Vec::with_capacity(tables.len());
461-
for entry in tables.iter() {
462-
let (name, path) = entry.extract::<(String, String)>()?;
463-
let path = if listing { format!("{path}/") } else { path };
464-
v.push((name, path));
465-
}
466-
v
467-
};
468-
let batch = wait_for_future(py, exec_sql(query, table_vec))?;
469-
batch.to_pyarrow(py)
438+
#[pyo3(signature = (name, path, file_extension=".parquet"))]
439+
pub fn register_listing_table(
440+
&mut self,
441+
py: Python,
442+
name: &str,
443+
path: &str,
444+
file_extension: &str,
445+
) -> PyResult<()> {
446+
let options =
447+
ListingOptions::new(Arc::new(ParquetFormat::new())).with_file_extension(file_extension);
448+
449+
let path = format!("{path}/");
450+
let url = ListingTableUrl::parse(&path).to_py_err()?;
451+
452+
maybe_register_object_store(&self.ctx, url.as_ref()).to_py_err()?;
453+
454+
debug!(
455+
"register_listing_table: registering table {} at {}",
456+
name, path
457+
);
458+
wait_for_future(
459+
py,
460+
self.ctx
461+
.register_listing_table(name, path, options, None, None),
462+
)
463+
.to_py_err()
464+
}
465+
466+
#[pyo3(signature = (query))]
467+
fn collect_sql(&self, py: Python, query: String) -> PyResult<PyObject> {
468+
let fut = async || {
469+
let df = self.ctx.sql(&query).await?;
470+
let batches = df.collect().await?;
471+
472+
Ok::<_, DataFusionError>(batches)
473+
};
474+
475+
let batches = wait_for_future(py, fut())
476+
.to_py_err()?
477+
.iter()
478+
.map(|batch| batch.to_pyarrow(py))
479+
.collect::<PyResult<Vec<_>>>()?;
480+
481+
let pylist = PyList::new(py, batches)?;
482+
Ok(pylist.into())
483+
}
470484
}
471485

472486
pub(crate) fn register_object_store_for_paths_in_plan(
@@ -570,62 +584,14 @@ mod test {
570584
use std::{sync::Arc, vec};
571585

572586
use arrow::{
573-
array::{Int32Array, StringArray},
587+
array::Int32Array,
574588
datatypes::{DataType, Field, Schema},
575589
};
576-
use datafusion::{
577-
parquet::file::properties::WriterProperties, test_util::parquet::TestParquetFile,
578-
};
590+
579591
use futures::stream;
580592

581593
use super::*;
582594

583-
#[tokio::test]
584-
async fn test_exec_sql() {
585-
let dir = tempfile::tempdir().unwrap();
586-
let path = dir.path().join("people.parquet");
587-
588-
let batch = RecordBatch::try_new(
589-
Arc::new(Schema::new(vec![
590-
Field::new("age", DataType::Int32, false),
591-
Field::new("name", DataType::Utf8, false),
592-
])),
593-
vec![
594-
Arc::new(Int32Array::from(vec![11, 12, 13])),
595-
Arc::new(StringArray::from(vec!["alice", "bob", "cindy"])),
596-
],
597-
)
598-
.unwrap();
599-
let props = WriterProperties::builder().build();
600-
let file = TestParquetFile::try_new(path.clone(), props, Some(batch.clone())).unwrap();
601-
602-
// test with file
603-
let tables = vec![(
604-
"people".to_string(),
605-
format!("file://{}", file.path().to_str().unwrap()),
606-
)];
607-
let query = "SELECT * FROM people ORDER BY age".to_string();
608-
let res = exec_sql(query.clone(), tables).await.unwrap();
609-
assert_eq!(
610-
format!(
611-
"{}",
612-
pretty::pretty_format_batches(&[batch.clone()]).unwrap()
613-
),
614-
format!("{}", pretty::pretty_format_batches(&[res]).unwrap()),
615-
);
616-
617-
// test with dir
618-
let tables = vec![(
619-
"people".to_string(),
620-
format!("file://{}/", dir.path().to_str().unwrap()),
621-
)];
622-
let res = exec_sql(query, tables).await.unwrap();
623-
assert_eq!(
624-
format!("{}", pretty::pretty_format_batches(&[batch]).unwrap()),
625-
format!("{}", pretty::pretty_format_batches(&[res]).unwrap()),
626-
);
627-
}
628-
629595
#[test]
630596
fn test_ipc_roundtrip() {
631597
let batch = RecordBatch::try_new(
@@ -641,10 +607,9 @@ mod test {
641607
#[tokio::test]
642608
async fn test_max_rows_stream() {
643609
let schema = Arc::new(Schema::new(vec![Field::new("a", DataType::Int32, false)]));
644-
let batch = RecordBatch::try_new(
645-
schema.clone(),
646-
vec![Arc::new(Int32Array::from(vec![1, 2, 3, 4, 5, 6, 7, 8]))],
647-
)
610+
let batch = RecordBatch::try_new(schema.clone(), vec![Arc::new(Int32Array::from(vec![
611+
1, 2, 3, 4, 5, 6, 7, 8,
612+
]))])
648613
.unwrap();
649614

650615
// 24 total rows

tpch/tpcbench.py

+19-16
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,7 @@
1818
import argparse
1919
import ray
2020
from datafusion_ray import DFRayContext, df_ray_runtime_env
21-
from datafusion_ray.util import exec_sql_on_tables, prettify
21+
from datafusion_ray.util import LocalValidator, prettify
2222
from datetime import datetime
2323
import json
2424
import os
@@ -63,6 +63,8 @@ def main(
6363
worker_pool_min=worker_pool_min,
6464
)
6565

66+
local = LocalValidator()
67+
6668
ctx.set("datafusion.execution.target_partitions", f"{concurrency}")
6769
# ctx.set("datafusion.execution.parquet.pushdown_filters", "true")
6870
ctx.set("datafusion.optimizer.enable_round_robin_repartition", "false")
@@ -73,8 +75,10 @@ def main(
7375
print(f"Registering table {table} using path {path}")
7476
if listing_tables:
7577
ctx.register_listing_table(table, path)
78+
local.register_listing_table(table, path)
7679
else:
7780
ctx.register_parquet(table, path)
81+
local.register_parquet(table, path)
7882

7983
current_time_millis = int(datetime.now().timestamp() * 1000)
8084
results_path = f"datafusion-ray-tpch-{current_time_millis}.json"
@@ -99,28 +103,27 @@ def main(
99103
for qnum in queries:
100104
sql = tpch_query(qnum)
101105

102-
statements = sql.split(";")
103-
sql = statements[0]
104-
105-
print("executing ", sql)
106+
statements = list(
107+
filter(lambda x: len(x) > 0, map(lambda x: x.strip(), sql.split(";")))
108+
)
109+
print(f"statements = {statements}")
106110

107111
start_time = time.time()
108-
df = ctx.sql(sql)
109-
batches = df.collect()
112+
all_batches = []
113+
for sql in statements:
114+
print("executing ", sql)
115+
df = ctx.sql(sql)
116+
all_batches.append(df.collect())
110117
end_time = time.time()
111118
results["queries"][qnum] = end_time - start_time
112119

113-
calculated = prettify(batches)
120+
calculated = "\n".join([prettify(b) for b in all_batches])
114121
print(calculated)
115122
if validate:
116-
tables = [
117-
(name, os.path.join(data_path, f"{name}.parquet"))
118-
for name in table_names
119-
]
120-
answer_batches = [
121-
b for b in [exec_sql_on_tables(sql, tables, listing_tables)] if b
122-
]
123-
expected = prettify(answer_batches)
123+
all_batches = []
124+
for sql in statements:
125+
all_batches.append(local.collect_sql(sql))
126+
expected = "\n".join([prettify(b) for b in all_batches])
124127

125128
results["validated"][qnum] = calculated == expected
126129
print(f"done with query {qnum}")

0 commit comments

Comments
 (0)