Skip to content

Commit 81d84f4

Browse files
committed
Fix query expansion with new retrieval
1 parent 17b84a1 commit 81d84f4

4 files changed

Lines changed: 47 additions & 20 deletions

File tree

autorag/autorag/nodes/queryexpansion/run.py

Lines changed: 35 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -2,14 +2,15 @@
22
import os
33
import pathlib
44
from copy import deepcopy
5-
from typing import List, Dict, Optional
5+
from typing import List, Dict, Optional, Union
66

77
import pandas as pd
88

9-
from autorag.nodes.retrieval.run import evaluate_retrieval_node
9+
from autorag.evaluation import evaluate_retrieval
1010
from autorag.schema.metricinput import MetricInput
1111
from autorag.strategy import measure_speed, filter_by_threshold, select_best
1212
from autorag.support import get_support_modules
13+
from autorag.utils.cast import cast_retrieve_infos
1314
from autorag.utils.util import make_combinations, explode
1415

1516
logger = logging.getLogger("AutoRAG")
@@ -217,14 +218,16 @@ def evaluate_one_query_expansion_node(
217218
zip(retrieval_funcs, retrieval_params),
218219
)
219220
)
221+
# Cast each retrieval results
222+
retrieval_result_dicts = list(map(cast_retrieve_infos, retrieval_results))
220223
evaluation_results = list(
221224
map(
222225
lambda x: evaluate_retrieval_node(
223226
x,
224227
metric_inputs,
225228
metrics,
226229
),
227-
retrieval_results,
230+
retrieval_result_dicts,
228231
)
229232
)
230233
best_result, _ = select_best(
@@ -274,3 +277,32 @@ def make_retrieval_callable_params(strategy_dict: Dict):
274277
)
275278
)
276279
return explode(modules, param_combinations)
280+
281+
282+
def evaluate_retrieval_node(
283+
result_dict: Dict,
284+
metric_inputs: List[MetricInput],
285+
metrics: Union[List[str], List[Dict]],
286+
) -> pd.DataFrame:
287+
"""
288+
Evaluate retrieval node from retrieval node result dataframe.
289+
290+
:param result_df: The result dataframe from a retrieval node.
291+
:param metric_inputs: List of metric input schema for AutoRAG.
292+
:param metrics: Metric list from input strategies.
293+
:return: Return result_df with metrics columns.
294+
The columns will be 'retrieved_contents', 'retrieved_ids', 'retrieve_scores', and metric names.
295+
"""
296+
297+
@evaluate_retrieval(
298+
metric_inputs=metric_inputs,
299+
metrics=metrics,
300+
)
301+
def evaluate_this_module(_dict: Dict):
302+
return (
303+
_dict["retrieved_contents"],
304+
_dict["retrieved_ids"],
305+
_dict["retrieve_scores"],
306+
)
307+
308+
return evaluate_this_module(result_dict)

docs/source/migration.md

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -151,3 +151,7 @@ node_lines:
151151
```
152152

153153
This YAML file do the same thing as the previous v0.3.7 version.
154+
155+
Also, you’re no longer able to use the hybrid retrieval node in the `query_expansion` node as `retrieval_modules`.
156+
We’re considering to add this feature in the future, but for now, you can use semantic and lexical retrieval nodes to evaluate query expansion.
157+
For most cases, you don't need to use hybrid retrieval node in the `query_expansion` node.

docs/source/nodes/query_expansion/query_expansion.md

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -35,11 +35,12 @@ Please refer to the parameter of [retrieval Node](../retrieval/retrieval.md) for
3535
1. **Metrics**: Metrics such as `retrieval_f1`,`retrieval_recall`, and `retrieval_precision` are used to evaluate the performance of the query expansion process through its impact on retrieval outcomes.
3636
2. **Speed Threshold**: `speed_threshold` is applied across all nodes, ensuring that any method exceeding the average processing time for a query is not used.
3737
3. **Top_k**: This parameter specifies the number of top results to consider during the retrieval evaluation phase.
38-
4. **Retrieval Modules**: The query expansion node can use all modules and module parameters from the retrieval node, including:
38+
4. **Retrieval Modules**: The query expansion node can use modules and module parameters from the lexical retrieval and semantic retrieval node, including:
3939
- [bm25](../retrieval/bm25.md)
40-
- [vectordb](../retrieval/vectordb.md): with `embedding_model` parameter
41-
- [hybrid_rrf](../retrieval/hybrid_rrf.md): with `target_modules` and `rrf_k` parameters
42-
- [hybrid_cc](../retrieval/hybrid_cc.md): with `target_modules` and `weights` parameters
40+
- [vectordb](../retrieval/vectordb.md): with `vectordb` parameter
41+
```{warning}
42+
You cannot use the hybrid retrieval modules in the query expansion node.
43+
```
4344

4445
### Example config.yaml file
4546
```yaml

tests/autorag/nodes/queryexpansion/test_query_expansion_run.py

Lines changed: 3 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,9 @@
1414
from autorag.nodes.queryexpansion import QueryDecompose, HyDE
1515
from autorag.nodes.queryexpansion.run import evaluate_one_query_expansion_node
1616
from autorag.nodes.queryexpansion.run import run_query_expansion_node
17-
from autorag.nodes.retrieval import BM25, VectorDB, HybridCC
17+
from autorag.nodes.lexicalretrieval import BM25
18+
from autorag.nodes.semanticretrieval import VectorDB
19+
from autorag.nodes.hybridretrieval import HybridCC
1820
from autorag.nodes.semanticretrieval.vectordb import vectordb_ingest_api
1921
from autorag.schema.metricinput import MetricInput
2022
from autorag.utils.util import load_summary_file, get_event_loop
@@ -112,18 +114,6 @@ def test_evaluate_one_query_expansion_node_vectordb(node_line_dir):
112114
retrieval_params = [
113115
{"top_k": 3, "vectordb": "chroma_large"},
114116
{"top_k": 5, "vectordb": "chroma_small"},
115-
{
116-
"top_k": 5,
117-
"target_modules": ("bm25", "vectordb"),
118-
"target_module_params": (
119-
{"top_k": 3, "bm25_tokenizer": "gpt2"},
120-
{
121-
"top_k": 3,
122-
"vectordb": "chroma_large",
123-
},
124-
),
125-
"weight": 0.36,
126-
},
127117
]
128118
base_test_evaluate_one_query_expansion_node(
129119
node_line_dir, retrieval_funcs, retrieval_params

0 commit comments

Comments
 (0)