forked from cocoindex-io/cocoindex
-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathmain.py
More file actions
123 lines (107 loc) · 4.07 KB
/
main.py
File metadata and controls
123 lines (107 loc) · 4.07 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
from dotenv import load_dotenv
from psycopg_pool import ConnectionPool
import cocoindex
import os
from typing import Any
@cocoindex.transform_flow()
def text_to_embedding(
text: cocoindex.DataSlice[str],
) -> cocoindex.DataSlice[list[float]]:
"""
Embed the text using a SentenceTransformer model.
This is a shared logic between indexing and querying, so extract it as a function.
"""
return text.transform(
cocoindex.functions.SentenceTransformerEmbed(
model="sentence-transformers/all-MiniLM-L6-v2"
)
)
@cocoindex.flow_def(name="AmazonS3TextEmbedding")
def amazon_s3_text_embedding_flow(
flow_builder: cocoindex.FlowBuilder, data_scope: cocoindex.DataScope
) -> None:
"""
Define an example flow that embeds text from Amazon S3 into a vector database.
"""
bucket_name = os.environ["AMAZON_S3_BUCKET_NAME"]
prefix = os.environ.get("AMAZON_S3_PREFIX", None)
sqs_queue_url = os.environ.get("AMAZON_S3_SQS_QUEUE_URL", None)
data_scope["documents"] = flow_builder.add_source(
cocoindex.sources.AmazonS3(
bucket_name=bucket_name,
prefix=prefix,
included_patterns=["*.md", "*.mdx", "*.txt", "*.docx"],
binary=False,
sqs_queue_url=sqs_queue_url,
)
)
doc_embeddings = data_scope.add_collector()
with data_scope["documents"].row() as doc:
doc["chunks"] = doc["content"].transform(
cocoindex.functions.SplitRecursively(),
language="markdown",
chunk_size=2000,
chunk_overlap=500,
)
with doc["chunks"].row() as chunk:
chunk["embedding"] = text_to_embedding(chunk["text"])
doc_embeddings.collect(
filename=doc["filename"],
location=chunk["location"],
text=chunk["text"],
embedding=chunk["embedding"],
)
doc_embeddings.export(
"doc_embeddings",
cocoindex.targets.Postgres(),
primary_key_fields=["filename", "location"],
vector_indexes=[
cocoindex.VectorIndexDef(
field_name="embedding",
metric=cocoindex.VectorSimilarityMetric.COSINE_SIMILARITY,
)
],
)
def search(pool: ConnectionPool, query: str, top_k: int = 5) -> list[dict[str, Any]]:
# Get the table name, for the export target in the amazon_s3_text_embedding_flow above.
table_name = cocoindex.utils.get_target_default_name(
amazon_s3_text_embedding_flow, "doc_embeddings"
)
# Evaluate the transform flow defined above with the input query, to get the embedding.
query_vector = text_to_embedding.eval(query)
# Run the query and get the results.
with pool.connection() as conn:
with conn.cursor() as cur:
cur.execute(
f"""
SELECT filename, text, embedding <=> %s::vector AS distance
FROM {table_name} ORDER BY distance LIMIT %s
""",
(query_vector, top_k),
)
return [
{"filename": row[0], "text": row[1], "score": 1.0 - row[2]}
for row in cur.fetchall()
]
def _main() -> None:
# Initialize the database connection pool.
pool = ConnectionPool(os.getenv("COCOINDEX_DATABASE_URL"))
amazon_s3_text_embedding_flow.setup()
with cocoindex.FlowLiveUpdater(amazon_s3_text_embedding_flow) as updater:
# Run queries in a loop to demonstrate the query capabilities.
while True:
query = input("Enter search query (or Enter to quit): ")
if query == "":
break
# Run the query function with the database connection pool and the query.
results = search(pool, query)
print("\nSearch results:")
for result in results:
print(f"[{result['score']:.3f}] {result['filename']}")
print(f" {result['text']}")
print("---")
print()
if __name__ == "__main__":
load_dotenv()
cocoindex.init()
_main()