|
1 | 1 | import argparse |
| 2 | +import platform |
| 3 | +import warnings |
2 | 4 | from datetime import datetime |
3 | 5 | from gridfm_graphkit.cli import main_cli, benchmark_cli |
4 | 6 |
|
5 | 7 |
|
6 | 8 | import subprocess |
7 | 9 | import os |
8 | 10 |
|
| 11 | + |
| 12 | +def _warn_mp_context_on_linux(mp_context): |
| 13 | + """On Linux, recommend 'spawn' when mp_context is unset, 'fork', or 'forkserver'.""" |
| 14 | + if platform.system() != "Linux": |
| 15 | + return |
| 16 | + if mp_context in (None, "fork", "forkserver"): |
| 17 | + chosen = mp_context if mp_context is not None else "PyTorch default" |
| 18 | + warnings.warn( |
| 19 | + f"--mp_context is '{chosen}' on Linux. 'spawn' is recommended for safety " |
| 20 | + "(avoids issues with CUDA initialization and forked processes), though " |
| 21 | + "'fork'/'forkserver' may be faster.", |
| 22 | + stacklevel=2, |
| 23 | + ) |
| 24 | + |
9 | 25 | def is_lsf(): |
10 | 26 | return ( |
11 | 27 | os.environ.get("LSB_JOBID") is not None |
@@ -94,13 +110,15 @@ def main(): |
94 | 110 | _mp_context_kwargs = dict( |
95 | 111 | dest="mp_context", |
96 | 112 | type=str, |
97 | | - default="spawn", |
| 113 | + default=None, |
98 | 114 | choices=["spawn", "fork", "forkserver"], |
99 | 115 | help=( |
100 | 116 | "Multiprocessing start method for DataLoader workers. " |
101 | | - "'spawn' (default) is safest and works everywhere. " |
| 117 | + "Defaults to None so PyTorch picks automatically. " |
| 118 | + "'spawn' is safest and works everywhere. " |
102 | 119 | "'fork' avoids re-importing modules but is unsafe after CUDA init. " |
103 | | - "'forkserver' uses a clean server process but requires file-descriptor passing." |
| 120 | + "'forkserver' uses a clean server process but requires file-descriptor passing. " |
| 121 | + "On Linux, 'spawn' is recommended; other choices emit a warning." |
104 | 122 | ), |
105 | 123 | ) |
106 | 124 |
|
@@ -370,6 +388,8 @@ def main(): |
370 | 388 |
|
371 | 389 | args = parser.parse_args() |
372 | 390 |
|
| 391 | + _warn_mp_context_on_linux(getattr(args, "mp_context", None)) |
| 392 | + |
373 | 393 | if args.command == "benchmark": |
374 | 394 | benchmark_cli(args) |
375 | 395 | else: |
|
0 commit comments