Skip to content

Conversation

@pchalasani
Copy link
Contributor

@pchalasani pchalasani commented Nov 27, 2022

Stochastic Weight Averaging (SWA) is (quoting/paraphrasing from their page):

a simple procedure that improves generalization in deep learning over Stochastic Gradient Descent (SGD) at no additional cost, and can be used as a drop-in replacement for any other optimizer in PyTorch. SWA has a wide range of applications and features, [...] including [...] improve the stability of training as well as the final average rewards of policy-gradient methods in deep reinforcement learning.

See the PyTorch SWA page for more.

Description

Relatively simple change in exp_manager.py. It allows an additional key "swa" to be included in policy_kwargs, e.g.

hyperparams["policy_kwargs"]["swa"] = {
   "swa_start": 5, 
   "swa_freq: 3,
   "swa_lr": 0.05
}

Motivation and Context

SWA might help improve stability and reduce sensitivity to random seeds in some DRL applications.

Closes #321

  • I have raised an issue to propose this change (required for new features and bug fixes)

Types of changes

  • Bug fix (non-breaking change which fixes an issue)
  • New feature (non-breaking change which adds functionality)
  • Breaking change (fix or feature that would cause existing functionality to change)
  • Documentation (update in the documentation)

Checklist:

  • I've read the CONTRIBUTION guide (required)
  • I have updated the changelog accordingly (required).
  • My change requires a change to the documentation.
  • I have updated the tests accordingly (required for a bug fix or a new feature).
  • I have updated the documentation accordingly.
  • I have reformatted the code using make format (required)
  • I have checked the codestyle using make check-codestyle and make lint (required)
  • I have ensured make pytest and make type both pass. (required)

Note: we are using a maximum length of 127 characters per line

@pchalasani pchalasani changed the title Support for Stoch Wt Avg (SWA) Support for Stoch Wt Avg (SWA) closes #321 Nov 27, 2022
@pchalasani pchalasani marked this pull request as ready for review November 27, 2022 02:51
@pchalasani
Copy link
Contributor Author

I realized we need to do opt.swap_swa_sgd() at the end of training, and some further thought is needed, to see how this impacts computation of validation metrics by EvalCallback etc

@pchalasani pchalasani marked this pull request as draft November 27, 2022 17:43
@pchalasani
Copy link
Contributor Author

I added opt.swap_swa_sgd() after model.learn.
We also need to do this before and after each evaluate_policy() call in EvalCallback (which is in the original sb3 repo), so that validation metrics are evaluated with the SWA-averaged model weights. We could potentially subclass EvalCallback to accomplish this.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

[Feature Request] Support Stochastic Weight Averaging (SWA) for improved stability

1 participant