Skip to content

Commit aeaccaa

Browse files
authored
feat: make Metric the basic unit for metrics, not str (#44)
* feat: make Metric the basic unit for metrics, not str * Remove old type * feat: use Metric internally everywhere
1 parent 7a74946 commit aeaccaa

8 files changed

Lines changed: 45 additions & 52 deletions

File tree

vicinity/backends/annoy.py

Lines changed: 8 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,7 @@
1616
@dataclass
1717
class AnnoyArgs(BaseArgs):
1818
dim: int = 0
19-
metric: str = "cosine"
19+
metric: Metric = Metric.COSINE
2020
internal_metric: str = "dot"
2121
trees: int = 100
2222
length: int | None = None
@@ -25,7 +25,7 @@ class AnnoyArgs(BaseArgs):
2525
class AnnoyBackend(AbstractBackend[AnnoyArgs]):
2626
argument_class = AnnoyArgs
2727
supported_metrics = {Metric.COSINE, Metric.EUCLIDEAN}
28-
inverse_metric_mapping = {
28+
inverse_metric_mapping: dict[Metric, str] = {
2929
Metric.COSINE: "dot",
3030
Metric.EUCLIDEAN: "euclidean",
3131
}
@@ -56,7 +56,6 @@ def from_vectors(
5656
if metric_enum not in cls.supported_metrics:
5757
raise ValueError(f"Metric '{metric_enum.value}' is not supported by AnnoyBackend.")
5858

59-
metric_string = metric_enum.value
6059
internal_metric = cls._map_metric_to_string(metric_enum)
6160

6261
if metric_enum == Metric.COSINE:
@@ -68,9 +67,7 @@ def from_vectors(
6867
index.add_item(i, vector)
6968
index.build(trees)
7069

71-
arguments = AnnoyArgs(
72-
dim=dim, metric=metric_string, trees=trees, length=len(vectors), internal_metric=internal_metric
73-
) # type: ignore
70+
arguments = AnnoyArgs(dim=dim, metric=metric, trees=trees, length=len(vectors), internal_metric=internal_metric) # type: ignore
7471
return AnnoyBackend(index, arguments=arguments)
7572

7673
@property
@@ -91,8 +88,10 @@ def __len__(self) -> int:
9188
def load(cls: type[AnnoyBackend], base_path: Path) -> AnnoyBackend:
9289
"""Load the vectors from a path."""
9390
path = Path(base_path) / "index.bin"
91+
9492
arguments = AnnoyArgs.load(base_path / "arguments.json")
95-
index = AnnoyIndex(arguments.dim, arguments.internal_metric) # type: ignore
93+
metric = cls._map_metric_to_string(arguments.metric)
94+
index = AnnoyIndex(arguments.dim, metric) # type: ignore
9695
index.load(str(path))
9796

9897
return cls(index, arguments=arguments)
@@ -109,11 +108,11 @@ def query(self, vectors: npt.NDArray, k: int) -> QueryResult:
109108
"""Query the backend."""
110109
out = []
111110
for vec in vectors:
112-
if self.arguments.metric == "cosine":
111+
if self.arguments.metric == Metric.COSINE:
113112
vec = normalize(vec)
114113
indices, scores = self.index.get_nns_by_vector(vec, k, include_distances=True)
115114
scores_array = np.asarray(scores)
116-
if self.arguments.metric == "cosine":
115+
if self.arguments.metric == Metric.COSINE:
117116
# Convert cosine similarity to cosine distance
118117
scores_array = 1 - scores_array
119118
out.append((np.asarray(indices), scores_array))

vicinity/backends/base.py

Lines changed: 9 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -14,19 +14,25 @@
1414

1515
@dataclass
1616
class BaseArgs:
17+
metric: Metric
18+
1719
def dump(self, file: Path) -> None:
1820
"""Dump the arguments to a file."""
1921
with open(file, "w") as f:
20-
json.dump(asdict(self), f)
22+
d = self.dict()
23+
d["metric"] = d["metric"].value
24+
json.dump(d, f)
2125

2226
@classmethod
2327
def load(cls: type[ArgType], file: Path) -> ArgType:
2428
"""Load the arguments from a file."""
2529
with open(file, "r") as f:
26-
return cls(**json.load(f))
30+
data = json.load(f)
31+
data["metric"] = Metric.from_string(data["metric"])
32+
return cls(**data)
2733

2834
def dict(self) -> dict[str, Any]:
29-
"""Dump the arguments to a string."""
35+
"""Dump the arguments to a dict."""
3036
return asdict(self)
3137

3238

vicinity/backends/basic.py

Lines changed: 6 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,7 @@
1515

1616
@dataclass
1717
class BasicArgs(BaseArgs):
18-
metric: str = "cosine"
18+
metric: Metric = Metric.COSINE
1919

2020

2121
class BasicBackend(AbstractBackend[BasicArgs], ABC):
@@ -72,11 +72,10 @@ def from_vectors(cls, vectors: npt.NDArray, metric: Union[str, Metric] = "cosine
7272
if metric_enum not in cls.supported_metrics:
7373
raise ValueError(f"Metric '{metric_enum.value}' is not supported by BasicBackend.")
7474

75-
metric = metric_enum.value
76-
arguments = BasicArgs(metric=metric)
77-
if metric == "cosine":
75+
arguments = BasicArgs(metric=metric_enum)
76+
if metric_enum == Metric.COSINE:
7877
return CosineBasicBackend(vectors, arguments)
79-
elif metric == "euclidean":
78+
elif metric_enum == Metric.EUCLIDEAN:
8079
return EuclideanBasicBackend(vectors, arguments)
8180
else:
8281
raise ValueError(f"Unsupported metric: {metric}")
@@ -88,9 +87,9 @@ def load(cls, folder: Path) -> BasicBackend:
8887
arguments = BasicArgs.load(folder / "arguments.json")
8988
with open(path, "rb") as f:
9089
vectors = np.load(f)
91-
if arguments.metric == "cosine":
90+
if arguments.metric == Metric.COSINE:
9291
return CosineBasicBackend(vectors, arguments)
93-
elif arguments.metric == "euclidean":
92+
elif arguments.metric == Metric.EUCLIDEAN:
9493
return EuclideanBasicBackend(vectors, arguments)
9594
else:
9695
raise ValueError(f"Unsupported metric: {arguments.metric}")

vicinity/backends/faiss.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -36,7 +36,7 @@
3636
class FaissArgs(BaseArgs):
3737
dim: int = 0
3838
index_type: str = "flat"
39-
metric: str = "cosine"
39+
metric: Metric = Metric.COSINE
4040
nlist: int = 100
4141
m: int = 8
4242
nbits: int = 8
@@ -122,7 +122,7 @@ def from_vectors( # noqa: C901
122122
arguments = FaissArgs(
123123
dim=dim,
124124
index_type=index_type,
125-
metric=metric_enum.value,
125+
metric=metric_enum,
126126
nlist=nlist,
127127
m=m,
128128
nbits=nbits,

vicinity/backends/hnsw.py

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,7 @@
1515
@dataclass
1616
class HNSWArgs(BaseArgs):
1717
dim: int = 0
18-
metric: str = "cosine"
18+
metric: Metric = Metric.COSINE
1919
ef_construction: int = 200
2020
m: int = 16
2121

@@ -58,7 +58,7 @@ def from_vectors(
5858
index = HnswIndex(space=metric, dim=dim)
5959
index.init_index(max_elements=vectors.shape[0], ef_construction=ef_construction, M=m)
6060
index.add_items(vectors)
61-
arguments = HNSWArgs(dim=dim, metric=metric, ef_construction=ef_construction, m=m)
61+
arguments = HNSWArgs(dim=dim, metric=metric_enum, ef_construction=ef_construction, m=m)
6262
return HNSWBackend(index, arguments=arguments)
6363

6464
@property
@@ -80,7 +80,8 @@ def load(cls: type[HNSWBackend], base_path: Path) -> HNSWBackend:
8080
"""Load the vectors from a path."""
8181
path = Path(base_path) / "index.bin"
8282
arguments = HNSWArgs.load(base_path / "arguments.json")
83-
index = HnswIndex(space=arguments.metric, dim=arguments.dim)
83+
mapped_metric = cls.inverse_metric_mapping[arguments.metric]
84+
index = HnswIndex(space=mapped_metric, dim=arguments.dim)
8485
index.load_index(str(path))
8586
return cls(index, arguments=arguments)
8687

vicinity/backends/pynndescent.py

Lines changed: 3 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,7 @@
1616
@dataclass
1717
class PyNNDescentArgs(BaseArgs):
1818
n_neighbors: int = 15
19-
metric: str = "cosine"
19+
metric: Metric = Metric.COSINE
2020

2121

2222
class PyNNDescentBackend(AbstractBackend[PyNNDescentArgs]):
@@ -49,7 +49,7 @@ def from_vectors(
4949
metric = metric_enum.value
5050

5151
index = NNDescent(vectors, n_neighbors=n_neighbors, metric=metric, **kwargs)
52-
arguments = PyNNDescentArgs(n_neighbors=n_neighbors, metric=metric)
52+
arguments = PyNNDescentArgs(n_neighbors=n_neighbors, metric=metric_enum)
5353
return cls(index=index, arguments=arguments)
5454

5555
def __len__(self) -> int:
@@ -105,10 +105,7 @@ def load(cls: type[PyNNDescentBackend], base_path: Path) -> PyNNDescentBackend:
105105
arguments = PyNNDescentArgs.load(base_path / "arguments.json")
106106
vectors = np.load(Path(base_path) / "vectors.npy")
107107

108-
metric_enum = Metric.from_string(arguments.metric)
109-
pynndescent_metric = metric_enum.value
110-
111-
index = NNDescent(vectors, n_neighbors=arguments.n_neighbors, metric=pynndescent_metric)
108+
index = NNDescent(vectors, n_neighbors=arguments.n_neighbors, metric=arguments.metric.value)
112109

113110
# Load the neighbor graph if it was saved
114111
neighbor_graph_path = base_path / "neighbor_graph.npy"

vicinity/backends/usearch.py

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,7 @@
1616
@dataclass
1717
class UsearchArgs(BaseArgs):
1818
dim: int = 0
19-
metric: str = "cos"
19+
metric: Metric = Metric.COSINE
2020
connectivity: int = 16
2121
expansion_add: int = 128
2222
expansion_search: int = 64
@@ -67,10 +67,10 @@ def from_vectors(
6767
expansion_add=expansion_add,
6868
expansion_search=expansion_search,
6969
)
70-
index.add(keys=None, vectors=vectors) # type: ignore
70+
index.add(keys=None, vectors=vectors) # type: ignore # None keys are allowed but not typed
7171
arguments = UsearchArgs(
7272
dim=dim,
73-
metric=metric,
73+
metric=metric_enum,
7474
connectivity=connectivity,
7575
expansion_add=expansion_add,
7676
expansion_search=expansion_search,
@@ -99,7 +99,7 @@ def load(cls: type[UsearchBackend], base_path: Path) -> UsearchBackend:
9999

100100
index = UsearchIndex(
101101
ndim=arguments.dim,
102-
metric=arguments.metric,
102+
metric=cls._map_metric_to_string(arguments.metric),
103103
connectivity=arguments.connectivity,
104104
expansion_add=arguments.expansion_add,
105105
expansion_search=arguments.expansion_search,
@@ -122,7 +122,7 @@ def query(self, vectors: npt.NDArray, k: int) -> QueryResult:
122122

123123
def insert(self, vectors: npt.NDArray) -> None:
124124
"""Insert vectors into the backend."""
125-
self.index.add(None, vectors) # type: ignore
125+
self.index.add(None, vectors) # type: ignore # None keys are allowed, but not typed.
126126

127127
def delete(self, indices: list[int]) -> None:
128128
"""Delete vectors from the index (not supported by Usearch)."""

vicinity/backends/voyager.py

Lines changed: 8 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -4,34 +4,28 @@
44
from pathlib import Path
55
from typing import Any, Union
66

7-
import numpy as np
87
from numpy import typing as npt
98
from voyager import Index, Space
109

1110
from vicinity.backends.base import AbstractBackend, BaseArgs
1211
from vicinity.datatypes import Backend, QueryResult
13-
from vicinity.utils import Metric, normalize
12+
from vicinity.utils import Metric
1413

1514

1615
@dataclass
1716
class VoyagerArgs(BaseArgs):
1817
dim: int = 0
19-
metric: str = "cosine"
18+
metric: Metric = Metric.COSINE
2019
ef_construction: int = 200
2120
m: int = 16
2221

2322

2423
class VoyagerBackend(AbstractBackend[VoyagerArgs]):
2524
argument_class = VoyagerArgs
2625
supported_metrics = {Metric.COSINE, Metric.EUCLIDEAN}
27-
inverse_metric_mapping = {
28-
Metric.COSINE: "cosine",
29-
Metric.EUCLIDEAN: "l2",
30-
}
31-
32-
metric_int_mapping = {
33-
"l2": 0,
34-
"cosine": 2,
26+
_metric_to_space = {
27+
Metric.COSINE: Space.Cosine,
28+
Metric.EUCLIDEAN: Space.Euclidean,
3529
}
3630

3731
def __init__(
@@ -56,13 +50,10 @@ def from_vectors(
5650
metric_enum = Metric.from_string(metric)
5751

5852
if metric_enum not in cls.supported_metrics:
59-
raise ValueError(
60-
f"Metric '{metric_enum.value}' is not supported by VoyagerBackend."
61-
)
53+
raise ValueError(f"Metric '{metric_enum.value}' is not supported by VoyagerBackend.")
6254

63-
metric = cls._map_metric_to_string(metric_enum)
55+
space = cls._metric_to_space[metric_enum]
6456
dim = vectors.shape[1]
65-
space = Space(value=cls.metric_int_mapping[metric])
6657
index = Index(
6758
space=space,
6859
num_dimensions=dim,
@@ -72,7 +63,7 @@ def from_vectors(
7263
index.add_items(vectors)
7364
return cls(
7465
index,
75-
VoyagerArgs(dim=dim, metric=metric, ef_construction=ef_construction, m=m),
66+
VoyagerArgs(dim=dim, metric=metric_enum, ef_construction=ef_construction, m=m),
7667
)
7768

7869
def query(self, query: npt.NDArray, k: int) -> QueryResult:

0 commit comments

Comments
 (0)