Skip to content

Commit 1fc3ba8

Browse files
committed
Modified docstrings to match order of variable appearance. Updated docstring to give more description for the from_model_output method. Added beta value into the names for the test cases. Order test cases based on beta value for readability.
1 parent 20d6067 commit 1fc3ba8

File tree

2 files changed

+19
-20
lines changed

2 files changed

+19
-20
lines changed

src/metrax/classification_metrics.py

Lines changed: 7 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -578,10 +578,8 @@ def compute(self) -> jax.Array:
578578

579579
@flax.struct.dataclass
580580
class FBetaScore(clu_metrics.Metric):
581-
582581
"""
583582
F-Beta score Metric class
584-
585583
Computes the F-Beta score for the binary classification given 'predictions' and 'labels'.
586584
587585
Formula for F-Beta Score:
@@ -591,13 +589,13 @@ class FBetaScore(clu_metrics.Metric):
591589
F-Beta turns into the F1 Score when beta = 1.0
592590
593591
Attributes:
594-
beta: The beta value used in the F-Score metric
595592
true_positives: The count of true positive instances from the given data,
596593
label, and threshold.
597594
false_positives: The count of false positive instances from the given data,
598595
label, and threshold.
599596
false_negatives: The count of false negative instances from the given data,
600597
label, and threshold.
598+
beta: The beta value used in the F-Score metric
601599
"""
602600

603601
true_positives: jax.Array
@@ -609,10 +607,10 @@ class FBetaScore(clu_metrics.Metric):
609607
@classmethod
610608
def empty(cls) -> 'FBetaScore':
611609
return cls(
612-
beta = 1.0,
613610
true_positives = jnp.array(0, jnp.float32),
614611
false_positives = jnp.array(0, jnp.float32),
615612
false_negatives = jnp.array(0, jnp.float32),
613+
beta=1.0,
616614
)
617615

618616
@classmethod
@@ -623,17 +621,19 @@ def from_model_output(
623621
beta = beta,
624622
threshold = 0.5,) -> 'FBetaScore':
625623
"""Updates the metric.
624+
Note: When only predictions and labels are given, the score calculated
625+
is the F1 score if the FBetaScore beta value has not been previously modified.
626626
627627
Args:
628-
threshold: threshold value to use in the F-Score metric a floating number.
629-
beta: beta value to use in the F-Score metric. A floating number.
630628
predictions: A floating point 1D vector whose values are in the range [0,
631629
1]. The shape should be (batch_size,).
632630
labels: True value. The value is expected to be 0 or 1. The shape should
633631
be (batch_size,).
632+
beta: beta value to use in the F-Score metric. A floating number.
633+
threshold: threshold value to use in the F-Score metric. A floating number.
634634
635635
Returns:
636-
The true positives, false positives, and false negatives.
636+
The updated FBetaScore object.
637637
638638
Raises:
639639
ValueError: If type of `labels` is wrong or the shapes of `predictions`

src/metrax/classification_metrics_test.py

Lines changed: 12 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -17,9 +17,9 @@
1717
import os
1818
os.environ['KERAS_BACKEND'] = 'jax'
1919

20-
from src.metrax.classification_metrics import FBetaScore
2120
from absl.testing import absltest
2221
from absl.testing import parameterized
22+
from metrax import FBetaScore
2323
import jax.numpy as jnp
2424
import keras
2525
import metrax
@@ -83,10 +83,10 @@ def test_aucroc_empty(self):
8383
def test_fbeta_empty(self):
8484
"""Tests the `empty` method of the `FBetaScore` class."""
8585
m = metrax.FBetaScore.empty()
86-
self.assertEqual(m.beta, 1.0)
8786
self.assertEqual(m.true_positives, jnp.array(0, jnp.float32))
8887
self.assertEqual(m.false_positives, jnp.array(0, jnp.float32))
8988
self.assertEqual(m.false_negatives, jnp.array(0, jnp.float32))
89+
self.assertEqual(m.beta, 1.0)
9090

9191
@parameterized.named_parameters(
9292
('basic_f16', OUTPUT_LABELS, OUTPUT_PREDS_F16, SAMPLE_WEIGHTS),
@@ -277,19 +277,18 @@ def test_aucroc(self, inputs, dtype):
277277
# Testing function for F-Beta classification
278278
# name, output true, output prediction, threshold, beta
279279
@parameterized.named_parameters(
280-
('basic_f16', OUTPUT_LABELS, OUTPUT_PREDS_F16, 0.5, 1.0),
281-
('high_threshold_f16', OUTPUT_LABELS, OUTPUT_PREDS_F16, 0.7, 2.0),
282-
('low_threshold_f16', OUTPUT_LABELS, OUTPUT_PREDS_F16, 0.1, 3.0),
283-
('basic_f32', OUTPUT_LABELS, OUTPUT_PREDS_F32, 0.5, 1.0),
284-
('high_threshold_f32', OUTPUT_LABELS, OUTPUT_PREDS_F32, 0.7, 2.0),
285-
('low_threshold_f32', OUTPUT_LABELS, OUTPUT_PREDS_F32, 0.1, 1.0),
286-
('basic_bf16', OUTPUT_LABELS, OUTPUT_PREDS_BF16, 0.5, 3.0),
287-
('high_threshold_bf16', OUTPUT_LABELS, OUTPUT_PREDS_BF16, 0.7, 1.0),
288-
('low_threshold_bf16', OUTPUT_LABELS, OUTPUT_PREDS_BF16, 0.1, 2.0),
289-
('batch_size_one', OUTPUT_LABELS_BS1, OUTPUT_PREDS_BS1, 0.5, 1.0),
280+
('basic_f16_beta_1.0', OUTPUT_LABELS, OUTPUT_PREDS_F16, 0.5, 1.0),
281+
('basic_f32_beta_1.0', OUTPUT_LABELS, OUTPUT_PREDS_F32, 0.5, 1.0),
282+
('low_threshold_f32_beta_1.0', OUTPUT_LABELS, OUTPUT_PREDS_F32, 0.1, 1.0),
283+
('high_threshold_bf16_beta_1.0', OUTPUT_LABELS, OUTPUT_PREDS_BF16, 0.7, 1.0),
284+
('batch_size_one_beta_1.0', OUTPUT_LABELS_BS1, OUTPUT_PREDS_BS1, 0.5, 1.0),
285+
('high_threshold_f16_beta_2.0', OUTPUT_LABELS, OUTPUT_PREDS_F16, 0.7, 2.0),
286+
('high_threshold_f32_beta_2.0', OUTPUT_LABELS, OUTPUT_PREDS_F32, 0.7, 2.0),
287+
('low_threshold_bf16_beta_2.0', OUTPUT_LABELS, OUTPUT_PREDS_BF16, 0.1, 2.0),
288+
('low_threshold_f16_beta_3.0', OUTPUT_LABELS, OUTPUT_PREDS_F16, 0.1, 3.0),
289+
('basic_bf16_beta_3.0', OUTPUT_LABELS, OUTPUT_PREDS_BF16, 0.5, 3.0),
290290
)
291291
def test_fbetascore(self, y_true, y_pred, threshold, beta):
292-
293292
# Define the Keras FBeta class to be tested against
294293
keras_fbeta = keras.metrics.FBetaScore(beta=beta, threshold=threshold)
295294
keras_fbeta.update_state(y_true, y_pred)

0 commit comments

Comments
 (0)