-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathgenerate-adaptive-k-retrieval-metrisc-changed.py
More file actions
238 lines (197 loc) · 7.36 KB
/
Copy pathgenerate-adaptive-k-retrieval-metrisc-changed.py
File metadata and controls
238 lines (197 loc) · 7.36 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
#!/usr/bin/env python3
# -*- coding: utf-8 -*-
import os
import pathlib
import logging
import json
import math
from typing import Dict, List, Tuple
from beir import LoggingHandler
from beir.datasets.data_loader import GenericDataLoader
from sentence_transformers import SentenceTransformer
import chromadb
import numpy as np
from tools.vector_store_search_fallback import vector_store_search
# =========================
# Konfiguracja
# =========================
dataset = "fiqa"
use_adaptive_k = True
embedding_model_name = "all-MiniLM-L6-v2"
collection_name = f"beir-{dataset}-corpus"
top_k = 100
limit_questions = 20
# =========================
# Logowanie
# =========================
logging.basicConfig(
format="%(asctime)s - %(message)s",
datefmt="-%Y-%m-%d %H:%M:%S",
level=logging.INFO,
handlers=[LoggingHandler()],
)
# =========================
# Metryki @k_i (binarny qrels)
# =========================
def precision_at_k(retrieved_list: List[str], relevant_set: set, k: int) -> float:
if k <= 0:
return 0.0
retrieved_k = retrieved_list[:k]
hits = sum(1 for d in retrieved_k if d in relevant_set)
return hits / k
def recall_at_k(retrieved_list: List[str], relevant_set: set, k: int) -> float:
if len(relevant_set) == 0:
return 0.0
retrieved_k = retrieved_list[:k]
hits = sum(1 for d in retrieved_k if d in relevant_set)
return hits / len(relevant_set)
def average_precision_at_k(
retrieved_list: List[str], relevant_set: set, k: int
) -> float:
if len(relevant_set) == 0 or k <= 0:
return 0.0
ap = 0.0
hits = 0
for i, doc_id in enumerate(retrieved_list[:k], start=1):
if doc_id in relevant_set:
hits += 1
ap += hits / i
denom = min(len(relevant_set), k)
return ap / denom if denom > 0 else 0.0
def ndcg_at_k(retrieved_list: List[str], relevant_set: set, k: int) -> float:
if k <= 0:
return 0.0
# DCG (binary gains)
dcg = 0.0
for i, doc_id in enumerate(retrieved_list[:k], start=1):
rel = 1.0 if doc_id in relevant_set else 0.0
if rel > 0:
dcg += (2**rel - 1) / math.log2(i + 1)
# IDCG
ideal_hits = min(len(relevant_set), k)
if ideal_hits == 0:
return 0.0
idcg = sum((2**1 - 1) / math.log2(i + 1) for i in range(1, ideal_hits + 1))
return dcg / idcg if idcg > 0 else 0.0
# =========================
# Główny skrypt
# =========================
def main():
# --- Ładowanie datasetu BEIR ---
data_path = os.path.join(
pathlib.Path(__file__).parent.absolute(), "datasets", dataset
)
corpus, queries, qrels = GenericDataLoader(data_folder=data_path).load(split="test")
# Ewentualny limit pytań (deterministycznie po sortowanych kluczach)
if limit_questions and limit_questions > 0:
sel_qids = sorted(list(queries.keys()))[:limit_questions]
queries = {qid: queries[qid] for qid in sel_qids}
qrels = {qid: qrels.get(qid, {}) for qid in sel_qids}
logging.info(f"Używam tylko {len(queries)} pytań (limit={limit_questions}).")
# --- Model i wektorowa baza (Chroma) ---
model = SentenceTransformer(embedding_model_name)
client = chromadb.PersistentClient(path="chroma_store")
collection = client.get_collection(name=collection_name)
# --- Wyniki i pomocnicze struktury ---
k_per_query: Dict[str, int] = {}
per_query_metrics: Dict[str, Dict] = {}
sum_p = sum_r = sum_map = sum_ndcg = 0.0
n_q = 0
logging.info("Retrieval + ocena per-query (variable k) ...")
# Iteracja po pytaniach
for qid, query_text in queries.items():
print(f"\n🔎 Query[{qid}]: {query_text}")
# --- Retrieval (+ adaptacyjne k z Twojego modułu) ---
if use_adaptive_k:
# Twoja funkcja: zwraca hits + wybrane k
query_hits, k_i = vector_store_search(
query_text, embedding_model=model, collection=collection, top_k=top_k
)
k_i = int(k_i)
else:
# Stałe top_k (fallback)
query_emb = model.encode(query_text, convert_to_numpy=True).tolist()
query_hits = collection.query(query_embeddings=[query_emb], n_results=top_k)
k_i = int(top_k)
# --- Zbierz (doc_id, dist) i przelicz score ---
all_hits: List[Tuple[str, float]] = []
for doc_id, dist in zip(query_hits["ids"][0], query_hits["distances"][0]):
all_hits.append((doc_id, dist))
# sortuj po rosnącej odległości i utnij do top_k (maksymalny bufor)
all_hits = sorted(all_hits, key=lambda x: x[1])[:top_k]
# zamiana na score (im mniejszy dystans, tym większy score)
doc_scores: Dict[str, float] = {}
for doc_id, dist in all_hits:
score = 1.0 / (1.0 + dist)
if doc_id in doc_scores:
doc_scores[doc_id] = max(doc_scores[doc_id], score)
else:
doc_scores[doc_id] = score
# posortuj malejąco po score i **przytnij do k_i**
ranked = sorted(doc_scores.items(), key=lambda x: x[1], reverse=True)
ranked_doc_ids = [d for d, s in ranked][:k_i]
# --- Metryki per-query przy własnym k_i ---
rel_set = set([did for did, rel in qrels.get(qid, {}).items() if rel > 0])
p = precision_at_k(ranked_doc_ids, rel_set, k_i)
r = recall_at_k(ranked_doc_ids, rel_set, k_i)
ap = average_precision_at_k(ranked_doc_ids, rel_set, k_i)
nd = ndcg_at_k(ranked_doc_ids, rel_set, k_i)
# zapisz
k_per_query[qid] = k_i
per_query_metrics[qid] = {
"k_i": k_i,
"precision": p,
"recall": r,
"map": ap,
"ndcg": nd,
}
sum_p += p
sum_r += r
sum_map += ap
sum_ndcg += nd
n_q += 1
print(f" k_i={k_i} | P={p:.3f} R={r:.3f} MAP={ap:.3f} NDCG={nd:.3f}")
# --- Agregaty ---
agg = {
"count_questions": n_q,
"mean": {
"precision": (sum_p / n_q) if n_q else 0.0,
"recall": (sum_r / n_q) if n_q else 0.0,
"map": (sum_map / n_q) if n_q else 0.0,
"ndcg": (sum_ndcg / n_q) if n_q else 0.0,
},
"sum": {
"precision": sum_p,
"recall": sum_r,
"map": sum_map,
"ndcg": sum_ndcg,
},
}
print("\n📊 Średnie (variable k per query):")
for m, v in agg["mean"].items():
print(f" - {m.upper()}: {v:.4f}")
# --- Zapis JSON ---
results_dir = os.path.join(pathlib.Path(__file__).parent.absolute(), "results")
os.makedirs(results_dir, exist_ok=True)
suffix = "_adaptive_k_applied" if use_adaptive_k else "_fixed_k"
output_path = os.path.join(
results_dir,
f"only_k{dataset}-{embedding_model_name}_first_{len(queries)}q{suffix}.json",
)
payload = {
"config": {
"dataset": dataset,
"embedding_model": embedding_model_name,
"top_k_pool": top_k,
"use_adaptive_k": use_adaptive_k,
"limit_questions": limit_questions,
},
"k_per_query": k_per_query,
"per_query_metrics": per_query_metrics,
"aggregation": agg,
}
with open(output_path, "w", encoding="utf-8") as f:
json.dump(payload, f, indent=2)
print(f"\n✅ Wyniki zapisane do pliku: {output_path}")
if __name__ == "__main__":
main()