@@ -77,7 +77,9 @@ def _get_worker_inputs(
77
77
plan_bytes = datafusion_ray .serialize_execution_plan (stage .get_execution_plan ())
78
78
futures = []
79
79
opt = {}
80
- opt ["resources" ] = {"worker" : 1e-3 }
80
+ # TODO not sure why we had this but my Ray cluster could not find suitable resource
81
+ # until I commented this out
82
+ # opt["resources"] = {"worker": 1e-3}
81
83
opt ["num_returns" ] = output_partitions_count
82
84
for part in range (concurrency ):
83
85
ids , inputs = _get_worker_inputs (part )
@@ -93,7 +95,6 @@ def _get_worker_inputs(
93
95
def execute_query_stage (
94
96
query_stages : list [QueryStage ],
95
97
stage_id : int ,
96
- use_ray_shuffle : bool ,
97
98
) -> tuple [int , list [ray .ObjectRef ]]:
98
99
"""
99
100
Execute a query stage on the workers.
@@ -106,7 +107,7 @@ def execute_query_stage(
106
107
child_futures = []
107
108
for child_id in stage .get_child_stage_ids ():
108
109
child_futures .append (
109
- execute_query_stage .remote (query_stages , child_id , use_ray_shuffle )
110
+ execute_query_stage .remote (query_stages , child_id )
110
111
)
111
112
112
113
# if the query stage has a single output partition then we need to execute for the output
@@ -133,33 +134,28 @@ def _get_worker_inputs(
133
134
) -> tuple [list [tuple [int , int , int ]], list [ray .ObjectRef ]]:
134
135
ids = []
135
136
futures = []
136
- if use_ray_shuffle :
137
- for child_stage_id , child_futures in child_outputs :
138
- for i , lst in enumerate (child_futures ):
139
- if isinstance (lst , list ):
140
- for j , f in enumerate (lst ):
141
- if concurrency == 1 or j == part :
142
- # If concurrency is 1, pass in all shuffle partitions. Otherwise,
143
- # only pass in the partitions that match the current worker partition.
144
- ids .append ((child_stage_id , i , j ))
145
- futures .append (f )
146
- elif concurrency == 1 or part == 0 :
147
- ids .append ((child_stage_id , i , 0 ))
148
- futures .append (lst )
137
+ for child_stage_id , child_futures in child_outputs :
138
+ for i , lst in enumerate (child_futures ):
139
+ if isinstance (lst , list ):
140
+ for j , f in enumerate (lst ):
141
+ if concurrency == 1 or j == part :
142
+ # If concurrency is 1, pass in all shuffle partitions. Otherwise,
143
+ # only pass in the partitions that match the current worker partition.
144
+ ids .append ((child_stage_id , i , j ))
145
+ futures .append (f )
146
+ elif concurrency == 1 or part == 0 :
147
+ ids .append ((child_stage_id , i , 0 ))
148
+ futures .append (lst )
149
149
return ids , futures
150
150
151
- # if we are using disk-based shuffle, wait until the child stages to finish
152
- # writing the shuffle files to disk first.
153
- if not use_ray_shuffle :
154
- ray .get ([f for _ , lst in child_outputs for f in lst ])
155
-
156
151
# schedule the actual execution workers
157
152
plan_bytes = datafusion_ray .serialize_execution_plan (stage .get_execution_plan ())
158
153
futures = []
159
154
opt = {}
160
- opt ["resources" ] = {"worker" : 1e-3 }
161
- if use_ray_shuffle :
162
- opt ["num_returns" ] = output_partitions_count
155
+ # TODO not sure why we had this but my Ray cluster could not find suitable resource
156
+ # until I commented this out
157
+ #opt["resources"] = {"worker": 1e-3}
158
+ opt ["num_returns" ] = output_partitions_count
163
159
for part in range (concurrency ):
164
160
ids , inputs = _get_worker_inputs (part )
165
161
futures .append (
@@ -210,10 +206,9 @@ def execute_query_partition(
210
206
211
207
212
208
class DatafusionRayContext :
213
- def __init__ (self , num_workers : int = 1 , use_ray_shuffle : bool = False ):
214
- self .ctx = Context (num_workers , use_ray_shuffle )
209
+ def __init__ (self , num_workers : int = 1 ):
210
+ self .ctx = Context (num_workers )
215
211
self .num_workers = num_workers
216
- self .use_ray_shuffle = use_ray_shuffle
217
212
218
213
def register_csv (self , table_name : str , path : str , has_header : bool ):
219
214
self .ctx .register_csv (table_name , path , has_header )
@@ -234,23 +229,7 @@ def sql(self, sql: str) -> pa.RecordBatch:
234
229
235
230
graph = self .ctx .plan (sql )
236
231
final_stage_id = graph .get_final_query_stage ().id ()
237
- if self .use_ray_shuffle :
238
- partitions = schedule_execution (graph , final_stage_id , True )
239
- else :
240
- # serialize the query stages and store in Ray object store
241
- query_stages = [
242
- datafusion_ray .serialize_execution_plan (
243
- graph .get_query_stage (i ).get_execution_plan ()
244
- )
245
- for i in range (final_stage_id + 1 )
246
- ]
247
- # schedule execution
248
- future = execute_query_stage .remote (
249
- query_stages ,
250
- final_stage_id ,
251
- self .use_ray_shuffle ,
252
- )
253
- _ , partitions = ray .get (future )
232
+ partitions = schedule_execution (graph , final_stage_id , True )
254
233
# assert len(partitions) == 1, len(partitions)
255
234
result_set = ray .get (partitions [0 ])
256
235
return result_set
0 commit comments