Skip to content

Commit d06401a

Browse files
authored
elasticsearch embeddings (run-llama#7914)
1 parent 31ae464 commit d06401a

File tree

8 files changed

+337
-3
lines changed

8 files changed

+337
-3
lines changed

CHANGELOG.md

+1
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@
66
- Updated `KeywordNodePostprocessor` to use spacy to support more languages (#7894)
77
- `LocalAI` supporting global or per-query `/chat/completions` vs `/completions` (#7921)
88
- Added notebook on using REBEL + Wikipedia filtering for knowledge graphs (#7919)
9+
- Added support for `ElasticsearchEmbeddings` (#7914)
910

1011
## [0.8.37] - 2023-09-30
1112

data_requirements.txt

+4-1
Original file line numberDiff line numberDiff line change
@@ -14,4 +14,7 @@ google-auth-httplib2
1414
google-auth-oauthlib
1515

1616
# vellum
17-
vellum-ai==0.0.15
17+
vellum-ai==0.0.15
18+
19+
# elasticsearch
20+
elasticsearch==8.9.0

docs/core_modules/model_modules/embeddings/modules.md

+1
Original file line numberDiff line numberDiff line change
@@ -11,4 +11,5 @@ maxdepth: 1
1111
/examples/customization/llms/AzureOpenAI.ipynb
1212
/examples/embeddings/custom_embeddings.ipynb
1313
/examples/embeddings/huggingface.ipynb
14+
/embeddings/elasticsearch.ipynb
1415
```
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,106 @@
1+
{
2+
"cells": [
3+
{
4+
"attachments": {},
5+
"cell_type": "markdown",
6+
"metadata": {},
7+
"source": [
8+
"# Elasticsearch Embeddings"
9+
]
10+
},
11+
{
12+
"cell_type": "code",
13+
"execution_count": null,
14+
"metadata": {},
15+
"outputs": [],
16+
"source": [
17+
"# imports\n",
18+
"\n",
19+
"from llama_index.embeddings.elasticsearch import ElasticsearchEmbeddings\n",
20+
"from llama_index.vector_stores import ElasticsearchStore\n",
21+
"from llama_index import ServiceContext, StorageContext, VectorStoreIndex"
22+
]
23+
},
24+
{
25+
"cell_type": "code",
26+
"execution_count": null,
27+
"metadata": {},
28+
"outputs": [],
29+
"source": [
30+
"# get credentials and create embeddings\n",
31+
"\n",
32+
"import os\n",
33+
"\n",
34+
"host = os.environ.get(\"ES_HOST\", \"localhost:9200\")\n",
35+
"username = os.environ.get(\"ES_USERNAME\", \"elastic\")\n",
36+
"password = os.environ.get(\"ES_PASSWORD\", \"changeme\")\n",
37+
"index_name = os.environ.get(\"INDEX_NAME\", \"your-index-name\")\n",
38+
"model_id = os.environ.get(\"MODEL_ID\", \"your-model-id\")\n",
39+
"\n",
40+
"\n",
41+
"embeddings = ElasticsearchEmbeddings.from_credentials(\n",
42+
" model_id=model_id, es_url=host, es_username=username, es_password=password\n",
43+
")"
44+
]
45+
},
46+
{
47+
"cell_type": "code",
48+
"execution_count": null,
49+
"metadata": {},
50+
"outputs": [],
51+
"source": [
52+
"# create service context using the embeddings\n",
53+
"\n",
54+
"service_context = ServiceContext(embed_model=embeddings, chunk_size=512)"
55+
]
56+
},
57+
{
58+
"cell_type": "code",
59+
"execution_count": null,
60+
"metadata": {},
61+
"outputs": [],
62+
"source": [
63+
"# usage with elasticsearch vector store\n",
64+
"\n",
65+
"vector_store = ElasticsearchStore(\n",
66+
" index_name=index_name, es_url=host, es_user=username, es_password=password\n",
67+
")\n",
68+
"\n",
69+
"storage_context = StorageContext.from_defaults(vector_store=vector_store)\n",
70+
"\n",
71+
"index = VectorStoreIndex.from_vector_store(\n",
72+
" vector_store=vector_store,\n",
73+
" storage_context=storage_context,\n",
74+
" service_context=service_context,\n",
75+
")\n",
76+
"\n",
77+
"query_engine = index.as_query_engine()\n",
78+
"\n",
79+
"\n",
80+
"response = query_engine.query(\"hello world\")"
81+
]
82+
}
83+
],
84+
"metadata": {
85+
"kernelspec": {
86+
"display_name": "Python 3",
87+
"language": "python",
88+
"name": "python3"
89+
},
90+
"language_info": {
91+
"codemirror_mode": {
92+
"name": "ipython",
93+
"version": 3
94+
},
95+
"file_extension": ".py",
96+
"mimetype": "text/x-python",
97+
"name": "python",
98+
"nbconvert_exporter": "python",
99+
"pygments_lexer": "ipython3",
100+
"version": "3.11.3"
101+
},
102+
"orig_nbformat": 4
103+
},
104+
"nbformat": 4,
105+
"nbformat_minor": 2
106+
}

llama_index/embeddings/__init__.py

+2-1
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,7 @@
1313
from llama_index.embeddings.instructor import InstructorEmbedding
1414
from llama_index.embeddings.utils import resolve_embed_model
1515
from llama_index.embeddings.base import SimilarityMode
16-
16+
from llama_index.embeddings.elasticsearch import ElasticsearchEmbeddings
1717

1818
__all__ = [
1919
"GoogleUnivSentEncoderEmbedding",
@@ -27,4 +27,5 @@
2727
"resolve_embed_model",
2828
"DEFAULT_HUGGINGFACE_EMBEDDING_MODEL",
2929
"SimilarityMode",
30+
"ElasticsearchEmbeddings",
3031
]
+181
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,181 @@
1+
from typing import List, Any
2+
from llama_index.embeddings.base import BaseEmbedding
3+
from llama_index.bridge.pydantic import PrivateAttr
4+
5+
6+
class ElasticsearchEmbeddings(BaseEmbedding):
7+
"""Elasticsearch embedding models.
8+
9+
This class provides an interface to generate embeddings using a model deployed
10+
in an Elasticsearch cluster. It requires an Elasticsearch connection object
11+
and the model_id of the model deployed in the cluster.
12+
13+
In Elasticsearch you need to have an embedding model loaded and deployed.
14+
- https://www.elastic.co
15+
/guide/en/elasticsearch/reference/current/infer-trained-model.html
16+
- https://www.elastic.co
17+
/guide/en/machine-learning/current/ml-nlp-deploy-models.html
18+
""" #
19+
20+
_client: Any = PrivateAttr()
21+
model_id: str
22+
input_field: str
23+
24+
@classmethod
25+
def class_name(self) -> str:
26+
return "ElasticsearchEmbeddings"
27+
28+
def __init__(
29+
self,
30+
client: Any,
31+
model_id: str,
32+
input_field: str = "text_field",
33+
**kwargs: Any,
34+
):
35+
self._client = client
36+
super().__init__(model_id=model_id, input_field=input_field, **kwargs)
37+
38+
@classmethod
39+
def from_es_connection(
40+
cls,
41+
model_id: str,
42+
es_connection: Any,
43+
input_field: str = "text_field",
44+
) -> BaseEmbedding:
45+
"""
46+
Instantiate embeddings from an existing Elasticsearch connection.
47+
48+
This method provides a way to create an instance of the ElasticsearchEmbeddings
49+
class using an existing Elasticsearch connection. The connection object is used
50+
to create an MlClient, which is then used to initialize the
51+
ElasticsearchEmbeddings instance.
52+
53+
Args:
54+
model_id (str): The model_id of the model deployed in the Elasticsearch cluster.
55+
es_connection (elasticsearch.Elasticsearch): An existing Elasticsearch
56+
connection object.
57+
input_field (str, optional): The name of the key for the input text field
58+
in the document. Defaults to 'text_field'.
59+
60+
Returns:
61+
ElasticsearchEmbeddings: An instance of the ElasticsearchEmbeddings class.
62+
63+
Example:
64+
.. code-block:: python
65+
66+
from elasticsearch import Elasticsearch
67+
68+
from llama_index.embeddings import ElasticsearchEmbeddings
69+
70+
# Define the model ID and input field name (if different from default)
71+
model_id = "your_model_id"
72+
# Optional, only if different from 'text_field'
73+
input_field = "your_input_field"
74+
75+
# Create Elasticsearch connection
76+
es_connection = Elasticsearch(
77+
hosts=["localhost:9200"], basic_auth=("user", "password")
78+
)
79+
80+
# Instantiate ElasticsearchEmbeddings using the existing connection
81+
embeddings = ElasticsearchEmbeddings.from_es_connection(
82+
model_id,
83+
es_connection,
84+
input_field=input_field,
85+
)
86+
"""
87+
88+
try:
89+
from elasticsearch.client import MlClient
90+
except ImportError:
91+
raise ImportError(
92+
"elasticsearch package not found, install with"
93+
"'pip install elasticsearch'"
94+
)
95+
96+
client = MlClient(es_connection)
97+
return cls(client, model_id, input_field=input_field)
98+
99+
@classmethod
100+
def from_credentials(
101+
cls,
102+
model_id: str,
103+
es_url: str,
104+
es_username: str,
105+
es_password: str,
106+
input_field: str = "text_field",
107+
) -> BaseEmbedding:
108+
"""Instantiate embeddings from Elasticsearch credentials.
109+
110+
Args:
111+
model_id (str): The model_id of the model deployed in the Elasticsearch
112+
cluster.
113+
input_field (str): The name of the key for the input text field in the
114+
document. Defaults to 'text_field'.
115+
es_url: (str): The Elasticsearch url to connect to.
116+
es_username: (str): Elasticsearch username.
117+
es_password: (str): Elasticsearch password.
118+
119+
Example:
120+
.. code-block:: python
121+
122+
from llama_index.embeddings import ElasticsearchEmbeddings
123+
124+
# Define the model ID and input field name (if different from default)
125+
model_id = "your_model_id"
126+
# Optional, only if different from 'text_field'
127+
input_field = "your_input_field"
128+
129+
embeddings = ElasticsearchEmbeddings.from_credentials(
130+
model_id,
131+
input_field=input_field,
132+
es_url="foo",
133+
es_username="bar",
134+
es_password="baz",
135+
)
136+
"""
137+
138+
try:
139+
from elasticsearch import Elasticsearch
140+
from elasticsearch.client import MlClient
141+
except ImportError:
142+
raise ImportError(
143+
"elasticsearch package not found, install with"
144+
"'pip install elasticsearch'"
145+
)
146+
147+
es_connection = Elasticsearch(
148+
hosts=[es_url],
149+
basic_auth=(es_username, es_password),
150+
)
151+
152+
client = MlClient(es_connection)
153+
return cls(client, model_id, input_field=input_field)
154+
155+
def _get_embedding(self, text: str) -> List[float]:
156+
"""
157+
Generate an embedding for a single query text.
158+
159+
Args:
160+
text (str): The query text to generate an embedding for.
161+
162+
Returns:
163+
List[float]: The embedding for the input query text.
164+
"""
165+
166+
response = self._client.infer_trained_model(
167+
model_id=self.model_id,
168+
docs=[{self.input_field: text}],
169+
)
170+
171+
embedding = response["inference_results"][0]["predicted_value"]
172+
return embedding
173+
174+
def _get_text_embedding(self, text: str) -> List[float]:
175+
return self._get_embedding(text)
176+
177+
def _get_query_embedding(self, query: str) -> List[float]:
178+
return self._get_embedding(query)
179+
180+
async def _aget_query_embedding(self, query: str) -> List[float]:
181+
return self._get_query_embedding(query)
+39
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,39 @@
1+
import pytest
2+
from llama_index.embeddings.elasticsearch import ElasticsearchEmbeddings
3+
4+
5+
@pytest.fixture
6+
def model_id() -> str:
7+
# Replace with your actual model_id
8+
return "your_model_id"
9+
10+
11+
@pytest.fixture
12+
def es_url() -> str:
13+
# Replace with your actual Elasticsearch URL
14+
return "http://localhost:9200"
15+
16+
17+
@pytest.fixture
18+
def es_username() -> str:
19+
# Replace with your actual Elasticsearch username
20+
return "foo"
21+
22+
23+
@pytest.fixture
24+
def es_password() -> str:
25+
# Replace with your actual Elasticsearch password
26+
return "bar"
27+
28+
29+
def test_elasticsearch_embedding_constructor(
30+
model_id: str, es_url: str, es_username: str, es_password: str
31+
) -> None:
32+
"""Test Elasticsearch embedding query."""
33+
34+
ElasticsearchEmbeddings.from_credentials(
35+
model_id=model_id,
36+
es_url=es_url,
37+
es_username=es_username,
38+
es_password=es_password,
39+
)

tests/vector_stores/test_elasticsearch.py

+3-1
Original file line numberDiff line numberDiff line change
@@ -463,7 +463,9 @@ async def perform_request(self, *args, **kwargs): # type: ignore
463463

464464
es_store.add(node_embeddings)
465465

466-
user_agent = es_client_instance.transport.requests[0]["headers"]["user-agent"]
466+
user_agent = es_client_instance.transport.requests[0]["headers"][ # type: ignore
467+
"user-agent"
468+
]
467469
pattern = r"^llama_index-py-vs/\d+\.\d+\.\d+$"
468470
match = re.match(pattern, user_agent)
469471

0 commit comments

Comments
 (0)