11from __future__ import annotations
22
3- from typing import Any
3+ from typing import Any , Literal
44
55from ..comparators .distances import MahalanobisDistance
66from ..dataset import (
@@ -24,10 +24,12 @@ def __init__(
2424 test_pairs : bool = False ,
2525 grouping_role : ABCRole | None = None ,
2626 key : Any = "" ,
27+ faiss_mode : Literal ["base" , "fast" , "auto" ] = "auto" ,
2728 ):
2829 self .n_neighbors = n_neighbors
2930 self .two_sides = two_sides
3031 self .test_pairs = test_pairs
32+ self .faiss_mode = faiss_mode
3133 super ().__init__ (
3234 grouping_role = grouping_role , target_role = FeatureRole (), key = key
3335 )
@@ -40,13 +42,15 @@ def _execute_inner_function(
4042 n_neighbors : int | None = None ,
4143 two_sides : bool | None = None ,
4244 test_pairs : bool | None = None ,
45+ faiss_mode : Literal ["base" , "fast" , "auto" ] = "auto" ,
4346 ** kwargs ,
4447 ) -> dict :
4548 if test_pairs is not True :
4649 data = cls ._inner_function (
4750 data = grouping_data [0 ][1 ],
4851 test_data = grouping_data [1 ][1 ],
4952 n_neighbors = n_neighbors or 1 ,
53+ faiss_mode = faiss_mode ,
5054 ** kwargs ,
5155 )
5256 if two_sides is not True :
@@ -57,13 +61,15 @@ def _execute_inner_function(
5761 data = grouping_data [1 ][1 ],
5862 test_data = grouping_data [0 ][1 ],
5963 n_neighbors = n_neighbors or 1 ,
64+ faiss_mode = faiss_mode ,
6065 ** kwargs ,
6166 ),
6267 }
6368 data = cls ._inner_function (
6469 data = grouping_data [1 ][1 ],
6570 test_data = grouping_data [0 ][1 ],
6671 n_neighbors = n_neighbors or 1 ,
72+ faiss_mode = faiss_mode ,
6773 ** kwargs ,
6874 )
6975 if two_sides is not True :
@@ -74,6 +80,7 @@ def _execute_inner_function(
7480 data = grouping_data [1 ][1 ],
7581 test_data = grouping_data [0 ][1 ],
7682 n_neighbors = n_neighbors or 1 ,
83+ faiss_mode = faiss_mode ,
7784 ** kwargs ,
7885 ),
7986 }
@@ -85,14 +92,15 @@ def _inner_function(
8592 test_data : Dataset | None = None ,
8693 target_data : Dataset | None = None ,
8794 n_neighbors : int | None = None ,
95+ faiss_mode : Literal ["base" , "fast" , "auto" ] = "auto" ,
8896 ** kwargs ,
8997 ) -> Any :
90- return FaissExtension (n_neighbors = n_neighbors or 1 ).calc (
98+ return FaissExtension (n_neighbors = n_neighbors or 1 , faiss_mode = faiss_mode ).calc (
9199 data = data , test_data = test_data
92100 )
93101
94102 def fit (self , X : Dataset , Y : Dataset | None = None ) -> MLExecutor :
95- return FaissExtension (self .n_neighbors ).fit (X = X , Y = Y )
103+ return FaissExtension (self .n_neighbors , self . faiss_mode ).fit (X = X , Y = Y )
96104
97105 def predict (self , X : Dataset ) -> Dataset :
98106 return FaissExtension ().predict (X )
@@ -114,6 +122,7 @@ def execute(self, data: ExperimentData) -> ExperimentData:
114122 grouping_data = grouping_data ,
115123 features_fields = features_fields ,
116124 n_neighbors = self .n_neighbors ,
125+ faiss_mode = self .faiss_mode ,
117126 two_sides = self .two_sides ,
118127 test_pairs = self .test_pairs ,
119128 )
0 commit comments