Skip to content

Commit d60f465

Browse files
committed
make again retrieval/base.py and run_util.py
1 parent baf5664 commit d60f465

3 files changed

Lines changed: 279 additions & 0 deletions

File tree

autorag/autorag/nodes/retrieval/__init__.py

Whitespace-only changes.
Lines changed: 127 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,127 @@
1+
import abc
2+
import logging
3+
import os
4+
from typing import List, Union, Tuple
5+
6+
import pandas as pd
7+
8+
from autorag.schema import BaseModule
9+
from autorag.support import get_support_modules
10+
from autorag.utils import fetch_contents, result_to_dataframe, validate_qa_dataset
11+
from autorag.utils.util import pop_params
12+
13+
logger = logging.getLogger("AutoRAG")
14+
15+
16+
class BaseRetrieval(BaseModule, metaclass=abc.ABCMeta):
17+
def __init__(self, project_dir: str, *args, **kwargs):
18+
logger.info(f"Initialize retrieval node - {self.__class__.__name__}")
19+
20+
self.resources_dir = os.path.join(project_dir, "resources")
21+
data_dir = os.path.join(project_dir, "data")
22+
# fetch data from corpus_data
23+
self.corpus_df = pd.read_parquet(
24+
os.path.join(data_dir, "corpus.parquet"), engine="pyarrow"
25+
)
26+
27+
def __del__(self):
28+
logger.info(f"Deleting retrieval node - {self.__class__.__name__} module...")
29+
30+
def cast_to_run(self, previous_result: pd.DataFrame, *args, **kwargs):
31+
logger.info(f"Running retrieval node - {self.__class__.__name__} module...")
32+
validate_qa_dataset(previous_result)
33+
# find queries columns & type cast queries
34+
assert (
35+
"query" in previous_result.columns
36+
), "previous_result must have query column."
37+
if "queries" not in previous_result.columns:
38+
previous_result["queries"] = previous_result["query"]
39+
previous_result.loc[:, "queries"] = previous_result["queries"].apply(
40+
cast_queries
41+
)
42+
queries = previous_result["queries"].tolist()
43+
return queries
44+
45+
46+
class HybridRetrieval(BaseRetrieval, metaclass=abc.ABCMeta):
47+
def __init__(
48+
self, project_dir: str, target_modules, target_module_params, *args, **kwargs
49+
):
50+
super().__init__(project_dir)
51+
self.target_modules = list(
52+
map(
53+
lambda x, y: get_support_modules(x)(
54+
**y,
55+
project_dir=project_dir,
56+
),
57+
target_modules,
58+
target_module_params,
59+
)
60+
)
61+
self.target_module_params = target_module_params
62+
63+
@result_to_dataframe(["retrieved_contents", "retrieved_ids", "retrieve_scores"])
64+
def pure(self, previous_result: pd.DataFrame, *args, **kwargs):
65+
result_dfs: List[pd.DataFrame] = list(
66+
map(
67+
lambda x, y: x.pure(
68+
**y,
69+
previous_result=previous_result,
70+
),
71+
self.target_modules,
72+
self.target_module_params,
73+
)
74+
)
75+
ids = tuple(
76+
map(lambda df: df["retrieved_ids"].apply(list).tolist(), result_dfs)
77+
)
78+
scores = tuple(
79+
map(
80+
lambda df: df["retrieve_scores"].apply(list).tolist(),
81+
result_dfs,
82+
)
83+
)
84+
85+
_pure_params = pop_params(self._pure, kwargs)
86+
if "ids" in _pure_params or "scores" in _pure_params:
87+
raise ValueError(
88+
"With specifying ids or scores, you must use HybridRRF.run_evaluator instead."
89+
)
90+
ids, scores = self._pure(ids=ids, scores=scores, **_pure_params)
91+
contents = fetch_contents(self.corpus_df, ids)
92+
return contents, ids, scores
93+
94+
95+
def cast_queries(queries: Union[str, List[str]]) -> List[str]:
96+
if isinstance(queries, str):
97+
return [queries]
98+
elif isinstance(queries, List):
99+
return queries
100+
else:
101+
raise ValueError(f"queries must be str or list, but got {type(queries)}")
102+
103+
104+
def evenly_distribute_passages(
105+
ids: List[List[str]], scores: List[List[float]], top_k: int
106+
) -> Tuple[List[str], List[float]]:
107+
assert len(ids) == len(scores), "ids and scores must have same length."
108+
query_cnt = len(ids)
109+
avg_len = top_k // query_cnt
110+
remainder = top_k % query_cnt
111+
112+
new_ids = []
113+
new_scores = []
114+
for i in range(query_cnt):
115+
if i < remainder:
116+
new_ids.extend(ids[i][: avg_len + 1])
117+
new_scores.extend(scores[i][: avg_len + 1])
118+
else:
119+
new_ids.extend(ids[i][:avg_len])
120+
new_scores.extend(scores[i][:avg_len])
121+
122+
return new_ids, new_scores
123+
124+
125+
def get_bm25_pkl_name(bm25_tokenizer: str):
126+
bm25_tokenizer = bm25_tokenizer.replace("/", "")
127+
return f"bm25_{bm25_tokenizer}.pkl"
Lines changed: 152 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,152 @@
1+
import os
2+
import pathlib
3+
from typing import Tuple, List, Union, Dict
4+
5+
import pandas as pd
6+
7+
from autorag.evaluation import evaluate_retrieval
8+
from autorag.schema.metricinput import MetricInput
9+
from autorag.strategy import measure_speed, filter_by_threshold, select_best
10+
11+
12+
def evaluate_retrieval_node(
13+
result_df: pd.DataFrame,
14+
metric_inputs: List[MetricInput],
15+
metrics: Union[List[str], List[Dict]],
16+
) -> pd.DataFrame:
17+
"""
18+
Evaluate retrieval node from retrieval node result dataframe.
19+
:param result_df: The result dataframe from a retrieval node.
20+
:param metric_inputs: List of metric input schema for AutoRAG.
21+
:param metrics: Metric list from input strategies.
22+
:return: Return result_df with metrics columns.
23+
The columns will be 'retrieved_contents', 'retrieved_ids', 'retrieve_scores', and metric names.
24+
"""
25+
26+
@evaluate_retrieval(
27+
metric_inputs=metric_inputs,
28+
metrics=metrics,
29+
)
30+
def evaluate_this_module(df: pd.DataFrame):
31+
return (
32+
df["retrieved_contents"].tolist(),
33+
df["retrieved_ids"].tolist(),
34+
df["retrieve_scores"].tolist(),
35+
)
36+
37+
return evaluate_this_module(result_df)
38+
39+
40+
def run(
41+
input_modules,
42+
input_module_params,
43+
project_dir: Union[str, pathlib.Path, pathlib.PurePath],
44+
previous_result: pd.DataFrame,
45+
strategies,
46+
metric_inputs: List[MetricInput],
47+
) -> Tuple[List[pd.DataFrame], List]:
48+
"""
49+
Run input modules and parameters.
50+
:param input_modules: Input modules
51+
:param input_module_params: Input module parameters
52+
:param project_dir: Project directory path.
53+
:param previous_result: Previous result dataframe.
54+
:param strategies: Strategies for retrieval node.
55+
:param metric_inputs: List of metric input schema for AutoRAG.
56+
:return: First, it returns list of result dataframe.
57+
Second, it returns list of execution times.
58+
"""
59+
result, execution_times = zip(
60+
*map(
61+
lambda task: measure_speed(
62+
task[0].run_evaluator,
63+
project_dir=project_dir,
64+
previous_result=previous_result,
65+
**task[1],
66+
),
67+
zip(input_modules, input_module_params),
68+
)
69+
)
70+
average_times = list(map(lambda x: x / len(result[0]), execution_times))
71+
72+
# run metrics before filtering
73+
if strategies.get("metrics") is None:
74+
raise ValueError("You must at least one metrics for retrieval evaluation.")
75+
result = list(
76+
map(
77+
lambda x: evaluate_retrieval_node(
78+
x,
79+
metric_inputs,
80+
strategies.get("metrics"),
81+
),
82+
result,
83+
)
84+
)
85+
86+
return result, average_times
87+
88+
89+
def save_and_summary(
90+
input_modules,
91+
input_module_params,
92+
result_list,
93+
execution_time_list,
94+
filename_start: int,
95+
save_dir: Union[str, pathlib.Path, pathlib.PurePath],
96+
strategies,
97+
):
98+
"""
99+
Save the result and make summary file
100+
:param input_modules: Input modules
101+
:param input_module_params: Input module parameters
102+
:param result_list: Result list
103+
:param execution_time_list: Execution times
104+
:param filename_start: The first filename to use
105+
:return: First, it returns list of result dataframe.
106+
Second, it returns list of execution times.
107+
"""
108+
109+
# save results to folder
110+
filepaths = list(
111+
map(
112+
lambda x: os.path.join(save_dir, f"{x}.parquet"),
113+
range(filename_start, filename_start + len(input_modules)),
114+
)
115+
)
116+
list(
117+
map(
118+
lambda x: x[0].to_parquet(x[1], index=False),
119+
zip(result_list, filepaths),
120+
)
121+
) # execute save to parquet
122+
filename_list = list(map(lambda x: os.path.basename(x), filepaths))
123+
124+
summary_df = pd.DataFrame(
125+
{
126+
"filename": filename_list,
127+
"module_name": list(map(lambda module: module.__name__, input_modules)),
128+
"module_params": input_module_params,
129+
"execution_time": execution_time_list,
130+
**{
131+
metric: list(map(lambda result: result[metric].mean(), result_list))
132+
for metric in strategies.get("metrics")
133+
},
134+
}
135+
)
136+
summary_df.to_csv(os.path.join(save_dir, "summary.csv"), index=False)
137+
return summary_df
138+
139+
140+
def find_best(results, average_times, filenames, strategies):
141+
# filter by strategies
142+
if strategies.get("speed_threshold") is not None:
143+
results, filenames = filter_by_threshold(
144+
results, average_times, strategies["speed_threshold"], filenames
145+
)
146+
selected_result, selected_filename = select_best(
147+
results,
148+
strategies.get("metrics"),
149+
filenames,
150+
strategies.get("strategy", "mean"),
151+
)
152+
return selected_result, selected_filename

0 commit comments

Comments
 (0)