diff --git a/src/agent/gs_agent/algos/config/registry.py b/src/agent/gs_agent/algos/config/registry.py index 43d70155..6f1c9838 100644 --- a/src/agent/gs_agent/algos/config/registry.py +++ b/src/agent/gs_agent/algos/config/registry.py @@ -1,6 +1,6 @@ from pathlib import Path -from gs_agent.algos.config.schema import BCArgs, OptimizerType, PPOArgs +from gs_agent.algos.config.schema import BCArgs, LearningRateType, OptimizerType, PPOArgs from gs_agent.modules.config.registry import DEFAULT_MLP # default PPO config @@ -8,6 +8,7 @@ policy_backbone=DEFAULT_MLP, critic_backbone=DEFAULT_MLP, lr=3e-4, + lr_type=LearningRateType.FIXED, gamma=0.99, gae_lambda=0.95, clip_ratio=0.2, @@ -28,6 +29,7 @@ policy_backbone=DEFAULT_MLP, critic_backbone=DEFAULT_MLP, lr=3e-4, + lr_type=LearningRateType.FIXED, value_lr=None, gamma=0.99, gae_lambda=0.95, @@ -49,6 +51,7 @@ policy_backbone=DEFAULT_MLP, critic_backbone=DEFAULT_MLP, lr=3e-4, + lr_type=LearningRateType.FIXED, value_lr=None, gamma=0.99, gae_lambda=0.95, @@ -78,3 +81,53 @@ optimizer_type=OptimizerType.ADAM, weight_decay=0.0, ) + + +# goal reaching PPO config +PPO_WALKING_MLP = PPOArgs( + policy_backbone=DEFAULT_MLP, + critic_backbone=DEFAULT_MLP, + lr=1e-3, + lr_type=LearningRateType.ADAPTIVE, + lr_adaptive_factor=1.5, + lr_min=1e-5, + lr_max=1e-2, + value_lr=None, + gamma=0.99, + gae_lambda=0.95, + clip_ratio=0.2, + use_clipped_value_loss=True, + value_loss_coef=1.0, + entropy_coef=0.003, + max_grad_norm=1.0, + target_kl=0.01, + num_epochs=5, + num_mini_batches=4, + rollout_length=24, + optimizer_type=OptimizerType.ADAM, + weight_decay=0.0, +) + +# goal reaching PPO config +PPO_TELEOP_MLP = PPOArgs( + policy_backbone=DEFAULT_MLP, + critic_backbone=DEFAULT_MLP, + lr=1e-3, + lr_type=LearningRateType.ADAPTIVE, + lr_adaptive_factor=1.5, + lr_min=1e-5, + lr_max=1e-2, + value_lr=None, + gamma=0.99, + gae_lambda=0.95, + clip_ratio=0.2, + value_loss_coef=1.0, + entropy_coef=0.003, + max_grad_norm=1.0, + target_kl=0.01, + num_epochs=5, + num_mini_batches=4, + rollout_length=24, + optimizer_type=OptimizerType.ADAM, + weight_decay=0.0, +) diff --git a/src/agent/gs_agent/algos/config/schema.py b/src/agent/gs_agent/algos/config/schema.py index 9da5770b..297c2cbd 100644 --- a/src/agent/gs_agent/algos/config/schema.py +++ b/src/agent/gs_agent/algos/config/schema.py @@ -12,6 +12,11 @@ class OptimizerType(GenesisEnum): SGD = "SGD" +class LearningRateType(GenesisEnum): + FIXED = "FIXED" + ADAPTIVE = "ADAPTIVE" + + class AlgorithmType(GenesisEnum): PPO = "PPO" BC = "BC" @@ -37,6 +42,12 @@ class PPOArgs(BaseModel): value_lr: PositiveFloat | None = None """None means use the same learning rate as the policy""" + # Adaptive learning rate + lr_type: LearningRateType = LearningRateType.FIXED + lr_adaptive_factor: PositiveFloat = 1.5 + lr_min: PositiveFloat = 1e-5 + lr_max: PositiveFloat = 1e-2 + # Discount and GAE gamma: PositiveFloat = Field(default=0.99, ge=0, le=1) gae_lambda: PositiveFloat = Field(default=0.95, ge=0, le=1) @@ -47,6 +58,7 @@ class PPOArgs(BaseModel): entropy_coef: NonNegativeFloat = 0.0 max_grad_norm: PositiveFloat = 1.0 target_kl: PositiveFloat = 0.02 + use_clipped_value_loss: bool = False # Training num_epochs: NonNegativeInt = 10 diff --git a/src/agent/gs_agent/algos/ppo.py b/src/agent/gs_agent/algos/ppo.py index f8156e9b..0de2e6ac 100644 --- a/src/agent/gs_agent/algos/ppo.py +++ b/src/agent/gs_agent/algos/ppo.py @@ -7,7 +7,7 @@ import torch import torch.nn as nn -from gs_agent.algos.config.schema import PPOArgs +from gs_agent.algos.config.schema import LearningRateType, PPOArgs from gs_agent.bases.algo import BaseAlgo from gs_agent.bases.env_wrapper import BaseEnvWrapper from gs_agent.bases.policy import Policy @@ -47,6 +47,12 @@ def __init__( self.env.num_envs, device=self.device, dtype=torch.float ) self._curr_ep_len = torch.zeros(self.env.num_envs, device=self.device, dtype=torch.float) + + # Adaptive learning rate tracking + self._current_lr = cfg.lr + + self.use_clipped_value_loss = cfg.use_clipped_value_loss + # self._build_actor_critic() self._build_rollouts() @@ -64,7 +70,7 @@ def _build_actor_critic(self) -> None: policy_backbone=policy_backbone, action_dim=self._action_dim, ).to(self.device) - self._actor_optimizer = torch.optim.Adam(self._actor.parameters(), lr=self.cfg.lr) + self._actor_optimizer = torch.optim.Adam(self._actor.parameters(), lr=self._current_lr) critic_backbone = NetworkFactory.create_network( network_backbone_args=self.cfg.critic_backbone, @@ -84,28 +90,60 @@ def _build_rollouts(self) -> None: num_envs=self._num_envs, max_steps=self._num_steps, actor_obs_size=self._actor_obs_dim, + critic_obs_size=self._critic_obs_dim, action_size=self._action_dim, device=self.device, ) + def _update_learning_rate(self, kl_mean: float) -> None: + """Update learning rate adaptively based on KL divergence.""" + if self.cfg.lr_type == LearningRateType.ADAPTIVE: + # If KL is too high, reduce learning rate + if kl_mean > self.cfg.target_kl * 1.5: # 1.5x threshold + self._current_lr = max( + self._current_lr / self.cfg.lr_adaptive_factor, self.cfg.lr_min + ) + # If KL is too low, increase learning rate + elif kl_mean < self.cfg.target_kl * 0.5: # 0.5x threshold + self._current_lr = min( + self._current_lr * self.cfg.lr_adaptive_factor, self.cfg.lr_max + ) + + # Update only actor optimizer learning rate + for param_group in self._actor_optimizer.param_groups: + param_group["lr"] = self._current_lr + def _collect_rollouts(self, num_steps: int) -> dict[str, Any]: """Collect rollouts from the environment.""" - obs = self.env.get_observations() + actor_obs, critic_obs = self.env.get_observations() + termination_buffer = [] + reward_terms_buffer = [] + info_buffer = [] with torch.inference_mode(): # collect rollouts and compute returns & advantages for _step in range(num_steps): - action, log_prob = self._actor(obs) + action, log_prob, mu, sigma = self._actor.forward_with_dist_params(actor_obs) # Step environment - next_obs, reward, terminated, truncated, _extra_infos = self.env.step(action) + _, reward, terminated, truncated, _extra_infos = self.env.step(action) + next_actor_obs, next_critic_obs = self.env.get_observations() + + # add next value to reward of truncated steps + if "time_outs" in _extra_infos: + time_outs = _extra_infos["time_outs"] + next_values = self._critic(_extra_infos["observations"]["critic"]) + reward = reward + next_values * time_outs # all tensors are of shape: [num_envs, dim] transition = { - GAEBufferKey.ACTOR_OBS: obs, + GAEBufferKey.ACTOR_OBS: actor_obs, + GAEBufferKey.CRITIC_OBS: critic_obs, GAEBufferKey.ACTIONS: action, GAEBufferKey.REWARDS: reward, GAEBufferKey.DONES: terminated, - GAEBufferKey.VALUES: self._critic(obs), + GAEBufferKey.VALUES: self._critic(critic_obs), GAEBufferKey.ACTION_LOGPROBS: log_prob, + GAEBufferKey.MU: mu, + GAEBufferKey.SIGMA: sigma, } self._rollouts.append(transition) @@ -114,6 +152,11 @@ def _collect_rollouts(self, num_steps: int) -> dict[str, Any]: self._curr_reward_sum += reward.squeeze(-1) self._curr_ep_len += 1 + # Update termination buffer + termination_buffer.append(_extra_infos["termination"]) + reward_terms_buffer.append(_extra_infos["reward_terms"]) + if "info" in _extra_infos: + info_buffer.append(_extra_infos["info"]) # Check for episode completions and reset tracking done_mask = terminated.unsqueeze(-1) | truncated.unsqueeze(-1) new_ids = (done_mask > 0).nonzero(as_tuple=False) @@ -127,9 +170,9 @@ def _collect_rollouts(self, num_steps: int) -> dict[str, Any]: self._curr_reward_sum[new_ids] = 0 self._curr_ep_len[new_ids] = 0 - obs = next_obs + actor_obs, critic_obs = next_actor_obs, next_critic_obs with torch.no_grad(): - last_value = self._critic(self.env.get_observations()) + last_value = self._critic(critic_obs) self._rollouts.set_final_value(last_value) mean_reward = 0.0 @@ -137,21 +180,46 @@ def _collect_rollouts(self, num_steps: int) -> dict[str, Any]: if len(self._rewbuffer) > 0: mean_reward = statistics.mean(self._rewbuffer) mean_ep_len = statistics.mean(self._lenbuffer) + # import ipdb; ipdb.set_trace() + mean_termination = {} + mean_reward_terms = {} + mean_info = {} + if len(termination_buffer) > 0: + for key in termination_buffer[0].keys(): + terminations = torch.stack([termination[key] for termination in termination_buffer]) + mean_termination[key] = terminations.to(torch.float).mean().item() + if len(reward_terms_buffer) > 0: + for key in reward_terms_buffer[0].keys(): + reward_terms = torch.stack( + [reward_term[key] for reward_term in reward_terms_buffer] + ) + mean_reward_terms[key] = reward_terms.mean().item() + if len(info_buffer) > 0: + for key in info_buffer[0].keys(): + infos = torch.tensor([info[key] for info in info_buffer]) + mean_info[key] = infos.mean().item() return { "mean_reward": mean_reward, "mean_ep_len": mean_ep_len, + "termination": mean_termination, + "reward_terms": mean_reward_terms, + "info": mean_info, } def _train_one_batch(self, mini_batch: dict[GAEBufferKey, torch.Tensor]) -> dict[str, Any]: """Train one batch of rollouts.""" - obs = mini_batch[GAEBufferKey.ACTOR_OBS] + actor_obs = mini_batch[GAEBufferKey.ACTOR_OBS] + critic_obs = mini_batch[GAEBufferKey.CRITIC_OBS] act = mini_batch[GAEBufferKey.ACTIONS] old_log_prob = mini_batch[GAEBufferKey.ACTION_LOGPROBS] + old_mu = mini_batch[GAEBufferKey.MU] + old_sigma = mini_batch[GAEBufferKey.SIGMA] advantage = mini_batch[GAEBufferKey.ADVANTAGES] returns = mini_batch[GAEBufferKey.RETURNS] + target_values = mini_batch[GAEBufferKey.VALUES] # - new_log_prob = self._actor.evaluate_log_prob(obs, act) + new_log_prob = self._actor.evaluate_log_prob(actor_obs, act) ratio = torch.exp(new_log_prob - old_log_prob) surr1 = -advantage * ratio surr2 = -advantage * torch.clamp( @@ -159,14 +227,45 @@ def _train_one_batch(self, mini_batch: dict[GAEBufferKey, torch.Tensor]) -> dict ) policy_loss = torch.max(surr1, surr2).mean() - approx_kl = (new_log_prob - old_log_prob).mean() + # Rename approx_kl to log_ratio + log_ratio = (new_log_prob - old_log_prob).mean() + + # Calculate true KL divergence between old and new distributions + # Get new distribution parameters + dist = self._actor.dist_from_obs(actor_obs) + new_mu = dist.mean + new_sigma = dist.stddev + + # KL divergence between two Gaussian distributions + # KL(P||Q) = 0.5 * [log(det(Σ_Q)/det(Σ_P)) - d + tr(Σ_Q^{-1} * Σ_P) + (μ_Q - μ_P)^T * Σ_Q^{-1} * (μ_Q - μ_P)] + # For diagonal Gaussians: KL = 0.5 * [sum(log(σ_Q^2/σ_P^2)) - d + sum(σ_P^2/σ_Q^2) + sum((μ_Q - μ_P)^2/σ_Q^2)] + kl_div = 0.5 * ( + torch.log(new_sigma**2 / (old_sigma**2 + 1e-8)).sum( + dim=-1, keepdim=True + ) # log ratio of variances + - new_mu.shape[-1] # dimension + + (old_sigma**2 / (new_sigma**2 + 1e-8)).sum(dim=-1, keepdim=True) # trace term + + ((new_mu - old_mu) ** 2 / (new_sigma**2 + 1e-8)).sum( + dim=-1, keepdim=True + ) # mean difference term + ) + kl_mean = kl_div.mean() # Calculate value loss - values = self._critic(obs) - value_loss = (returns - values).pow(2).mean() + values = self._critic(critic_obs) + + if self.use_clipped_value_loss: + clipped_values = target_values + (values - target_values).clamp( + -self.cfg.clip_ratio, self.cfg.clip_ratio + ) + value_loss = (values - returns).pow(2) + clipped_value_loss = (clipped_values - returns).pow(2) + value_loss = torch.max(value_loss, clipped_value_loss).mean() + else: + value_loss = (returns - values).pow(2).mean() # Calculate entropy loss - entropy = self._actor.entropy_on(obs) + entropy = self._actor.entropy_on(actor_obs) entropy_loss = entropy.mean() # Total loss @@ -189,7 +288,10 @@ def _train_one_batch(self, mini_batch: dict[GAEBufferKey, torch.Tensor]) -> dict "policy_loss": policy_loss.item(), "value_loss": value_loss.item(), "entropy_loss": entropy_loss.item(), - "approx_kl": approx_kl.item(), + "log_ratio": log_ratio.item(), + "kl_mean": kl_mean.item(), + "mean_std": self._actor.action_std.mean().item(), + "max_std": self._actor.action_std.max().item(), } def train_one_iteration(self) -> dict[str, Any]: @@ -211,6 +313,10 @@ def train_one_iteration(self) -> dict[str, Any]: t2 = time.time() train_time = t2 - t1 + # Update learning rate adaptively based on KL divergence + avg_kl_mean = statistics.mean([metrics["kl_mean"] for metrics in train_metrics_list]) + self._update_learning_rate(avg_kl_mean) + self._rollouts.reset() iteration_infos = { @@ -228,9 +334,15 @@ def train_one_iteration(self) -> dict[str, Any]: "entropy_loss": statistics.mean( [metrics["entropy_loss"] for metrics in train_metrics_list] ), - "approx_kl": statistics.mean( - [metrics["approx_kl"] for metrics in train_metrics_list] + "log_ratio": statistics.mean( + [metrics["log_ratio"] for metrics in train_metrics_list] + ), + "kl_mean": statistics.mean([metrics["kl_mean"] for metrics in train_metrics_list]), + "mean_std": statistics.mean( + [metrics["mean_std"] for metrics in train_metrics_list] ), + "max_std": statistics.mean([metrics["max_std"] for metrics in train_metrics_list]), + "learning_rate": self._current_lr, }, "speed": { "rollout_time": rollouts_time, @@ -238,6 +350,9 @@ def train_one_iteration(self) -> dict[str, Any]: "train_time": train_time, "rollout_step": self._num_steps * self._num_envs, }, + "termination": rollout_infos["termination"], + "reward_terms": rollout_infos["reward_terms"], + "info": rollout_infos["info"], } return iteration_infos @@ -253,7 +368,7 @@ def save(self, path: Path) -> None: def load(self, path: Path, load_optimizer: bool = True) -> None: """Load the algorithm from a file.""" - checkpoint = torch.load(path) + checkpoint = torch.load(path, map_location=self.device) self._actor.load_state_dict(checkpoint["model_state_dict"]) if load_optimizer: self._actor_optimizer.load_state_dict(checkpoint["actor_optimizer_state_dict"]) diff --git a/src/agent/gs_agent/buffers/config/schema.py b/src/agent/gs_agent/buffers/config/schema.py index 798f29a8..360460f5 100644 --- a/src/agent/gs_agent/buffers/config/schema.py +++ b/src/agent/gs_agent/buffers/config/schema.py @@ -6,11 +6,14 @@ class GAEBufferKey(str, Enum): ACTOR_OBS = "ACTOR_OBS" + CRITIC_OBS = "CRITIC_OBS" ACTIONS = "ACTIONS" REWARDS = "REWARDS" DONES = "DONES" VALUES = "VALUES" ACTION_LOGPROBS = "ACTION_LOGPROBS" + MU = "MU" + SIGMA = "SIGMA" ADVANTAGES = "ADVANTAGES" RETURNS = "RETURNS" diff --git a/src/agent/gs_agent/buffers/gae_buffer.py b/src/agent/gs_agent/buffers/gae_buffer.py index 21724c28..be00a91f 100644 --- a/src/agent/gs_agent/buffers/gae_buffer.py +++ b/src/agent/gs_agent/buffers/gae_buffer.py @@ -85,6 +85,7 @@ def __init__( num_envs: int, max_steps: int, actor_obs_size: int, + critic_obs_size: int, action_size: int, device: torch.device = _DEFAULT_DEVICE, gae_gamma: float = 0.98, @@ -104,6 +105,7 @@ def __init__( self._num_envs = num_envs self._max_steps = max_steps self._actor_obs_dim = actor_obs_size + self._critic_obs_dim = critic_obs_size self._action_dim = action_size self._device = device @@ -119,18 +121,21 @@ def __init__( self._final_value = None # Initialize buffer - self._buffer = self._init_buffers(actor_obs_size, action_size) + self._buffer = self._init_buffers(actor_obs_size, critic_obs_size, action_size) - def _init_buffers(self, actor_obs_dim: int, action_dim: int) -> TensorDict: + def _init_buffers(self, actor_obs_dim: int, critic_obs_dim: int, action_dim: int) -> TensorDict: max_steps, num_envs = self._max_steps, self._num_envs buffer = TensorDict( { GAEBufferKey.ACTOR_OBS: torch.zeros(max_steps, num_envs, actor_obs_dim), + GAEBufferKey.CRITIC_OBS: torch.zeros(max_steps, num_envs, critic_obs_dim), GAEBufferKey.ACTIONS: torch.zeros(max_steps, num_envs, action_dim), GAEBufferKey.REWARDS: torch.zeros(max_steps, num_envs, 1), GAEBufferKey.DONES: torch.zeros(max_steps, num_envs, 1).byte(), GAEBufferKey.VALUES: torch.zeros(max_steps, num_envs, 1), GAEBufferKey.ACTION_LOGPROBS: torch.zeros(max_steps, num_envs, 1), + GAEBufferKey.MU: torch.zeros(max_steps, num_envs, action_dim), + GAEBufferKey.SIGMA: torch.zeros(max_steps, num_envs, action_dim), }, batch_size=[self._max_steps, self._num_envs], device=self._device, @@ -146,11 +151,14 @@ def append(self, transition: dict[GAEBufferKey, torch.Tensor]) -> None: raise ValueError(f"Buffer full! Cannot append more than {self._max_steps} steps.") idx = self._idx self._buffer[GAEBufferKey.ACTOR_OBS][idx] = transition[GAEBufferKey.ACTOR_OBS] + self._buffer[GAEBufferKey.CRITIC_OBS][idx] = transition[GAEBufferKey.CRITIC_OBS] self._buffer[GAEBufferKey.ACTIONS][idx] = transition[GAEBufferKey.ACTIONS] self._buffer[GAEBufferKey.REWARDS][idx] = transition[GAEBufferKey.REWARDS] self._buffer[GAEBufferKey.DONES][idx] = transition[GAEBufferKey.DONES] self._buffer[GAEBufferKey.VALUES][idx] = transition[GAEBufferKey.VALUES] self._buffer[GAEBufferKey.ACTION_LOGPROBS][idx] = transition[GAEBufferKey.ACTION_LOGPROBS] + self._buffer[GAEBufferKey.MU][idx] = transition[GAEBufferKey.MU] + self._buffer[GAEBufferKey.SIGMA][idx] = transition[GAEBufferKey.SIGMA] # Increment index self._idx += 1 @@ -208,6 +216,9 @@ def minibatch_gen( GAEBufferKey.ACTOR_OBS: self._buffer[GAEBufferKey.ACTOR_OBS][ t_idx, b_idx ].reshape(mini_batch_size, -1), + GAEBufferKey.CRITIC_OBS: self._buffer[GAEBufferKey.CRITIC_OBS][ + t_idx, b_idx + ].reshape(mini_batch_size, -1), GAEBufferKey.ACTIONS: self._buffer[GAEBufferKey.ACTIONS][t_idx, b_idx].reshape( mini_batch_size, -1 ), @@ -223,6 +234,12 @@ def minibatch_gen( GAEBufferKey.ACTION_LOGPROBS: self._buffer[GAEBufferKey.ACTION_LOGPROBS][ t_idx, b_idx ].reshape(mini_batch_size, -1), + GAEBufferKey.MU: self._buffer[GAEBufferKey.MU][t_idx, b_idx].reshape( + mini_batch_size, -1 + ), + GAEBufferKey.SIGMA: self._buffer[GAEBufferKey.SIGMA][t_idx, b_idx].reshape( + mini_batch_size, -1 + ), GAEBufferKey.ADVANTAGES: advantages[t_idx, b_idx].reshape(mini_batch_size, -1), GAEBufferKey.RETURNS: returns[t_idx, b_idx].reshape(mini_batch_size, -1), } diff --git a/src/agent/gs_agent/modules/policies.py b/src/agent/gs_agent/modules/policies.py index bbed01cb..ce9957be 100644 --- a/src/agent/gs_agent/modules/policies.py +++ b/src/agent/gs_agent/modules/policies.py @@ -11,7 +11,7 @@ class GaussianPolicy(Policy): def __init__(self, policy_backbone: NetworkBackbone, action_dim: int) -> None: super().__init__(policy_backbone, action_dim) self.mu = nn.Linear(self.backbone.output_dim, self.action_dim) - self.log_std = nn.Parameter(torch.ones(self.action_dim) * 1.0) + self.log_std = nn.Parameter(torch.ones(self.action_dim) * 0.0) Normal.set_default_validate_args(False) self._init_params() @@ -33,10 +33,10 @@ def forward( deterministic: Whether to use deterministic action. Returns: - GaussianPolicyOutput: Policy output. + tuple: (action, log_prob) """ # Convert observation to tensor format - dist = self._dist_from_obs(obs) + dist = self.dist_from_obs(obs) if deterministic: action = dist.mean else: @@ -45,7 +45,37 @@ def forward( log_prob = dist.log_prob(action).sum(-1, keepdim=True) return action, log_prob - def _dist_from_obs(self, obs: torch.Tensor) -> Normal: + def forward_with_dist_params( + self, + obs: torch.Tensor, + *, + deterministic: bool = False, + ) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]: + """Forward pass of the policy with distribution parameters. + + Args: + obs: Typed observation batch. + deterministic: Whether to use deterministic action. + + Returns: + tuple: (action, log_prob, mu, sigma) + """ + # Convert observation to tensor format + dist = self.dist_from_obs(obs) + if deterministic: + action = dist.mean + else: + action = dist.sample() + # Compute log probability with tanh transformation + log_prob = dist.log_prob(action).sum(-1, keepdim=True) + + # Extract mu and sigma from the distribution + mu = dist.mean + sigma = dist.stddev + + return action, log_prob, mu, sigma + + def dist_from_obs(self, obs: torch.Tensor) -> Normal: feature = self.backbone(obs) action_mu = self.mu(feature) action_std = torch.exp(self.log_std) @@ -61,7 +91,7 @@ def evaluate_log_prob(self, obs: torch.Tensor, act: torch.Tensor) -> torch.Tenso Returns: Log probability of the action. """ - dist = self._dist_from_obs(obs) + dist = self.dist_from_obs(obs) return dist.log_prob(act).sum(-1, keepdim=True) def entropy_on(self, obs: torch.Tensor) -> torch.Tensor: @@ -73,12 +103,16 @@ def entropy_on(self, obs: torch.Tensor) -> torch.Tensor: Returns: Entropy of the action distribution. """ - dist = self._dist_from_obs(obs) + dist = self.dist_from_obs(obs) return dist.entropy().sum(-1) def get_action_shape(self) -> tuple[int, ...]: return (self.action_dim,) + @property + def action_std(self) -> torch.Tensor: + return self.log_std.exp() + class DeterministicPolicy(Policy): def __init__(self, policy_backbone: NetworkBackbone, action_dim: int) -> None: diff --git a/src/agent/gs_agent/runners/config/registry.py b/src/agent/gs_agent/runners/config/registry.py index 6274eda6..7b3ee088 100644 --- a/src/agent/gs_agent/runners/config/registry.py +++ b/src/agent/gs_agent/runners/config/registry.py @@ -30,3 +30,18 @@ save_interval=100, save_path=Path("./logs/ppo_gs_goal_reaching"), ) + + +RUNNER_WALKING_MLP = RunnerArgs( + total_iterations=6001, + log_interval=5, + save_interval=500, + save_path=Path("./logs/ppo_gs_walking"), +) + +RUNNER_TELEOP_MLP = RunnerArgs( + total_iterations=3001, + log_interval=5, + save_interval=500, + save_path=Path("./logs/ppo_gs_teleop"), +) diff --git a/src/agent/gs_agent/runners/onpolicy_runner.py b/src/agent/gs_agent/runners/onpolicy_runner.py index 241e9373..2ec8325e 100644 --- a/src/agent/gs_agent/runners/onpolicy_runner.py +++ b/src/agent/gs_agent/runners/onpolicy_runner.py @@ -90,7 +90,6 @@ def train(self, metric_logger: Any) -> dict[str, int | float | str]: "total_steps": total_steps, "total_time": training_time, "final_reward": reward_list[-1], - "final_iteration": total_iterations, } def _log_metrics( diff --git a/src/agent/gs_agent/utils/logger.py b/src/agent/gs_agent/utils/logger.py index efe31a96..c6d2e117 100644 --- a/src/agent/gs_agent/utils/logger.py +++ b/src/agent/gs_agent/utils/logger.py @@ -14,7 +14,7 @@ from collections import defaultdict from collections.abc import Mapping, Sequence from io import TextIOBase -from typing import Any, TextIO +from typing import Any, TextIO, Literal import numpy as np import pandas @@ -476,10 +476,12 @@ def __init__( entity: str | None = None, config: dict[str, Any] | None = None, log_dir: str | None = None, + mode: Literal["online", "offline", "disabled"] = "online", + exp_name: str | None = None, ): if log_dir is not None: os.environ["WANDB_DIR"] = log_dir - wandb.init(project=project, entity=entity, config=config) + wandb.init(project=project, entity=entity, config=config, mode=mode, name=exp_name) self.run = wandb.run def write( @@ -528,6 +530,8 @@ def make_output_format(_format: str, log_dir: str, log_suffix: str = "", **kwarg entity=kwargs.get("entity", None), config=kwargs.get("config", None), log_dir=log_dir, + mode=kwargs.get("mode", "online"), + exp_name=kwargs.get("exp_name", None) ) else: raise ValueError(f"Unknown format specified: {_format}") @@ -712,7 +716,7 @@ def _do_log(self, args: tuple[Any, ...]) -> None: _format.write_sequence(list(map(str, args))) -def configure(folder: str | None = None, format_strings: list[str] | None = None) -> Logger: +def configure(folder: str | None = None, format_strings: list[str] | None = None, **kwargs) -> Logger: """ Configure the current logger. @@ -737,7 +741,7 @@ def configure(folder: str | None = None, format_strings: list[str] | None = None format_strings = os.getenv("GSRL_LOG_FORMAT", "stdout,log,csv").split(",") format_strings = list(filter(None, format_strings)) - output_formats = [make_output_format(f, folder, log_suffix) for f in format_strings] + output_formats = [make_output_format(f, folder, log_suffix, **kwargs) for f in format_strings] logger = Logger(folder=folder, output_formats=output_formats) # Only print when some files will be saved diff --git a/src/agent/gs_agent/wrappers/gs_env_wrapper.py b/src/agent/gs_agent/wrappers/gs_env_wrapper.py index ed3b4f00..543b6aec 100644 --- a/src/agent/gs_agent/wrappers/gs_env_wrapper.py +++ b/src/agent/gs_agent/wrappers/gs_env_wrapper.py @@ -32,12 +32,6 @@ def step( ) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, dict[str, Any]]: # apply action self.env.apply_action(action) - # get observations - next_obs = self.env.get_observations() - # get reward - reward, reward_terms = self.env.get_reward() - if reward.dim() == 1: - reward = reward.unsqueeze(-1) # get terminated terminated = self.env.get_terminated() if terminated.dim() == 1: @@ -46,13 +40,21 @@ def step( truncated = self.env.get_truncated() if truncated.dim() == 1: truncated = truncated.unsqueeze(-1) + # get reward + reward, reward_terms = self.env.get_reward() + if reward.dim() == 1: + reward = reward.unsqueeze(-1) + # update history + self.env.update_history() + # get extra infos + extra_infos = self.env.get_extra_infos() + extra_infos["reward_terms"] = reward_terms # reset if terminated or truncated done_idx = terminated.nonzero(as_tuple=True)[0] if len(done_idx) > 0: self.env.reset_idx(done_idx) - # get extra infos - extra_infos = self.env.get_extra_infos() - extra_infos["reward_terms"] = reward_terms + # get observations + next_obs, _ = self.env.get_observations() return next_obs, reward, terminated, truncated, extra_infos def get_observations(self) -> torch.Tensor: