Skip to content

Commit e635ba6

Browse files
xuan07472ice-tongzhouzaida
authored
[Feature] Add ConnectivityError metric (#79)
* add connectivity error * Apply suggestions from code review Co-authored-by: yancong <32220263+ice-tong@users.noreply.github.com> * Correct the Args * Change metrics to lowercase * Change metrics to lowercase Co-authored-by: yancong <32220263+ice-tong@users.noreply.github.com> Co-authored-by: Zaida Zhou <58739961+zhouzaida@users.noreply.github.com>
1 parent 0e8fce2 commit e635ba6

File tree

5 files changed

+155
-1
lines changed

5 files changed

+155
-1
lines changed

docs/en/api/metrics.rst

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -45,3 +45,4 @@ Metrics
4545
SAD
4646
GradientError
4747
MattingMSE
48+
ConnectivityError

docs/zh_cn/api/metrics.rst

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -45,3 +45,4 @@ Metrics
4545
SAD
4646
GradientError
4747
MattingMSE
48+
ConnectivityError

mmeval/metrics/__init__.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44
from .ava_map import AVAMeanAP
55
from .bleu import BLEU
66
from .coco_detection import COCODetectionMetric
7+
from .connectivity_error import ConnectivityError
78
from .end_point_error import EndPointError
89
from .f_metric import F1Metric
910
from .gradient_error import GradientError
@@ -29,5 +30,5 @@
2930
'PCKAccuracy', 'MpiiPCKAccuracy', 'JhmdbPCKAccuracy', 'ProposalRecall',
3031
'PSNR', 'MAE', 'MSE', 'SSIM', 'SNR', 'MultiLabelMetric',
3132
'AveragePrecision', 'AVAMeanAP', 'BLEU', 'SAD', 'GradientError',
32-
'MattingMSE'
33+
'MattingMSE', 'ConnectivityError'
3334
]
Lines changed: 134 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,134 @@
1+
# Copyright (c) OpenMMLab. All rights reserved.Dict
2+
import numpy as np
3+
from typing import TYPE_CHECKING, Dict, List, Sequence
4+
5+
from mmeval.core import BaseMetric
6+
from mmeval.utils import try_import
7+
8+
if TYPE_CHECKING:
9+
import cv2
10+
else:
11+
cv2 = try_import('cv2')
12+
13+
14+
class ConnectivityError(BaseMetric):
15+
"""Connectivity error for evaluating alpha matte prediction.
16+
17+
Args:
18+
step (float): Step of threshold when computing intersection between
19+
`alpha` and `pred_alpha`. Default to 0.1 .
20+
norm_const (int): Divide the result to reduce its magnitude.
21+
Defaults to 1000 .
22+
**kwargs: Keyword parameters passed to :class:`BaseMetric`.
23+
24+
Note:
25+
The current implementation assumes the image / alpha / trimap
26+
a numpy array with pixel values ranging from 0 to 255.
27+
28+
The pred_alpha should be masked by trimap before passing
29+
into this metric.
30+
31+
The trimap is the most commonly used prior knowledge. As the
32+
name implies, trimap is a ternary graph and each pixel
33+
takes one of {0, 128, 255}, representing the foreground, the
34+
unknown and the background respectively.
35+
36+
Examples:
37+
38+
>>> from mmeval import ConnectivityError
39+
>>> import numpy as np
40+
>>>
41+
>>> connectivity_error = ConnectivityError()
42+
>>> pred_alpha = np.zeros((32, 32), dtype=np.uint8)
43+
>>> gt_alpha = np.ones((32, 32), dtype=np.uint8) * 255
44+
>>> trimap = np.zeros((32, 32), dtype=np.uint8)
45+
>>> trimap[:16, :16] = 128
46+
>>> trimap[16:, 16:] = 255
47+
>>> connectivity_error(pred_alpha, gt_alpha, trimap)
48+
{'connectivity_error': ...}
49+
"""
50+
51+
def __init__(self,
52+
step: float = 0.1,
53+
norm_const: int = 1000,
54+
**kwargs) -> None:
55+
super().__init__(**kwargs)
56+
self.step = step
57+
self.norm_const = norm_const
58+
59+
if cv2 is None:
60+
raise ImportError(f'For availability of {self.__class__.__name__},'
61+
' please pip install opencv-python first.')
62+
63+
def add(self, pred_alphas: Sequence[np.ndarray], gt_alphas: Sequence[np.ndarray], trimaps: Sequence[np.ndarray]) -> None: # type: ignore # yapf: disable # noqa: E501
64+
"""Add ConnectivityError score of batch to ``self._results``
65+
66+
Args:
67+
pred_alphas (Sequence[np.ndarray]): Predict the probability
68+
that pixels belong to the foreground.
69+
gt_alphas (Sequence[np.ndarray]): Probability that the actual
70+
pixel belongs to the foreground.
71+
trimaps (Sequence[np.ndarray]): Broadly speaking, the trimap
72+
consists of foreground and unknown region.
73+
"""
74+
75+
for pred_alpha, gt_alpha, trimap in zip(pred_alphas, gt_alphas,
76+
trimaps):
77+
assert pred_alpha.shape == gt_alpha.shape, 'The shape of ' \
78+
'`pred_alpha` and `gt_alpha` should be the same, but got: ' \
79+
f'{pred_alpha.shape} and {gt_alpha.shape}'
80+
81+
thresh_steps = np.arange(0, 1 + self.step, self.step)
82+
round_down_map = -np.ones_like(gt_alpha)
83+
for i in range(1, len(thresh_steps)):
84+
gt_alpha_thresh = gt_alpha >= thresh_steps[i]
85+
pred_alpha_thresh = pred_alpha >= thresh_steps[i]
86+
intersection = gt_alpha_thresh & pred_alpha_thresh
87+
intersection = intersection.astype(np.uint8)
88+
89+
# connected components
90+
_, output, stats, _ = cv2.connectedComponentsWithStats(
91+
intersection, connectivity=4)
92+
# start from 1 in dim 0 to exclude background
93+
size = stats[1:, -1]
94+
95+
# largest connected component of the intersection
96+
omega = np.zeros_like(gt_alpha)
97+
if len(size) != 0:
98+
max_id = np.argmax(size)
99+
# plus one to include background
100+
omega[output == max_id + 1] = 1
101+
102+
mask = (round_down_map == -1) & (omega == 0)
103+
round_down_map[mask] = thresh_steps[i - 1]
104+
105+
round_down_map[round_down_map == -1] = 1
106+
107+
gt_alpha_diff = gt_alpha - round_down_map
108+
pred_alpha_diff = pred_alpha - round_down_map
109+
# only calculate difference larger than or equal to 0.15
110+
gt_alpha_phi = 1 - gt_alpha_diff * (gt_alpha_diff >= 0.15)
111+
pred_alpha_phi = 1 - pred_alpha_diff * (pred_alpha_diff >= 0.15)
112+
113+
connectivity_error = np.sum(
114+
np.abs(gt_alpha_phi - pred_alpha_phi) * (trimap == 128))
115+
116+
# divide by norm_const to reduce the magnitude of the result
117+
connectivity_error /= self.norm_const
118+
119+
self._results.append(connectivity_error)
120+
121+
def compute_metric(self, results: List) -> Dict[str, float]:
122+
"""Compute the ConnectivityError metric.
123+
124+
Args:
125+
results (List): A list that consisting the ConnectivityError score.
126+
This list has already been synced across all ranks.
127+
128+
Returns:
129+
Dict[str, float]: The computed ConnectivityError metric.
130+
The keys are the names of the metrics,
131+
and the values are corresponding results.
132+
"""
133+
134+
return {'connectivity_error': float(np.array(results).mean())}
Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,17 @@
1+
# Copyright (c) OpenMMLab. All rights reserved.
2+
import numpy as np
3+
4+
from mmeval.metrics import ConnectivityError
5+
6+
7+
def test_connectivity_error():
8+
pred_alpha = np.zeros((32, 32), dtype=np.uint8)
9+
gt_alpha = np.ones((32, 32), dtype=np.uint8) * 255
10+
trimap = np.zeros((32, 32), dtype=np.uint8)
11+
trimap[:16, :16] = 128
12+
trimap[16:, 16:] = 255
13+
14+
connectivity_error = ConnectivityError()
15+
metric_results = connectivity_error(pred_alpha, gt_alpha, trimap)
16+
assert isinstance(metric_results, dict)
17+
np.testing.assert_almost_equal(metric_results['connectivity_error'], 0.008)

0 commit comments

Comments
 (0)