1212from faiss .contrib import factory_tools
1313from faiss .contrib import datasets
1414
15+
1516class TestFactory (unittest .TestCase ):
1617
1718 def test_factory_1 (self ):
@@ -40,7 +41,6 @@ def test_factory_2(self):
4041 index = faiss .index_factory (12 , "SQ8" )
4142 assert index .code_size == 12
4243
43-
4444 def test_factory_3 (self ):
4545
4646 index = faiss .index_factory (12 , "IVF10,PQ4" )
@@ -73,7 +73,8 @@ def test_factory_HNSW(self):
7373 def test_factory_HNSW_newstyle (self ):
7474 index = faiss .index_factory (12 , "HNSW32,Flat" )
7575 assert index .storage .sa_code_size () == 12 * 4
76- index = faiss .index_factory (12 , "HNSW32,SQ8" , faiss .METRIC_INNER_PRODUCT )
76+ index = faiss .index_factory (12 , "HNSW32,SQ8" ,
77+ faiss .METRIC_INNER_PRODUCT )
7778 assert index .storage .sa_code_size () == 12
7879 assert index .metric_type == faiss .METRIC_INNER_PRODUCT
7980 index = faiss .index_factory (12 , "HNSW,PQ4" )
@@ -131,7 +132,8 @@ def test_factory_fast_scan(self):
131132 self .assertEqual (index .pq .nbits , 4 )
132133 index = faiss .index_factory (56 , "PQ28x4fs_64" )
133134 self .assertEqual (index .bbs , 64 )
134- index = faiss .index_factory (56 , "IVF50,PQ28x4fs_64" , faiss .METRIC_INNER_PRODUCT )
135+ index = faiss .index_factory (56 , "IVF50,PQ28x4fs_64" ,
136+ faiss .METRIC_INNER_PRODUCT )
135137 self .assertEqual (index .bbs , 64 )
136138 self .assertEqual (index .nlist , 50 )
137139 self .assertTrue (index .cp .spherical )
@@ -158,7 +160,6 @@ def test_parenthesis_refine(self):
158160 self .assertEqual (rf .pq .M , 25 )
159161 self .assertEqual (rf .pq .nbits , 12 )
160162
161-
162163 def test_parenthesis_refine_2 (self ):
163164 # Refine applies on the whole index including pre-transforms
164165 index = faiss .index_factory (50 , "PCA32,IVF32,Flat,Refine(PQ25x12)" )
@@ -264,6 +265,19 @@ def test_idmap2_prefix(self):
264265 index = faiss .downcast_index (index )
265266 self .assertEqual (index .__class__ , faiss .IndexIDMap2 )
266267
268+ def test_idmap_refine (self ):
269+ index = faiss .index_factory (8 , "IDMap,PQ4x4fs,RFlat" )
270+ self .assertEqual (index .__class__ , faiss .IndexIDMap )
271+ refine_index = faiss .downcast_index (index .index )
272+ self .assertEqual (refine_index .__class__ , faiss .IndexRefineFlat )
273+ base_index = faiss .downcast_index (refine_index .base_index )
274+ self .assertEqual (base_index .__class__ , faiss .IndexPQFastScan )
275+
276+ # Index now works with add_with_ids, but not with add
277+ index .train (np .zeros ((16 , 8 )))
278+ index .add_with_ids (np .zeros ((16 , 8 )), np .arange (16 ))
279+ self .assertRaises (RuntimeError , index .add , np .zeros ((16 , 8 )))
280+
267281 def test_ivf_hnsw (self ):
268282 index = faiss .index_factory (123 , "IVF100_HNSW,Flat" )
269283 quantizer = faiss .downcast_index (index .quantizer )
@@ -337,4 +351,4 @@ def test_replace_vt(self):
337351 index = faiss .IndexIVFSpectralHash (faiss .IndexFlat (10 ), 10 , 20 , 10 , 1 )
338352 index .replace_vt (faiss .ITQTransform (10 , 10 ))
339353 gc .collect ()
340- index .vt .d_out # this should not crash
354+ index .vt .d_out # this should not crash
0 commit comments