Skip to content

Commit 380e9e9

Browse files
authored
fix atari examples (#206)
1 parent 8bb8ecb commit 380e9e9

File tree

1 file changed

+6
-4
lines changed

1 file changed

+6
-4
lines changed

examples/atari/atari_dqn.py

Lines changed: 6 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -84,7 +84,7 @@ def test_dqn(args=get_args()):
8484
# replay buffer: `save_last_obs` and `stack_num` can be removed together
8585
# when you have enough RAM
8686
buffer = ReplayBuffer(args.buffer_size, ignore_obs_next=True,
87-
save_last_obs=True, stack_num=args.frames_stack)
87+
save_only_last_obs=True, stack_num=args.frames_stack)
8888
# collector
8989
train_collector = Collector(policy, train_envs, buffer)
9090
test_collector = Collector(policy, test_envs)
@@ -100,17 +100,19 @@ def stop_fn(x):
100100
return x >= env.spec.reward_threshold
101101
elif 'Pong' in args.task:
102102
return x >= 20
103+
else:
104+
return False
103105

104106
def train_fn(x):
105107
# nature DQN setting, linear decay in the first 1M steps
106108
now = x * args.collect_per_step * args.step_per_epoch
107109
if now <= 1e6:
108110
eps = args.eps_train - now / 1e6 * \
109111
(args.eps_train - args.eps_train_final)
110-
policy.set_eps(eps)
111112
else:
112-
policy.set_eps(args.eps_train_final)
113-
print("set eps =", policy.eps)
113+
eps = args.eps_train_final
114+
policy.set_eps(eps)
115+
writer.add_scalar('train/eps', eps, global_step=now)
114116

115117
def test_fn(x):
116118
policy.set_eps(args.eps_test)

0 commit comments

Comments
 (0)