Skip to content

[Feature Request] Purely functional loss objectives #338

Open
@XuehaiPan

Description

@XuehaiPan

Motivation

1. Consistent style for torch.nn.modules.loss.*Loss

In torch.nn.modules.loss, there are many *Loss subclassing nn.Module. The Loss.__init__() does not takes other nn.Module's as arguments. And method Loss.forward() method is purely functional and directly calls nn.functional.*_loss.

I think the motivation for using torch.nn.modules.loss.*Loss is compositing networks by nn.Sequential(...).

2. More straightforward implementation for functional style algorithms, such as meta-RL algorithms

In many meta-RL algorithms, the policy is trained with meta-parameters that may not register to the LossModule.

Case.1 MGRL: Register leaf meta-parameters as buffers in the loss module

For Meta-Gradient Reinforcement Learning (MGRL) https://arxiv.org/abs/1805.09801, it takes the discount factor gamma as the meta-parameter cross RL updates.

Use PPO for example:

import torch
import torch.nn as nn

from torchrl.objectives import PPOLoss

import torchopt

### Setup ###
meta_param = nn.Parameter(torch.tensor(0.95))
loss_module = PPOLoss(
    actor, critic, ...,
    gamma=None,  # whatever value
)
loss_module.regester_buffer('gamma', meta_param)  # register gamma as buffer

### Optimizers ###
inner_optim = torchopt.MetaAdam(loss_module)
outer_optim = torchopt.Adam([meta_param])

### Inner update (update network parameters) ####
inner_loss1 = loss_module(tensordict)
inner_optim.step(inner_loss1['loss_objective'])  # inner update 1: param(0) -> param(1)

...

inner_lossN = loss_module(tensordict)
inner_optim.step(inner_lossN['loss_objective'])  # inner update N: param(N - 1) -> param(N)

### Outer update (update meta-parameters (gamma)) ###
outer_loss = loss_module(tensordict)  # sampled by param(N)
outer_optim.zero_grad()
outer_loss.backward()
outer_optim.step()

See https://github.com/metaopt/TorchOpt#torchopt-as-differentiable-optimizer-for-meta-learning for figures.

we need to register our meta-parameter gamma in the buffer of the loss module instead of full control of the parameters by the user.

For integration with functorch, register the meta-parameter as module buffer works freely.

meta_param = nn.Parameter(torch.tensor(0.95))
loss_module = PPOLoss(
    actor, critic, ...,
    gamma=None,  # whatever value
)
loss_module.regester_buffer('gamma', meta_param)  # register gamma as buffer

# Make functional
fmodel, params, buffers = functorch.make_functional_with_buffers(loss_module)

Case.2 LPG: Register non-leaf meta-parameters as buffers in the loss module on every outer update

For Learning Policy Gradient (LPG) https://arxiv.org/abs/2007.08794, it takes the LSTM network as the meta-parameter.

Different from MGRL, on each update, the meta-network output is not a leaf tensor anymore. Then we need to register these output again and again before each call of loss_module.forward. This makes

fmodel, params, buffers = functorch.make_functional_with_buffers(loss_module)

not working.

cc @Benjamin-eecs @waterhorse1

Solution

A clear and concise description of what you want to happen.

Split the forward method in the loss module into a separate pure function, i.e., a state less function does not have any parameters. The model parameters should be organized by other modules. The loss function only takes a tensordict as input, and add a new key "loss_objective" into the tensordict. All tensor inputs (e.g. value = self.critic(...)) should be calculated before calling the loss function, because the loss function is purely functional, i.e., does not host parameters (e.g., actor.parameters(), critic.parameters()).

Here is a prototype example:

def ppo_loss(tensordict: TensorDictBase, dist_class: Distribution, **kwargs) -> TensorDictBase:
    tensordict = tensordict.clone()
    gamma = tensordict.get('gamma', kwargs.get('gamma'))  # In MGRL, the `gamma` parameter can be a tensor rather than a Python scalar
    critic_coeff = tensordict.get('critic_coeff', kwargs.get('critic_coeff'))
    dist = dist_class(tensordict.get('actor_output'))
    ...

For backward compatibility, refactor the PPOLoss module as:

class PPOLoss(LossModule):
    actor: nn.Module
    critic: nn.Module
    gamma: float
    entropy_coef: float
    critic_coef: float
    advantage_module: nn.Module

    def __init__(self, ...):
        ...

        self.floss_kwargs = dict(
            gamma=self.gamma,
            entropy_coef=self.entropy_coef,
            critic_coef=self.critic_coef,
        )


    def forward(self, tensordict: TensorDictBase) -> TensorDictBase:
        # Prepare all inputs for loss function
        if self.advantage_module is not None:
            tensordict = self.advantage_module(
                tensordict,
            )
        tensordict = tensordict.clone()
        ...

        # Call purely functional version of loss function
        return ppo_loss(
            tensordict,
            dist_class,
            **self.floss_kwargs
        )

Alternatives

A clear and concise description of any alternative solutions or features you've considered.

Copy and paste the loss module source code, then do specific customizations.

Additional context

Add any other context or screenshots about the feature request here.

Checklist

  • I have checked that there is no similar issue in the repo (required)

Metadata

Metadata

Assignees

Labels

enhancementNew feature or request

Type

No type

Projects

No projects

Milestone

No milestone

Relationships

None yet

Development

No branches or pull requests

Issue actions