|
| 1 | +# Copyright (c) Meta Platforms, Inc. and affiliates. |
| 2 | +# All rights reserved. |
| 3 | +# |
| 4 | +# This source code is licensed under the BSD-style license found in the |
| 5 | +# LICENSE file in the root directory of this source tree. |
| 6 | + |
| 7 | +""" |
| 8 | +LocalSGD |
| 9 | +========= |
| 10 | +
|
| 11 | +This module implements a fault tolerant version of LocalSGD and related methods. |
| 12 | +""" |
| 13 | + |
| 14 | +from typing import Any, Dict, List, Mapping, Optional |
| 15 | + |
| 16 | +import torch |
| 17 | +from torch import nn, optim |
| 18 | + |
| 19 | +from torchft.manager import Manager |
| 20 | + |
| 21 | + |
| 22 | +class LocalSGD(nn.Module): |
| 23 | + """ |
| 24 | + LocalSGD is a model wrapper similar to DistributedDataParallel that |
| 25 | + implements the algorithm described in https://arxiv.org/pdf/1805.09767 |
| 26 | +
|
| 27 | + This will synchronize the model parameters periodically in a fault tolerant |
| 28 | + way using a torchft Manager. The allreduce on the parameters will happen |
| 29 | + every sync_every steps after the optimizer.step call. |
| 30 | +
|
| 31 | + To implement safe and fault tolerant, this requires a backup copy of the |
| 32 | + weights. By default these are stored in CPU memory. If any error occurs |
| 33 | + during the LocalSGD step, the step will be discarded and the model |
| 34 | + parameters will reset back to the last time LocalSGD synchronized. |
| 35 | +
|
| 36 | + The backup weights could be eliminated by relaxing the guarantee of exactly |
| 37 | + `sync_every` steps but that would diverge from the LocalSGD algorithm. |
| 38 | + DiLoCo also needs this backup copy to compute the delta. |
| 39 | +
|
| 40 | + The torchft quorum is computed at the beginning of ``sync_every`` steps. If |
| 41 | + any error occurs, or a worker fails between syncs, ``sync_every`` steps will be |
| 42 | + discarded and a new quorum will be computed on the next step. |
| 43 | +
|
| 44 | + If running in async mode, on a joining worker the first ``sync_every`` steps |
| 45 | + will discarded as the model will be recovering during that period. When |
| 46 | + using sync mode, the checkpoint will be restored prior to the first step. |
| 47 | +
|
| 48 | + TODO: add a way via Manager to detect workers failing early for shrink only |
| 49 | + TODO: add DiLoCo support |
| 50 | + """ |
| 51 | + |
| 52 | + def __init__( |
| 53 | + self, |
| 54 | + manager: Manager, |
| 55 | + model: nn.Module, |
| 56 | + optimizer: optim.Optimizer, |
| 57 | + sync_every: int, |
| 58 | + backup_device: Optional[torch.device] = None, |
| 59 | + pin_memory: bool = True, |
| 60 | + ) -> None: |
| 61 | + """ |
| 62 | + Args: |
| 63 | + manager: The manager to use. |
| 64 | + model: The model to wrap. |
| 65 | + optimizer: The optimizer used by the model. |
| 66 | + sync_every: How often to sync the model weights. |
| 67 | + backup_device: The device to store the backup of the model parameters on. (default cpu) |
| 68 | + pin_memory: Whether to pin the memory used for the backup of the model parameters. |
| 69 | + """ |
| 70 | + super().__init__() |
| 71 | + |
| 72 | + self._manager = manager |
| 73 | + self._model = model |
| 74 | + self._local_step = 0 |
| 75 | + self._started_step = False |
| 76 | + self._sync_every = sync_every |
| 77 | + assert sync_every >= 1, "sync_every must be greater than or equal to 1" |
| 78 | + |
| 79 | + device = backup_device or torch.device("cpu") |
| 80 | + |
| 81 | + self._backup_parameters: Dict[str, torch.Tensor] = {} |
| 82 | + |
| 83 | + for name, p in self._model.named_parameters(): |
| 84 | + t = torch.empty(*tuple(p.shape), dtype=p.dtype, device=device) |
| 85 | + if ( |
| 86 | + pin_memory |
| 87 | + and t.device == torch.device("cpu") |
| 88 | + and torch.cuda.is_available() |
| 89 | + ): |
| 90 | + t = t.pin_memory() |
| 91 | + self._backup_parameters[name] = t |
| 92 | + |
| 93 | + # Need to copy the parameters to the host to be safe if we are on the first step. |
| 94 | + self._save_parameters() |
| 95 | + |
| 96 | + optimizer.register_step_post_hook(self._step_post_hook) |
| 97 | + |
| 98 | + def _save_parameters(self) -> None: |
| 99 | + # TODO: consider running copy on a separate stream |
| 100 | + for name, p in self._model.named_parameters(): |
| 101 | + self._backup_parameters[name].copy_(p.data, non_blocking=True) |
| 102 | + |
| 103 | + def _restore_parameters(self) -> None: |
| 104 | + # TODO: consider running copy on a separate stream |
| 105 | + for name, p in self._model.named_parameters(): |
| 106 | + p.data.copy_(self._backup_parameters[name], non_blocking=True) |
| 107 | + |
| 108 | + # pyre-fixme[14]: support state_dict args |
| 109 | + def state_dict(self) -> Dict[str, object]: |
| 110 | + """ |
| 111 | + state_dict returns the state_dict from the last time LocalSGD |
| 112 | + synchronized and not the current weights. |
| 113 | + """ |
| 114 | + state_dict = self._model.state_dict() |
| 115 | + for name, p in self._backup_parameters.items(): |
| 116 | + assert name in state_dict |
| 117 | + state_dict[name] = p |
| 118 | + return state_dict |
| 119 | + |
| 120 | + def load_state_dict( |
| 121 | + self, state_dict: Mapping[str, Any], strict: bool = True, assign: bool = False |
| 122 | + ) -> None: |
| 123 | + """ |
| 124 | + Loads the state dict to the model and the backup parameters. |
| 125 | +
|
| 126 | + This must be called while the model weights aren't being modified to |
| 127 | + avoid corrupting the backup weights. |
| 128 | + """ |
| 129 | + self._model.load_state_dict(state_dict, strict=strict, assign=assign) |
| 130 | + self._save_parameters() |
| 131 | + |
| 132 | + def forward(self, *args: object, **kwargs: object) -> object: |
| 133 | + """ |
| 134 | + Run the model parameters. |
| 135 | +
|
| 136 | + This should be called before the optimizer step. |
| 137 | +
|
| 138 | + This will start the quorum and save the parameters if this is the first step. |
| 139 | + """ |
| 140 | + if self._local_step == 0: |
| 141 | + self._manager.start_quorum() |
| 142 | + |
| 143 | + self._started_step = True |
| 144 | + |
| 145 | + return self._model.forward(*args, **kwargs) |
| 146 | + |
| 147 | + def _step_post_hook( |
| 148 | + self, _optim: optim.Optimizer, _args: List[object], _kwargs: Dict[str, object] |
| 149 | + ) -> None: |
| 150 | + """ |
| 151 | + This hook is registered on the optimizer and is called after the optimizer step. |
| 152 | +
|
| 153 | + This will call the allreduce on the model weights every sync_every steps. |
| 154 | + If any errors occur it will restore to the weights from the previous sync. |
| 155 | +
|
| 156 | + ``forward`` must be called before this function. |
| 157 | + """ |
| 158 | + assert self._started_step, "forward must be called before step" |
| 159 | + self._started_step = False |
| 160 | + |
| 161 | + self._local_step += 1 |
| 162 | + |
| 163 | + if self._local_step >= self._sync_every: |
| 164 | + self._local_step = 0 |
| 165 | + self._average() |
| 166 | + |
| 167 | + if self._manager.should_commit(): |
| 168 | + # save the parameters so we can restore from them later if necessary. |
| 169 | + self._save_parameters() |
| 170 | + else: |
| 171 | + # commit failed, restore from the backup parameters |
| 172 | + self._restore_parameters() |
| 173 | + |
| 174 | + def _average(self) -> None: |
| 175 | + # TODO: do we need to broadcast buffers like DDP does? |
| 176 | + |
| 177 | + works = [] |
| 178 | + |
| 179 | + for p in self._model.parameters(): |
| 180 | + # TODO: bucketize parameters |
| 181 | + works.append(self._manager.allreduce(p)) |
| 182 | + |
| 183 | + for work in works: |
| 184 | + work.wait() |
0 commit comments