14
14
from net import Net
15
15
16
16
# Training settings
17
- parser = argparse .ArgumentParser (description = 'PyTorch MNIST Example' )
18
- parser .add_argument ('--batch-size' , type = int , default = 64 , metavar = 'N' ,
19
- help = 'input batch size for training (default: 64)' )
20
- parser .add_argument ('--test-batch-size' , type = int , default = 1000 , metavar = 'N' ,
21
- help = 'input batch size for testing (default: 1000)' )
22
- parser .add_argument ('--epochs' , type = int , default = 14 , metavar = 'N' ,
23
- help = 'number of epochs to train (default: 14)' )
24
- parser .add_argument ('--lr' , type = float , default = 0.01 , metavar = 'LR' ,
25
- help = 'learning rate (default: 0.01)' )
26
- parser .add_argument ('--momentum' , type = float , default = 0.5 , metavar = 'M' ,
27
- help = 'SGD momentum (default: 0.5)' )
28
- parser .add_argument ('--no-cuda' , action = 'store_true' , default = False ,
29
- help = 'disables CUDA training' )
30
- parser .add_argument ('--seed' , type = int , default = 42 , metavar = 'S' ,
31
- help = 'random seed (default: 42)' )
32
- parser .add_argument ('--log-interval' , type = int , default = 10 , metavar = 'N' ,
33
- help = 'how many batches to wait before logging training status' )
34
- parser .add_argument ('--fp16-allreduce' , action = 'store_true' , default = False ,
35
- help = 'use fp16 compression during allreduce' )
36
- parser .add_argument ('--use-adasum' , action = 'store_true' , default = False ,
37
- help = 'use adasum algorithm to do reduction' )
17
+ parser = argparse .ArgumentParser (description = "PyTorch MNIST Example" )
18
+ parser .add_argument (
19
+ "--batch-size" ,
20
+ type = int ,
21
+ default = 64 ,
22
+ metavar = "N" ,
23
+ help = "input batch size for training (default: 64)" ,
24
+ )
25
+ parser .add_argument (
26
+ "--test-batch-size" ,
27
+ type = int ,
28
+ default = 1000 ,
29
+ metavar = "N" ,
30
+ help = "input batch size for testing (default: 1000)" ,
31
+ )
32
+ parser .add_argument (
33
+ "--epochs" ,
34
+ type = int ,
35
+ default = 14 ,
36
+ metavar = "N" ,
37
+ help = "number of epochs to train (default: 14)" ,
38
+ )
39
+ parser .add_argument (
40
+ "--lr" , type = float , default = 0.01 , metavar = "LR" , help = "learning rate (default: 0.01)"
41
+ )
42
+ parser .add_argument (
43
+ "--momentum" ,
44
+ type = float ,
45
+ default = 0.5 ,
46
+ metavar = "M" ,
47
+ help = "SGD momentum (default: 0.5)" ,
48
+ )
49
+ parser .add_argument (
50
+ "--no-cuda" , action = "store_true" , default = False , help = "disables CUDA training"
51
+ )
52
+ parser .add_argument (
53
+ "--seed" , type = int , default = 42 , metavar = "S" , help = "random seed (default: 42)"
54
+ )
55
+ parser .add_argument (
56
+ "--log-interval" ,
57
+ type = int ,
58
+ default = 10 ,
59
+ metavar = "N" ,
60
+ help = "how many batches to wait before logging training status" ,
61
+ )
62
+ parser .add_argument (
63
+ "--fp16-allreduce" ,
64
+ action = "store_true" ,
65
+ default = False ,
66
+ help = "use fp16 compression during allreduce" ,
67
+ )
68
+ parser .add_argument (
69
+ "--use-adasum" ,
70
+ action = "store_true" ,
71
+ default = False ,
72
+ help = "use adasum algorithm to do reduction" ,
73
+ )
38
74
39
75
40
76
def train (model , train_sampler , train_loader , args , optimizer , epoch ):
@@ -53,9 +89,15 @@ def train(model, train_sampler, train_loader, args, optimizer, epoch):
53
89
if batch_idx % args .log_interval == 0 :
54
90
# Horovod: use train_sampler to determine the number of examples in
55
91
# this worker's partition.
56
- print ('Train Epoch: {} [{}/{} ({:.0f}%)]\t Loss: {:.6f}' .format (
57
- epoch , hvd .size () * batch_idx * len (data ), len (train_loader .dataset ),
58
- 100. * batch_idx / len (train_loader ), loss .item ()))
92
+ print (
93
+ "Train Epoch: {} [{}/{} ({:.0f}%)]\t Loss: {:.6f}" .format (
94
+ epoch ,
95
+ hvd .size () * batch_idx * len (data ),
96
+ len (train_loader .dataset ),
97
+ 100.0 * batch_idx / len (train_loader ),
98
+ loss .item (),
99
+ )
100
+ )
59
101
60
102
61
103
def metric_average (val , name ):
@@ -67,14 +109,14 @@ def metric_average(val, name):
67
109
68
110
def test (model , test_sampler , test_loader , args ):
69
111
model .eval ()
70
- test_loss = 0.
71
- test_accuracy = 0.
112
+ test_loss = 0.0
113
+ test_accuracy = 0.0
72
114
for data , target in test_loader :
73
115
if args .cuda :
74
116
data , target = data .cuda (), target .cuda ()
75
117
output = model (data )
76
118
# sum up batch loss
77
- test_loss += F .nll_loss (output , target , reduction = ' sum' ).item ()
119
+ test_loss += F .nll_loss (output , target , reduction = " sum" ).item ()
78
120
# get the index of the max log-probability
79
121
pred = output .data .max (1 , keepdim = True )[1 ]
80
122
test_accuracy += pred .eq (target .data .view_as (pred )).cpu ().float ().sum ()
@@ -85,13 +127,16 @@ def test(model, test_sampler, test_loader, args):
85
127
test_accuracy /= len (test_sampler )
86
128
87
129
# Horovod: average metric values across workers.
88
- test_loss = metric_average (test_loss , ' avg_loss' )
89
- test_accuracy = metric_average (test_accuracy , ' avg_accuracy' )
130
+ test_loss = metric_average (test_loss , " avg_loss" )
131
+ test_accuracy = metric_average (test_accuracy , " avg_accuracy" )
90
132
91
133
# Horovod: print output only on first rank.
92
134
if hvd .rank () == 0 :
93
- print ('\n Test set: Average loss: {:.4f}, Accuracy: {:.2f}%\n ' .format (
94
- test_loss , 100. * test_accuracy ))
135
+ print (
136
+ "\n Test set: Average loss: {:.4f}, Accuracy: {:.2f}%\n " .format (
137
+ test_loss , 100.0 * test_accuracy
138
+ )
139
+ )
95
140
96
141
97
142
def main ():
@@ -107,44 +152,54 @@ def main():
107
152
torch .cuda .set_device (hvd .local_rank ())
108
153
torch .cuda .manual_seed (args .seed )
109
154
110
-
111
155
# Horovod: limit # of CPU threads to be used per worker.
112
156
torch .set_num_threads (1 )
113
157
114
- kwargs = {' num_workers' : 1 , ' pin_memory' : True } if args .cuda else {}
158
+ kwargs = {" num_workers" : 1 , " pin_memory" : True } if args .cuda else {}
115
159
# When supported, use 'forkserver' to spawn dataloader workers instead of 'fork' to prevent
116
160
# issues with Infiniband implementations that are not fork-safe
117
- if (kwargs .get ('num_workers' , 0 ) > 0 and hasattr (mp , '_supports_context' ) and
118
- mp ._supports_context and 'forkserver' in mp .get_all_start_methods ()):
119
- kwargs ['multiprocessing_context' ] = 'forkserver'
120
-
121
- transform = transforms .Compose ([
122
- transforms .ToTensor (),
123
- transforms .Normalize ((0.1307 ,), (0.3081 ,))
124
- ])
161
+ if (
162
+ kwargs .get ("num_workers" , 0 ) > 0
163
+ and hasattr (mp , "_supports_context" )
164
+ and mp ._supports_context
165
+ and "forkserver" in mp .get_all_start_methods ()
166
+ ):
167
+ kwargs ["multiprocessing_context" ] = "forkserver"
168
+
169
+ transform = transforms .Compose (
170
+ [transforms .ToTensor (), transforms .Normalize ((0.1307 ,), (0.3081 ,))]
171
+ )
125
172
126
173
if hvd .rank () != 0 :
127
174
# might be downloading mnist data, let rank 0 download first
128
175
hvd .barrier ()
129
176
130
177
# train_dataset = datasets.MNIST('data-%d' % hvd.rank(), train=True, download=True, transform=transform)
131
- train_dataset = datasets .MNIST ('./data' , train = True , download = True , transform = transform )
178
+ train_dataset = datasets .MNIST (
179
+ "./data" , train = True , download = True , transform = transform
180
+ )
132
181
133
182
if hvd .rank () == 0 :
134
183
# mnist data is downloaded, indicate other ranks can proceed
135
184
hvd .barrier ()
136
185
137
186
# Horovod: use DistributedSampler to partition the training data.
138
- train_sampler = dist .DistributedSampler (train_dataset , num_replicas = hvd .size (), rank = hvd .rank ())
187
+ train_sampler = dist .DistributedSampler (
188
+ train_dataset , num_replicas = hvd .size (), rank = hvd .rank ()
189
+ )
139
190
train_loader = torch .utils .data .DataLoader (
140
- train_dataset , batch_size = args .batch_size , sampler = train_sampler , ** kwargs )
191
+ train_dataset , batch_size = args .batch_size , sampler = train_sampler , ** kwargs
192
+ )
141
193
142
194
# test_dataset = datasets.MNIST('data-%d' % hvd.rank(), train=False, transform=transform)
143
- test_dataset = datasets .MNIST (' ./data' , train = False , transform = transform )
195
+ test_dataset = datasets .MNIST (" ./data" , train = False , transform = transform )
144
196
# Horovod: use DistributedSampler to partition the test data.
145
- test_sampler = dist .DistributedSampler (test_dataset , num_replicas = hvd .size (), rank = hvd .rank ())
146
- test_loader = torch .utils .data .DataLoader (test_dataset , batch_size = args .test_batch_size ,
147
- sampler = test_sampler , ** kwargs )
197
+ test_sampler = dist .DistributedSampler (
198
+ test_dataset , num_replicas = hvd .size (), rank = hvd .rank ()
199
+ )
200
+ test_loader = torch .utils .data .DataLoader (
201
+ test_dataset , batch_size = args .test_batch_size , sampler = test_sampler , ** kwargs
202
+ )
148
203
149
204
model = Net ()
150
205
@@ -159,8 +214,9 @@ def main():
159
214
lr_scaler = hvd .local_size ()
160
215
161
216
# Horovod: scale learning rate by lr_scaler.
162
- optimizer = optim .SGD (model .parameters (), lr = args .lr * lr_scaler ,
163
- momentum = args .momentum )
217
+ optimizer = optim .SGD (
218
+ model .parameters (), lr = args .lr * lr_scaler , momentum = args .momentum
219
+ )
164
220
165
221
# Horovod: broadcast parameters & optimizer state.
166
222
hvd .broadcast_parameters (model .state_dict (), root_rank = 0 )
@@ -170,12 +226,14 @@ def main():
170
226
compression = hvd .Compression .fp16 if args .fp16_allreduce else hvd .Compression .none
171
227
172
228
# Horovod: wrap optimizer with DistributedOptimizer.
173
- optimizer = hvd .DistributedOptimizer (optimizer ,
174
- named_parameters = model .named_parameters (),
175
- compression = compression ,
176
- op = hvd .Adasum if args .use_adasum else hvd .Average )
229
+ optimizer = hvd .DistributedOptimizer (
230
+ optimizer ,
231
+ named_parameters = model .named_parameters (),
232
+ compression = compression ,
233
+ op = hvd .Adasum if args .use_adasum else hvd .Average ,
234
+ )
177
235
178
- total_time = 0.
236
+ total_time = 0.0
179
237
180
238
for epoch in range (1 , args .epochs + 1 ):
181
239
start = time .time ()
@@ -186,6 +244,6 @@ def main():
186
244
return hvd .rank (), total_time
187
245
188
246
189
- if __name__ == ' __main__' :
247
+ if __name__ == " __main__" :
190
248
rk , tt = main ()
191
- print (f' [{ rk } ] Total time elapsed: { tt } seconds' )
249
+ print (f" [{ rk } ] Total time elapsed: { tt } seconds" )
0 commit comments