-
Notifications
You must be signed in to change notification settings - Fork 423
Expand file tree
/
Copy pathhhrlhf_rw.py
More file actions
32 lines (24 loc) · 793 Bytes
/
hhrlhf_rw.py
File metadata and controls
32 lines (24 loc) · 793 Bytes
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
import sys
from areal import RWTrainer
from areal.api.cli_args import RWConfig, load_expr_config
from areal.dataset import get_custom_dataset
from areal.utils.hf_utils import load_hf_tokenizer
def main(args):
config, _ = load_expr_config(args, RWConfig)
tokenizer = load_hf_tokenizer(config.tokenizer_path)
train_dataset = get_custom_dataset(
split="train",
dataset_config=config.train_dataset,
tokenizer=tokenizer,
)
valid_dataset = get_custom_dataset(
split="test",
dataset_config=config.valid_dataset,
tokenizer=tokenizer,
)
with RWTrainer(
config, train_dataset=train_dataset, valid_dataset=valid_dataset
) as trainer:
trainer.train()
if __name__ == "__main__":
main(sys.argv[1:])