-
Notifications
You must be signed in to change notification settings - Fork 2
Expand file tree
/
Copy pathagent.py
More file actions
63 lines (50 loc) · 2.03 KB
/
agent.py
File metadata and controls
63 lines (50 loc) · 2.03 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
from __future__ import annotations
import logging
from typing import TYPE_CHECKING, cast
from langchain.retrievers import ContextualCompressionRetriever
from langchain.retrievers.document_compressors import LLMListwiseRerank
from langchain_core.documents import Document
from langchain_core.runnables import Runnable
from automation.agents import BaseAgent
from automation.retrievers import MultiQueryRephraseRetriever
from .conf import settings
if TYPE_CHECKING:
from langchain_core.language_models.chat_models import BaseChatModel
from langchain_core.retrievers import BaseRetriever
logger = logging.getLogger("daiv.agents")
class CodebaseSearchAgent(BaseAgent[Runnable[str, list[Document]]]):
"""
Agent to search for code snippets in the codebase.
"""
def __init__(self, retriever: BaseRetriever, rephrase: bool = True, *args, **kwargs):
self.retriever = retriever
self.rephrase = rephrase
super().__init__(*args, **kwargs)
def compile(self) -> Runnable:
"""
Compile the agent into a Runnable.
Returns:
Runnable: The compiled agent
"""
if self.rephrase:
base_retriever: BaseRetriever = MultiQueryRephraseRetriever.from_llm(
self.retriever,
llm=cast(
"BaseChatModel",
# this model shows better results for rephrasing
self.get_model(model=settings.REPHRASE_MODEL_NAME),
),
)
else:
base_retriever: BaseRetriever = self.retriever
return ContextualCompressionRetriever(
base_compressor=LLMListwiseRerank.from_llm(
llm=cast(
"BaseChatModel",
# this model shows better results for listwise reranking
self.get_model(model=settings.RERANKING_MODEL_NAME),
),
top_n=settings.TOP_N,
),
base_retriever=base_retriever,
).with_config({"run_name": settings.NAME})