Skip to content

Commit

Permalink
[WIP] Compute lp during loss execution
Browse files Browse the repository at this point in the history
ghstack-source-id: f16d93a5fab2016d436c808896c9cf24f783a754
Pull Request resolved: #2688
  • Loading branch information
vmoens committed Jan 15, 2025
1 parent 7575e96 commit 2811962
Show file tree
Hide file tree
Showing 2 changed files with 45 additions and 15 deletions.
16 changes: 7 additions & 9 deletions test/test_cost.py
Original file line number Diff line number Diff line change
Expand Up @@ -8180,18 +8180,19 @@ def _create_seq_mock_data_ppo(
obs = total_obs[:, :T]
next_obs = total_obs[:, 1:]
if atoms:
action = torch.randn(batch, T, atoms, action_dim, device=device).clamp(
-1, 1
)
action_shape = (batch, T, atoms, action_dim)
else:
action = torch.randn(batch, T, action_dim, device=device).clamp(-1, 1)
action_shape = (batch, T, action_dim)
params_mean = torch.randn(action_shape, device=device) / 10
params_scale = torch.rand(action_shape, device=device) / 10
action = (params_mean + params_scale * torch.randn(action_shape, device=device)).clamp(
-1, 1
)
reward = torch.randn(batch, T, 1, device=device)
done = torch.zeros(batch, T, 1, dtype=torch.bool, device=device)
terminated = torch.zeros(batch, T, 1, dtype=torch.bool, device=device)
mask = torch.ones(batch, T, dtype=torch.bool, device=device)
action = action.masked_fill_(~mask.unsqueeze(-1), 0.0)
params_mean = torch.randn_like(action) / 10
params_scale = torch.rand_like(action) / 10
loc = params_mean.masked_fill_(~mask.unsqueeze(-1), 0.0)
scale = params_scale.masked_fill_(~mask.unsqueeze(-1), 0.0)
if sample_log_prob_key is None:
Expand All @@ -8218,9 +8219,6 @@ def _create_seq_mock_data_ppo(
},
"collector": {"mask": mask},
action_key: action,
sample_log_prob_key: (
torch.randn_like(action[..., 1]) / 10
).masked_fill_(~mask, 0.0),
},
device=device,
names=[None, "time"],
Expand Down
44 changes: 38 additions & 6 deletions torchrl/objectives/ppo.py
Original file line number Diff line number Diff line change
Expand Up @@ -518,10 +518,7 @@ def _log_weight(
self.actor_network
) if self.functional else contextlib.nullcontext():
dist = self.actor_network.get_dist(tensordict)
if isinstance(dist, CompositeDistribution):
is_composite = True
else:
is_composite = False
is_composite = isinstance(dist, CompositeDistribution)

# current log_prob of actions
if is_composite:
Expand All @@ -538,6 +535,32 @@ def _log_weight(
prev_log_prob = _maybe_get_or_select(
tensordict, self.tensor_keys.sample_log_prob
)
# TODO:
# # current log_prob of actions
# action = _maybe_get_or_select(tensordict, self.tensor_keys.action)
#
# is_composite = None
# if all(key in tensordict for key in self.actor_network.dist_params_keys):
# prev_dist = self.actor_network.build_dist_from_params(tensordict.detach())
# kwargs, is_composite = _get_composite_kwargs(prev_dist)
# if is_composite:
# prev_log_prob = prev_dist.log_prob(tensordict, **kwargs)
# else:
# prev_log_prob = prev_dist.log_prob(action, **kwargs)
# print('prev_log_prob', prev_log_prob)
# else:
# try:
# prev_log_prob = _maybe_get_or_select(
# tensordict, self.tensor_keys.sample_log_prob
# )
# except KeyError as err:
# raise _make_lp_get_error(self.tensor_keys, tensordict, err)

with self.actor_network_params.to_module(
self.actor_network
) if self.functional else contextlib.nullcontext():
current_dist = self.actor_network.get_dist(tensordict)


if prev_log_prob.requires_grad:
raise RuntimeError(
Expand All @@ -558,6 +581,13 @@ def _log_weight(
"the beginning of your script to get a proper composite log-prob.",
category=UserWarning,
)
# TODO:
# if isinstance(action, torch.Tensor):
# log_prob = current_dist.log_prob(action)
# else:
# if is_composite is None:
# kwargs, is_composite = _get_composite_kwargs(current_dist)
# log_prob: TensorDictBase = current_dist.log_prob(tensordict, **kwargs)
if (
is_composite
and not is_tensor_collection(prev_log_prob)
Expand All @@ -571,7 +601,7 @@ def _log_weight(
if is_tensor_collection(kl_approx):
kl_approx = _sum_td_features(kl_approx)

return log_weight, dist, kl_approx
return log_weight, current_dist, kl_approx

def loss_critic(self, tensordict: TensorDictBase) -> torch.Tensor:
"""Returns the critic loss multiplied by ``critic_coef``, if it is not ``None``."""
Expand Down Expand Up @@ -647,6 +677,9 @@ def _cached_critic_network_params_detached(self):
@dispatch
def forward(self, tensordict: TensorDictBase) -> TensorDictBase:
tensordict = tensordict.clone(False)

log_weight, dist, kl_approx = self._log_weight(tensordict)

advantage = tensordict.get(self.tensor_keys.advantage, None)
if advantage is None:
self.value_estimator(
Expand All @@ -660,7 +693,6 @@ def forward(self, tensordict: TensorDictBase) -> TensorDictBase:
scale = advantage.std().clamp_min(1e-6)
advantage = (advantage - loc) / scale

log_weight, dist, kl_approx = self._log_weight(tensordict)
if is_tensor_collection(log_weight):
log_weight = _sum_td_features(log_weight)
log_weight = log_weight.view(advantage.shape)
Expand Down

0 comments on commit 2811962

Please sign in to comment.