-
Notifications
You must be signed in to change notification settings - Fork 225
Open
Description
=================================== FAILURES ===================================
_____ TestTrainAgentAsync_param_1_{max_episode_len=None, num_envs=2}.test ______
self = <chainer.testing._bundle.TestTrainAgentAsync_param_1_{max_episode_len=None, num_envs=2} testMethod=test>
def test(self):
steps = 50
outdir = tempfile.mkdtemp()
agent = mock.Mock()
agent.shared_attributes = []
def _make_env(process_idx, test):
env = mock.Mock()
env.reset.side_effect = [('state', 0)] * 1000
if self.max_episode_len is None:
# Episodic env that terminates after 5 actions
env.step.side_effect = [
(('state', 1), 0, False, {}),
(('state', 2), 0, False, {}),
(('state', 3), -0.5, False, {}),
(('state', 4), 0, False, {}),
(('state', 5), 1, True, {}),
] * 1000
else:
# Continuing env
env.step.side_effect = [
(('state', 1), 0, False, {}),
] * 1000
return env
# Keep references to mock envs to check their states later
envs = [_make_env(i, test=False) for i in range(self.num_envs)]
eval_envs = [_make_env(i, test=True) for i in range(self.num_envs)]
def make_env(process_idx, test):
if test:
return eval_envs[process_idx]
else:
return envs[process_idx]
# Mock states cannot be shared among processes. To check states of mock
# objects, threading is used instead of multiprocessing.
# Because threading.Thread does not have .exitcode attribute, we
# add the attribute manually to avoid an exception.
import threading
# Mock.call_args_list does not seem thread-safe
hook_lock = threading.Lock()
hook = mock.Mock()
def hook_locked(*args, **kwargs):
with hook_lock:
return hook(*args, **kwargs)
with mock.patch('multiprocessing.Process', threading.Thread),\
mock.patch.object(
threading.Thread, 'exitcode', create=True, new=0):
chainerrl.experiments.train_agent_async(
processes=self.num_envs,
agent=agent,
make_env=make_env,
steps=steps,
outdir=outdir,
max_episode_len=self.max_episode_len,
global_step_hooks=[hook_locked],
)
if self.num_envs == 1:
self.assertEqual(agent.act_and_train.call_count, steps)
elif self.num_envs > 1:
> self.assertGreater(agent.act_and_train.call_count, steps)
E AssertionError: 50 not greater than 50
tests/experiments_tests/test_train_agent_async.py:94: AssertionError
Metadata
Metadata
Assignees
Labels
No labels