Skip to content

Commit d7203fb

Browse files
Merge pull request #37 from sentinel-hub/develop
Add resunet-a architecture used for field delineation
2 parents b47a3f4 + 7d8ea35 commit d7203fb

File tree

5 files changed

+638
-3
lines changed

5 files changed

+638
-3
lines changed

README.md

+1
Original file line numberDiff line numberDiff line change
@@ -60,6 +60,7 @@ Project also contains other folders:
6060
Segmentation models for land cover semantic segmentation:
6161
* **Fully-Convolutional-Network (FCN, a.k.a. U-net)**, vanilla implementation of method described in this [paper](https://arxiv.org/abs/1505.04597). This network expects 2D MSI images as inputs and predicts 2D label maps as output.
6262
* **Temporal FCN**, where the whole time-series is considered as a 3D MSI volume and convolutions are performed along the temporal dimension as well spatial dimension. The output of the network is a 2D label map as in previous cases. More details can be found in this [paper](https://www.researchgate.net/publication/333262625_Spatio-Temporal_Deep_Learning_An_Application_to_Land_Cover_Classification).
63+
* **ResUNet-a**, architecture proposed in Diakogiannis et al. ["ResUNet-a: A deep learning framework for semantic segmetnation of remotely sensed data"](https://www.sciencedirect.com/science/article/abs/pii/S0924271620300149). Original `mxnet` implementation can be found [here](https://github.com/feevos/resuneta).
6364

6465
Classification models for crop classification using time-series:
6566
* **TCN**: Implementation of the TCN network taken from the [keras-TCN implementation by Philippe Remy](https://github.com/philipperemy/keras-tcn).

eoflow/models/metrics.py

+149
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,13 @@
1+
import warnings
2+
3+
from typing import Any, Callable, List
4+
5+
from skimage import measure
6+
from scipy import ndimage
7+
18
import tensorflow as tf
29
import tensorflow_addons as tfa
10+
import numpy as np
311

412

513
class InitializableMetric(tf.keras.metrics.Metric):
@@ -156,3 +164,144 @@ def get_config(self):
156164
self.assert_initialized()
157165

158166
return self.metric.get_config()
167+
168+
169+
class GeometricMetrics(InitializableMetric):
170+
""""
171+
Implementation of Geometric error metrics. Oversegmentation, Undersegmentation, Border, Fragmentation errors.
172+
173+
The error metrics are based on a paper by C. Persello, A Novel Protocol for Accuracy Assessment in Classification of
174+
Very High Resolution Images (https://ieeexplore.ieee.org/document/5282610)
175+
"""
176+
177+
@staticmethod
178+
def _detect_edges(im: np.ndarray, thr: float = 0) -> np.ndarray:
179+
""" Edge detection function using the sobel operator. """
180+
sx = ndimage.sobel(im, axis=0, mode='constant')
181+
sy = ndimage.sobel(im, axis=1, mode='constant')
182+
sob = np.hypot(sx, sy)
183+
return sob > thr
184+
185+
@staticmethod
186+
def _segmentation_error(intersection_area: float, object_area: float) -> float:
187+
return 1. - intersection_area / object_area
188+
189+
@staticmethod
190+
def _intersection(mask1: np.ndarray, mask2: np.ndarray) -> float:
191+
return np.sum(np.logical_and(mask1, mask2))
192+
193+
def _border_err(self, border_ref_edge: np.ndarray, border_meas_edge: np.ndarray) -> float:
194+
ref_edge_size = np.sum(border_ref_edge)
195+
intersection = self._intersection(border_ref_edge, border_meas_edge)
196+
err = intersection / ref_edge_size if ref_edge_size != 0 else 0
197+
be = 1. - err
198+
return be
199+
200+
def _fragmentation_err(self, r: int, reference_mask: np.ndarray) -> float:
201+
if r <= 1:
202+
return 0
203+
den = np.sum(reference_mask) - self.pixel_size
204+
err = (r - 1.) / den if den > 0 else 0
205+
return err
206+
207+
@staticmethod
208+
def _validate_input(reference, measurement):
209+
if np.ndim(reference) != np.ndim(measurement):
210+
raise ValueError("Reference and measurement input shapes must match.")
211+
212+
def __init__(self, pixel_size: int = 1, edge_func: Callable = None, **edge_func_params: Any):
213+
214+
super().__init__(name='geometric_metrics', dtype=tf.float64)
215+
216+
self.oversegmentation_error = []
217+
self.undersegmentation_error = []
218+
self.border_error = []
219+
self.fragmentation_error = []
220+
221+
self.edge_func = self._detect_edges if edge_func is None else edge_func
222+
self.edge_func_params = edge_func_params
223+
self.pixel_size = pixel_size
224+
225+
def update_state(self, reference: np.ndarray, measurement: np.ndarray, encode_reference: bool = True,
226+
background_value: int = 0) -> None:
227+
""" Calculate the error metrics for a measurement and reference arrays. For each .
228+
229+
If encode_reference is set to True, connected components will be used to label objects in the reference and
230+
measurements.
231+
"""
232+
233+
if not tf.executing_eagerly():
234+
warnings.warn("Geometric metrics must be run with eager execution. If running as a compiled Keras model, "
235+
"enable eager execution with model.run_eagerly = True")
236+
237+
reference = reference.numpy() if isinstance(reference, tf.Tensor) else reference
238+
measurement = measurement.numpy() if isinstance(reference, tf.Tensor) else measurement
239+
240+
self._validate_input(reference, measurement)
241+
242+
for ref, meas in zip(reference, measurement):
243+
ref = ref
244+
meas = meas
245+
246+
if encode_reference:
247+
cc_reference = measure.label(ref, background=background_value)
248+
else:
249+
cc_reference = ref
250+
251+
cc_measurement = measure.label(meas, background=background_value)
252+
components_reference = set(np.unique(cc_reference)).difference([background_value])
253+
254+
ref_edges = self.edge_func(cc_reference)
255+
meas_edges = self.edge_func(cc_measurement)
256+
for component in components_reference:
257+
reference_mask = cc_reference == component
258+
259+
uniq, count = np.unique(cc_measurement[reference_mask & (cc_measurement != background_value)],
260+
return_counts=True)
261+
ref_area = np.sum(reference_mask)
262+
263+
max_interecting_measurement = uniq[count.argmax()] if len(count) > 0 else background_value
264+
meas_mask = cc_measurement == max_interecting_measurement
265+
meas_area = np.count_nonzero(cc_measurement == max_interecting_measurement)
266+
intersection_area = count.max() if len(count) > 0 else 0
267+
268+
self.oversegmentation_error.append(self._segmentation_error(intersection_area, ref_area))
269+
self.undersegmentation_error.append(self._segmentation_error(intersection_area, meas_area))
270+
border_ref_edge = ref_edges.squeeze() & reference_mask.squeeze()
271+
border_meas_edge = meas_edges.squeeze() & meas_mask.squeeze()
272+
273+
self.border_error.append(self._border_err(border_ref_edge, border_meas_edge))
274+
self.fragmentation_error.append(self._fragmentation_err(len(uniq), reference_mask))
275+
276+
def get_oversegmentation_error(self) -> float:
277+
""" Return oversegmentation error. """
278+
return np.array(self.oversegmentation_error).mean()
279+
280+
def get_undersegmentation_error(self) -> float:
281+
""" Return undersegmentation error. """
282+
283+
return np.array(self.undersegmentation_error).mean()
284+
285+
def get_border_error(self) -> float:
286+
""" Return border error. """
287+
288+
return np.array(self.border_error).mean()
289+
290+
def get_fragmentation_error(self) -> float:
291+
""" Return fragmentation error. """
292+
293+
return np.array(self.fragmentation_error).mean()
294+
295+
def result(self) -> List[float]:
296+
""" Return a list of values representing oversegmentation, undersegmentation, border, fragmentation errors. """
297+
298+
return [self.get_oversegmentation_error(),
299+
self.get_undersegmentation_error(),
300+
self.get_border_error(), self.get_fragmentation_error()]
301+
302+
def reset_states(self) -> None:
303+
""" Empty all the error arrays. """
304+
self.oversegmentation_error = []
305+
self.undersegmentation_error = []
306+
self.border_error = []
307+
self.fragmentation_error = []

0 commit comments

Comments
 (0)