Skip to content
This repository was archived by the owner on Oct 31, 2023. It is now read-only.

Latest commit

 

History

History

mono_cpu

Folders and files

NameName
Last commit message
Last commit date

parent directory

..
 
 
 
 
 
 
 
 

A2C Mono CPU

We propose a simple implementation of A2C on one single CPU with a MLP policy.

Note on the forward pass

  • When first executing the agent of the workspace, states from t=0 to t=n_timesteps-1 are computed
  • When executing the agent a second time, then states from t=n_timesteps to t=n_timesteps+n_timesteps-1 are computed
  • There is thus a missing transition between n_timesteps-1 and n_timesteps that never appears in one workspace
  • To avoid this effect, we:
    • Copy the last state of the workspace at the first position through workspace.copy_n_last_steps(1)
    • Then execute the agent from timestep=1 in the workspace
    • The resulting workspace now contains states from n_timesteps-1 to n_timesteps+n_timesteps-2 and the transition is not missing anymore

Execution

    PYTHONPATH=salina python salina/salina_examples/rl/a2c/mono_cpu/main.py

Detailed Explanations

The Agent

We first write an Agent which will read an observation (the observation will be generated by a AutoResetGymAgent that models an environment) and will write action, action_probs and critic at time t in the Workspace:

class A2CAgent(TAgent):
    def __init__(self, observation_size, hidden_size, n_actions):
        super().__init__()
        self.model = nn.Sequential(
            nn.Linear(observation_size, hidden_size),
            nn.ReLU(),
            nn.Linear(hidden_size, n_actions),
        )
        self.critic_model = nn.Sequential(
            nn.Linear(observation_size, hidden_size),
            nn.ReLU(),
            nn.Linear(hidden_size, 1),
        )

    def forward(self, t, stochastic, **kwargs):
        observation = self.get(("env/env_obs", t))
        scores = self.model(observation)
        probs = torch.softmax(scores, dim=-1)
        critic = self.critic_model(observation).squeeze(-1)
        if stochastic:
            action = torch.distributions.Categorical(probs).sample()
        else:
            action = probs.argmax(1)

        self.set(("action", t), action)
        self.set(("action_probs", t), probs)
        self.set(("critic", t), critic)

This agent also has additional forward arguments that will allow us to control how to execute it (e.g stochastic or deterministic mode, ...)

Creating the agents for learning

  • Environment Agent: The first agent to create is the agent that models the environment. In the A2C case, this agent will automatically reset the environments when reaching a final state:
env_agent = AutoResetGymAgent(
        get_class(cfg.algorithm.env), get_arguments(cfg.algorithm.env), n_envs=cfg.algorithm.n_envs
    )

Note that this agent takes the function/class name + function/class arguments as argument and will construct the environment by itself (this is need for parallelization)

  • Agent at time t: Given the environment agent and the policy agent, we can compose them to obtain an agent that will produce at time t both observations, reward, etc... and also action, action_probs, critic
agent = Agents(env_agent, a2c_agent)
  • Complete acquisition agent: While the previous agent acts at time t, we can obtain an agent that will act over a full Workspace:
agent = TemporalAgent(agent)
  • Defining a workspace: Now we can define the workspace on which our agents will be applied:
workspace = salina.Workspace()

Once it is done, the acquisition of a trajectory can be done by just executing: agent(workspace).

  • Executing the agent at each epoch: When executing the agent over a workspace, it will be executed over timesteps 0 to time_size-1. Then, at the next execution, the states from t=time_size to t=time_size+time_size-1 will be acquired, etc.... It means that some transitions (here the one between time_size-1 and time_size) will not appear in the workspace since they are split between two different workspaces. To avoid this border effect, salina allows one to do like this
        if epoch > 0:
            workspace.copy_n_last_steps(1)
            agent(workspace, t=1, n_steps=cfg.algorithm.n_timesteps-1,stochastic=True)
        else:
            agent(workspace, t=0, n_steps=cfg.algorithm.n_timesteps,stochastic=True)
  • Loss computation: To compute the loss, one can get from a workspace the tensors generated by the agents:
critic, done, action_probs, reward, action = workspace[
            "critic", "env/done", "action_probs", "env/reward", "action"
        ]

Each tensor is of size time_size x batch_size x .... They thus allow an easy loss computation, making the implementation of any RL algorithm quite easy.

Next steps

  • In the next example, we show how modularity can be used to easily define complex agents without rewritting the base learning algorithm
  • In (here), we show that any agent can be parallelized over multiple processes
  • In (here), we show how to use GPU for speeding up computation
  • ...