2424from contextlib import ExitStack
2525from logging import getLogger as get_logger
2626from pathlib import Path
27- from typing import Any
27+ from typing import Any , Callable
2828from urllib .parse import urlencode
2929
3030import questionary as qn
5454from .profile import ensure_program , setup_profile
5555from .utils import (
5656 CLUSTERS ,
57+ AllocationFlagsAction ,
5758 Cluster ,
5859 CommandNotFoundError ,
5960 MilatoolsUserError ,
@@ -130,6 +131,13 @@ def main():
130131
131132def mila ():
132133 parser = ArgumentParser (prog = "mila" , description = __doc__ , add_help = True )
134+ add_arguments (parser )
135+ verbose , function , args_dict = parse_args (parser )
136+ setup_logging (verbose )
137+ return function (** args_dict )
138+
139+
140+ def add_arguments (parser : argparse .ArgumentParser ):
133141 parser .add_argument (
134142 "--version" ,
135143 action = "version" ,
@@ -198,24 +206,26 @@ def mila():
198206 code_parser = subparsers .add_parser (
199207 "code" ,
200208 help = "Open a remote VSCode session on a compute node." ,
201- formatter_class = SortingHelpFormatter ,
202209 )
203210 code_parser .add_argument (
204- "PATH" , help = "Path to open on the remote machine" , type = str
211+ "PATH" ,
212+ help = (
213+ "Path to open on the remote machine. Defaults to $HOME.\n "
214+ "Can be a relative or absolute path. When a relative path (that doesn't "
215+ "start with a '/', like foo/bar) is passed, the path is relative to the "
216+ "$HOME directory on the selected cluster.\n "
217+ "For example, foo/project will be interpreted as $HOME/foo/project."
218+ ),
219+ type = str ,
220+ default = "." ,
221+ nargs = "?" ,
205222 )
206223 code_parser .add_argument (
207224 "--cluster" ,
208225 choices = CLUSTERS ,
209226 default = "mila" ,
210227 help = "Which cluster to connect to." ,
211228 )
212- code_parser .add_argument (
213- "--alloc" ,
214- nargs = argparse .REMAINDER ,
215- help = "Extra options to pass to slurm" ,
216- metavar = "VALUE" ,
217- default = [],
218- )
219229 code_parser .add_argument (
220230 "--command" ,
221231 default = get_code_command (),
@@ -227,23 +237,20 @@ def mila():
227237 )
228238 code_parser .add_argument (
229239 "--job" ,
230- type = str ,
240+ type = int ,
231241 default = None ,
232242 help = "Job ID to connect to" ,
233- metavar = "VALUE " ,
243+ metavar = "JOB_ID " ,
234244 )
235245 code_parser .add_argument (
236246 "--node" ,
237247 type = str ,
238248 default = None ,
239249 help = "Node to connect to" ,
240- metavar = "VALUE" ,
241- )
242- code_parser .add_argument (
243- "--persist" ,
244- action = "store_true" ,
245- help = "Whether the server should persist or not" ,
250+ metavar = "NODE" ,
246251 )
252+ _add_allocation_options (code_parser )
253+
247254 code_parser .set_defaults (function = code )
248255
249256 # ----- mila sync vscode-extensions ------
@@ -353,7 +360,6 @@ def mila():
353360 serve_lab_parser = serve_subparsers .add_parser (
354361 "lab" ,
355362 help = "Start a Jupyterlab server." ,
356- formatter_class = SortingHelpFormatter ,
357363 )
358364 serve_lab_parser .add_argument (
359365 "PATH" ,
@@ -369,7 +375,6 @@ def mila():
369375 serve_notebook_parser = serve_subparsers .add_parser (
370376 "notebook" ,
371377 help = "Start a Jupyter Notebook server." ,
372- formatter_class = SortingHelpFormatter ,
373378 )
374379 serve_notebook_parser .add_argument (
375380 "PATH" ,
@@ -385,7 +390,6 @@ def mila():
385390 serve_tensorboard_parser = serve_subparsers .add_parser (
386391 "tensorboard" ,
387392 help = "Start a Tensorboard server." ,
388- formatter_class = SortingHelpFormatter ,
389393 )
390394 serve_tensorboard_parser .add_argument (
391395 "LOGDIR" , type = str , help = "Path to the experiment logs"
@@ -398,7 +402,6 @@ def mila():
398402 serve_mlflow_parser = serve_subparsers .add_parser (
399403 "mlflow" ,
400404 help = "Start an MLFlow server." ,
401- formatter_class = SortingHelpFormatter ,
402405 )
403406 serve_mlflow_parser .add_argument (
404407 "LOGDIR" , type = str , help = "Path to the experiment logs"
@@ -411,22 +414,29 @@ def mila():
411414 serve_aim_parser = serve_subparsers .add_parser (
412415 "aim" ,
413416 help = "Start an AIM server." ,
414- formatter_class = SortingHelpFormatter ,
415417 )
416418 serve_aim_parser .add_argument (
417419 "LOGDIR" , type = str , help = "Path to the experiment logs"
418420 )
419421 _add_standard_server_args (serve_aim_parser )
420422 serve_aim_parser .set_defaults (function = aim )
421423
424+
425+ def parse_args (parser : argparse .ArgumentParser ) -> tuple [int , Callable , dict [str , Any ]]:
426+ """Parses the command-line arguments.
427+
428+ Returns the verbosity level, the function (or awaitable) to call, and the arguments
429+ to the function.
430+ """
422431 args = parser .parse_args ()
423432 args_dict = vars (args )
433+
424434 verbose : int = args_dict .pop ("verbose" )
435+
425436 function = args_dict .pop ("function" )
426437 _ = args_dict .pop ("<command>" )
427438 _ = args_dict .pop ("<serve_subcommand>" , None )
428439 _ = args_dict .pop ("<sync_subcommand>" , None )
429- setup_logging (verbose )
430440 # replace SEARCH -> "search", REMOTE -> "remote", etc.
431441 args_dict = _convert_uppercase_keys_to_lowercase (args_dict )
432442
@@ -438,7 +448,7 @@ def mila():
438448 return
439449
440450 assert callable (function )
441- return function ( ** args_dict )
451+ return verbose , function , args_dict
442452
443453
444454def setup_logging (verbose : int ) -> None :
@@ -550,7 +560,7 @@ def code(
550560 path : str ,
551561 command : str ,
552562 persist : bool ,
553- job : str | None ,
563+ job : int | None ,
554564 node : str | None ,
555565 alloc : list [str ],
556566 cluster : Cluster = "mila" ,
@@ -788,7 +798,7 @@ class StandardServerArgs(TypedDict):
788798 alloc : list [str ]
789799 """Extra options to pass to slurm."""
790800
791- job : str | None
801+ job : int | None
792802 """Job ID to connect to."""
793803
794804 name : str | None
@@ -931,20 +941,56 @@ def add_arguments(self, actions):
931941 super ().add_arguments (actions )
932942
933943
934- def _add_standard_server_args (parser : ArgumentParser ):
935- parser .add_argument (
944+ def _add_allocation_options (parser : ArgumentParser ):
945+ # note: Ideally we'd like [--persist --alloc] | [--salloc] | [--sbatch] (i.e. a
946+ # subgroup with alloc and persist within a mutually exclusive group with salloc and
947+ # sbatch) but that doesn't seem possible with argparse as far as I can tell.
948+ arg_group = parser .add_argument_group (
949+ "Allocation options" , description = "Extra options to pass to slurm."
950+ )
951+ alloc_group = arg_group .add_mutually_exclusive_group ()
952+ common_kwargs = {
953+ "dest" : "alloc" ,
954+ "nargs" : argparse .REMAINDER ,
955+ "action" : AllocationFlagsAction ,
956+ "metavar" : "VALUE" ,
957+ "default" : [],
958+ }
959+ alloc_group .add_argument (
960+ "--persist" ,
961+ action = "store_true" ,
962+ help = "Whether the server should persist or not when using --alloc" ,
963+ )
964+ # --persist can be used with --alloc
965+ arg_group .add_argument (
936966 "--alloc" ,
937- nargs = argparse .REMAINDER ,
938- help = "Extra options to pass to slurm" ,
939- metavar = "VALUE" ,
940- default = [],
967+ ** common_kwargs ,
968+ help = "Extra options to pass to salloc or to sbatch if --persist is set." ,
969+ )
970+ # --persist cannot be used with --salloc or --sbatch.
971+ # Note: REMAINDER args like --alloc, --sbatch and --salloc are already mutually
972+ # exclusive in a sense, since it's only possible to use one correctly, the other
973+ # args are stored in the first one (e.g. mila code --alloc --salloc bob will have
974+ # alloc of ["--salloc", "bob"]).
975+ alloc_group .add_argument (
976+ "--salloc" ,
977+ ** common_kwargs ,
978+ help = "Extra options to pass to salloc. Same as using --alloc without --persist." ,
941979 )
980+ alloc_group .add_argument (
981+ "--sbatch" ,
982+ ** common_kwargs ,
983+ help = "Extra options to pass to sbatch. Same as using --alloc with --persist." ,
984+ )
985+
986+
987+ def _add_standard_server_args (parser : ArgumentParser ):
942988 parser .add_argument (
943989 "--job" ,
944- type = str ,
990+ type = int ,
945991 default = None ,
946992 help = "Job ID to connect to" ,
947- metavar = "VALUE " ,
993+ metavar = "JOB_ID " ,
948994 )
949995 parser .add_argument (
950996 "--name" ,
@@ -960,11 +1006,6 @@ def _add_standard_server_args(parser: ArgumentParser):
9601006 help = "Node to connect to" ,
9611007 metavar = "VALUE" ,
9621008 )
963- parser .add_argument (
964- "--persist" ,
965- action = "store_true" ,
966- help = "Whether the server should persist or not" ,
967- )
9681009 parser .add_argument (
9691010 "--port" ,
9701011 type = int ,
@@ -979,6 +1020,8 @@ def _add_standard_server_args(parser: ArgumentParser):
9791020 help = "Name of the profile to use" ,
9801021 metavar = "VALUE" ,
9811022 )
1023+ # Add these arguments last because we want them to show up last in the usage message
1024+ _add_allocation_options (parser )
9821025
9831026
9841027def _standard_server (
@@ -992,7 +1035,7 @@ def _standard_server(
9921035 port : int | None ,
9931036 name : str | None ,
9941037 node : str | None ,
995- job : str | None ,
1038+ job : int | None ,
9961039 alloc : list [str ],
9971040 port_pattern = None ,
9981041 token_pattern = None ,
@@ -1277,7 +1320,7 @@ def get_colour(used: float, max: float) -> str:
12771320def _find_allocation (
12781321 remote : RemoteV1 ,
12791322 node : str | None ,
1280- job : str | None ,
1323+ job : int | str | None ,
12811324 alloc : list [str ],
12821325 cluster : Cluster = "mila" ,
12831326 job_name : str = "mila-tools" ,
0 commit comments