Skip to content

Support for continuous action spaces. #82

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 2 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1,085 changes: 1,085 additions & 0 deletions epymarl_continuous_action_demo.ipynb

Large diffs are not rendered by default.

3 changes: 2 additions & 1 deletion requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@ pygame
pyparsing
pytest
python-dateutil
PyYAML==5.3.1
PyYAML
requests
sacred
seaborn
Expand All @@ -34,3 +34,4 @@ urllib3
websocket-client
whichcraft
wrapt
sacred
52 changes: 27 additions & 25 deletions src/components/action_selectors.py
Original file line number Diff line number Diff line change
@@ -1,14 +1,13 @@
import torch as th
import torch.nn.functional as F
from torch.distributions import Categorical
from .epsilon_schedules import DecayThenFlatSchedule
REGISTRY = {}


class MultinomialActionSelector():

class MultinomialActionSelector(): # Multinomial distribution of action probabilities
def __init__(self, args):
self.args = args

self.schedule = DecayThenFlatSchedule(args.epsilon_start, args.epsilon_finish, args.epsilon_anneal_time,
decay="linear")
self.epsilon = self.schedule.eval(0)
Expand All @@ -17,62 +16,65 @@ def __init__(self, args):
def select_action(self, agent_inputs, avail_actions, t_env, test_mode=False):
masked_policies = agent_inputs.clone()
masked_policies[avail_actions == 0.0] = 0.0

self.epsilon = self.schedule.eval(t_env)

if test_mode and self.test_greedy:
picked_actions = masked_policies.max(dim=2)[1]
else:
picked_actions = Categorical(masked_policies).sample().long()

return picked_actions


REGISTRY["multinomial"] = MultinomialActionSelector


class EpsilonGreedyActionSelector():

class EpsilonGreedyActionSelector(): # epsilon greedy action selection
def __init__(self, args):
self.args = args

self.schedule = DecayThenFlatSchedule(args.epsilon_start, args.epsilon_finish, args.epsilon_anneal_time,
decay="linear")
self.schedule = DecayThenFlatSchedule(args.epsilon_start, args.epsilon_finish, args.epsilon_anneal_time, decay="linear")
self.epsilon = self.schedule.eval(0)

def select_action(self, agent_inputs, avail_actions, t_env, test_mode=False):

# Assuming agent_inputs is a batch of Q-Values for each agent bav
self.epsilon = self.schedule.eval(t_env)

if test_mode:
# Greedy action selection only
self.epsilon = self.args.evaluation_epsilon

# mask actions that are excluded from selection
masked_q_values = agent_inputs.clone()
masked_q_values[avail_actions == 0.0] = -float("inf") # should never be selected!

random_numbers = th.rand_like(agent_inputs[:, :, 0])
pick_random = (random_numbers < self.epsilon).long()
random_actions = Categorical(avail_actions.float()).sample().long()

picked_actions = pick_random * random_actions + (1 - pick_random) * masked_q_values.max(dim=2)[1]
return picked_actions


REGISTRY["epsilon_greedy"] = EpsilonGreedyActionSelector


class SoftPoliciesSelector():

class SoftPoliciesSelector(): # Categorical distribution, softmaxed action logits
def __init__(self, args):
self.args = args

def select_action(self, agent_inputs, avail_actions, t_env, test_mode=False):
m = Categorical(agent_inputs)
picked_actions = m.sample().long()
return picked_actions

REGISTRY["soft_policies"] = SoftPoliciesSelector

REGISTRY["soft_policies"] = SoftPoliciesSelector
class ContinuousSelector(): # Means and standard deviations of normal distributions. [:k], [k:]
def __init__(self, args):
self.args = args
def select_action(self,agent_inputs,avail_actions,t_env,test_mode=False):
with th.no_grad():
#print("SELECT ACTION called!")
#print(avail_actions)
k = agent_inputs.shape[-1]//2
#print(f"!!!! {k}")
#print(agent_inputs)
#print(agent_inputs.shape)
u, var = agent_inputs[:,:,:k], agent_inputs[:,:,k:]
u, var = F.tanh(u), th.sqrt(F.softplus(var))
action_dist = th.distributions.Normal(u, var)
action = action_dist.sample()
# For now, clip actions to [0,1] manually.
action = th.clip(action,min=0,max=1)
#print(action)
return action

REGISTRY["continuous"] = ContinuousSelector
6 changes: 3 additions & 3 deletions src/components/episode_buffer.py
Original file line number Diff line number Diff line change
Expand Up @@ -86,7 +86,8 @@ def to(self, device):

def update(self, data, bs=slice(None), ts=slice(None), mark_filled=True):
slices = self._parse_slices((bs, ts))
for k, v in data.items():
#print(f"!!! Updating buffer with data: {data}")
for k, v in data.items(): # Maps each field to save to its value.
if k in self.data.transition_data:
target = self.data.transition_data
if mark_filled:
Expand Down Expand Up @@ -245,5 +246,4 @@ def __repr__(self):
return "ReplayBuffer. {}/{} episodes. Keys:{} Groups:{}".format(self.episodes_in_buffer,
self.buffer_size,
self.scheme.keys(),
self.groups.keys())

self.groups.keys())
34 changes: 34 additions & 0 deletions src/config/algs/mappo_c.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,34 @@
# --- MAPPO specific parameters ---

action_selector: "continuous"
mask_before_softmax: False

runner: "parallel"

buffer_size: 10
batch_size_run: 10
batch_size: 10

# update the target network every {} training steps
target_update_interval_or_tau: 0.01

lr: 0.0003
hidden_dim: 128

obs_agent_id: False
obs_last_action: False
obs_individual_obs: False

agent_output_type: "continuous_u_std" # Continuous means and stds
learner: "ppo_c_learner"
entropy_coef: 0.001
use_rnn: False
standardise_returns: False
standardise_rewards: True
q_nstep: 5 # 1 corresponds to normal r + gammaV
critic_type: "cv_critic"
epochs: 4
eps_clip: 0.2
name: "mappo_c"

t_max: 20050000
16 changes: 16 additions & 0 deletions src/config/envs/mpe.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,16 @@
env: "gymma"

env_args:
key: null
time_limit: 100
pretrained_wrapper: null
max_cycles: 100
continuous_actions: False

test_greedy: True
test_nepisode: 100
test_interval: 50000
log_interval: 50000
runner_log_interval: 10000
learner_log_interval: 10000
t_max: 2050000
16 changes: 16 additions & 0 deletions src/config/envs/mpe_continuous.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,16 @@
env: "gymma"

env_args:
key: null
time_limit: 100
pretrained_wrapper: null
max_cycles: 100
continuous_actions: True

test_greedy: True
test_nepisode: 100
test_interval: 50000
log_interval: 50000
runner_log_interval: 10000
learner_log_interval: 10000
t_max: 2050000
12 changes: 7 additions & 5 deletions src/controllers/basic_controller.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,24 +19,26 @@ def __init__(self, scheme, groups, args):
def select_actions(self, ep_batch, t_ep, t_env, bs=slice(None), test_mode=False):
# Only select actions for the selected batch elements in bs
avail_actions = ep_batch["avail_actions"][:, t_ep]
# agent_outputs is log probabilities here.
agent_outputs = self.forward(ep_batch, t_ep, test_mode=test_mode)
chosen_actions = self.action_selector.select_action(agent_outputs[bs], avail_actions[bs], t_env, test_mode=test_mode)
return chosen_actions

def forward(self, ep_batch, t, test_mode=False):
# Passes
agent_inputs = self._build_inputs(ep_batch, t)
avail_actions = ep_batch["avail_actions"][:, t]
agent_outs, self.hidden_states = self.agent(agent_inputs, self.hidden_states)

# Softmax the agent outputs if they're policy logits
if self.agent_output_type == "pi_logits":

if getattr(self.args, "mask_before_softmax", True):
# Make the logits for unavailable actions very negative to minimise their affect on the softmax
reshaped_avail_actions = avail_actions.reshape(ep_batch.batch_size * self.n_agents, -1)
agent_outs[reshaped_avail_actions == 0] = -1e10
agent_outs = th.nn.functional.softmax(agent_outs, dim=-1)

elif self.agent_output_type == "continuous_u_std":
# Runs before select_actions; SA has it from here.
pass
return agent_outs.view(ep_batch.batch_size, self.n_agents, -1)

def init_hidden(self, batch_size):
Expand All @@ -58,6 +60,7 @@ def load_models(self, path):
self.agent.load_state_dict(th.load("{}/agent.th".format(path), map_location=lambda storage, loc: storage))

def _build_agents(self, input_shape):
print(f"BUILDING AGENTS! ARGS.N_ACTIONS = {self.args.n_actions}")
self.agent = agent_REGISTRY[self.args.agent](input_shape, self.args)

def _build_inputs(self, batch, t):
Expand All @@ -83,5 +86,4 @@ def _get_input_shape(self, scheme):
input_shape += scheme["actions_onehot"]["vshape"][0]
if self.args.obs_agent_id:
input_shape += self.n_agents

return input_shape
return input_shape
2 changes: 1 addition & 1 deletion src/envs/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@

from .multiagentenv import MultiAgentEnv
from .gymma import GymmaWrapper
from .smaclite_wrapper import SMACliteWrapper
#from .smaclite_wrapper import SMACliteWrapper # Appears to be broken at the moment.


if sys.platform == "linux":
Expand Down
51 changes: 34 additions & 17 deletions src/envs/gymma.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,22 +8,9 @@

from .multiagentenv import MultiAgentEnv
from .wrappers import FlattenObservation
from .pz_wrapper import PettingZooWrapper # noqa
import envs.pretrained as pretrained # noqa

try:
from .pz_wrapper import PettingZooWrapper # noqa
except ImportError:
warnings.warn(
"PettingZoo is not installed, so these environments will not be available! To install, run `pip install pettingzoo`"
)

try:
from .vmas_wrapper import VMASWrapper # noqa
except ImportError:
warnings.warn(
"VMAS is not installed, so these environments will not be available! To install, run `pip install 'vmas[gymnasium]'`"
)


class GymmaWrapper(MultiAgentEnv):
def __init__(
Expand All @@ -48,7 +35,17 @@ def __init__(
self._obs = None
self._info = None

self.longest_action_space = max(self._env.action_space, key=lambda x: x.n)
try:
self.longest_action_space = max(self._env.action_space, key=lambda x: x.n)
self.cont_space = False
self.continuous_action_space = False
except Exception as e:
print('!!! Using continuous action space')
self.cont_space = True
self.longest_action_space = max(self._env.action_space, key=lambda x: x.shape)
self.action_space_min = min(self._env.action_space, key=lambda x: x.low).low
self.action_space_max = max(self._env.action_space, key=lambda x: x.high).high
self.continuous_action_space = True
self.longest_observation_space = max(
self._env.observation_space, key=lambda x: x.shape
)
Expand Down Expand Up @@ -83,7 +80,11 @@ def _pad_observation(self, obs):

def step(self, actions):
"""Returns obss, reward, terminated, truncated, info"""
actions = [int(a) for a in actions]
#print(f"!!!!! processing actions: {actions}")
if (self.cont_space):
actions = [np.array(a) for a in actions]
else:
actions = [int(a) for a in actions]
obs, reward, done, truncated, self._info = self._env.step(actions)
self._obs = self._pad_observation(obs)

Expand Down Expand Up @@ -129,14 +130,30 @@ def get_avail_actions(self):
def get_avail_agent_actions(self, agent_id):
"""Returns the available actions for agent_id"""
valid = flatdim(self._env.action_space[agent_id]) * [1]
invalid = [0] * (self.longest_action_space.n - len(valid))
if (self.cont_space):
invalid = [0] * (self.longest_action_space.shape[0] - len(valid))
else:
invalid = [0] * (self.longest_action_space.n - len(valid))
return valid + invalid

def get_total_actions(self):
"""Returns the total number of actions an agent could ever take"""
# TODO: This is only suitable for a discrete 1 dimensional action space for each agent
return flatdim(self.longest_action_space)

def get_env_info(self):
env_info = {
"state_shape": self.get_state_size(),
"obs_shape": self.get_obs_size(),
"n_actions": self.get_total_actions(),
"n_agents": self.n_agents,
"episode_limit": self.episode_limit,
}
if (self.continuous_action_space):
env_info['action_min'] = self.action_space_min
env_info['action_max'] = self.action_space_max
return env_info

def reset(self, seed=None, options=None):
"""Returns initial observations and info"""
obs, info = self._env.reset(seed=seed, options=options)
Expand Down
4 changes: 2 additions & 2 deletions src/envs/pz_wrapper.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,10 +22,10 @@ def __init__(self, lib_name, env_name, **kwargs):
self.last_obs = None

self.action_space = Tuple(
tuple([self._env.action_spaces[k] for k in self._env.agents])
tuple([self._env.action_space(k) for k in self._env.agents])
)
self.observation_space = Tuple(
tuple([self._env.observation_spaces[k] for k in self._env.agents])
tuple([self._env.observation_space(k) for k in self._env.agents])
)

def reset(self, *args, **kwargs):
Expand Down
4 changes: 3 additions & 1 deletion src/learners/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
from .actor_critic_pac_dcg_learner import PACDCGLearner
from .maddpg_learner import MADDPGLearner
from .ppo_learner import PPOLearner
from .ppo_c_learner import PPOLearner_C


REGISTRY = {}
Expand All @@ -15,5 +16,6 @@
REGISTRY["actor_critic_learner"] = ActorCriticLearner
REGISTRY["maddpg_learner"] = MADDPGLearner
REGISTRY["ppo_learner"] = PPOLearner
REGISTRY['ppo_c_learner'] = PPOLearner_C
REGISTRY["pac_learner"] = PACActorCriticLearner
REGISTRY["pac_dcg_learner"] = PACDCGLearner
REGISTRY["pac_dcg_learner"] = PACDCGLearner
Loading