|
17 | 17 | import os |
18 | 18 | os.environ['KERAS_BACKEND'] = 'jax' |
19 | 19 |
|
20 | | -from src.metrax.classification_metrics import FBetaScore |
21 | 20 | from absl.testing import absltest |
22 | 21 | from absl.testing import parameterized |
| 22 | +from metrax import FBetaScore |
23 | 23 | import jax.numpy as jnp |
24 | 24 | import keras |
25 | 25 | import metrax |
@@ -83,10 +83,10 @@ def test_aucroc_empty(self): |
83 | 83 | def test_fbeta_empty(self): |
84 | 84 | """Tests the `empty` method of the `FBetaScore` class.""" |
85 | 85 | m = metrax.FBetaScore.empty() |
86 | | - self.assertEqual(m.beta, 1.0) |
87 | 86 | self.assertEqual(m.true_positives, jnp.array(0, jnp.float32)) |
88 | 87 | self.assertEqual(m.false_positives, jnp.array(0, jnp.float32)) |
89 | 88 | self.assertEqual(m.false_negatives, jnp.array(0, jnp.float32)) |
| 89 | + self.assertEqual(m.beta, 1.0) |
90 | 90 |
|
91 | 91 | @parameterized.named_parameters( |
92 | 92 | ('basic_f16', OUTPUT_LABELS, OUTPUT_PREDS_F16, SAMPLE_WEIGHTS), |
@@ -277,19 +277,18 @@ def test_aucroc(self, inputs, dtype): |
277 | 277 | # Testing function for F-Beta classification |
278 | 278 | # name, output true, output prediction, threshold, beta |
279 | 279 | @parameterized.named_parameters( |
280 | | - ('basic_f16', OUTPUT_LABELS, OUTPUT_PREDS_F16, 0.5, 1.0), |
281 | | - ('high_threshold_f16', OUTPUT_LABELS, OUTPUT_PREDS_F16, 0.7, 2.0), |
282 | | - ('low_threshold_f16', OUTPUT_LABELS, OUTPUT_PREDS_F16, 0.1, 3.0), |
283 | | - ('basic_f32', OUTPUT_LABELS, OUTPUT_PREDS_F32, 0.5, 1.0), |
284 | | - ('high_threshold_f32', OUTPUT_LABELS, OUTPUT_PREDS_F32, 0.7, 2.0), |
285 | | - ('low_threshold_f32', OUTPUT_LABELS, OUTPUT_PREDS_F32, 0.1, 1.0), |
286 | | - ('basic_bf16', OUTPUT_LABELS, OUTPUT_PREDS_BF16, 0.5, 3.0), |
287 | | - ('high_threshold_bf16', OUTPUT_LABELS, OUTPUT_PREDS_BF16, 0.7, 1.0), |
288 | | - ('low_threshold_bf16', OUTPUT_LABELS, OUTPUT_PREDS_BF16, 0.1, 2.0), |
289 | | - ('batch_size_one', OUTPUT_LABELS_BS1, OUTPUT_PREDS_BS1, 0.5, 1.0), |
| 280 | + ('basic_f16_beta_1.0', OUTPUT_LABELS, OUTPUT_PREDS_F16, 0.5, 1.0), |
| 281 | + ('basic_f32_beta_1.0', OUTPUT_LABELS, OUTPUT_PREDS_F32, 0.5, 1.0), |
| 282 | + ('low_threshold_f32_beta_1.0', OUTPUT_LABELS, OUTPUT_PREDS_F32, 0.1, 1.0), |
| 283 | + ('high_threshold_bf16_beta_1.0', OUTPUT_LABELS, OUTPUT_PREDS_BF16, 0.7, 1.0), |
| 284 | + ('batch_size_one_beta_1.0', OUTPUT_LABELS_BS1, OUTPUT_PREDS_BS1, 0.5, 1.0), |
| 285 | + ('high_threshold_f16_beta_2.0', OUTPUT_LABELS, OUTPUT_PREDS_F16, 0.7, 2.0), |
| 286 | + ('high_threshold_f32_beta_2.0', OUTPUT_LABELS, OUTPUT_PREDS_F32, 0.7, 2.0), |
| 287 | + ('low_threshold_bf16_beta_2.0', OUTPUT_LABELS, OUTPUT_PREDS_BF16, 0.1, 2.0), |
| 288 | + ('low_threshold_f16_beta_3.0', OUTPUT_LABELS, OUTPUT_PREDS_F16, 0.1, 3.0), |
| 289 | + ('basic_bf16_beta_3.0', OUTPUT_LABELS, OUTPUT_PREDS_BF16, 0.5, 3.0), |
290 | 290 | ) |
291 | 291 | def test_fbetascore(self, y_true, y_pred, threshold, beta): |
292 | | - |
293 | 292 | # Define the Keras FBeta class to be tested against |
294 | 293 | keras_fbeta = keras.metrics.FBetaScore(beta=beta, threshold=threshold) |
295 | 294 | keras_fbeta.update_state(y_true, y_pred) |
|
0 commit comments