Skip to content

Commit 302fd39

Browse files
tushar00jainfacebook-github-bot
authored andcommitted
load/store outer optimizer state dict (#277)
Summary: We don't restore outer optimizer state currently which can lead to bumps in loss because of high learning rate from a new replica. So save the outer optimizer state in the diloco specific state dict. Pull Request resolved: #277 Reviewed By: d4l3k Differential Revision: D83512078 fbshipit-source-id: 07c3ca7f4830f2115c3a4586d93c6d0883a38660
1 parent 6393e6d commit 302fd39

File tree

4 files changed

+33
-30
lines changed

4 files changed

+33
-30
lines changed

torchft/_test/diloco_trainer.py

Lines changed: 0 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -227,10 +227,6 @@ def load_state_dict(self, state_dict: Dict[str, Dict[str, object]]) -> None:
227227
self.model.to(self.device)
228228

229229
self.inner_optimizer.load_state_dict(state_dict["inner_optim"])
230-
for i, optimizer in enumerate(self.outer_optimizers):
231-
optimizer.load_state_dict(
232-
cast(dict[str, torch.Tensor], state_dict[f"outer_optim"][f"{i}"])
233-
)
234230

235231
def state_dict(self) -> Dict[str, Dict[str, object]]:
236232
"""
@@ -244,10 +240,6 @@ def state_dict(self) -> Dict[str, Dict[str, object]]:
244240
return {
245241
"model": self.model.state_dict(),
246242
"inner_optim": self.inner_optimizer.state_dict(),
247-
"outer_optim": {
248-
f"{i}": optimizer.state_dict()
249-
for i, optimizer in enumerate(self.outer_optimizers)
250-
},
251243
}
252244

253245
def train_loop(self) -> dict[str, Any]:

torchft/diloco_regression_test.py

Lines changed: 8 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -221,12 +221,18 @@ def train_loop(self) -> Dict[str, Any]:
221221

222222
for i in range(self.n_fragments):
223223
value = cast(
224-
dict[str, torch.Tensor],
224+
dict[str, dict[str, torch.Tensor]],
225225
user_state_dict[f"StreamingDiLoCoFragment_{i}"],
226226
)
227227
parameter_history["global_parameter_history"][local_step][
228228
f"layers.{i}.weight"
229-
] = value["weight"].data.clone().detach().cpu().tolist()
229+
] = (
230+
value["original_parameters"]["weight"]
231+
.data.clone()
232+
.detach()
233+
.cpu()
234+
.tolist()
235+
)
230236

231237
manager_steps.add(manager_curr_step)
232238

torchft/local_sgd.py

Lines changed: 10 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -259,16 +259,21 @@ def register_state_dict_fn(self) -> None:
259259
fragment_key = f"StreamingDiLoCoFragment_{self._fragment_id}"
260260

261261
# Define load function for this fragment
262-
def load_fn(state_dict: Dict[str, torch.Tensor]) -> None:
263-
for name, param in state_dict.items():
262+
def load_fn(state_dict: Dict[str, Dict[str, torch.Tensor]]) -> None:
263+
for name, param in state_dict["original_parameters"].items():
264264
if name in self.original_parameters:
265265
self.original_parameters[name].copy_(param)
266266

267+
self._outer_optimizer.load_state_dict(state_dict["outer_optimizer"])
268+
267269
# Define save function for this fragment
268-
def save_fn() -> Dict[str, torch.Tensor]:
270+
def save_fn() -> Dict[str, Dict[str, torch.Tensor]]:
269271
return {
270-
name: extract_local_tensor(param)
271-
for name, param in self.original_parameters.items()
272+
"outer_optimizer": self._outer_optimizer.state_dict(),
273+
"original_parameters": {
274+
name: extract_local_tensor(param)
275+
for name, param in self.original_parameters.items()
276+
},
272277
}
273278

274279
# Register the functions with the manager

torchft/local_sgd_integ_test.py

Lines changed: 15 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -140,29 +140,29 @@ def assert_equal_global_state(
140140
for step in rep0.keys():
141141
for i in range(n_fragments):
142142
torch.testing.assert_close(
143-
rep1[step]["user"][f"StreamingDiLoCoFragment_{i}"],
144-
rep0[step]["user"][f"StreamingDiLoCoFragment_{i}"],
143+
rep1[step]["user"][f"StreamingDiLoCoFragment_{i}"][
144+
"original_parameters"
145+
],
146+
rep0[step]["user"][f"StreamingDiLoCoFragment_{i}"][
147+
"original_parameters"
148+
],
145149
check_device=False,
146150
msg=f"{step=} {i=}",
147151
)
148-
# Check all outer optimizers
149-
for i in range(
150-
len(
151-
cast(
152-
dict[str, dict[str, torch.Tensor]],
153-
rep0[step]["user"]["default"]["outer_optim"],
154-
).keys()
155-
)
156-
):
152+
# Check all outer optimizers
157153
torch.testing.assert_close(
158154
cast(
159155
dict[str, dict[str, torch.Tensor]],
160-
rep1[step]["user"]["default"]["outer_optim"],
161-
)[f"{i}"],
156+
rep1[step]["user"][f"StreamingDiLoCoFragment_{i}"][
157+
"outer_optimizer"
158+
],
159+
),
162160
cast(
163161
dict[str, dict[str, torch.Tensor]],
164-
rep0[step]["user"]["default"]["outer_optim"],
165-
)[f"{i}"],
162+
rep0[step]["user"][f"StreamingDiLoCoFragment_{i}"][
163+
"outer_optimizer"
164+
],
165+
),
166166
check_device=False,
167167
)
168168

0 commit comments

Comments
 (0)