Skip to content

Evaluation mode

Compare
Choose a tag to compare
@cpnota cpnota released this 18 Apr 18:39
· 134 commits to master since this release
57536b2

This release contains some minor changes to several key APIs.

Agent Evaluation Mode

We added a new method to the Agent interface called eval. eval is the same as act, except the agent does not perform any training updates. This is useful for measure the performance of an agent at the end of a training run. Speaking of which...

Experiment Refactoring: Train/Test

We completely refactored the all.experiments module. First of all, the primary public entry point is now a function called run_experiment. Under the hood, there is a new Experiment interface:

class Experiment(ABC):
    '''An Experiment manages the basic train/test loop and logs results.'''

    @abstractmethod
    def frame(self):
        '''The index of the current training frame.'''

    @property
    @abstractmethod
    def episode(self):
        '''The index of the current training episode'''

    @abstractmethod
    def train(self, frames=np.inf, episodes=np.inf):
        '''
        Train the agent for a certain number of frames or episodes.
        If both frames and episodes are specified, then the training loop will exit
        when either condition is satisfied.

        Args:
                frames (int): The maximum number of training frames.
                episodes (bool): The maximum number of training episodes.
        '''

    @abstractmethod
    def test(self, episodes=100):
        '''
        Test the agent in eval mode for a certain number of episodes.

        Args:
            episodes (int): The number of test epsiodes.

        Returns:
            list(float): A list of all returns received during testing.
        '''

Notice the new method, experiment.test(). This method runs the agent in eval mode for a certain number of episodes and logs summary statistics (the mean and std of the returns).

Approximation: no_grad vs. eval

Finally, we clarified the usage of Approximation.eval(*inputs) by adding an additional method, Approximation.no_grad(*inputs). eval() both puts the network in evaluation mode and runs the forward pass with torch.no_grad(). no_grad() simply runs a forward pass in the current mode. The various Policy implementations were also adjusted to correctly execute the greedy behavior in eval mode.