Skip to content

Commit b0768c8

Browse files
committed
change base_metric to base
1 parent aa9431f commit b0768c8

File tree

7 files changed

+22
-26
lines changed

7 files changed

+22
-26
lines changed

src/metrax/__init__.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,7 @@
1212
# See the License for the specific language governing permissions and
1313
# limitations under the License.
1414

15-
from metrax.base_metrics import (
15+
from metrax.base import (
1616
Average,
1717
)
1818
from metrax.classification_metrics import (

src/metrax/classification_metrics.py

Lines changed: 8 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,7 @@
1818
import flax
1919
import jax
2020
import jax.numpy as jnp
21-
from metrax import base_metrics
21+
from metrax import base
2222

2323

2424
def _default_threshold(num_thresholds: int) -> jax.Array:
@@ -112,7 +112,7 @@ def merge(self, other: 'Precision') -> 'Precision':
112112
)
113113

114114
def compute(self) -> jax.Array:
115-
return base_metrics.divide_no_nan(
115+
return base.divide_no_nan(
116116
self.true_positives, (self.true_positives + self.false_positives)
117117
)
118118

@@ -183,7 +183,7 @@ def merge(self, other: 'Recall') -> 'Recall':
183183
)
184184

185185
def compute(self) -> jax.Array:
186-
return base_metrics.divide_no_nan(
186+
return base.divide_no_nan(
187187
self.true_positives, (self.true_positives + self.false_negatives)
188188
)
189189

@@ -361,20 +361,20 @@ def interpolate_pr_auc(self) -> jax.Array:
361361
)
362362
p = self.true_positives + self.false_positives
363363
dp = p[: self.num_thresholds - 1] - p[1:]
364-
prec_slope = base_metrics.divide_no_nan(dtp, jnp.maximum(dp, 0))
364+
prec_slope = base.divide_no_nan(dtp, jnp.maximum(dp, 0))
365365
intercept = self.true_positives[1:] - prec_slope * p[1:]
366366

367367
# recall_relative_ratio
368368
safe_p_ratio = jnp.where(
369369
jnp.multiply(p[: self.num_thresholds - 1] > 0, p[1:] > 0),
370-
base_metrics.divide_no_nan(
370+
base.divide_no_nan(
371371
p[: self.num_thresholds - 1],
372372
jnp.maximum(p[1:], 0),
373373
),
374374
jnp.ones_like(p[1:]),
375375
)
376376
# pr_auc_increment
377-
pr_auc_increment = base_metrics.divide_no_nan(
377+
pr_auc_increment = base.divide_no_nan(
378378
prec_slope * (dtp + intercept * jnp.log(safe_p_ratio)),
379379
jnp.maximum(self.true_positives[1:] + self.false_negatives[1:], 0),
380380
)
@@ -502,10 +502,10 @@ def merge(self, other: 'AUCROC') -> 'AUCROC':
502502
)
503503

504504
def compute(self) -> jax.Array:
505-
tp_rate = base_metrics.divide_no_nan(
505+
tp_rate = base.divide_no_nan(
506506
self.true_positives, self.true_positives + self.false_negatives
507507
)
508-
fp_rate = base_metrics.divide_no_nan(
508+
fp_rate = base.divide_no_nan(
509509
self.false_positives, self.false_positives + self.true_negatives
510510
)
511511
# Threshold goes from 0 to 1, so trapezoid is negative.

src/metrax/nlp_metrics.py

Lines changed: 5 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,7 @@
1818
import flax
1919
import jax
2020
import jax.numpy as jnp
21-
from metrax import base_metrics
21+
from metrax import base
2222

2323

2424
@flax.struct.dataclass
@@ -86,7 +86,7 @@ def from_model_output(
8686
ValueError: If type of `labels` is wrong or the shapes of `predictions`
8787
and `labels` are incompatible.
8888
"""
89-
predictions = base_metrics.divide_no_nan(
89+
predictions = base.divide_no_nan(
9090
predictions, jnp.sum(predictions, axis=-1, keepdims=True)
9191
)
9292
labels_one_hot = jax.nn.one_hot(labels, predictions.shape[-1], axis=-1)
@@ -97,7 +97,7 @@ def from_model_output(
9797
if sample_weights is not None:
9898
crossentropy = crossentropy * sample_weights
9999
# Normalize by the sum of weights for each sequence.
100-
crossentropy = base_metrics.divide_no_nan(
100+
crossentropy = base.divide_no_nan(
101101
jnp.sum(crossentropy), jnp.sum(sample_weights)
102102
)
103103
else:
@@ -119,14 +119,12 @@ def merge(self, other: 'Perplexity') -> 'Perplexity':
119119

120120
def compute(self) -> jax.Array:
121121
return jnp.exp(
122-
base_metrics.divide_no_nan(
123-
self.aggregate_crossentropy, self.num_samples
124-
)
122+
base.divide_no_nan(self.aggregate_crossentropy, self.num_samples)
125123
)
126124

127125

128126
@flax.struct.dataclass
129-
class WER(base_metrics.Average):
127+
class WER(base.Average):
130128
r"""Computes Word Error Rate (WER) for speech recognition or text generation tasks.
131129
132130
Word Error Rate measures the edit distance between reference texts and

src/metrax/ranking_metrics.py

Lines changed: 4 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -17,11 +17,11 @@
1717
import flax
1818
import jax
1919
import jax.numpy as jnp
20-
from metrax import base_metrics
20+
from metrax import base
2121

2222

2323
@flax.struct.dataclass
24-
class AveragePrecisionAtK(base_metrics.Average):
24+
class AveragePrecisionAtK(base.Average):
2525
r"""Computes AP@k (average precision at k) metrics in JAX.
2626
2727
Average precision at k (AP@k) is a metric used to evaluate the performance of
@@ -63,16 +63,14 @@ def average_precision_at_ks(
6363
def compute_ap_at_k_single(relevant_labels, total_relevant, ks):
6464
cumulative_precision = jnp.where(
6565
relevant_labels,
66-
base_metrics.divide_no_nan(
66+
base.divide_no_nan(
6767
jnp.cumsum(relevant_labels),
6868
jnp.arange(1, len(relevant_labels) + 1),
6969
),
7070
0,
7171
)
7272
return jnp.array([
73-
base_metrics.divide_no_nan(
74-
jnp.sum(cumulative_precision[:k]), total_relevant
75-
)
73+
base.divide_no_nan(jnp.sum(cumulative_precision[:k]), total_relevant)
7674
for k in ks
7775
])
7876

src/metrax/regression_metrics.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -18,11 +18,11 @@
1818
import flax
1919
import jax
2020
import jax.numpy as jnp
21-
from metrax import base_metrics
21+
from metrax import base
2222

2323

2424
@flax.struct.dataclass
25-
class MSE(base_metrics.Average):
25+
class MSE(base.Average):
2626
r"""Computes the mean squared error for regression problems given `predictions` and `labels`.
2727
2828
The mean squared error without sample weights is defined as:
@@ -216,6 +216,6 @@ def compute(self) -> jax.Array:
216216
Returns:
217217
The r-squared score.
218218
"""
219-
mean = base_metrics.divide_no_nan(self.total, self.count)
219+
mean = base.divide_no_nan(self.total, self.count)
220220
sst = self.sum_of_squared_label - self.count * jnp.power(mean, 2)
221-
return 1 - base_metrics.divide_no_nan(self.sum_of_squared_error, sst)
221+
return 1 - base.divide_no_nan(self.sum_of_squared_error, sst)

0 commit comments

Comments
 (0)