Skip to content

Commit 0c32530

Browse files
authored
Change order of call args in RemoveAccidentalHits layer. (#45)
The order is now `logits`, `labels`, which is more consistent with the other layers and Keras in general.
1 parent 40f0ac9 commit 0c32530

File tree

2 files changed

+27
-26
lines changed

2 files changed

+27
-26
lines changed

keras_rs/src/layers/retrieval/remove_accidental_hits.py

Lines changed: 9 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -19,8 +19,8 @@ class RemoveAccidentalHits(keras.layers.Layer):
1919

2020
def call(
2121
self,
22-
labels: types.Tensor,
2322
logits: types.Tensor,
23+
labels: types.Tensor,
2424
candidate_ids: types.Tensor,
2525
) -> types.Tensor:
2626
"""Zeroes selected logits.
@@ -29,16 +29,16 @@ def call(
2929
have the same ID as the positive candidate in that row.
3030
3131
Args:
32-
labels: one-hot labels tensor, typically
33-
`[batch_size, num_candidates]` but can have more dimensions or be
34-
1D as `[num_candidates]`.
35-
logits: logits tensor. Must have the same shape as `labels`.
36-
candidate_ids: candidate identifiers tensor, can be `[num_candidates]`
37-
or `[batch_size, num_candidates]` or have more dimensions as long
38-
as they match the last dimensions of `labels`.
32+
logits: logits tensor, typically `[batch_size, num_candidates]` but
33+
can have more dimensions or be 1D as `[num_candidates]`.
34+
labels: one-hot labels tensor, must be the same shape as `logits`.
35+
candidate_ids: candidate identifiers tensor, can be
36+
`[num_candidates]` or `[batch_size, num_candidates]` or have
37+
more dimensions as long as they match the last dimensions of
38+
`labels`.
3939
4040
Returns:
41-
logits: Modified logits.
41+
logits: Modified logits with the same shape as the input logits.
4242
"""
4343
# A more principled way is to implement
4444
# `softmax_cross_entropy_with_logits` with a input mask. Here we

keras_rs/src/layers/retrieval/remove_accidental_hits_test.py

Lines changed: 18 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -10,23 +10,24 @@
1010

1111
class RemoveAccidentalHitsTest(testing.TestCase, parameterized.TestCase):
1212
def create_inputs(self, logits_rank=2, candidate_ids_rank=1):
13-
num_candidates = 10
14-
shape_3d = (15, 20, num_candidates)
13+
shape_3d = (15, 20, 10)
1514
shape = shape_3d[-logits_rank:]
1615
candidate_ids_shape = shape_3d[-candidate_ids_rank:]
17-
rng = keras.random.SeedGenerator(42)
16+
num_candidates = shape[-1]
1817

18+
rng = keras.random.SeedGenerator(42)
19+
logits = keras.random.uniform(shape, seed=rng)
1920
labels = keras.ops.one_hot(
2021
keras.random.randint(
2122
shape[:-1], minval=0, maxval=num_candidates, seed=rng
2223
),
2324
num_candidates,
2425
)
25-
logits = keras.random.uniform(shape, seed=rng)
2626
candidate_ids = keras.random.randint(
2727
candidate_ids_shape, minval=0, maxval=num_candidates, seed=rng
2828
)
29-
return labels, logits, candidate_ids
29+
30+
return logits, labels, candidate_ids
3031

3132
@parameterized.named_parameters(
3233
[
@@ -63,12 +64,12 @@ def create_inputs(self, logits_rank=2, candidate_ids_rank=1):
6364
]
6465
)
6566
def test_call(self, logits_rank, candidate_ids_rank):
66-
labels, logits, candidate_ids = self.create_inputs(
67+
logits, labels, candidate_ids = self.create_inputs(
6768
logits_rank=logits_rank, candidate_ids_rank=candidate_ids_rank
6869
)
6970

7071
out_logits = remove_accidental_hits.RemoveAccidentalHits()(
71-
labels, logits, candidate_ids
72+
logits, labels, candidate_ids
7273
)
7374

7475
# Logits of labels are unchanged.
@@ -148,36 +149,36 @@ def test_mismatched_labels_candidates_shapes(self):
148149

149150
def test_predict(self):
150151
# Note: for predict, we test with probabilities that have a batch dim.
151-
labels, logits, candidate_ids = self.create_inputs(candidate_ids_rank=2)
152+
logits, labels, candidate_ids = self.create_inputs(candidate_ids_rank=2)
152153

153154
layer = remove_accidental_hits.RemoveAccidentalHits()
154-
in_labels = keras.layers.Input(labels.shape[1:])
155155
in_logits = keras.layers.Input(logits.shape[1:])
156+
in_labels = keras.layers.Input(labels.shape[1:])
156157
in_candidate_ids = keras.layers.Input(labels.shape[1:])
157-
out_logits = layer(in_labels, in_logits, in_candidate_ids)
158+
out_logits = layer(in_logits, in_labels, in_candidate_ids)
158159
model = keras.Model(
159-
[in_labels, in_logits, in_candidate_ids], out_logits
160+
[in_logits, in_labels, in_candidate_ids], out_logits
160161
)
161162

162-
model.predict([labels, logits, candidate_ids], batch_size=8)
163+
model.predict([logits, labels, candidate_ids], batch_size=8)
163164

164165
def test_serialization(self):
165166
layer = remove_accidental_hits.RemoveAccidentalHits()
166167
restored = deserialize(serialize(layer))
167168
self.assertDictEqual(layer.get_config(), restored.get_config())
168169

169170
def test_model_saving(self):
170-
labels, logits, candidate_ids = self.create_inputs()
171+
logits, labels, candidate_ids = self.create_inputs()
171172

172173
layer = remove_accidental_hits.RemoveAccidentalHits()
173-
in_labels = keras.layers.Input(labels.shape[1:])
174174
in_logits = keras.layers.Input(logits.shape[1:])
175+
in_labels = keras.layers.Input(labels.shape[1:])
175176
in_candidate_ids = keras.layers.Input(batch_shape=candidate_ids.shape)
176-
out_logits = layer(in_labels, in_logits, in_candidate_ids)
177+
out_logits = layer(in_logits, in_labels, in_candidate_ids)
177178
model = keras.Model(
178-
[in_labels, in_logits, in_candidate_ids], out_logits
179+
[in_logits, in_labels, in_candidate_ids], out_logits
179180
)
180181

181182
self.run_model_saving_test(
182-
model=model, input_data=[labels, logits, candidate_ids]
183+
model=model, input_data=[logits, labels, candidate_ids]
183184
)

0 commit comments

Comments
 (0)