77root_dir = os .path .abspath (os .path .join (current_dir , ".." ))
88sys .path .append (root_dir )
99import argparse
10- import os
10+ import json
1111import random
1212import time
1313import threading
@@ -61,6 +61,9 @@ def parse_args():
6161 parser .add_argument ("--analyze_template" , type = str , default = "<analyze>\n Let's analyze the Paragraph {cur_step} step by step: " )
6262 parser .add_argument ("--verify_template" , type = str , default = "<verify>\n Let's use python code to find any potential error:\n ```python\n " )
6363 parser .add_argument ("--output_template" , type = str , default = "<output>\n **Judgement**: $\\ boxed" )
64+ parser .add_argument ("--tensor_parallel_size" , type = int , default = 1 )
65+ parser .add_argument ("--idd" , type = int , default = 1 )
66+
6467 return parser .parse_args ()
6568
6669
@@ -73,7 +76,7 @@ def parse_args():
7376
7477##################################################### model load with VLLM ########################################################
7578
76- genprm = GenPRM (args .reward_name_or_path )
79+ genprm = GenPRM (args .reward_name_or_path , args . tensor_parallel_size )
7780
7881##################################################### load splited dataset ########################################################
7982
@@ -91,6 +94,7 @@ def get_shuffled_folders(directory):
9194for data_path in target_list :
9295 folder_name = os .path .basename (data_path )
9396 save_path = os .path .join (args .split_out , folder_name )
97+
9498 if args .analyze :
9599 save_path += '_analyze'
96100 if args .verify :
@@ -125,11 +129,9 @@ def get_shuffled_folders(directory):
125129 thread .start ()
126130 timestamped_print ("Heartbeat thread started. Main thread continues..." )
127131
128- data = load_from_disk (os .path .join (args .data_path , folder_name ))
129- timestamped_print (data )
130- data_new = data .to_list ()
131-
132- sample = deepcopy (data_new )[0 ]
132+ with open (os .path .join (args .data_path , folder_name , 'sample.json' ), 'r' ) as f :
133+ data_new = json .load (f )
134+ sample = deepcopy (data_new )
133135 data_input = sample ['steps' ]
134136 data_input [0 ] = sample ['problem' ] + '\n ' + data_input [0 ]
135137 if data_input and data_input [- 1 ] == '' :
@@ -143,11 +145,11 @@ def get_shuffled_folders(directory):
143145 else :
144146 message = {
145147 'conversation' : [
146- {'role' : 'system' , 'content' : 'You are a math teacher. Your task is to review and critique the paragraphs in solution directly. Output your judgement in the format of `boxed{Yes}` if the paragraph is correct, or `boxed{No}` if the paragraph is incorrect.' }
148+ {'role' : 'system' , 'content' : 'You are a math teacher. Your task is to review and critique the paragraphs in solution directly. Output your judgement in the format of `\\ boxed{Yes}` if the paragraph is correct, or `\\ boxed{No}` if the paragraph is incorrect.' }
147149 ]
148150 }
149151 for j1 in range (len (data_input )):
150- line = {'content ' : data_input [ j1 ] , 'role ' : 'user' }
152+ line = {'role ' : 'user' , 'content ' : data_input [ j1 ] }
151153 message ['conversation' ].append (line )
152154 line = {'content' : '' , 'role' : 'assistant' }
153155 message ['conversation' ].append (line )
@@ -192,12 +194,13 @@ def get_shuffled_folders(directory):
192194 step_scores .append (reward )
193195
194196 end = time .perf_counter ()
195- data_new [0 ]['time' ] = end - start
196- data_new [0 ]['value' ] = step_scores
197- data_new [0 ]['conversation' ] = conversation
197+ data_new ['time' ] = end - start
198+ data_new ['value' ] = step_scores
199+ data_new ['conversation' ] = conversation
200+
198201 timestamped_print (type (data_new ))
199- timestamped_print ( type ( Dataset . from_list ( data_new )))
200- ( Dataset . from_list (data_new )). save_to_disk ( save_path )
202+ with open ( os . path . join ( save_path , f'result_ { args . idd } .json' ), 'w' ) as f :
203+ json . dump (data_new , f , indent = 4 )
201204 timestamped_print (f"dataset has been saved to: { save_path } " )
202205 except Exception as e :
203206 traceback .print_exc ()
0 commit comments