Skip to content

Commit baf5664

Browse files
committed
fix several errors
1 parent 81d84f4 commit baf5664

17 files changed

Lines changed: 295 additions & 1634 deletions

File tree

autorag/autorag/nodes/passageaugmenter/base.py

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,7 @@
1111
validate_corpus_dataset,
1212
cast_corpus_dataset,
1313
)
14-
from autorag.utils.cast import cast_retrieve_infos
14+
from autorag.utils.cast import cast_retrieved_ids
1515
from autorag.utils.util import select_top_k
1616

1717
logger = logging.getLogger("AutoRAG")
@@ -42,8 +42,7 @@ def cast_to_run(self, previous_result: pd.DataFrame, *args, **kwargs):
4242
validate_qa_dataset(previous_result)
4343

4444
# find ids columns
45-
retrieve_infos = cast_retrieve_infos(previous_result)
46-
return retrieve_infos["retrieved_ids"]
45+
return cast_retrieved_ids(previous_result)
4746

4847
@staticmethod
4948
def sort_by_scores(

autorag/autorag/nodes/passagecompressor/base.py

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,7 @@
88
from autorag import generator_models
99
from autorag.schema import BaseModule
1010
from autorag.utils import result_to_dataframe
11-
from autorag.utils.cast import cast_retrieve_infos
11+
from autorag.utils.cast import cast_retrieved_contents
1212

1313
logger = logging.getLogger("AutoRAG")
1414

@@ -34,9 +34,8 @@ def cast_to_run(self, previous_result: pd.DataFrame, *args, **kwargs):
3434
assert len(previous_result) > 0, "previous_result must have at least one row."
3535

3636
queries = previous_result["query"].tolist()
37-
retrieve_infos = cast_retrieve_infos(previous_result)
3837

39-
return queries, retrieve_infos["retrieved_contents"]
38+
return queries, cast_retrieved_contents(previous_result)
4039

4140

4241
class LlamaIndexCompressor(BasePassageCompressor, metaclass=abc.ABCMeta):

autorag/autorag/nodes/promptmaker/base.py

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@
66
import pandas as pd
77

88
from autorag.schema.base import BaseModule
9-
from autorag.utils.cast import cast_retrieve_infos
9+
from autorag.utils.cast import cast_retrieved_contents
1010

1111
logger = logging.getLogger("AutoRAG")
1212

@@ -28,6 +28,5 @@ def cast_to_run(self, previous_result: pd.DataFrame, *args, **kwargs):
2828
), "previous_result must have query column."
2929

3030
query = previous_result["query"].tolist()
31-
retrieve_infos = cast_retrieve_infos(previous_result)
3231
prompt = kwargs.pop("prompt")
33-
return query, retrieve_infos["retrieved_contents"], prompt
32+
return query, cast_retrieved_contents(previous_result), prompt

autorag/autorag/nodes/retrieval/__init__.py

Whitespace-only changes.

autorag/autorag/nodes/retrieval/base.py

Lines changed: 0 additions & 76 deletions
This file was deleted.

0 commit comments

Comments
 (0)