1+ import logging
2+ import os
3+ import unittest
4+
15import torch
26import torch .nn as nn
37import torch .nn .functional as F
4- from tests .internal .common_utils import find_free_port
5- import unittest
6- import multiprocessing
7- import os
88import bagua .torch_api as bagua
9- from tests import skip_if_cuda_not_available
10- import logging
9+
1110from bagua .torch_api .data_parallel import DistributedDataParallel as DDP
11+ from tests .internal .multi_process_v2 import MultiProcessTestCase , skip_if_lt_x_gpu
12+
13+ logger = logging .getLogger (__name__ )
1214
1315
1416class Net (nn .Module ):
@@ -26,21 +28,8 @@ def forward(self, x):
2628 return F .softmax (x , dim = 1 )
2729
2830
29- def run_model_wrapper (rank , env , fn , warmup_steps ):
30- # initialize subprocess env
31- os .environ ["WORLD_SIZE" ] = env ["WORLD_SIZE" ]
32- os .environ ["LOCAL_WORLD_SIZE" ] = env ["LOCAL_WORLD_SIZE" ]
33- os .environ ["MASTER_ADDR" ] = env ["MASTER_ADDR" ]
34- os .environ ["MASTER_PORT" ] = env ["MASTER_PORT" ]
35- os .environ ["BAGUA_SERVICE_PORT" ] = env ["BAGUA_SERVICE_PORT" ]
36- os .environ ["RANK" ] = str (rank )
37- os .environ ["LOCAL_RANK" ] = str (rank )
38-
39- # init bagua distributed process group
40- torch .cuda .set_device (rank )
41- bagua .init_process_group ()
42-
43- # construct model and optimizer, etc.
31+ def create_model_and_optimizer (warmup_steps ):
32+ # construct model and optimizer
4433 model = Net ().cuda ()
4534 optimizer = torch .optim .SGD (model .parameters (), lr = 0.01 )
4635 loss_fn = nn .MSELoss ()
@@ -52,84 +41,62 @@ def run_model_wrapper(rank, env, fn, warmup_steps):
5241 )
5342 ddp_model = DDP (model , optimizers = [optimizer ], algorithm = algorithm )
5443
55- fn ( ddp_model , optimizer , loss_fn )
44+ return model , optimizer
5645
5746
58- def train_epoch (epoch , model , optimizer , loss_fn ):
59- logging .debug ("Training epoch {}" .format (epoch ))
47+ def train_epoch (epoch , model , optimizer ):
48+ logger .debug ("Training epoch {}" .format (epoch ))
6049 for _ in range (10 ):
6150 data = torch .randn (4 , 2 ).cuda ()
6251 target = torch .randn (4 , 4 ).cuda ()
6352
6453 optimizer .zero_grad ()
6554 output = model (data )
66- loss = loss_fn (output , target )
55+ loss = nn . MSELoss () (output , target )
6756
6857 loss .backward ()
6958 optimizer .step ()
7059
7160
72- def run_epochs ( model , optimizer , loss_fn ):
73- for epoch in range ( 5 ):
74- train_epoch ( epoch , model , optimizer , loss_fn )
75- model . bagua_algorithm . abort ( model )
61+ class TestAsyncModelAverage ( MultiProcessTestCase ):
62+ def setUp ( self ):
63+ super ( TestAsyncModelAverage , self ). setUp ( )
64+ self . _spawn_processes ( )
7665
66+ def tearDown (self ):
67+ super (TestAsyncModelAverage , self ).tearDown ()
68+ try :
69+ os .remove (self .file_name )
70+ except OSError :
71+ pass
7772
78- def run_multiple_aborts (model , optimizer , loss_fn ):
79- for epoch in range (10 ):
80- model .bagua_algorithm .resume (model )
81- model .bagua_algorithm .resume (model )
82- train_epoch (epoch , model , optimizer , loss_fn )
83- model .bagua_algorithm .abort (model )
84- model .bagua_algorithm .abort (model )
85-
73+ @property
74+ def world_size (self ) -> int :
75+ return torch .cuda .device_count ()
8676
87- class TestAsyncModelAverage (unittest .TestCase ):
88- @skip_if_cuda_not_available ()
77+ @skip_if_lt_x_gpu (2 )
8978 def test_algorithm (self ):
90- nprocs = torch .cuda .device_count ()
91- env = {
92- "WORLD_SIZE" : str (nprocs ),
93- "LOCAL_WORLD_SIZE" : str (nprocs ),
94- "MASTER_ADDR" : "127.0.0.1" ,
95- "MASTER_PORT" : str (find_free_port (8000 , 8100 )),
96- "BAGUA_SERVICE_PORT" : str (find_free_port (9000 , 9100 )),
97- }
98-
99- mp = multiprocessing .get_context ("spawn" )
100- processes = []
101- for i in range (nprocs ):
102- p = mp .Process (target = run_model_wrapper , args = (i , env , run_epochs , 0 ))
103- p .start ()
104- processes .append (p )
105-
106- for p in processes :
107- p .join (timeout = 60 )
108- self .assertTrue (p .exitcode == 0 )
109-
110- @skip_if_cuda_not_available ()
79+ self ._init_bagua_distributed ()
80+ model , optimizer = create_model_and_optimizer (warmup_steps = 0 )
81+
82+ for epoch in range (100 ):
83+ train_epoch (epoch , model , optimizer )
84+ model .bagua_algorithm .abort (model )
85+
86+ @skip_if_lt_x_gpu (2 )
11187 def test_multiple_aborts (self ):
112- nprocs = torch .cuda .device_count ()
113- env = {
114- "WORLD_SIZE" : str (nprocs ),
115- "LOCAL_WORLD_SIZE" : str (nprocs ),
116- "MASTER_ADDR" : "127.0.0.1" ,
117- "MASTER_PORT" : str (find_free_port (8000 , 8100 )),
118- "BAGUA_SERVICE_PORT" : str (find_free_port (9000 , 9100 )),
119- }
120-
121- mp = multiprocessing .get_context ("spawn" )
122- processes = []
123- for i in range (nprocs ):
124- p = mp .Process (
125- target = run_model_wrapper , args = (i , env , run_multiple_aborts , 10 )
126- )
127- p .start ()
128- processes .append (p )
129-
130- for p in processes :
131- p .join (timeout = 60 )
132- self .assertTrue (p .exitcode == 0 )
88+ self ._init_bagua_distributed ()
89+ model , optimizer = create_model_and_optimizer (warmup_steps = 10 )
90+
91+ for i in range (2 ):
92+ model .bagua_algorithm .resume (model )
93+ model .bagua_algorithm .abort (model )
94+ model .bagua_algorithm .resume (model )
95+ for epoch in range (100 ):
96+ train_epoch (i * 100 + epoch , model , optimizer )
97+
98+ model .bagua_algorithm .abort (model )
99+ model .bagua_algorithm .abort (model )
133100
134101
135102if __name__ == "__main__" :
0 commit comments