Skip to content

Commit 8271a19

Browse files
committed
5381 and deprecate compute_meandice compute_meaniou (#5382)
Signed-off-by: Wenqi Li <[email protected]> Fixes #5381 - non-breaking changes to rename `compute_meandice` to `compute_dice` - non-breaking changes to rename `compute_meaniou` to `compute_iou` <!--- Put an `x` in all the boxes that apply, and remove the not applicable items --> - [x] Non-breaking change (fix or new feature that would not break existing functionality). - [ ] Breaking change (fix or new feature that would cause existing functionality to change). - [x] New tests added to cover the changes. - [ ] Integration tests passed locally by running `./runtests.sh -f -u --net --coverage`. - [x] Quick tests passed locally by running `./runtests.sh --quick --unittests --disttests`. - [ ] In-line docstrings updated. - [ ] Documentation updated, tested `make html` command in the `docs/` folder. Signed-off-by: Wenqi Li <[email protected]>
1 parent bc09a4c commit 8271a19

7 files changed

+32
-18
lines changed

Diff for: CHANGELOG.md

+3
Original file line numberDiff line numberDiff line change
@@ -36,6 +36,9 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
3636
* Automatically infer device from the first item in random elastic deformation dict
3737
* Add channel dim in `ComputeHoVerMaps` and `ComputeHoVerMapsd`
3838
* Remove batch dim in `SobelGradients` and `SobelGradientsd`
39+
### Deprecated
40+
* Deprecating `compute_meandice`, `compute_meaniou` in `monai.metrics`, in favor of
41+
`compute_dice` and `compute_iou` respectively
3942

4043
## [1.0.0] - 2022-09-16
4144
### Added

Diff for: monai/metrics/__init__.py

+2-2
Original file line numberDiff line numberDiff line change
@@ -15,8 +15,8 @@
1515
from .froc import compute_fp_tp_probs, compute_froc_curve_data, compute_froc_score
1616
from .generalized_dice import GeneralizedDiceScore, compute_generalized_dice
1717
from .hausdorff_distance import HausdorffDistanceMetric, compute_hausdorff_distance, compute_percent_hausdorff_distance
18-
from .meandice import DiceMetric, compute_meandice
19-
from .meaniou import MeanIoU, compute_meaniou
18+
from .meandice import DiceMetric, compute_dice, compute_meandice
19+
from .meaniou import MeanIoU, compute_iou, compute_meaniou
2020
from .metric import Cumulative, CumulativeIterationMetric, IterationMetric, Metric
2121
from .regression import MAEMetric, MSEMetric, PSNRMetric, RMSEMetric, SSIMMetric
2222
from .rocauc import ROCAUCMetric, compute_roc_auc

Diff for: monai/metrics/active_learning_metrics.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,7 @@
2222

2323
class VarianceMetric(Metric):
2424
"""
25-
Compute the Variance of a given T-repeats N-dimensional array/tensor. The primary usage is as a uncertainty based
25+
Compute the Variance of a given T-repeats N-dimensional array/tensor. The primary usage is as an uncertainty based
2626
metric for Active Learning.
2727
2828
It can return the spatial variance/uncertainty map based on user choice or a single scalar value via mean/sum of the

Diff for: monai/metrics/meandice.py

+10-5
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,7 @@
1414
import torch
1515

1616
from monai.metrics.utils import do_metric_reduction, ignore_background, is_binary_tensor
17-
from monai.utils import MetricReduction
17+
from monai.utils import MetricReduction, deprecated
1818

1919
from .metric import CumulativeIterationMetric
2020

@@ -80,7 +80,7 @@ def _compute_tensor(self, y_pred: torch.Tensor, y: torch.Tensor): # type: ignor
8080
if dims < 3:
8181
raise ValueError(f"y_pred should have at least 3 dimensions (batch, channel, spatial), got {dims}.")
8282
# compute dice (BxC) for each channel for each batch
83-
return compute_meandice(
83+
return compute_dice(
8484
y_pred=y_pred, y=y, include_background=self.include_background, ignore_empty=self.ignore_empty
8585
)
8686

@@ -103,10 +103,10 @@ def aggregate(self, reduction: Union[MetricReduction, str, None] = None):
103103
return (f, not_nans) if self.get_not_nans else f
104104

105105

106-
def compute_meandice(
106+
def compute_dice(
107107
y_pred: torch.Tensor, y: torch.Tensor, include_background: bool = True, ignore_empty: bool = True
108108
) -> torch.Tensor:
109-
"""Computes Dice score metric from full size Tensor and collects average.
109+
"""Computes Dice score metric for a batch of predictions.
110110
111111
Args:
112112
y_pred: input data to compute, typical segmentation model output.
@@ -146,6 +146,11 @@ def compute_meandice(
146146
y_pred_o = torch.sum(y_pred, dim=reduce_axis)
147147
denominator = y_o + y_pred_o
148148

149-
if ignore_empty is True:
149+
if ignore_empty:
150150
return torch.where(y_o > 0, (2.0 * intersection) / denominator, torch.tensor(float("nan"), device=y_o.device))
151151
return torch.where(denominator > 0, (2.0 * intersection) / denominator, torch.tensor(1.0, device=y_o.device))
152+
153+
154+
@deprecated(since="1.0.0", msg_suffix="use `compute_dice` instead.")
155+
def compute_meandice(*args, **kwargs):
156+
return compute_dice(*args, **kwargs)

Diff for: monai/metrics/meaniou.py

+12-6
Original file line numberDiff line numberDiff line change
@@ -14,14 +14,15 @@
1414
import torch
1515

1616
from monai.metrics.utils import do_metric_reduction, ignore_background, is_binary_tensor
17-
from monai.utils import MetricReduction
17+
from monai.utils import MetricReduction, deprecated
1818

1919
from .metric import CumulativeIterationMetric
2020

2121

2222
class MeanIoU(CumulativeIterationMetric):
2323
"""
24-
Compute average IoU score between two tensors. It can support both multi-classes and multi-labels tasks.
24+
Compute average Intersection over Union (IoU) score between two tensors.
25+
It supports both multi-classes and multi-labels tasks.
2526
Input `y_pred` is compared with ground truth `y`.
2627
`y_pred` is expected to have binarized predictions and `y` should be in one-hot format. You can use suitable transforms
2728
in ``monai.transforms.post`` first to achieve binarized values.
@@ -80,7 +81,7 @@ def _compute_tensor(self, y_pred: torch.Tensor, y: torch.Tensor): # type: ignor
8081
if dims < 3:
8182
raise ValueError(f"y_pred should have at least 3 dimensions (batch, channel, spatial), got {dims}.")
8283
# compute IoU (BxC) for each channel for each batch
83-
return compute_meaniou(
84+
return compute_iou(
8485
y_pred=y_pred, y=y, include_background=self.include_background, ignore_empty=self.ignore_empty
8586
)
8687

@@ -103,10 +104,10 @@ def aggregate(self, reduction: Union[MetricReduction, str, None] = None):
103104
return (f, not_nans) if self.get_not_nans else f
104105

105106

106-
def compute_meaniou(
107+
def compute_iou(
107108
y_pred: torch.Tensor, y: torch.Tensor, include_background: bool = True, ignore_empty: bool = True
108109
) -> torch.Tensor:
109-
"""Computes IoU score metric from full size Tensor and collects average.
110+
"""Computes Intersection over Union (IoU) score metric from a batch of predictions.
110111
111112
Args:
112113
y_pred: input data to compute, typical segmentation model output.
@@ -146,6 +147,11 @@ def compute_meaniou(
146147
y_pred_o = torch.sum(y_pred, dim=reduce_axis)
147148
union = y_o + y_pred_o - intersection
148149

149-
if ignore_empty is True:
150+
if ignore_empty:
150151
return torch.where(y_o > 0, (intersection) / union, torch.tensor(float("nan"), device=y_o.device))
151152
return torch.where(union > 0, (intersection) / union, torch.tensor(1.0, device=y_o.device))
153+
154+
155+
@deprecated(since="1.0.0", msg_suffix="use `compute_iou` instead.")
156+
def compute_meaniou(*args, **kwargs):
157+
return compute_iou(*args, **kwargs)

Diff for: tests/test_compute_meandice.py

+2-2
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,7 @@
1515
import torch
1616
from parameterized import parameterized
1717

18-
from monai.metrics import DiceMetric, compute_meandice
18+
from monai.metrics import DiceMetric, compute_dice, compute_meandice
1919

2020
_device = "cuda:0" if torch.cuda.is_available() else "cpu"
2121
# keep background
@@ -187,7 +187,7 @@
187187
class TestComputeMeanDice(unittest.TestCase):
188188
@parameterized.expand([TEST_CASE_1, TEST_CASE_2, TEST_CASE_9, TEST_CASE_11, TEST_CASE_12])
189189
def test_value(self, input_data, expected_value):
190-
result = compute_meandice(**input_data)
190+
result = compute_dice(**input_data)
191191
np.testing.assert_allclose(result.cpu().numpy(), expected_value, atol=1e-4)
192192

193193
@parameterized.expand([TEST_CASE_3])

Diff for: tests/test_compute_meaniou.py

+2-2
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,7 @@
1515
import torch
1616
from parameterized import parameterized
1717

18-
from monai.metrics import MeanIoU, compute_meaniou
18+
from monai.metrics import MeanIoU, compute_iou, compute_meaniou
1919

2020
_device = "cuda:0" if torch.cuda.is_available() else "cpu"
2121
# keep background
@@ -192,7 +192,7 @@ def test_value(self, input_data, expected_value):
192192

193193
@parameterized.expand([TEST_CASE_3])
194194
def test_nans(self, input_data, expected_value):
195-
result = compute_meaniou(**input_data)
195+
result = compute_iou(**input_data)
196196
self.assertTrue(np.allclose(np.isnan(result.cpu().numpy()), expected_value))
197197

198198
# MeanIoU class tests

0 commit comments

Comments
 (0)