|
6 | 6 | import json |
7 | 7 | import os |
8 | 8 | import shlex |
| 9 | +import shutil |
9 | 10 | import sys |
10 | 11 | import tempfile |
11 | 12 | import threading |
@@ -125,7 +126,7 @@ def make_monitor_script(rank, hostfile, cwd, env, command, verbose): |
125 | 126 | script += "\n" |
126 | 127 |
|
127 | 128 | # Replace the process with the script |
128 | | - script += shlex.join(["exec", sys.executable, *command]) |
| 129 | + script += shlex.join(["exec", *command]) |
129 | 130 | script += "\n" |
130 | 131 |
|
131 | 132 | return script |
@@ -210,7 +211,7 @@ def node_thread(rank, host, hostfile): |
210 | 211 | ring_hosts.append(node) |
211 | 212 | hostfile = json.dumps(ring_hosts) if len(ring_hosts) > 1 else "" |
212 | 213 |
|
213 | | - log(args.verbose, "Running", shlex.join([sys.executable, *command])) |
| 214 | + log(args.verbose, "Running", shlex.join(command)) |
214 | 215 |
|
215 | 216 | threads = [] |
216 | 217 | for i, h in enumerate(hosts): |
@@ -261,7 +262,6 @@ def launch_mpi(parser, hosts, args, command): |
261 | 262 | *sum((["-x", e] for e in args.env), []), |
262 | 263 | *sum([shlex.split(arg) for arg in args.mpi_arg], []), |
263 | 264 | "--", |
264 | | - sys.executable, |
265 | 265 | *command, |
266 | 266 | ] |
267 | 267 | log(args.verbose, "Running", " ".join(cmd)) |
@@ -323,9 +323,12 @@ def main(): |
323 | 323 | hosts = parse_hostlist(parser, args.hosts, args.repeat_hosts) |
324 | 324 |
|
325 | 325 | # Check if the script is a file and convert it to a full path |
326 | | - script = Path(rest[0]) |
327 | | - if script.exists(): |
328 | | - rest[0] = str(script.resolve()) |
| 326 | + if (script := Path(rest[0])).exists(): |
| 327 | + rest[0:1] = [sys.executable, str(script.resolve())] |
| 328 | + elif (command := shutil.which(rest[0])) is not None: |
| 329 | + rest[0] = command |
| 330 | + else: |
| 331 | + raise ValueError(f"Invalid script or command {rest[0]}") |
329 | 332 |
|
330 | 333 | # Launch |
331 | 334 | if args.backend == "ring": |
|
0 commit comments