Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
55 changes: 54 additions & 1 deletion src/agent/gs_agent/algos/config/registry.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,14 @@
from pathlib import Path

from gs_agent.algos.config.schema import BCArgs, OptimizerType, PPOArgs
from gs_agent.algos.config.schema import BCArgs, LearningRateType, OptimizerType, PPOArgs
from gs_agent.modules.config.registry import DEFAULT_MLP

# default PPO config
PPO_DEFAULT = PPOArgs(
policy_backbone=DEFAULT_MLP,
critic_backbone=DEFAULT_MLP,
lr=3e-4,
lr_type=LearningRateType.FIXED,
gamma=0.99,
gae_lambda=0.95,
clip_ratio=0.2,
Expand All @@ -28,6 +29,7 @@
policy_backbone=DEFAULT_MLP,
critic_backbone=DEFAULT_MLP,
lr=3e-4,
lr_type=LearningRateType.FIXED,
value_lr=None,
gamma=0.99,
gae_lambda=0.95,
Expand All @@ -49,6 +51,7 @@
policy_backbone=DEFAULT_MLP,
critic_backbone=DEFAULT_MLP,
lr=3e-4,
lr_type=LearningRateType.FIXED,
value_lr=None,
gamma=0.99,
gae_lambda=0.95,
Expand Down Expand Up @@ -78,3 +81,53 @@
optimizer_type=OptimizerType.ADAM,
weight_decay=0.0,
)


# goal reaching PPO config
PPO_WALKING_MLP = PPOArgs(
policy_backbone=DEFAULT_MLP,
critic_backbone=DEFAULT_MLP,
lr=1e-3,
lr_type=LearningRateType.ADAPTIVE,
lr_adaptive_factor=1.5,
lr_min=1e-5,
lr_max=1e-2,
value_lr=None,
gamma=0.99,
gae_lambda=0.95,
clip_ratio=0.2,
use_clipped_value_loss=True,
value_loss_coef=1.0,
entropy_coef=0.003,
max_grad_norm=1.0,
target_kl=0.01,
num_epochs=5,
num_mini_batches=4,
rollout_length=24,
optimizer_type=OptimizerType.ADAM,
weight_decay=0.0,
)

# goal reaching PPO config
PPO_TELEOP_MLP = PPOArgs(
policy_backbone=DEFAULT_MLP,
critic_backbone=DEFAULT_MLP,
lr=1e-3,
lr_type=LearningRateType.ADAPTIVE,
lr_adaptive_factor=1.5,
lr_min=1e-5,
lr_max=1e-2,
value_lr=None,
gamma=0.99,
gae_lambda=0.95,
clip_ratio=0.2,
value_loss_coef=1.0,
entropy_coef=0.003,
max_grad_norm=1.0,
target_kl=0.01,
num_epochs=5,
num_mini_batches=4,
rollout_length=24,
optimizer_type=OptimizerType.ADAM,
weight_decay=0.0,
)
12 changes: 12 additions & 0 deletions src/agent/gs_agent/algos/config/schema.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,11 @@ class OptimizerType(GenesisEnum):
SGD = "SGD"


class LearningRateType(GenesisEnum):
FIXED = "FIXED"
ADAPTIVE = "ADAPTIVE"


class AlgorithmType(GenesisEnum):
PPO = "PPO"
BC = "BC"
Expand All @@ -37,6 +42,12 @@ class PPOArgs(BaseModel):
value_lr: PositiveFloat | None = None
"""None means use the same learning rate as the policy"""

# Adaptive learning rate
lr_type: LearningRateType = LearningRateType.FIXED
lr_adaptive_factor: PositiveFloat = 1.5
lr_min: PositiveFloat = 1e-5
lr_max: PositiveFloat = 1e-2

# Discount and GAE
gamma: PositiveFloat = Field(default=0.99, ge=0, le=1)
gae_lambda: PositiveFloat = Field(default=0.95, ge=0, le=1)
Expand All @@ -47,6 +58,7 @@ class PPOArgs(BaseModel):
entropy_coef: NonNegativeFloat = 0.0
max_grad_norm: PositiveFloat = 1.0
target_kl: PositiveFloat = 0.02
use_clipped_value_loss: bool = False

# Training
num_epochs: NonNegativeInt = 10
Expand Down
153 changes: 134 additions & 19 deletions src/agent/gs_agent/algos/ppo.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
import torch
import torch.nn as nn

from gs_agent.algos.config.schema import PPOArgs
from gs_agent.algos.config.schema import LearningRateType, PPOArgs
from gs_agent.bases.algo import BaseAlgo
from gs_agent.bases.env_wrapper import BaseEnvWrapper
from gs_agent.bases.policy import Policy
Expand Down Expand Up @@ -47,6 +47,12 @@ def __init__(
self.env.num_envs, device=self.device, dtype=torch.float
)
self._curr_ep_len = torch.zeros(self.env.num_envs, device=self.device, dtype=torch.float)

# Adaptive learning rate tracking
self._current_lr = cfg.lr

self.use_clipped_value_loss = cfg.use_clipped_value_loss

#
self._build_actor_critic()
self._build_rollouts()
Expand All @@ -64,7 +70,7 @@ def _build_actor_critic(self) -> None:
policy_backbone=policy_backbone,
action_dim=self._action_dim,
).to(self.device)
self._actor_optimizer = torch.optim.Adam(self._actor.parameters(), lr=self.cfg.lr)
self._actor_optimizer = torch.optim.Adam(self._actor.parameters(), lr=self._current_lr)

critic_backbone = NetworkFactory.create_network(
network_backbone_args=self.cfg.critic_backbone,
Expand All @@ -84,28 +90,60 @@ def _build_rollouts(self) -> None:
num_envs=self._num_envs,
max_steps=self._num_steps,
actor_obs_size=self._actor_obs_dim,
critic_obs_size=self._critic_obs_dim,
action_size=self._action_dim,
device=self.device,
)

def _update_learning_rate(self, kl_mean: float) -> None:
"""Update learning rate adaptively based on KL divergence."""
if self.cfg.lr_type == LearningRateType.ADAPTIVE:
# If KL is too high, reduce learning rate
if kl_mean > self.cfg.target_kl * 1.5: # 1.5x threshold
self._current_lr = max(
self._current_lr / self.cfg.lr_adaptive_factor, self.cfg.lr_min
)
# If KL is too low, increase learning rate
elif kl_mean < self.cfg.target_kl * 0.5: # 0.5x threshold
self._current_lr = min(
self._current_lr * self.cfg.lr_adaptive_factor, self.cfg.lr_max
)

# Update only actor optimizer learning rate
for param_group in self._actor_optimizer.param_groups:
param_group["lr"] = self._current_lr

def _collect_rollouts(self, num_steps: int) -> dict[str, Any]:
"""Collect rollouts from the environment."""
obs = self.env.get_observations()
actor_obs, critic_obs = self.env.get_observations()
termination_buffer = []
reward_terms_buffer = []
info_buffer = []
with torch.inference_mode():
# collect rollouts and compute returns & advantages
for _step in range(num_steps):
action, log_prob = self._actor(obs)
action, log_prob, mu, sigma = self._actor.forward_with_dist_params(actor_obs)
# Step environment
next_obs, reward, terminated, truncated, _extra_infos = self.env.step(action)
_, reward, terminated, truncated, _extra_infos = self.env.step(action)
next_actor_obs, next_critic_obs = self.env.get_observations()

# add next value to reward of truncated steps
if "time_outs" in _extra_infos:
time_outs = _extra_infos["time_outs"]
next_values = self._critic(_extra_infos["observations"]["critic"])
reward = reward + next_values * time_outs

# all tensors are of shape: [num_envs, dim]
transition = {
GAEBufferKey.ACTOR_OBS: obs,
GAEBufferKey.ACTOR_OBS: actor_obs,
GAEBufferKey.CRITIC_OBS: critic_obs,
GAEBufferKey.ACTIONS: action,
GAEBufferKey.REWARDS: reward,
GAEBufferKey.DONES: terminated,
GAEBufferKey.VALUES: self._critic(obs),
GAEBufferKey.VALUES: self._critic(critic_obs),
GAEBufferKey.ACTION_LOGPROBS: log_prob,
GAEBufferKey.MU: mu,
GAEBufferKey.SIGMA: sigma,
}
self._rollouts.append(transition)

Expand All @@ -114,6 +152,11 @@ def _collect_rollouts(self, num_steps: int) -> dict[str, Any]:
self._curr_reward_sum += reward.squeeze(-1)
self._curr_ep_len += 1

# Update termination buffer
termination_buffer.append(_extra_infos["termination"])
reward_terms_buffer.append(_extra_infos["reward_terms"])
if "info" in _extra_infos:
info_buffer.append(_extra_infos["info"])
# Check for episode completions and reset tracking
done_mask = terminated.unsqueeze(-1) | truncated.unsqueeze(-1)
new_ids = (done_mask > 0).nonzero(as_tuple=False)
Expand All @@ -127,46 +170,102 @@ def _collect_rollouts(self, num_steps: int) -> dict[str, Any]:
self._curr_reward_sum[new_ids] = 0
self._curr_ep_len[new_ids] = 0

obs = next_obs
actor_obs, critic_obs = next_actor_obs, next_critic_obs
with torch.no_grad():
last_value = self._critic(self.env.get_observations())
last_value = self._critic(critic_obs)
self._rollouts.set_final_value(last_value)

mean_reward = 0.0
mean_ep_len = 0.0
if len(self._rewbuffer) > 0:
mean_reward = statistics.mean(self._rewbuffer)
mean_ep_len = statistics.mean(self._lenbuffer)
# import ipdb; ipdb.set_trace()
mean_termination = {}
mean_reward_terms = {}
mean_info = {}
if len(termination_buffer) > 0:
for key in termination_buffer[0].keys():
terminations = torch.stack([termination[key] for termination in termination_buffer])
mean_termination[key] = terminations.to(torch.float).mean().item()
if len(reward_terms_buffer) > 0:
for key in reward_terms_buffer[0].keys():
reward_terms = torch.stack(
[reward_term[key] for reward_term in reward_terms_buffer]
)
mean_reward_terms[key] = reward_terms.mean().item()
if len(info_buffer) > 0:
for key in info_buffer[0].keys():
infos = torch.tensor([info[key] for info in info_buffer])
mean_info[key] = infos.mean().item()
return {
"mean_reward": mean_reward,
"mean_ep_len": mean_ep_len,
"termination": mean_termination,
"reward_terms": mean_reward_terms,
"info": mean_info,
}

def _train_one_batch(self, mini_batch: dict[GAEBufferKey, torch.Tensor]) -> dict[str, Any]:
"""Train one batch of rollouts."""
obs = mini_batch[GAEBufferKey.ACTOR_OBS]
actor_obs = mini_batch[GAEBufferKey.ACTOR_OBS]
critic_obs = mini_batch[GAEBufferKey.CRITIC_OBS]
act = mini_batch[GAEBufferKey.ACTIONS]
old_log_prob = mini_batch[GAEBufferKey.ACTION_LOGPROBS]
old_mu = mini_batch[GAEBufferKey.MU]
old_sigma = mini_batch[GAEBufferKey.SIGMA]
advantage = mini_batch[GAEBufferKey.ADVANTAGES]
returns = mini_batch[GAEBufferKey.RETURNS]
target_values = mini_batch[GAEBufferKey.VALUES]

#
new_log_prob = self._actor.evaluate_log_prob(obs, act)
new_log_prob = self._actor.evaluate_log_prob(actor_obs, act)
ratio = torch.exp(new_log_prob - old_log_prob)
surr1 = -advantage * ratio
surr2 = -advantage * torch.clamp(
ratio, 1.0 - self.cfg.clip_ratio, 1.0 + self.cfg.clip_ratio
)
policy_loss = torch.max(surr1, surr2).mean()

approx_kl = (new_log_prob - old_log_prob).mean()
# Rename approx_kl to log_ratio
log_ratio = (new_log_prob - old_log_prob).mean()

# Calculate true KL divergence between old and new distributions
# Get new distribution parameters
dist = self._actor.dist_from_obs(actor_obs)
new_mu = dist.mean
new_sigma = dist.stddev

# KL divergence between two Gaussian distributions
# KL(P||Q) = 0.5 * [log(det(Σ_Q)/det(Σ_P)) - d + tr(Σ_Q^{-1} * Σ_P) + (μ_Q - μ_P)^T * Σ_Q^{-1} * (μ_Q - μ_P)]
# For diagonal Gaussians: KL = 0.5 * [sum(log(σ_Q^2/σ_P^2)) - d + sum(σ_P^2/σ_Q^2) + sum((μ_Q - μ_P)^2/σ_Q^2)]
kl_div = 0.5 * (
torch.log(new_sigma**2 / (old_sigma**2 + 1e-8)).sum(
dim=-1, keepdim=True
) # log ratio of variances
- new_mu.shape[-1] # dimension
+ (old_sigma**2 / (new_sigma**2 + 1e-8)).sum(dim=-1, keepdim=True) # trace term
+ ((new_mu - old_mu) ** 2 / (new_sigma**2 + 1e-8)).sum(
dim=-1, keepdim=True
) # mean difference term
)
kl_mean = kl_div.mean()

# Calculate value loss
values = self._critic(obs)
value_loss = (returns - values).pow(2).mean()
values = self._critic(critic_obs)

if self.use_clipped_value_loss:
clipped_values = target_values + (values - target_values).clamp(
-self.cfg.clip_ratio, self.cfg.clip_ratio
)
value_loss = (values - returns).pow(2)
clipped_value_loss = (clipped_values - returns).pow(2)
value_loss = torch.max(value_loss, clipped_value_loss).mean()
else:
value_loss = (returns - values).pow(2).mean()

# Calculate entropy loss
entropy = self._actor.entropy_on(obs)
entropy = self._actor.entropy_on(actor_obs)
entropy_loss = entropy.mean()

# Total loss
Expand All @@ -189,7 +288,10 @@ def _train_one_batch(self, mini_batch: dict[GAEBufferKey, torch.Tensor]) -> dict
"policy_loss": policy_loss.item(),
"value_loss": value_loss.item(),
"entropy_loss": entropy_loss.item(),
"approx_kl": approx_kl.item(),
"log_ratio": log_ratio.item(),
"kl_mean": kl_mean.item(),
"mean_std": self._actor.action_std.mean().item(),
"max_std": self._actor.action_std.max().item(),
}

def train_one_iteration(self) -> dict[str, Any]:
Expand All @@ -211,6 +313,10 @@ def train_one_iteration(self) -> dict[str, Any]:
t2 = time.time()
train_time = t2 - t1

# Update learning rate adaptively based on KL divergence
avg_kl_mean = statistics.mean([metrics["kl_mean"] for metrics in train_metrics_list])
self._update_learning_rate(avg_kl_mean)

self._rollouts.reset()

iteration_infos = {
Expand All @@ -228,16 +334,25 @@ def train_one_iteration(self) -> dict[str, Any]:
"entropy_loss": statistics.mean(
[metrics["entropy_loss"] for metrics in train_metrics_list]
),
"approx_kl": statistics.mean(
[metrics["approx_kl"] for metrics in train_metrics_list]
"log_ratio": statistics.mean(
[metrics["log_ratio"] for metrics in train_metrics_list]
),
"kl_mean": statistics.mean([metrics["kl_mean"] for metrics in train_metrics_list]),
"mean_std": statistics.mean(
[metrics["mean_std"] for metrics in train_metrics_list]
),
"max_std": statistics.mean([metrics["max_std"] for metrics in train_metrics_list]),
"learning_rate": self._current_lr,
},
"speed": {
"rollout_time": rollouts_time,
"rollout_fps": fps,
"train_time": train_time,
"rollout_step": self._num_steps * self._num_envs,
},
"termination": rollout_infos["termination"],
"reward_terms": rollout_infos["reward_terms"],
"info": rollout_infos["info"],
}
return iteration_infos

Expand All @@ -253,7 +368,7 @@ def save(self, path: Path) -> None:

def load(self, path: Path, load_optimizer: bool = True) -> None:
"""Load the algorithm from a file."""
checkpoint = torch.load(path)
checkpoint = torch.load(path, map_location=self.device)
self._actor.load_state_dict(checkpoint["model_state_dict"])
if load_optimizer:
self._actor_optimizer.load_state_dict(checkpoint["actor_optimizer_state_dict"])
Expand Down
Loading