Skip to content

Commit 11d7939

Browse files
authored
Make torchx error out when the same arg is passed in twice
Differential Revision: D55216738 Pull Request resolved: #854
1 parent 591965c commit 11d7939

File tree

4 files changed

+65
-5
lines changed

4 files changed

+65
-5
lines changed

torchx/cli/argparse_util.py

+36-3
Original file line numberDiff line numberDiff line change
@@ -6,19 +6,25 @@
66

77
# pyre-strict
88

9+
import logging
10+
import sys
911
from argparse import Action, ArgumentParser, Namespace
10-
from typing import Any, Dict, Optional, Sequence, Text
12+
from typing import Any, Dict, List, Optional, Sequence, Set, Text
1113

1214
from torchx.runner import config
1315

16+
logger: logging.Logger = logging.getLogger(__name__)
1417

15-
class _torchxconfig(Action):
18+
19+
class torchxconfig(Action):
1620
"""
1721
Custom argparse action that loads default torchx CLI options
1822
from .torchxconfig file.
1923
2024
"""
2125

26+
called_args: Set[str] = set()
27+
2228
# since this action is used for each argparse argument
2329
# load the config section for the subcmd once
2430
_subcmd_configs: Dict[str, Dict[str, str]] = {}
@@ -66,13 +72,18 @@ def __call__(
6672
values: Any, # pyre-ignore[2] declared as Any in superclass Action
6773
option_string: Optional[str] = None,
6874
) -> None:
75+
if option_string is not None:
76+
if option_string in self.called_args:
77+
logger.error(f"{option_string} is specified more than once")
78+
sys.exit(1)
79+
self.called_args.add(option_string)
6980
setattr(namespace, self.dest, values)
7081

7182

7283
# argparse takes the action as a Type[Action] so we can't have custom constructors
7384
# hence for each subcommand we need to subclass the base _torchxconfig Action
7485
# this is also how store_true and store_false builtin actions are implemented in argparse
75-
class torchxconfig_run(_torchxconfig):
86+
class torchxconfig_run(torchxconfig):
7687
"""
7788
Custom action that gets the default argument from .torchxconfig.
7889
"""
@@ -94,3 +105,25 @@ def __init__(
94105
option_strings=option_strings,
95106
**kwargs,
96107
)
108+
109+
110+
class ArgOnceAction(Action):
111+
"""
112+
Custom argparse action only allows argument to be specified once
113+
"""
114+
115+
called_args: Set[str] = set()
116+
117+
def __call__(
118+
self,
119+
parser: ArgumentParser,
120+
namespace: Namespace,
121+
values: List[str],
122+
option_string: Optional[str] = None,
123+
) -> None:
124+
if option_string is not None:
125+
if option_string in self.called_args:
126+
logger.error(f"{option_string} is specified more than once")
127+
sys.exit(1)
128+
self.called_args.add(option_string)
129+
setattr(namespace, self.dest, values)

torchx/cli/cmd_run.py

+3-1
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,7 @@
1717
from typing import Dict, List, Optional, Tuple
1818

1919
import torchx.specs as specs
20-
from torchx.cli.argparse_util import torchxconfig_run
20+
from torchx.cli.argparse_util import ArgOnceAction, torchxconfig_run
2121
from torchx.cli.cmd_base import SubCommand
2222
from torchx.cli.cmd_log import get_logs
2323
from torchx.runner import config, get_runner, Runner
@@ -133,6 +133,7 @@ def add_arguments(self, subparser: argparse.ArgumentParser) -> None:
133133
"-cfg",
134134
"--scheduler_args",
135135
type=str,
136+
action=ArgOnceAction,
136137
help="Arguments to pass to the scheduler (Ex:`cluster=foo,user=bar`)."
137138
" For a list of scheduler run options run: `torchx runopts`",
138139
)
@@ -165,6 +166,7 @@ def add_arguments(self, subparser: argparse.ArgumentParser) -> None:
165166
subparser.add_argument(
166167
"--parent_run_id",
167168
type=str,
169+
action=ArgOnceAction,
168170
help="optional parent run ID that this run belongs to."
169171
" It can be used to group runs for experiment tracking purposes",
170172
)

torchx/cli/test/argparse_util_test.py

+2-1
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,8 @@
1919
class ArgparseUtilTest(TestWithTmpDir):
2020
def setUp(self) -> None:
2121
super().setUp()
22-
argparse_util._torchxconfig._subcmd_configs.clear()
22+
argparse_util.torchxconfig._subcmd_configs.clear()
23+
argparse_util.torchxconfig.called_args = set()
2324

2425
def test_torchxconfig_action(self) -> None:
2526
with mock.patch(DEFAULT_CONFIG_DIRS, [str(self.tmpdir)]):

torchx/cli/test/cmd_run_test.py

+24
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,8 @@
2020
from typing import Generator
2121
from unittest.mock import MagicMock, patch
2222

23+
from torchx.cli.argparse_util import ArgOnceAction, torchxconfig
24+
2325
from torchx.cli.cmd_run import _parse_component_name_and_args, CmdBuiltins, CmdRun
2426
from torchx.schedulers.local_scheduler import SignalException
2527

@@ -45,6 +47,28 @@ def setUp(self) -> None:
4547

4648
def tearDown(self) -> None:
4749
shutil.rmtree(self.tmpdir, ignore_errors=True)
50+
ArgOnceAction.called_args = set()
51+
torchxconfig.called_args = set()
52+
53+
def test_run_with_multiple_scheduler_args(self) -> None:
54+
55+
args = ["--scheduler_args", "first_args", "--scheduler_args", "second_args"]
56+
with self.assertRaises(SystemExit) as cm:
57+
self.parser.parse_args(args)
58+
self.assertEqual(cm.exception.code, 1)
59+
60+
def test_run_with_multiple_schedule_args(self) -> None:
61+
62+
args = [
63+
"--scheduler",
64+
"local_conda",
65+
"--scheduler",
66+
"local_cwd",
67+
]
68+
69+
with self.assertRaises(SystemExit) as cm:
70+
self.parser.parse_args(args)
71+
self.assertEqual(cm.exception.code, 1)
4872

4973
def test_run_with_user_conf_abs_path(self) -> None:
5074
args = self.parser.parse_args(

0 commit comments

Comments
 (0)