@@ -8,7 +8,7 @@ use rayon::iter::{IntoParallelRefIterator, ParallelIterator};
8
8
9
9
use fastembed:: {
10
10
read_file_to_bytes, Embedding , EmbeddingModel , ImageEmbedding , ImageInitOptions , InitOptions ,
11
- InitOptionsUserDefined , Pooling , QuantizationMode , RerankInitOptions ,
11
+ InitOptionsUserDefined , OnnxSource , Pooling , QuantizationMode , RerankInitOptions ,
12
12
RerankInitOptionsUserDefined , RerankerModel , SparseInitOptions , SparseTextEmbedding ,
13
13
TextEmbedding , TextRerank , TokenizerFiles , UserDefinedEmbeddingModel ,
14
14
UserDefinedRerankingModel , DEFAULT_CACHE_DIR ,
@@ -284,6 +284,8 @@ fn test_rerank() {
284
284
. par_iter ( )
285
285
. for_each ( |supported_model| {
286
286
287
+ println ! ( "supported_model: {:?}" , supported_model) ;
288
+
287
289
let result = TextRerank :: try_new ( RerankInitOptions :: new ( supported_model. model . clone ( ) ) )
288
290
. unwrap ( ) ;
289
291
@@ -300,14 +302,78 @@ fn test_rerank() {
300
302
. unwrap ( ) ;
301
303
302
304
assert_eq ! ( results. len( ) , documents. len( ) , "rerank model {:?} failed" , supported_model) ;
303
- assert_eq ! ( results[ 0 ] . document. as_ref( ) . unwrap( ) , "panda is an animal" ) ;
304
- assert_eq ! ( results[ 1 ] . document. as_ref( ) . unwrap( ) , "The giant panda, sometimes called a panda bear or simply panda, is a bear species endemic to China." ) ;
305
+
306
+ let option_a = "panda is an animal" ;
307
+ let option_b = "The giant panda, sometimes called a panda bear or simply panda, is a bear species endemic to China." ;
308
+
309
+ assert ! (
310
+ results[ 0 ] . document. as_ref( ) . unwrap( ) == option_a ||
311
+ results[ 0 ] . document. as_ref( ) . unwrap( ) == option_b
312
+ ) ;
313
+ assert ! (
314
+ results[ 1 ] . document. as_ref( ) . unwrap( ) == option_a ||
315
+ results[ 1 ] . document. as_ref( ) . unwrap( ) == option_b
316
+ ) ;
317
+ assert_ne ! ( results[ 0 ] . document, results[ 1 ] . document, "The top two results should be different" ) ;
305
318
306
319
// Clear the model cache to avoid running out of space on GitHub Actions.
307
320
clean_cache ( supported_model. model_code . clone ( ) )
308
321
} ) ;
309
322
}
310
323
324
+ #[ test]
325
+ fn test_user_defined_reranking_large_model ( ) {
326
+ // Setup model to download from Hugging Face
327
+ let cache = hf_hub:: Cache :: new ( std:: path:: PathBuf :: from ( fastembed:: DEFAULT_CACHE_DIR ) ) ;
328
+ let api = hf_hub:: api:: sync:: ApiBuilder :: from_cache ( cache)
329
+ . with_progress ( true )
330
+ . build ( )
331
+ . expect ( "Failed to build API from cache" ) ;
332
+ let model_repo = api. model ( "rozgo/bge-reranker-v2-m3" . to_string ( ) ) ;
333
+
334
+ // Download the onnx model file
335
+ let onnx_file = model_repo. download ( "model.onnx" ) . unwrap ( ) ;
336
+ // Onnx model exceeds the limit of 2GB for a file, so we need to download the data file separately
337
+ let _onnx_data_file = model_repo. get ( "model.onnx.data" ) . unwrap ( ) ;
338
+
339
+ // OnnxSource::File is used to load the onnx file using onnx session builder commit_from_file
340
+ let onnx_source = OnnxSource :: File ( onnx_file) ;
341
+
342
+ // Load the tokenizer files
343
+ let tokenizer_files: TokenizerFiles = TokenizerFiles {
344
+ tokenizer_file : read_file_to_bytes ( & model_repo. get ( "tokenizer.json" ) . unwrap ( ) ) . unwrap ( ) ,
345
+ config_file : read_file_to_bytes ( & model_repo. get ( "config.json" ) . unwrap ( ) ) . unwrap ( ) ,
346
+ special_tokens_map_file : read_file_to_bytes (
347
+ & model_repo. get ( "special_tokens_map.json" ) . unwrap ( ) ,
348
+ )
349
+ . unwrap ( ) ,
350
+
351
+ tokenizer_config_file : read_file_to_bytes (
352
+ & model_repo. get ( "tokenizer_config.json" ) . unwrap ( ) ,
353
+ )
354
+ . unwrap ( ) ,
355
+ } ;
356
+
357
+ let model = UserDefinedRerankingModel :: new ( onnx_source, tokenizer_files) ;
358
+
359
+ let user_defined_reranker =
360
+ TextRerank :: try_new_from_user_defined ( model, Default :: default ( ) ) . unwrap ( ) ;
361
+
362
+ let documents = vec ! [
363
+ "Hello, World!" ,
364
+ "This is an example passage." ,
365
+ "fastembed-rs is licensed under Apache-2.0" ,
366
+ "Some other short text here blah blah blah" ,
367
+ ] ;
368
+
369
+ let results = user_defined_reranker
370
+ . rerank ( "Ciao, Earth!" , documents. clone ( ) , false , None )
371
+ . unwrap ( ) ;
372
+
373
+ assert_eq ! ( results. len( ) , documents. len( ) ) ;
374
+ assert_eq ! ( results. first( ) . unwrap( ) . index, 0 ) ;
375
+ }
376
+
311
377
#[ test]
312
378
fn test_user_defined_reranking_model ( ) {
313
379
// Constitute the model in order to ensure it's downloaded and cached
0 commit comments