Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
28 changes: 19 additions & 9 deletions mighty/mighty_exploration/mighty_exploration_policy.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,27 +9,36 @@


def sample_nondeterministic_logprobs(
z: torch.Tensor, mean: torch.Tensor, log_std: torch.Tensor, sac: bool = False
z: torch.Tensor,
mean: torch.Tensor,
log_std: torch.Tensor,
tanh_squash: bool = False,
) -> torch.Tensor:
"""
Compute log-prob of a Gaussian sample z ~ N(mean, exp(log_std)),
and if sac=True apply the tanh-squash correction to get log π(a).
Compute log-prob of a Gaussian sample z ~ N(mean, exp(log_std)).

If tanh_squash=True, applies the change-of-variables correction for
the squashed action a = tanh(z):

log π(a) = log N(z | mean, std) − ∑ log(1 − tanh(z)²)

Both old and new log-probs must use the same value of tanh_squash so
the correction cancels correctly in the PPO importance-sampling ratio.
"""
std = torch.exp(log_std) # [batch, action_dim]
dist = Normal(mean, std)
# base Gaussian log‐prob of z
log_pz = dist.log_prob(z).sum(dim=-1, keepdim=True) # [batch, 1]

if sac:
if tanh_squash:
# subtract the ∑_i log(d tanh/dz_i) = ∑ log(1 - tanh(z)^2)
eps = 1e-4
eps = 1e-6
log_correction = torch.log(1.0 - torch.tanh(z).pow(2) + eps).sum(
dim=-1, keepdim=True
) # [batch, 1]
return log_pz - log_correction
else:
# PPO-style or other: no squash correction
return log_pz

return log_pz


class MightyExplorationPolicy:
Expand Down Expand Up @@ -111,7 +120,8 @@ def sample_func_logits(self, state_array):
elif isinstance(out, tuple) and len(out) == 4:
action = out[0] # [batch, action_dim]
log_prob = sample_nondeterministic_logprobs(
z=out[1], mean=out[2], log_std=out[3], sac=self.ago == "sac"
z=out[1], mean=out[2], log_std=out[3],
tanh_squash=getattr(self.model, "tanh_squash", False),
)
return action.detach().cpu().numpy(), log_prob

Expand Down
55 changes: 17 additions & 38 deletions mighty/mighty_exploration/stochastic_policy.py
Original file line number Diff line number Diff line change
Expand Up @@ -87,16 +87,10 @@ def explore(self, s, return_logp, metrics=None) -> Tuple[np.ndarray, torch.Tenso
# 4-tuple case (Tanh squashing): (action, z, mean, log_std)
elif isinstance(model_output, tuple) and len(model_output) == 4:
action, z, mean, log_std = model_output

if not self.algo == "sac":
log_prob = sample_nondeterministic_logprobs(
z=z,
mean=mean,
log_std=log_std,
sac=False,
)
else:
log_prob = self.model.policy_log_prob(z, mean, log_std)
log_prob = sample_nondeterministic_logprobs(
z=z, mean=mean, log_std=log_std,
tanh_squash=getattr(self.model, "tanh_squash", False),
)

if return_logp:
return action.detach().cpu().numpy(), log_prob
Expand All @@ -119,15 +113,10 @@ def explore(self, s, return_logp, metrics=None) -> Tuple[np.ndarray, torch.Tenso
elif len(model_output) == 4:
# Tanh squashing mode: (action, z, mean, log_std)
action, z, mean, log_std = model_output
if not self.algo == "sac":
log_prob = sample_nondeterministic_logprobs(
z=z,
mean=mean,
log_std=log_std,
sac=False,
)
else:
log_prob = self.model.policy_log_prob(z, mean, log_std)
log_prob = sample_nondeterministic_logprobs(
z=z, mean=mean, log_std=log_std,
tanh_squash=getattr(self.model, "tanh_squash", False),
)
else:
raise ValueError(
f"Unexpected model output length: {len(model_output)}"
Expand All @@ -144,15 +133,10 @@ def explore(self, s, return_logp, metrics=None) -> Tuple[np.ndarray, torch.Tenso
if self.model.output_style == "squashed_gaussian":
# Should be 4-tuple: (action, z, mean, log_std)
action, z, mean, log_std = model_output
if not self.algo == "sac":
log_prob = sample_nondeterministic_logprobs(
z=z,
mean=mean,
log_std=log_std,
sac=False,
)
else:
log_prob = self.model.policy_log_prob(z, mean, log_std)
log_prob = sample_nondeterministic_logprobs(
z=z, mean=mean, log_std=log_std,
tanh_squash=getattr(self.model, "tanh_squash", False),
)

if return_logp:
return action.detach().cpu().numpy(), log_prob
Expand All @@ -166,16 +150,11 @@ def explore(self, s, return_logp, metrics=None) -> Tuple[np.ndarray, torch.Tenso
dist = Normal(mean, std)
z = dist.rsample()
action = torch.tanh(z)

if not self.algo == "sac":
log_prob = sample_nondeterministic_logprobs(
z=z,
mean=mean,
log_std=log_std,
sac=False,
)
else:
log_prob = self.model.policy_log_prob(z, mean, log_std)
log_std = torch.log(std)
log_prob = sample_nondeterministic_logprobs(
z=z, mean=mean, log_std=log_std,
tanh_squash=getattr(self.model, "tanh_squash", False),
)

entropy = dist.entropy().sum(dim=-1, keepdim=True)
weighted_log_prob = log_prob * entropy
Expand Down
1 change: 1 addition & 0 deletions mighty/mighty_models/sac.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@ class SACModel(nn.Module):
output_style = (
"squashed_gaussian" # For continuous actions, we use squashed Gaussian output
)
tanh_squash: bool = True # SAC always uses tanh squashing

def __init__(
self,
Expand Down
Loading