Skip to content

Commit 1f9416f

Browse files
committed
Merge branch 'dev-v1' into dev-v2
Conflicts: test/discrete/test_drqn.py
2 parents b2fd31f + 0c385f9 commit 1f9416f

File tree

1 file changed

+8
-2
lines changed

1 file changed

+8
-2
lines changed

test/discrete/test_drqn.py

Lines changed: 8 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@
1717
from tianshou.utils import TensorboardLogger
1818
from tianshou.utils.net.common import Recurrent
1919
from tianshou.utils.space_info import SpaceInfo
20+
from tianshou.utils.torch_utils import policy_within_training_step
2021

2122

2223
def get_args() -> argparse.Namespace:
@@ -92,6 +93,7 @@ def test_drqn(args: argparse.Namespace = get_args(), enable_assertions: bool = T
9293
n_step_return_horizon=args.n_step,
9394
target_update_freq=args.target_update_freq,
9495
)
96+
9597
# collector
9698
buffer = VectorReplayBuffer(
9799
args.buffer_size,
@@ -102,8 +104,12 @@ def test_drqn(args: argparse.Namespace = get_args(), enable_assertions: bool = T
102104
train_collector = Collector[CollectStats](algorithm, train_envs, buffer, exploration_noise=True)
103105
# the stack_num is for RNN training: sample framestack obs
104106
test_collector = Collector[CollectStats](algorithm, test_envs, exploration_noise=True)
105-
train_collector.reset()
106-
train_collector.collect(n_step=args.batch_size * args.training_num)
107+
108+
# initial data collection
109+
with policy_within_training_step(policy):
110+
train_collector.reset()
111+
train_collector.collect(n_step=args.batch_size * args.training_num)
112+
107113
# log
108114
log_path = os.path.join(args.logdir, args.task, "drqn")
109115
writer = SummaryWriter(log_path)

0 commit comments

Comments
 (0)