Skip to content

Commit 1503929

Browse files
authored
Merge pull request #267 from JdeRobot/dph/issue-240
Update metrics & improvements in datasets
2 parents e16439b + 0a3d19f commit 1503929

22 files changed

+651
-424
lines changed

detectionmetrics/cli/evaluate.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -33,8 +33,9 @@ def get_dataset(
3333
split,
3434
):
3535
# Check if required data is available
36-
if dataset_format == "gaia" and dataset_fname is None:
37-
raise ValueError("--dataset is required for 'gaia' format")
36+
if dataset_format == "gaia":
37+
if dataset_fname is None:
38+
raise ValueError("--dataset is required for 'gaia' format")
3839

3940
elif dataset_format == "rellis3d":
4041
if dataset_dir is None:

detectionmetrics/datasets/dataset.py

Lines changed: 67 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -28,6 +28,7 @@ def __init__(self, dataset: pd.DataFrame, dataset_dir: str, ontology: dict):
2828
self.dataset = dataset
2929
self.dataset_dir = os.path.abspath(dataset_dir)
3030
self.ontology = ontology
31+
self.has_label_count = all("label_count" in v for v in self.ontology.values())
3132

3233
def __len__(self):
3334
return len(self.dataset)
@@ -54,6 +55,16 @@ def append(self, new_dataset: Self):
5455
[self.dataset, new_dataset.dataset], verify_integrity=True
5556
)
5657

58+
def get_label_count(self, splits: List[str] = ["train", "val"]) -> np.ndarray:
59+
"""Get label count for each class in the dataset
60+
61+
:param splits: Dataset splits to consider, defaults to ["train", "val"]
62+
:type splits: List[str], optional
63+
:return: Label count for the dataset
64+
:rtype: np.ndarray
65+
"""
66+
raise NotImplementedError
67+
5768

5869
class ImageSegmentationDataset(SegmentationDataset):
5970
"""Parent image segmentation dataset class
@@ -94,7 +105,9 @@ def export(
94105
outdir: str,
95106
new_ontology: Optional[dict] = None,
96107
ontology_translation: Optional[dict] = None,
97-
ignored_classes: Optional[List[str]] = [],
108+
ignored_classes: Optional[List[str]] = None,
109+
resize: Optional[Tuple[int, int]] = None,
110+
include_label_count: bool = True,
98111
):
99112
"""Export dataset dataframe and image files in SemanticKITTI format. Optionally, modify ontology before exporting.
100113
@@ -106,6 +119,10 @@ def export(
106119
:type ontology_translation: Optional[dict], optional
107120
:param ignored_classes: Classes to ignore from the old ontology, defaults to []
108121
:type ignored_classes: Optional[List[str]], optional
122+
:param resize: Resize images and labels to the given dimensions, defaults to None
123+
:type resize: Optional[Tuple[int, int]], optional
124+
:param include_label_count: Whether to include class weights in the dataset, defaults to True
125+
:type include_label_count: bool, optional
109126
"""
110127
os.makedirs(outdir, exist_ok=True)
111128

@@ -117,7 +134,7 @@ def export(
117134
if ontology_translation is not None and new_ontology is None:
118135
raise ValueError("New ontology must be provided")
119136

120-
# Create ontology conversion lookup table
137+
# Create ontology conversion lookup table if needed and get number of classes
121138
ontology_conversion_lut = None
122139
if new_ontology is not None:
123140
ontology_conversion_lut = uc.get_ontology_conversion_lut(
@@ -126,6 +143,16 @@ def export(
126143
ontology_translation=ontology_translation,
127144
ignored_classes=ignored_classes,
128145
)
146+
n_classes = max(c["idx"] for c in new_ontology.values()) + 1
147+
else:
148+
n_classes = max(c["idx"] for c in self.ontology.values()) + 1
149+
150+
# Check if label count is missing and create empty array if needed
151+
label_count_missing = include_label_count and (
152+
not self.has_label_count or new_ontology is not None
153+
)
154+
if label_count_missing:
155+
label_count = np.zeros(n_classes, dtype=np.uint64)
129156

130157
# Export each sample
131158
for sample_name, row in pbar:
@@ -149,20 +176,29 @@ def export(
149176
label_fname = os.path.join(self.dataset_dir, label_fname)
150177

151178
# If image mode is not appropriate: read, convert, and rewrite image
152-
if uio.get_image_mode(image_fname) != "RGB":
179+
if uio.get_image_mode(image_fname) != "RGB" or resize is not None:
153180
image = cv2.imread(image_fname, 1) # convert to RGB
181+
182+
# Resize image if needed
183+
if resize is not None:
184+
image = cv2.resize(image, resize, interpolation=cv2.INTER_CUBIC)
154185
cv2.imwrite(os.path.join(outdir, rel_image_fname), image)
155-
# if image mode is appropriate simply copy image to new location
186+
187+
# If image mode is appropriate simply copy image to new location
156188
else:
157189
shutil.copy2(image_fname, os.path.join(outdir, rel_image_fname))
158190
self.dataset.at[sample_name, "image"] = rel_image_fname
159191

160-
# Same for labels (plus ontology conversion if needed)
192+
# Same for labels (plus ontology conversion and label count if needed)
161193
if label_fname:
162194
image_mode = uio.get_image_mode(label_fname)
163-
if image_mode == "L" and ontology_conversion_lut is None:
164-
shutil.copy2(label_fname, os.path.join(outdir, rel_label_fname))
165-
else:
195+
if (
196+
image_mode != "L"
197+
or ontology_conversion_lut is not None
198+
or resize is not None
199+
or label_count_missing
200+
):
201+
# Read and convert label from RGB to L
166202
if self.is_label_rgb:
167203
label_rgb = cv2.imread(label_fname)[:, :, ::-1]
168204
label = np.zeros(label_rgb.shape[:2], dtype=np.uint8)
@@ -172,16 +208,37 @@ def export(
172208
label[(label_rgb == rgb).all(axis=2)] = idx
173209
else:
174210
label = cv2.imread(label_fname, 0) # convert to L
175-
if ontology_conversion_lut is not None:
176-
label = ontology_conversion_lut[label]
211+
212+
# Convert label to new ontology if needed
213+
if ontology_conversion_lut is not None:
214+
label = ontology_conversion_lut[label]
215+
216+
# Resize label if needed
217+
if resize is not None:
218+
label = cv2.resize(
219+
label, resize, interpolation=cv2.INTER_NEAREST
220+
)
221+
222+
# Update label count if needed
223+
if label_count_missing:
224+
indices, counts = np.unique(label, return_counts=True)
225+
label_count[indices] += counts.astype(np.uint64)
226+
177227
cv2.imwrite(os.path.join(outdir, rel_label_fname), label)
228+
else:
229+
shutil.copy2(label_fname, os.path.join(outdir, rel_label_fname))
230+
178231
self.dataset.at[sample_name, "label"] = rel_label_fname
179232

180233
# Update dataset directory and ontology if needed
181234
self.dataset_dir = outdir
182235
self.ontology = new_ontology if new_ontology is not None else self.ontology
183236

184237
# Write ontology and store relative path in dataset attributes
238+
if label_count_missing:
239+
for class_data in self.ontology.values():
240+
class_data["label_count"] = int(label_count[class_data["idx"]])
241+
185242
ontology_fname = "ontology.json"
186243
self.dataset.attrs = {"ontology_fname": ontology_fname}
187244
uio.write_json(os.path.join(outdir, ontology_fname), self.ontology)

detectionmetrics/models/tensorflow.py

Lines changed: 64 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
1+
from collections import defaultdict
12
import os
23
import time
34
from typing import List, Optional, Tuple, Union
@@ -94,6 +95,8 @@ class ImageSegmentationTensorflowDataset:
9495
:type split: str, optional
9596
:param lut_ontology: LUT to transform label classes, defaults to None
9697
:type lut_ontology: dict, optional
98+
:param normalization: Parameters for normalizing input images, defaults to None
99+
:type normalization: dict, optional
97100
"""
98101

99102
def __init__(
@@ -103,8 +106,14 @@ def __init__(
103106
batch_size: int = 1,
104107
split: str = "all",
105108
lut_ontology: Optional[dict] = None,
109+
normalization: Optional[dict] = None,
106110
):
107111
self.image_size = image_size
112+
self.normalization = None
113+
if normalization is not None:
114+
mean = tf.constant(normalization["mean"], dtype=tf.float32)
115+
std = tf.constant(normalization["std"], dtype=tf.float32)
116+
self.normalization = {"mean": mean, "std": std}
108117

109118
# Filter split and make filenames global
110119
if split != "all":
@@ -155,9 +164,17 @@ def read_image(self, fname: str, label=False) -> tf.Tensor:
155164
# Resize (use NN to avoid interpolation when dealing with labels)
156165
method = "nearest" if label else "bilinear"
157166
image = tf_image.resize(images=image, size=self.image_size, method=method)
167+
168+
# If label, round values to avoid interpolation artifacts
158169
if label:
159170
image = tf.round(image)
160171

172+
# If normalization parameters are provided, normalize image
173+
else:
174+
if self.normalization is not None:
175+
image = tf.cast(image, tf.float32) / 255.0
176+
image = (image - self.normalization["mean"]) / self.normalization["std"]
177+
161178
return image
162179

163180
def load_data(
@@ -217,6 +234,11 @@ def t_in(image):
217234
tensor = tf.convert_to_tensor(image)
218235
tensor = tf_image.resize(images=tensor, size=self.model_cfg["image_size"])
219236
tensor = tf.expand_dims(tensor, axis=0)
237+
if "normalization" in self.model_cfg:
238+
mean = tf.constant(self.model_cfg["normalization"]["mean"])
239+
std = tf.constant(self.model_cfg["normalization"]["std"])
240+
tensor = tf.cast(tensor, tf.float32) / 255.0
241+
tensor = (tensor - mean) / std
220242
return tensor
221243

222244
self.t_in = t_in
@@ -275,18 +297,23 @@ def eval(
275297
batch_size=self.model_cfg.get("batch_size", 1),
276298
split=split,
277299
lut_ontology=lut_ontology,
300+
normalization=self.model_cfg.get("normalization", None),
278301
)
279302

303+
# Retrieve ignored label indices
304+
ignored_label_indices = []
305+
for ignored_class in self.model_cfg.get("ignored_classes", []):
306+
ignored_label_indices.append(dataset.ontology[ignored_class]["idx"])
307+
280308
# Init metrics
281309
results = {}
282-
iou = um.IoU(self.n_classes)
283-
cm = um.ConfusionMatrix(self.n_classes)
310+
metrics_factory = um.MetricsFactory(self.n_classes)
284311

285312
# Evaluation loop
286313
pbar = tqdm(dataset.dataset)
287314
for image, label in pbar:
288315
if self.model_type == "native":
289-
pred = self.model(image)
316+
pred = self.model(image, training=False)
290317
elif self.model_type == "compiled":
291318
pred = self.model.signatures["serving_default"](image)
292319
else:
@@ -295,37 +322,48 @@ def eval(
295322
if isinstance(pred, dict):
296323
pred = list(pred.values())[0]
297324

325+
# Get valid points masks depending on ignored label indices
326+
if ignored_label_indices:
327+
valid_mask = tf.ones_like(label, dtype=tf.bool)
328+
for idx in ignored_label_indices:
329+
valid_mask *= label != idx
330+
else:
331+
valid_mask = None
332+
298333
label = tf.squeeze(label, axis=3)
299334
pred = tf.argmax(pred, axis=3)
300-
cm.update(pred.numpy(), label.numpy())
301-
302-
pred = tf.one_hot(pred, self.n_classes)
303-
pred = tf.transpose(pred, perm=[0, 3, 1, 2])
335+
if valid_mask is not None:
336+
valid_mask = tf.squeeze(valid_mask, axis=3)
337+
metrics_factory.update(
338+
pred.numpy(),
339+
label.numpy(),
340+
valid_mask.numpy() if valid_mask is not None else None,
341+
)
304342

305-
label = tf.one_hot(label, self.n_classes)
306-
label = tf.transpose(label, perm=[0, 3, 1, 2])
343+
# Build results dataframe
344+
results = defaultdict(dict)
307345

308-
iou.update(pred.numpy(), label.numpy())
346+
# Add per class and global metrics
347+
for metric in metrics_factory.get_metric_names():
348+
per_class = metrics_factory.get_metric_per_name(metric, per_class=True)
309349

310-
# Get metrics results
311-
iou_per_class, iou = iou.compute()
312-
acc_per_class, acc = cm.get_accuracy()
313-
iou_per_class = [float(n) for n in iou_per_class]
314-
acc_per_class = [float(n) for n in acc_per_class]
350+
for class_name, class_data in self.ontology.items():
351+
results[class_name][metric] = float(per_class[class_data["idx"]])
315352

316-
# Build results dataframe
317-
results = {}
318-
for class_name, class_data in self.ontology.items():
319-
results[class_name] = {
320-
"iou": iou_per_class[class_data["idx"]],
321-
"acc": acc_per_class[class_data["idx"]],
322-
}
323-
results["global"] = {"iou": iou, "acc": acc}
353+
if metric not in ["tp", "fp", "fn", "tn"]:
354+
for avg_method in ["macro", "micro"]:
355+
results[avg_method][metric] = metrics_factory.get_averaged_metric(
356+
metric, avg_method
357+
)
324358

325-
results = pd.DataFrame(results)
326-
results.index.name = "metric"
359+
# Add confusion matrix
360+
for class_name_a, class_data_a in self.ontology.items():
361+
for class_name_b, class_data_b in self.ontology.items():
362+
results[class_name_a][class_name_b] = metrics_factory.confusion_matrix[
363+
class_data_a["idx"], class_data_b["idx"]
364+
]
327365

328-
return results
366+
return pd.DataFrame(results)
329367

330368
def get_computational_cost(self, runs: int = 30, warm_up_runs: int = 5) -> dict:
331369
"""Get different metrics related to the computational cost of the model

0 commit comments

Comments
 (0)