Skip to content

Commit 1fc7f09

Browse files
committed
feat: implement rag in serverless
1 parent 04c2df1 commit 1fc7f09

File tree

4 files changed

+104
-4
lines changed

4 files changed

+104
-4
lines changed

deploy/init_kb_serverless.py

Lines changed: 52 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,52 @@
1+
"""
2+
Initialize Cloudflare vector database.
3+
After initializing the vector database in init_kb.py, run this script to dump the vectorized result to jsonl file for wrangler vector insert.
4+
5+
https://developers.cloudflare.com/vectorize/best-practices/insert-vectors/
6+
Target jsonl line: {id: <node_id>, values: <embedding>, metadata: {text: <text>, ...metadata_}}
7+
"""
8+
from config import config
9+
import psycopg2
10+
import json
11+
import os
12+
import subprocess
13+
14+
print("Clearing existing Cloudflare vector store...")
15+
subprocess.run(["npx", "wrangler", "vectorize", "delete", "ppedt-embed", "--force"], cwd="ppedt-serverless")
16+
print("Creating new Cloudflare vector store, press 'n' if prompted to confirm...")
17+
subprocess.run(["npx", "wrangler", "vectorize", "create", "ppedt-embed", "--preset", "@cf/baai/bge-small-en-v1.5"], cwd="ppedt-serverless", check=True)
18+
19+
20+
print("Dumping vectors from Postgres...")
21+
conn = psycopg2.connect(config.POSTGRES_URI, dbname="ppedt")
22+
conn.autocommit = True
23+
24+
# create table public.data_embed (
25+
# id bigint primary key not null default nextval('data_embed_id_seq'::regclass),
26+
# text character varying not null,
27+
# metadata_ json,
28+
# node_id character varying,
29+
# embedding vector(384)
30+
# );
31+
# create index embed_idx_1 on data_embed using btree (((metadata_ ->> 'ref_doc_id'::text)));
32+
# create index data_embed_embedding_idx on data_embed using hnsw (embedding);
33+
34+
with conn.cursor() as c:
35+
c.execute("SELECT id, text, metadata_ -> 'file_name', node_id, embedding FROM data_embed;")
36+
with open(os.path.join("ppedt-serverless", "data_embed.jsonl"), "w", encoding="utf-8") as f:
37+
for row in c.fetchall():
38+
id, text, file_name, node_id, embedding = row
39+
embedding_list = eval(embedding) # convert to list for JSON serialization
40+
json_line = {
41+
"id": str(node_id),
42+
"values": embedding_list,
43+
"metadata": {
44+
"text": text,
45+
"file_name": file_name
46+
}
47+
}
48+
f.write(json.dumps(json_line) + "\n")
49+
50+
51+
print("Inserting vectors to Cloudflare vector store...")
52+
subprocess.run(["npx", "wrangler", "vectorize", "insert", "ppedt-embed", "--file", "data_embed.jsonl"], cwd="ppedt-serverless", check=True)

deploy/ppedt-serverless/.gitignore

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -166,4 +166,6 @@ dist
166166
!.env.example
167167
.wrangler/
168168

169-
package-lock.json
169+
package-lock.json
170+
# Dumped embedding data
171+
data_embed.jsonl

deploy/ppedt-serverless/src/index.js

Lines changed: 42 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -23,10 +23,49 @@ export default {
2323
if (request.method === "POST") {
2424
const { code, prompt } = JSON.parse(await request.text());
2525

26-
// messages - chat style input
27-
let messages = [
26+
// If prompt has more than ascii alphanumeric characters, ask LLM to rewrite it.
27+
let final_prompt = prompt;
28+
if (/[^\x00-\x7F]+/.test(prompt)) {
29+
const response = await env.AI.run(
30+
'@cf/meta/llama-3-8b-instruct',
31+
{
32+
messages: [
33+
{ role: 'system', content: 'Translate the query to English and modify mathematics unicode symbols to LaTeX commands if necessary without any explanation.' },
34+
{ role: 'user', content: prompt }
35+
]
36+
});
37+
final_prompt = await response.response;
38+
}
39+
40+
// Get the embedding of final_prompt
41+
const embedding_response = await env.AI.run(
42+
'@cf/baai/bge-small-en-v1.5',
43+
{
44+
text: final_prompt
45+
});
46+
const query_vector = await embedding_response.data[0];
47+
48+
// Retrieve context from Cloudflare embedding DB
49+
let matches = await env.VECTORIZE.query(query_vector, {
50+
topK: 3,
51+
returnMetadata: 'all'
52+
});
53+
matches = matches.matches;
54+
// make a cutoff of score < 0.75
55+
matches = matches.filter(m => m.score >= 0.75);
56+
const context_str = matches.map(m => "File: " + m.metadata.file_name + "\n" + m.metadata.text).join('\n\n');
57+
final_prompt = "Context information is below.\n" +
58+
"---------------------\n" +
59+
context_str + "\n" +
60+
"---------------------\n" +
61+
"Answer the query.\n" +
62+
"Query: " + final_prompt + "\n" +
63+
"Answer: ";
64+
65+
// Final Stream output
66+
const messages = [
2867
{ role: 'system', content: 'You are a LaTeX code helper, especially for the code of package pgfplots. Return only the modified version of the following code without any additional text or explanation. You have to make sure the code could compile successfully and don\'t omit the code of documentclass.' },
29-
{ role: 'user', content: prompt + ':\n' + code }
68+
{ role: 'user', content: final_prompt + ':\n' + code }
3069
];
3170
const response = await env.AI.run(
3271
'@cf/meta/llama-3-8b-instruct',

deploy/ppedt-serverless/wrangler.jsonc

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -38,6 +38,13 @@
3838
"directory": "./public/",
3939
"binding": "ASSETS"
4040
},
41+
"vectorize": [
42+
{
43+
"binding": "VECTORIZE",
44+
"index_name": "ppedt-embed",
45+
"remote": true
46+
}
47+
],
4148
"ai": {
4249
"binding": "AI",
4350
}

0 commit comments

Comments
 (0)