|
108 | 108 | " model = nn.parallel.DistributedDataParallel(model)\n", |
109 | 109 | " optimizer = optim.SGD(model.parameters(), lr=0.01, momentum=0.5)\n", |
110 | 110 | "\n", |
111 | | - " # Download FashionMNIST dataset only on local_rank=0 process.\n", |
112 | | - " if local_rank == 0:\n", |
113 | | - " dataset = datasets.FashionMNIST(\n", |
| 111 | + " # Download FashionMNIST dataset only on global rank 0 process.\n", |
| 112 | + " if dist.get_rank() == 0:\n", |
| 113 | + " datasets.FashionMNIST(\n", |
114 | 114 | " \"./data\",\n", |
115 | 115 | " train=True,\n", |
116 | 116 | " download=True,\n", |
117 | | - " transform=transforms.Compose([\n", |
118 | | - " transforms.ToTensor(),\n", |
119 | | - " transforms.Normalize((0.1307,), (0.3081,)),\n", |
120 | | - " ]),\n", |
121 | 117 | " )\n", |
122 | 118 | " dist.barrier()\n", |
123 | 119 | " dataset = datasets.FashionMNIST(\n", |
|
127 | 123 | " transform=transforms.Compose([transforms.ToTensor()]),\n", |
128 | 124 | " )\n", |
129 | 125 | "\n", |
130 | | - " # Shard the dataset accross workers.\n", |
| 126 | + " # Shard the dataset across workers.\n", |
131 | 127 | " train_loader = DataLoader(\n", |
132 | 128 | " dataset,\n", |
133 | 129 | " batch_size=100,\n", |
|
374 | 370 | ], |
375 | 371 | "metadata": { |
376 | 372 | "kernelspec": { |
377 | | - "display_name": "Python 3 (ipykernel)", |
| 373 | + "display_name": "Python 3", |
378 | 374 | "language": "python", |
379 | 375 | "name": "python3" |
380 | 376 | }, |
|
388 | 384 | "name": "python", |
389 | 385 | "nbconvert_exporter": "python", |
390 | 386 | "pygments_lexer": "ipython3", |
391 | | - "version": "3.11.13" |
| 387 | + "version": "3.13.7" |
392 | 388 | } |
393 | 389 | }, |
394 | 390 | "nbformat": 4, |
|
0 commit comments