-
Notifications
You must be signed in to change notification settings - Fork 30
Expand file tree
/
Copy pathbge_m3.py
More file actions
113 lines (95 loc) · 3.97 KB
/
bge_m3.py
File metadata and controls
113 lines (95 loc) · 3.97 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
import logging
from typing import Dict, List
from scipy.sparse import csr_array
import numpy as np
from pymilvus.model.base import BaseEmbeddingFunction
from pymilvus.model.utils import import_FlagEmbedding, import_datasets
from pymilvus.model.sparse.utils import stack_sparse_embeddings
logger = logging.getLogger(__name__)
logger.setLevel(logging.DEBUG)
class BGEM3EmbeddingFunction(BaseEmbeddingFunction):
def __init__(
self,
model_name: str = "BAAI/bge-m3",
batch_size: int = 16,
device: str = None,
normalize_embeddings: bool = True,
use_fp16: bool = False,
return_dense: bool = True,
return_sparse: bool = True,
return_colbert_vecs: bool = False,
**kwargs,
):
import_datasets()
import_FlagEmbedding()
try:
from FlagEmbedding import BGEM3FlagModel
except AttributeError as e:
import sys
if "google.colab" in sys.modules and "ListView" in str(e):
print(
"\033[91mIt looks like you're running on Google Colab. Please restart the session to resolve this issue.\033[0m")
print(
"\033[91mFor further details, visit: https://github.com/milvus-io/milvus-model/issues/32.\033[0m")
raise
self.model_name = model_name
self.batch_size = batch_size
self.normalize_embeddings = normalize_embeddings
self.device = device
self.use_fp16 = use_fp16
if device == "cpu" and use_fp16 is True:
logger.warning(
"Using fp16 with CPU can lead to runtime errors such as 'LayerNormKernelImpl', It's recommended to set 'use_fp16 = False' when using cpu. "
)
if "devices" in kwargs:
device = kwargs["devices"]
kwargs.pop("devices")
_model_config = dict(
{
"model_name_or_path": model_name,
"device": device,
"normalize_embeddings": normalize_embeddings,
"use_fp16": use_fp16,
},
**kwargs,
)
_encode_config = {
"batch_size": batch_size,
"return_dense": return_dense,
"return_sparse": return_sparse,
"return_colbert_vecs": return_colbert_vecs,
}
self._model_config = _model_config
self._encode_config = _encode_config
self.model = BGEM3FlagModel(**self._model_config)
def __call__(self, texts: List[str]) -> Dict:
return self._encode(texts)
@property
def dim(self) -> Dict:
return {
"dense": self.model.model.model.config.hidden_size,
"colbert_vecs": self.model.model.colbert_linear.out_features,
"sparse": len(self.model.tokenizer),
}
def _encode(self, texts: List[str]) -> Dict:
output = self.model.encode(sentences=texts, **self._encode_config)
results = {}
if self._encode_config["return_dense"] is True:
results["dense"] = list(output["dense_vecs"])
if self._encode_config["return_sparse"] is True:
sparse_dim = self.dim["sparse"]
results["sparse"] = []
for sparse_vec in output["lexical_weights"]:
indices = [int(k) for k in sparse_vec]
values = np.array(list(sparse_vec.values()), dtype=np.float64)
row_indices = [0] * len(indices)
csr = csr_array((values, (row_indices, indices)), shape=(1, sparse_dim))
results["sparse"].append(csr)
results["sparse"] = stack_sparse_embeddings(results["sparse"]).tocsr()
if self._encode_config["return_colbert_vecs"] is True:
results["colbert_vecs"] = output["colbert_vecs"]
return results
def encode_queries(self, queries: List[str]) -> Dict:
return self._encode(queries)
def encode_documents(self, documents: List[str]) -> Dict:
return self._encode(documents)