|
| 1 | +# Copyright (c) OpenMMLab. All rights reserved.Dict |
| 2 | +import numpy as np |
| 3 | +from typing import TYPE_CHECKING, Dict, List, Sequence |
| 4 | + |
| 5 | +from mmeval.core import BaseMetric |
| 6 | +from mmeval.utils import try_import |
| 7 | + |
| 8 | +if TYPE_CHECKING: |
| 9 | + import cv2 |
| 10 | +else: |
| 11 | + cv2 = try_import('cv2') |
| 12 | + |
| 13 | + |
| 14 | +class ConnectivityError(BaseMetric): |
| 15 | + """Connectivity error for evaluating alpha matte prediction. |
| 16 | +
|
| 17 | + Args: |
| 18 | + step (float): Step of threshold when computing intersection between |
| 19 | + `alpha` and `pred_alpha`. Default to 0.1 . |
| 20 | + norm_const (int): Divide the result to reduce its magnitude. |
| 21 | + Defaults to 1000 . |
| 22 | + **kwargs: Keyword parameters passed to :class:`BaseMetric`. |
| 23 | +
|
| 24 | + Note: |
| 25 | + The current implementation assumes the image / alpha / trimap |
| 26 | + a numpy array with pixel values ranging from 0 to 255. |
| 27 | +
|
| 28 | + The pred_alpha should be masked by trimap before passing |
| 29 | + into this metric. |
| 30 | +
|
| 31 | + The trimap is the most commonly used prior knowledge. As the |
| 32 | + name implies, trimap is a ternary graph and each pixel |
| 33 | + takes one of {0, 128, 255}, representing the foreground, the |
| 34 | + unknown and the background respectively. |
| 35 | +
|
| 36 | + Examples: |
| 37 | +
|
| 38 | + >>> from mmeval import ConnectivityError |
| 39 | + >>> import numpy as np |
| 40 | + >>> |
| 41 | + >>> connectivity_error = ConnectivityError() |
| 42 | + >>> pred_alpha = np.zeros((32, 32), dtype=np.uint8) |
| 43 | + >>> gt_alpha = np.ones((32, 32), dtype=np.uint8) * 255 |
| 44 | + >>> trimap = np.zeros((32, 32), dtype=np.uint8) |
| 45 | + >>> trimap[:16, :16] = 128 |
| 46 | + >>> trimap[16:, 16:] = 255 |
| 47 | + >>> connectivity_error(pred_alpha, gt_alpha, trimap) |
| 48 | + {'connectivity_error': ...} |
| 49 | + """ |
| 50 | + |
| 51 | + def __init__(self, |
| 52 | + step: float = 0.1, |
| 53 | + norm_const: int = 1000, |
| 54 | + **kwargs) -> None: |
| 55 | + super().__init__(**kwargs) |
| 56 | + self.step = step |
| 57 | + self.norm_const = norm_const |
| 58 | + |
| 59 | + if cv2 is None: |
| 60 | + raise ImportError(f'For availability of {self.__class__.__name__},' |
| 61 | + ' please pip install opencv-python first.') |
| 62 | + |
| 63 | + def add(self, pred_alphas: Sequence[np.ndarray], gt_alphas: Sequence[np.ndarray], trimaps: Sequence[np.ndarray]) -> None: # type: ignore # yapf: disable # noqa: E501 |
| 64 | + """Add ConnectivityError score of batch to ``self._results`` |
| 65 | +
|
| 66 | + Args: |
| 67 | + pred_alphas (Sequence[np.ndarray]): Predict the probability |
| 68 | + that pixels belong to the foreground. |
| 69 | + gt_alphas (Sequence[np.ndarray]): Probability that the actual |
| 70 | + pixel belongs to the foreground. |
| 71 | + trimaps (Sequence[np.ndarray]): Broadly speaking, the trimap |
| 72 | + consists of foreground and unknown region. |
| 73 | + """ |
| 74 | + |
| 75 | + for pred_alpha, gt_alpha, trimap in zip(pred_alphas, gt_alphas, |
| 76 | + trimaps): |
| 77 | + assert pred_alpha.shape == gt_alpha.shape, 'The shape of ' \ |
| 78 | + '`pred_alpha` and `gt_alpha` should be the same, but got: ' \ |
| 79 | + f'{pred_alpha.shape} and {gt_alpha.shape}' |
| 80 | + |
| 81 | + thresh_steps = np.arange(0, 1 + self.step, self.step) |
| 82 | + round_down_map = -np.ones_like(gt_alpha) |
| 83 | + for i in range(1, len(thresh_steps)): |
| 84 | + gt_alpha_thresh = gt_alpha >= thresh_steps[i] |
| 85 | + pred_alpha_thresh = pred_alpha >= thresh_steps[i] |
| 86 | + intersection = gt_alpha_thresh & pred_alpha_thresh |
| 87 | + intersection = intersection.astype(np.uint8) |
| 88 | + |
| 89 | + # connected components |
| 90 | + _, output, stats, _ = cv2.connectedComponentsWithStats( |
| 91 | + intersection, connectivity=4) |
| 92 | + # start from 1 in dim 0 to exclude background |
| 93 | + size = stats[1:, -1] |
| 94 | + |
| 95 | + # largest connected component of the intersection |
| 96 | + omega = np.zeros_like(gt_alpha) |
| 97 | + if len(size) != 0: |
| 98 | + max_id = np.argmax(size) |
| 99 | + # plus one to include background |
| 100 | + omega[output == max_id + 1] = 1 |
| 101 | + |
| 102 | + mask = (round_down_map == -1) & (omega == 0) |
| 103 | + round_down_map[mask] = thresh_steps[i - 1] |
| 104 | + |
| 105 | + round_down_map[round_down_map == -1] = 1 |
| 106 | + |
| 107 | + gt_alpha_diff = gt_alpha - round_down_map |
| 108 | + pred_alpha_diff = pred_alpha - round_down_map |
| 109 | + # only calculate difference larger than or equal to 0.15 |
| 110 | + gt_alpha_phi = 1 - gt_alpha_diff * (gt_alpha_diff >= 0.15) |
| 111 | + pred_alpha_phi = 1 - pred_alpha_diff * (pred_alpha_diff >= 0.15) |
| 112 | + |
| 113 | + connectivity_error = np.sum( |
| 114 | + np.abs(gt_alpha_phi - pred_alpha_phi) * (trimap == 128)) |
| 115 | + |
| 116 | + # divide by norm_const to reduce the magnitude of the result |
| 117 | + connectivity_error /= self.norm_const |
| 118 | + |
| 119 | + self._results.append(connectivity_error) |
| 120 | + |
| 121 | + def compute_metric(self, results: List) -> Dict[str, float]: |
| 122 | + """Compute the ConnectivityError metric. |
| 123 | +
|
| 124 | + Args: |
| 125 | + results (List): A list that consisting the ConnectivityError score. |
| 126 | + This list has already been synced across all ranks. |
| 127 | +
|
| 128 | + Returns: |
| 129 | + Dict[str, float]: The computed ConnectivityError metric. |
| 130 | + The keys are the names of the metrics, |
| 131 | + and the values are corresponding results. |
| 132 | + """ |
| 133 | + |
| 134 | + return {'connectivity_error': float(np.array(results).mean())} |
0 commit comments