-
Notifications
You must be signed in to change notification settings - Fork 373
Expand file tree
/
Copy pathconfig.py
More file actions
169 lines (149 loc) · 6.28 KB
/
config.py
File metadata and controls
169 lines (149 loc) · 6.28 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
# Copyright 2024 Bytedance Ltd. and/or its affiliates
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
PPO config
"""
import os
from dataclasses import asdict, dataclass, field, fields, is_dataclass
from typing import Optional, Tuple
from ..workers.config import WorkerConfig
def recursive_post_init(dataclass_obj):
if hasattr(dataclass_obj, "post_init"):
dataclass_obj.post_init()
for attr in fields(dataclass_obj):
if is_dataclass(getattr(dataclass_obj, attr.name)):
recursive_post_init(getattr(dataclass_obj, attr.name))
@dataclass
class DataConfig:
train_files: str = ""
val_files: str = ""
prompt_key: str = "prompt"
answer_key: str = "answer"
image_key: str = "images"
image_dir: Optional[str] = None
max_prompt_length: int = 512
max_response_length: int = 512
rollout_batch_size: int = 512
mini_rollout_batch_size: Optional[int] = None
val_batch_size: int = -1
format_prompt: Optional[str] = None
override_chat_template: Optional[str] = None
shuffle: bool = True
seed: int = 1
min_pixels: Optional[int] = 262144
max_pixels: Optional[int] = 4194304
filter_overlong_prompts: bool = True
def post_init(self):
if self.image_dir is not None:
if os.path.exists(self.image_dir): # ray job uses absolute path
self.image_dir = os.path.abspath(self.image_dir)
else:
print(f"Image directory {self.image_dir} not found.")
self.image_dir = None
if self.format_prompt is not None:
if os.path.exists(self.format_prompt): # ray job uses absolute path
self.format_prompt = os.path.abspath(self.format_prompt)
else:
print(f"Format prompt file {self.format_prompt} not found.")
self.format_prompt = None
@dataclass
class AlgorithmConfig:
gamma: float = 1.0
"""discount factor for ppo gae advantage estimator"""
lam: float = 1.0
"""lambda value for ppo gae advantage estimator"""
adv_estimator: str = "grpo"
"""advantage estimator, support `gae`, `grpo`, `reinforce_plus_plus`, `remax`, `rloo`"""
disable_kl: bool = False
"""disable reference model"""
use_kl_loss: bool = False
"""use kl loss instead of kl in reward"""
kl_penalty: str = "kl"
"""kl penalty type, support `kl`, `abs`, `mse`, `low_var_kl`, `full`"""
kl_coef: float = 1e-3
"""kl coefficient"""
kl_type: str = "fixed"
"""kl controller type, support `fixed`, `adaptive`"""
kl_horizon: float = 10000.0
"""kl horizon for adaptive kl controller"""
kl_target: float = 0.1
"""target kl for adaptive kl controller"""
online_filtering: bool = False
"""use online filtering"""
@dataclass
class TrainerConfig:
total_epochs: int = 15
"""total epochs for training"""
max_steps: Optional[int] = None
"""max steps for training, if specified, total_epochs is ignored"""
project_name: str = "easy_r1"
"""project name for logger"""
experiment_name: str = "demo"
"""experiment name for logger"""
logger: Tuple[str] = ("console", "wandb")
"""logger type, support `console`, `mlflow`, `swanlab`, `tensorboard`, `wandb`"""
nnodes: int = 1
"""number of nodes for training"""
n_gpus_per_node: int = 8
"""number of gpus per node for training"""
max_try_make_batch: int = 20
"""max number of generations for online filtering, -1 means no limit"""
critic_warmup: int = 0
"""critic warmup steps"""
val_freq: int = -1
"""validation frequency, -1 means no validation"""
val_before_train: bool = True
"""validate before training"""
val_only: bool = False
"""validate only, skip training"""
val_generations_to_log: int = 0
"""number of generations to log for validation"""
save_freq: int = -1
"""save frequency, -1 means no saving"""
save_limit: int = -1
"""max number of checkpoints to save, -1 means no limit"""
save_model_only: bool = False
"""save model only, no optimizer state dict"""
save_checkpoint_path: Optional[str] = None
"""save checkpoint path, if not specified, use `checkpoints/project_name/experiment_name`"""
load_checkpoint_path: Optional[str] = None
"""load checkpoint path"""
def post_init(self):
if self.save_checkpoint_path is None:
self.save_checkpoint_path = os.path.join("checkpoints", self.project_name, self.experiment_name)
self.save_checkpoint_path = os.path.abspath(self.save_checkpoint_path) # ray job uses absolute path
if self.load_checkpoint_path is not None:
if os.path.exists(self.load_checkpoint_path): # ray job uses absolute path
self.load_checkpoint_path = os.path.abspath(self.load_checkpoint_path)
else:
print(f"Model checkpoint {self.load_checkpoint_path} not found.")
self.load_checkpoint_path = None
@dataclass
class PPOConfig:
data: DataConfig = field(default_factory=DataConfig)
worker: WorkerConfig = field(default_factory=WorkerConfig)
algorithm: AlgorithmConfig = field(default_factory=AlgorithmConfig)
trainer: TrainerConfig = field(default_factory=TrainerConfig)
def post_init(self):
self.worker.rollout.prompt_length = self.data.max_prompt_length
self.worker.rollout.response_length = self.data.max_response_length
self.worker.rollout.trust_remote_code = self.worker.actor.model.trust_remote_code
self.worker.actor.disable_kl = self.algorithm.disable_kl
self.worker.actor.use_kl_loss = self.algorithm.use_kl_loss
self.worker.actor.kl_penalty = self.algorithm.kl_penalty
self.worker.actor.kl_coef = self.algorithm.kl_coef
def deep_post_init(self):
recursive_post_init(self)
def to_dict(self):
return asdict(self)