-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathjust_evaluate_generation_phase_fixed_k.py
More file actions
174 lines (128 loc) · 4.94 KB
/
Copy pathjust_evaluate_generation_phase_fixed_k.py
File metadata and controls
174 lines (128 loc) · 4.94 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
import os
import pathlib
import logging
from beir import LoggingHandler
from beir.datasets.data_loader import GenericDataLoader
from beir.retrieval.evaluation import EvaluateRetrieval
from sentence_transformers import SentenceTransformer
import chromadb
from langchain_core.prompts import ChatPromptTemplate
from pydantic import BaseModel, Field
from llm_client import together_client
from modul_decyzujny.decompose_query import check_if_query_is_complex, decompose_query
from modul_decyzujny.first_router import query_router, ToolChoice
from modul_decyzujny.knowledge_summarizer import (
summarize_knowledge_with_most_occurring_words,
)
from query_analizer.hyde_generator import HyDEGenerator
from transformers import AutoTokenizer, AutoModelForSequenceClassification
import torch
import re
import json
from tools.vector_store_search_fallback import vector_store_search
# use llm as a judge to check if document is relevant to question
# mozna uzywac llm ALBO rerankera
# ============== parameters to adjust ===================
# specify dataset HERE
# fiqa - ma normalne pytania
# trec-covid tez git
# dataset = "trec-covid-v2"
dataset = "fiqa"
use_ll_aaj = False
use_reranker = False
use_hyde = False
split_documents = False
route_questions = False
use_adaptive_k = False
# moze byc tu odpalic mistrall small 3.1
judge_model = "mistralai/Mixtral-8x7B-Instruct-v0.1"
# ========================================================
# === Logowanie ===
logging.basicConfig(
format="%(asctime)s - %(message)s",
datefmt="%Y-%m-%d %H:%M:%S",
level=logging.INFO,
handlers=[LoggingHandler()],
)
# === Parametry ===
embedding_model_name = "all-MiniLM-L6-v2"
collection_name = f"beir-{dataset}-corpus"
top_k_to_check = 4
# === 1. Załaduj dane ===
data_path = os.path.join(pathlib.Path(__file__).parent.absolute(), "datasets", dataset)
corpus, queries, qrels = GenericDataLoader(data_folder=data_path).load(split="test")
if route_questions:
docs_summary = summarize_knowledge_with_most_occurring_words(corpus, top_k=20)
# === 2. Model ===
model = SentenceTransformer(embedding_model_name)
# === 3. Chroma ===
client = chromadb.PersistentClient(path="chroma_store")
collection = client.get_collection(name=collection_name)
# results to jest slownik, gdzie jako klucz mamy id pytania, a jako wartosc mamy kolejny wlonik w ktorym jes twiele kluczy i warotsci
# te klucze to iq pobranych dokumentow, a warotsci to ich score
results = {}
logging.info("Wyszukiwanie przez Chroma + budowanie results...")
number_of_questions_to_process = 0
GENERATE_ANSWERS = """
You are a helpful assistant. Given a question and some context, generate a precise and factual answer.
Context:
context_from_retrieval
Question: query_from_dataset
Answer:
"""
def generate_model_answer(context, query_from_dataset):
prompt_for_model_answer = GENERATE_ANSWERS.replace(
"context_from_retrieval", full_context
).replace("query_from_dataset", query_from_dataset)
response_model_answer = together_client.chat.completions.create(
# model="mistralai/Mistral-7B-Instruct-v0.1",
# model="meta-llama/Llama-4-Maverick-17B-128E-Instruct-FP8", # dla tego mozna by to weetualnie zrbic na raty
model="meta-llama/Llama-4-Scout-17B-16E-Instruct",
messages=[{"role": "user", "content": prompt_for_model_answer}],
temperature=0,
max_tokens=400,
)
return response_model_answer.choices[0].message.content
for qid, query_text in queries.items():
if number_of_questions_to_process == 20:
break
all_hits = []
print(f"\nCurrent query: {query_text}")
if use_adaptive_k:
query_hits = vector_store_search(
query_text,
embedding_model=model,
collection=collection,
top_k=top_k_to_check,
)
else:
query_emb = model.encode(query_text, convert_to_numpy=True).tolist()
query_hits = collection.query(
query_embeddings=[query_emb], n_results=top_k_to_check
)
relevant_doc_ids = list(qrels[qid].keys())
relevant_docs = [
corpus[doc_id]["text"] for doc_id in relevant_doc_ids if doc_id in corpus
]
full_context = "\n\n".join(relevant_docs)
model_answer = generate_model_answer(full_context, query_text)
generated_answer = 1
scored_answer = 1
number_of_questions_to_process += 1
logging.info(f"Gotowe — oceniam {len(results)} zapytań.")
# === 5. Ocena BEIR ===
evaluator = EvaluateRetrieval()
number_of_k_values_to_check = [4, 8, 16, 32]
ndcg, _map, recall, precision = evaluator.evaluate(
qrels, results, number_of_k_values_to_check
)
# === 6. Zapisz wyniki do JSON ===
# Sklej wszystkie metryki do jednego słownika
if use_adaptive_k:
suffix_name += "_adaptive_k_applied"
output_path = os.path.join(
results_dir, f"{dataset}-{embedding_model_name}_first_20_q{suffix_name}.json"
)
with open(output_path, "w") as f:
json.dump(full_metrics, f, indent=4)
print(f"✅ Wyniki zapisane do pliku: {output_path}")