Skip to content

Commit

Permalink
Improve OptimizerWrapper composability (#85)
Browse files Browse the repository at this point in the history
* Improve OptimizerWrapper composability

OptimizerWrapper currently miss several attributes that are required for training integration. This PR adds the missing gap.
  • Loading branch information
fegin authored Jan 30, 2025
1 parent 6e4ae38 commit c3d5d54
Show file tree
Hide file tree
Showing 2 changed files with 18 additions and 1 deletion.
11 changes: 10 additions & 1 deletion torchft/optim.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,8 +12,9 @@
"""

from typing import TYPE_CHECKING, Optional
from typing import TYPE_CHECKING, Any, Dict, List, Mapping, Optional

import torch
from torch.optim import Optimizer

if TYPE_CHECKING:
Expand Down Expand Up @@ -52,3 +53,11 @@ def step(self, closure: Optional[object] = None) -> None:
assert closure is None, "optimizers that use closures are not supported"
if self.manager.should_commit():
self.optim.step()

@property
def param_groups(self) -> List[Dict[str, Any]]:
return self.optim.param_groups

@property
def state(self) -> Mapping[torch.Tensor, Any]: # pyre-fixme[3]
return self.optim.state
8 changes: 8 additions & 0 deletions torchft/optim_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
from unittest import TestCase
from unittest.mock import MagicMock, create_autospec

import torch
from torch.nn import Linear
from torch.optim import AdamW

Expand Down Expand Up @@ -34,9 +35,16 @@ def test_optimizer_wrapper(self) -> None:
optim.zero_grad()
self.assertEqual(manager.start_quorum.call_count, 1)

b = torch.rand(3)
m(b).sum().backward()

manager.should_commit.return_value = True
optim.step()
manager.should_commit.return_value = False
optim.step()
self.assertEqual(len(optim.param_groups), 2)
self.assertEqual(optim.param_groups[1]["lr"], 1e-4)
self.assertEqual(optim.param_groups[1]["params"], [])
self.assertEqual(len(optim.state), len(list(m.parameters())))

self.assertEqual(manager.should_commit.call_count, 2)

0 comments on commit c3d5d54

Please sign in to comment.