@@ -35,6 +35,7 @@ def __init__(
3535 path : Optional [str ] = None ,
3636 distance_strategy : str = "euclidean" ,
3737 normalize_L2 : bool = False ,
38+ embedding_model_dims : int = 1536 ,
3839 ):
3940 """
4041 Initialize the FAISS vector store.
@@ -51,6 +52,7 @@ def __init__(
5152 self .path = path or f"/tmp/faiss/{ collection_name } "
5253 self .distance_strategy = distance_strategy
5354 self .normalize_L2 = normalize_L2
55+ self .embedding_model_dims = embedding_model_dims
5456
5557 # Initialize storage structures
5658 self .index = None
@@ -145,13 +147,12 @@ def _parse_output(self, scores, ids, limit=None) -> List[OutputData]:
145147
146148 return results
147149
148- def create_col (self , name : str , vector_size : int = 1536 , distance : str = None ):
150+ def create_col (self , name : str , distance : str = None ):
149151 """
150152 Create a new collection.
151153
152154 Args:
153155 name (str): Name of the collection.
154- vector_size (int, optional): Dimensionality of vectors. Defaults to 1536.
155156 distance (str, optional): Distance metric to use. Overrides the distance_strategy
156157 passed during initialization. Defaults to None.
157158
@@ -162,9 +163,9 @@ def create_col(self, name: str, vector_size: int = 1536, distance: str = None):
162163
163164 # Create index based on distance strategy
164165 if distance_strategy .lower () == "inner_product" or distance_strategy .lower () == "cosine" :
165- self .index = faiss .IndexFlatIP (vector_size )
166+ self .index = faiss .IndexFlatIP (self . embedding_model_dims )
166167 else :
167- self .index = faiss .IndexFlatL2 (vector_size )
168+ self .index = faiss .IndexFlatL2 (self . embedding_model_dims )
168169
169170 self .collection_name = name
170171
0 commit comments