Skip to content

Resolve compatibility issues with sentence-transformers 3.3.1 #127

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 3 commits into
base: main
Choose a base branch
from
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
113 changes: 63 additions & 50 deletions InstructorEmbedding/instructor.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
import json
import os
from collections import OrderedDict
from typing import Union
from typing import Union, Any

import numpy as np
import torch
Expand Down Expand Up @@ -43,15 +43,15 @@ class INSTRUCTORPooling(nn.Module):
"""

def __init__(
self,
word_embedding_dimension: int,
pooling_mode: Union[str, None] = None,
pooling_mode_cls_token: bool = False,
pooling_mode_max_tokens: bool = False,
pooling_mode_mean_tokens: bool = True,
pooling_mode_mean_sqrt_len_tokens: bool = False,
pooling_mode_weightedmean_tokens: bool = False,
pooling_mode_lasttoken: bool = False,
self,
word_embedding_dimension: int,
pooling_mode: Union[str, None] = None,
pooling_mode_cls_token: bool = False,
pooling_mode_max_tokens: bool = False,
pooling_mode_mean_tokens: bool = True,
pooling_mode_mean_sqrt_len_tokens: bool = False,
pooling_mode_weightedmean_tokens: bool = False,
pooling_mode_lasttoken: bool = False,
):
super().__init__()

Expand Down Expand Up @@ -93,7 +93,7 @@ def __init__(
]
)
self.pooling_output_dimension = (
pooling_mode_multiplier * word_embedding_dimension
pooling_mode_multiplier * word_embedding_dimension
)

def __repr__(self):
Expand Down Expand Up @@ -137,7 +137,7 @@ def forward(self, features):
)
token_embeddings[
input_mask_expanded == 0
] = -1e9 # Set padding tokens to large negative value
] = -1e9 # Set padding tokens to large negative value
max_over_time = torch.max(token_embeddings, 1)[0]
output_vectors.append(max_over_time)
if self.pooling_mode_mean_tokens or self.pooling_mode_mean_sqrt_len_tokens:
Expand Down Expand Up @@ -199,7 +199,7 @@ def forward(self, features):
# argmin gives us the index of the first 0 in the attention mask;
# We get the last 1 index by subtracting 1
gather_indices = (
torch.argmin(attention_mask, 1, keepdim=False) - 1
torch.argmin(attention_mask, 1, keepdim=False) - 1
) # Shape [bs]

# There are empty sequences, where the index would become -1 which will crash
Expand Down Expand Up @@ -234,14 +234,14 @@ def get_config_dict(self):

def save(self, output_path):
with open(
os.path.join(output_path, "config.json"), "w", encoding="UTF-8"
os.path.join(output_path, "config.json"), "w", encoding="UTF-8"
) as config_file:
json.dump(self.get_config_dict(), config_file, indent=2)

@staticmethod
def load(input_path):
with open(
os.path.join(input_path, "config.json"), encoding="UTF-8"
os.path.join(input_path, "config.json"), encoding="UTF-8"
) as config_file:
config = json.load(config_file)

Expand Down Expand Up @@ -273,15 +273,16 @@ def import_from_string(dotted_path):

class INSTRUCTORTransformer(Transformer):
def __init__(
self,
model_name_or_path: str,
max_seq_length=None,
model_args=None,
cache_dir=None,
tokenizer_args=None,
do_lower_case: bool = False,
tokenizer_name_or_path: Union[str, None] = None,
load_model: bool = True,
self,
model_name_or_path: str,
max_seq_length=None,
model_args=None,
cache_dir=None,
tokenizer_args=None,
do_lower_case: bool = False,
tokenizer_name_or_path: Union[str, None] = None,
load_model: bool = True,
backend: str = "torch",
):
super().__init__(model_name_or_path)
if model_args is None:
Expand All @@ -307,7 +308,7 @@ def __init__(
)

if load_model:
self._load_model(self.model_name_or_path, config, cache_dir, **model_args)
self._load_model(self.model_name_or_path, config, cache_dir, backend,**model_args)
self.tokenizer = AutoTokenizer.from_pretrained(
tokenizer_name_or_path
if tokenizer_name_or_path is not None
Expand All @@ -318,9 +319,9 @@ def __init__(

if max_seq_length is None:
if (
hasattr(self.auto_model, "config")
and hasattr(self.auto_model.config, "max_position_embeddings")
and hasattr(self.tokenizer, "model_max_length")
hasattr(self.auto_model, "config")
and hasattr(self.auto_model.config, "max_position_embeddings")
and hasattr(self.tokenizer, "model_max_length")
):
max_seq_length = min(
self.auto_model.config.max_position_embeddings,
Expand Down Expand Up @@ -352,7 +353,7 @@ def forward(self, features):
if self.auto_model.config.output_hidden_states:
all_layer_idx = 2
if (
len(output_states) < 3
len(output_states) < 3
): # Some models only output last_hidden_states and all_hidden_states
all_layer_idx = 1
hidden_states = output_states[all_layer_idx]
Expand Down Expand Up @@ -404,7 +405,7 @@ def tokenize(self, texts):
elif isinstance(texts[0], list):
assert isinstance(texts[0][1], str)
assert (
len(texts[0]) == 2
len(texts[0]) == 2
), "The input should have both instruction and input text"

instructions = []
Expand Down Expand Up @@ -433,7 +434,7 @@ def tokenize(self, texts):
class INSTRUCTOR(SentenceTransformer):
@staticmethod
def prepare_input_features(
input_features, instruction_features, return_data_type: str = "pt"
input_features, instruction_features, return_data_type: str = "pt"
):
if return_data_type == "np":
input_features["attention_mask"] = torch.from_numpy(
Expand Down Expand Up @@ -464,7 +465,7 @@ def prepare_input_features(
# [1,1,0,0],
# [1,0,0,0]]
expanded_instruction_attention_mask[
: instruction_attention_mask.size(0), : instruction_attention_mask.size(1)
: instruction_attention_mask.size(0), : instruction_attention_mask.size(1)
] = instruction_attention_mask

# In the pooling layer we want to consider only the tokens corresponding to the input text
Expand Down Expand Up @@ -495,7 +496,7 @@ def smart_batching_collate(self, batch):
for idx in range(num_texts):
assert isinstance(texts[idx][0], list)
assert (
len(texts[idx][0]) == 2
len(texts[idx][0]) == 2
), "The input should have both instruction and input text"

num = len(texts[idx])
Expand All @@ -517,7 +518,16 @@ def smart_batching_collate(self, batch):

return batched_input_features, labels

def _load_sbert_model(self, model_path, token=None, cache_folder=None, revision=None, trust_remote_code=False):
def _load_sbert_model(self,
model_path: str,
token: bool | str | None,
cache_folder: str | None,
revision: str | None = None,
trust_remote_code: bool = False,
local_files_only: bool = False,
model_kwargs: dict[str, Any] | None = None,
tokenizer_kwargs: dict[str, Any] | None = None,
config_kwargs: dict[str, Any] | None = None, ):
"""
Loads a full sentence-transformers model
"""
Expand All @@ -543,7 +553,7 @@ def _load_sbert_model(self, model_path, token=None, cache_folder=None, revision=
)
if os.path.exists(config_sentence_transformers_json_path):
with open(
config_sentence_transformers_json_path, encoding="UTF-8"
config_sentence_transformers_json_path, encoding="UTF-8"
) as config_file:
self._model_config = json.load(config_file)

Expand All @@ -562,6 +572,8 @@ def _load_sbert_model(self, model_path, token=None, cache_folder=None, revision=
modules_config = json.load(config_file)

modules = OrderedDict()
module_kwargs = OrderedDict()

for module_config in modules_config:
if module_config["idx"] == 0:
module_class = INSTRUCTORTransformer
Expand All @@ -571,19 +583,20 @@ def _load_sbert_model(self, model_path, token=None, cache_folder=None, revision=
module_class = import_from_string(module_config["type"])
module = module_class.load(os.path.join(model_path, module_config["path"]))
modules[module_config["name"]] = module
module_kwargs[module_config["name"]] = module_config.get("kwargs", [])

return modules
return modules,module_kwargs

def encode(
self,
sentences,
batch_size: int = 32,
show_progress_bar: Union[bool, None] = None,
output_value: str = "sentence_embedding",
convert_to_numpy: bool = True,
convert_to_tensor: bool = False,
device: Union[str, None] = None,
normalize_embeddings: bool = False,
self,
sentences,
batch_size: int = 32,
show_progress_bar: Union[bool, None] = None,
output_value: str = "sentence_embedding",
convert_to_numpy: bool = True,
convert_to_tensor: bool = False,
device: Union[str, None] = None,
normalize_embeddings: bool = False,
):
"""
Computes sentence embeddings
Expand Down Expand Up @@ -618,7 +631,7 @@ def encode(

input_was_string = False
if isinstance(sentences, str) or not hasattr(
sentences, "__len__"
sentences, "__len__"
): # Cast an individual sentence to a list with length 1
sentences = [sentences]
input_was_string = True
Expand All @@ -641,9 +654,9 @@ def encode(
sentences_sorted = [sentences[idx] for idx in length_sorted_idx]

for start_index in trange(
0, len(sentences), batch_size, desc="Batches", disable=not show_progress_bar
0, len(sentences), batch_size, desc="Batches", disable=not show_progress_bar
):
sentences_batch = sentences_sorted[start_index : start_index + batch_size]
sentences_batch = sentences_sorted[start_index: start_index + batch_size]
features = self.tokenize(sentences_batch)
features = batch_to_device(features, device)

Expand All @@ -653,13 +666,13 @@ def encode(
if output_value == "token_embeddings":
embeddings = []
for token_emb, attention in zip(
out_features[output_value], out_features["attention_mask"]
out_features[output_value], out_features["attention_mask"]
):
last_mask_id = len(attention) - 1
while last_mask_id > 0 and attention[last_mask_id].item() == 0:
last_mask_id -= 1

embeddings.append(token_emb[0 : last_mask_id + 1])
embeddings.append(token_emb[0: last_mask_id + 1])
elif output_value is None: # Return all outputs
embeddings = []
for sent_idx in range(len(out_features["sentence_embedding"])):
Expand Down