-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathgenerate-adaptive-k-retrieval-metrisc.py
More file actions
120 lines (87 loc) · 3.44 KB
/
Copy pathgenerate-adaptive-k-retrieval-metrisc.py
File metadata and controls
120 lines (87 loc) · 3.44 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
import os
import pathlib
import logging
from beir import LoggingHandler
from beir.datasets.data_loader import GenericDataLoader
from beir.retrieval.evaluation import EvaluateRetrieval
from sentence_transformers import SentenceTransformer
import chromadb
import json
from tools.vector_store_search_fallback import vector_store_search
dataset = "fiqa"
route_questions = False
use_adaptive_k = True
cross_encoder_model = "cross-encoder/quo,ra-distilroberta-base"
# ========================================================
logging.basicConfig(
format="%(asctime)s - %(message)s",
datefmt="%Y-%m-%d %H:%M:%S",
level=logging.INFO,
handlers=[LoggingHandler()],
)
embedding_model_name = "all-MiniLM-L6-v2"
collection_name = f"beir-{dataset}-corpus"
top_k = 32
data_path = os.path.join(pathlib.Path(__file__).parent.absolute(), "datasets", dataset)
corpus, queries, qrels = GenericDataLoader(data_folder=data_path).load(split="test")
model = SentenceTransformer(embedding_model_name)
client = chromadb.PersistentClient(path="chroma_store")
collection = client.get_collection(name=collection_name)
results = {}
logging.info("Wyszukiwanie przez Chroma + budowanie results...")
number_of_questions_to_process = 0
number_of_irrelevant_answers = 0
for qid, query_text in queries.items():
if number_of_questions_to_process == 20:
break
all_hits = []
print(f"\nCurrent query: {query_text}")
if use_adaptive_k:
query_hits, k = vector_store_search(
query_text, embedding_model=model, collection=collection, top_k=32
)
else:
query_emb = model.encode(query_text, convert_to_numpy=True).tolist()
query_hits = collection.query(query_embeddings=[query_emb], n_results=top_k)
for doc_id, dist in zip(query_hits["ids"][0], query_hits["distances"][0]):
all_hits.append((doc_id, dist))
doc_scores = {}
all_hits = sorted(all_hits, key=lambda x: x[1])[:top_k]
for doc_id, dist in all_hits:
score = 1 / (1 + dist)
if doc_id in doc_scores:
doc_scores[doc_id] = max(doc_scores[doc_id], score)
else:
doc_scores[doc_id] = score
doc_ids = list(doc_scores.keys())
scores = list(doc_scores.values())
number_of_questions_to_process += 1
print(f"Ilosc odrzuconych odpowiedzi: {number_of_irrelevant_answers}")
logging.info(f"Gotowe — oceniam {len(results)} zapytań.")
# === 5. Ocena BEIR ===
evaluator = EvaluateRetrieval()
number_of_k_values_to_check = [4, 8, 16, 32]
ndcg, _map, recall, precision = evaluator.evaluate(
qrels, results, number_of_k_values_to_check
)
print("\n📊 Wyniki oceny:")
for k in number_of_k_values_to_check:
print(f"NDCG@{k}: {ndcg[f'NDCG@{k}']:.4f}")
print(f"MAP@{k}: {_map[f'MAP@{k}']:.4f}")
print(f"Recall@{k}: {recall[f'Recall@{k}']:.4f}")
print(f"P@{k}: {precision[f'P@{k}']:.4f}")
print()
# === 6. Zapisz wyniki do JSON ===
results_dir = os.path.join(pathlib.Path(__file__).parent.absolute(), "results")
os.makedirs(results_dir, exist_ok=True)
# Sklej wszystkie metryki do jednego słownika
full_metrics = {"NDCG": ndcg, "MAP": _map, "Recall": recall, "Precision": precision}
suffix_name = ""
if use_adaptive_k:
suffix_name += "_adaptive_k_applied"
output_path = os.path.join(
results_dir, f"{dataset}-{embedding_model_name}_first_20_q{suffix_name}.json"
)
with open(output_path, "w") as f:
json.dump(full_metrics, f, indent=4)
print(f"✅ Wyniki zapisane do pliku: {output_path}")