Description
Motivation
1. Consistent style for torch.nn.modules.loss.*Loss
In torch.nn.modules.loss
, there are many *Loss
subclassing nn.Module
. The Loss.__init__()
does not takes other nn.Module
's as arguments. And method Loss.forward()
method is purely functional and directly calls nn.functional.*_loss
.
I think the motivation for using torch.nn.modules.loss.*Loss
is compositing networks by nn.Sequential(...)
.
2. More straightforward implementation for functional style algorithms, such as meta-RL algorithms
In many meta-RL algorithms, the policy is trained with meta-parameters that may not register to the LossModule
.
Case.1 MGRL: Register leaf meta-parameters as buffers in the loss module
For Meta-Gradient Reinforcement Learning (MGRL) https://arxiv.org/abs/1805.09801, it takes the discount factor gamma
as the meta-parameter cross RL updates.
Use PPO for example:
import torch
import torch.nn as nn
from torchrl.objectives import PPOLoss
import torchopt
### Setup ###
meta_param = nn.Parameter(torch.tensor(0.95))
loss_module = PPOLoss(
actor, critic, ...,
gamma=None, # whatever value
)
loss_module.regester_buffer('gamma', meta_param) # register gamma as buffer
### Optimizers ###
inner_optim = torchopt.MetaAdam(loss_module)
outer_optim = torchopt.Adam([meta_param])
### Inner update (update network parameters) ####
inner_loss1 = loss_module(tensordict)
inner_optim.step(inner_loss1['loss_objective']) # inner update 1: param(0) -> param(1)
...
inner_lossN = loss_module(tensordict)
inner_optim.step(inner_lossN['loss_objective']) # inner update N: param(N - 1) -> param(N)
### Outer update (update meta-parameters (gamma)) ###
outer_loss = loss_module(tensordict) # sampled by param(N)
outer_optim.zero_grad()
outer_loss.backward()
outer_optim.step()
See https://github.com/metaopt/TorchOpt#torchopt-as-differentiable-optimizer-for-meta-learning for figures.
we need to register our meta-parameter gamma
in the buffer of the loss module instead of full control of the parameters by the user.
For integration with functorch
, register the meta-parameter as module buffer works freely.
meta_param = nn.Parameter(torch.tensor(0.95))
loss_module = PPOLoss(
actor, critic, ...,
gamma=None, # whatever value
)
loss_module.regester_buffer('gamma', meta_param) # register gamma as buffer
# Make functional
fmodel, params, buffers = functorch.make_functional_with_buffers(loss_module)
Case.2 LPG: Register non-leaf meta-parameters as buffers in the loss module on every outer update
For Learning Policy Gradient (LPG) https://arxiv.org/abs/2007.08794, it takes the LSTM network as the meta-parameter.
Different from MGRL, on each update, the meta-network output is not a leaf tensor anymore. Then we need to register these output again and again before each call of loss_module.forward
. This makes
fmodel, params, buffers = functorch.make_functional_with_buffers(loss_module)
not working.
cc @Benjamin-eecs @waterhorse1
Solution
A clear and concise description of what you want to happen.
Split the forward
method in the loss module into a separate pure function, i.e., a state less function does not have any parameters. The model parameters should be organized by other modules. The loss function only takes a tensordict
as input, and add a new key "loss_objective"
into the tensordict
. All tensor inputs (e.g. value = self.critic(...)
) should be calculated before calling the loss function, because the loss function is purely functional, i.e., does not host parameters (e.g., actor.parameters()
, critic.parameters()
).
Here is a prototype example:
def ppo_loss(tensordict: TensorDictBase, dist_class: Distribution, **kwargs) -> TensorDictBase:
tensordict = tensordict.clone()
gamma = tensordict.get('gamma', kwargs.get('gamma')) # In MGRL, the `gamma` parameter can be a tensor rather than a Python scalar
critic_coeff = tensordict.get('critic_coeff', kwargs.get('critic_coeff'))
dist = dist_class(tensordict.get('actor_output'))
...
For backward compatibility, refactor the PPOLoss
module as:
class PPOLoss(LossModule):
actor: nn.Module
critic: nn.Module
gamma: float
entropy_coef: float
critic_coef: float
advantage_module: nn.Module
def __init__(self, ...):
...
self.floss_kwargs = dict(
gamma=self.gamma,
entropy_coef=self.entropy_coef,
critic_coef=self.critic_coef,
)
def forward(self, tensordict: TensorDictBase) -> TensorDictBase:
# Prepare all inputs for loss function
if self.advantage_module is not None:
tensordict = self.advantage_module(
tensordict,
)
tensordict = tensordict.clone()
...
# Call purely functional version of loss function
return ppo_loss(
tensordict,
dist_class,
**self.floss_kwargs
)
Alternatives
A clear and concise description of any alternative solutions or features you've considered.
Copy and paste the loss module source code, then do specific customizations.
Additional context
Add any other context or screenshots about the feature request here.
Checklist
- I have checked that there is no similar issue in the repo (required)