Skip to content

TestTrainAgentAsync.test is flaky #578

@muupan

Description

@muupan
=================================== FAILURES ===================================
_____ TestTrainAgentAsync_param_1_{max_episode_len=None, num_envs=2}.test ______
self = <chainer.testing._bundle.TestTrainAgentAsync_param_1_{max_episode_len=None, num_envs=2} testMethod=test>
    def test(self):
    
        steps = 50
    
        outdir = tempfile.mkdtemp()
    
        agent = mock.Mock()
        agent.shared_attributes = []
    
        def _make_env(process_idx, test):
            env = mock.Mock()
            env.reset.side_effect = [('state', 0)] * 1000
            if self.max_episode_len is None:
                # Episodic env that terminates after 5 actions
                env.step.side_effect = [
                    (('state', 1), 0, False, {}),
                    (('state', 2), 0, False, {}),
                    (('state', 3), -0.5, False, {}),
                    (('state', 4), 0, False, {}),
                    (('state', 5), 1, True, {}),
                ] * 1000
            else:
                # Continuing env
                env.step.side_effect = [
                    (('state', 1), 0, False, {}),
                ] * 1000
            return env
    
        # Keep references to mock envs to check their states later
        envs = [_make_env(i, test=False) for i in range(self.num_envs)]
        eval_envs = [_make_env(i, test=True) for i in range(self.num_envs)]
    
        def make_env(process_idx, test):
            if test:
                return eval_envs[process_idx]
            else:
                return envs[process_idx]
    
        # Mock states cannot be shared among processes. To check states of mock
        # objects, threading is used instead of multiprocessing.
        # Because threading.Thread does not have .exitcode attribute, we
        # add the attribute manually to avoid an exception.
        import threading
    
        # Mock.call_args_list does not seem thread-safe
        hook_lock = threading.Lock()
        hook = mock.Mock()
    
        def hook_locked(*args, **kwargs):
            with hook_lock:
                return hook(*args, **kwargs)
    
        with mock.patch('multiprocessing.Process', threading.Thread),\
            mock.patch.object(
                threading.Thread, 'exitcode', create=True, new=0):
            chainerrl.experiments.train_agent_async(
                processes=self.num_envs,
                agent=agent,
                make_env=make_env,
                steps=steps,
                outdir=outdir,
                max_episode_len=self.max_episode_len,
                global_step_hooks=[hook_locked],
            )
    
        if self.num_envs == 1:
            self.assertEqual(agent.act_and_train.call_count, steps)
        elif self.num_envs > 1:
>           self.assertGreater(agent.act_and_train.call_count, steps)
E           AssertionError: 50 not greater than 50
tests/experiments_tests/test_train_agent_async.py:94: AssertionError

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions