1616@dataclass
1717class AnnoyArgs (BaseArgs ):
1818 dim : int = 0
19- metric : str = "cosine"
19+ metric : Metric = Metric . COSINE
2020 internal_metric : str = "dot"
2121 trees : int = 100
2222 length : int | None = None
@@ -25,7 +25,7 @@ class AnnoyArgs(BaseArgs):
2525class AnnoyBackend (AbstractBackend [AnnoyArgs ]):
2626 argument_class = AnnoyArgs
2727 supported_metrics = {Metric .COSINE , Metric .EUCLIDEAN }
28- inverse_metric_mapping = {
28+ inverse_metric_mapping : dict [ Metric , str ] = {
2929 Metric .COSINE : "dot" ,
3030 Metric .EUCLIDEAN : "euclidean" ,
3131 }
@@ -56,7 +56,6 @@ def from_vectors(
5656 if metric_enum not in cls .supported_metrics :
5757 raise ValueError (f"Metric '{ metric_enum .value } ' is not supported by AnnoyBackend." )
5858
59- metric_string = metric_enum .value
6059 internal_metric = cls ._map_metric_to_string (metric_enum )
6160
6261 if metric_enum == Metric .COSINE :
@@ -68,9 +67,7 @@ def from_vectors(
6867 index .add_item (i , vector )
6968 index .build (trees )
7069
71- arguments = AnnoyArgs (
72- dim = dim , metric = metric_string , trees = trees , length = len (vectors ), internal_metric = internal_metric
73- ) # type: ignore
70+ arguments = AnnoyArgs (dim = dim , metric = metric , trees = trees , length = len (vectors ), internal_metric = internal_metric ) # type: ignore
7471 return AnnoyBackend (index , arguments = arguments )
7572
7673 @property
@@ -91,8 +88,10 @@ def __len__(self) -> int:
9188 def load (cls : type [AnnoyBackend ], base_path : Path ) -> AnnoyBackend :
9289 """Load the vectors from a path."""
9390 path = Path (base_path ) / "index.bin"
91+
9492 arguments = AnnoyArgs .load (base_path / "arguments.json" )
95- index = AnnoyIndex (arguments .dim , arguments .internal_metric ) # type: ignore
93+ metric = cls ._map_metric_to_string (arguments .metric )
94+ index = AnnoyIndex (arguments .dim , metric ) # type: ignore
9695 index .load (str (path ))
9796
9897 return cls (index , arguments = arguments )
@@ -109,11 +108,11 @@ def query(self, vectors: npt.NDArray, k: int) -> QueryResult:
109108 """Query the backend."""
110109 out = []
111110 for vec in vectors :
112- if self .arguments .metric == "cosine" :
111+ if self .arguments .metric == Metric . COSINE :
113112 vec = normalize (vec )
114113 indices , scores = self .index .get_nns_by_vector (vec , k , include_distances = True )
115114 scores_array = np .asarray (scores )
116- if self .arguments .metric == "cosine" :
115+ if self .arguments .metric == Metric . COSINE :
117116 # Convert cosine similarity to cosine distance
118117 scores_array = 1 - scores_array
119118 out .append ((np .asarray (indices ), scores_array ))
0 commit comments