-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathdemo.py
More file actions
64 lines (53 loc) · 2.08 KB
/
demo.py
File metadata and controls
64 lines (53 loc) · 2.08 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
import os
from typing import List
import json
import argparse
import logging
from src.hipporag import HippoRAG
def main():
# Prepare datasets and evaluation
docs = [
"Oliver Badman is a politician.",
"George Rankin is a politician.",
"Thomas Marwick is a politician.",
"Cinderella attended the royal ball.",
"The prince used the lost glass slipper to search the kingdom.",
"When the slipper fit perfectly, Cinderella was reunited with the prince.",
"Erik Hort's birthplace is Montebello.",
"Marina is bom in Minsk.",
"Montebello is a part of Rockland County."
]
save_dir = 'outputs' # Define save directory for HippoRAG objects (each LLM/Embedding model combination will create a new subdirectory)
llm_model_name = 'gpt-4o-mini' # Any OpenAI model name
embedding_model_name = 'GritLM/GritLM-7B' # Embedding model name (NV-Embed, GritLM or Contriever for now)
# Startup a HippoRAG instance
hipporag = HippoRAG(save_dir=save_dir,
llm_model_name=llm_model_name,
embedding_model_name=embedding_model_name)
# Run indexing
hipporag.index(docs=docs)
# Separate Retrieval & QA
queries = [
"What is George Rankin's occupation?",
"How did Cinderella reach her happy ending?",
"What county is Erik Hort's birthplace a part of?"
]
# For Evaluation
answers = [
["Politician"],
["By going to the ball."],
["Rockland County"]
]
gold_docs = [
["George Rankin is a politician."],
["Cinderella attended the royal ball.",
"The prince used the lost glass slipper to search the kingdom.",
"When the slipper fit perfectly, Cinderella was reunited with the prince."],
["Erik Hort's birthplace is Montebello.",
"Montebello is a part of Rockland County."]
]
print(hipporag.rag_qa(queries=queries,
gold_docs=gold_docs,
gold_answers=answers))
if __name__ == "__main__":
main()