Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
34 changes: 34 additions & 0 deletions akd/agents/search/aspect_search/aspect_search.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,8 @@
update_references,
update_search_results,
)
from akd.structures import DecompositionClassification
from akd.tools.decomp_classifier import DecompClassifierConfig, DecompClassifierTool
from akd.tools.search import SearchResultItem, SearxNGSearchTool


Expand Down Expand Up @@ -89,6 +91,23 @@ class AspectSearchConfig(BaseAgentConfig):
description="Maximum length of the search result context during interviews.",
)

# Query classification configuration
enable_query_classification: bool = Field(
default=False,
description="Whether to classify decomposed queries before execution",
)
classifier_config: Optional[DecompClassifierConfig] = Field(
default=None,
description="Configuration for the decomposition classifier tool",
)
filter_classifications: Optional[List[DecompositionClassification]] = Field(
default=None,
description=(
"If set, only execute queries with these classifications. "
"Example: [EXACT, CALCULATOR, PROXY] to skip TANGENTIAL queries"
),
)


class AspectSearchAgent(BaseAgent):
input_schema = AspectSearchInputSchema
Expand All @@ -113,6 +132,19 @@ def _post_init(

self.search_tool = self.config.search_tool

# Initialize query classifier if enabled
self.classifier_tool = None
if self.config.enable_query_classification:
classifier_config = self.config.classifier_config or DecompClassifierConfig()
self.classifier_tool = DecompClassifierTool(
config=classifier_config,
debug=self.debug,
)
if self.debug:
logger.debug(
f"Query classification enabled with model: {classifier_config.model_name}"
)

builder = StateGraph(InterviewState)
builder.add_node(
"ask_question",
Expand All @@ -126,6 +158,8 @@ def _post_init(
search_tool=self.search_tool,
search_category=self.config.category,
max_context_len=self.config.max_ctx_len,
classifier_tool=self.classifier_tool,
filter_classifications=self.config.filter_classifications,
),
retry=RetryPolicy(max_attempts=self.config.retry_attempts),
)
Expand Down
33 changes: 32 additions & 1 deletion akd/agents/search/aspect_search/interview_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -181,6 +181,8 @@ async def generate_answer(
search_category: str = None,
name: str = "Subject_Matter_Expert",
max_ctx_len: int = 15000,
classifier_tool=None,
filter_classifications=None,
**kwargs,
) -> Dict:
"""
Expand All @@ -191,8 +193,12 @@ async def generate_answer(
state (InterviewState): Current interview state.
llm (ChatOpenAI): Language model for generating queries and answers.
search_tool (SearchTool): Tool for retrieving search results.
search_category (str, optional): Category for the search tool.
name (str, optional): AI participant name. Defaults to "Subject_Matter_Expert".
max_ctx_len (int, optional): Max context length for search data. Defaults to 15000.
classifier_tool (DecompClassifierTool, optional): Tool for classifying queries.
filter_classifications (List[DecompositionClassification], optional):
If set, only execute queries with these classifications.

Returns:
Dict: Generated answer message, cited references, and search results.
Expand All @@ -208,9 +214,34 @@ async def generate_answer(
)
swapped_state = swap_roles(state, name)
queries = await gen_queries_chain.ainvoke(swapped_state)

# Classify queries if classifier is enabled
queries_to_execute = queries["parsed"].queries
if classifier_tool is not None:
# Extract original topic from the last editor question
last_question = state["messages"][-2].content if len(state["messages"]) >= 2 else "research topic"

# Classify all queries
classification_result = await classifier_tool.arun(
classifier_tool.input_schema(
original_topic=last_question,
queries=queries["parsed"].queries,
)
)

# Store classifications in the queries object
queries["parsed"].classified_queries = classification_result.classified_queries

# Filter queries if filter_classifications is set
if filter_classifications is not None:
queries_to_execute = [
cq.query for cq in classification_result.classified_queries
if cq.classification in filter_classifications
]

query_results = await search_tool.arun(
search_tool.input_schema(
queries=queries["parsed"].queries,
queries=queries_to_execute,
category=search_category,
),
)
Expand Down
6 changes: 6 additions & 0 deletions akd/agents/search/aspect_search/structures.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,8 @@
from pydantic import BaseModel, Field
from typing_extensions import TypedDict

from akd.structures import ClassifiedQuery

# ---------------------------------------------------
# Interview state helper functions
# ---------------------------------------------------
Expand Down Expand Up @@ -96,6 +98,10 @@ class Queries(BaseModel):
queries: List[str] = Field(
description="Comprehensive list of search engine queries to answer the user's questions.",
)
classified_queries: Optional[List[ClassifiedQuery]] = Field(
default=None,
description="Optional classified version of queries with labels and reasoning",
)


class AnswerWithCitations(BaseModel):
Expand Down
58 changes: 58 additions & 0 deletions akd/structures.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
organized into logical sections for better maintainability.
"""

from enum import Enum
from typing import Any

from pydantic import (
Expand All @@ -22,6 +23,51 @@
# from akd.common_types import ToolType
from akd.configs.project import CONFIG

# =============================================================================
# Classification Enums
# =============================================================================


class DecompositionClassification(str, Enum):
"""
Classification categories for decomposed queries relative to original topic.

Categories are defined by their relationship to the research topic:
- EXACT: Direct measurement of the phenomenon ("That is the thing you asked for")
- CALCULATOR: Mechanistic input/driver that physically affects the topic
- PROXY: Surrogate/stand-in measurement used because it correlates with the topic
- TANGENTIAL: Weakly related, contextual information not core to the analysis

Examples:
- Fire risk → Fire Weather Index: EXACT
- Fire risk → soil moisture: CALCULATOR
- Phytoplankton biomass → chlorophyll-a: PROXY
- Fire risk → regional humidity: TANGENTIAL
"""

EXACT = "exact"
CALCULATOR = "calculator"
PROXY = "proxy"
TANGENTIAL = "tangential"


class ClassifiedQuery(BaseModel):
"""
A decomposed query with its classification relative to the original topic.

Attributes:
query: The search query text
classification: Classification category (EXACT, CALCULATOR, PROXY, TANGENTIAL)
reasoning: Brief explanation for why this classification was assigned
"""

query: str = Field(description="The search query text")
classification: DecompositionClassification = Field(
description="Classification of query relative to topic"
)
reasoning: str = Field(description="Brief explanation for the classification")


# =============================================================================
# Search and Data Models
# =============================================================================
Expand Down Expand Up @@ -242,6 +288,13 @@ class PaperDataItem(BaseModel):
)


class BaseCriterion(BaseModel):
"""Base class for criteria used in reranking."""

name: str = Field(..., description="Criterion name")
description: str = Field(..., description="Detailed description of what this criterion evaluates")


# =============================================================================
# Extraction Schemas
# =============================================================================
Expand Down Expand Up @@ -331,7 +384,11 @@ def name(self) -> str:
# Type alias for semantic clarity in literature search contexts
LitSearchResult = SearchResultItem


__all__ = [
# Classification
"DecompositionClassification",
"ClassifiedQuery",
# Search and Data Models
"SearchResult",
"SearchResultItem",
Expand All @@ -342,4 +399,5 @@ def name(self) -> str:
"SingleEstimation",
# Tool Models
"ToolSearchResult",
"BaseCriterion",
]
Loading