Skip to content

Commit f892292

Browse files
committed
add DCGAtK to metrax
1 parent cae0c2a commit f892292

File tree

6 files changed

+155
-5
lines changed

6 files changed

+155
-5
lines changed

src/metrax/__init__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,7 @@
2323
Average = base.Average
2424
AveragePrecisionAtK = ranking_metrics.AveragePrecisionAtK
2525
BLEU = nlp_metrics.BLEU
26+
DCGAtK = ranking_metrics.DCGAtK
2627
MSE = regression_metrics.MSE
2728
Perplexity = nlp_metrics.Perplexity
2829
Precision = classification_metrics.Precision
@@ -42,6 +43,7 @@
4243
"Average",
4344
"AveragePrecisionAtK",
4445
"BLEU",
46+
"DCGAtK",
4547
"MSE",
4648
"Perplexity",
4749
"Precision",

src/metrax/metrax_test.py

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -70,6 +70,15 @@ class MetraxTest(parameterized.TestCase):
7070
'ks': KS,
7171
},
7272
),
73+
(
74+
'dcgAtK',
75+
metrax.DCGAtK,
76+
{
77+
'predictions': OUTPUT_LABELS,
78+
'labels': OUTPUT_PREDS,
79+
'ks': KS,
80+
},
81+
),
7382
(
7483
'mse',
7584
metrax.MSE,

src/metrax/nnx/__init__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@
1919
Average = nnx_metrics.Average
2020
AveragePrecisionAtK = nnx_metrics.AveragePrecisionAtK
2121
BLEU = nnx_metrics.BLEU
22+
DCGAtK = nnx_metrics.DCGAtK
2223
MSE = nnx_metrics.MSE
2324
Perplexity = nnx_metrics.Perplexity
2425
Precision = nnx_metrics.Precision
@@ -38,6 +39,7 @@
3839
"Average",
3940
"AveragePrecisionAtK",
4041
"BLEU",
42+
"DCGAtK",
4143
"MSE",
4244
"Perplexity",
4345
"Precision",

src/metrax/nnx/nnx_metrics.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -53,6 +53,13 @@ def __init__(self):
5353
super().__init__(metrax.BLEU)
5454

5555

56+
class DCGAtK(NnxWrapper):
57+
"""An NNX class for the Metrax metric DCGAtK."""
58+
59+
def __init__(self):
60+
super().__init__(metrax.DCGAtK)
61+
62+
5663
class MSE(NnxWrapper):
5764
"""An NNX class for the Metrax metric MSE."""
5865

src/metrax/ranking_metrics.py

Lines changed: 91 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,92 @@
2121
from metrax import base
2222

2323

24+
@flax.struct.dataclass
25+
class DCGAtK(base.Average):
26+
r"""Computes DCG@k (Discounted Cumulative Gain at k) metrics.
27+
28+
This implementation calculates DCG@k based on the principle:
29+
$DCG@k(y, s) = \sum_{i | \text{rank}(s_i) \le k} \text{gain}(y_i) \times
30+
\text{rank\_discount}(\text{rank}(s_i))$
31+
where $y_i$ is the label of item $i$, $s_i$ is its score,
32+
and $\text{rank}(s_i)$ is the 1-based rank of item $i$ based on its score.
33+
34+
The gain is $gain(y_i) = 2^{y_i} - 1$.
35+
The rank_discount is $1 / \log_2(\text{rank} + 1)$.
36+
"""
37+
38+
@classmethod
39+
def _calculate_dcg_at_ks(
40+
cls,
41+
predictions: jax.Array,
42+
labels: jax.Array,
43+
ks: jax.Array,
44+
) -> jax.Array:
45+
"""Computes DCG@k for each example and for each k, using 'exp2' gain.
46+
47+
This function is JIT-compiled. The gain calculation is fixed to 'exp2'.
48+
It uses jax.vmap to compute DCG for multiple k values efficiently.
49+
50+
Args:
51+
predictions: A floating point 2D array (batch_size, vocab_size)
52+
representing prediction scores. Higher scores mean higher rank.
53+
labels: A 2D array (batch_size, vocab_size) of graded relevance scores.
54+
ks: A 1D array of integers representing the k values for which DCG is
55+
computed (e.g., jnp.array([1, 5, 10])). Shape: (num_ks,).
56+
57+
Returns:
58+
A 2D array (batch_size, num_ks) containing DCG@k values.
59+
"""
60+
gains = jnp.power(2.0, labels.astype(jnp.float32)) - 1.0
61+
score_ranks = jnp.argsort(jnp.argsort(-predictions, axis=1), axis=1) + 1
62+
score_rank_discounts = 1.0 / jnp.log2(score_ranks.astype(jnp.float32) + 1.0)
63+
item_contributions = gains * score_rank_discounts
64+
65+
def _compute_dcg_at_k(k, current_item_contributions, current_score_ranks):
66+
"""Computes DCG for a single k value across all examples in a batch.
67+
68+
Args:
69+
k: A scalar JAX array representing the single 'k' (top-k) value for
70+
which DCG is to be computed.
71+
current_item_contributions: A 2D JAX array containing the pre-calculated
72+
contribution (gain * discount) for each item in each example of the
73+
batch. The shape should be (batch_size, vocab_size).
74+
current_score_ranks: A 2D JAX array containing the 1-based rank for each
75+
item in each example of the batch. The shape should be (batch_size,
76+
vocab_size).
77+
78+
Returns:
79+
A 1D JAX array containing the DCG@k for each example in the batch.
80+
The shape should be (batch_size, ).
81+
"""
82+
mask_for_k = current_score_ranks <= k
83+
dcg_at_k = jnp.sum(current_item_contributions * mask_for_k, axis=1)
84+
return dcg_at_k
85+
86+
dcg_at_ks = jax.vmap(
87+
_compute_dcg_at_k,
88+
in_axes=(0, None, None),
89+
out_axes=1, # Place the mapped axis(from ks) as the second axis
90+
)(ks, item_contributions, score_ranks)
91+
92+
return dcg_at_ks
93+
94+
@classmethod
95+
def from_model_output(
96+
cls,
97+
predictions: jax.Array,
98+
labels: jax.Array,
99+
ks: jax.Array,
100+
) -> 'DCGAtK':
101+
"""Creates a DCGAtK metric instance from model output."""
102+
dcg_at_ks = cls._calculate_dcg_at_ks(predictions, labels, ks)
103+
num_examples = jnp.array(labels.shape[0], dtype=jnp.float32)
104+
return cls(
105+
total=jnp.sum(dcg_at_ks, axis=0),
106+
count=num_examples,
107+
)
108+
109+
24110
@flax.struct.dataclass
25111
class AveragePrecisionAtK(base.Average):
26112
r"""Computes AP@k (average precision at k) metrics.
@@ -151,10 +237,10 @@ def _get_relevant_at_k(
151237
predictions: A floating point 2D array representing the prediction scores
152238
from the model. Higher scores indicate higher relevance. The shape
153239
should be (batch_size, vocab_size).
154-
labels: A multi-hot encoding (0 or 1, or counts) of the true labels. The
155-
shape should be (batch_size, vocab_size).
156-
ks: A 1D array of integers representing the k's (cut-off points) for which
157-
to compute metrics. The shape should be (|ks|).
240+
labels: A multi-hot encoding (0 or 1) of the true labels. The shape should
241+
be (batch_size, vocab_size).
242+
ks: A 1D array of integers representing the k's to compute the P@k
243+
metrics. The shape should be (|ks|).
158244
159245
Returns:
160246
relevant_at_k: A 2D array of shape (batch_size, |ks|). Each element [i, j]
@@ -279,7 +365,7 @@ class to get the number of relevant items at each k, and then divides
279365

280366
@flax.struct.dataclass
281367
class RecallAtK(TopKRankingMetric):
282-
r"""Computes R@k (recall at k) metrics in JAX.
368+
r"""Computes R@k (recall at k) metrics.
283369
284370
Recall at k (R@k) is a metric that measures the proportion of
285371
relevant items that are found in the top k recommendations, out of the

src/metrax/ranking_metrics_test.py

Lines changed: 44 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -32,12 +32,22 @@
3232
OUTPUT_PREDS = np.random.uniform(size=(BATCH_SIZE, VOCAB_SIZE)).astype(
3333
np.float32
3434
)
35+
OUTPUT_RELEVANCES = np.random.randint(
36+
0,
37+
2,
38+
size=(BATCH_SIZE, VOCAB_SIZE),
39+
).astype(np.float32)
3540
OUTPUT_LABELS_VS1 = np.random.randint(
3641
0,
3742
2,
3843
size=(BATCH_SIZE, 1),
3944
).astype(np.float32)
4045
OUTPUT_PREDS_VS1 = np.random.uniform(size=(BATCH_SIZE, 1)).astype(np.float32)
46+
OUTPUT_RELEVANCES_VS1 = np.random.randint(
47+
0,
48+
2,
49+
size=(BATCH_SIZE, 1),
50+
).astype(np.float32)
4151
# TODO(jiwonshin): Replace with keras metric once it is available in OSS.
4252
MAP_FROM_KERAS = np.array([
4353
0.2083333432674408,
@@ -59,6 +69,15 @@
5969
0.75,
6070
])
6171
R_FROM_KERAS_VS1 = np.array([0.75, 0.75, 0.75, 0.75, 0.75, 0.75])
72+
DCG_FROM_KERAS = np.array([
73+
0.25,
74+
0.880929708480835,
75+
1.255929708480835,
76+
1.5789371728897095,
77+
1.8690768480300903,
78+
2.04718017578125,
79+
])
80+
DCG_FROM_KERAS_VS1 = np.array([0.75, 0.75, 0.75, 0.75, 0.75, 0.75])
6281

6382

6483
class RankingMetricsTest(parameterized.TestCase):
@@ -150,6 +169,31 @@ def test_recallatk(self, y_true, y_pred, map_from_keras):
150169
atol=1e-05,
151170
)
152171

172+
@parameterized.named_parameters(
173+
('basic', OUTPUT_RELEVANCES, OUTPUT_PREDS, DCG_FROM_KERAS),
174+
(
175+
'vocab_size_one',
176+
OUTPUT_RELEVANCES_VS1,
177+
OUTPUT_PREDS_VS1,
178+
DCG_FROM_KERAS_VS1,
179+
),
180+
)
181+
def test_dcgatk(self, y_true, y_pred, map_from_keras):
182+
"""Test that `DCGAtK` Metric computes correct values."""
183+
ks = jnp.array([1, 2, 3, 4, 5, 6])
184+
metric = metrax.DCGAtK.from_model_output(
185+
predictions=y_pred,
186+
labels=y_true,
187+
ks=ks,
188+
)
189+
190+
np.testing.assert_allclose(
191+
metric.compute(),
192+
map_from_keras,
193+
rtol=1e-05,
194+
atol=1e-05,
195+
)
196+
153197

154198
if __name__ == '__main__':
155199
absltest.main()

0 commit comments

Comments
 (0)