|
| 1 | +import copy |
| 2 | +import logging |
| 3 | +from concurrent.futures import ThreadPoolExecutor, as_completed |
| 4 | +from contextlib import ExitStack |
| 5 | +from typing import Any, Dict |
| 6 | +from unittest import TestCase |
| 7 | + |
| 8 | +import torch |
| 9 | +from torch import nn, optim |
| 10 | + |
| 11 | +from torchft.local_sgd import DiLoCo, LocalSGD |
| 12 | +from torchft.manager import Manager |
| 13 | +from torchft.manager_integ_test import FailureInjector, MyModel, Runner |
| 14 | +from torchft.process_group import ProcessGroupGloo |
| 15 | +from torchft.torchft import Lighthouse |
| 16 | + |
| 17 | +logger: logging.Logger = logging.getLogger(__name__) |
| 18 | + |
| 19 | + |
| 20 | +def local_sgd_train_loop( |
| 21 | + rank: int, |
| 22 | + store_port: int, |
| 23 | + runner: Runner, |
| 24 | +) -> Dict[str, Dict[str, object]]: |
| 25 | + with ExitStack() as stack: |
| 26 | + |
| 27 | + def load_state_dict(state_dict: Dict[str, Dict[str, object]]) -> None: |
| 28 | + m.load_state_dict(state_dict["model"]) |
| 29 | + optimizer.load_state_dict(state_dict["optim"]) |
| 30 | + |
| 31 | + def state_dict() -> Dict[str, Dict[str, object]]: |
| 32 | + return { |
| 33 | + "model": m.state_dict(), |
| 34 | + "optim": optimizer.state_dict(), |
| 35 | + } |
| 36 | + |
| 37 | + print(f"worker {runner.replica_id=} {rank=} {runner.world_size=} starting") |
| 38 | + |
| 39 | + pg = ProcessGroupGloo() |
| 40 | + manager = Manager( |
| 41 | + pg=pg, |
| 42 | + min_replica_size=2, |
| 43 | + load_state_dict=load_state_dict, |
| 44 | + state_dict=state_dict, |
| 45 | + replica_id=str(runner.replica_id), |
| 46 | + store_addr="localhost", |
| 47 | + store_port=store_port, |
| 48 | + rank=rank, |
| 49 | + world_size=runner.world_size, |
| 50 | + lighthouse_addr=runner.lighthouse_address, |
| 51 | + port=19530 + runner.replica_id, |
| 52 | + # pyre-fixme[6]: Incompatible parameter type |
| 53 | + **runner.manager_args, |
| 54 | + ) |
| 55 | + stack.callback(lambda: manager.shutdown(wait=False)) |
| 56 | + |
| 57 | + m: nn.Module = MyModel() |
| 58 | + optimizer: optim.Optimizer = optim.Adam(m.parameters()) |
| 59 | + criterion = nn.CrossEntropyLoss() |
| 60 | + |
| 61 | + with LocalSGD(manager, m, optimizer, sync_every=2): |
| 62 | + while True: |
| 63 | + inputs = torch.rand(2, 3) |
| 64 | + labels = torch.randint(4, (2,)) |
| 65 | + |
| 66 | + optimizer.zero_grad() |
| 67 | + out = m(inputs) |
| 68 | + loss = criterion(out, labels) |
| 69 | + |
| 70 | + loss.backward() |
| 71 | + |
| 72 | + optimizer.step() |
| 73 | + |
| 74 | + if manager.current_step() >= 4: |
| 75 | + break |
| 76 | + |
| 77 | + runner.failure_injector.check(rank, manager.current_step()) |
| 78 | + |
| 79 | + # return state_dict so we can check consistency |
| 80 | + return state_dict() |
| 81 | + |
| 82 | + |
| 83 | +def diloco_train_loop( |
| 84 | + rank: int, |
| 85 | + store_port: int, |
| 86 | + runner: Runner, |
| 87 | +) -> Dict[str, Dict[str, object]]: |
| 88 | + with ExitStack() as stack: |
| 89 | + # Declare the model and optimizers |
| 90 | + m: nn.Module = MyModel() |
| 91 | + model_state_dict: Dict[str, Any] = runner.train_loop_args["model_state_dict"] |
| 92 | + m.load_state_dict(model_state_dict) |
| 93 | + |
| 94 | + # Setup optimizers |
| 95 | + inner_optimizer: optim.Optimizer = torch.optim.AdamW( |
| 96 | + m.parameters(), lr=4e-4, weight_decay=0.1, betas=(0.9, 0.95) |
| 97 | + ) |
| 98 | + outer_optimizer: optim.Optimizer = torch.optim.SGD( |
| 99 | + m.parameters(), lr=0.7, momentum=0.9, nesterov=True |
| 100 | + ) |
| 101 | + |
| 102 | + # pyre-ignore[53] |
| 103 | + def load_state_dict(state_dict: Dict[str, Dict[str, object]]) -> None: |
| 104 | + m.load_state_dict(state_dict["model"]) |
| 105 | + # TODO: make this cleaner so we don't have to save this |
| 106 | + diloco._backup_parameters = state_dict["backup_params"] |
| 107 | + inner_optimizer.load_state_dict(state_dict["inner_optim"]) |
| 108 | + outer_optimizer.load_state_dict(state_dict["outer_optim"]) |
| 109 | + |
| 110 | + def state_dict() -> Dict[str, Dict[str, object]]: # pyre-ignore[53] |
| 111 | + return { |
| 112 | + "model": m.state_dict(), |
| 113 | + "backup_params": copy.deepcopy(diloco._backup_parameters), |
| 114 | + "inner_optim": inner_optimizer.state_dict(), |
| 115 | + "outer_optim": outer_optimizer.state_dict(), |
| 116 | + } |
| 117 | + |
| 118 | + print(f"worker {runner.replica_id=} {rank=} {runner.world_size=} starting") |
| 119 | + |
| 120 | + pg = ProcessGroupGloo() |
| 121 | + manager = Manager( |
| 122 | + pg=pg, |
| 123 | + min_replica_size=2, |
| 124 | + use_async_quorum=False, |
| 125 | + load_state_dict=load_state_dict, |
| 126 | + state_dict=state_dict, |
| 127 | + replica_id=str(runner.replica_id), |
| 128 | + store_addr="localhost", |
| 129 | + store_port=store_port, |
| 130 | + rank=rank, |
| 131 | + world_size=runner.world_size, |
| 132 | + lighthouse_addr=runner.lighthouse_address, |
| 133 | + port=19530 + runner.replica_id, |
| 134 | + # pyre-fixme[6]: Incompatible parameter type |
| 135 | + **runner.manager_args, |
| 136 | + ) |
| 137 | + stack.callback(manager.shutdown) |
| 138 | + |
| 139 | + criterion = nn.CrossEntropyLoss() |
| 140 | + all_state_dicts = {} |
| 141 | + with DiLoCo( |
| 142 | + manager, m, inner_optimizer, outer_optimizer, sync_every=2 |
| 143 | + ) as diloco: |
| 144 | + while True: |
| 145 | + inputs = torch.rand(2, 3) |
| 146 | + labels = torch.randint(4, (2,)) |
| 147 | + |
| 148 | + out = m(inputs) |
| 149 | + loss = criterion(out, labels) |
| 150 | + |
| 151 | + inner_optimizer.zero_grad() |
| 152 | + loss.backward() |
| 153 | + inner_optimizer.step() |
| 154 | + manager_step_str = str(manager.current_step()) |
| 155 | + all_state_dicts[manager_step_str] = state_dict() |
| 156 | + |
| 157 | + # after 4 model updates then break |
| 158 | + if manager.current_step() >= 4: |
| 159 | + break |
| 160 | + |
| 161 | + runner.failure_injector.check(rank, manager.current_step()) |
| 162 | + |
| 163 | + # return state_dict so we can check consistency |
| 164 | + return all_state_dicts |
| 165 | + |
| 166 | + |
| 167 | +class ManagerIntegTest(TestCase): |
| 168 | + def test_local_sgd_recovery(self) -> None: |
| 169 | + lighthouse = Lighthouse( |
| 170 | + bind="[::]:0", |
| 171 | + min_replicas=2, |
| 172 | + ) |
| 173 | + num_replicas = 2 |
| 174 | + futures = [] |
| 175 | + |
| 176 | + failure_injectors = [ |
| 177 | + FailureInjector(), |
| 178 | + FailureInjector().fail_at(0, 2), |
| 179 | + ] |
| 180 | + |
| 181 | + with ThreadPoolExecutor(max_workers=num_replicas) as executor: |
| 182 | + for replica_id, failure_injector in zip( |
| 183 | + range(num_replicas), failure_injectors |
| 184 | + ): |
| 185 | + runner = Runner( |
| 186 | + replica_id=replica_id, |
| 187 | + lighthouse_address=lighthouse.address(), |
| 188 | + failure_injector=failure_injector, |
| 189 | + train_loop=local_sgd_train_loop, |
| 190 | + manager_args={ |
| 191 | + "use_async_quorum": False, |
| 192 | + }, |
| 193 | + ) |
| 194 | + futures.append(executor.submit(runner.run_replica)) |
| 195 | + |
| 196 | + state_dicts = [] |
| 197 | + |
| 198 | + for fut in as_completed(futures): |
| 199 | + try: |
| 200 | + state_dicts.append(fut.result()) |
| 201 | + except Exception as e: |
| 202 | + print(e) |
| 203 | + raise |
| 204 | + |
| 205 | + lighthouse.shutdown() |
| 206 | + |
| 207 | + for state_dict in state_dicts: |
| 208 | + # LocalSGD only guarantees that the model is consistent across |
| 209 | + # replicas but uses separate optimizer states. |
| 210 | + torch.testing.assert_close( |
| 211 | + state_dict[0]["model"], state_dicts[0][0]["model"] |
| 212 | + ) |
| 213 | + |
| 214 | + self.assertEqual(failure_injectors[1].count, 1) |
| 215 | + |
| 216 | + def test_diloco_healthy(self) -> None: |
| 217 | + lighthouse = Lighthouse( |
| 218 | + bind="[::]:0", |
| 219 | + min_replicas=2, |
| 220 | + ) |
| 221 | + num_replicas = 2 |
| 222 | + futures = [] |
| 223 | + |
| 224 | + torch.manual_seed(42) |
| 225 | + # Initialize the model so we can pass in the state_dict |
| 226 | + m: nn.Module = MyModel() |
| 227 | + |
| 228 | + with ThreadPoolExecutor(max_workers=num_replicas) as executor: |
| 229 | + for replica_id in range(num_replicas): |
| 230 | + failure_injector = FailureInjector() |
| 231 | + runner = Runner( |
| 232 | + replica_id=replica_id, |
| 233 | + lighthouse_address=lighthouse.address(), |
| 234 | + failure_injector=failure_injector, |
| 235 | + train_loop=diloco_train_loop, |
| 236 | + train_loop_args={ |
| 237 | + "model_state_dict": m.state_dict(), |
| 238 | + }, |
| 239 | + ) |
| 240 | + futures.append(executor.submit(runner.run_replica)) |
| 241 | + |
| 242 | + state_dicts = [] |
| 243 | + |
| 244 | + for fut in as_completed(futures): |
| 245 | + state_dicts.append(fut.result()[0]) |
| 246 | + |
| 247 | + lighthouse.shutdown() |
| 248 | + |
| 249 | + for replica_group in state_dicts: |
| 250 | + for step, state_dict in replica_group.items(): |
| 251 | + # inner optimizer will be different, outer optimizer and model should be the same |
| 252 | + torch.testing.assert_close( |
| 253 | + state_dict["backup_params"], |
| 254 | + state_dicts[0][str(step)]["backup_params"], |
| 255 | + ) |
| 256 | + torch.testing.assert_close( |
| 257 | + state_dict["outer_optim"], state_dicts[0][str(step)]["outer_optim"] |
| 258 | + ) |
| 259 | + |
| 260 | + def test_diloco_recovery(self) -> None: |
| 261 | + lighthouse = Lighthouse( |
| 262 | + bind="[::]:0", |
| 263 | + min_replicas=2, |
| 264 | + ) |
| 265 | + num_replicas = 2 |
| 266 | + futures = [] |
| 267 | + |
| 268 | + failure_injectors = [ |
| 269 | + FailureInjector(), |
| 270 | + FailureInjector().fail_at(0, 2), |
| 271 | + ] |
| 272 | + |
| 273 | + torch.manual_seed(42) |
| 274 | + # Initialize the model so we can pass in the state_dict |
| 275 | + m: nn.Module = MyModel() |
| 276 | + |
| 277 | + with ThreadPoolExecutor(max_workers=num_replicas) as executor: |
| 278 | + for replica_id, failure_injector in zip( |
| 279 | + range(num_replicas), failure_injectors |
| 280 | + ): |
| 281 | + runner = Runner( |
| 282 | + replica_id=replica_id, |
| 283 | + lighthouse_address=lighthouse.address(), |
| 284 | + failure_injector=failure_injector, |
| 285 | + train_loop=diloco_train_loop, |
| 286 | + train_loop_args={ |
| 287 | + "model_state_dict": m.state_dict(), |
| 288 | + }, |
| 289 | + ) |
| 290 | + futures.append(executor.submit(runner.run_replica)) |
| 291 | + |
| 292 | + state_dicts = [] |
| 293 | + |
| 294 | + for fut in as_completed(futures): |
| 295 | + try: |
| 296 | + state_dicts.append(fut.result()[0]) |
| 297 | + except Exception as e: |
| 298 | + print(e) |
| 299 | + raise |
| 300 | + |
| 301 | + lighthouse.shutdown() |
| 302 | + for replica_group in state_dicts: |
| 303 | + for step, state_dict in replica_group.items(): |
| 304 | + str_step = str(step) |
| 305 | + if str_step in state_dicts[0]: |
| 306 | + # inner optimizer will be different, outer optimizer and model should be the same |
| 307 | + torch.testing.assert_close( |
| 308 | + state_dict["backup_params"], |
| 309 | + state_dicts[0][str_step]["backup_params"], |
| 310 | + ) |
| 311 | + torch.testing.assert_close( |
| 312 | + state_dict["outer_optim"], |
| 313 | + state_dicts[0][str_step]["outer_optim"], |
| 314 | + ) |
| 315 | + |
| 316 | + self.assertEqual(failure_injectors[1].count, 1) |
0 commit comments