File tree Expand file tree Collapse file tree 1 file changed +35
-1
lines changed
Expand file tree Collapse file tree 1 file changed +35
-1
lines changed Original file line number Diff line number Diff line change 11import os
22import random
3+ import subprocess
4+ import pytest
35import time
46
57import torch
@@ -86,13 +88,45 @@ def run():
8688 ps .update (checkpoint_name , queue .put , ranks = ranks )
8789 # sleep 3s to wait process group is destroyed
8890 time .sleep (3 )
89- except RuntimeError as e :
91+ except Exception as e :
9092 print (f"[rank{ rank } ] Caught exception from worker process: { e } " )
9193 assert isinstance (e , RuntimeError )
94+ assert "failed to update weights due to remote error(s)" in str (e )
9295 finally :
9396 ps .unregister_checkpoint (checkpoint_name )
9497 queue .put (None )
9598
9699
100+ @pytest .mark .gpu
101+ def test_update ():
102+ world_size = torch .cuda .device_count ()
103+ assert world_size >= 2 , "This test requires at least 2 GPUs."
104+
105+ master_addr = "localhost"
106+ master_port = random .randint (20000 , 30000 )
107+
108+ cmd = [
109+ "torchrun" ,
110+ "--nproc_per_node" ,
111+ str (world_size ),
112+ "--master_addr" ,
113+ master_addr ,
114+ "--master_port" ,
115+ str (master_port ),
116+ "tests/test_error_quit.py" ,
117+ ]
118+
119+ result = subprocess .run (
120+ cmd ,
121+ capture_output = False ,
122+ text = True ,
123+ cwd = os .path .dirname (os .path .dirname (os .path .abspath (__file__ ))),
124+ shell = False ,
125+ check = False ,
126+ )
127+
128+ assert result .returncode == 0
129+
130+
97131if __name__ == "__main__" :
98132 run ()
You can’t perform that action at this time.
0 commit comments