diff --git a/InstructorEmbedding/instructor.py b/InstructorEmbedding/instructor.py index 72b3df2..01d288d 100644 --- a/InstructorEmbedding/instructor.py +++ b/InstructorEmbedding/instructor.py @@ -3,7 +3,7 @@ import json import os from collections import OrderedDict -from typing import Union +from typing import Union, Any import numpy as np import torch @@ -43,15 +43,15 @@ class INSTRUCTORPooling(nn.Module): """ def __init__( - self, - word_embedding_dimension: int, - pooling_mode: Union[str, None] = None, - pooling_mode_cls_token: bool = False, - pooling_mode_max_tokens: bool = False, - pooling_mode_mean_tokens: bool = True, - pooling_mode_mean_sqrt_len_tokens: bool = False, - pooling_mode_weightedmean_tokens: bool = False, - pooling_mode_lasttoken: bool = False, + self, + word_embedding_dimension: int, + pooling_mode: Union[str, None] = None, + pooling_mode_cls_token: bool = False, + pooling_mode_max_tokens: bool = False, + pooling_mode_mean_tokens: bool = True, + pooling_mode_mean_sqrt_len_tokens: bool = False, + pooling_mode_weightedmean_tokens: bool = False, + pooling_mode_lasttoken: bool = False, ): super().__init__() @@ -93,7 +93,7 @@ def __init__( ] ) self.pooling_output_dimension = ( - pooling_mode_multiplier * word_embedding_dimension + pooling_mode_multiplier * word_embedding_dimension ) def __repr__(self): @@ -137,7 +137,7 @@ def forward(self, features): ) token_embeddings[ input_mask_expanded == 0 - ] = -1e9 # Set padding tokens to large negative value + ] = -1e9 # Set padding tokens to large negative value max_over_time = torch.max(token_embeddings, 1)[0] output_vectors.append(max_over_time) if self.pooling_mode_mean_tokens or self.pooling_mode_mean_sqrt_len_tokens: @@ -199,7 +199,7 @@ def forward(self, features): # argmin gives us the index of the first 0 in the attention mask; # We get the last 1 index by subtracting 1 gather_indices = ( - torch.argmin(attention_mask, 1, keepdim=False) - 1 + torch.argmin(attention_mask, 1, keepdim=False) - 1 ) # Shape [bs] # There are empty sequences, where the index would become -1 which will crash @@ -234,14 +234,14 @@ def get_config_dict(self): def save(self, output_path): with open( - os.path.join(output_path, "config.json"), "w", encoding="UTF-8" + os.path.join(output_path, "config.json"), "w", encoding="UTF-8" ) as config_file: json.dump(self.get_config_dict(), config_file, indent=2) @staticmethod def load(input_path): with open( - os.path.join(input_path, "config.json"), encoding="UTF-8" + os.path.join(input_path, "config.json"), encoding="UTF-8" ) as config_file: config = json.load(config_file) @@ -273,15 +273,16 @@ def import_from_string(dotted_path): class INSTRUCTORTransformer(Transformer): def __init__( - self, - model_name_or_path: str, - max_seq_length=None, - model_args=None, - cache_dir=None, - tokenizer_args=None, - do_lower_case: bool = False, - tokenizer_name_or_path: Union[str, None] = None, - load_model: bool = True, + self, + model_name_or_path: str, + max_seq_length=None, + model_args=None, + cache_dir=None, + tokenizer_args=None, + do_lower_case: bool = False, + tokenizer_name_or_path: Union[str, None] = None, + load_model: bool = True, + backend: str = "torch", ): super().__init__(model_name_or_path) if model_args is None: @@ -307,7 +308,7 @@ def __init__( ) if load_model: - self._load_model(self.model_name_or_path, config, cache_dir, **model_args) + self._load_model(self.model_name_or_path, config, cache_dir, backend,**model_args) self.tokenizer = AutoTokenizer.from_pretrained( tokenizer_name_or_path if tokenizer_name_or_path is not None @@ -318,9 +319,9 @@ def __init__( if max_seq_length is None: if ( - hasattr(self.auto_model, "config") - and hasattr(self.auto_model.config, "max_position_embeddings") - and hasattr(self.tokenizer, "model_max_length") + hasattr(self.auto_model, "config") + and hasattr(self.auto_model.config, "max_position_embeddings") + and hasattr(self.tokenizer, "model_max_length") ): max_seq_length = min( self.auto_model.config.max_position_embeddings, @@ -352,7 +353,7 @@ def forward(self, features): if self.auto_model.config.output_hidden_states: all_layer_idx = 2 if ( - len(output_states) < 3 + len(output_states) < 3 ): # Some models only output last_hidden_states and all_hidden_states all_layer_idx = 1 hidden_states = output_states[all_layer_idx] @@ -404,7 +405,7 @@ def tokenize(self, texts): elif isinstance(texts[0], list): assert isinstance(texts[0][1], str) assert ( - len(texts[0]) == 2 + len(texts[0]) == 2 ), "The input should have both instruction and input text" instructions = [] @@ -433,7 +434,7 @@ def tokenize(self, texts): class INSTRUCTOR(SentenceTransformer): @staticmethod def prepare_input_features( - input_features, instruction_features, return_data_type: str = "pt" + input_features, instruction_features, return_data_type: str = "pt" ): if return_data_type == "np": input_features["attention_mask"] = torch.from_numpy( @@ -464,7 +465,7 @@ def prepare_input_features( # [1,1,0,0], # [1,0,0,0]] expanded_instruction_attention_mask[ - : instruction_attention_mask.size(0), : instruction_attention_mask.size(1) + : instruction_attention_mask.size(0), : instruction_attention_mask.size(1) ] = instruction_attention_mask # In the pooling layer we want to consider only the tokens corresponding to the input text @@ -495,7 +496,7 @@ def smart_batching_collate(self, batch): for idx in range(num_texts): assert isinstance(texts[idx][0], list) assert ( - len(texts[idx][0]) == 2 + len(texts[idx][0]) == 2 ), "The input should have both instruction and input text" num = len(texts[idx]) @@ -517,7 +518,16 @@ def smart_batching_collate(self, batch): return batched_input_features, labels - def _load_sbert_model(self, model_path, token=None, cache_folder=None, revision=None, trust_remote_code=False): + def _load_sbert_model(self, + model_path: str, + token: bool | str | None, + cache_folder: str | None, + revision: str | None = None, + trust_remote_code: bool = False, + local_files_only: bool = False, + model_kwargs: dict[str, Any] | None = None, + tokenizer_kwargs: dict[str, Any] | None = None, + config_kwargs: dict[str, Any] | None = None, ): """ Loads a full sentence-transformers model """ @@ -543,7 +553,7 @@ def _load_sbert_model(self, model_path, token=None, cache_folder=None, revision= ) if os.path.exists(config_sentence_transformers_json_path): with open( - config_sentence_transformers_json_path, encoding="UTF-8" + config_sentence_transformers_json_path, encoding="UTF-8" ) as config_file: self._model_config = json.load(config_file) @@ -562,6 +572,8 @@ def _load_sbert_model(self, model_path, token=None, cache_folder=None, revision= modules_config = json.load(config_file) modules = OrderedDict() + module_kwargs = OrderedDict() + for module_config in modules_config: if module_config["idx"] == 0: module_class = INSTRUCTORTransformer @@ -571,19 +583,20 @@ def _load_sbert_model(self, model_path, token=None, cache_folder=None, revision= module_class = import_from_string(module_config["type"]) module = module_class.load(os.path.join(model_path, module_config["path"])) modules[module_config["name"]] = module + module_kwargs[module_config["name"]] = module_config.get("kwargs", []) - return modules + return modules,module_kwargs def encode( - self, - sentences, - batch_size: int = 32, - show_progress_bar: Union[bool, None] = None, - output_value: str = "sentence_embedding", - convert_to_numpy: bool = True, - convert_to_tensor: bool = False, - device: Union[str, None] = None, - normalize_embeddings: bool = False, + self, + sentences, + batch_size: int = 32, + show_progress_bar: Union[bool, None] = None, + output_value: str = "sentence_embedding", + convert_to_numpy: bool = True, + convert_to_tensor: bool = False, + device: Union[str, None] = None, + normalize_embeddings: bool = False, ): """ Computes sentence embeddings @@ -618,7 +631,7 @@ def encode( input_was_string = False if isinstance(sentences, str) or not hasattr( - sentences, "__len__" + sentences, "__len__" ): # Cast an individual sentence to a list with length 1 sentences = [sentences] input_was_string = True @@ -641,9 +654,9 @@ def encode( sentences_sorted = [sentences[idx] for idx in length_sorted_idx] for start_index in trange( - 0, len(sentences), batch_size, desc="Batches", disable=not show_progress_bar + 0, len(sentences), batch_size, desc="Batches", disable=not show_progress_bar ): - sentences_batch = sentences_sorted[start_index : start_index + batch_size] + sentences_batch = sentences_sorted[start_index: start_index + batch_size] features = self.tokenize(sentences_batch) features = batch_to_device(features, device) @@ -653,13 +666,13 @@ def encode( if output_value == "token_embeddings": embeddings = [] for token_emb, attention in zip( - out_features[output_value], out_features["attention_mask"] + out_features[output_value], out_features["attention_mask"] ): last_mask_id = len(attention) - 1 while last_mask_id > 0 and attention[last_mask_id].item() == 0: last_mask_id -= 1 - embeddings.append(token_emb[0 : last_mask_id + 1]) + embeddings.append(token_emb[0: last_mask_id + 1]) elif output_value is None: # Return all outputs embeddings = [] for sent_idx in range(len(out_features["sentence_embedding"])):