Skip to content

Bounded Action Space #81

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 1 commit into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
19 changes: 7 additions & 12 deletions rsl_rl/algorithms/ppo.py
Original file line number Diff line number Diff line change
Expand Up @@ -141,7 +141,7 @@ def act(self, obs, critic_obs):
self.transition.values = self.policy.evaluate(critic_obs).detach()
self.transition.actions_log_prob = self.policy.get_actions_log_prob(self.transition.actions).detach()
self.transition.action_mean = self.policy.action_mean.detach()
self.transition.action_sigma = self.policy.action_std.detach()
self.transition.actions_distribution = self.policy.actions_distribution.detach()
# need to record obs and critic_obs before env.step()
self.transition.observations = obs
self.transition.privileged_observations = critic_obs
Expand Down Expand Up @@ -214,12 +214,11 @@ def update(self): # noqa: C901
returns_batch,
old_actions_log_prob_batch,
old_mu_batch,
old_sigma_batch,
old_actions_distributions_parameters,
hid_states_batch,
masks_batch,
rnd_state_batch,
) in generator:

# number of augmentations per sample
# we start with 1 and increase it if we use symmetry augmentation
num_aug = 1
Expand Down Expand Up @@ -262,21 +261,16 @@ def update(self): # noqa: C901
# -- entropy
# we only keep the entropy of the first augmentation (the original one)
mu_batch = self.policy.action_mean[:original_batch_size]
sigma_batch = self.policy.action_std[:original_batch_size]
actions_distributions_batch = self.policy.actions_distribution[:original_batch_size]
entropy_batch = self.policy.entropy[:original_batch_size]

# KL
if self.desired_kl is not None and self.schedule == "adaptive":
current_dist = self.policy.build_distribution(actions_distributions_batch)
old_dist = self.policy.build_distribution(old_actions_distributions_parameters)
with torch.inference_mode():
kl = torch.sum(
torch.log(sigma_batch / old_sigma_batch + 1.0e-5)
+ (torch.square(old_sigma_batch) + torch.square(old_mu_batch - mu_batch))
/ (2.0 * torch.square(sigma_batch))
- 0.5,
axis=-1,
)
kl = torch.distributions.kl.kl_divergence(current_dist, old_dist)
kl_mean = torch.mean(kl)

# Reduce the KL divergence across all GPUs
if self.is_multi_gpu:
torch.distributed.all_reduce(kl_mean, op=torch.distributed.ReduceOp.SUM)
Expand Down Expand Up @@ -304,6 +298,7 @@ def update(self): # noqa: C901

# Surrogate loss
ratio = torch.exp(actions_log_prob_batch - torch.squeeze(old_actions_log_prob_batch))

surrogate = -torch.squeeze(advantages_batch) * ratio
surrogate_clipped = -torch.squeeze(advantages_batch) * torch.clamp(
ratio, 1.0 - self.clip_param, 1.0 + self.clip_param
Expand Down
2 changes: 2 additions & 0 deletions rsl_rl/modules/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
"""Definitions for neural-network components for RL-agents."""

from .actor_critic import ActorCritic
from .actor_critic_beta import ActorCriticBeta
from .actor_critic_recurrent import ActorCriticRecurrent
from .normalizer import EmpiricalNormalization
from .rnd import RandomNetworkDistillation
Expand All @@ -14,6 +15,7 @@

__all__ = [
"ActorCritic",
"ActorCriticBeta",
"ActorCriticRecurrent",
"EmpiricalNormalization",
"RandomNetworkDistillation",
Expand Down
59 changes: 53 additions & 6 deletions rsl_rl/modules/actor_critic.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,8 @@ def __init__(
activation="elu",
init_noise_std=1.0,
noise_std_type: str = "scalar",
clip_actions: bool = False,
clip_actions_range: tuple = (-1.0, 1.0),
**kwargs,
):
if kwargs:
Expand All @@ -49,6 +51,11 @@ def __init__(
actor_layers.append(activation)
self.actor = nn.Sequential(*actor_layers)

self.clip_actions = clip_actions
self.clip_actions_range = clip_actions_range
if self.clip_actions:
self.clipping_layer = nn.Tanh()

# Value function
critic_layers = []
critic_layers.append(nn.Linear(mlp_input_dim_c, critic_hidden_dims[0]))
Expand Down Expand Up @@ -94,19 +101,34 @@ def forward(self):

@property
def action_mean(self):
return self.distribution.mean

mode = self.distribution.mean
if self.clip_actions:
mode = ((mode + 1) /2.0)* (self.clip_actions_range[1] - self.clip_actions_range[0]) + self.clip_actions_range[0]
return mode

@property
def action_std(self):
return self.distribution.stddev

@property
def actions_distribution(self) -> torch.Tensor:
# Mean and Std concatenated on an extra dimension
return torch.stack([self.distribution.mean, self.distribution.stddev], dim=-1)

@property
def entropy(self):
return self.distribution.entropy().sum(dim=-1)

def build_distribution(self, parameters):
# build the distribution
return Normal(parameters[..., 0], parameters[..., 1])

def update_distribution(self, observations):
# compute mean
mean = self.actor(observations)
if self.clip_actions:
mean = self.clipping_layer(mean)

# compute standard deviation
if self.noise_std_type == "scalar":
std = self.std.expand_as(mean)
Expand All @@ -119,14 +141,30 @@ def update_distribution(self, observations):

def act(self, observations, **kwargs):
self.update_distribution(observations)
return self.distribution.sample()
act = self.distribution.sample()
if self.clip_actions:
# Apply tanh to clip the actions to [-1, 1]
act = self.clipping_layer(act)
# Rescale the actions to the desired range
act = ((act + 1) / 2.0) * (self.clip_actions_range[1] - self.clip_actions_range[0]) + self.clip_actions_range[0]
return act

def get_actions_log_prob(self, actions):
return self.distribution.log_prob(actions).sum(dim=-1)
# Scale the actions to [-1, 1] before computing the log probability.
if self.clip_actions:
# The unscaled actions still have the tanh applied to them.
unscaled_actions = (actions - self.clip_actions_range[0]) / (self.clip_actions_range[1] - self.clip_actions_range[0]) * 2.0 - 1.0
# Revert the tanh to get the original actions. We use the TanhBijector to avoid numerical issues.
gaussian_actions = self.inverse_tanh(unscaled_actions)
return (self.distribution.log_prob(gaussian_actions) - torch.log(1 - unscaled_actions*unscaled_actions + 1e-6)).sum(dim=-1)
else:
return self.distribution.log_prob(actions).sum(dim=-1)

def act_inference(self, observations):
actions_mean = self.actor(observations)
return actions_mean
mode= self.actor(observations)
if self.clip_actions:
mode = ((mode + 1) / 2.0) * (self.clip_actions_range[1] - self.clip_actions_range[0]) + self.clip_actions_range[0]
return mode

def evaluate(self, critic_observations, **kwargs):
value = self.critic(critic_observations)
Expand All @@ -147,3 +185,12 @@ def load_state_dict(self, state_dict, strict=True):

super().load_state_dict(state_dict, strict=strict)
return True

@staticmethod
def atanh(x):
return 0.5 * (x.log1p() - (-x).log1p())

@staticmethod
def inverse_tanh(y):
eps = torch.finfo(y.dtype).eps
return ActorCritic.atanh(y.clamp(min=-1.0 + eps, max=1.0 - eps))
166 changes: 166 additions & 0 deletions rsl_rl/modules/actor_critic_beta.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,166 @@
# Copyright (c) 2021-2025, ETH Zurich and NVIDIA CORPORATION
# All rights reserved.
#
# SPDX-License-Identifier: BSD-3-Clause

from __future__ import annotations

import torch
import torch.nn as nn
from torch.distributions import Beta

from rsl_rl.utils import resolve_nn_activation


class ActorCriticBeta(nn.Module):
is_recurrent = False

def __init__(
self,
num_actor_obs,
num_critic_obs,
num_actions,
actor_hidden_dims=[256, 256, 256],
critic_hidden_dims=[256, 256, 256],
activation="elu",
init_noise_std=1.0,
noise_std_type: str = "scalar",
clip_actions: bool = True,
clip_actions_range: tuple = (-1.0, 1.0),
**kwargs,
):
if kwargs:
print(
"ActorCritic.__init__ got unexpected arguments, which will be ignored: "
+ str([key for key in kwargs.keys()])
)
super().__init__()
activation = resolve_nn_activation(activation)

mlp_input_dim_a = num_actor_obs
mlp_input_dim_c = num_critic_obs
# Policy
actor_layers = []
actor_layers.append(nn.Linear(mlp_input_dim_a, actor_hidden_dims[0]))
actor_layers.append(activation)
for layer_index in range(len(actor_hidden_dims)):
if layer_index == len(actor_hidden_dims) - 1:
self.alpha = nn.Linear(actor_hidden_dims[layer_index], num_actions)
self.beta = nn.Linear(actor_hidden_dims[layer_index], num_actions)
self.alpha_activation = nn.Softplus()
self.beta_activation = nn.Softplus()
else:
actor_layers.append(nn.Linear(actor_hidden_dims[layer_index], actor_hidden_dims[layer_index + 1]))
actor_layers.append(activation)
self.actor = nn.Sequential(*actor_layers)

self.clip_actions_range = clip_actions_range

# Value function
critic_layers = []
critic_layers.append(nn.Linear(mlp_input_dim_c, critic_hidden_dims[0]))
critic_layers.append(activation)
for layer_index in range(len(critic_hidden_dims)):
if layer_index == len(critic_hidden_dims) - 1:
critic_layers.append(nn.Linear(critic_hidden_dims[layer_index], 1))
else:
critic_layers.append(nn.Linear(critic_hidden_dims[layer_index], critic_hidden_dims[layer_index + 1]))
critic_layers.append(activation)
self.critic = nn.Sequential(*critic_layers)

print(f"Actor MLP: {self.actor}")
print(f"Critic MLP: {self.critic}")

# Action distribution (populated in update_distribution)
self.distribution = None
self.a = None
self.b = None
# disable args validation for speedup
Beta.set_default_validate_args(False)

@staticmethod
# not used at the moment
def init_weights(sequential, scales):
[
torch.nn.init.orthogonal_(module.weight, gain=scales[idx])
for idx, module in enumerate(mod for mod in sequential if isinstance(mod, nn.Linear))
]

def reset(self, dones=None):
pass

def forward(self):
raise NotImplementedError

@property
def action_mean(self):
mode = self.a / (self.a + self.b)
mode_rescaled = mode * (self.clip_actions_range[1] - self.clip_actions_range[0]) + self.clip_actions_range[0]
return mode_rescaled

@property
def action_std(self):
return torch.sqrt(self.a * self.b / ((self.a + self.b + 1) * (self.a + self.b) ** 2))

@property
def actions_distribution(self):
# Alpha and beta concatenated on an extra dimension
return torch.stack([self.a, self.b], dim=-1)

@property
def entropy(self):
return self.distribution.entropy().sum(dim=-1)

def build_distribution(self, parameters):
# create distribution
return Beta(parameters[...,0], parameters[...,1])

def update_distribution(self, observations):
# compute mean
latent = self.actor(observations)
self.a = self.alpha_activation(self.alpha(latent)) + 1.0
self.b = self.beta_activation(self.beta(latent)) + 1.0

# create distribution
self.distribution = Beta(self.a, self.b)

def act(self, observations, **kwargs):
self.update_distribution(observations)
act = self.distribution.sample()
act_rescaled = act * (self.clip_actions_range[1] - self.clip_actions_range[0]) + self.clip_actions_range[0]
return act_rescaled

def get_actions_log_prob(self, actions):
# Unscale the actions to [0, 1] before computing the log probability.
unscaled_actions = (actions - self.clip_actions_range[0]) / (self.clip_actions_range[1] - self.clip_actions_range[0])
# For numerical stability, clip the actions to [1e-5, 1 - 1e-5].
unscaled_actions = torch.clamp(unscaled_actions, 1e-5, 1 - 1e-5)
return self.distribution.log_prob(unscaled_actions).sum(dim=-1)

def act_inference(self, observations):
latent = self.actor(observations)
self.a = self.alpha_activation(self.alpha(latent))
self.b = self.beta_activation(self.beta(latent))
mode = self.a / (self.a + self.b)
mode_rescaled = mode * (self.clip_actions_range[1] - self.clip_actions_range[0]) + self.clip_actions_range[0]
return mode_rescaled

def evaluate(self, critic_observations, **kwargs):
value = self.critic(critic_observations)
return value

def load_state_dict(self, state_dict, strict=True):
"""Load the parameters of the actor-critic model.

Args:
state_dict (dict): State dictionary of the model.
strict (bool): Whether to strictly enforce that the keys in state_dict match the keys returned by this
module's state_dict() function.

Returns:
bool: Whether this training resumes a previous training. This flag is used by the `load()` function of
`OnPolicyRunner` to determine how to load further parameters (relevant for, e.g., distillation).
"""

super().load_state_dict(state_dict, strict=strict)
return True
3 changes: 2 additions & 1 deletion rsl_rl/runners/on_policy_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
from rsl_rl.env import VecEnv
from rsl_rl.modules import (
ActorCritic,
ActorCriticBeta,
ActorCriticRecurrent,
EmpiricalNormalization,
StudentTeacher,
Expand Down Expand Up @@ -69,7 +70,7 @@ def __init__(self, env: VecEnv, train_cfg: dict, log_dir: str | None = None, dev

# evaluate the policy class
policy_class = eval(self.policy_cfg.pop("class_name"))
policy: ActorCritic | ActorCriticRecurrent | StudentTeacher | StudentTeacherRecurrent = policy_class(
policy: ActorCritic | ActorCriticBeta | ActorCriticRecurrent | StudentTeacher | StudentTeacherRecurrent = policy_class(
num_obs, num_privileged_obs, self.env.num_actions, **self.policy_cfg
).to(self.device)

Expand Down
Loading