|
| 1 | +import warnings |
| 2 | + |
| 3 | +from typing import Any, Callable, List |
| 4 | + |
| 5 | +from skimage import measure |
| 6 | +from scipy import ndimage |
| 7 | + |
1 | 8 | import tensorflow as tf
|
2 | 9 | import tensorflow_addons as tfa
|
| 10 | +import numpy as np |
3 | 11 |
|
4 | 12 |
|
5 | 13 | class InitializableMetric(tf.keras.metrics.Metric):
|
@@ -156,3 +164,144 @@ def get_config(self):
|
156 | 164 | self.assert_initialized()
|
157 | 165 |
|
158 | 166 | return self.metric.get_config()
|
| 167 | + |
| 168 | + |
| 169 | +class GeometricMetrics(InitializableMetric): |
| 170 | + """" |
| 171 | + Implementation of Geometric error metrics. Oversegmentation, Undersegmentation, Border, Fragmentation errors. |
| 172 | +
|
| 173 | + The error metrics are based on a paper by C. Persello, A Novel Protocol for Accuracy Assessment in Classification of |
| 174 | + Very High Resolution Images (https://ieeexplore.ieee.org/document/5282610) |
| 175 | + """ |
| 176 | + |
| 177 | + @staticmethod |
| 178 | + def _detect_edges(im: np.ndarray, thr: float = 0) -> np.ndarray: |
| 179 | + """ Edge detection function using the sobel operator. """ |
| 180 | + sx = ndimage.sobel(im, axis=0, mode='constant') |
| 181 | + sy = ndimage.sobel(im, axis=1, mode='constant') |
| 182 | + sob = np.hypot(sx, sy) |
| 183 | + return sob > thr |
| 184 | + |
| 185 | + @staticmethod |
| 186 | + def _segmentation_error(intersection_area: float, object_area: float) -> float: |
| 187 | + return 1. - intersection_area / object_area |
| 188 | + |
| 189 | + @staticmethod |
| 190 | + def _intersection(mask1: np.ndarray, mask2: np.ndarray) -> float: |
| 191 | + return np.sum(np.logical_and(mask1, mask2)) |
| 192 | + |
| 193 | + def _border_err(self, border_ref_edge: np.ndarray, border_meas_edge: np.ndarray) -> float: |
| 194 | + ref_edge_size = np.sum(border_ref_edge) |
| 195 | + intersection = self._intersection(border_ref_edge, border_meas_edge) |
| 196 | + err = intersection / ref_edge_size if ref_edge_size != 0 else 0 |
| 197 | + be = 1. - err |
| 198 | + return be |
| 199 | + |
| 200 | + def _fragmentation_err(self, r: int, reference_mask: np.ndarray) -> float: |
| 201 | + if r <= 1: |
| 202 | + return 0 |
| 203 | + den = np.sum(reference_mask) - self.pixel_size |
| 204 | + err = (r - 1.) / den if den > 0 else 0 |
| 205 | + return err |
| 206 | + |
| 207 | + @staticmethod |
| 208 | + def _validate_input(reference, measurement): |
| 209 | + if np.ndim(reference) != np.ndim(measurement): |
| 210 | + raise ValueError("Reference and measurement input shapes must match.") |
| 211 | + |
| 212 | + def __init__(self, pixel_size: int = 1, edge_func: Callable = None, **edge_func_params: Any): |
| 213 | + |
| 214 | + super().__init__(name='geometric_metrics', dtype=tf.float64) |
| 215 | + |
| 216 | + self.oversegmentation_error = [] |
| 217 | + self.undersegmentation_error = [] |
| 218 | + self.border_error = [] |
| 219 | + self.fragmentation_error = [] |
| 220 | + |
| 221 | + self.edge_func = self._detect_edges if edge_func is None else edge_func |
| 222 | + self.edge_func_params = edge_func_params |
| 223 | + self.pixel_size = pixel_size |
| 224 | + |
| 225 | + def update_state(self, reference: np.ndarray, measurement: np.ndarray, encode_reference: bool = True, |
| 226 | + background_value: int = 0) -> None: |
| 227 | + """ Calculate the error metrics for a measurement and reference arrays. For each . |
| 228 | +
|
| 229 | + If encode_reference is set to True, connected components will be used to label objects in the reference and |
| 230 | + measurements. |
| 231 | + """ |
| 232 | + |
| 233 | + if not tf.executing_eagerly(): |
| 234 | + warnings.warn("Geometric metrics must be run with eager execution. If running as a compiled Keras model, " |
| 235 | + "enable eager execution with model.run_eagerly = True") |
| 236 | + |
| 237 | + reference = reference.numpy() if isinstance(reference, tf.Tensor) else reference |
| 238 | + measurement = measurement.numpy() if isinstance(reference, tf.Tensor) else measurement |
| 239 | + |
| 240 | + self._validate_input(reference, measurement) |
| 241 | + |
| 242 | + for ref, meas in zip(reference, measurement): |
| 243 | + ref = ref |
| 244 | + meas = meas |
| 245 | + |
| 246 | + if encode_reference: |
| 247 | + cc_reference = measure.label(ref, background=background_value) |
| 248 | + else: |
| 249 | + cc_reference = ref |
| 250 | + |
| 251 | + cc_measurement = measure.label(meas, background=background_value) |
| 252 | + components_reference = set(np.unique(cc_reference)).difference([background_value]) |
| 253 | + |
| 254 | + ref_edges = self.edge_func(cc_reference) |
| 255 | + meas_edges = self.edge_func(cc_measurement) |
| 256 | + for component in components_reference: |
| 257 | + reference_mask = cc_reference == component |
| 258 | + |
| 259 | + uniq, count = np.unique(cc_measurement[reference_mask & (cc_measurement != background_value)], |
| 260 | + return_counts=True) |
| 261 | + ref_area = np.sum(reference_mask) |
| 262 | + |
| 263 | + max_interecting_measurement = uniq[count.argmax()] if len(count) > 0 else background_value |
| 264 | + meas_mask = cc_measurement == max_interecting_measurement |
| 265 | + meas_area = np.count_nonzero(cc_measurement == max_interecting_measurement) |
| 266 | + intersection_area = count.max() if len(count) > 0 else 0 |
| 267 | + |
| 268 | + self.oversegmentation_error.append(self._segmentation_error(intersection_area, ref_area)) |
| 269 | + self.undersegmentation_error.append(self._segmentation_error(intersection_area, meas_area)) |
| 270 | + border_ref_edge = ref_edges.squeeze() & reference_mask.squeeze() |
| 271 | + border_meas_edge = meas_edges.squeeze() & meas_mask.squeeze() |
| 272 | + |
| 273 | + self.border_error.append(self._border_err(border_ref_edge, border_meas_edge)) |
| 274 | + self.fragmentation_error.append(self._fragmentation_err(len(uniq), reference_mask)) |
| 275 | + |
| 276 | + def get_oversegmentation_error(self) -> float: |
| 277 | + """ Return oversegmentation error. """ |
| 278 | + return np.array(self.oversegmentation_error).mean() |
| 279 | + |
| 280 | + def get_undersegmentation_error(self) -> float: |
| 281 | + """ Return undersegmentation error. """ |
| 282 | + |
| 283 | + return np.array(self.undersegmentation_error).mean() |
| 284 | + |
| 285 | + def get_border_error(self) -> float: |
| 286 | + """ Return border error. """ |
| 287 | + |
| 288 | + return np.array(self.border_error).mean() |
| 289 | + |
| 290 | + def get_fragmentation_error(self) -> float: |
| 291 | + """ Return fragmentation error. """ |
| 292 | + |
| 293 | + return np.array(self.fragmentation_error).mean() |
| 294 | + |
| 295 | + def result(self) -> List[float]: |
| 296 | + """ Return a list of values representing oversegmentation, undersegmentation, border, fragmentation errors. """ |
| 297 | + |
| 298 | + return [self.get_oversegmentation_error(), |
| 299 | + self.get_undersegmentation_error(), |
| 300 | + self.get_border_error(), self.get_fragmentation_error()] |
| 301 | + |
| 302 | + def reset_states(self) -> None: |
| 303 | + """ Empty all the error arrays. """ |
| 304 | + self.oversegmentation_error = [] |
| 305 | + self.undersegmentation_error = [] |
| 306 | + self.border_error = [] |
| 307 | + self.fragmentation_error = [] |
0 commit comments