Skip to content

Commit a5c34b9

Browse files
authored
Merge pull request #170 from Pennycook/custom-action-fixes
Fix several issues with custom argparse actions
2 parents b9b4e62 + 779e269 commit a5c34b9

File tree

2 files changed

+60
-12
lines changed

2 files changed

+60
-12
lines changed

codebasin/config.py

+16-7
Original file line numberDiff line numberDiff line change
@@ -110,7 +110,7 @@ def __call__(
110110
template = string.Template(self.format)
111111
split_values = [template.substitute(value=v) for v in split_values]
112112
if self.dest == "passes":
113-
passes = getattr(namespace, self.dest)
113+
passes = getattr(namespace, "_passes")
114114
passes[option_string] = split_values
115115
else:
116116
setattr(namespace, self.dest, split_values)
@@ -131,6 +131,8 @@ def __init__(
131131
):
132132
self.pattern = kwargs.pop("pattern", None)
133133
self.format = kwargs.pop("format", None)
134+
self.override = kwargs.pop("override", False)
135+
self.flag_name = option_strings[0]
134136
super().__init__(option_strings, dest, nargs=nargs, **kwargs)
135137

136138
def __call__(
@@ -147,13 +149,20 @@ def __call__(
147149
template = string.Template(self.format)
148150
matches = [template.substitute(value=v) for v in matches]
149151
if self.dest == "passes":
150-
passes = getattr(namespace, self.dest)
151-
if option_string not in passes:
152-
passes[option_string] = []
153-
passes[option_string].extend(matches)
152+
passes = getattr(namespace, "_passes")
153+
if self.flag_name not in passes:
154+
passes[self.flag_name] = []
155+
if self.override:
156+
passes[self.flag_name] = matches
157+
self.override = False
158+
else:
159+
passes[self.flag_name].extend(matches)
154160
else:
155-
dest = getattr(namespace, self.dest)
156-
dest.extend(matches)
161+
if self.override:
162+
setattr(namespace, self.dest, matches)
163+
else:
164+
dest = getattr(namespace, self.dest)
165+
dest.extend(matches)
157166

158167

159168
@dataclass

tests/compilers/test_actions.py

+44-5
Original file line numberDiff line numberDiff line change
@@ -30,7 +30,7 @@ def test_store_split_init(self):
3030
def test_store_split(self):
3131
"""Check that argparse calls store_split correctly"""
3232
namespace = argparse.Namespace()
33-
namespace.passes = {}
33+
namespace._passes = {}
3434

3535
parser = argparse.ArgumentParser()
3636
parser.add_argument("--foo", action=_StoreSplitAction, sep=",")
@@ -58,7 +58,7 @@ def test_store_split(self):
5858
args, _ = parser.parse_known_args(["--baz=1"], namespace)
5959

6060
args, _ = parser.parse_known_args(["--qux=one,two"], namespace)
61-
self.assertEqual(args.passes, {"--qux": ["one", "two"]})
61+
self.assertEqual(args._passes, {"--qux": ["one", "two"]})
6262

6363
def test_extend_match_init(self):
6464
"""Check that extend_match recognizes custom arguments"""
@@ -78,7 +78,6 @@ def test_extend_match_init(self):
7878
def test_extend_match(self):
7979
"""Check that argparse calls store_split correctly"""
8080
namespace = argparse.Namespace()
81-
namespace.passes = {}
8281

8382
parser = argparse.ArgumentParser()
8483
parser.add_argument(
@@ -100,6 +99,23 @@ def test_extend_match(self):
10099
action=_ExtendMatchAction,
101100
pattern=r"option_(\d+)",
102101
dest="passes",
102+
default=["0"],
103+
override=True,
104+
)
105+
parser.add_argument(
106+
"--one",
107+
"--two",
108+
action=_ExtendMatchAction,
109+
pattern=r"option_(\d+)",
110+
dest="passes",
111+
)
112+
parser.add_argument(
113+
"--default-override",
114+
action=_ExtendMatchAction,
115+
pattern=r"option_(\d+)",
116+
default=["0"],
117+
dest="override",
118+
override=True,
103119
)
104120

105121
args, _ = parser.parse_known_args(
@@ -117,11 +133,34 @@ def test_extend_match(self):
117133
with self.assertRaises(TypeError):
118134
args, _ = parser.parse_known_args(["--baz=1"], namespace)
119135

136+
# Check that the default values defined by flags always exists.
137+
# Note that the caller must initialize the default.
138+
namespace.override = ["0"]
139+
namespace._passes = {"--qux": ["0"]}
140+
args, _ = parser.parse_known_args(
141+
[],
142+
namespace,
143+
)
144+
self.assertEqual(args.override, ["0"])
145+
self.assertEqual(args._passes, {"--qux": ["0"]})
146+
147+
# Check that the default pass is overridden by use of --qux.
148+
# Note that the caller must initialize the default.
149+
namespace.override = ["0"]
150+
namespace._passes = {"--qux": ["0"]}
151+
args, _ = parser.parse_known_args(
152+
["--qux=option_1,option_2", "--default-override=option_1"],
153+
namespace,
154+
)
155+
self.assertEqual(args.override, ["1"])
156+
self.assertEqual(args._passes, {"--qux": ["1", "2"]})
157+
158+
namespace._passes = {}
120159
args, _ = parser.parse_known_args(
121-
["--qux=option_1,option_2"],
160+
["--one=option_1", "--two=option_2"],
122161
namespace,
123162
)
124-
self.assertEqual(args.passes, {"--qux": ["1", "2"]})
163+
self.assertEqual(args._passes, {"--one": ["1", "2"]})
125164

126165

127166
if __name__ == "__main__":

0 commit comments

Comments
 (0)