-
Notifications
You must be signed in to change notification settings - Fork 423
Expand file tree
/
Copy pathgsm8k_rl.py
More file actions
47 lines (39 loc) · 1.33 KB
/
gsm8k_rl.py
File metadata and controls
47 lines (39 loc) · 1.33 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
import sys
from areal import PPOTrainer
from areal.api.cli_args import GRPOConfig, load_expr_config
from areal.dataset import get_custom_dataset
from areal.utils.hf_utils import load_hf_tokenizer
def main(args):
config, _ = load_expr_config(args, GRPOConfig)
tokenizer = load_hf_tokenizer(config.tokenizer_path)
train_dataset = get_custom_dataset(
split="train",
dataset_config=config.train_dataset,
tokenizer=tokenizer,
)
valid_dataset = get_custom_dataset(
split="test",
dataset_config=config.valid_dataset,
tokenizer=tokenizer,
)
workflow_kwargs = dict(
reward_fn="areal.reward.gsm8k.gsm8k_reward_fn",
gconfig=config.gconfig,
tokenizer=config.tokenizer_path,
enable_thinking=False,
)
eval_workflow_kwargs = workflow_kwargs.copy()
eval_workflow_kwargs["gconfig"] = config.gconfig.new(temperature=0.6)
with PPOTrainer(
config,
train_dataset=train_dataset,
valid_dataset=valid_dataset,
) as trainer:
trainer.train(
workflow="areal.workflow.rlvr.RLVRWorkflow",
workflow_kwargs=workflow_kwargs,
eval_workflow="areal.workflow.rlvr.RLVRWorkflow",
eval_workflow_kwargs=eval_workflow_kwargs,
)
if __name__ == "__main__":
main(sys.argv[1:])