diff --git a/haystack/components/embedders/backends/sentence_transformers_backend.py b/haystack/components/embedders/backends/sentence_transformers_backend.py index 9d86b80ca0..cff9135c86 100644 --- a/haystack/components/embedders/backends/sentence_transformers_backend.py +++ b/haystack/components/embedders/backends/sentence_transformers_backend.py @@ -29,6 +29,7 @@ def get_embedding_backend( truncate_dim: Optional[int] = None, model_kwargs: Optional[Dict[str, Any]] = None, tokenizer_kwargs: Optional[Dict[str, Any]] = None, + config_kwargs: Optional[Dict[str, Any]] = None, ): embedding_backend_id = f"{model}{device}{auth_token}{truncate_dim}" @@ -42,6 +43,7 @@ def get_embedding_backend( truncate_dim=truncate_dim, model_kwargs=model_kwargs, tokenizer_kwargs=tokenizer_kwargs, + config_kwargs=config_kwargs, ) _SentenceTransformersEmbeddingBackendFactory._instances[embedding_backend_id] = embedding_backend return embedding_backend @@ -61,6 +63,7 @@ def __init__( truncate_dim: Optional[int] = None, model_kwargs: Optional[Dict[str, Any]] = None, tokenizer_kwargs: Optional[Dict[str, Any]] = None, + config_kwargs: Optional[Dict[str, Any]] = None, ): sentence_transformers_import.check() self.model = SentenceTransformer( @@ -71,6 +74,7 @@ def __init__( truncate_dim=truncate_dim, model_kwargs=model_kwargs, tokenizer_kwargs=tokenizer_kwargs, + config_kwargs=config_kwargs, ) def embed(self, data: List[str], **kwargs) -> List[List[float]]: diff --git a/haystack/components/embedders/sentence_transformers_text_embedder.py b/haystack/components/embedders/sentence_transformers_text_embedder.py index e29b2d439c..2d9f5efeb7 100644 --- a/haystack/components/embedders/sentence_transformers_text_embedder.py +++ b/haystack/components/embedders/sentence_transformers_text_embedder.py @@ -34,7 +34,7 @@ class SentenceTransformersTextEmbedder: ``` """ - def __init__( + def __init__( # noqa: PLR0913 self, model: str = "sentence-transformers/all-mpnet-base-v2", device: Optional[ComponentDevice] = None, @@ -48,6 +48,7 @@ def __init__( truncate_dim: Optional[int] = None, model_kwargs: Optional[Dict[str, Any]] = None, tokenizer_kwargs: Optional[Dict[str, Any]] = None, + config_kwargs: Optional[Dict[str, Any]] = None, precision: Literal["float32", "int8", "uint8", "binary", "ubinary"] = "float32", ): """ @@ -86,6 +87,8 @@ def __init__( :param tokenizer_kwargs: Additional keyword arguments for `AutoTokenizer.from_pretrained` when loading the tokenizer. Refer to specific model documentation for available kwargs. + :param config_kwargs: + Additional keyword arguments for model configuration parameters :param precision: The precision to use for the embeddings. All non-float32 precisions are quantized embeddings. @@ -105,6 +108,7 @@ def __init__( self.truncate_dim = truncate_dim self.model_kwargs = model_kwargs self.tokenizer_kwargs = tokenizer_kwargs + self.config_kwargs = config_kwargs self.embedding_backend = None self.precision = precision @@ -135,6 +139,7 @@ def to_dict(self) -> Dict[str, Any]: truncate_dim=self.truncate_dim, model_kwargs=self.model_kwargs, tokenizer_kwargs=self.tokenizer_kwargs, + config_kwargs=self.config_kwargs, precision=self.precision, ) if serialization_dict["init_parameters"].get("model_kwargs") is not None: @@ -172,6 +177,7 @@ def warm_up(self): truncate_dim=self.truncate_dim, model_kwargs=self.model_kwargs, tokenizer_kwargs=self.tokenizer_kwargs, + config_kwargs=self.config_kwargs, ) if self.tokenizer_kwargs and self.tokenizer_kwargs.get("model_max_length"): self.embedding_backend.model.max_seq_length = self.tokenizer_kwargs["model_max_length"] diff --git a/releasenotes/notes/sentence-transformers-text-embedder-config_kwargs-11f10429e25a3a6e.yaml b/releasenotes/notes/sentence-transformers-text-embedder-config_kwargs-11f10429e25a3a6e.yaml new file mode 100644 index 0000000000..9409b09e6d --- /dev/null +++ b/releasenotes/notes/sentence-transformers-text-embedder-config_kwargs-11f10429e25a3a6e.yaml @@ -0,0 +1,4 @@ +--- +enhancements: + - | + SentenceTransformersTextEmbedder now supports config_kwargs for additional parameters when loading the model configuration diff --git a/test/components/embedders/test_sentence_transformers_embedding_backend.py b/test/components/embedders/test_sentence_transformers_embedding_backend.py index 7ca42aab91..55014183b2 100644 --- a/test/components/embedders/test_sentence_transformers_embedding_backend.py +++ b/test/components/embedders/test_sentence_transformers_embedding_backend.py @@ -42,6 +42,7 @@ def test_model_initialization(mock_sentence_transformer): truncate_dim=256, model_kwargs=None, tokenizer_kwargs=None, + config_kwargs=None, ) diff --git a/test/components/embedders/test_sentence_transformers_text_embedder.py b/test/components/embedders/test_sentence_transformers_text_embedder.py index 2f043de237..9325c481ca 100644 --- a/test/components/embedders/test_sentence_transformers_text_embedder.py +++ b/test/components/embedders/test_sentence_transformers_text_embedder.py @@ -70,6 +70,7 @@ def test_to_dict(self): "truncate_dim": None, "model_kwargs": None, "tokenizer_kwargs": None, + "config_kwargs": None, "precision": "float32", }, } @@ -88,6 +89,7 @@ def test_to_dict_with_custom_init_parameters(self): truncate_dim=256, model_kwargs={"torch_dtype": torch.float32}, tokenizer_kwargs={"model_max_length": 512}, + config_kwargs={"use_memory_efficient_attention": False}, precision="int8", ) data = component.to_dict() @@ -106,6 +108,7 @@ def test_to_dict_with_custom_init_parameters(self): "truncate_dim": 256, "model_kwargs": {"torch_dtype": "torch.float32"}, "tokenizer_kwargs": {"model_max_length": 512}, + "config_kwargs": {"use_memory_efficient_attention": False}, "precision": "int8", }, } @@ -131,6 +134,7 @@ def test_from_dict(self): "truncate_dim": None, "model_kwargs": {"torch_dtype": "torch.float32"}, "tokenizer_kwargs": {"model_max_length": 512}, + "config_kwargs": {"use_memory_efficient_attention": False}, "precision": "float32", }, } @@ -147,6 +151,7 @@ def test_from_dict(self): assert component.truncate_dim is None assert component.model_kwargs == {"torch_dtype": torch.float32} assert component.tokenizer_kwargs == {"model_max_length": 512} + assert component.config_kwargs == {"use_memory_efficient_attention": False} assert component.precision == "float32" def test_from_dict_no_default_parameters(self): @@ -218,6 +223,7 @@ def test_warmup(self, mocked_factory): truncate_dim=None, model_kwargs=None, tokenizer_kwargs={"model_max_length": 512}, + config_kwargs=None, ) @patch(