Skip to content

Commit 0fa4e13

Browse files
authored
Add base retrieval class (#49)
* Add base retrieval class * Restore doc-string * doc-string edit * Add unit tests * Typo * Address comments
1 parent ec94626 commit 0fa4e13

File tree

4 files changed

+205
-60
lines changed

4 files changed

+205
-60
lines changed

keras_rs/api/layers/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@
1717
from keras_rs.src.layers.retrieval.remove_accidental_hits import (
1818
RemoveAccidentalHits,
1919
)
20+
from keras_rs.src.layers.retrieval.retrieval import Retrieval
2021
from keras_rs.src.layers.retrieval.sampling_probability_correction import (
2122
SamplingProbabilityCorrection,
2223
)

keras_rs/src/layers/retrieval/brute_force_retrieval.py

Lines changed: 11 additions & 60 deletions
Original file line numberDiff line numberDiff line change
@@ -4,10 +4,11 @@
44

55
from keras_rs.src import types
66
from keras_rs.src.api_export import keras_rs_export
7+
from keras_rs.src.layers.retrieval.retrieval import Retrieval
78

89

910
@keras_rs_export("keras_rs.layers.BruteForceRetrieval")
10-
class BruteForceRetrieval(keras.layers.Layer):
11+
class BruteForceRetrieval(Retrieval):
1112
"""Brute force top-k retrieval.
1213
1314
This layer maintains a set of candidates and is able to exactly retrieve the
@@ -60,11 +61,13 @@ def __init__(
6061
return_scores: bool = True,
6162
**kwargs: Any,
6263
) -> None:
63-
super().__init__(**kwargs)
64+
# Keep `k`, `return_scores` as separately passed args instead of keeping
65+
# them in `kwargs`. This is to ensure the user does not have to hop
66+
# to the base class to check which other args can be passed.
67+
super().__init__(k=k, return_scores=return_scores, **kwargs)
68+
6469
self.candidate_embeddings = None
6570
self.candidate_ids = None
66-
self.k = k
67-
self.return_scores = return_scores
6871

6972
if candidate_embeddings is None:
7073
if candidate_ids is not None:
@@ -84,36 +87,12 @@ def update_candidates(
8487
8588
Args:
8689
candidate_embeddings: The candidate embeddings.
87-
candidate_ids: The identifiers for the candidates. If `None` the
90+
candidate_ids: The identifiers for the candidates. If `None`, the
8891
indices of the candidates are returned instead.
8992
"""
90-
if candidate_embeddings is None:
91-
raise ValueError("`candidate_embeddings` is required")
92-
93-
if len(candidate_embeddings.shape) != 2:
94-
raise ValueError(
95-
"`candidate_embeddings` must be a tensor of rank 2 "
96-
"(num_candidates, embedding_size), received "
97-
"`candidate_embeddings` with shape "
98-
f"{candidate_embeddings.shape}"
99-
)
100-
101-
if candidate_embeddings.shape[0] < self.k:
102-
raise ValueError(
103-
"The number of candidates provided "
104-
f"({candidate_embeddings.shape[0]}) is less than the number of "
105-
f"candidates to retrieve (k={self.k})."
106-
)
107-
108-
if (
109-
candidate_ids is not None
110-
and candidate_ids.shape[0] != candidate_embeddings.shape[0]
111-
):
112-
raise ValueError(
113-
"The `candidate_embeddings` and `candidate_is` tensors must "
114-
"have the same number of rows, got tensors of shape "
115-
f"{candidate_embeddings.shape} and {candidate_ids.shape}."
116-
)
93+
self._validate_candidate_embeddings_and_ids(
94+
candidate_embeddings, candidate_ids
95+
)
11796

11897
if self.candidate_embeddings is not None:
11998
# Update of existing variables.
@@ -167,31 +146,3 @@ def call(
167146
return top_scores, top_ids
168147
else:
169148
return top_ids
170-
171-
def compute_score(
172-
self, query_embedding: types.Tensor, candidate_embedding: types.Tensor
173-
) -> types.Tensor:
174-
"""Computes the standard dot product score from queries and candidates.
175-
176-
Args:
177-
query_embedding: Tensor of query embedding corresponding to the
178-
queries for which to retrieve top candidates.
179-
candidate_embedding: Tensor of candidate embeddings.
180-
181-
Returns:
182-
The dot product of queries and candidates.
183-
"""
184-
185-
return keras.ops.matmul(
186-
query_embedding, keras.ops.transpose(candidate_embedding)
187-
)
188-
189-
def get_config(self) -> dict[str, Any]:
190-
config: dict[str, Any] = super().get_config()
191-
config.update(
192-
{
193-
"k": self.k,
194-
"return_scores": self.compute_score,
195-
}
196-
)
197-
return config
Lines changed: 127 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,127 @@
1+
import abc
2+
from typing import Any, Optional, Union
3+
4+
import keras
5+
6+
from keras_rs.src import types
7+
from keras_rs.src.api_export import keras_rs_export
8+
9+
10+
@keras_rs_export("keras_rs.layers.Retrieval")
11+
class Retrieval(keras.layers.Layer, abc.ABC):
12+
"""Retrieval base abstract class.
13+
14+
This layer provides a common interface for all retrieval layers. In order
15+
to implement a custom retrieval layer, this abstract class should be
16+
subclassed.
17+
18+
Args:
19+
k: int. Number of candidates to retrieve.
20+
return_scores: bool. When `True`, this layer returns a tuple with the
21+
top scores and the top identifiers. When `False`, this layer returns
22+
a single tensor with the top identifiers.
23+
"""
24+
25+
def __init__(
26+
self,
27+
k: int = 10,
28+
return_scores: bool = True,
29+
**kwargs: Any,
30+
) -> None:
31+
super().__init__(**kwargs)
32+
self.k = k
33+
self.return_scores = return_scores
34+
35+
def _validate_candidate_embeddings_and_ids(
36+
self,
37+
candidate_embeddings: types.Tensor,
38+
candidate_ids: Optional[types.Tensor] = None,
39+
) -> None:
40+
"""Validates inputs to `update_candidates()`."""
41+
42+
if candidate_embeddings is None:
43+
raise ValueError("`candidate_embeddings` is required.")
44+
45+
if len(candidate_embeddings.shape) != 2:
46+
raise ValueError(
47+
"`candidate_embeddings` must be a tensor of rank 2 "
48+
"(num_candidates, embedding_size), received "
49+
"`candidate_embeddings` with shape "
50+
f"{candidate_embeddings.shape}"
51+
)
52+
53+
if candidate_embeddings.shape[0] < self.k:
54+
raise ValueError(
55+
"The number of candidates provided "
56+
f"({candidate_embeddings.shape[0]}) is less than the number of "
57+
f"candidates to retrieve (k={self.k})."
58+
)
59+
60+
if (
61+
candidate_ids is not None
62+
and candidate_ids.shape[0] != candidate_embeddings.shape[0]
63+
):
64+
raise ValueError(
65+
"The `candidate_embeddings` and `candidate_is` tensors must "
66+
"have the same number of rows, got tensors of shape "
67+
f"{candidate_embeddings.shape} and {candidate_ids.shape}."
68+
)
69+
70+
@abc.abstractmethod
71+
def update_candidates(
72+
self,
73+
candidate_embeddings: types.Tensor,
74+
candidate_ids: Optional[types.Tensor] = None,
75+
) -> None:
76+
"""Update the set of candidates and optionally their candidate IDs.
77+
78+
Args:
79+
candidate_embeddings: The candidate embeddings.
80+
candidate_ids: The identifiers for the candidates. If `None`, the
81+
indices of the candidates are returned instead.
82+
"""
83+
pass
84+
85+
@abc.abstractmethod
86+
def call(
87+
self, inputs: types.Tensor
88+
) -> Union[types.Tensor, tuple[types.Tensor, types.Tensor]]:
89+
"""Returns the top candidates for the query passed as input.
90+
91+
Args:
92+
inputs: the query for which to return top candidates.
93+
94+
Returns:
95+
A tuple with the top scores and the top identifiers if
96+
`returns_scores` is True, otherwise a tensor with the top
97+
identifiers.
98+
"""
99+
pass
100+
101+
def compute_score(
102+
self, query_embedding: types.Tensor, candidate_embedding: types.Tensor
103+
) -> types.Tensor:
104+
"""Computes the standard dot product score from queries and candidates.
105+
106+
Args:
107+
query_embedding: Tensor of query embedding corresponding to the
108+
queries for which to retrieve top candidates.
109+
candidate_embedding: Tensor of candidate embeddings.
110+
111+
Returns:
112+
The dot product of queries and candidates.
113+
"""
114+
115+
return keras.ops.matmul(
116+
query_embedding, keras.ops.transpose(candidate_embedding)
117+
)
118+
119+
def get_config(self) -> dict[str, Any]:
120+
config: dict[str, Any] = super().get_config()
121+
config.update(
122+
{
123+
"k": self.k,
124+
"return_scores": self.compute_score,
125+
}
126+
)
127+
return config
Lines changed: 66 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,66 @@
1+
import keras
2+
from absl.testing import parameterized
3+
4+
from keras_rs.src import testing
5+
from keras_rs.src.layers.retrieval.retrieval import Retrieval
6+
7+
8+
class DummyRetrieval(Retrieval):
9+
def update_candidates(self, candidate_embeddings, candidate_ids=None):
10+
pass
11+
12+
def call(self, inputs):
13+
pass
14+
15+
16+
class RetrievalTest(testing.TestCase, parameterized.TestCase):
17+
def setUp(self):
18+
self.layer = DummyRetrieval(k=5)
19+
20+
@parameterized.named_parameters(
21+
("embeddings_none", None, None, "`candidate_embeddings` is required."),
22+
(
23+
"embeddings_rank_1",
24+
keras.random.normal(shape=(10,)),
25+
None,
26+
"`candidate_embeddings` must be a tensor of rank 2",
27+
),
28+
(
29+
"embeddings_smaller_than_k",
30+
keras.random.normal(shape=(3, 10)),
31+
None,
32+
"The number of candidates provided \(3\) is less than",
33+
),
34+
(
35+
"embeddings_ids_shape",
36+
keras.random.normal(shape=(6, 10)),
37+
keras.random.randint(shape=(4,), minval=0, maxval=3),
38+
"The `candidate_embeddings` and `candidate_is` tensors must have "
39+
"the same number of rows",
40+
),
41+
)
42+
def test_validate_candidate_embeddings_and_ids(
43+
self, candidate_embeddings, candidate_ids, error_msg
44+
):
45+
with self.assertRaisesRegex(ValueError, error_msg):
46+
self.layer._validate_candidate_embeddings_and_ids(
47+
candidate_embeddings, candidate_ids
48+
)
49+
50+
def test_call_not_overridden(self):
51+
class DummyRetrieval(Retrieval):
52+
def update_candidates(
53+
self, candidate_embeddings, candidate_ids=None
54+
):
55+
pass
56+
57+
with self.assertRaises(TypeError):
58+
DummyRetrieval(k=5)
59+
60+
def test_update_candidates_not_overridden(self):
61+
class DummyRetrieval(Retrieval):
62+
def call(self, inputs):
63+
pass
64+
65+
with self.assertRaises(TypeError):
66+
DummyRetrieval(k=5)

0 commit comments

Comments
 (0)