Skip to content

Commit bcd6304

Browse files
authored
feat: Support for bge-reranker-v2-m3 (#118)
* adds support for bge-reranker-v2-m3 * impl Into<OnnxSource> to avoid breaking existing code * adds bge-reranker-v2-m3 model to README
1 parent 18cad72 commit bcd6304

File tree

7 files changed

+124
-11
lines changed

7 files changed

+124
-11
lines changed

Cargo.toml

+2-1
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,8 @@ authors = [
1212
"Timon Vonk <[email protected]>",
1313
"Luya Wang <[email protected]>",
1414
15-
"Denny Wong <[email protected]>"
15+
"Denny Wong <[email protected]>",
16+
"Alex Rozgo <[email protected]>"
1617
]
1718
documentation = "https://docs.rs/fastembed"
1819
repository = "https://github.com/Anush008/fastembed-rs"

README.md

+1
Original file line numberDiff line numberDiff line change
@@ -62,6 +62,7 @@ The default model is Flag Embedding, which is top of the [MTEB](https://huggingf
6262
### Reranking
6363

6464
- [**BAAI/bge-reranker-base**](https://huggingface.co/BAAI/bge-reranker-base)
65+
- [**BAAI/bge-reranker-v2-m3**](https://huggingface.co/BAAI/bge-reranker-v2-m3)
6566
- [**jinaai/jina-reranker-v1-turbo-en**](https://huggingface.co/jinaai/jina-reranker-v1-turbo-en)
6667
- [**jinaai/jina-reranker-v2-base-multiligual**](https://huggingface.co/jinaai/jina-reranker-v2-base-multilingual)
6768

src/lib.rs

+1-1
Original file line numberDiff line numberDiff line change
@@ -78,7 +78,7 @@ pub use crate::models::{
7878
pub use crate::output::{EmbeddingOutput, OutputKey, OutputPrecedence, SingleBatchOutput};
7979
pub use crate::pooling::Pooling;
8080
pub use crate::reranking::{
81-
RerankInitOptions, RerankInitOptionsUserDefined, RerankResult, TextRerank,
81+
OnnxSource, RerankInitOptions, RerankInitOptionsUserDefined, RerankResult, TextRerank,
8282
UserDefinedRerankingModel,
8383
};
8484
pub use crate::sparse_text_embedding::{

src/models/reranking.rs

+13
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,8 @@ use std::fmt::Display;
44
pub enum RerankerModel {
55
/// BAAI/bge-reranker-base
66
BGERerankerBase,
7+
/// rozgo/bge-reranker-v2-m3
8+
BGERerankerV2M3,
79
/// jinaai/jina-reranker-v1-turbo-en
810
JINARerankerV1TurboEn,
911
/// jinaai/jina-reranker-v2-base-multilingual
@@ -17,18 +19,28 @@ pub fn reranker_model_list() -> Vec<RerankerModelInfo> {
1719
description: String::from("reranker model for English and Chinese"),
1820
model_code: String::from("BAAI/bge-reranker-base"),
1921
model_file: String::from("onnx/model.onnx"),
22+
additional_files: vec![],
23+
},
24+
RerankerModelInfo {
25+
model: RerankerModel::BGERerankerV2M3,
26+
description: String::from("reranker model for multilingual"),
27+
model_code: String::from("rozgo/bge-reranker-v2-m3"),
28+
model_file: String::from("model.onnx"),
29+
additional_files: vec![String::from("model.onnx.data")],
2030
},
2131
RerankerModelInfo {
2232
model: RerankerModel::JINARerankerV1TurboEn,
2333
description: String::from("reranker model for English"),
2434
model_code: String::from("jinaai/jina-reranker-v1-turbo-en"),
2535
model_file: String::from("onnx/model.onnx"),
36+
additional_files: vec![],
2637
},
2738
RerankerModelInfo {
2839
model: RerankerModel::JINARerankerV2BaseMultiligual,
2940
description: String::from("reranker model for multilingual"),
3041
model_code: String::from("jinaai/jina-reranker-v2-base-multilingual"),
3142
model_file: String::from("onnx/model.onnx"),
43+
additional_files: vec![],
3244
},
3345
];
3446
reranker_model_list
@@ -41,6 +53,7 @@ pub struct RerankerModelInfo {
4153
pub description: String,
4254
pub model_code: String,
4355
pub model_file: String,
56+
pub additional_files: Vec<String>,
4457
}
4558

4659
impl Display for RerankerModel {

src/reranking/impl.rs

+14-3
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,7 @@ use tokenizers::Tokenizer;
1717
#[cfg(feature = "online")]
1818
use super::RerankInitOptions;
1919
use super::{
20-
RerankInitOptionsUserDefined, RerankResult, TextRerank, UserDefinedRerankingModel,
20+
OnnxSource, RerankInitOptionsUserDefined, RerankResult, TextRerank, UserDefinedRerankingModel,
2121
DEFAULT_BATCH_SIZE,
2222
};
2323

@@ -70,6 +70,13 @@ impl TextRerank {
7070
let model_file_reference = model_repo
7171
.get(&model_file_name)
7272
.unwrap_or_else(|_| panic!("Failed to retrieve model file: {}", model_file_name));
73+
let additional_files = TextRerank::get_model_info(&model_name).additional_files;
74+
for additional_file in additional_files {
75+
let _additional_file_reference =
76+
model_repo.get(&additional_file).unwrap_or_else(|_| {
77+
panic!("Failed to retrieve additional file: {}", additional_file)
78+
});
79+
}
7380

7481
let session = Session::builder()?
7582
.with_execution_providers(execution_providers)?
@@ -98,8 +105,12 @@ impl TextRerank {
98105
let session = Session::builder()?
99106
.with_execution_providers(execution_providers)?
100107
.with_optimization_level(GraphOptimizationLevel::Level3)?
101-
.with_intra_threads(threads)?
102-
.commit_from_memory(&model.onnx_file)?;
108+
.with_intra_threads(threads)?;
109+
110+
let session = match &model.onnx_source {
111+
OnnxSource::Memory(bytes) => session.commit_from_memory(bytes)?,
112+
OnnxSource::File(path) => session.commit_from_file(path)?,
113+
};
103114

104115
let tokenizer = load_tokenizer(model.tokenizer_files, max_length)?;
105116
Ok(Self::new(tokenizer, session))

src/reranking/init.rs

+24-3
Original file line numberDiff line numberDiff line change
@@ -99,20 +99,41 @@ impl From<RerankInitOptions> for RerankInitOptionsUserDefined {
9999
}
100100
}
101101

102+
/// Enum for the source of the onnx file
103+
///
104+
/// User-defined models can either be in memory or on disk
105+
#[derive(Debug, Clone, PartialEq, Eq)]
106+
pub enum OnnxSource {
107+
Memory(Vec<u8>),
108+
File(PathBuf),
109+
}
110+
111+
impl From<Vec<u8>> for OnnxSource {
112+
fn from(bytes: Vec<u8>) -> Self {
113+
OnnxSource::Memory(bytes)
114+
}
115+
}
116+
117+
impl From<PathBuf> for OnnxSource {
118+
fn from(path: PathBuf) -> Self {
119+
OnnxSource::File(path)
120+
}
121+
}
122+
102123
/// Struct for "bring your own" reranking models
103124
///
104125
/// The onnx_file and tokenizer_files are expecting the files' bytes
105126
#[derive(Debug, Clone, PartialEq, Eq)]
106127
#[non_exhaustive]
107128
pub struct UserDefinedRerankingModel {
108-
pub onnx_file: Vec<u8>,
129+
pub onnx_source: OnnxSource,
109130
pub tokenizer_files: TokenizerFiles,
110131
}
111132

112133
impl UserDefinedRerankingModel {
113-
pub fn new(onnx_file: Vec<u8>, tokenizer_files: TokenizerFiles) -> Self {
134+
pub fn new(onnx_source: impl Into<OnnxSource>, tokenizer_files: TokenizerFiles) -> Self {
114135
Self {
115-
onnx_file,
136+
onnx_source: onnx_source.into(),
116137
tokenizer_files,
117138
}
118139
}

tests/embeddings.rs

+69-3
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,7 @@ use rayon::iter::{IntoParallelRefIterator, ParallelIterator};
88

99
use fastembed::{
1010
read_file_to_bytes, Embedding, EmbeddingModel, ImageEmbedding, ImageInitOptions, InitOptions,
11-
InitOptionsUserDefined, Pooling, QuantizationMode, RerankInitOptions,
11+
InitOptionsUserDefined, OnnxSource, Pooling, QuantizationMode, RerankInitOptions,
1212
RerankInitOptionsUserDefined, RerankerModel, SparseInitOptions, SparseTextEmbedding,
1313
TextEmbedding, TextRerank, TokenizerFiles, UserDefinedEmbeddingModel,
1414
UserDefinedRerankingModel, DEFAULT_CACHE_DIR,
@@ -284,6 +284,8 @@ fn test_rerank() {
284284
.par_iter()
285285
.for_each(|supported_model| {
286286

287+
println!("supported_model: {:?}", supported_model);
288+
287289
let result = TextRerank::try_new(RerankInitOptions::new(supported_model.model.clone()))
288290
.unwrap();
289291

@@ -300,14 +302,78 @@ fn test_rerank() {
300302
.unwrap();
301303

302304
assert_eq!(results.len(), documents.len(), "rerank model {:?} failed", supported_model);
303-
assert_eq!(results[0].document.as_ref().unwrap(), "panda is an animal");
304-
assert_eq!(results[1].document.as_ref().unwrap(), "The giant panda, sometimes called a panda bear or simply panda, is a bear species endemic to China.");
305+
306+
let option_a = "panda is an animal";
307+
let option_b = "The giant panda, sometimes called a panda bear or simply panda, is a bear species endemic to China.";
308+
309+
assert!(
310+
results[0].document.as_ref().unwrap() == option_a ||
311+
results[0].document.as_ref().unwrap() == option_b
312+
);
313+
assert!(
314+
results[1].document.as_ref().unwrap() == option_a ||
315+
results[1].document.as_ref().unwrap() == option_b
316+
);
317+
assert_ne!(results[0].document, results[1].document, "The top two results should be different");
305318

306319
// Clear the model cache to avoid running out of space on GitHub Actions.
307320
clean_cache(supported_model.model_code.clone())
308321
});
309322
}
310323

324+
#[test]
325+
fn test_user_defined_reranking_large_model() {
326+
// Setup model to download from Hugging Face
327+
let cache = hf_hub::Cache::new(std::path::PathBuf::from(fastembed::DEFAULT_CACHE_DIR));
328+
let api = hf_hub::api::sync::ApiBuilder::from_cache(cache)
329+
.with_progress(true)
330+
.build()
331+
.expect("Failed to build API from cache");
332+
let model_repo = api.model("rozgo/bge-reranker-v2-m3".to_string());
333+
334+
// Download the onnx model file
335+
let onnx_file = model_repo.download("model.onnx").unwrap();
336+
// Onnx model exceeds the limit of 2GB for a file, so we need to download the data file separately
337+
let _onnx_data_file = model_repo.get("model.onnx.data").unwrap();
338+
339+
// OnnxSource::File is used to load the onnx file using onnx session builder commit_from_file
340+
let onnx_source = OnnxSource::File(onnx_file);
341+
342+
// Load the tokenizer files
343+
let tokenizer_files: TokenizerFiles = TokenizerFiles {
344+
tokenizer_file: read_file_to_bytes(&model_repo.get("tokenizer.json").unwrap()).unwrap(),
345+
config_file: read_file_to_bytes(&model_repo.get("config.json").unwrap()).unwrap(),
346+
special_tokens_map_file: read_file_to_bytes(
347+
&model_repo.get("special_tokens_map.json").unwrap(),
348+
)
349+
.unwrap(),
350+
351+
tokenizer_config_file: read_file_to_bytes(
352+
&model_repo.get("tokenizer_config.json").unwrap(),
353+
)
354+
.unwrap(),
355+
};
356+
357+
let model = UserDefinedRerankingModel::new(onnx_source, tokenizer_files);
358+
359+
let user_defined_reranker =
360+
TextRerank::try_new_from_user_defined(model, Default::default()).unwrap();
361+
362+
let documents = vec![
363+
"Hello, World!",
364+
"This is an example passage.",
365+
"fastembed-rs is licensed under Apache-2.0",
366+
"Some other short text here blah blah blah",
367+
];
368+
369+
let results = user_defined_reranker
370+
.rerank("Ciao, Earth!", documents.clone(), false, None)
371+
.unwrap();
372+
373+
assert_eq!(results.len(), documents.len());
374+
assert_eq!(results.first().unwrap().index, 0);
375+
}
376+
311377
#[test]
312378
fn test_user_defined_reranking_model() {
313379
// Constitute the model in order to ensure it's downloaded and cached

0 commit comments

Comments
 (0)