-
Notifications
You must be signed in to change notification settings - Fork 1
Expand file tree
/
Copy pathrag_logic.py
More file actions
145 lines (123 loc) · 5.27 KB
/
rag_logic.py
File metadata and controls
145 lines (123 loc) · 5.27 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
import os, base64
import time
from couchbase.cluster import Cluster
from couchbase.auth import PasswordAuthenticator
from couchbase.options import ClusterOptions
from datetime import timedelta
from langchain_couchbase import CouchbaseQueryVectorStore
from langchain_openai import OpenAIEmbeddings
from langchain_core.documents import Document
from langchain_couchbase.vectorstores import CouchbaseSearchVectorStore
from langchain_openai import ChatOpenAI
from langchain_core.prompts import ChatPromptTemplate
from langchain_core.runnables import RunnablePassthrough
from langchain_core.output_parsers import StrOutputParser
def connect_to_couchbase(connection_string, db_username, db_password):
"""Connect to couchbase"""
auth = PasswordAuthenticator(db_username, db_password)
cluster = Cluster(f"couchbases://{connection_string}", ClusterOptions(auth))
# Wait until the cluster is ready for use.
cluster.wait_until_ready(timedelta(seconds=5))
return cluster
def get_vector_store(_cluster, db_bucket, db_scope, db_collection, _embedding, _index_name):
"""Return the Couchbase vector store"""
vector_store = CouchbaseSearchVectorStore(
cluster=_cluster,
bucket_name=db_bucket,
scope_name=db_scope,
collection_name=db_collection,
embedding=_embedding,
index_name=_index_name,
scoped_index=db_scope,
embedding_key="embedding",
text_key="text"
)
return vector_store
def init():
"""Initialize all components and return the RAG chain and cluster"""
print("Initializing RAG components...")
# Configuration
DB_CONN_STR = os.getenv("COUCHBASE_HOST", "cb.tdxltwlk73vwbb4f.cloud.couchbase.com")
DB_USERNAME = os.getenv("COUCHBASE_USER", "Administrator")
DB_PASSWORD = os.getenv("COUCHBASE_PASSWORD", "P@ssw0rd1!")
DB_BUCKET = os.getenv("COUCHBASE_BUCKET", "airlines_stats")
DB_SCOPE = os.getenv("DB_SCOPE", "data")
DB_COLLECTION = os.getenv("DB_COLLECTION", "baggage_policies")
INDEX_NAME = os.getenv("INDEX_NAME", "FTS_vector_search2")
CAPELLA_AI_KEY = os.getenv("CAPELLA_AI_KEY", "cbsk-v1-cjXYyRCStiKkFZtgOIYuC1e367k6Od9abiY1XoN28C6ZCd39")
EMBEDDING_MODEL_NAME = "nvidia/llama-3.2-nv-embedqa-1b-v2"
LLM_MODEL_NAME = "mistralai/mistral-7b-instruct-v0.3"
CAPELLA_MODEL_SERVICES_ENDPOINT = "https://xbxg79oc1emidd.ai.cloud.couchbase.com/v1"
# 1. Connect to Couchbase
try:
print("Connecting to Couchbase...")
cluster = connect_to_couchbase(DB_CONN_STR, DB_USERNAME, DB_PASSWORD)
print("Connected to Couchbase.")
except Exception as e:
raise ConnectionError(f"Failed to connect to Couchbase: {e}")
# 2. Setup Embeddings
try:
embeddings = OpenAIEmbeddings(
api_key=CAPELLA_AI_KEY,
base_url=CAPELLA_MODEL_SERVICES_ENDPOINT,
model=EMBEDDING_MODEL_NAME,
check_embedding_ctx_length=False,
tiktoken_enabled=False
)
print("Embeddings initialized.")
except Exception as e:
raise ValueError(f"Error creating embeddings: {e}")
# 3. Setup Vector Store & Retriever
vector_store = get_vector_store(
cluster, DB_BUCKET, DB_SCOPE, DB_COLLECTION, embeddings, INDEX_NAME
)
retriever = vector_store.as_retriever()
# 4. Setup LLM
llm = ChatOpenAI(
model=LLM_MODEL_NAME,
temperature=0,
api_key=CAPELLA_AI_KEY,
base_url=CAPELLA_MODEL_SERVICES_ENDPOINT,
)
# 5. Build RAG Chain
template = """You are a helpful assistant. Answer the question based only on the following context. Be concise and respond precisely to what the customer is asking. Context may contain irrelevant information, ignore that. Deliver one clear answer in the first sentence. Then a new line, then additional explanation if needed.:
{context}
Question: {question}
"""
prompt_template = ChatPromptTemplate.from_template(template)
chain = (
{"context": retriever, "question": RunnablePassthrough()}
| prompt_template
| llm
| StrOutputParser()
)
print("Initialization complete.")
return chain, cluster
def get_tier_status(cluster, username):
"""Retrieve the tier status from Couchbase using SQL++"""
print(f"Retrieving tier status for user: {username}")
try:
query = "SELECT tierStatus FROM `airlines_stats`.`data`.`customers` WHERE name = $user"
result = cluster.query(query, user=username)
for row in result:
tier_status = row.get("tierStatus")
print(f"Found tier: {tier_status}")
return tier_status
print(f"No tier status found for user: {username}. Defaulting to 'silver'.")
return "silver"
except Exception as e:
print(f"Error retrieving tier status: {e}")
return "silver"
def query_rag(chain, prompt):
"""Execute the RAG query using the provided chain"""
print(f"Processing prompt: {prompt}")
print("Invoking Chain...")
try:
response = chain.invoke(prompt)
print("\n--- Response ---")
print(response)
print("----------------")
return response
except Exception as e:
print(f"Error during chain invocation: {e}")
return "No response. Something is broken."