@@ -96,49 +96,10 @@ impl DFRayDataFrame {
96
96
) -> PyResult < Vec < PyDFRayStage > > {
97
97
let mut stages = vec ! [ ] ;
98
98
99
- // TODO: This can be done more efficiently, likely in one pass but I'm
100
- // struggling to get the TreeNodeRecursion return values to make it do
101
- // what I want. So, two steps for now
102
-
103
- // Step 2: we walk down this stage and replace stages earlier in the tree with
104
- // RayStageReaderExecs as we will need to consume their output instead of
105
- // execute that part of the tree ourselves
106
- let down = |plan : Arc < dyn ExecutionPlan > | {
107
- trace ! (
108
- "examining plan down:\n {}" ,
109
- display_plan_with_partition_counts( & plan)
110
- ) ;
111
-
112
- if let Some ( stage_exec) = plan. as_any ( ) . downcast_ref :: < DFRayStageExec > ( ) {
113
- let input = plan. children ( ) ;
114
- assert ! ( input. len( ) == 1 , "RayStageExec must have exactly one child" ) ;
115
- let input = input[ 0 ] ;
116
-
117
- trace ! (
118
- "inserting a ray stage reader to consume: {} with partitioning {}" ,
119
- displayable( plan. as_ref( ) ) . one_line( ) ,
120
- plan. output_partitioning( ) . partition_count( )
121
- ) ;
122
-
123
- let replacement = Arc :: new ( DFRayStageReaderExec :: try_new (
124
- plan. output_partitioning ( ) . clone ( ) ,
125
- input. schema ( ) ,
126
- stage_exec. stage_id ,
127
- ) ?) as Arc < dyn ExecutionPlan > ;
128
-
129
- Ok ( Transformed {
130
- data : replacement,
131
- transformed : true ,
132
- tnr : TreeNodeRecursion :: Jump ,
133
- } )
134
- } else {
135
- Ok ( Transformed :: no ( plan) )
136
- }
137
- } ;
138
-
139
99
let mut partition_groups = vec ! [ ] ;
140
100
let mut full_partitions = false ;
141
- // Step 1: we walk up the tree from the leaves to find the stages
101
+ // We walk up the tree from the leaves to find the stages, record ray stages, and replace
102
+ // each ray stage with a corresponding ray reader stage.
142
103
let up = |plan : Arc < dyn ExecutionPlan > | {
143
104
trace ! (
144
105
"Examining plan up: {}" ,
@@ -151,19 +112,23 @@ impl DFRayDataFrame {
151
112
assert ! ( input. len( ) == 1 , "RayStageExec must have exactly one child" ) ;
152
113
let input = input[ 0 ] ;
153
114
154
- let fixed_plan = input. clone ( ) . transform_down ( down) ?. data ;
115
+ let replacement = Arc :: new ( DFRayStageReaderExec :: try_new (
116
+ plan. output_partitioning ( ) . clone ( ) ,
117
+ input. schema ( ) ,
118
+ stage_exec. stage_id ,
119
+ ) ?) as Arc < dyn ExecutionPlan > ;
155
120
156
121
let stage = PyDFRayStage :: new (
157
122
stage_exec. stage_id ,
158
- fixed_plan ,
123
+ input . clone ( ) ,
159
124
partition_groups. clone ( ) ,
160
125
full_partitions,
161
126
) ;
162
127
partition_groups = vec ! [ ] ;
163
128
full_partitions = false ;
164
129
165
130
stages. push ( stage) ;
166
- Ok ( Transformed :: no ( plan ) )
131
+ Ok ( Transformed :: yes ( replacement ) )
167
132
} else if plan. as_any ( ) . downcast_ref :: < RepartitionExec > ( ) . is_some ( ) {
168
133
trace ! ( "repartition exec" ) ;
169
134
let ( calculated_partition_groups, replacement) = build_replacement (
0 commit comments