diff --git a/InstructorEmbedding/instructor.py b/InstructorEmbedding/instructor.py
index 72b3df2..f58709f 100644
--- a/InstructorEmbedding/instructor.py
+++ b/InstructorEmbedding/instructor.py
@@ -1,4 +1,3 @@
-# This script is based on the modifications from https://github.com/UKPLab/sentence-transformers
import importlib
import json
import os
@@ -24,24 +23,6 @@ def batch_to_device(batch, target_device: str):
class INSTRUCTORPooling(nn.Module):
- """Performs pooling (max or mean) on the token embeddings.
-
- Using pooling, it generates from a variable sized sentence a fixed sized sentence embedding.
- This layer also allows to use the CLS token if it is returned by the underlying word embedding model.
- You can concatenate multiple poolings together.
-
- :param word_embedding_dimension: Dimensions for the word embeddings
- :param pooling_mode: Can be a string: mean/max/cls. If set, overwrites the other pooling_mode_* settings
- :param pooling_mode_cls_token: Use the first token (CLS token) as text representations
- :param pooling_mode_max_tokens: Use max in each dimension over all tokens.
- :param pooling_mode_mean_tokens: Perform mean-pooling
- :param pooling_mode_mean_sqrt_len_tokens: Perform mean-pooling, but divide by sqrt(input_length).
- :param pooling_mode_weightedmean_tokens: Perform (position) weighted mean pooling,
- see https://arxiv.org/abs/2202.08904
- :param pooling_mode_lasttoken: Perform last token pooling,
- see https://arxiv.org/abs/2202.08904 & https://arxiv.org/abs/2201.10005
- """
-
def __init__(
self,
word_embedding_dimension: int,
@@ -65,7 +46,7 @@ def __init__(
"pooling_mode_lasttoken",
]
- if pooling_mode is not None: # Set pooling mode by string
+ if pooling_mode is not None:
pooling_mode = pooling_mode.lower()
assert pooling_mode in ["mean", "max", "cls", "weightedmean", "lasttoken"]
pooling_mode_cls_token = pooling_mode == "cls"
@@ -100,9 +81,6 @@ def __repr__(self):
return f"Pooling({self.get_config_dict()})"
def get_pooling_mode_str(self) -> str:
- """
- Returns the pooling mode as string
- """
modes = []
if self.pooling_mode_cls_token:
modes.append("cls")
@@ -120,16 +98,14 @@ def get_pooling_mode_str(self) -> str:
return "+".join(modes)
def forward(self, features):
- # print(features.keys())
token_embeddings = features["token_embeddings"]
attention_mask = features["attention_mask"]
- ## Pooling strategy
output_vectors = []
if self.pooling_mode_cls_token:
cls_token = features.get(
"cls_token_embeddings", token_embeddings[:, 0]
- ) # Take first token by default
+ )
output_vectors.append(cls_token)
if self.pooling_mode_max_tokens:
input_mask_expanded = (
@@ -137,7 +113,7 @@ def forward(self, features):
)
token_embeddings[
input_mask_expanded == 0
- ] = -1e9 # Set padding tokens to large negative value
+ ] = -1e9
max_over_time = torch.max(token_embeddings, 1)[0]
output_vectors.append(max_over_time)
if self.pooling_mode_mean_tokens or self.pooling_mode_mean_sqrt_len_tokens:
@@ -146,7 +122,6 @@ def forward(self, features):
)
sum_embeddings = torch.sum(token_embeddings * input_mask_expanded, 1)
- # If tokens are weighted (by WordWeights layer), feature 'token_weights_sum' will be present
if "token_weights_sum" in features:
sum_mask = (
features["token_weights_sum"]
@@ -166,7 +141,6 @@ def forward(self, features):
input_mask_expanded = (
attention_mask.unsqueeze(-1).expand(token_embeddings.size()).float()
)
- # token_embeddings shape: bs, seq, hidden_dim
weights = (
torch.arange(start=1, end=token_embeddings.shape[1] + 1)
.unsqueeze(0)
@@ -180,7 +154,6 @@ def forward(self, features):
sum_embeddings = torch.sum(token_embeddings * input_mask_expanded, 1)
- # If tokens are weighted (by WordWeights layer), feature 'token_weights_sum' will be present
if "token_weights_sum" in features:
sum_mask = (
features["token_weights_sum"]
@@ -194,26 +167,16 @@ def forward(self, features):
output_vectors.append(sum_embeddings / sum_mask)
if self.pooling_mode_lasttoken:
batch_size, _, hidden_dim = token_embeddings.shape
- # attention_mask shape: (bs, seq_len)
- # Get shape [bs] indices of the last token (i.e. the last token for each batch item)
- # argmin gives us the index of the first 0 in the attention mask;
- # We get the last 1 index by subtracting 1
gather_indices = (
torch.argmin(attention_mask, 1, keepdim=False) - 1
- ) # Shape [bs]
+ )
- # There are empty sequences, where the index would become -1 which will crash
gather_indices = torch.clamp(gather_indices, min=0)
- # Turn indices from shape [bs] --> [bs, 1, hidden_dim]
gather_indices = gather_indices.unsqueeze(-1).repeat(1, hidden_dim)
gather_indices = gather_indices.unsqueeze(1)
assert gather_indices.shape == (batch_size, 1, hidden_dim)
- # Gather along the 1st dim (seq_len) (bs, seq_len, hidden_dim -> bs, hidden_dim)
- # Actually no need for the attention mask as we gather the last token where attn_mask = 1
- # but as we set some indices (which shouldn't be attended to) to 0 with clamp, we
- # use the attention mask to ignore them again
input_mask_expanded = (
attention_mask.unsqueeze(-1).expand(token_embeddings.size()).float()
)
@@ -249,10 +212,6 @@ def load(input_path):
def import_from_string(dotted_path):
- """
- Import a dotted module path and return the attribute/class designated by the
- last name in the path. Raise ImportError if the import failed.
- """
try:
module_path, class_name = dotted_path.rsplit(".", 1)
except ValueError:
@@ -307,6 +266,9 @@ def __init__(
)
if load_model:
+ import inspect
+ if 'backend' in inspect.signature(self._load_model).parameters:
+ model_args['backend'] = 'torch'
self._load_model(self.model_name_or_path, config, cache_dir, **model_args)
self.tokenizer = AutoTokenizer.from_pretrained(
tokenizer_name_or_path
@@ -331,6 +293,18 @@ def __init__(
if tokenizer_name_or_path is not None:
self.auto_model.config.tokenizer_class = self.tokenizer.__class__.__name__
+ def _load_model(self, model_name_or_path: str, config, cache_dir=None, backend=None, is_peft_model=False, **model_args):
+ """Loads the transformers model into the `auto_model` attribute"""
+ import inspect
+ parent_load_model = super()._load_model
+
+ if 'is_peft_model' in inspect.signature(parent_load_model).parameters:
+ model_args['is_peft_model'] = is_peft_model
+ if 'backend' in inspect.signature(parent_load_model).parameters:
+ model_args['backend'] = backend or 'torch'
+
+ return parent_load_model(model_name_or_path, config, cache_dir, **model_args)
+
def forward(self, features):
input_features = {
"input_ids": features["input_ids"],
@@ -353,7 +327,7 @@ def forward(self, features):
all_layer_idx = 2
if (
len(output_states) < 3
- ): # Some models only output last_hidden_states and all_hidden_states
+ ):
all_layer_idx = 1
hidden_states = output_states[all_layer_idx]
features.update({"all_layer_embeddings": hidden_states})
@@ -362,7 +336,6 @@ def forward(self, features):
@staticmethod
def load(input_path: str):
- # Old classes used other config names than 'sentence_bert_config.json'
for config_name in [
"sentence_bert_config.json",
"sentence_roberta_config.json",
@@ -381,15 +354,11 @@ def load(input_path: str):
return INSTRUCTORTransformer(model_name_or_path=input_path, **config)
def tokenize(self, texts):
- """
- Tokenizes a text and maps tokens to token-ids
- """
output = {}
if isinstance(texts[0], str):
to_tokenize = [texts]
to_tokenize = [[str(s).strip() for s in col] for col in to_tokenize]
- # Lowercase
if self.do_lower_case:
to_tokenize = [[s.lower() for s in col] for col in to_tokenize]
@@ -446,30 +415,15 @@ def prepare_input_features(
input_attention_mask_shape = input_features["attention_mask"].shape
instruction_attention_mask = instruction_features["attention_mask"]
- # reducing the attention length by 1 in order to omit the attention corresponding to the end_token
instruction_attention_mask = instruction_attention_mask[:, 1:]
- # creating instruction attention matrix equivalent to the size of the input attention matrix
expanded_instruction_attention_mask = torch.zeros(
input_attention_mask_shape, dtype=torch.int64
)
- # assigning the the actual instruction attention matrix to the expanded_instruction_attention_mask
- # eg:
- # instruction_attention_mask: 3x3
- # [[1,1,1],
- # [1,1,0],
- # [1,0,0]]
- # expanded_instruction_attention_mask: 3x4
- # [[1,1,1,0],
- # [1,1,0,0],
- # [1,0,0,0]]
expanded_instruction_attention_mask[
: instruction_attention_mask.size(0), : instruction_attention_mask.size(1)
] = instruction_attention_mask
- # In the pooling layer we want to consider only the tokens corresponding to the input text
- # and not the instruction. This is achieved by inverting the
- # attention_mask corresponding to the instruction.
expanded_instruction_attention_mask = 1 - expanded_instruction_attention_mask
input_features["instruction_mask"] = expanded_instruction_attention_mask
if return_data_type == "np":
@@ -517,62 +471,62 @@ def smart_batching_collate(self, batch):
return batched_input_features, labels
- def _load_sbert_model(self, model_path, token=None, cache_folder=None, revision=None, trust_remote_code=False):
- """
- Loads a full sentence-transformers model
- """
- # copied from https://github.com/UKPLab/sentence-transformers/blob/66e0ee30843dd411c64f37f65447bb38c7bf857a/sentence_transformers/util.py#L559
- # because we need to get files outside of the allow_patterns too
- # If file is local
- if os.path.isdir(model_path):
- model_path = str(model_path)
- else:
- # If model_path is a Hugging Face repository ID, download the model
- download_kwargs = {
- "repo_id": model_path,
- "revision": revision,
- "library_name": "InstructorEmbedding",
- "token": token,
- "cache_dir": cache_folder,
- "tqdm_class": disabled_tqdm,
- }
-
- # Check if the config_sentence_transformers.json file exists (exists since v2 of the framework)
- config_sentence_transformers_json_path = os.path.join(
- model_path, "config_sentence_transformers.json"
- )
- if os.path.exists(config_sentence_transformers_json_path):
- with open(
- config_sentence_transformers_json_path, encoding="UTF-8"
- ) as config_file:
- self._model_config = json.load(config_file)
-
- # Check if a readme exists
- model_card_path = os.path.join(model_path, "README.md")
- if os.path.exists(model_card_path):
- try:
- with open(model_card_path, encoding="utf8") as config_file:
- self._model_card_text = config_file.read()
- except:
- pass
-
- # Load the modules of sentence transformer
- modules_json_path = os.path.join(model_path, "modules.json")
- with open(modules_json_path, encoding="UTF-8") as config_file:
- modules_config = json.load(config_file)
-
- modules = OrderedDict()
- for module_config in modules_config:
- if module_config["idx"] == 0:
- module_class = INSTRUCTORTransformer
- elif module_config["idx"] == 1:
- module_class = INSTRUCTORPooling
- else:
- module_class = import_from_string(module_config["type"])
- module = module_class.load(os.path.join(model_path, module_config["path"]))
- modules[module_config["name"]] = module
-
- return modules
+ def _load_sbert_model(self, model_path, token=None, cache_folder=None, revision=None, trust_remote_code=False, local_files_only=False, model_kwargs=None, tokenizer_kwargs=None, config_kwargs=None):
+ import inspect
+ base_signature = inspect.signature(SentenceTransformer.__init__)
+
+ if os.path.isdir(model_path):
+ model_path = str(model_path)
+ else:
+ download_kwargs = {
+ "repo_id": model_path,
+ "revision": revision,
+ "library_name": "sentence-transformers",
+ "token": token,
+ "cache_dir": cache_folder,
+ "tqdm_class": disabled_tqdm,
+ "local_files_only": local_files_only,
+ }
+ model_path = snapshot_download(**download_kwargs)
+
+ config_sentence_transformers_json_path = os.path.join(
+ model_path, "config_sentence_transformers.json"
+ )
+ if os.path.exists(config_sentence_transformers_json_path):
+ with open(
+ config_sentence_transformers_json_path, encoding="UTF-8"
+ ) as config_file:
+ self._model_config = json.load(config_file)
+
+ model_card_path = os.path.join(model_path, "README.md")
+ if os.path.exists(model_card_path):
+ try:
+ with open(model_card_path, encoding="utf8") as config_file:
+ self._model_card_text = config_file.read()
+ except:
+ pass
+
+ modules_json_path = os.path.join(model_path, "modules.json")
+ with open(modules_json_path, encoding="UTF-8") as config_file:
+ modules_config = json.load(config_file)
+
+ modules = OrderedDict()
+ if 'backend' in base_signature.parameters:
+ module_kwargs = {}
+
+ for module_config in modules_config:
+ if module_config["idx"] == 0:
+ module_class = INSTRUCTORTransformer
+ elif module_config["idx"] == 1:
+ module_class = INSTRUCTORPooling
+ else:
+ module_class = import_from_string(module_config["type"])
+ module = module_class.load(os.path.join(model_path, module_config["path"]))
+ modules[module_config["name"]] = module
+
+ if 'backend' in base_signature.parameters:
+ return modules, module_kwargs
+ return modules
def encode(
self,
@@ -585,26 +539,6 @@ def encode(
device: Union[str, None] = None,
normalize_embeddings: bool = False,
):
- """
- Computes sentence embeddings
-
- :param sentences: the sentences to embed
- :param batch_size: the batch size used for the computation
- :param show_progress_bar: Output a progress bar when encode sentences
- :param output_value: Default sentence_embedding, to get sentence embeddings.
- Can be set to token_embeddings to get wordpiece token embeddings. Set to None, to get all output values
- :param convert_to_numpy: If true, the output is a list of numpy vectors.
- Else, it is a list of pytorch tensors.
- :param convert_to_tensor: If true, you get one large tensor as return.
- Overwrites any setting from convert_to_numpy
- :param device: Which torch.device to use for the computation
- :param normalize_embeddings: If set to true, returned vectors will have length 1.
- In that case, the faster dot-product (util.dot_score) instead of cosine similarity can be used.
-
- :return:
- By default, a list of tensors is returned. If convert_to_tensor,
- a stacked tensor is returned. If convert_to_numpy, a numpy matrix is returned.
- """
self.eval()
if show_progress_bar is None:
show_progress_bar = False
@@ -619,7 +553,7 @@ def encode(
input_was_string = False
if isinstance(sentences, str) or not hasattr(
sentences, "__len__"
- ): # Cast an individual sentence to a list with length 1
+ ):
sentences = [sentences]
input_was_string = True
@@ -660,14 +594,14 @@ def encode(
last_mask_id -= 1
embeddings.append(token_emb[0 : last_mask_id + 1])
- elif output_value is None: # Return all outputs
+ elif output_value is None:
embeddings = []
for sent_idx in range(len(out_features["sentence_embedding"])):
row = {
name: out_features[name][sent_idx] for name in out_features
}
embeddings.append(row)
- else: # Sentence embeddings
+ else:
embeddings = out_features[output_value]
embeddings = embeddings.detach()
if normalize_embeddings:
@@ -675,7 +609,6 @@ def encode(
embeddings, p=2, dim=1
)
- # fixes for #522 and #487 to avoid oom problems on gpu with large datasets
if convert_to_numpy:
embeddings = embeddings.cpu()
diff --git a/README.md b/README.md
index 8e3d3b3..b3b32de 100644
--- a/README.md
+++ b/README.md
@@ -1,12 +1,78 @@
-## My Personal Fork
+## Updated 1/29/2025
-This is a fork for the Instructor model becuase the original repository isn't kept up anymore. I've also made some improvements to their source code:
+This is a fork for the Instructor model becuase the original repository is rarely (if ever) updated. I've also made some improvements to their source code above and beyond making sure it functions on a basic level with libraries such as ```sentence-transformers```, ```langchain```, etc. As of 1/29/2025, it works with the most recent versions of those libraries, and it should be backwards compatible as well.
-1) Fixing it to work with the ```sentence-transformers``` library above 2.2.2.
-2) Properly download the models from huggingface using the new "snapshot download" API.
-3) Ability to specify where you want the model donwloaded with the "cache_dir" parameter.
+* Instructions on how to use it directly via the ```sentence-transformers``` library [are located here](https://sbert.net/docs/sentence_transformer/pretrained_models.html#instructor-models)
+* If using Langchain, the [instructions are here](https://python.langchain.com/api_reference/community/embeddings/langchain_community.embeddings.huggingface.HuggingFaceInstructEmbeddings.html#langchain_community.embeddings.huggingface.HuggingFaceInstructEmbeddings.model_kwargs)
-## What follows is the original repository's readme file. Ignore the quantization section, however, becuase pytorch has changed its API since then.
+Feel free to ask any questions because it's a little tricky with newer versions of sentence-transformers.
+
+Example creating embeddings to be put into vectorstore
+
+```python
+# Creating embeddings to be put into a vector database
+from langchain_community.embeddings import HuggingFaceInstructEmbeddings
+
+model_kwargs = {
+ "device": "cuda", # or "cpu"
+ "trust_remote_code": True,
+ "model_kwargs": {
+ "torch_dtype": torch.float16 # or torch.bfloat16 or torch.float32 depending on your needs
+ }
+}
+
+encode_kwargs = {
+ "normalize_embeddings": True,
+ "batch_size": 2
+}
+
+instructor_embeddings = HuggingFaceInstructEmbeddings(
+ model_name="hkunlp/instructor-xl", # or any other instructor model
+ model_kwargs=model_kwargs,
+ encode_kwargs=encode_kwargs,
+ embed_instruction="Represent the document for retrieval:",
+ show_progress=True
+)
+
+embeddings = instructor_embeddings.embed_documents(texts)
+```
+
+
+Example creatings embeddings from query
+
+```python
+from langchain_community.embeddings import HuggingFaceInstructEmbeddings
+
+model_kwargs = {
+ "device": "cuda", # or "cpu"
+ "trust_remote_code": True,
+ "model_kwargs": {
+ "torch_dtype": torch.float16 # or torch.bfloat16 or torch.float32 depending on your needs
+ }
+}
+
+encode_kwargs = {
+ "normalize_embeddings": True,
+ "batch_size": 1 # For queries, batch size is always 1
+}
+
+query_embeddings = HuggingFaceInstructEmbeddings(
+ model_name="hkunlp/instructor-xl", # or any other instructor model
+ model_kwargs=model_kwargs,
+ encode_kwargs=encode_kwargs,
+ embed_instruction="Represent the question for retrieving supporting documents:",
+ show_progress=False
+)
+
+# Generate embedding for your query
+query_embedding = query_embeddings.embed_query(query_text)
+```
+
+
+
+## Below is the original repository's readme file. Ignore the quantization section, however, because pytorch has changed its API since then.
+
+Original Repository Readme
# One Embedder, Any Task: Instruction-Finetuned Text Embeddings
@@ -257,7 +323,8 @@ You can evaluate your trained model checkpoints by specifying `--model_name` and
## Quantization
To [**Quantize**](https://pytorch.org/docs/stable/quantization.html) the Instructor embedding model, run the following code:
-```python
+```python
+
# imports
import torch
from InstructorEmbedding import INSTRUCTOR
diff --git a/requirements.txt b/requirements.txt
index 64d9c9a..6033047 100644
--- a/requirements.txt
+++ b/requirements.txt
@@ -1,12 +1,13 @@
transformers>=4.20,<5.0
-datasets>=2.20,<3.0
-pyarrow>=17.0,<18.0
-numpy>=1.0,<=1.26.4
+datasets>=2.20,<=3.2.0 # i can personally verify up to this limit as of today
+pyarrow>=17.0,<=18.1.0 # i can personally verify up to this limit as of today
+pyarrow-hotfix==0.6 # every so often this is necessary for pyarrow to avoid errors; rarely needed, but solves the error
+numpy>=1.0,<=1.26.4 # keep as is until numpy 2+ becomes more accepted
requests>=2.26,<3.0
scikit_learn>=1.0.2,<2.0
scipy>=1.14,<2.0
-sentence-transformers>=3.0.1,<4.0
+sentence-transformers>=3.0.1,<4.0 # this can remain if using my pull request. If you use either of the other two, you'll need to make it "sentence-transformers>=3.2.0,<4.0"
torch>=2.0
tqdm>=4.0,<5.0
rich>=13.0,<14.0
-huggingface-hub>=0.24.1
\ No newline at end of file
+huggingface-hub>=0.24.1
diff --git a/setup.py b/setup.py
index 6f17fb0..ac4dd1b 100644
--- a/setup.py
+++ b/setup.py
@@ -6,7 +6,7 @@
setup(
name='InstructorEmbedding',
packages=['InstructorEmbedding'],
- version='1.0.2',
+ version='1.0.3',
license='Apache License 2.0',
description='Text embedding tool',
long_description=readme,