Skip to content

Commit 22cb81f

Browse files
committed
change mp_context default from spawn to None/auto, add warning
Signed-off-by: Romeo Kienzler <romeo.kienzler1@ibm.com>
1 parent e85322a commit 22cb81f

1 file changed

Lines changed: 23 additions & 3 deletions

File tree

gridfm_graphkit/__main__.py

Lines changed: 23 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,27 @@
11
import argparse
2+
import platform
3+
import warnings
24
from datetime import datetime
35
from gridfm_graphkit.cli import main_cli, benchmark_cli
46

57

68
import subprocess
79
import os
810

11+
12+
def _warn_mp_context_on_linux(mp_context):
13+
"""On Linux, recommend 'spawn' when mp_context is unset, 'fork', or 'forkserver'."""
14+
if platform.system() != "Linux":
15+
return
16+
if mp_context in (None, "fork", "forkserver"):
17+
chosen = mp_context if mp_context is not None else "PyTorch default"
18+
warnings.warn(
19+
f"--mp_context is '{chosen}' on Linux. 'spawn' is recommended for safety "
20+
"(avoids issues with CUDA initialization and forked processes), though "
21+
"'fork'/'forkserver' may be faster.",
22+
stacklevel=2,
23+
)
24+
925
def is_lsf():
1026
return (
1127
os.environ.get("LSB_JOBID") is not None
@@ -94,13 +110,15 @@ def main():
94110
_mp_context_kwargs = dict(
95111
dest="mp_context",
96112
type=str,
97-
default="spawn",
113+
default=None,
98114
choices=["spawn", "fork", "forkserver"],
99115
help=(
100116
"Multiprocessing start method for DataLoader workers. "
101-
"'spawn' (default) is safest and works everywhere. "
117+
"Defaults to None so PyTorch picks automatically. "
118+
"'spawn' is safest and works everywhere. "
102119
"'fork' avoids re-importing modules but is unsafe after CUDA init. "
103-
"'forkserver' uses a clean server process but requires file-descriptor passing."
120+
"'forkserver' uses a clean server process but requires file-descriptor passing. "
121+
"On Linux, 'spawn' is recommended; other choices emit a warning."
104122
),
105123
)
106124

@@ -370,6 +388,8 @@ def main():
370388

371389
args = parser.parse_args()
372390

391+
_warn_mp_context_on_linux(getattr(args, "mp_context", None))
392+
373393
if args.command == "benchmark":
374394
benchmark_cli(args)
375395
else:

0 commit comments

Comments
 (0)