Skip to content

Commit 5672f16

Browse files
Delaunaypierre.delaunay
andauthored
new RLHF benchmark (#273)
* new RLHF benchmark * Add RLHF config to standard --------- Co-authored-by: pierre.delaunay <[email protected]>
1 parent e327768 commit 5672f16

File tree

13 files changed

+735
-6
lines changed

13 files changed

+735
-6
lines changed

benchmarks/rlhf/Makefile

Lines changed: 31 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,31 @@
1+
# Use global base if possible
2+
ifndef MILABENCH_BASE
3+
MILABENCH_BASE="base"
4+
endif
5+
6+
export MILABENCH_BASE
7+
8+
BENCH_NAME=rlhf
9+
MILABENCH_CONFIG=dev.yaml
10+
MILABENCH_ARGS=--config $(MILABENCH_CONFIG) --base $(MILABENCH_BASE)
11+
12+
all:
13+
install prepare single gpus nodes
14+
15+
install:
16+
milabench install $(MILABENCH_ARGS) --force
17+
18+
prepare:
19+
milabench prepare $(MILABENCH_ARGS)
20+
21+
tests: install prepare
22+
milabench run $(MILABENCH_ARGS)
23+
24+
single:
25+
milabench run $(MILABENCH_ARGS) --select $(BENCH_NAME)-single
26+
27+
gpus:
28+
milabench run $(MILABENCH_ARGS) --select $(BENCH_NAME)-gpus
29+
30+
nodes:
31+
milabench run $(MILABENCH_ARGS) --select $(BENCH_NAME)-nodes

benchmarks/rlhf/README.md

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,4 @@
1+
2+
# Rlhf
3+
4+
Rewrite this README to explain what the benchmark is!

benchmarks/rlhf/benchfile.py

Lines changed: 41 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,41 @@
1+
from milabench.pack import Package
2+
3+
4+
class Rlhf(Package):
5+
# Requirements file installed by install(). It can be empty or absent.
6+
base_requirements = "requirements.in"
7+
8+
# The preparation script called by prepare(). It must be executable,
9+
# but it can be any type of script. It can be empty or absent.
10+
prepare_script = "prepare.py"
11+
12+
# The main script called by run(). It must be a Python file. It has to
13+
# be present.
14+
main_script = "main.py"
15+
16+
# You can remove the functions below if you don't need to modify them.
17+
18+
def make_env(self):
19+
# Return a dict of environment variables for prepare_script and
20+
# main_script.
21+
return super().make_env()
22+
23+
async def install(self):
24+
await super().install() # super() call installs the requirements
25+
26+
async def prepare(self):
27+
await super().prepare() # super() call executes prepare_script
28+
29+
def build_run_plan(self):
30+
from milabench.commands import PackCommand, AccelerateAllNodes
31+
32+
main = self.dirs.code / self.main_script
33+
plan = PackCommand(self, *self.argv, lazy=True)
34+
35+
if False:
36+
plan = VoirCommand(plan, cwd=main.parent)
37+
38+
return AccelerateAllNodes(plan).use_stdout()
39+
40+
41+
__pack__ = Rlhf

benchmarks/rlhf/dev.yaml

Lines changed: 29 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,29 @@
1+
2+
rlhf_:
3+
inherits: _defaults
4+
definition: .
5+
install-variant: unpinned
6+
install_group: torch
7+
plan:
8+
method: per_gpu
9+
10+
argv:
11+
--output_dir: "{milabench_extra}/output"
12+
--model_name_or_path: EleutherAI/pythia-1b-deduped
13+
--per_device_train_batch_size: 64
14+
--logging_strategy: "no"
15+
--log_level: "critical"
16+
--bf16: true
17+
18+
19+
rlhf-single:
20+
inherits: rlhf_
21+
plan:
22+
method: per_gpu
23+
24+
25+
rlhf-gpus:
26+
inherits: rlhf_
27+
plan:
28+
method: njobs
29+
n: 1

benchmarks/rlhf/main.py

Lines changed: 136 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,136 @@
1+
#!/usr/bin/env python
2+
3+
import shutil
4+
5+
from accelerate import PartialState
6+
from datasets import load_dataset
7+
from transformers import (
8+
AutoModelForCausalLM,
9+
AutoModelForSequenceClassification,
10+
AutoTokenizer,
11+
HfArgumentParser,
12+
)
13+
14+
from trl import ModelConfig
15+
from trl.trainer.ppov2_trainer import PPOv2Config, PPOv2Trainer
16+
from trl.trainer.utils import SIMPLE_QUERY_CHAT_TEMPLATE
17+
18+
19+
class PPOv2TrainerIntrumented(PPOv2Trainer):
20+
def __init__(self, *args, **kwargs):
21+
super().__init__(*args, **kwargs)
22+
23+
def batch_size_fn(batch):
24+
x, y = batch['input_ids'].shape
25+
return x * y
26+
27+
from benchmate.observer import BenchObserver
28+
observer = BenchObserver(
29+
batch_size_fn=batch_size_fn,
30+
earlystop=70,
31+
raise_stop_program=True,
32+
stdout=True,
33+
)
34+
35+
self.dataloader = observer.iterate(self.dataloader)
36+
37+
def generate_completions(self, sampling: bool = False):
38+
pass
39+
40+
def _save_checkpoint(self, *args, **kwargs):
41+
pass
42+
43+
def save_model(self, *args, **kwargs):
44+
pass
45+
46+
47+
def main():
48+
parser = HfArgumentParser((PPOv2Config, ModelConfig))
49+
config, model_config = parser.parse_args_into_dataclasses()
50+
# remove output_dir if exists
51+
shutil.rmtree(config.output_dir, ignore_errors=True)
52+
53+
################
54+
# Model & Tokenizer
55+
################
56+
tokenizer = AutoTokenizer.from_pretrained(
57+
model_config.model_name_or_path,
58+
padding_side="left",
59+
trust_remote_code=model_config.trust_remote_code,
60+
)
61+
tokenizer.add_special_tokens({"pad_token": "[PAD]"})
62+
if tokenizer.chat_template is None:
63+
tokenizer.chat_template = SIMPLE_QUERY_CHAT_TEMPLATE
64+
value_model = AutoModelForSequenceClassification.from_pretrained(
65+
config.reward_model_path, trust_remote_code=model_config.trust_remote_code, num_labels=1
66+
)
67+
reward_model = AutoModelForSequenceClassification.from_pretrained(
68+
config.reward_model_path, trust_remote_code=model_config.trust_remote_code, num_labels=1
69+
)
70+
ref_policy = AutoModelForCausalLM.from_pretrained(
71+
config.sft_model_path, trust_remote_code=model_config.trust_remote_code
72+
)
73+
policy = AutoModelForCausalLM.from_pretrained(
74+
config.sft_model_path, trust_remote_code=model_config.trust_remote_code
75+
)
76+
################
77+
# Dataset
78+
################
79+
raw_datasets = load_dataset("trl-internal-testing/descriptiveness-sentiment-trl-style", split="descriptiveness")
80+
eval_samples = 20
81+
train_dataset = raw_datasets.select(range(len(raw_datasets) - eval_samples))
82+
eval_dataset = raw_datasets.select(range(len(raw_datasets) - eval_samples, len(raw_datasets)))
83+
dataset_text_field = "prompt"
84+
85+
def prepare_dataset(dataset, tokenizer):
86+
"""pre-tokenize the dataset before training; only collate during training"""
87+
88+
def tokenize(element):
89+
outputs = tokenizer(
90+
element[dataset_text_field],
91+
padding=False,
92+
)
93+
return {"input_ids": outputs["input_ids"]}
94+
95+
return dataset.map(
96+
tokenize,
97+
batched=True,
98+
remove_columns=dataset.column_names,
99+
num_proc=config.dataset_num_proc,
100+
)
101+
102+
# Compute that only on the main process for faster data processing.
103+
# see: https://github.com/huggingface/trl/pull/1255
104+
with PartialState().local_main_process_first():
105+
train_dataset = prepare_dataset(train_dataset, tokenizer)
106+
eval_dataset = prepare_dataset(eval_dataset, tokenizer)
107+
108+
################
109+
# Training
110+
################
111+
trainer = PPOv2TrainerIntrumented(
112+
config=config,
113+
tokenizer=tokenizer,
114+
policy=policy,
115+
ref_policy=ref_policy,
116+
reward_model=reward_model,
117+
value_model=value_model,
118+
train_dataset=train_dataset,
119+
eval_dataset=eval_dataset,
120+
)
121+
trainer.train()
122+
trainer.save_model(config.output_dir)
123+
if config.push_to_hub:
124+
trainer.push_to_hub()
125+
trainer.generate_completions()
126+
127+
128+
if __name__ == "__main__":
129+
from voir.phase import StopProgram
130+
from benchmate.monitor import bench_monitor
131+
132+
try:
133+
with bench_monitor():
134+
main()
135+
except StopProgram:
136+
pass

benchmarks/rlhf/prepare.py

Lines changed: 54 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,54 @@
1+
#!/usr/bin/env python
2+
3+
import shutil
4+
5+
from transformers import (
6+
AutoModelForCausalLM,
7+
AutoModelForSequenceClassification,
8+
AutoTokenizer,
9+
HfArgumentParser,
10+
)
11+
from datasets import load_dataset
12+
from trl import ModelConfig
13+
from trl.trainer.ppov2_trainer import PPOv2Config
14+
from trl.trainer.utils import SIMPLE_QUERY_CHAT_TEMPLATE
15+
16+
17+
if __name__ == "__main__":
18+
parser = HfArgumentParser((PPOv2Config, ModelConfig))
19+
config, model_config = parser.parse_args_into_dataclasses()
20+
21+
# remove output_dir if exists
22+
shutil.rmtree(config.output_dir, ignore_errors=True)
23+
24+
tokenizer = AutoTokenizer.from_pretrained(
25+
model_config.model_name_or_path,
26+
padding_side="left",
27+
trust_remote_code=model_config.trust_remote_code,
28+
)
29+
30+
tokenizer.add_special_tokens({"pad_token": "[PAD]"})
31+
32+
if tokenizer.chat_template is None:
33+
tokenizer.chat_template = SIMPLE_QUERY_CHAT_TEMPLATE
34+
35+
value_model = AutoModelForSequenceClassification.from_pretrained(
36+
config.reward_model_path,
37+
trust_remote_code=model_config.trust_remote_code,
38+
num_labels=1
39+
)
40+
reward_model = AutoModelForSequenceClassification.from_pretrained(
41+
config.reward_model_path,
42+
trust_remote_code=model_config.trust_remote_code,
43+
num_labels=1
44+
)
45+
ref_policy = AutoModelForCausalLM.from_pretrained(
46+
config.sft_model_path,
47+
trust_remote_code=model_config.trust_remote_code
48+
)
49+
policy = AutoModelForCausalLM.from_pretrained(
50+
config.sft_model_path,
51+
trust_remote_code=model_config.trust_remote_code
52+
)
53+
54+
raw_datasets = load_dataset("trl-internal-testing/descriptiveness-sentiment-trl-style", split="descriptiveness")

0 commit comments

Comments
 (0)