-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathtrain_enhanced.py
More file actions
262 lines (237 loc) · 6.81 KB
/
train_enhanced.py
File metadata and controls
262 lines (237 loc) · 6.81 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
#!/usr/bin/env python
# -*- coding: utf-8 -*-
"""
强化版本DQN训练脚本
使用了多种增强技术:
1. Double DQN
2. Dueling DQN
3. Prioritized Experience Replay
"""
import torch
import numpy as np
import lightning as pl
import gymnasium as gym
import gym_2048
import argparse
from interface import DQNInterface
from utils import load_model_path_by_args
import lightning.pytorch.callbacks as plc
def load_callbacks():
callbacks = []
callbacks.append(
plc.ModelCheckpoint(
monitor="tot_reward",
filename="best-{steps:.1f}-{tot_reward:.2f}",
save_top_k=3, # 保存更多的检查点
mode="max",
save_last=True,
)
)
# 添加学习率调度器监控器
callbacks.append(plc.LearningRateMonitor(logging_interval='step'))
# 添加提前停止回调,但容忍更长时间
callbacks.append(
plc.EarlyStopping(
monitor="tot_reward",
patience=2000,
min_delta=0.01,
mode="max",
verbose=True,
check_on_train_epoch_end=False,
)
)
return callbacks
def main():
pl.seed_everything(args.seed)
load_path = load_model_path_by_args(args)
# 如果启用了奖励整形,设置权重字典
if args.use_reward_shaping:
args.reward_weights = {
'merge': args.merge_weight,
'empty': args.empty_weight,
'monotonicity': args.monotonicity_weight,
'corner': args.corner_weight,
'game_over': args.game_over_weight
}
model = DQNInterface(**vars(args))
args.callbacks = load_callbacks()
trainer = pl.Trainer(
min_epochs=args.min_epochs,
max_epochs=args.max_epochs,
devices=args.devices,
accelerator="auto",
enable_checkpointing=True,
# inference_mode=True,
callbacks=load_callbacks(),
check_val_every_n_epoch=1,
default_root_dir=f"./train_logs/{args.model_name}",
gradient_clip_algorithm="value",
gradient_clip_val=1.0,
accumulate_grad_batches=args.accumulate_grad_batches, # 梯度累积
)
trainer.fit(model, train_dataloaders=model.train_dataloader(), ckpt_path=load_path)
if __name__ == "__main__":
parser = argparse.ArgumentParser(description="Enhanced DQN training for 2048 game")
# 环境和模型参数
parser.add_argument(
"--env_name",
type=str,
default="2048-extended-v2",
help="Name of the environment to train on.",
)
parser.add_argument(
"--obs_size",
type=int,
default=16,
)
parser.add_argument(
"--n_actions",
type=int,
default=4,
)
parser.add_argument(
"--model_name",
type=str,
default="dueling_dqn", # 默认使用Dueling DQN
help="Name of the model to use: [dqn_net, dueling_dqn, prioritized_dqn]",
)
# 缓冲区参数
parser.add_argument(
"--buffer_capacity",
type=int,
default=50000, # 增大缓冲区容量
help="Capacity of the replay buffer.",
)
parser.add_argument(
"--warm_start_steps",
type=int,
default=1000, # 增加预热步数
help="Number of warm start steps.",
)
# 优先级经验回放参数
parser.add_argument(
"--use_prioritized_replay",
action="store_true",
help="使用优先级经验回放",
)
parser.add_argument(
"--per_alpha",
type=float,
default=0.6,
help="优先级经验回放的alpha参数",
)
parser.add_argument(
"--per_beta",
type=float,
default=0.4,
help="优先级经验回放的beta初始值",
)
parser.add_argument(
"--per_beta_increment",
type=float,
default=0.0001,
help="优先级经验回放的beta增量",
)
# 奖励整形参数
parser.add_argument(
"--use_reward_shaping",
action="store_true",
help="使用奖励整形来提升训练效果",
)
parser.add_argument(
"--merge_weight",
type=float,
default=1.0,
help="合并奖励权重",
)
parser.add_argument(
"--empty_weight",
type=float,
default=0.1,
help="空格奖励权重",
)
parser.add_argument(
"--monotonicity_weight",
type=float,
default=0.2,
help="单调性奖励权重",
)
parser.add_argument(
"--corner_weight",
type=float,
default=0.5,
help="角落奖励权重",
)
parser.add_argument(
"--game_over_weight",
type=float,
default=-1.0,
help="游戏结束惩罚权重",
)
# 强化学习参数
parser.add_argument(
"--gamma",
type=float,
default=0.99,
help="Discount factor.",
)
parser.add_argument(
"--target_net_sync_steps",
type=int,
default=1000, # 增加同步频率
help="Number of steps before updating the target network.",
)
parser.add_argument(
"--epsilon_decay_steps",
type=int,
default=100000, # 更缓慢的探索衰减
help="Number of steps for epsilon decay.",
)
parser.add_argument(
"--epsilon_start",
type=float,
default=1.0,
help="Starting value of epsilon.",
)
parser.add_argument(
"--epsilon_end",
type=float,
default=0.1, # 增加底线探索率
help="Ending value of epsilon.",
)
# 重启控制
parser.add_argument("--load_best", action="store_true")
parser.add_argument("--load_dir", default=None, type=str)
parser.add_argument("--load_ver", default=None, type=str)
parser.add_argument("--load_v_num", default=None, type=int)
# 训练参数
parser.add_argument(
"--lr",
type=float,
default=1e-4, # 降低学习率
help="Learning rate for the optimizer.",
)
parser.add_argument(
"--weight_decay",
type=float,
default=1e-5, # 添加权重衰减
help="Weight decay for regularization.",
)
parser.add_argument("--devices", default=-1, type=int)
parser.add_argument("--min_epochs", default=100, type=int)
parser.add_argument("--max_epochs", default=100000, type=int)
parser.add_argument("--seed", default=42, type=int, help="随机种子")
parser.add_argument(
"--batch_size",
type=int,
default=256, # 增大批量大小
help="Batch size for training.",
)
parser.add_argument(
"--accumulate_grad_batches",
type=int,
default=4, # 梯度累积
help="Number of batches to accumulate gradients for.",
)
args = parser.parse_args()
main()