@@ -73,26 +73,18 @@ def generate_records(
7373 subprocess .run (gensort_cmd .split ()).check_returncode ()
7474 runtime_task .add_elapsed_time ("generate records (secs)" )
7575 shm_file .seek (0 )
76- buffer = arrow .py_buffer (
77- shm_file .read (record_count * record_nbytes )
78- )
76+ buffer = arrow .py_buffer (shm_file .read (record_count * record_nbytes ))
7977 runtime_task .add_elapsed_time ("read records (secs)" )
8078 # https://arrow.apache.org/docs/format/Columnar.html#fixed-size-primitive-layout
81- records = arrow .Array .from_buffers (
82- arrow .binary (record_nbytes ), record_count , [None , buffer ]
83- )
79+ records = arrow .Array .from_buffers (arrow .binary (record_nbytes ), record_count , [None , buffer ])
8480 keys = pc .binary_slice (records , 0 , key_nbytes )
8581 # get first 2 bytes and convert to big-endian uint16
8682 binary_prefix = pc .binary_slice (records , 0 , 2 ).cast (arrow .binary ())
87- reversed_prefix = pc .binary_reverse (binary_prefix ).cast (
88- arrow .binary (2 )
89- )
83+ reversed_prefix = pc .binary_reverse (binary_prefix ).cast (arrow .binary (2 ))
9084 uint16_prefix = reversed_prefix .view (arrow .uint16 ())
9185 buckets = pc .shift_right (uint16_prefix , 16 - bucket_nbits )
9286 runtime_task .add_elapsed_time ("build arrow table (secs)" )
93- yield arrow .Table .from_arrays (
94- [buckets , keys , records ], schema = schema
95- )
87+ yield arrow .Table .from_arrays ([buckets , keys , records ], schema = schema )
9688 yield StreamOutput (
9789 schema .empty_table (),
9890 batch_indices = [batch_idx ],
@@ -108,9 +100,7 @@ def sort_records(
108100 write_io_nbytes = 500 * MB ,
109101) -> bool :
110102 runtime_task : PythonScriptTask = runtime_ctx .task
111- data_file_path = os .path .join (
112- runtime_task .runtime_output_abspath , f"{ runtime_task .output_filename } .dat"
113- )
103+ data_file_path = os .path .join (runtime_task .runtime_output_abspath , f"{ runtime_task .output_filename } .dat" )
114104
115105 if sort_engine == "polars" :
116106 input_data = polars .read_parquet (
@@ -134,9 +124,7 @@ def sort_records(
134124 record_arrays = sorted_table .column ("records" ).chunks
135125 runtime_task .add_elapsed_time ("convert to chunks (secs)" )
136126 elif sort_engine == "duckdb" :
137- with duckdb .connect (
138- database = ":memory:" , config = {"allow_unsigned_extensions" : "true" }
139- ) as conn :
127+ with duckdb .connect (database = ":memory:" , config = {"allow_unsigned_extensions" : "true" }) as conn :
140128 runtime_task .prepare_connection (conn )
141129 input_views = runtime_task .create_input_views (conn , input_datasets )
142130 sql_query = "select records from {0} order by keys" .format (* input_views )
@@ -154,8 +142,7 @@ def sort_records(
154142 buffer_mem = memoryview (values )
155143
156144 total_write_nbytes = sum (
157- fout .write (buffer_mem [offset : offset + write_io_nbytes ])
158- for offset in range (0 , len (buffer_mem ), write_io_nbytes )
145+ fout .write (buffer_mem [offset : offset + write_io_nbytes ]) for offset in range (0 , len (buffer_mem ), write_io_nbytes )
159146 )
160147 assert total_write_nbytes == len (buffer_mem )
161148
@@ -164,16 +151,10 @@ def sort_records(
164151 return True
165152
166153
167- def validate_records (
168- runtime_ctx : RuntimeContext , input_datasets : List [DataSet ], output_path : str
169- ) -> bool :
154+ def validate_records (runtime_ctx : RuntimeContext , input_datasets : List [DataSet ], output_path : str ) -> bool :
170155 for data_path in input_datasets [0 ].resolved_paths :
171- summary_path = os .path .join (
172- output_path , PurePath (data_path ).with_suffix (".sum" ).name
173- )
174- cmdstr = (
175- f"{ SortBenchTool .valsort_path } -o { summary_path } { data_path } ,buf,trans=10m"
176- )
156+ summary_path = os .path .join (output_path , PurePath (data_path ).with_suffix (".sum" ).name )
157+ cmdstr = f"{ SortBenchTool .valsort_path } -o { summary_path } { data_path } ,buf,trans=10m"
177158 logging .debug (f"running command: { cmdstr } " )
178159 result = subprocess .run (cmdstr .split (), capture_output = True , encoding = "utf8" )
179160 if result .stderr :
@@ -185,9 +166,7 @@ def validate_records(
185166 return True
186167
187168
188- def validate_summary (
189- runtime_ctx : RuntimeContext , input_datasets : List [DataSet ], output_path : str
190- ) -> bool :
169+ def validate_summary (runtime_ctx : RuntimeContext , input_datasets : List [DataSet ], output_path : str ) -> bool :
191170 concated_summary_path = os .path .join (output_path , "merged.sum" )
192171 with open (concated_summary_path , "wb" ) as fout :
193172 for path in input_datasets [0 ].resolved_paths :
@@ -224,22 +203,13 @@ def generate_random_records(
224203 )
225204
226205 range_begin_at = [pos for pos in range (0 , total_num_records , record_range_size )]
227- range_num_records = [
228- min (total_num_records , record_range_size * (range_idx + 1 )) - begin_at
229- for range_idx , begin_at in enumerate (range_begin_at )
230- ]
206+ range_num_records = [min (total_num_records , record_range_size * (range_idx + 1 )) - begin_at for range_idx , begin_at in enumerate (range_begin_at )]
231207 assert sum (range_num_records ) == total_num_records
232208 record_range = DataSourceNode (
233209 ctx ,
234- ArrowTableDataSet (
235- arrow .Table .from_arrays (
236- [range_begin_at , range_num_records ], names = ["begin_at" , "num_records" ]
237- )
238- ),
239- )
240- record_range_partitions = DataSetPartitionNode (
241- ctx , (record_range ,), npartitions = num_data_partitions , partition_by_rows = True
210+ ArrowTableDataSet (arrow .Table .from_arrays ([range_begin_at , range_num_records ], names = ["begin_at" , "num_records" ])),
242211 )
212+ record_range_partitions = DataSetPartitionNode (ctx , (record_range ,), npartitions = num_data_partitions , partition_by_rows = True )
243213
244214 random_records = ArrowStreamNode (
245215 ctx ,
@@ -288,9 +258,7 @@ def gray_sort_benchmark(
288258 if input_paths :
289259 input_dataset = ParquetDataSet (input_paths )
290260 input_nbytes = sum (os .path .getsize (p ) for p in input_dataset .resolved_paths )
291- logging .warning (
292- f"input data size: { input_nbytes / GB :.3f} GB, { input_dataset .num_files } files"
293- )
261+ logging .warning (f"input data size: { input_nbytes / GB :.3f} GB, { input_dataset .num_files } files" )
294262 random_records = DataSourceNode (ctx , input_dataset )
295263 else :
296264 random_records = generate_random_records (
@@ -335,12 +303,8 @@ def gray_sort_benchmark(
335303 process_func = validate_records ,
336304 output_name = "partitioned_summaries" ,
337305 )
338- merged_summaries = DataSetPartitionNode (
339- ctx , (partitioned_summaries ,), npartitions = 1
340- )
341- final_check = PythonScriptNode (
342- ctx , (merged_summaries ,), process_func = validate_summary
343- )
306+ merged_summaries = DataSetPartitionNode (ctx , (partitioned_summaries ,), npartitions = 1 )
307+ final_check = PythonScriptNode (ctx , (merged_summaries ,), process_func = validate_summary )
344308 root = final_check
345309 else :
346310 root = sorted_records
@@ -359,17 +323,11 @@ def main():
359323 driver .add_argument ("-n" , "--num_data_partitions" , type = int , default = None )
360324 driver .add_argument ("-t" , "--num_sort_partitions" , type = int , default = None )
361325 driver .add_argument ("-i" , "--input_paths" , nargs = "+" , default = [])
362- driver .add_argument (
363- "-e" , "--shuffle_engine" , default = "duckdb" , choices = ("duckdb" , "arrow" )
364- )
365- driver .add_argument (
366- "-s" , "--sort_engine" , default = "duckdb" , choices = ("duckdb" , "arrow" , "polars" )
367- )
326+ driver .add_argument ("-e" , "--shuffle_engine" , default = "duckdb" , choices = ("duckdb" , "arrow" ))
327+ driver .add_argument ("-s" , "--sort_engine" , default = "duckdb" , choices = ("duckdb" , "arrow" , "polars" ))
368328 driver .add_argument ("-H" , "--hive_partitioning" , action = "store_true" )
369329 driver .add_argument ("-V" , "--validate_results" , action = "store_true" )
370- driver .add_argument (
371- "-C" , "--shuffle_cpu_limit" , type = int , default = ShuffleNode .default_cpu_limit
372- )
330+ driver .add_argument ("-C" , "--shuffle_cpu_limit" , type = int , default = ShuffleNode .default_cpu_limit )
373331 driver .add_argument (
374332 "-M" ,
375333 "--shuffle_memory_limit" ,
@@ -378,12 +336,8 @@ def main():
378336 )
379337 driver .add_argument ("-TC" , "--sort_cpu_limit" , type = int , default = 8 )
380338 driver .add_argument ("-TM" , "--sort_memory_limit" , type = int , default = None )
381- driver .add_argument (
382- "-NC" , "--cpus_per_node" , type = int , default = psutil .cpu_count (logical = False )
383- )
384- driver .add_argument (
385- "-NM" , "--memory_per_node" , type = int , default = psutil .virtual_memory ().total
386- )
339+ driver .add_argument ("-NC" , "--cpus_per_node" , type = int , default = psutil .cpu_count (logical = False ))
340+ driver .add_argument ("-NM" , "--memory_per_node" , type = int , default = psutil .virtual_memory ().total )
387341 driver .add_argument ("-CP" , "--parquet_compression" , default = None )
388342 driver .add_argument ("-LV" , "--parquet_compression_level" , type = int , default = None )
389343
@@ -393,16 +347,9 @@ def main():
393347 total_num_cpus = max (1 , driver_args .num_executors ) * user_args .cpus_per_node
394348 memory_per_cpu = user_args .memory_per_node // user_args .cpus_per_node
395349
396- user_args .sort_cpu_limit = (
397- 1 if user_args .sort_engine == "arrow" else user_args .sort_cpu_limit
398- )
399- sort_memory_limit = (
400- user_args .sort_memory_limit or user_args .sort_cpu_limit * memory_per_cpu
401- )
402- user_args .total_data_nbytes = (
403- user_args .total_data_nbytes
404- or max (1 , driver_args .num_executors ) * user_args .memory_per_node
405- )
350+ user_args .sort_cpu_limit = 1 if user_args .sort_engine == "arrow" else user_args .sort_cpu_limit
351+ sort_memory_limit = user_args .sort_memory_limit or user_args .sort_cpu_limit * memory_per_cpu
352+ user_args .total_data_nbytes = user_args .total_data_nbytes or max (1 , driver_args .num_executors ) * user_args .memory_per_node
406353 user_args .num_data_partitions = user_args .num_data_partitions or total_num_cpus // 2
407354 user_args .num_sort_partitions = user_args .num_sort_partitions or max (
408355 total_num_cpus // user_args .sort_cpu_limit ,
0 commit comments