We read every piece of feedback, and take your input very seriously.
To see all available qualifiers, see our documentation.
1 parent b10b7e1 commit 1866926Copy full SHA for 1866926
3 files changed
monai/transforms/post/dictionary.py
@@ -196,7 +196,7 @@ def __init__(
196
super().__init__(keys, allow_missing_keys)
197
self.argmax = ensure_tuple_rep(argmax, len(self.keys))
198
self.rankseg = ensure_tuple_rep(rankseg, len(self.keys))
199
- if any(argmax_ and rankseg_ for argmax_, rankseg_ in zip(self.argmax, self.rankseg)):
+ if any(argmax_ and rankseg_ for argmax_, rankseg_ in zip(self.argmax, self.rankseg, strict=True)):
200
raise ValueError("`rankseg=True` is incompatible with `argmax=True`.")
201
self.to_onehot = []
202
for flag in ensure_tuple_rep(to_onehot, len(self.keys)):
tests/transforms/test_as_discrete.py
@@ -67,23 +67,11 @@
67
)
68
69
TEST_CASES.append(
70
- [
71
- {"rankseg": False, "argmax": True},
72
- p([[[0.3, 0.6]], [[0.7, 0.4]]]),
73
- p([[[1.0, 0.0]]]),
74
- (1, 1, 2),
75
- ]
+ [{"rankseg": False, "argmax": True}, p([[[0.3, 0.6]], [[0.7, 0.4]]]), p([[[1.0, 0.0]]]), (1, 1, 2)]
76
77
78
if post_array.has_rankseg:
79
- TEST_CASES.append(
80
81
- {"rankseg": True},
82
83
- p([[[1.0, 1.0]]]),
84
85
86
- )
+ TEST_CASES.append([{"rankseg": True}, p([[[0.3, 0.6]], [[0.7, 0.4]]]), p([[[1.0, 1.0]]]), (1, 1, 2)])
87
88
89
class TestAsDiscrete(unittest.TestCase):
tests/transforms/test_as_discreted.py
@@ -91,14 +91,8 @@
91
92
[
93
{"keys": ["pred", "label"], "rankseg": [True, False]},
94
- {
95
- "pred": p([[[0.3, 0.6]], [[0.7, 0.4]]]),
96
- "label": p([[[0.0, 1.0]]]),
97
- },
98
99
- "pred": p([[[1.0, 1.0]]]),
100
101
+ {"pred": p([[[0.3, 0.6]], [[0.7, 0.4]]]), "label": p([[[0.0, 1.0]]])},
+ {"pred": p([[[1.0, 1.0]]]), "label": p([[[0.0, 1.0]]])},
102
(1, 1, 2),
103
]
104
@@ -116,9 +110,7 @@ def test_value_shape(self, input_param, test_input, output, expected_shape):
116
110
117
111
def test_rankseg_argmax_incompatible(self):
118
112
with self.assertRaises(ValueError):
119
- AsDiscreted(keys="pred", argmax=True, rankseg=True)(
120
- {"pred": [[[0.3, 0.6]], [[0.7, 0.4]]]}
121
113
+ AsDiscreted(keys="pred", argmax=True, rankseg=True)({"pred": [[[0.3, 0.6]], [[0.7, 0.4]]]})
122
114
123
115
def test_rankseg_missing_dependency(self):
124
with mock.patch("monai.transforms.post.array.has_rankseg", False):
0 commit comments