Skip to content

Commit bf82ca9

Browse files
authored
Tweak cli. add --sbatch and --salloc as alternatives to --alloc and --persist (#119)
* Split `mila` function into chunks Signed-off-by: Fabrice Normandin <[email protected]> * Make `mila code` default to `mila code .` Signed-off-by: Fabrice Normandin <[email protected]> * Make the job_id an int instead of str Signed-off-by: Fabrice Normandin <[email protected]> * Add --salloc and --sbatch flags (see desc.) - Adds a --salloc flag which is exaclty the same as using the '--alloc' flag (without the --persist) flag. - Adds a --sbatch flag which is the same as doing --persist --alloc ... I think these are more naturally understood as the argument that are passed to `salloc` and `sbatch` respectively. Also, these two new args are in a mutually exclusive group with --persist. Signed-off-by: Fabrice Normandin <[email protected]> * Put the --alloc/--salloc/--sbatch args last Signed-off-by: Fabrice Normandin <[email protected]> * Add missing regression test file Signed-off-by: Fabrice Normandin <[email protected]> --------- Signed-off-by: Fabrice Normandin <[email protected]>
1 parent 9b4a7cd commit bf82ca9

10 files changed

+219
-81
lines changed

milatools/cli/commands.py

Lines changed: 85 additions & 42 deletions
Original file line numberDiff line numberDiff line change
@@ -24,7 +24,7 @@
2424
from contextlib import ExitStack
2525
from logging import getLogger as get_logger
2626
from pathlib import Path
27-
from typing import Any
27+
from typing import Any, Callable
2828
from urllib.parse import urlencode
2929

3030
import questionary as qn
@@ -54,6 +54,7 @@
5454
from .profile import ensure_program, setup_profile
5555
from .utils import (
5656
CLUSTERS,
57+
AllocationFlagsAction,
5758
Cluster,
5859
CommandNotFoundError,
5960
MilatoolsUserError,
@@ -130,6 +131,13 @@ def main():
130131

131132
def mila():
132133
parser = ArgumentParser(prog="mila", description=__doc__, add_help=True)
134+
add_arguments(parser)
135+
verbose, function, args_dict = parse_args(parser)
136+
setup_logging(verbose)
137+
return function(**args_dict)
138+
139+
140+
def add_arguments(parser: argparse.ArgumentParser):
133141
parser.add_argument(
134142
"--version",
135143
action="version",
@@ -198,24 +206,26 @@ def mila():
198206
code_parser = subparsers.add_parser(
199207
"code",
200208
help="Open a remote VSCode session on a compute node.",
201-
formatter_class=SortingHelpFormatter,
202209
)
203210
code_parser.add_argument(
204-
"PATH", help="Path to open on the remote machine", type=str
211+
"PATH",
212+
help=(
213+
"Path to open on the remote machine. Defaults to $HOME.\n"
214+
"Can be a relative or absolute path. When a relative path (that doesn't "
215+
"start with a '/', like foo/bar) is passed, the path is relative to the "
216+
"$HOME directory on the selected cluster.\n"
217+
"For example, foo/project will be interpreted as $HOME/foo/project."
218+
),
219+
type=str,
220+
default=".",
221+
nargs="?",
205222
)
206223
code_parser.add_argument(
207224
"--cluster",
208225
choices=CLUSTERS,
209226
default="mila",
210227
help="Which cluster to connect to.",
211228
)
212-
code_parser.add_argument(
213-
"--alloc",
214-
nargs=argparse.REMAINDER,
215-
help="Extra options to pass to slurm",
216-
metavar="VALUE",
217-
default=[],
218-
)
219229
code_parser.add_argument(
220230
"--command",
221231
default=get_code_command(),
@@ -227,23 +237,20 @@ def mila():
227237
)
228238
code_parser.add_argument(
229239
"--job",
230-
type=str,
240+
type=int,
231241
default=None,
232242
help="Job ID to connect to",
233-
metavar="VALUE",
243+
metavar="JOB_ID",
234244
)
235245
code_parser.add_argument(
236246
"--node",
237247
type=str,
238248
default=None,
239249
help="Node to connect to",
240-
metavar="VALUE",
241-
)
242-
code_parser.add_argument(
243-
"--persist",
244-
action="store_true",
245-
help="Whether the server should persist or not",
250+
metavar="NODE",
246251
)
252+
_add_allocation_options(code_parser)
253+
247254
code_parser.set_defaults(function=code)
248255

249256
# ----- mila sync vscode-extensions ------
@@ -353,7 +360,6 @@ def mila():
353360
serve_lab_parser = serve_subparsers.add_parser(
354361
"lab",
355362
help="Start a Jupyterlab server.",
356-
formatter_class=SortingHelpFormatter,
357363
)
358364
serve_lab_parser.add_argument(
359365
"PATH",
@@ -369,7 +375,6 @@ def mila():
369375
serve_notebook_parser = serve_subparsers.add_parser(
370376
"notebook",
371377
help="Start a Jupyter Notebook server.",
372-
formatter_class=SortingHelpFormatter,
373378
)
374379
serve_notebook_parser.add_argument(
375380
"PATH",
@@ -385,7 +390,6 @@ def mila():
385390
serve_tensorboard_parser = serve_subparsers.add_parser(
386391
"tensorboard",
387392
help="Start a Tensorboard server.",
388-
formatter_class=SortingHelpFormatter,
389393
)
390394
serve_tensorboard_parser.add_argument(
391395
"LOGDIR", type=str, help="Path to the experiment logs"
@@ -398,7 +402,6 @@ def mila():
398402
serve_mlflow_parser = serve_subparsers.add_parser(
399403
"mlflow",
400404
help="Start an MLFlow server.",
401-
formatter_class=SortingHelpFormatter,
402405
)
403406
serve_mlflow_parser.add_argument(
404407
"LOGDIR", type=str, help="Path to the experiment logs"
@@ -411,22 +414,29 @@ def mila():
411414
serve_aim_parser = serve_subparsers.add_parser(
412415
"aim",
413416
help="Start an AIM server.",
414-
formatter_class=SortingHelpFormatter,
415417
)
416418
serve_aim_parser.add_argument(
417419
"LOGDIR", type=str, help="Path to the experiment logs"
418420
)
419421
_add_standard_server_args(serve_aim_parser)
420422
serve_aim_parser.set_defaults(function=aim)
421423

424+
425+
def parse_args(parser: argparse.ArgumentParser) -> tuple[int, Callable, dict[str, Any]]:
426+
"""Parses the command-line arguments.
427+
428+
Returns the verbosity level, the function (or awaitable) to call, and the arguments
429+
to the function.
430+
"""
422431
args = parser.parse_args()
423432
args_dict = vars(args)
433+
424434
verbose: int = args_dict.pop("verbose")
435+
425436
function = args_dict.pop("function")
426437
_ = args_dict.pop("<command>")
427438
_ = args_dict.pop("<serve_subcommand>", None)
428439
_ = args_dict.pop("<sync_subcommand>", None)
429-
setup_logging(verbose)
430440
# replace SEARCH -> "search", REMOTE -> "remote", etc.
431441
args_dict = _convert_uppercase_keys_to_lowercase(args_dict)
432442

@@ -438,7 +448,7 @@ def mila():
438448
return
439449

440450
assert callable(function)
441-
return function(**args_dict)
451+
return verbose, function, args_dict
442452

443453

444454
def setup_logging(verbose: int) -> None:
@@ -550,7 +560,7 @@ def code(
550560
path: str,
551561
command: str,
552562
persist: bool,
553-
job: str | None,
563+
job: int | None,
554564
node: str | None,
555565
alloc: list[str],
556566
cluster: Cluster = "mila",
@@ -788,7 +798,7 @@ class StandardServerArgs(TypedDict):
788798
alloc: list[str]
789799
"""Extra options to pass to slurm."""
790800

791-
job: str | None
801+
job: int | None
792802
"""Job ID to connect to."""
793803

794804
name: str | None
@@ -931,20 +941,56 @@ def add_arguments(self, actions):
931941
super().add_arguments(actions)
932942

933943

934-
def _add_standard_server_args(parser: ArgumentParser):
935-
parser.add_argument(
944+
def _add_allocation_options(parser: ArgumentParser):
945+
# note: Ideally we'd like [--persist --alloc] | [--salloc] | [--sbatch] (i.e. a
946+
# subgroup with alloc and persist within a mutually exclusive group with salloc and
947+
# sbatch) but that doesn't seem possible with argparse as far as I can tell.
948+
arg_group = parser.add_argument_group(
949+
"Allocation options", description="Extra options to pass to slurm."
950+
)
951+
alloc_group = arg_group.add_mutually_exclusive_group()
952+
common_kwargs = {
953+
"dest": "alloc",
954+
"nargs": argparse.REMAINDER,
955+
"action": AllocationFlagsAction,
956+
"metavar": "VALUE",
957+
"default": [],
958+
}
959+
alloc_group.add_argument(
960+
"--persist",
961+
action="store_true",
962+
help="Whether the server should persist or not when using --alloc",
963+
)
964+
# --persist can be used with --alloc
965+
arg_group.add_argument(
936966
"--alloc",
937-
nargs=argparse.REMAINDER,
938-
help="Extra options to pass to slurm",
939-
metavar="VALUE",
940-
default=[],
967+
**common_kwargs,
968+
help="Extra options to pass to salloc or to sbatch if --persist is set.",
969+
)
970+
# --persist cannot be used with --salloc or --sbatch.
971+
# Note: REMAINDER args like --alloc, --sbatch and --salloc are already mutually
972+
# exclusive in a sense, since it's only possible to use one correctly, the other
973+
# args are stored in the first one (e.g. mila code --alloc --salloc bob will have
974+
# alloc of ["--salloc", "bob"]).
975+
alloc_group.add_argument(
976+
"--salloc",
977+
**common_kwargs,
978+
help="Extra options to pass to salloc. Same as using --alloc without --persist.",
941979
)
980+
alloc_group.add_argument(
981+
"--sbatch",
982+
**common_kwargs,
983+
help="Extra options to pass to sbatch. Same as using --alloc with --persist.",
984+
)
985+
986+
987+
def _add_standard_server_args(parser: ArgumentParser):
942988
parser.add_argument(
943989
"--job",
944-
type=str,
990+
type=int,
945991
default=None,
946992
help="Job ID to connect to",
947-
metavar="VALUE",
993+
metavar="JOB_ID",
948994
)
949995
parser.add_argument(
950996
"--name",
@@ -960,11 +1006,6 @@ def _add_standard_server_args(parser: ArgumentParser):
9601006
help="Node to connect to",
9611007
metavar="VALUE",
9621008
)
963-
parser.add_argument(
964-
"--persist",
965-
action="store_true",
966-
help="Whether the server should persist or not",
967-
)
9681009
parser.add_argument(
9691010
"--port",
9701011
type=int,
@@ -979,6 +1020,8 @@ def _add_standard_server_args(parser: ArgumentParser):
9791020
help="Name of the profile to use",
9801021
metavar="VALUE",
9811022
)
1023+
# Add these arguments last because we want them to show up last in the usage message
1024+
_add_allocation_options(parser)
9821025

9831026

9841027
def _standard_server(
@@ -992,7 +1035,7 @@ def _standard_server(
9921035
port: int | None,
9931036
name: str | None,
9941037
node: str | None,
995-
job: str | None,
1038+
job: int | None,
9961039
alloc: list[str],
9971040
port_pattern=None,
9981041
token_pattern=None,
@@ -1277,7 +1320,7 @@ def get_colour(used: float, max: float) -> str:
12771320
def _find_allocation(
12781321
remote: RemoteV1,
12791322
node: str | None,
1280-
job: str | None,
1323+
job: int | str | None,
12811324
alloc: list[str],
12821325
cluster: Cluster = "mila",
12831326
job_name: str = "mila-tools",

milatools/cli/utils.py

Lines changed: 31 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
from __future__ import annotations
22

3+
import argparse
34
import contextvars
45
import functools
56
import itertools
@@ -13,6 +14,7 @@
1314
import warnings
1415
from collections.abc import Callable, Iterable
1516
from contextlib import contextmanager
17+
from logging import getLogger as get_logger
1618
from pathlib import Path
1719
from typing import Any, Literal, TypeVar, Union, get_args
1820

@@ -27,6 +29,7 @@
2729
from milatools.utils.remote_v1 import RemoteV1
2830

2931

32+
logger = get_logger(__name__)
3033
control_file_var = contextvars.ContextVar("control_file", default="/dev/null")
3134

3235
SSH_CONFIG_FILE = Path.home() / ".ssh" / "config"
@@ -375,3 +378,31 @@ def removesuffix(s: str, suffix: str) -> str:
375378
return s
376379
else:
377380
removesuffix = str.removesuffix
381+
382+
383+
class AllocationFlagsAction(argparse._StoreAction):
384+
def __call__(
385+
self,
386+
parser: argparse.ArgumentParser,
387+
namespace,
388+
values: list[str],
389+
option_string: str | None = None,
390+
):
391+
persist: bool | None = namespace.persist
392+
if option_string == "--alloc":
393+
namespace.alloc = values
394+
elif option_string == "--salloc":
395+
# --salloc is in a mutually exclusive group with --persist
396+
assert not persist
397+
if persist:
398+
raise argparse.ArgumentError(
399+
argument=self,
400+
message="Cannot use --salloc with --persist, use only --sbatch for a persistent session.",
401+
)
402+
namespace.alloc = values
403+
else:
404+
assert option_string == "--sbatch", option_string
405+
# --sbatch is in a mutually exclusive group with --persist
406+
assert not persist
407+
namespace.alloc = values
408+
namespace.persist = True

tests/cli/test_commands.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -63,7 +63,7 @@ def test_help(
6363
[
6464
"mila", # Error: Missing a subcommand.
6565
"mila search conda",
66-
"mila code", # Error: Missing the required PATH argument.
66+
"mila code --boo", # Error: Unknown argument.
6767
"mila serve", # Error: Missing the subcommand.
6868
"mila forward", # Error: Missing the REMOTE argument.
6969
],
Lines changed: 23 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -1,18 +1,33 @@
11
usage: mila code [-h] [--cluster {mila,cedar,narval,beluga,graham}]
2-
[--alloc ...] [--command VALUE] [--job VALUE] [--node VALUE]
3-
[--persist]
4-
PATH
2+
[--command VALUE] [--job JOB_ID] [--node NODE] [--persist]
3+
[--alloc ...] [--salloc ...] [--sbatch ...]
4+
[PATH]
55

66
positional arguments:
7-
PATH Path to open on the remote machine
7+
PATH Path to open on the remote machine. Defaults to $HOME.
8+
Can be a relative or absolute path. When a relative
9+
path (that doesn't start with a '/', like foo/bar) is
10+
passed, the path is relative to the $HOME directory on
11+
the selected cluster. For example, foo/project will be
12+
interpreted as $HOME/foo/project.
813

914
optional arguments:
1015
-h, --help show this help message and exit
11-
--alloc ... Extra options to pass to slurm
1216
--cluster {mila,cedar,narval,beluga,graham}
1317
Which cluster to connect to.
1418
--command VALUE Command to use to start vscode (defaults to "code" or
1519
the value of $MILATOOLS_CODE_COMMAND)
16-
--job VALUE Job ID to connect to
17-
--node VALUE Node to connect to
18-
--persist Whether the server should persist or not
20+
--job JOB_ID Job ID to connect to
21+
--node NODE Node to connect to
22+
23+
Allocation optional arguments:
24+
Extra options to pass to slurm.
25+
26+
--persist Whether the server should persist or not when using
27+
--alloc
28+
--alloc ... Extra options to pass to salloc or to sbatch if
29+
--persist is set.
30+
--salloc ... Extra options to pass to salloc. Same as using --alloc
31+
without --persist.
32+
--sbatch ... Extra options to pass to sbatch. Same as using --alloc
33+
with --persist.

0 commit comments

Comments
 (0)