Skip to content

Commit 9e94b5e

Browse files
authored
Merge pull request #91 from BillFarber/task/extendExamples
Adds a Vector/BM25 example.
2 parents 65a8c47 + 9f8571d commit 9e94b5e

File tree

7 files changed

+207
-18
lines changed

7 files changed

+207
-18
lines changed

examples/langchain/README.md

+18-1
Original file line numberDiff line numberDiff line change
@@ -104,7 +104,9 @@ loaded via `load_data.py`.
104104
### MarkLogic 12EA Setup
105105

106106
To try out this functionality out, you will need acces to an instance of MarkLogic 12
107-
(currently internal or Early Access only). You may use docker
107+
(currently internal or Early Access only).
108+
<TODO>Add info to get ML12</TODO>
109+
You may use docker
108110
[docker-compose](https://docs.docker.com/compose/) to instantiate a new MarkLogic
109111
instance with port 8003 available (you can use your own MarkLogic instance too, just be
110112
sure that port 8003 is available):
@@ -147,3 +149,18 @@ into different collections.
147149
```
148150
python load_data_with_embeddings.py
149151
```
152+
153+
### Running the Vector Query
154+
155+
You are now ready to test the example vector retriever. Run the following to ask a
156+
question with the results augmented via the `marklogic_vector_query_retriever.py` module
157+
in this project:
158+
159+
python ask_vector_query.py "What is task decomposition?" posts_with_embeddings
160+
161+
This retriever searches MarkLogic for candidate documents, and defaults to
162+
using the new score-bm25 scoring method in MarkLogic 12EA. If preferred, you can adjust
163+
this to one of the other scoring methods. After retrieving candidate documents based on
164+
the CTS search, the retriever uses the new vector functionality to sort the documents
165+
based on cosine similarity to the user question, and then returns the top N documents
166+
for the retriever to package up.
+53
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,53 @@
1+
# Based on example at
2+
# https://python.langchain.com/docs/use_cases/question_answering/quickstart .
3+
4+
import os
5+
import sys
6+
from dotenv import load_dotenv
7+
from langchain import hub
8+
from langchain_openai import AzureChatOpenAI, AzureOpenAIEmbeddings
9+
from langchain.schema import StrOutputParser
10+
from langchain.schema.runnable import RunnablePassthrough
11+
from marklogic import Client
12+
from marklogic_vector_query_retriever import (
13+
MarkLogicVectorQueryRetriever,
14+
)
15+
16+
17+
def format_docs(docs):
18+
return "\n\n".join(doc.page_content for doc in docs)
19+
20+
21+
load_dotenv()
22+
embeddings = AzureOpenAIEmbeddings(
23+
azure_deployment=os.environ["AZURE_EMBEDDING_DEPLOYMENT_NAME"]
24+
)
25+
retriever = MarkLogicVectorQueryRetriever.create(
26+
Client("http://localhost:8003", digest=("langchain-user", "password")),
27+
embedding_generator=embeddings,
28+
)
29+
retriever.collections = [sys.argv[2]]
30+
retriever.max_results = int(sys.argv[3]) if len(sys.argv) > 3 else 10
31+
if len(sys.argv) > 4:
32+
retriever.query_type = sys.argv[4]
33+
34+
question = sys.argv[1]
35+
36+
prompt = hub.pull("rlm/rag-prompt")
37+
# Note that the Azure OpenAI API key, the Azure OpenAI Endpoint, and the OpenAI API
38+
# Version, are all read from the environment automatically.
39+
llm = AzureChatOpenAI(
40+
model_name=os.getenv("AZURE_LLM_DEPLOYMENT_NAME"),
41+
azure_deployment=os.getenv("AZURE_LLM_DEPLOYMENT_NAME"),
42+
temperature=0,
43+
max_tokens=None,
44+
timeout=None,
45+
)
46+
47+
rag_chain = (
48+
{"context": retriever | format_docs, "question": RunnablePassthrough()}
49+
| prompt
50+
| llm
51+
| StrOutputParser()
52+
)
53+
print(rag_chain.invoke(question))

examples/langchain/load_data.py

+21-13
Original file line numberDiff line numberDiff line change
@@ -17,34 +17,42 @@
1717
)
1818
docs = loader.load()
1919

20-
text_splitter = RecursiveCharacterTextSplitter(chunk_size=1000, chunk_overlap=100)
20+
text_splitter = RecursiveCharacterTextSplitter(
21+
chunk_size=1000, chunk_overlap=100
22+
)
2123
splits = text_splitter.split_documents(docs)
2224

2325
client = Client("http://localhost:8003", digest=("langchain-user", "password"))
2426

25-
marklogic_docs = [
26-
DefaultMetadata(collections="posts")
27-
]
27+
marklogic_docs = [DefaultMetadata(collections="posts")]
2828
for split in splits:
29-
doc = Document(None, split.page_content, extension=".txt", directory="/post/")
29+
doc = Document(
30+
None, split.page_content, extension=".txt", directory="/post/"
31+
)
3032
marklogic_docs.append(doc)
3133

3234
client.documents.write(marklogic_docs)
33-
print(f"Number of documents written to collection 'posts': {len(marklogic_docs)-1}")
35+
print(
36+
f"Number of documents written to collection 'posts': {len(marklogic_docs)-1}"
37+
)
3438

3539
loader = WebBaseLoader(
36-
web_paths=("https://raw.githubusercontent.com/langchain-ai/langchain/master/docs/docs/modules/state_of_the_union.txt",)
40+
web_paths=(["https://www.whitehouse.gov/state-of-the-union-2022/"])
3741
)
3842
docs = loader.load()
39-
text_splitter = RecursiveCharacterTextSplitter(chunk_size=1000, chunk_overlap=100)
43+
text_splitter = RecursiveCharacterTextSplitter(
44+
chunk_size=1000, chunk_overlap=100
45+
)
4046
splits = text_splitter.split_documents(docs)
4147

42-
marklogic_docs = [
43-
DefaultMetadata(collections="sotu")
44-
]
48+
marklogic_docs = [DefaultMetadata(collections="sotu")]
4549
for split in splits:
46-
doc = Document(None, split.page_content, extension=".txt", directory="/sotu/")
50+
doc = Document(
51+
None, split.page_content, extension=".txt", directory="/sotu/"
52+
)
4753
marklogic_docs.append(doc)
4854

4955
client.documents.write(marklogic_docs)
50-
print(f"Number of documents written to collection 'sotu': {len(marklogic_docs)-1}")
56+
print(
57+
f"Number of documents written to collection 'sotu': {len(marklogic_docs)-1}"
58+
)

examples/langchain/load_data_with_embeddings.py

+1-3
Original file line numberDiff line numberDiff line change
@@ -50,9 +50,7 @@
5050
)
5151

5252
loader = WebBaseLoader(
53-
web_paths=(
54-
"https://raw.githubusercontent.com/langchain-ai/langchain/master/docs/docs/modules/state_of_the_union.txt",
55-
)
53+
web_paths=(["https://www.whitehouse.gov/state-of-the-union-2022/"])
5654
)
5755
docs = loader.load()
5856
text_splitter = RecursiveCharacterTextSplitter(
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,92 @@
1+
from typing import List
2+
from langchain_core.documents import Document
3+
from langchain_core.retrievers import BaseRetriever
4+
from langchain_openai import AzureOpenAIEmbeddings
5+
from marklogic import Client
6+
7+
8+
class MarkLogicVectorQueryRetriever(BaseRetriever):
9+
10+
client: Client
11+
embedding_generator: AzureOpenAIEmbeddings
12+
max_results: int = 10
13+
collections: List[str] = []
14+
tde_schema: str
15+
tde_view: str
16+
scoring_method: str
17+
18+
@classmethod
19+
def create(
20+
cls,
21+
client: Client,
22+
embedding_generator: AzureOpenAIEmbeddings,
23+
tde_schema: str = None,
24+
tde_view: str = None,
25+
scoring_method: str = "score-bm25",
26+
):
27+
return cls(
28+
client=client,
29+
embedding_generator=embedding_generator,
30+
tde_schema=tde_schema or "demo",
31+
tde_view=tde_view or "posts",
32+
scoring_method=scoring_method,
33+
)
34+
35+
def _build_javascript_query_query(self, query, query_embedding):
36+
# Returning first self.max_results documents based on token limitations
37+
#
38+
# If limits are hit, consider different models:
39+
# gpt-35-turbo (0125): 16,385/4,096
40+
# gpt-35-turbo (1106): 16,385/4,096
41+
# gpt-35-turbo-16k (0613):
42+
43+
# This JavaScript consists of two queries.
44+
# The first is a standard cts search, searching for words that match those used
45+
# in the chat question.
46+
# The second query is an Optic query that uses the top documents from the CTS
47+
# query to do a vector search to re-order the results.
48+
49+
search_words = []
50+
for word in query.split():
51+
search_words.append(word.lower().replace("?", ""))
52+
return """
53+
const op = require('/MarkLogic/optic');
54+
const ovec = require('/MarkLogic/optic/optic-vec.xqy');
55+
const result =
56+
fn.subsequence(cts.search(cts.andQuery([
57+
cts.wordQuery({}),
58+
cts.collectionQuery({})
59+
]),["{}"]), 1, {});
60+
let uris = [];
61+
for(const doc of result){{
62+
uris.push(xdmp.nodeUri(doc))
63+
}}
64+
const qv = vec.vector({})
65+
66+
const rows = op.fromView('{}','{}','')
67+
.where(op.in(op.col('uri'), uris))
68+
.bind(op.as('summaryCosineSim', op.vec.cosineSimilarity(op.vec.vector(op.col('embedding')),qv)))
69+
.orderBy(op.desc(op.col('summaryCosineSim')))
70+
.result();
71+
rows;
72+
""".format(
73+
search_words,
74+
self.collections,
75+
self.scoring_method,
76+
self.max_results,
77+
query_embedding,
78+
self.tde_schema,
79+
self.tde_view,
80+
)
81+
82+
def _get_relevant_documents(self, query: str) -> List[Document]:
83+
print(f"Searching with query: {query}")
84+
85+
query_embedding = self.embedding_generator.embed_query(query)
86+
javascript_query = self._build_javascript_query_query(
87+
query, query_embedding
88+
)
89+
results = self.client.eval(javascript=javascript_query)
90+
91+
print(f"Count of matching MarkLogic documents: {len(results)}")
92+
return map(lambda doc: Document(page_content=doc["text"]), results)
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,20 @@
1+
{
2+
"role-name": "langchain-eval-role",
3+
"privilege": [
4+
{
5+
"privilege-name": "xdmp:eval",
6+
"action": "http://marklogic.com/xdmp/privileges/xdmp-eval",
7+
"kind": "execute"
8+
},
9+
{
10+
"privilege-name": "xdmp:eval-in",
11+
"action": "http://marklogic.com/xdmp/privileges/xdmp-eval-in",
12+
"kind": "execute"
13+
},
14+
{
15+
"privilege-name": "xdbc:eval",
16+
"action": "http://marklogic.com/xdmp/privileges/xdbc-eval",
17+
"kind": "execute"
18+
}
19+
]
20+
}

examples/langchain/src/main/ml-config/security/users/langchain-user.json

+2-1
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44
"role": [
55
"rest-reader",
66
"rest-writer",
7-
"qconsole-user"
7+
"qconsole-user",
8+
"langchain-eval-role"
89
]
910
}

0 commit comments

Comments
 (0)