File tree 1 file changed +26
-17
lines changed
1 file changed +26
-17
lines changed Original file line number Diff line number Diff line change @@ -57,17 +57,34 @@ def launch(
57
57
args (tuple): arguments passed to main_func
58
58
"""
59
59
world_size = num_machines * num_gpus_per_machine
60
+ if world_size <= 0 :
61
+ raise ValueError ('`world_size` should be positive, currently {}' .format (world_size ))
62
+
63
+ # Even if `world_size == 1`, we have to initialize the process group,
64
+ # so the user code can use all the `torch.dist`` facilities. This
65
+ # makes the code uniform whether there is one or more processes.
66
+
67
+ if dist_url == "auto" :
68
+ assert (
69
+ num_machines == 1
70
+ ), "`dist_url=auto` cannot work with distributed training."
71
+ port = _find_free_port ()
72
+ dist_url = f"tcp://127.0.0.1:{ port } "
73
+
74
+ worker_args = (
75
+ main_func ,
76
+ world_size ,
77
+ num_gpus_per_machine ,
78
+ machine_rank ,
79
+ backend ,
80
+ dist_url ,
81
+ args ,
82
+ )
83
+
60
84
if world_size > 1 :
61
85
# https://github.com/pytorch/pytorch/pull/14391
62
86
# TODO prctl in spawned processes
63
87
64
- if dist_url == "auto" :
65
- assert (
66
- num_machines == 1
67
- ), "dist_url=auto cannot work with distributed training."
68
- port = _find_free_port ()
69
- dist_url = f"tcp://127.0.0.1:{ port } "
70
-
71
88
start_method = "spawn"
72
89
cache = vars (args [1 ]).get ("cache" , False )
73
90
@@ -82,20 +99,12 @@ def launch(
82
99
mp .start_processes (
83
100
_distributed_worker ,
84
101
nprocs = num_gpus_per_machine ,
85
- args = (
86
- main_func ,
87
- world_size ,
88
- num_gpus_per_machine ,
89
- machine_rank ,
90
- backend ,
91
- dist_url ,
92
- args ,
93
- ),
102
+ args = worker_args ,
94
103
daemon = False ,
95
104
start_method = start_method ,
96
105
)
97
106
else :
98
- main_func ( * args )
107
+ _distributed_worker ( 0 , * worker_args )
99
108
100
109
101
110
def _distributed_worker (
You can’t perform that action at this time.
0 commit comments