This repository was archived by the owner on Oct 2, 2024. It is now read-only.
-
Notifications
You must be signed in to change notification settings - Fork 16
/
Copy pathquery_stage.rs
121 lines (105 loc) · 3.86 KB
/
query_stage.rs
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
// Licensed to the Apache Software Foundation (ASF) under one
// or more contributor license agreements. See the NOTICE file
// distributed with this work for additional information
// regarding copyright ownership. The ASF licenses this file
// to you under the Apache License, Version 2.0 (the
// "License"); you may not use this file except in compliance
// with the License. You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing,
// software distributed under the License is distributed on an
// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
// KIND, either express or implied. See the License for the
// specific language governing permissions and limitations
// under the License.
use crate::shuffle::{RayShuffleReaderExec, ShuffleCodec, ShuffleReaderExec};
use datafusion::error::Result;
use datafusion::physical_plan::{ExecutionPlan, Partitioning};
use datafusion::prelude::SessionContext;
use datafusion_proto::bytes::physical_plan_from_bytes_with_extension_codec;
use datafusion_python::physical_plan::PyExecutionPlan;
use pyo3::prelude::*;
use std::sync::Arc;
#[pyclass(name = "QueryStage", module = "raysql", subclass)]
pub struct PyQueryStage {
stage: Arc<QueryStage>,
}
impl PyQueryStage {
pub fn from_rust(stage: Arc<QueryStage>) -> Self {
Self { stage }
}
}
#[pymethods]
impl PyQueryStage {
#[new]
pub fn new(id: usize, bytes: Vec<u8>) -> Result<Self> {
let ctx = SessionContext::new();
let codec = ShuffleCodec {};
let plan = physical_plan_from_bytes_with_extension_codec(&bytes, &ctx, &codec)?;
Ok(PyQueryStage {
stage: Arc::new(QueryStage { id, plan }),
})
}
pub fn id(&self) -> usize {
self.stage.id
}
pub fn get_execution_plan(&self) -> PyExecutionPlan {
PyExecutionPlan::new(self.stage.plan.clone())
}
pub fn get_child_stage_ids(&self) -> Vec<usize> {
self.stage.get_child_stage_ids()
}
pub fn get_input_partition_count(&self) -> usize {
self.stage.get_input_partition_count()
}
pub fn get_output_partition_count(&self) -> usize {
self.stage.get_output_partition_count()
}
}
#[derive(Debug)]
pub struct QueryStage {
pub id: usize,
pub plan: Arc<dyn ExecutionPlan>,
}
fn _get_output_partition_count(plan: &dyn ExecutionPlan) -> usize {
// UnknownPartitioning and HashPartitioning with empty expressions will
// both return 1 partition.
match plan.output_partitioning() {
Partitioning::UnknownPartitioning(_) => 1,
Partitioning::Hash(expr, _) if expr.is_empty() => 1,
p => p.partition_count(),
}
}
impl QueryStage {
pub fn new(id: usize, plan: Arc<dyn ExecutionPlan>) -> Self {
Self { id, plan }
}
pub fn get_child_stage_ids(&self) -> Vec<usize> {
let mut ids = vec![];
collect_child_stage_ids(self.plan.as_ref(), &mut ids);
ids
}
/// Get the input partition count. This is the same as the number of concurrent tasks
/// when we schedule this query stage for execution
pub fn get_input_partition_count(&self) -> usize {
self.plan.children()[0]
.output_partitioning()
.partition_count()
}
pub fn get_output_partition_count(&self) -> usize {
_get_output_partition_count(self.plan.as_ref())
}
}
fn collect_child_stage_ids(plan: &dyn ExecutionPlan, ids: &mut Vec<usize>) {
if let Some(shuffle_reader) = plan.as_any().downcast_ref::<ShuffleReaderExec>() {
ids.push(shuffle_reader.stage_id);
} else if let Some(shuffle_reader) = plan.as_any().downcast_ref::<RayShuffleReaderExec>() {
ids.push(shuffle_reader.stage_id);
} else {
for child_plan in plan.children() {
collect_child_stage_ids(child_plan.as_ref(), ids);
}
}
}