Skip to content

Commit da32556

Browse files
authored
add IoU to metrax (#87)
* add IoU to metrax * remove unused imports * modify image_metrics_test * remove extra line * use keras instead of tf.keras in test * add description to global variables * add space
1 parent 664732d commit da32556

File tree

6 files changed

+463
-33
lines changed

6 files changed

+463
-33
lines changed

src/metrax/__init__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,7 @@
2626
AveragePrecisionAtK = ranking_metrics.AveragePrecisionAtK
2727
BLEU = nlp_metrics.BLEU
2828
DCGAtK = ranking_metrics.DCGAtK
29+
IoU = image_metrics.IoU
2930
MAE = regression_metrics.MAE
3031
MRR = ranking_metrics.MRR
3132
MSE = regression_metrics.MSE
@@ -51,6 +52,7 @@
5152
"AveragePrecisionAtK",
5253
"BLEU",
5354
"DCGAtK",
55+
"IoU",
5456
"MAE",
5557
"MRR",
5658
"MSE",

src/metrax/image_metrics.py

Lines changed: 152 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -314,11 +314,11 @@ def _calculate_ssim_for_channel(x_ch, y_ch, conv_kernel, c1, c2):
314314
return jnp.mean(ssim_scores_stacked, axis=-1) # (batch,)
315315

316316
@classmethod
317-
def from_model_output( # type: ignore[override]
317+
def from_model_output(
318318
cls,
319-
predictions: jax.Array, # Represents predicted images (y_pred)
320-
targets: jax.Array, # Represents ground truth images (y_true)
321-
max_val: float, # Dynamic range of pixel values
319+
predictions: jax.Array,
320+
targets: jax.Array,
321+
max_val: float,
322322
filter_size: int = 11,
323323
filter_sigma: float = 1.5,
324324
k1: float = 0.01,
@@ -360,3 +360,151 @@ def from_model_output( # type: ignore[override]
360360
k2=k2,
361361
)
362362
return super().from_model_output(values=batch_ssim_values)
363+
364+
365+
@flax.struct.dataclass
366+
class IoU(base.Average):
367+
r"""Measures Intersection over Union (IoU) for semantic segmentation.
368+
369+
The general formula for IoU for a single class is:
370+
$IoU_{class} = \frac{TP}{TP + FP + FN}$
371+
where TP, FP, FN are True Positives, False Positives, and False Negatives.
372+
373+
**Per-Batch Processing:**
374+
For each input batch, a mean IoU is calculated. This involves:
375+
1. Aggregating TP, FP, and FN pixel counts for each specified target class
376+
(from the required `target_class_ids` list) across all samples within the
377+
batch.
378+
2. Computing IoU for each of these classes using the batch-aggregated counts:
379+
$IoU_{class} = \frac{TP}{TP + FP + FN + \epsilon}$.
380+
3. Averaging these per-class IoU scores to get a single value for the batch.
381+
- If `target_class_ids` is empty, an array of zeros of shape `(B,)`
382+
(where `B` is batch size) is produced by `_calculate_iou`.
383+
- Otherwise, a scalar `jnp.ndarray` (shape `()`) representing the mean
384+
IoU is produced.
385+
386+
**Accumulation & Final Metric:**
387+
This class inherits from `base.Average`. It accumulates the results from
388+
per-batch processing and `compute()` returns the final mean IoU as a scalar
389+
`jnp.ndarray` (shape `()`).
390+
"""
391+
392+
@staticmethod
393+
def _calculate_iou(
394+
targets: jnp.ndarray,
395+
predictions: jnp.ndarray,
396+
target_class_ids: jnp.ndarray,
397+
epsilon: float = 1e-7,
398+
) -> jnp.ndarray:
399+
r"""Computes mean IoU for a processed batch by class-wise aggregation using jax.vmap.
400+
401+
Per-batch processing: For each target class in the provided
402+
`target_class_ids` list, True Positives (TP), False Positives (FP), and
403+
False Negatives (FN) are summed across all items in the input batch.
404+
The IoU for that class is $TP / (TP + FP + FN + \epsilon)$.
405+
These per-class IoU scores are then averaged. If `target_class_ids` is
406+
empty, a scalar 0.0 is returned.
407+
408+
Args:
409+
targets: Ground truth segmentation masks. Shape is `(B, H, W)`, integer
410+
class labels. (B: batch size, H: height, W: width)
411+
predictions: Predicted segmentation masks. Shape is `(B, H, W)`, integer
412+
class labels.
413+
target_class_ids: An array of integer class IDs for which to compute IoU.
414+
epsilon: Small float added to the denominator for numerical stability.
415+
Default is `1e-7`.
416+
417+
Returns:
418+
scalar `jnp.ndarray` (shape `()`) mean IoU for the batch. Returns 0.0
419+
if `target_class_ids` is empty.
420+
"""
421+
if target_class_ids.shape[0] == 0:
422+
return jnp.array(0.0, dtype=jnp.float32)
423+
424+
def _calculate_iou_for_single_class(
425+
class_id: jnp.ndarray,
426+
) -> jnp.ndarray:
427+
target_is_class = (targets == class_id)
428+
pred_is_class = (predictions == class_id)
429+
intersection = jnp.sum(jnp.logical_and(target_is_class, pred_is_class))
430+
union = jnp.sum(jnp.logical_or(target_is_class, pred_is_class))
431+
return intersection / (union + epsilon)
432+
433+
iou_scores_per_class = jax.vmap(_calculate_iou_for_single_class)(
434+
target_class_ids
435+
)
436+
437+
return jnp.mean(iou_scores_per_class)
438+
439+
@classmethod
440+
def from_model_output(
441+
cls,
442+
predictions: jax.Array,
443+
targets: jax.Array,
444+
num_classes: int,
445+
target_class_ids: jax.Array,
446+
from_logits: bool = False,
447+
epsilon: float = 1e-7,
448+
) -> 'IoU':
449+
"""Creates an `IoU` instance from a batch of model outputs.
450+
451+
Per-batch processing:
452+
1. Preprocesses `predictions` and `targets` into integer label masks of
453+
shape `(B, H, W)`. (B: batch size, H: height, W: width).
454+
2. Calls `_calculate_iou` using the provided `target_class_ids` to compute
455+
the batch's mean IoU.
456+
457+
Args:
458+
predictions: `jax.Array`. Model predictions. - If `from_logits` is `True`:
459+
shape `(B, H, W, C)` (C: `num_classes`). - If `from_logits` is `False`:
460+
shape `(B, H, W)` or `(B, H, W, 1)`.
461+
targets: `jax.Array`. Ground truth segmentation masks. Shape `(B, H, W)`
462+
or `(B, H, W, 1)`, integer class labels.
463+
num_classes: Total number of distinct classes (`C`). Integer.
464+
target_class_ids: An array of integer class IDs for which to compute IoU.
465+
from_logits: `bool`. If `True`, `predictions` are logits and argmax is
466+
applied. Default is `False`.
467+
epsilon: `float`. Small value for stable IoU calculation. Default is
468+
`1e-7`.
469+
470+
Returns:
471+
An `IoU` metric instance updated with the IoU score from this batch.
472+
"""
473+
# Preprocessing predictions and targets to be (batch, H, W) integer labels
474+
if from_logits:
475+
if predictions.ndim != 4 or predictions.shape[-1] != num_classes:
476+
raise ValueError(
477+
'Logit predictions must be 4D (batch, H, W, num_classes) with last'
478+
f' dim matching num_classes. Got shape {predictions.shape} and'
479+
f' num_classes {num_classes}'
480+
)
481+
processed_predictions = jnp.argmax(predictions, axis=-1).astype(jnp.int32)
482+
else:
483+
if predictions.ndim == 4 and predictions.shape[-1] == 1:
484+
processed_predictions = jnp.squeeze(predictions, axis=-1).astype(
485+
jnp.int32
486+
)
487+
elif predictions.ndim == 3:
488+
processed_predictions = predictions.astype(jnp.int32)
489+
else:
490+
raise ValueError(
491+
'Predictions (if not from_logits) must be 3D (batch, H, W) or '
492+
f'4D (batch, H, W, 1). Got shape {predictions.shape}'
493+
)
494+
if targets.ndim == 4 and targets.shape[-1] == 1:
495+
processed_targets = jnp.squeeze(targets, axis=-1).astype(jnp.int32)
496+
elif targets.ndim == 3:
497+
processed_targets = targets.astype(jnp.int32)
498+
else:
499+
raise ValueError(
500+
'Targets must be 3D (batch, H, W) or 4D (batch, H, W, 1). '
501+
f'Got shape {targets.shape}'
502+
)
503+
504+
iou_score = cls._calculate_iou(
505+
targets=processed_targets,
506+
predictions=processed_predictions,
507+
target_class_ids=target_class_ids,
508+
epsilon=epsilon,
509+
)
510+
return super().from_model_output(values=iou_score)

0 commit comments

Comments
 (0)