Skip to content

Commit 2a9c928

Browse files
authored
rename save_fn to save_best_fn to avoid ambiguity (#575)
This PR also introduces `tianshou.utils.deprecation` for a unified deprecation wrapper.
1 parent 10d9190 commit 2a9c928

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

69 files changed

+187
-155
lines changed

docs/tutorials/tictactoe.rst

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -327,7 +327,7 @@ With the above preparation, we are close to the first learned agent. The followi
327327

328328
# ======== callback functions used during training =========
329329

330-
def save_fn(policy):
330+
def save_best_fn(policy):
331331
if hasattr(args, 'model_save_path'):
332332
model_save_path = args.model_save_path
333333
else:
@@ -358,8 +358,9 @@ With the above preparation, we are close to the first learned agent. The followi
358358
policy, train_collector, test_collector, args.epoch,
359359
args.step_per_epoch, args.step_per_collect, args.test_num,
360360
args.batch_size, train_fn=train_fn, test_fn=test_fn,
361-
stop_fn=stop_fn, save_fn=save_fn, update_per_step=args.update_per_step,
362-
logger=logger, test_in_train=False, reward_metric=reward_metric)
361+
stop_fn=stop_fn, save_best_fn=save_best_fn,
362+
update_per_step=args.update_per_step, logger=logger,
363+
test_in_train=False, reward_metric=reward_metric)
363364

364365
agent = policy.policies[args.agent_id - 1]
365366
# let's watch the match!

examples/atari/atari_c51.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -133,7 +133,7 @@ def test_c51(args=get_args()):
133133
else: # wandb
134134
logger.load(writer)
135135

136-
def save_fn(policy):
136+
def save_best_fn(policy):
137137
torch.save(policy.state_dict(), os.path.join(log_path, "policy.pth"))
138138

139139
def stop_fn(mean_rewards):
@@ -206,7 +206,7 @@ def watch():
206206
train_fn=train_fn,
207207
test_fn=test_fn,
208208
stop_fn=stop_fn,
209-
save_fn=save_fn,
209+
save_best_fn=save_best_fn,
210210
logger=logger,
211211
update_per_step=args.update_per_step,
212212
test_in_train=False,

examples/atari/atari_dqn.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -165,7 +165,7 @@ def test_dqn(args=get_args()):
165165
else: # wandb
166166
logger.load(writer)
167167

168-
def save_fn(policy):
168+
def save_best_fn(policy):
169169
torch.save(policy.state_dict(), os.path.join(log_path, "policy.pth"))
170170

171171
def stop_fn(mean_rewards):
@@ -244,7 +244,7 @@ def watch():
244244
train_fn=train_fn,
245245
test_fn=test_fn,
246246
stop_fn=stop_fn,
247-
save_fn=save_fn,
247+
save_best_fn=save_best_fn,
248248
logger=logger,
249249
update_per_step=args.update_per_step,
250250
test_in_train=False,

examples/atari/atari_fqf.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -150,7 +150,7 @@ def test_fqf(args=get_args()):
150150
else: # wandb
151151
logger.load(writer)
152152

153-
def save_fn(policy):
153+
def save_best_fn(policy):
154154
torch.save(policy.state_dict(), os.path.join(log_path, "policy.pth"))
155155

156156
def stop_fn(mean_rewards):
@@ -223,7 +223,7 @@ def watch():
223223
train_fn=train_fn,
224224
test_fn=test_fn,
225225
stop_fn=stop_fn,
226-
save_fn=save_fn,
226+
save_best_fn=save_best_fn,
227227
logger=logger,
228228
update_per_step=args.update_per_step,
229229
test_in_train=False,

examples/atari/atari_iqn.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -145,7 +145,7 @@ def test_iqn(args=get_args()):
145145
else: # wandb
146146
logger.load(writer)
147147

148-
def save_fn(policy):
148+
def save_best_fn(policy):
149149
torch.save(policy.state_dict(), os.path.join(log_path, "policy.pth"))
150150

151151
def stop_fn(mean_rewards):
@@ -218,7 +218,7 @@ def watch():
218218
train_fn=train_fn,
219219
test_fn=test_fn,
220220
stop_fn=stop_fn,
221-
save_fn=save_fn,
221+
save_best_fn=save_best_fn,
222222
logger=logger,
223223
update_per_step=args.update_per_step,
224224
test_in_train=False,

examples/atari/atari_ppo.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -209,7 +209,7 @@ def dist(p):
209209
else: # wandb
210210
logger.load(writer)
211211

212-
def save_fn(policy):
212+
def save_best_fn(policy):
213213
torch.save(policy.state_dict(), os.path.join(log_path, "policy.pth"))
214214

215215
def stop_fn(mean_rewards):
@@ -272,7 +272,7 @@ def watch():
272272
args.batch_size,
273273
step_per_collect=args.step_per_collect,
274274
stop_fn=stop_fn,
275-
save_fn=save_fn,
275+
save_best_fn=save_best_fn,
276276
logger=logger,
277277
test_in_train=False,
278278
resume_from_log=args.resume_id is not None,

examples/atari/atari_qrdqn.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -129,7 +129,7 @@ def test_qrdqn(args=get_args()):
129129
else: # wandb
130130
logger.load(writer)
131131

132-
def save_fn(policy):
132+
def save_best_fn(policy):
133133
torch.save(policy.state_dict(), os.path.join(log_path, "policy.pth"))
134134

135135
def stop_fn(mean_rewards):
@@ -202,7 +202,7 @@ def watch():
202202
train_fn=train_fn,
203203
test_fn=test_fn,
204204
stop_fn=stop_fn,
205-
save_fn=save_fn,
205+
save_best_fn=save_best_fn,
206206
logger=logger,
207207
update_per_step=args.update_per_step,
208208
test_in_train=False,

examples/atari/atari_rainbow.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -162,7 +162,7 @@ def test_rainbow(args=get_args()):
162162
else: # wandb
163163
logger.load(writer)
164164

165-
def save_fn(policy):
165+
def save_best_fn(policy):
166166
torch.save(policy.state_dict(), os.path.join(log_path, "policy.pth"))
167167

168168
def stop_fn(mean_rewards):
@@ -246,7 +246,7 @@ def watch():
246246
train_fn=train_fn,
247247
test_fn=test_fn,
248248
stop_fn=stop_fn,
249-
save_fn=save_fn,
249+
save_best_fn=save_best_fn,
250250
logger=logger,
251251
update_per_step=args.update_per_step,
252252
test_in_train=False,

examples/box2d/acrobot_dualdqn.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -99,7 +99,7 @@ def test_dqn(args=get_args()):
9999
writer = SummaryWriter(log_path)
100100
logger = TensorboardLogger(writer)
101101

102-
def save_fn(policy):
102+
def save_best_fn(policy):
103103
torch.save(policy.state_dict(), os.path.join(log_path, 'policy.pth'))
104104

105105
def stop_fn(mean_rewards):
@@ -132,7 +132,7 @@ def test_fn(epoch, env_step):
132132
train_fn=train_fn,
133133
test_fn=test_fn,
134134
stop_fn=stop_fn,
135-
save_fn=save_fn,
135+
save_best_fn=save_best_fn,
136136
logger=logger
137137
)
138138

examples/box2d/bipedal_hardcore_sac.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -161,7 +161,7 @@ def test_sac_bipedal(args=get_args()):
161161
writer = SummaryWriter(log_path)
162162
logger = TensorboardLogger(writer)
163163

164-
def save_fn(policy):
164+
def save_best_fn(policy):
165165
torch.save(policy.state_dict(), os.path.join(log_path, 'policy.pth'))
166166

167167
def stop_fn(mean_rewards):
@@ -180,7 +180,7 @@ def stop_fn(mean_rewards):
180180
update_per_step=args.update_per_step,
181181
test_in_train=False,
182182
stop_fn=stop_fn,
183-
save_fn=save_fn,
183+
save_best_fn=save_best_fn,
184184
logger=logger
185185
)
186186

0 commit comments

Comments
 (0)