Skip to content

Support grad_clip_norm_() for FSDP #20784

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 12 commits into
base: master
Choose a base branch
from
1 change: 1 addition & 0 deletions src/lightning/pytorch/CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
### Added

- Add enable_autolog_hparams argument to Trainer ([#20593](https://github.com/Lightning-AI/pytorch-lightning/pull/20593))
- Support `grad_clip_norm_()` for FSDP ([#20784](https://github.com/Lightning-AI/pytorch-lightning/pull/20784))


### Changed
Expand Down
4 changes: 3 additions & 1 deletion src/lightning/pytorch/core/module.py
Original file line number Diff line number Diff line change
Expand Up @@ -1207,7 +1207,9 @@ def clip_gradients(
)

gradient_clip_algorithm = GradClipAlgorithmType(gradient_clip_algorithm)
self.trainer.precision_plugin.clip_gradients(optimizer, gradient_clip_val, gradient_clip_algorithm)
self.trainer.precision_plugin.clip_gradients(
self.trainer.model, optimizer, gradient_clip_val, gradient_clip_algorithm
)

def configure_gradient_clipping(
self,
Expand Down
6 changes: 5 additions & 1 deletion src/lightning/pytorch/plugins/precision/amp.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@

import torch
from torch import Tensor
from torch.nn import Module
from torch.optim import LBFGS, Optimizer
from typing_extensions import override

Expand Down Expand Up @@ -100,6 +101,7 @@ def optimizer_step( # type: ignore[override]
@override
def clip_gradients(
self,
module: Optional[Module],
optimizer: Optimizer,
clip_val: Union[int, float] = 0.0,
gradient_clip_algorithm: GradClipAlgorithmType = GradClipAlgorithmType.NORM,
Expand All @@ -109,7 +111,9 @@ def clip_gradients(
f"The current optimizer, {type(optimizer).__qualname__}, does not allow for gradient clipping"
" because it performs unscaling of gradients internally. HINT: Are you using a 'fused' optimizer?"
)
super().clip_gradients(optimizer=optimizer, clip_val=clip_val, gradient_clip_algorithm=gradient_clip_algorithm)
super().clip_gradients(
module=module, optimizer=optimizer, clip_val=clip_val, gradient_clip_algorithm=gradient_clip_algorithm
)

def autocast_context_manager(self) -> torch.autocast:
return torch.autocast(self.device, dtype=(torch.bfloat16 if self.precision == "bf16-mixed" else torch.half))
Expand Down
1 change: 1 addition & 0 deletions src/lightning/pytorch/plugins/precision/deepspeed.py
Original file line number Diff line number Diff line change
Expand Up @@ -141,6 +141,7 @@ def optimizer_step( # type: ignore[override]
@override
def clip_gradients(
self,
module: Optional[Module],
optimizer: Optimizer,
clip_val: Union[int, float] = 0.0,
gradient_clip_algorithm: GradClipAlgorithmType = GradClipAlgorithmType.NORM,
Expand Down
14 changes: 6 additions & 8 deletions src/lightning/pytorch/plugins/precision/fsdp.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,12 +12,13 @@
# See the License for the specific language governing permissions and
# limitations under the License.
from contextlib import AbstractContextManager
from typing import TYPE_CHECKING, Any, Callable, Optional
from typing import TYPE_CHECKING, Any, Callable, Optional, Union

import torch
from lightning_utilities import apply_to_collection
from torch import Tensor
from torch.nn import Module
from torch.optim import Optimizer
from typing_extensions import get_args, override

import lightning.pytorch as pl
Expand Down Expand Up @@ -81,14 +82,11 @@ def convert_module(self, module: Module) -> Module:
return module

@override
def clip_grad_by_norm(self, *_: Any, **__: Any) -> None:
def clip_grad_by_norm(self, module: Optional[Module], optimizer: Optimizer, clip_val: Union[int, float]) -> None:
# see https://pytorch.org/docs/stable/fsdp.html#torch.distributed.fsdp.FullyShardedDataParallel.clip_grad_norm_
# section `Gradient Clipping`, using `torch.nn.utils.clip_grad_norm_` is incorrect with FSDP.
# To overcome this we need to call root_sharded_module.clip_grad_norm(clip_val), but we don't have a reference
# to the root module
raise MisconfigurationException(
f"`gradient_clip_algorithm='norm'` is currently not supported for `{self.__class__.__name__}`"
)
if module is None:
return
module.clip_grad_norm_(clip_val)

@property
def mixed_precision_config(self) -> "TorchMixedPrecision":
Expand Down
5 changes: 3 additions & 2 deletions src/lightning/pytorch/plugins/precision/precision.py
Original file line number Diff line number Diff line change
Expand Up @@ -143,6 +143,7 @@ def _clip_gradients(

def clip_gradients(
self,
module: Optional[Module],
optimizer: Optimizer,
clip_val: Union[int, float] = 0.0,
gradient_clip_algorithm: GradClipAlgorithmType = GradClipAlgorithmType.NORM,
Expand All @@ -153,14 +154,14 @@ def clip_gradients(
if gradient_clip_algorithm == GradClipAlgorithmType.VALUE:
self.clip_grad_by_value(optimizer, clip_val)
elif gradient_clip_algorithm == GradClipAlgorithmType.NORM:
self.clip_grad_by_norm(optimizer, clip_val)
self.clip_grad_by_norm(module, optimizer, clip_val)

def clip_grad_by_value(self, optimizer: Optimizer, clip_val: Union[int, float]) -> None:
"""Clip gradients by value."""
parameters = self.main_params(optimizer)
torch.nn.utils.clip_grad_value_(parameters, clip_value=clip_val)

def clip_grad_by_norm(self, optimizer: Optimizer, clip_val: Union[int, float]) -> None:
def clip_grad_by_norm(self, module: Optional[Module], optimizer: Optimizer, clip_val: Union[int, float]) -> None:
"""Clip gradients by norm."""
parameters = self.main_params(optimizer)
torch.nn.utils.clip_grad_norm_(parameters, clip_val)
Expand Down
11 changes: 7 additions & 4 deletions tests/tests_pytorch/plugins/precision/test_amp.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
from unittest.mock import Mock

import pytest
from torch.nn import Module
from torch.optim import Optimizer

from lightning.pytorch.plugins import MixedPrecision
Expand All @@ -22,22 +23,23 @@

def test_clip_gradients():
"""Test that `.clip_gradients()` is a no-op when clipping is disabled."""
module = Mock(spec=Module)
optimizer = Mock(spec=Optimizer)
precision = MixedPrecision(precision="16-mixed", device="cuda:0", scaler=Mock())
precision.clip_grad_by_value = Mock()
precision.clip_grad_by_norm = Mock()
precision.clip_gradients(optimizer)
precision.clip_gradients(module, optimizer)
precision.clip_grad_by_value.assert_not_called()
precision.clip_grad_by_norm.assert_not_called()

precision.clip_gradients(optimizer, clip_val=1.0, gradient_clip_algorithm=GradClipAlgorithmType.VALUE)
precision.clip_gradients(module, optimizer, clip_val=1.0, gradient_clip_algorithm=GradClipAlgorithmType.VALUE)
precision.clip_grad_by_value.assert_called_once()
precision.clip_grad_by_norm.assert_not_called()

precision.clip_grad_by_value.reset_mock()
precision.clip_grad_by_norm.reset_mock()

precision.clip_gradients(optimizer, clip_val=1.0, gradient_clip_algorithm=GradClipAlgorithmType.NORM)
precision.clip_gradients(module, optimizer, clip_val=1.0, gradient_clip_algorithm=GradClipAlgorithmType.NORM)
precision.clip_grad_by_value.assert_not_called()
precision.clip_grad_by_norm.assert_called_once()

Expand All @@ -46,8 +48,9 @@ def test_optimizer_amp_scaling_support_in_step_method():
"""Test that the plugin checks if the optimizer takes over unscaling in its step, making it incompatible with
gradient clipping (example: fused Adam)."""

module = Mock(spec=Module)
optimizer = Mock(_step_supports_amp_scaling=True)
precision = MixedPrecision(precision="16-mixed", device="cuda:0", scaler=Mock())

with pytest.raises(RuntimeError, match="The current optimizer.*does not allow for gradient clipping"):
precision.clip_gradients(optimizer, clip_val=1.0)
precision.clip_gradients(module, optimizer, clip_val=1.0)
Loading