Skip to content

Commit 2e00dc9

Browse files
committed
Fix inference code
1 parent 61da04a commit 2e00dc9

File tree

3 files changed

+33
-31
lines changed

3 files changed

+33
-31
lines changed

src/prm_evaluation/genprm_inference.py

Lines changed: 10 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -65,10 +65,14 @@ def execute(self, text):
6565

6666

6767
class GenPRM:
68-
def __init__(self, model_path):
68+
def __init__(self, model_path, tensor_parallel_size):
6969
# Load the model and tokenizer
7070
timestamped_print(f"Loading model from {model_path}", level="INFO")
71-
self.model = LLM(model=model_path, gpu_memory_utilization=0.90, enable_chunked_prefill=True)
71+
self.model = LLM(
72+
model=model_path,
73+
tensor_parallel_size=tensor_parallel_size,
74+
enable_chunked_prefill=True
75+
)
7276
self.tokenizer = AutoTokenizer.from_pretrained(model_path)
7377
timestamped_print(f"GenPRM loaded successfully", level="INFO")
7478

@@ -301,16 +305,16 @@ def _single_inference(
301305
cur_time += 1
302306
new_prompts = []
303307
if output2.text.endswith('</output>\n'):
304-
output2.text = cur_prompt + output2.text
308+
output2.text = cur_prompts[0] + output2.text
305309
out_nodes.append(output2)
306310
else:
307311
if execute:
308312
# execute the code
309-
code_output = code_executor.execute(cur_prompt + output2.text)
313+
code_output = code_executor.execute(cur_prompts[0] + output2.text)
310314
code_content = f"[Code Output]\n\n```\n{code_output}\n```\n"
311-
new_prompts.append(cur_prompt + output2.text + code_content)
315+
new_prompts.append(cur_prompts[0] + output2.text + code_content)
312316
else:
313-
new_prompts.append(cur_prompt + output2.text + '[Code Output]\n\n```\n')
317+
new_prompts.append(cur_prompts[0] + output2.text + '[Code Output]\n\n```\n')
314318

315319
cur_prompts = new_prompts
316320

src/prm_evaluation/prm_evaluate.py

Lines changed: 17 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@
77
root_dir = os.path.abspath(os.path.join(current_dir, ".."))
88
sys.path.append(root_dir)
99
import argparse
10-
import os
10+
import json
1111
import random
1212
import time
1313
import threading
@@ -61,6 +61,9 @@ def parse_args():
6161
parser.add_argument("--analyze_template", type=str, default="<analyze>\nLet's analyze the Paragraph {cur_step} step by step: ")
6262
parser.add_argument("--verify_template", type=str, default="<verify>\nLet's use python code to find any potential error:\n```python\n")
6363
parser.add_argument("--output_template", type=str, default="<output>\n**Judgement**: $\\boxed")
64+
parser.add_argument("--tensor_parallel_size", type=int, default=1)
65+
parser.add_argument("--idd", type=int, default=1)
66+
6467
return parser.parse_args()
6568

6669

@@ -73,7 +76,7 @@ def parse_args():
7376

7477
##################################################### model load with VLLM ########################################################
7578

76-
genprm = GenPRM(args.reward_name_or_path)
79+
genprm = GenPRM(args.reward_name_or_path, args.tensor_parallel_size)
7780

7881
##################################################### load splited dataset ########################################################
7982

@@ -91,6 +94,7 @@ def get_shuffled_folders(directory):
9194
for data_path in target_list:
9295
folder_name = os.path.basename(data_path)
9396
save_path = os.path.join(args.split_out, folder_name)
97+
9498
if args.analyze:
9599
save_path += '_analyze'
96100
if args.verify:
@@ -125,11 +129,9 @@ def get_shuffled_folders(directory):
125129
thread.start()
126130
timestamped_print("Heartbeat thread started. Main thread continues...")
127131

128-
data = load_from_disk(os.path.join(args.data_path, folder_name))
129-
timestamped_print(data)
130-
data_new = data.to_list()
131-
132-
sample = deepcopy(data_new)[0]
132+
with open(os.path.join(args.data_path, folder_name, 'sample.json'), 'r') as f:
133+
data_new = json.load(f)
134+
sample = deepcopy(data_new)
133135
data_input = sample['steps']
134136
data_input[0] = sample['problem'] + '\n' + data_input[0]
135137
if data_input and data_input[-1] == '':
@@ -143,11 +145,11 @@ def get_shuffled_folders(directory):
143145
else:
144146
message = {
145147
'conversation': [
146-
{'role': 'system', 'content': 'You are a math teacher. Your task is to review and critique the paragraphs in solution directly. Output your judgement in the format of `boxed{Yes}` if the paragraph is correct, or `boxed{No}` if the paragraph is incorrect.'}
148+
{'role': 'system', 'content': 'You are a math teacher. Your task is to review and critique the paragraphs in solution directly. Output your judgement in the format of `\\boxed{Yes}` if the paragraph is correct, or `\\boxed{No}` if the paragraph is incorrect.'}
147149
]
148150
}
149151
for j1 in range(len(data_input)):
150-
line = {'content': data_input[j1], 'role': 'user'}
152+
line = {'role': 'user', 'content': data_input[j1]}
151153
message['conversation'].append(line)
152154
line = {'content': '', 'role': 'assistant'}
153155
message['conversation'].append(line)
@@ -192,12 +194,13 @@ def get_shuffled_folders(directory):
192194
step_scores.append(reward)
193195

194196
end = time.perf_counter()
195-
data_new[0]['time'] = end - start
196-
data_new[0]['value'] = step_scores
197-
data_new[0]['conversation'] = conversation
197+
data_new['time'] = end - start
198+
data_new['value'] = step_scores
199+
data_new['conversation'] = conversation
200+
198201
timestamped_print(type(data_new))
199-
timestamped_print(type(Dataset.from_list(data_new)))
200-
(Dataset.from_list(data_new)).save_to_disk(save_path)
202+
with open(os.path.join(save_path, f'result_{args.idd}.json'), 'w') as f:
203+
json.dump(data_new, f, indent=4)
201204
timestamped_print(f"dataset has been saved to: {save_path}")
202205
except Exception as e:
203206
traceback.print_exc()

src/utils/split_dataset.py

Lines changed: 6 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -8,9 +8,10 @@
88
--split_dir _data/split_input/ProcessBench
99
"""
1010

11-
from datasets import load_dataset, Dataset
1211
import argparse
12+
import json
1313
import os
14+
from datasets import load_dataset, Dataset
1415

1516

1617
def export_all_splits(dataset_name, root_output_dir):
@@ -43,17 +44,11 @@ def process_split(dataset, split_name, root_output_dir):
4344
for idx, example in enumerate(dataset):
4445
example_dir = os.path.join(
4546
split_dir,
46-
f"{split_name}_example_{idx:05d}"
47+
f"{split_name}_{idx:03d}"
4748
)
48-
create_single_example_dataset(example, dataset.info, example_dir)
49-
50-
51-
def create_single_example_dataset(example, info, output_dir):
52-
"""Create self-contained dataset for one example"""
53-
os.makedirs(output_dir, exist_ok=True)
54-
single_ds = Dataset.from_list([example])
55-
# single_ds.info = info
56-
single_ds.save_to_disk(output_dir)
49+
os.makedirs(example_dir, exist_ok=True)
50+
with open(os.path.join(example_dir, "sample.json"), "w") as f:
51+
json.dump(example, f, ensure_ascii=False, indent=4)
5752

5853

5954
if __name__ == "__main__":

0 commit comments

Comments
 (0)