-
Notifications
You must be signed in to change notification settings - Fork 420
Expand file tree
/
Copy pathrl_config_factory.py
More file actions
129 lines (102 loc) · 4.86 KB
/
rl_config_factory.py
File metadata and controls
129 lines (102 loc) · 4.86 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
from pathlib import Path
from typing import Any, Dict, Optional
from xtuner.v1.config import AdamWConfig, FSDPConfig, LRConfig
from xtuner.v1.datasets import RLTokenizeFnConfig
from xtuner.v1.datasets.config import DataloaderConfig, DatasetConfig
from xtuner.v1.ray.base import AcceleratorResourcesConfig
from xtuner.v1.ray.config.worker import RolloutConfig
from xtuner.v1.ray.dataflow import DataFlowConfig, ReplayBufferConfig
from xtuner.v1.ray.evaluator import EvaluatorConfig
from xtuner.v1.ray.judger.controller import JudgerConfig
from xtuner.v1.rl.base import WorkerConfig
from xtuner.v1.rl.grpo import GRPOLossConfig
from xtuner.v1.utils.rl_test_utils import get_eos_token
def _filter_pydantic_kwargs(target_class: Any, kwargs: Dict) -> Dict:
accepted_keys = set(target_class.model_fields.keys())
return {k: v for k, v in kwargs.items() if k in accepted_keys}
def _build_config(config_class, **kwargs):
filtered_params = _filter_pydantic_kwargs(config_class, kwargs)
return config_class(**filtered_params)
def get_resources_config(**kwargs) -> AcceleratorResourcesConfig:
return _build_config(AcceleratorResourcesConfig, **kwargs)
def get_rollout_config(**kwargs) -> RolloutConfig:
return _build_config(RolloutConfig, **kwargs)
def get_dataflow_config(**kwargs) -> DataFlowConfig:
return _build_config(DataFlowConfig, **kwargs)
def get_replay_buffer_config(tokenizer: Any, **kwargs) -> ReplayBufferConfig:
tokenizer_config = RLTokenizeFnConfig(max_length=kwargs["max_prompt_length"])
train_dataset = DatasetConfig(anno_path=kwargs["data_path"])
train_dataset_cfg = [{"dataset": train_dataset, "tokenize_fn": tokenizer_config}]
dataloader_config = DataloaderConfig(collator="fake_collator", pack_level="none")
return ReplayBufferConfig(
dataset_cfg=train_dataset_cfg,
dataloader_cfg=dataloader_config,
tokenizer=tokenizer,
postprocessor_func=kwargs.get("filter_func"),
)
def get_dapo_judger_config(tokenizer: Any, **kwargs):
dapo_defaults_args = {
"enable_overlong_buffer": True,
"overlong_buffer_len": 4096,
"overlong_penalty_factor": 1.0,
}
dapo_config_params = {**dapo_defaults_args, **kwargs}
eos_token_id = get_eos_token(kwargs["model_path"])
eos_token_str = tokenizer.convert_ids_to_tokens(eos_token_id)
from xtuner.v1.ray.judger.dapo_math import DapoMathJudgerConfig
filtered_params = _filter_pydantic_kwargs(DapoMathJudgerConfig, dapo_config_params)
dapomath_judger_config = DapoMathJudgerConfig(
judger_name="dapo_math",
eos_token=eos_token_str,
max_response_len=kwargs["max_response_length"],
tokenizer=tokenizer,
**filtered_params,
)
return JudgerConfig(reward_judger_configs=[dapomath_judger_config])
def get_gsm8k_judger_config():
from xtuner.v1.ray.judger.gsm8k import GSM8KJudgerConfig
gsm8k_judger_config = GSM8KJudgerConfig(judger_name="openai/gsm8k")
judger_cfg = JudgerConfig(reward_judger_configs=[gsm8k_judger_config])
return judger_cfg
def get_evaluator_config(tokenizer: Any, **kwargs) -> Optional[EvaluatorConfig]:
if not kwargs["enable_evaluate"]:
return None
eval_dataset = DatasetConfig(anno_path=kwargs["eval_data_path"])
tokenizer_config = RLTokenizeFnConfig(max_length=kwargs["max_prompt_length"])
eval_dataset_cfg = [{"dataset": eval_dataset, "tokenize_fn": tokenizer_config}]
filtered_params = _filter_pydantic_kwargs(EvaluatorConfig, kwargs)
return EvaluatorConfig(
dataset_cfg=eval_dataset_cfg,
tokenizer=tokenizer,
**filtered_params,
)
def get_train_worker_config(**kwargs) -> WorkerConfig:
from xtuner.v1.model import get_model_config_from_hf
model_cfg = get_model_config_from_hf(Path(kwargs["model_path"]))
defaults = {
"optim_cfg": AdamWConfig(lr=1e-6, betas=(0.9, 0.999), max_grad_norm=1.0, weight_decay=0.1, foreach=False),
"loss_cfg": GRPOLossConfig(
policy_loss_cfg={
"cliprange_high": 0.28,
"cliprange_low": 0.2,
"loss_type": "vanilla",
"clip_ratio_c": 10.0,
"log_prob_diff_min": -20.0,
"log_prob_diff_max": 20.0,
},
ignore_idx=-100,
use_kl_loss=False,
kl_loss_coef=0.0,
kl_loss_type="low_var_kl",
mode="chunk",
chunk_size=512,
),
"lr_cfg": LRConfig(lr_type="constant", warmup_ratio=0, lr_min=1e-6),
"fsdp_cfg": FSDPConfig(torch_compile=False, cpu_offload=False, ep_size=1),
"sp_size": 1,
"optimizer_steps": 16,
"pack_max_length": 4096,
}
config_params = {**defaults, **kwargs}
filtered_params = _filter_pydantic_kwargs(WorkerConfig, config_params)
return WorkerConfig(load_from=config_params["model_path"], model_cfg=model_cfg, **filtered_params)