|
190 | 190 | NUM_CLASSES_IOU_6 = NUM_CLASSES_IOU_2 |
191 | 191 | TARGET_CLASS_IDS_IOU_6 = np.array(range(NUM_CLASSES_IOU_6)) |
192 | 192 |
|
| 193 | +# Test data for Dice |
| 194 | +BATCHES = 4 |
| 195 | +BATCH_SIZE = 8 |
| 196 | +OUTPUT_LABELS = np.random.randint( |
| 197 | + 0, |
| 198 | + 2, |
| 199 | + size=(BATCHES, BATCH_SIZE), |
| 200 | +).astype(np.float32) |
| 201 | +OUTPUT_PREDS = np.random.uniform(size=(BATCHES, BATCH_SIZE)) |
| 202 | +OUTPUT_PREDS_F16 = OUTPUT_PREDS.astype(jnp.float16) |
| 203 | +OUTPUT_PREDS_F32 = OUTPUT_PREDS.astype(jnp.float32) |
| 204 | +OUTPUT_PREDS_BF16 = OUTPUT_PREDS.astype(jnp.bfloat16) |
| 205 | +OUTPUT_LABELS_BS1 = np.random.randint( |
| 206 | + 0, |
| 207 | + 2, |
| 208 | + size=(BATCHES, 1), |
| 209 | +).astype(np.float32) |
| 210 | +OUTPUT_PREDS_BS1 = np.random.uniform(size=(BATCHES, 1)).astype(np.float32) |
193 | 211 |
|
194 | | -class ImageMetricsTest(parameterized.TestCase): |
| 212 | +DICE_ALL_ONES = (jnp.array([1, 1, 1, 1]), jnp.array([1, 1, 1, 1])) |
| 213 | +DICE_ALL_ZEROS = (jnp.array([0, 0, 0, 0]), jnp.array([0, 0, 0, 0])) |
| 214 | +DICE_NO_OVERLAP = (jnp.array([1, 1, 0, 0]), jnp.array([0, 0, 1, 1])) |
195 | 215 |
|
| 216 | +class ImageMetricsTest(parameterized.TestCase): |
| 217 | + |
| 218 | + def test_dice_empty(self): |
| 219 | + """Tests the `empty` method of the `Dice` class.""" |
| 220 | + m = metrax.Dice.empty() |
| 221 | + self.assertEqual(m.intersection, jnp.array(0, jnp.float32)) |
| 222 | + self.assertEqual(m.sum_true, jnp.array(0, jnp.float32)) |
| 223 | + self.assertEqual(m.sum_pred, jnp.array(0, jnp.float32)) |
| 224 | + |
196 | 225 | @parameterized.named_parameters( |
197 | 226 | ( |
198 | 227 | 'ssim_basic_norm_single_channel', |
@@ -500,5 +529,36 @@ def test_psnr_against_tensorflow( |
500 | 529 | atol=1e-4, |
501 | 530 | err_msg="PSNR mismatch", |
502 | 531 | ) |
| 532 | + |
| 533 | + @parameterized.named_parameters( |
| 534 | + ('basic_f32', OUTPUT_LABELS, OUTPUT_PREDS_F32), |
| 535 | + ('low_threshold', OUTPUT_LABELS, OUTPUT_PREDS_F32), |
| 536 | + ('high_threshold', OUTPUT_LABELS, OUTPUT_PREDS_F32), |
| 537 | + ('batch_size_one', OUTPUT_LABELS_BS1, OUTPUT_PREDS_BS1), |
| 538 | + ('all_ones', *DICE_ALL_ONES), |
| 539 | + ('all_zeros', *DICE_ALL_ZEROS), |
| 540 | + ('no_overlap', *DICE_NO_OVERLAP), |
| 541 | + ) |
| 542 | + def test_dice(self, y_true, y_pred): |
| 543 | + """Test that Dice metric computes expected values.""" |
| 544 | + y_true = jnp.asarray(y_true, jnp.float32) |
| 545 | + y_pred = jnp.asarray(y_pred, jnp.float32) |
| 546 | + |
| 547 | + # Manually compute expected Dice |
| 548 | + eps = 1e-7 |
| 549 | + intersection = jnp.sum(y_true * y_pred) |
| 550 | + sum_pred = jnp.sum(y_pred) |
| 551 | + sum_true = jnp.sum(y_true) |
| 552 | + expected = (2.0 * intersection) / (sum_pred + sum_true + eps) |
| 553 | + |
| 554 | + # Compute using the metric class |
| 555 | + metric = metrax.Dice.from_model_output( |
| 556 | + predictions=y_pred, |
| 557 | + labels=y_true |
| 558 | + ) |
| 559 | + result = metric.compute() |
| 560 | + |
| 561 | + np.testing.assert_allclose(result, expected, rtol=1e-5, atol=1e-5) |
| 562 | + |
503 | 563 | if __name__ == '__main__': |
504 | 564 | absltest.main() |
0 commit comments