Skip to content

Commit dcea736

Browse files
authored
Address a TODO about simplify Ray stages collection (#80)
1 parent 4c83d76 commit dcea736

File tree

1 file changed

+9
-44
lines changed

1 file changed

+9
-44
lines changed

src/dataframe.rs

+9-44
Original file line numberDiff line numberDiff line change
@@ -96,49 +96,10 @@ impl DFRayDataFrame {
9696
) -> PyResult<Vec<PyDFRayStage>> {
9797
let mut stages = vec![];
9898

99-
// TODO: This can be done more efficiently, likely in one pass but I'm
100-
// struggling to get the TreeNodeRecursion return values to make it do
101-
// what I want. So, two steps for now
102-
103-
// Step 2: we walk down this stage and replace stages earlier in the tree with
104-
// RayStageReaderExecs as we will need to consume their output instead of
105-
// execute that part of the tree ourselves
106-
let down = |plan: Arc<dyn ExecutionPlan>| {
107-
trace!(
108-
"examining plan down:\n{}",
109-
display_plan_with_partition_counts(&plan)
110-
);
111-
112-
if let Some(stage_exec) = plan.as_any().downcast_ref::<DFRayStageExec>() {
113-
let input = plan.children();
114-
assert!(input.len() == 1, "RayStageExec must have exactly one child");
115-
let input = input[0];
116-
117-
trace!(
118-
"inserting a ray stage reader to consume: {} with partitioning {}",
119-
displayable(plan.as_ref()).one_line(),
120-
plan.output_partitioning().partition_count()
121-
);
122-
123-
let replacement = Arc::new(DFRayStageReaderExec::try_new(
124-
plan.output_partitioning().clone(),
125-
input.schema(),
126-
stage_exec.stage_id,
127-
)?) as Arc<dyn ExecutionPlan>;
128-
129-
Ok(Transformed {
130-
data: replacement,
131-
transformed: true,
132-
tnr: TreeNodeRecursion::Jump,
133-
})
134-
} else {
135-
Ok(Transformed::no(plan))
136-
}
137-
};
138-
13999
let mut partition_groups = vec![];
140100
let mut full_partitions = false;
141-
// Step 1: we walk up the tree from the leaves to find the stages
101+
// We walk up the tree from the leaves to find the stages, record ray stages, and replace
102+
// each ray stage with a corresponding ray reader stage.
142103
let up = |plan: Arc<dyn ExecutionPlan>| {
143104
trace!(
144105
"Examining plan up: {}",
@@ -151,19 +112,23 @@ impl DFRayDataFrame {
151112
assert!(input.len() == 1, "RayStageExec must have exactly one child");
152113
let input = input[0];
153114

154-
let fixed_plan = input.clone().transform_down(down)?.data;
115+
let replacement = Arc::new(DFRayStageReaderExec::try_new(
116+
plan.output_partitioning().clone(),
117+
input.schema(),
118+
stage_exec.stage_id,
119+
)?) as Arc<dyn ExecutionPlan>;
155120

156121
let stage = PyDFRayStage::new(
157122
stage_exec.stage_id,
158-
fixed_plan,
123+
input.clone(),
159124
partition_groups.clone(),
160125
full_partitions,
161126
);
162127
partition_groups = vec![];
163128
full_partitions = false;
164129

165130
stages.push(stage);
166-
Ok(Transformed::no(plan))
131+
Ok(Transformed::yes(replacement))
167132
} else if plan.as_any().downcast_ref::<RepartitionExec>().is_some() {
168133
trace!("repartition exec");
169134
let (calculated_partition_groups, replacement) = build_replacement(

0 commit comments

Comments
 (0)