|
2 | 2 | import os |
3 | 3 | import pathlib |
4 | 4 | from copy import deepcopy |
5 | | -from typing import List, Dict, Optional |
| 5 | +from typing import List, Dict, Optional, Union |
6 | 6 |
|
7 | 7 | import pandas as pd |
8 | 8 |
|
9 | | -from autorag.nodes.retrieval.run import evaluate_retrieval_node |
| 9 | +from autorag.evaluation import evaluate_retrieval |
10 | 10 | from autorag.schema.metricinput import MetricInput |
11 | 11 | from autorag.strategy import measure_speed, filter_by_threshold, select_best |
12 | 12 | from autorag.support import get_support_modules |
| 13 | +from autorag.utils.cast import cast_retrieve_infos |
13 | 14 | from autorag.utils.util import make_combinations, explode |
14 | 15 |
|
15 | 16 | logger = logging.getLogger("AutoRAG") |
@@ -217,14 +218,16 @@ def evaluate_one_query_expansion_node( |
217 | 218 | zip(retrieval_funcs, retrieval_params), |
218 | 219 | ) |
219 | 220 | ) |
| 221 | + # Cast each retrieval results |
| 222 | + retrieval_result_dicts = list(map(cast_retrieve_infos, retrieval_results)) |
220 | 223 | evaluation_results = list( |
221 | 224 | map( |
222 | 225 | lambda x: evaluate_retrieval_node( |
223 | 226 | x, |
224 | 227 | metric_inputs, |
225 | 228 | metrics, |
226 | 229 | ), |
227 | | - retrieval_results, |
| 230 | + retrieval_result_dicts, |
228 | 231 | ) |
229 | 232 | ) |
230 | 233 | best_result, _ = select_best( |
@@ -274,3 +277,32 @@ def make_retrieval_callable_params(strategy_dict: Dict): |
274 | 277 | ) |
275 | 278 | ) |
276 | 279 | return explode(modules, param_combinations) |
| 280 | + |
| 281 | + |
| 282 | +def evaluate_retrieval_node( |
| 283 | + result_dict: Dict, |
| 284 | + metric_inputs: List[MetricInput], |
| 285 | + metrics: Union[List[str], List[Dict]], |
| 286 | +) -> pd.DataFrame: |
| 287 | + """ |
| 288 | + Evaluate retrieval node from retrieval node result dataframe. |
| 289 | +
|
| 290 | + :param result_df: The result dataframe from a retrieval node. |
| 291 | + :param metric_inputs: List of metric input schema for AutoRAG. |
| 292 | + :param metrics: Metric list from input strategies. |
| 293 | + :return: Return result_df with metrics columns. |
| 294 | + The columns will be 'retrieved_contents', 'retrieved_ids', 'retrieve_scores', and metric names. |
| 295 | + """ |
| 296 | + |
| 297 | + @evaluate_retrieval( |
| 298 | + metric_inputs=metric_inputs, |
| 299 | + metrics=metrics, |
| 300 | + ) |
| 301 | + def evaluate_this_module(_dict: Dict): |
| 302 | + return ( |
| 303 | + _dict["retrieved_contents"], |
| 304 | + _dict["retrieved_ids"], |
| 305 | + _dict["retrieve_scores"], |
| 306 | + ) |
| 307 | + |
| 308 | + return evaluate_this_module(result_dict) |
0 commit comments