Skip to content

Commit ec41518

Browse files
authored
Implement empty function (#14)
1 parent 9d50348 commit ec41518

File tree

2 files changed

+99
-0
lines changed

2 files changed

+99
-0
lines changed

src/metrax/metrics.py

Lines changed: 44 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -165,6 +165,15 @@ class RSQUARED(clu_metrics.Metric):
165165
count: jax.Array
166166
sum_of_squared_error: jax.Array
167167
sum_of_squared_label: jax.Array
168+
169+
170+
@classmethod
171+
def empty(cls) -> 'RSQUARED':
172+
return cls(
173+
total=jnp.array(0, jnp.float32),
174+
count=jnp.array(0, jnp.float32),
175+
sum_of_squared_error=jnp.array(0, jnp.float32),
176+
sum_of_squared_label=jnp.array(0, jnp.float32))
168177

169178
@classmethod
170179
def from_model_output(
@@ -259,6 +268,12 @@ class Precision(clu_metrics.Metric):
259268
true_positives: jax.Array
260269
false_positives: jax.Array
261270

271+
@classmethod
272+
def empty(cls) -> 'Precision':
273+
return cls(
274+
true_positives=jnp.array(0, jnp.float32),
275+
false_positives=jnp.array(0, jnp.float32))
276+
262277
@classmethod
263278
def from_model_output(
264279
cls,
@@ -326,6 +341,12 @@ class Recall(clu_metrics.Metric):
326341
true_positives: jax.Array
327342
false_negatives: jax.Array
328343

344+
@classmethod
345+
def empty(cls) -> 'Recall':
346+
return cls(
347+
true_positives=jnp.array(0, jnp.float32),
348+
false_negatives=jnp.array(0, jnp.float32))
349+
329350
@classmethod
330351
def from_model_output(
331352
cls, predictions: jax.Array, labels: jax.Array, threshold: float = 0.5
@@ -415,6 +436,14 @@ class AUCPR(clu_metrics.Metric):
415436
false_positives: jax.Array
416437
false_negatives: jax.Array
417438
num_thresholds: int
439+
440+
@classmethod
441+
def empty(cls) -> 'AUCPR':
442+
return cls(
443+
true_positives=jnp.array(0, jnp.float32),
444+
false_positives=jnp.array(0, jnp.float32),
445+
false_negatives=jnp.array(0, jnp.float32),
446+
num_thresholds=0)
418447

419448
@classmethod
420449
def from_model_output(
@@ -589,6 +618,15 @@ class AUCROC(clu_metrics.Metric):
589618
false_negatives: jax.Array
590619
num_thresholds: int
591620

621+
@classmethod
622+
def empty(cls) -> 'AUCROC':
623+
return cls(
624+
true_positives=jnp.array(0, jnp.float32),
625+
true_negatives=jnp.array(0, jnp.float32),
626+
false_positives=jnp.array(0, jnp.float32),
627+
false_negatives=jnp.array(0, jnp.float32),
628+
num_thresholds=0)
629+
592630
@classmethod
593631
def from_model_output(
594632
cls,
@@ -698,6 +736,12 @@ class Perplexity(clu_metrics.Metric):
698736
aggregate_crossentropy: jax.Array
699737
num_samples: jax.Array
700738

739+
@classmethod
740+
def empty(cls) -> 'Perplexity':
741+
return cls(
742+
aggregate_crossentropy=jnp.array(0, jnp.float32),
743+
num_samples=jnp.array(0, jnp.float32))
744+
701745
@classmethod
702746
def from_model_output(
703747
cls,

src/metrax/metrics_test.py

Lines changed: 55 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -117,6 +117,61 @@ def compute_aucroc(self, model_outputs, sample_weights=None):
117117
metric = update if metric is None else metric.merge(update)
118118
return metric.compute()
119119

120+
def test_mse_empty(self):
121+
"""Tests the `empty` method of the `MSE` class."""
122+
m = metrax.MSE.empty()
123+
self.assertEqual(m.total, jnp.array(0, jnp.float32))
124+
self.assertEqual(m.count, jnp.array(0, jnp.int32))
125+
126+
def test_rmse_empty(self):
127+
"""Tests the `empty` method of the `RMSE` class."""
128+
m = metrax.RMSE.empty()
129+
self.assertEqual(m.total, jnp.array(0, jnp.float32))
130+
self.assertEqual(m.count, jnp.array(0, jnp.int32))
131+
132+
def test_rsquared_empty(self):
133+
"""Tests the `empty` method of the `RSQUARED` class."""
134+
m = metrax.RSQUARED.empty()
135+
self.assertEqual(m.total, jnp.array(0, jnp.float32))
136+
self.assertEqual(m.count, jnp.array(0, jnp.float32))
137+
self.assertEqual(m.sum_of_squared_error, jnp.array(0, jnp.float32))
138+
self.assertEqual(m.sum_of_squared_label, jnp.array(0, jnp.float32))
139+
140+
def test_precision_empty(self):
141+
"""Tests the `empty` method of the `Precision` class."""
142+
m = metrax.Precision.empty()
143+
self.assertEqual(m.true_positives, jnp.array(0, jnp.float32))
144+
self.assertEqual(m.false_positives, jnp.array(0, jnp.float32))
145+
146+
def test_recall_empty(self):
147+
"""Tests the `empty` method of the `Recall` class."""
148+
m = metrax.Recall.empty()
149+
self.assertEqual(m.true_positives, jnp.array(0, jnp.float32))
150+
self.assertEqual(m.false_negatives, jnp.array(0, jnp.float32))
151+
152+
def test_aucpr_empty(self):
153+
"""Tests the `empty` method of the `AUCPR` class."""
154+
m = metrax.AUCPR.empty()
155+
self.assertEqual(m.true_positives, jnp.array(0, jnp.float32))
156+
self.assertEqual(m.false_positives, jnp.array(0, jnp.float32))
157+
self.assertEqual(m.false_negatives, jnp.array(0, jnp.float32))
158+
self.assertEqual(m.num_thresholds, 0)
159+
160+
def test_aucroc_empty(self):
161+
"""Tests the `empty` method of the `AUCROC` class."""
162+
m = metrax.AUCROC.empty()
163+
self.assertEqual(m.true_positives, jnp.array(0, jnp.float32))
164+
self.assertEqual(m.true_negatives, jnp.array(0, jnp.float32))
165+
self.assertEqual(m.false_positives, jnp.array(0, jnp.float32))
166+
self.assertEqual(m.false_negatives, jnp.array(0, jnp.float32))
167+
self.assertEqual(m.num_thresholds, 0)
168+
169+
def test_perplexity_empty(self):
170+
"""Tests the `empty` method of the `Perplexity` class."""
171+
m = metrax.Perplexity.empty()
172+
self.assertEqual(m.aggregate_crossentropy, jnp.array(0, jnp.float32))
173+
self.assertEqual(m.num_samples, jnp.array(0, jnp.float32))
174+
120175
@parameterized.named_parameters(
121176
('basic', OUTPUT_LABELS, OUTPUT_PREDS, 0.5),
122177
('high_threshold', OUTPUT_LABELS, OUTPUT_PREDS, 0.7),

0 commit comments

Comments
 (0)