Skip to content

Commit 1866926

Browse files
committed
Apply codeformat fixes and tighten validation
Signed-off-by: Zixun Wang <craddywang@gmail.com>
1 parent b10b7e1 commit 1866926

3 files changed

Lines changed: 6 additions & 26 deletions

File tree

monai/transforms/post/dictionary.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -196,7 +196,7 @@ def __init__(
196196
super().__init__(keys, allow_missing_keys)
197197
self.argmax = ensure_tuple_rep(argmax, len(self.keys))
198198
self.rankseg = ensure_tuple_rep(rankseg, len(self.keys))
199-
if any(argmax_ and rankseg_ for argmax_, rankseg_ in zip(self.argmax, self.rankseg)):
199+
if any(argmax_ and rankseg_ for argmax_, rankseg_ in zip(self.argmax, self.rankseg, strict=True)):
200200
raise ValueError("`rankseg=True` is incompatible with `argmax=True`.")
201201
self.to_onehot = []
202202
for flag in ensure_tuple_rep(to_onehot, len(self.keys)):

tests/transforms/test_as_discrete.py

Lines changed: 2 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -67,23 +67,11 @@
6767
)
6868

6969
TEST_CASES.append(
70-
[
71-
{"rankseg": False, "argmax": True},
72-
p([[[0.3, 0.6]], [[0.7, 0.4]]]),
73-
p([[[1.0, 0.0]]]),
74-
(1, 1, 2),
75-
]
70+
[{"rankseg": False, "argmax": True}, p([[[0.3, 0.6]], [[0.7, 0.4]]]), p([[[1.0, 0.0]]]), (1, 1, 2)]
7671
)
7772

7873
if post_array.has_rankseg:
79-
TEST_CASES.append(
80-
[
81-
{"rankseg": True},
82-
p([[[0.3, 0.6]], [[0.7, 0.4]]]),
83-
p([[[1.0, 1.0]]]),
84-
(1, 1, 2),
85-
]
86-
)
74+
TEST_CASES.append([{"rankseg": True}, p([[[0.3, 0.6]], [[0.7, 0.4]]]), p([[[1.0, 1.0]]]), (1, 1, 2)])
8775

8876

8977
class TestAsDiscrete(unittest.TestCase):

tests/transforms/test_as_discreted.py

Lines changed: 3 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -91,14 +91,8 @@
9191
TEST_CASES.append(
9292
[
9393
{"keys": ["pred", "label"], "rankseg": [True, False]},
94-
{
95-
"pred": p([[[0.3, 0.6]], [[0.7, 0.4]]]),
96-
"label": p([[[0.0, 1.0]]]),
97-
},
98-
{
99-
"pred": p([[[1.0, 1.0]]]),
100-
"label": p([[[0.0, 1.0]]]),
101-
},
94+
{"pred": p([[[0.3, 0.6]], [[0.7, 0.4]]]), "label": p([[[0.0, 1.0]]])},
95+
{"pred": p([[[1.0, 1.0]]]), "label": p([[[0.0, 1.0]]])},
10296
(1, 1, 2),
10397
]
10498
)
@@ -116,9 +110,7 @@ def test_value_shape(self, input_param, test_input, output, expected_shape):
116110

117111
def test_rankseg_argmax_incompatible(self):
118112
with self.assertRaises(ValueError):
119-
AsDiscreted(keys="pred", argmax=True, rankseg=True)(
120-
{"pred": [[[0.3, 0.6]], [[0.7, 0.4]]]}
121-
)
113+
AsDiscreted(keys="pred", argmax=True, rankseg=True)({"pred": [[[0.3, 0.6]], [[0.7, 0.4]]]})
122114

123115
def test_rankseg_missing_dependency(self):
124116
with mock.patch("monai.transforms.post.array.has_rankseg", False):

0 commit comments

Comments
 (0)