Skip to content

Commit cbaddbc

Browse files
committed
Merge branch 'dev/add_faiss_mode'
# Conflicts: # hypex/matching.py # tests/test_tutorials.py
2 parents 1575426 + faf5b76 commit cbaddbc

File tree

4 files changed

+30
-14
lines changed

4 files changed

+30
-14
lines changed

hypex/extensions/faiss.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -11,8 +11,9 @@
1111

1212

1313
class FaissExtension(MLExtension):
14-
def __init__(self, n_neighbors: int = 1):
14+
def __init__(self, n_neighbors: int = 1, faiss_mode: Literal["base", "fast", "auto"] = "auto"):
1515
self.n_neighbors = n_neighbors
16+
self.faiss_mode = faiss_mode
1617
super().__init__()
1718

1819
@staticmethod
@@ -52,7 +53,7 @@ def _calc_pandas(
5253
X = data.data.values
5354
if mode in ["auto", "fit"]:
5455
self.index = faiss.IndexFlatL2(X.shape[1])
55-
if len(X) > 1_000_000:
56+
if (len(X) > 1_000_000 and self.faiss_mode == "auto") or self.faiss_mode == "fast":
5657
self.index = faiss.IndexIVFFlat(self.index, 1, 1000)
5758
self.index.train(X)
5859
self.index.add(X)

hypex/matching.py

Lines changed: 13 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -59,14 +59,15 @@ class Matching(ExperimentShell):
5959

6060
@staticmethod
6161
def _make_experiment(
62-
group_match: bool = False,
63-
distance: Literal["mahalanobis", "l2"] = "mahalanobis",
64-
metric: Literal["atc", "att", "ate"] = "ate",
65-
bias_estimation: bool = True,
66-
quality_tests: (
67-
Literal["smd", "psi", "ks-test", "repeats", "t-test", "auto"]
68-
| list[Literal["smd", "psi", "ks-test", "repeats", "t-test", "auto"]]
69-
) = "auto",
62+
group_match: bool = False,
63+
distance: Literal["mahalanobis", "l2"] = "mahalanobis",
64+
metric: Literal["atc", "att", "ate"] = "ate",
65+
bias_estimation: bool = True,
66+
quality_tests: (
67+
Literal["smd", "psi", "ks-test", "repeats", "t-test", "auto"]
68+
| list[Literal["smd", "psi", "ks-test", "repeats", "t-test", "auto"]]
69+
) = "auto",
70+
faiss_mode: Literal["base", "fast", "auto"] = "auto",
7071
) -> Experiment:
7172
"""Creates an experiment configuration with specified matching parameters.
7273
@@ -109,6 +110,7 @@ def _make_experiment(
109110
grouping_role=TreatmentRole(),
110111
two_sides=two_sides,
111112
test_pairs=test_pairs,
113+
faiss_mode=faiss_mode,
112114
)
113115
]
114116
if bias_estimation:
@@ -160,10 +162,12 @@ def __init__(
160162
Literal["smd", "psi", "ks-test", "repeats", "t-test", "auto"]
161163
| list[Literal["smd", "psi", "ks-test", "repeats", "t-test", "auto"]]
162164
) = "auto",
165+
faiss_mode: Literal["base", "fast", "auto"] = "auto",
166+
163167
):
164168
super().__init__(
165169
experiment=self._make_experiment(
166-
group_match, distance, metric, bias_estimation, quality_tests
170+
group_match, distance, metric, bias_estimation, quality_tests, faiss_mode
167171
),
168172
output=MatchingOutput(GroupExperiment if group_match else MatchingAnalyzer),
169173
)

hypex/ml/faiss.py

Lines changed: 12 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
from __future__ import annotations
22

3-
from typing import Any
3+
from typing import Any, Literal
44

55
from ..comparators.distances import MahalanobisDistance
66
from ..dataset import (
@@ -24,10 +24,12 @@ def __init__(
2424
test_pairs: bool = False,
2525
grouping_role: ABCRole | None = None,
2626
key: Any = "",
27+
faiss_mode: Literal["base", "fast", "auto"] = "auto",
2728
):
2829
self.n_neighbors = n_neighbors
2930
self.two_sides = two_sides
3031
self.test_pairs = test_pairs
32+
self.faiss_mode = faiss_mode
3133
super().__init__(
3234
grouping_role=grouping_role, target_role=FeatureRole(), key=key
3335
)
@@ -40,13 +42,15 @@ def _execute_inner_function(
4042
n_neighbors: int | None = None,
4143
two_sides: bool | None = None,
4244
test_pairs: bool | None = None,
45+
faiss_mode: Literal["base", "fast", "auto"] = "auto",
4346
**kwargs,
4447
) -> dict:
4548
if test_pairs is not True:
4649
data = cls._inner_function(
4750
data=grouping_data[0][1],
4851
test_data=grouping_data[1][1],
4952
n_neighbors=n_neighbors or 1,
53+
faiss_mode=faiss_mode,
5054
**kwargs,
5155
)
5256
if two_sides is not True:
@@ -57,13 +61,15 @@ def _execute_inner_function(
5761
data=grouping_data[1][1],
5862
test_data=grouping_data[0][1],
5963
n_neighbors=n_neighbors or 1,
64+
faiss_mode=faiss_mode,
6065
**kwargs,
6166
),
6267
}
6368
data = cls._inner_function(
6469
data=grouping_data[1][1],
6570
test_data=grouping_data[0][1],
6671
n_neighbors=n_neighbors or 1,
72+
faiss_mode=faiss_mode,
6773
**kwargs,
6874
)
6975
if two_sides is not True:
@@ -74,6 +80,7 @@ def _execute_inner_function(
7480
data=grouping_data[1][1],
7581
test_data=grouping_data[0][1],
7682
n_neighbors=n_neighbors or 1,
83+
faiss_mode=faiss_mode,
7784
**kwargs,
7885
),
7986
}
@@ -85,14 +92,15 @@ def _inner_function(
8592
test_data: Dataset | None = None,
8693
target_data: Dataset | None = None,
8794
n_neighbors: int | None = None,
95+
faiss_mode: Literal["base", "fast", "auto"] = "auto",
8896
**kwargs,
8997
) -> Any:
90-
return FaissExtension(n_neighbors=n_neighbors or 1).calc(
98+
return FaissExtension(n_neighbors=n_neighbors or 1, faiss_mode=faiss_mode).calc(
9199
data=data, test_data=test_data
92100
)
93101

94102
def fit(self, X: Dataset, Y: Dataset | None = None) -> MLExecutor:
95-
return FaissExtension(self.n_neighbors).fit(X=X, Y=Y)
103+
return FaissExtension(self.n_neighbors, self.faiss_mode).fit(X=X, Y=Y)
96104

97105
def predict(self, X: Dataset) -> Dataset:
98106
return FaissExtension().predict(X)
@@ -114,6 +122,7 @@ def execute(self, data: ExperimentData) -> ExperimentData:
114122
grouping_data=grouping_data,
115123
features_fields=features_fields,
116124
n_neighbors=self.n_neighbors,
125+
faiss_mode=self.faiss_mode,
117126
two_sides=self.two_sides,
118127
test_pairs=self.test_pairs,
119128
)

tests/test_tutorials.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -170,6 +170,8 @@ def test_matchingtest(matching_data):
170170
"matching-atc": Matching(metric="atc"),
171171
"matching-att": Matching(metric="att"),
172172
"matching-l2": Matching(distance="l2", metric="att"),
173+
"matching-faiss-auto": Matching(distance="l2", faiss_mode="auto"),
174+
"matching-faiss_base": Matching(distance="mahalanobis", faiss_mode="base"),
173175
}
174176

175177
for test_name in mapping.keys():

0 commit comments

Comments
 (0)