|
1 | | ---- mnist_main_original.py 2024-08-10 17:30:08.552324326 -0500 |
2 | | -+++ pnetcdf_mnist.py 2024-08-11 16:10:31.895471785 -0500 |
| 1 | +--- mnist_main.py.orig 2025-05-09 10:51:06.814200110 -0500 |
| 2 | ++++ mnist_main.py 2025-05-09 11:15:17.198167820 -0500 |
3 | 3 | @@ -1,3 +1,8 @@ |
4 | 4 | +# |
5 | | -+# Copyright (C) 2024, Northwestern University and Argonne National Laboratory |
| 5 | ++# Copyright (C) 2025, Northwestern University and Argonne National Laboratory |
6 | 6 | +# See COPYRIGHT notice in top-level directory. |
7 | 7 | +# |
8 | 8 | + |
|
15 | 15 | from torch.optim.lr_scheduler import StepLR |
16 | 16 | +from torch.nn.parallel import DistributedDataParallel as DDP |
17 | 17 | +from torch.utils.data.distributed import DistributedSampler |
18 | | - |
| 18 | + |
19 | 19 | +import comm_file, pnetcdf_io |
20 | 20 | +from mpi4py import MPI |
21 | | - |
| 21 | + |
22 | 22 | class Net(nn.Module): |
23 | 23 | def __init__(self): |
24 | | -@@ -42,14 +51,13 @@ |
| 24 | +@@ -42,7 +51,7 @@ |
25 | 25 | loss = F.nll_loss(output, target) |
26 | 26 | loss.backward() |
27 | 27 | optimizer.step() |
|
30 | 30 | print('Train Epoch: {} [{}/{} ({:.0f}%)]\tLoss: {:.6f}'.format( |
31 | 31 | epoch, batch_idx * len(data), len(train_loader.dataset), |
32 | 32 | 100. * batch_idx / len(train_loader), loss.item())) |
33 | | - if args.dry_run: |
34 | | - break |
35 | | - |
36 | | -- |
37 | | - def test(model, device, test_loader): |
38 | | - model.eval() |
39 | | - test_loss = 0 |
40 | | -@@ -62,9 +70,14 @@ |
| 33 | +@@ -62,9 +71,14 @@ |
41 | 34 | pred = output.argmax(dim=1, keepdim=True) # get the index of the max log-probability |
42 | 35 | correct += pred.eq(target.view_as(pred)).sum().item() |
43 | | - |
| 36 | + |
44 | 37 | + # aggregate loss among all ranks |
45 | 38 | + test_loss = comm.mpi_comm.allreduce(test_loss, op=MPI.SUM) |
46 | 39 | + correct = comm.mpi_comm.allreduce(correct, op=MPI.SUM) |
47 | 40 | + |
48 | 41 | test_loss /= len(test_loader.dataset) |
49 | | - |
| 42 | + |
50 | 43 | - print('\nTest set: Average loss: {:.4f}, Accuracy: {}/{} ({:.0f}%)\n'.format( |
51 | 44 | + if rank == 0: |
52 | 45 | + print('\nTest set: Average loss: {:.4f}, Accuracy: {}/{} ({:.0f}%)\n'.format( |
53 | 46 | test_loss, correct, len(test_loader.dataset), |
54 | 47 | 100. * correct / len(test_loader.dataset))) |
55 | | - |
56 | | -@@ -94,6 +107,8 @@ |
| 48 | + |
| 49 | +@@ -92,6 +106,8 @@ |
57 | 50 | help='how many batches to wait before logging training status') |
58 | | - parser.add_argument('--save-model', action='store_true', default=False, |
| 51 | + parser.add_argument('--save-model', action='store_true', |
59 | 52 | help='For Saving the current Model') |
60 | 53 | + parser.add_argument('--input-file', type=str, required=True, |
61 | 54 | + help='NetCDF file storing train and test samples') |
62 | 55 | args = parser.parse_args() |
63 | | - use_cuda = not args.no_cuda and torch.cuda.is_available() |
64 | | - use_mps = not args.no_mps and torch.backends.mps.is_available() |
65 | | -@@ -101,18 +116,18 @@ |
66 | | - torch.manual_seed(args.seed) |
67 | | - |
68 | | - if use_cuda: |
69 | | -- device = torch.device("cuda") |
70 | | -+ torch.cuda.set_device(rank) # Set the GPU device by rank |
71 | | -+ device = torch.device(f"cuda:{rank}") |
72 | | - elif use_mps: |
73 | | - device = torch.device("mps") |
| 56 | + |
| 57 | + use_accel = not args.no_accel and torch.accelerator.is_available() |
| 58 | +@@ -103,12 +119,11 @@ |
74 | 59 | else: |
75 | 60 | device = torch.device("cpu") |
76 | | - |
| 61 | + |
77 | 62 | - train_kwargs = {'batch_size': args.batch_size} |
78 | 63 | + train_kwargs = {'batch_size': args.batch_size//nprocs} |
79 | 64 | test_kwargs = {'batch_size': args.test_batch_size} |
80 | | - if use_cuda: |
81 | | - cuda_kwargs = {'num_workers': 1, |
| 65 | + if use_accel: |
| 66 | + accel_kwargs = {'num_workers': 1, |
82 | 67 | - 'pin_memory': True, |
83 | 68 | - 'shuffle': True} |
84 | 69 | + 'pin_memory': True} |
85 | | - train_kwargs.update(cuda_kwargs) |
86 | | - test_kwargs.update(cuda_kwargs) |
87 | | - |
88 | | -@@ -120,25 +135,53 @@ |
| 70 | + train_kwargs.update(accel_kwargs) |
| 71 | + test_kwargs.update(accel_kwargs) |
| 72 | + |
| 73 | +@@ -116,25 +131,53 @@ |
89 | 74 | transforms.ToTensor(), |
90 | 75 | transforms.Normalize((0.1307,), (0.3081,)) |
91 | 76 | ]) |
|
108 | 93 | + # add distributed samplers to DataLoaders |
109 | 94 | + train_loader = torch.utils.data.DataLoader(train_file, sampler=train_sampler, **train_kwargs) |
110 | 95 | + test_loader = torch.utils.data.DataLoader(test_file, sampler=test_sampler, **test_kwargs, drop_last=False) |
111 | | - |
| 96 | + |
112 | 97 | model = Net().to(device) |
113 | 98 | + |
114 | 99 | + # use DDP |
115 | | -+ model = DDP(model, device_ids=[device] if use_cuda else None) |
| 100 | ++ model = DDP(model, device_ids=[device] if use_accel else None) |
116 | 101 | + |
117 | 102 | optimizer = optim.Adadelta(model.parameters(), lr=args.lr) |
118 | | - |
| 103 | + |
119 | 104 | scheduler = StepLR(optimizer, step_size=1, gamma=args.gamma) |
120 | 105 | for epoch in range(1, args.epochs + 1): |
121 | 106 | + # train sampler set epoch |
|
125 | 110 | train(args, model, device, train_loader, optimizer, epoch) |
126 | 111 | test(model, device, test_loader) |
127 | 112 | scheduler.step() |
128 | | - |
| 113 | + |
129 | 114 | if args.save_model: |
130 | 115 | - torch.save(model.state_dict(), "mnist_cnn.pt") |
131 | 116 | + if rank == 0: |
132 | 117 | + torch.save(model.state_dict(), "mnist_cnn.pt") |
133 | | - |
| 118 | + |
134 | 119 | + # close files |
135 | 120 | + train_file.close() |
136 | 121 | + test_file.close() |
137 | | - |
| 122 | + |
138 | 123 | if __name__ == '__main__': |
139 | 124 | + ## initialize parallel environment |
140 | 125 | + comm, device = comm_file.init_parallel() |
|
0 commit comments