Skip to content

Commit 85c22a4

Browse files
author
rpehkone
committed
fix issue for single GPU training/inference. By always initializing process group. Megvii-BaseDetection#1722
1 parent 68e0286 commit 85c22a4

File tree

1 file changed

+26
-17
lines changed

1 file changed

+26
-17
lines changed

yolox/core/launch.py

+26-17
Original file line numberDiff line numberDiff line change
@@ -57,17 +57,34 @@ def launch(
5757
args (tuple): arguments passed to main_func
5858
"""
5959
world_size = num_machines * num_gpus_per_machine
60+
if world_size <= 0:
61+
raise ValueError('`world_size` should be positive, currently {}'.format(world_size))
62+
63+
# Even if `world_size == 1`, we have to initialize the process group,
64+
# so the user code can use all the `torch.dist`` facilities. This
65+
# makes the code uniform whether there is one or more processes.
66+
67+
if dist_url == "auto":
68+
assert (
69+
num_machines == 1
70+
), "`dist_url=auto` cannot work with distributed training."
71+
port = _find_free_port()
72+
dist_url = f"tcp://127.0.0.1:{port}"
73+
74+
worker_args = (
75+
main_func,
76+
world_size,
77+
num_gpus_per_machine,
78+
machine_rank,
79+
backend,
80+
dist_url,
81+
args,
82+
)
83+
6084
if world_size > 1:
6185
# https://github.com/pytorch/pytorch/pull/14391
6286
# TODO prctl in spawned processes
6387

64-
if dist_url == "auto":
65-
assert (
66-
num_machines == 1
67-
), "dist_url=auto cannot work with distributed training."
68-
port = _find_free_port()
69-
dist_url = f"tcp://127.0.0.1:{port}"
70-
7188
start_method = "spawn"
7289
cache = vars(args[1]).get("cache", False)
7390

@@ -82,20 +99,12 @@ def launch(
8299
mp.start_processes(
83100
_distributed_worker,
84101
nprocs=num_gpus_per_machine,
85-
args=(
86-
main_func,
87-
world_size,
88-
num_gpus_per_machine,
89-
machine_rank,
90-
backend,
91-
dist_url,
92-
args,
93-
),
102+
args=worker_args,
94103
daemon=False,
95104
start_method=start_method,
96105
)
97106
else:
98-
main_func(*args)
107+
_distributed_worker(0, *worker_args)
99108

100109

101110
def _distributed_worker(

0 commit comments

Comments
 (0)