-
Notifications
You must be signed in to change notification settings - Fork 90
/
Copy pathmain.py
97 lines (79 loc) · 2.57 KB
/
main.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
# from huggingface_hub import login
# your_token = "INPUT YOUR TOKEN HERE"
# login(your_token)
import os
from minirag import MiniRAG, QueryParam
from minirag.llm.hf import (
hf_model_complete,
hf_embed,
)
from minirag.utils import EmbeddingFunc
from transformers import AutoModel, AutoTokenizer
EMBEDDING_MODEL = "sentence-transformers/all-MiniLM-L6-v2"
import argparse
def get_args():
parser = argparse.ArgumentParser(description="MiniRAG")
parser.add_argument("--model", type=str, default="PHI")
parser.add_argument("--outputpath", type=str, default="./logs/Default_output.csv")
parser.add_argument("--workingdir", type=str, default="./LiHua-World")
parser.add_argument("--datapath", type=str, default="./dataset/LiHua-World/data/")
parser.add_argument(
"--querypath", type=str, default="./dataset/LiHua-World/qa/query_set.csv"
)
args = parser.parse_args()
return args
args = get_args()
if args.model == "PHI":
LLM_MODEL = "microsoft/Phi-3.5-mini-instruct"
elif args.model == "GLM":
LLM_MODEL = "THUDM/glm-edge-1.5b-chat"
elif args.model == "MiniCPM":
LLM_MODEL = "openbmb/MiniCPM3-4B"
elif args.model == "qwen":
LLM_MODEL = "Qwen/Qwen2.5-3B-Instruct"
else:
print("Invalid model name")
exit(1)
WORKING_DIR = args.workingdir
DATA_PATH = args.datapath
QUERY_PATH = args.querypath
OUTPUT_PATH = args.outputpath
print("USING LLM:", LLM_MODEL)
print("USING WORKING DIR:", WORKING_DIR)
if not os.path.exists(WORKING_DIR):
os.mkdir(WORKING_DIR)
rag = MiniRAG(
working_dir=WORKING_DIR,
llm_model_func=hf_model_complete,
llm_model_max_token_size=200,
llm_model_name=LLM_MODEL,
embedding_func=EmbeddingFunc(
embedding_dim=384,
max_token_size=1000,
func=lambda texts: hf_embed(
texts,
tokenizer=AutoTokenizer.from_pretrained(EMBEDDING_MODEL),
embed_model=AutoModel.from_pretrained(EMBEDDING_MODEL),
),
),
)
# Now indexing
def find_txt_files(root_path):
txt_files = []
for root, dirs, files in os.walk(root_path):
for file in files:
if file.endswith(".txt"):
txt_files.append(os.path.join(root, file))
return txt_files
WEEK_LIST = find_txt_files(DATA_PATH)
for WEEK in WEEK_LIST:
id = WEEK_LIST.index(WEEK)
print(f"{id}/{len(WEEK_LIST)}")
with open(WEEK) as f:
rag.insert(f.read())
# A toy query
query = 'What does LiHua predict will happen in "The Rings of Power"?'
answer = (
rag.query(query, param=QueryParam(mode="mini")).replace("\n", "").replace("\r", "")
)
print(answer)