Skip to content

Commit 0deced5

Browse files
authored
add Dice to metrax (#64) (#91)
* add Dice to metrax (#64) * fix: added sample_weights parameter for DICE * removed sample_weights from DICE * moved DICE form classification_metrics.py to image_metrics.py
1 parent 583edd9 commit 0deced5

File tree

8 files changed

+159
-4
lines changed

8 files changed

+159
-4
lines changed

src/metrax/__init__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,7 @@
2626
AveragePrecisionAtK = ranking_metrics.AveragePrecisionAtK
2727
BLEU = nlp_metrics.BLEU
2828
DCGAtK = ranking_metrics.DCGAtK
29+
Dice = image_metrics.Dice
2930
IoU = image_metrics.IoU
3031
MAE = regression_metrics.MAE
3132
MRR = ranking_metrics.MRR
@@ -53,6 +54,7 @@
5354
"AveragePrecisionAtK",
5455
"BLEU",
5556
"DCGAtK",
57+
"Dice",
5658
"IoU",
5759
"MAE",
5860
"MRR",

src/metrax/classification_metrics.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -576,4 +576,4 @@ def compute(self) -> jax.Array:
576576
self.false_positives, self.false_positives + self.true_negatives
577577
)
578578
# Threshold goes from 0 to 1, so trapezoid is negative.
579-
return jnp.trapezoid(tp_rate, fp_rate) * -1
579+
return jnp.trapezoid(tp_rate, fp_rate) * -1

src/metrax/classification_metrics_test.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -47,7 +47,6 @@
4747
(BATCHES, 1),
4848
).astype(np.float32)
4949

50-
5150
class ClassificationMetricsTest(parameterized.TestCase):
5251

5352
def test_precision_empty(self):
@@ -78,7 +77,7 @@ def test_aucroc_empty(self):
7877
self.assertEqual(m.false_positives, jnp.array(0, jnp.float32))
7978
self.assertEqual(m.false_negatives, jnp.array(0, jnp.float32))
8079
self.assertEqual(m.num_thresholds, 0)
81-
80+
8281
@parameterized.named_parameters(
8382
('basic_f16', OUTPUT_LABELS, OUTPUT_PREDS_F16, SAMPLE_WEIGHTS),
8483
('basic_f32', OUTPUT_LABELS, OUTPUT_PREDS_F32, SAMPLE_WEIGHTS),

src/metrax/image_metrics.py

Lines changed: 80 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@
1414

1515
"""A collection of different metrics for image models."""
1616

17+
from clu import metrics as clu_metrics
1718
import flax
1819
import jax
1920
from jax import lax
@@ -584,3 +585,82 @@ def from_model_output(
584585
"""
585586
batch_psnr = cls._calculate_psnr(predictions, targets, max_val=max_val)
586587
return super().from_model_output(values=batch_psnr)
588+
589+
590+
@flax.struct.dataclass
591+
class Dice(clu_metrics.Metric):
592+
r"""Computes the Dice coefficient between `y_true` and `y_pred`.
593+
594+
Dice is a similarity measure used to measure overlap between two samples.
595+
A Dice score of 1 indicates perfect overlap; 0 indicates no overlap.
596+
597+
The formula is:
598+
599+
.. math::
600+
601+
\text{Dice} = \frac{2 \cdot \sum (y_{\text{true}} \cdot y_{\text{pred}})}
602+
{\sum y_{\text{true}} + \sum y_{\text{pred}} + \epsilon}
603+
604+
Attributes:
605+
intersection: Sum of element-wise product between `y_true` and `y_pred`.
606+
sum_true: Sum of y_true across all examples.
607+
sum_pred: Sum of y_pred across all examples.
608+
609+
"""
610+
611+
intersection: jax.Array
612+
sum_pred: jax.Array
613+
sum_true: jax.Array
614+
615+
@classmethod
616+
def empty(cls) -> 'Dice':
617+
return cls(
618+
intersection=jnp.array(0.0, jnp.float32),
619+
sum_pred=jnp.array(0.0, jnp.float32),
620+
sum_true=jnp.array(0.0, jnp.float32),
621+
)
622+
623+
@classmethod
624+
def from_model_output(
625+
cls,
626+
predictions: jax.Array,
627+
labels: jax.Array,
628+
) -> 'Dice':
629+
"""Updates the metric.
630+
631+
Args:
632+
predictions: A floating point vector whose values are in the range [0,
633+
1]. The shape should be (batch_size,).
634+
labels: True value. The value is expected to be 0 or 1. The shape should
635+
be (batch_size,).
636+
637+
Returns:
638+
Updated Dice metric.
639+
640+
"""
641+
predictions = jnp.asarray(predictions, jnp.float32)
642+
labels = jnp.asarray(labels, jnp.float32)
643+
644+
intersection = jnp.sum(predictions * labels)
645+
sum_pred = jnp.sum(predictions)
646+
sum_true = jnp.sum(labels)
647+
648+
return cls(
649+
intersection=intersection,
650+
sum_pred=sum_pred,
651+
sum_true=sum_true,
652+
)
653+
654+
def merge(self, other: 'Dice') -> 'Dice':
655+
return type(self)(
656+
intersection=self.intersection + other.intersection,
657+
sum_pred=self.sum_pred + other.sum_pred,
658+
sum_true=self.sum_true + other.sum_true,
659+
)
660+
661+
def compute(self) -> jax.Array:
662+
"""Returns the final Dice coefficient."""
663+
epsilon = 1e-7
664+
return (2.0 * self.intersection) / (
665+
self.sum_pred + self.sum_true + epsilon
666+
)

src/metrax/image_metrics_test.py

Lines changed: 61 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -190,9 +190,38 @@
190190
NUM_CLASSES_IOU_6 = NUM_CLASSES_IOU_2
191191
TARGET_CLASS_IDS_IOU_6 = np.array(range(NUM_CLASSES_IOU_6))
192192

193+
# Test data for Dice
194+
BATCHES = 4
195+
BATCH_SIZE = 8
196+
OUTPUT_LABELS = np.random.randint(
197+
0,
198+
2,
199+
size=(BATCHES, BATCH_SIZE),
200+
).astype(np.float32)
201+
OUTPUT_PREDS = np.random.uniform(size=(BATCHES, BATCH_SIZE))
202+
OUTPUT_PREDS_F16 = OUTPUT_PREDS.astype(jnp.float16)
203+
OUTPUT_PREDS_F32 = OUTPUT_PREDS.astype(jnp.float32)
204+
OUTPUT_PREDS_BF16 = OUTPUT_PREDS.astype(jnp.bfloat16)
205+
OUTPUT_LABELS_BS1 = np.random.randint(
206+
0,
207+
2,
208+
size=(BATCHES, 1),
209+
).astype(np.float32)
210+
OUTPUT_PREDS_BS1 = np.random.uniform(size=(BATCHES, 1)).astype(np.float32)
193211

194-
class ImageMetricsTest(parameterized.TestCase):
212+
DICE_ALL_ONES = (jnp.array([1, 1, 1, 1]), jnp.array([1, 1, 1, 1]))
213+
DICE_ALL_ZEROS = (jnp.array([0, 0, 0, 0]), jnp.array([0, 0, 0, 0]))
214+
DICE_NO_OVERLAP = (jnp.array([1, 1, 0, 0]), jnp.array([0, 0, 1, 1]))
195215

216+
class ImageMetricsTest(parameterized.TestCase):
217+
218+
def test_dice_empty(self):
219+
"""Tests the `empty` method of the `Dice` class."""
220+
m = metrax.Dice.empty()
221+
self.assertEqual(m.intersection, jnp.array(0, jnp.float32))
222+
self.assertEqual(m.sum_true, jnp.array(0, jnp.float32))
223+
self.assertEqual(m.sum_pred, jnp.array(0, jnp.float32))
224+
196225
@parameterized.named_parameters(
197226
(
198227
'ssim_basic_norm_single_channel',
@@ -500,5 +529,36 @@ def test_psnr_against_tensorflow(
500529
atol=1e-4,
501530
err_msg="PSNR mismatch",
502531
)
532+
533+
@parameterized.named_parameters(
534+
('basic_f32', OUTPUT_LABELS, OUTPUT_PREDS_F32),
535+
('low_threshold', OUTPUT_LABELS, OUTPUT_PREDS_F32),
536+
('high_threshold', OUTPUT_LABELS, OUTPUT_PREDS_F32),
537+
('batch_size_one', OUTPUT_LABELS_BS1, OUTPUT_PREDS_BS1),
538+
('all_ones', *DICE_ALL_ONES),
539+
('all_zeros', *DICE_ALL_ZEROS),
540+
('no_overlap', *DICE_NO_OVERLAP),
541+
)
542+
def test_dice(self, y_true, y_pred):
543+
"""Test that Dice metric computes expected values."""
544+
y_true = jnp.asarray(y_true, jnp.float32)
545+
y_pred = jnp.asarray(y_pred, jnp.float32)
546+
547+
# Manually compute expected Dice
548+
eps = 1e-7
549+
intersection = jnp.sum(y_true * y_pred)
550+
sum_pred = jnp.sum(y_pred)
551+
sum_true = jnp.sum(y_true)
552+
expected = (2.0 * intersection) / (sum_pred + sum_true + eps)
553+
554+
# Compute using the metric class
555+
metric = metrax.Dice.from_model_output(
556+
predictions=y_pred,
557+
labels=y_true
558+
)
559+
result = metric.compute()
560+
561+
np.testing.assert_allclose(result, expected, rtol=1e-5, atol=1e-5)
562+
503563
if __name__ == '__main__':
504564
absltest.main()

src/metrax/metrax_test.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -98,6 +98,11 @@ class MetraxTest(parameterized.TestCase):
9898
'ks': KS,
9999
},
100100
),
101+
(
102+
'dice',
103+
metrax.Dice,
104+
{'predictions': OUTPUT_LABELS, 'labels': OUTPUT_PREDS},
105+
),
101106
(
102107
'iou',
103108
metrax.IoU,

src/metrax/nnx/__init__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,7 @@
2121
AveragePrecisionAtK = nnx_metrics.AveragePrecisionAtK
2222
BLEU = nnx_metrics.BLEU
2323
DCGAtK = nnx_metrics.DCGAtK
24+
Dice = nnx_metrics.Dice
2425
IoU = nnx_metrics.IoU
2526
MAE = nnx_metrics.MAE
2627
MRR = nnx_metrics.MRR
@@ -47,6 +48,7 @@
4748
"AveragePrecisionAtK",
4849
"BLEU",
4950
"DCGAtK",
51+
"Dice",
5052
"IoU",
5153
"MRR",
5254
"MAE"

src/metrax/nnx/nnx_metrics.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -66,6 +66,13 @@ def __init__(self):
6666
super().__init__(metrax.DCGAtK)
6767

6868

69+
class Dice(NnxWrapper):
70+
"""An NNX class for the Metrax metric Dice."""
71+
72+
def __init__(self):
73+
super().__init__(metrax.Dice)
74+
75+
6976
class IoU(NnxWrapper):
7077
"""An NNX class for the Metrax metric IoU."""
7178

0 commit comments

Comments
 (0)