Skip to content

Commit 0e8fce2

Browse files
xuan07472ice-tongzhouzaida
authored
[Feature] Add MattingMSE metric (#71)
* add mattingmse * modify config.yml * Modify some details * Add test_matting_mse.py and simplify matting_mse.py. * Synchronize config.yml in the main branch * improve the document * Fine-tune docstring * Apply suggestions from code review Modify format statement Co-authored-by: yancong <32220263+ice-tong@users.noreply.github.com> * update * Update mmeval/metrics/matting_mse.py * Update matting_mse.py update * Change metrics to lowercase Co-authored-by: yancong <32220263+ice-tong@users.noreply.github.com> Co-authored-by: Zaida Zhou <58739961+zhouzaida@users.noreply.github.com>
1 parent d14a5ef commit 0e8fce2

File tree

5 files changed

+109
-1
lines changed

5 files changed

+109
-1
lines changed

docs/en/api/metrics.rst

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -44,3 +44,4 @@ Metrics
4444
BLEU
4545
SAD
4646
GradientError
47+
MattingMSE

docs/zh_cn/api/metrics.rst

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -44,3 +44,4 @@ Metrics
4444
BLEU
4545
SAD
4646
GradientError
47+
MattingMSE

mmeval/metrics/__init__.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@
99
from .gradient_error import GradientError
1010
from .hmean_iou import HmeanIoU
1111
from .mae import MAE
12+
from .matting_mse import MattingMSE
1213
from .mean_iou import MeanIoU
1314
from .mse import MSE
1415
from .multi_label import AveragePrecision, MultiLabelMetric
@@ -27,5 +28,6 @@
2728
'F1Metric', 'HmeanIoU', 'SingleLabelMetric', 'COCODetectionMetric',
2829
'PCKAccuracy', 'MpiiPCKAccuracy', 'JhmdbPCKAccuracy', 'ProposalRecall',
2930
'PSNR', 'MAE', 'MSE', 'SSIM', 'SNR', 'MultiLabelMetric',
30-
'AveragePrecision', 'AVAMeanAP', 'BLEU', 'SAD', 'GradientError'
31+
'AveragePrecision', 'AVAMeanAP', 'BLEU', 'SAD', 'GradientError',
32+
'MattingMSE'
3133
]

mmeval/metrics/matting_mse.py

Lines changed: 87 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,87 @@
1+
# Copyright (c) OpenMMLab. All rights reserved.
2+
import numpy as np
3+
from typing import Dict, List, Sequence
4+
5+
from mmeval.core import BaseMetric
6+
7+
8+
class MattingMSE(BaseMetric):
9+
"""Mean Squared Error metric for image matting.
10+
11+
This metric computes the per-pixel squared error average across all
12+
pixels.
13+
i.e. mean((a-b)^2)
14+
15+
Args:
16+
**kwargs: Keyword parameters passed to :class:`BaseMetric`.
17+
18+
Note:
19+
The current implementation assumes the image / alpha / trimap
20+
a numpy array with pixel values ranging from 0 to 255.
21+
22+
The pred_alpha should be masked by trimap before passing
23+
into this metric.
24+
25+
The trimap is the most commonly used prior knowledge. As the
26+
name implies, trimap is a ternary graph and each pixel
27+
takes one of {0, 128, 255}, representing the foreground, the
28+
unknown and the background respectively.
29+
30+
Examples:
31+
32+
>>> from mmeval import MattingMSE
33+
>>> import numpy as np
34+
>>>
35+
>>> matting_mse = MattingMSE()
36+
>>> pred_alpha = np.zeros((32, 32), dtype=np.uint8)
37+
>>> gt_alpha = np.ones((32, 32), dtype=np.uint8) * 255
38+
>>> trimap = np.zeros((32, 32), dtype=np.uint8)
39+
>>> trimap[:16, :16] = 128
40+
>>> trimap[16:, 16:] = 255
41+
>>> matting_mse(pred_alpha, gt_alpha, trimap) # doctest: +ELLIPSIS
42+
{'matting_mse': ...}
43+
"""
44+
45+
def __init__(self, **kwargs) -> None:
46+
super().__init__(**kwargs)
47+
48+
def add(self, pred_alphas: Sequence[np.ndarray], gt_alphas: Sequence[np.ndarray], trimaps: Sequence[np.ndarray]) -> None: # type: ignore # yapf: disable # noqa: E501
49+
"""Add MattingMSE score of batch to ``self._results``
50+
51+
Args:
52+
pred_alphas (Sequence[np.ndarray]): Predict the probability
53+
that pixels belong to the foreground.
54+
gt_alphas (Sequence[np.ndarray]): Probability that the actual
55+
pixel belongs to the foreground.
56+
trimaps (Sequence[np.ndarray]): Broadly speaking, the trimap
57+
consists of foreground and unknown region.
58+
"""
59+
60+
for pred_alpha, gt_alpha, trimap in zip(pred_alphas, gt_alphas,
61+
trimaps):
62+
assert pred_alpha.shape == gt_alpha.shape, 'The shape of ' \
63+
'`pred_alpha` and `gt_alpha` should be the same, but got: ' \
64+
f'{pred_alpha.shape} and {gt_alpha.shape}'
65+
66+
weight_sum = (trimap == 128).sum()
67+
if weight_sum != 0:
68+
mse_result = ((pred_alpha - gt_alpha)**2).sum() / weight_sum
69+
else:
70+
mse_result = 0
71+
72+
self._results.append(mse_result)
73+
74+
def compute_metric(self, results: List) -> Dict[str, float]:
75+
"""Compute the MattingMSE metric.
76+
77+
Args:
78+
results (List): A list that consisting the MattingMSE score.
79+
This list has already been synced across all ranks.
80+
81+
Returns:
82+
Dict[str, float]: The computed MattingMSE metric.
83+
The keys are the names of the metrics,
84+
and the values are corresponding results.
85+
"""
86+
87+
return {'matting_mse': float(np.array(results).mean())}
Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,17 @@
1+
# Copyright (c) OpenMMLab. All rights reserved.
2+
import numpy as np
3+
4+
from mmeval.metrics import MattingMSE
5+
6+
7+
def test_matting_mse():
8+
pred_alpha = np.zeros((32, 32), dtype=np.uint8)
9+
gt_alpha = np.ones((32, 32), dtype=np.uint8) * 255
10+
trimap = np.zeros((32, 32), dtype=np.uint8)
11+
trimap[:16, :16] = 128
12+
trimap[16:, 16:] = 255
13+
14+
matting_mse = MattingMSE()
15+
metric_results = matting_mse(pred_alpha, gt_alpha, trimap)
16+
assert isinstance(metric_results, dict)
17+
np.testing.assert_almost_equal(metric_results['matting_mse'], 1.0)

0 commit comments

Comments
 (0)