Skip to content

Commit 83a0340

Browse files
authored
allow command (#1836)
1 parent a62fc1b commit 83a0340

File tree

1 file changed

+9
-6
lines changed

1 file changed

+9
-6
lines changed

python/mlx/distributed_run.py

Lines changed: 9 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@
66
import json
77
import os
88
import shlex
9+
import shutil
910
import sys
1011
import tempfile
1112
import threading
@@ -125,7 +126,7 @@ def make_monitor_script(rank, hostfile, cwd, env, command, verbose):
125126
script += "\n"
126127

127128
# Replace the process with the script
128-
script += shlex.join(["exec", sys.executable, *command])
129+
script += shlex.join(["exec", *command])
129130
script += "\n"
130131

131132
return script
@@ -210,7 +211,7 @@ def node_thread(rank, host, hostfile):
210211
ring_hosts.append(node)
211212
hostfile = json.dumps(ring_hosts) if len(ring_hosts) > 1 else ""
212213

213-
log(args.verbose, "Running", shlex.join([sys.executable, *command]))
214+
log(args.verbose, "Running", shlex.join(command))
214215

215216
threads = []
216217
for i, h in enumerate(hosts):
@@ -261,7 +262,6 @@ def launch_mpi(parser, hosts, args, command):
261262
*sum((["-x", e] for e in args.env), []),
262263
*sum([shlex.split(arg) for arg in args.mpi_arg], []),
263264
"--",
264-
sys.executable,
265265
*command,
266266
]
267267
log(args.verbose, "Running", " ".join(cmd))
@@ -323,9 +323,12 @@ def main():
323323
hosts = parse_hostlist(parser, args.hosts, args.repeat_hosts)
324324

325325
# Check if the script is a file and convert it to a full path
326-
script = Path(rest[0])
327-
if script.exists():
328-
rest[0] = str(script.resolve())
326+
if (script := Path(rest[0])).exists():
327+
rest[0:1] = [sys.executable, str(script.resolve())]
328+
elif (command := shutil.which(rest[0])) is not None:
329+
rest[0] = command
330+
else:
331+
raise ValueError(f"Invalid script or command {rest[0]}")
329332

330333
# Launch
331334
if args.backend == "ring":

0 commit comments

Comments
 (0)