Skip to content

Commit 1802bec

Browse files
AlanShao-zylewtun
andauthored
fix dataset parsing error (#540)
* fix dataset parsing error support defined question field to fix errors when datasets' question field is not 'problem' * add question field config add script_args: question field * refactor: datasets prompt column --------- Co-authored-by: lewtun <[email protected]>
1 parent 4ec555b commit 1802bec

File tree

3 files changed

+11
-3
lines changed

3 files changed

+11
-3
lines changed

.gitignore

+2-1
Original file line numberDiff line numberDiff line change
@@ -177,4 +177,5 @@ logs/
177177
eval_results/
178178
results/
179179

180-
.vscode/
180+
.vscode/
181+
.python-version

src/open_r1/configs.py

+4
Original file line numberDiff line numberDiff line change
@@ -154,3 +154,7 @@ class GRPOScriptArguments(trl.ScriptArguments):
154154
"help": "for each generation, evaluate these many test cases in parallel, then check if any of them failed (0 score): if so stop evaluating; otherwise continue with the next batch of test cases. Useful to avoid overloading the eval server + save time on wrong solutions"
155155
},
156156
)
157+
dataset_prompt_column: str = field(
158+
default="prompt",
159+
metadata={"help": "Column to use as prompts for training."},
160+
)

src/open_r1/grpo.py

+5-2
Original file line numberDiff line numberDiff line change
@@ -84,13 +84,16 @@ def main(script_args, training_args, model_args):
8484
reward_funcs = get_reward_funcs(script_args)
8585

8686
# Format into conversation
87-
def make_conversation(example):
87+
def make_conversation(example, prompt_column: str = script_args.dataset_prompt_column):
8888
prompt = []
8989

9090
if training_args.system_prompt is not None:
9191
prompt.append({"role": "system", "content": training_args.system_prompt})
9292

93-
prompt.append({"role": "user", "content": example["problem"]})
93+
if prompt_column not in example:
94+
raise ValueError(f"Dataset Question Field Error: {prompt_column} is not supported.")
95+
96+
prompt.append({"role": "user", "content": example[prompt_column]})
9497
return {"prompt": prompt}
9598

9699
dataset = dataset.map(make_conversation)

0 commit comments

Comments
 (0)