@@ -314,11 +314,11 @@ def _calculate_ssim_for_channel(x_ch, y_ch, conv_kernel, c1, c2):
314314 return jnp .mean (ssim_scores_stacked , axis = - 1 ) # (batch,)
315315
316316 @classmethod
317- def from_model_output ( # type: ignore[override]
317+ def from_model_output (
318318 cls ,
319- predictions : jax .Array , # Represents predicted images (y_pred)
320- targets : jax .Array , # Represents ground truth images (y_true)
321- max_val : float , # Dynamic range of pixel values
319+ predictions : jax .Array ,
320+ targets : jax .Array ,
321+ max_val : float ,
322322 filter_size : int = 11 ,
323323 filter_sigma : float = 1.5 ,
324324 k1 : float = 0.01 ,
@@ -360,3 +360,151 @@ def from_model_output( # type: ignore[override]
360360 k2 = k2 ,
361361 )
362362 return super ().from_model_output (values = batch_ssim_values )
363+
364+
365+ @flax .struct .dataclass
366+ class IoU (base .Average ):
367+ r"""Measures Intersection over Union (IoU) for semantic segmentation.
368+
369+ The general formula for IoU for a single class is:
370+ $IoU_{class} = \frac{TP}{TP + FP + FN}$
371+ where TP, FP, FN are True Positives, False Positives, and False Negatives.
372+
373+ **Per-Batch Processing:**
374+ For each input batch, a mean IoU is calculated. This involves:
375+ 1. Aggregating TP, FP, and FN pixel counts for each specified target class
376+ (from the required `target_class_ids` list) across all samples within the
377+ batch.
378+ 2. Computing IoU for each of these classes using the batch-aggregated counts:
379+ $IoU_{class} = \frac{TP}{TP + FP + FN + \epsilon}$.
380+ 3. Averaging these per-class IoU scores to get a single value for the batch.
381+ - If `target_class_ids` is empty, an array of zeros of shape `(B,)`
382+ (where `B` is batch size) is produced by `_calculate_iou`.
383+ - Otherwise, a scalar `jnp.ndarray` (shape `()`) representing the mean
384+ IoU is produced.
385+
386+ **Accumulation & Final Metric:**
387+ This class inherits from `base.Average`. It accumulates the results from
388+ per-batch processing and `compute()` returns the final mean IoU as a scalar
389+ `jnp.ndarray` (shape `()`).
390+ """
391+
392+ @staticmethod
393+ def _calculate_iou (
394+ targets : jnp .ndarray ,
395+ predictions : jnp .ndarray ,
396+ target_class_ids : jnp .ndarray ,
397+ epsilon : float = 1e-7 ,
398+ ) -> jnp .ndarray :
399+ r"""Computes mean IoU for a processed batch by class-wise aggregation using jax.vmap.
400+
401+ Per-batch processing: For each target class in the provided
402+ `target_class_ids` list, True Positives (TP), False Positives (FP), and
403+ False Negatives (FN) are summed across all items in the input batch.
404+ The IoU for that class is $TP / (TP + FP + FN + \epsilon)$.
405+ These per-class IoU scores are then averaged. If `target_class_ids` is
406+ empty, a scalar 0.0 is returned.
407+
408+ Args:
409+ targets: Ground truth segmentation masks. Shape is `(B, H, W)`, integer
410+ class labels. (B: batch size, H: height, W: width)
411+ predictions: Predicted segmentation masks. Shape is `(B, H, W)`, integer
412+ class labels.
413+ target_class_ids: An array of integer class IDs for which to compute IoU.
414+ epsilon: Small float added to the denominator for numerical stability.
415+ Default is `1e-7`.
416+
417+ Returns:
418+ scalar `jnp.ndarray` (shape `()`) mean IoU for the batch. Returns 0.0
419+ if `target_class_ids` is empty.
420+ """
421+ if target_class_ids .shape [0 ] == 0 :
422+ return jnp .array (0.0 , dtype = jnp .float32 )
423+
424+ def _calculate_iou_for_single_class (
425+ class_id : jnp .ndarray ,
426+ ) -> jnp .ndarray :
427+ target_is_class = (targets == class_id )
428+ pred_is_class = (predictions == class_id )
429+ intersection = jnp .sum (jnp .logical_and (target_is_class , pred_is_class ))
430+ union = jnp .sum (jnp .logical_or (target_is_class , pred_is_class ))
431+ return intersection / (union + epsilon )
432+
433+ iou_scores_per_class = jax .vmap (_calculate_iou_for_single_class )(
434+ target_class_ids
435+ )
436+
437+ return jnp .mean (iou_scores_per_class )
438+
439+ @classmethod
440+ def from_model_output (
441+ cls ,
442+ predictions : jax .Array ,
443+ targets : jax .Array ,
444+ num_classes : int ,
445+ target_class_ids : jax .Array ,
446+ from_logits : bool = False ,
447+ epsilon : float = 1e-7 ,
448+ ) -> 'IoU' :
449+ """Creates an `IoU` instance from a batch of model outputs.
450+
451+ Per-batch processing:
452+ 1. Preprocesses `predictions` and `targets` into integer label masks of
453+ shape `(B, H, W)`. (B: batch size, H: height, W: width).
454+ 2. Calls `_calculate_iou` using the provided `target_class_ids` to compute
455+ the batch's mean IoU.
456+
457+ Args:
458+ predictions: `jax.Array`. Model predictions. - If `from_logits` is `True`:
459+ shape `(B, H, W, C)` (C: `num_classes`). - If `from_logits` is `False`:
460+ shape `(B, H, W)` or `(B, H, W, 1)`.
461+ targets: `jax.Array`. Ground truth segmentation masks. Shape `(B, H, W)`
462+ or `(B, H, W, 1)`, integer class labels.
463+ num_classes: Total number of distinct classes (`C`). Integer.
464+ target_class_ids: An array of integer class IDs for which to compute IoU.
465+ from_logits: `bool`. If `True`, `predictions` are logits and argmax is
466+ applied. Default is `False`.
467+ epsilon: `float`. Small value for stable IoU calculation. Default is
468+ `1e-7`.
469+
470+ Returns:
471+ An `IoU` metric instance updated with the IoU score from this batch.
472+ """
473+ # Preprocessing predictions and targets to be (batch, H, W) integer labels
474+ if from_logits :
475+ if predictions .ndim != 4 or predictions .shape [- 1 ] != num_classes :
476+ raise ValueError (
477+ 'Logit predictions must be 4D (batch, H, W, num_classes) with last'
478+ f' dim matching num_classes. Got shape { predictions .shape } and'
479+ f' num_classes { num_classes } '
480+ )
481+ processed_predictions = jnp .argmax (predictions , axis = - 1 ).astype (jnp .int32 )
482+ else :
483+ if predictions .ndim == 4 and predictions .shape [- 1 ] == 1 :
484+ processed_predictions = jnp .squeeze (predictions , axis = - 1 ).astype (
485+ jnp .int32
486+ )
487+ elif predictions .ndim == 3 :
488+ processed_predictions = predictions .astype (jnp .int32 )
489+ else :
490+ raise ValueError (
491+ 'Predictions (if not from_logits) must be 3D (batch, H, W) or '
492+ f'4D (batch, H, W, 1). Got shape { predictions .shape } '
493+ )
494+ if targets .ndim == 4 and targets .shape [- 1 ] == 1 :
495+ processed_targets = jnp .squeeze (targets , axis = - 1 ).astype (jnp .int32 )
496+ elif targets .ndim == 3 :
497+ processed_targets = targets .astype (jnp .int32 )
498+ else :
499+ raise ValueError (
500+ 'Targets must be 3D (batch, H, W) or 4D (batch, H, W, 1). '
501+ f'Got shape { targets .shape } '
502+ )
503+
504+ iou_score = cls ._calculate_iou (
505+ targets = processed_targets ,
506+ predictions = processed_predictions ,
507+ target_class_ids = target_class_ids ,
508+ epsilon = epsilon ,
509+ )
510+ return super ().from_model_output (values = iou_score )
0 commit comments