Skip to content

Commit 65d1a76

Browse files
committed
continued rebase
1 parent 18f51d5 commit 65d1a76

13 files changed

Lines changed: 881 additions & 648 deletions

config.yaml

Lines changed: 8 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -1,13 +1,13 @@
11
# AI Agent Model Configuration
22

33
# Default config
4-
agent_model:
5-
name: "gpt-5.1" # "gpt-4o" # Model name
6-
base_url: null # null for default OpenAI endpoint
7-
api_key_env: "OPENAI_API_KEY" # Environment variable containing API key
4+
# agent_model:
5+
# name: "gpt-5.1" # "gpt-4o" # Model name
6+
# base_url: null # null for default OpenAI endpoint
7+
# api_key_env: "OPENAI_API_KEY" # Environment variable containing API key
88

99
# Using EPFL's inference server
10-
# agent_model:
11-
# name: "openai/gpt-oss-120b"
12-
# base_url: "https://inference.rcp.epfl.ch/v1"
13-
# api_key_env: "EPFL_API_KEY" # Set EPFL_API_KEY in .env
10+
agent_model:
11+
name: "openai/gpt-oss-120b"
12+
base_url: "https://inference.rcp.epfl.ch/v1"
13+
api_key_env: "EPFL_API_KEY" # Set EPFL_API_KEY in .env

src/ai_agent/agent/agent.py

Lines changed: 1 addition & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,6 @@
2020
RepoSummaryInput,
2121
coerce_github_url_or_none,
2222
)
23-
from .tools.rerank_tool import tool_rerank, RerankInput
2423
from .tools.search_tool import tool_search_tools, SearchToolsInput
2524
from .tools.search_alternative_tool import tool_search_alternative, SearchAlternativeInput
2625
from .tools.gradio_space_tool import tool_run_example, RunExampleInput
@@ -124,32 +123,6 @@ async def search_tools(
124123
return [c.model_dump(mode="python") for c in out.candidates]
125124

126125

127-
@agent.tool(retries=2, prepare=cap_prepare)
128-
@limit_tool_calls("rerank", cap=3)
129-
async def rerank(
130-
ctx: RunContext[AgentState],
131-
query: str,
132-
candidate_names: List[str],
133-
top_k: int = 5,
134-
) -> List[dict]:
135-
"""
136-
Cross-encoder reranker over a small set of candidate tool names.
137-
"""
138-
out = tool_rerank(
139-
RerankInput(query=query, candidate_names=candidate_names, top_k=top_k)
140-
)
141-
ctx.deps.tool_calls.append(
142-
{
143-
"tool": "rerank",
144-
"query": query,
145-
"used_model": out.used_model,
146-
"count": len(out.reranked),
147-
"timestamp": datetime.now().isoformat()
148-
}
149-
)
150-
return list(out.reranked)
151-
152-
153126
@agent.tool(retries=2, prepare=cap_prepare)
154127
@limit_tool_calls("search_alternative", cap=3)
155128
async def search_alternative(
@@ -159,7 +132,7 @@ async def search_alternative(
159132
top_k: int = 12,
160133
) -> List[dict]:
161134
"""
162-
Search with an alternative query formulation.
135+
Search with an alternative query formulation (includes automatic reranking).
163136
"""
164137
# Merge exclusions
165138
explicit_excluded = excluded or []

src/ai_agent/agent/tools/search_alternative_tool.py

Lines changed: 3 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -30,7 +30,7 @@ class SearchAlternativeOutput(BaseModel):
3030

3131
def tool_search_alternative(inp: SearchAlternativeInput) -> SearchAlternativeOutput:
3232
"""
33-
Search with an alternative query formulation.
33+
Search with an alternative query formulation, with automatic reranking.
3434
3535
This tool allows the agent to explicitly try a different search approach
3636
when initial results are not satisfactory.
@@ -67,15 +67,12 @@ def tool_search_alternative(inp: SearchAlternativeInput) -> SearchAlternativeOut
6767
query + " " + " ".join(f"format:{t}" for t in fmt_tokens)
6868
).strip()
6969

70-
# Call retrieval with the alternative query
71-
# Set min_results=0 to prevent automatic retry (agent is already retrying)
72-
hits = pipe.retrieve_no_rerank(
70+
# Call retrieve() which includes automatic reranking
71+
hits = pipe.retrieve(
7372
query,
7473
image_paths=inp.image_paths or None,
7574
exclusions=inp.excluded,
7675
top_k=inp.top_k,
77-
min_results=0, # Disable automatic retry since agent controls this
78-
max_retries=0, # Disable automatic retry
7976
)
8077

8178
# Convert hits to CandidateDoc objects

src/ai_agent/agent/tools/search_tool.py

Lines changed: 5 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -18,13 +18,13 @@ class SearchToolsOutput(BaseModel):
1818

1919
def tool_search_tools(inp: SearchToolsInput) -> SearchToolsOutput:
2020
"""
21-
Search tools WITHOUT reranker.
21+
Search tools with automatic reranking.
2222
23-
- Uses dense retrieval with similarity-based query expansion.
23+
- Uses dense retrieval with dictionary-based query expansion.
24+
- Applies CrossEncoder reranking automatically for best results.
2425
- Softly biases results using file-format hints (format:EXT).
2526
- Optionally uses `image_paths` so the pipeline can derive additional
2627
hints (modality / anatomy / dims) directly from the image files.
27-
- Includes automatic retry logic if insufficient results are found.
2828
"""
2929
pipe = get_pipeline()
3030

@@ -76,9 +76,8 @@ def tool_search_tools(inp: SearchToolsInput) -> SearchToolsOutput:
7676
base_query + " " + " ".join(f"format:{t}" for t in fmt_tokens)
7777
).strip()
7878

79-
# 5) Call the vector index with similarity expansion and automatic retry
80-
# The pipeline now handles similarity-based expansion internally
81-
hits = pipe.retrieve_no_rerank(
79+
# 5) Call retrieve() that includes automatic reranking
80+
hits = pipe.retrieve(
8281
base_query,
8382
image_paths=inp.image_paths or None,
8483
exclusions=inp.excluded,

src/ai_agent/api/pipeline.py

Lines changed: 78 additions & 50 deletions
Original file line numberDiff line numberDiff line change
@@ -19,9 +19,18 @@
1919

2020

2121
class RAGImagingPipeline:
22-
def __init__(self, index_dir: Optional[str] = None):
22+
def __init__(
23+
self,
24+
index_dir: Optional[str] = None,
25+
min_results: int = 5,
26+
max_retries: int = 2,
27+
):
28+
"""Initialize the RAG imaging pipeline."""
2329
self.index_dir = Path(index_dir or os.getenv("RAG_INDEX_DIR", "artifacts/rag_index"))
2430
self.index_dir.mkdir(parents=True, exist_ok=True)
31+
32+
self.min_results = min_results
33+
self.max_retries = max_retries
2534

2635
self.embedder = LocalBGEEmbedder()
2736
self.reranker = CrossEncoderReranker()
@@ -115,42 +124,37 @@ def retrieve_no_rerank(
115124
image_paths: Optional[List[str]] = None,
116125
top_k: int = 30,
117126
exclusions: Optional[List[str]] = None,
118-
max_retries: int = 2,
119-
min_results: int = 5,
120127
) -> List[dict]:
121128
"""
122129
Return raw vector hits WITHOUT applying the CrossEncoder reranker.
123130
124131
Each item: {id, doc, score}. Optional `image_paths` are used to derive
125132
additional text hints (format / modality / anatomy / dims) that are
126133
appended to the query before embedding.
134+
135+
Relies on BGE-M3 semantic embeddings + CrossEncoder reranking.
127136
"""
128137

129138
def _norm(s: str) -> str:
130139
return re.sub(r"\s+", " ", (s or "").strip().lower())
131140

132141
excluded_norm = {_norm(x) for x in (exclusions or []) if x}
133142

134-
# 1) Strip any tags from the query (your existing behavior)
143+
# 1) Strip any tags from the query
135144
clean_q = strip_tags(query)
136145

137146
# 2) Add image-derived hints (format, modality, anatomy, dims, ...)
138147
image_hints = self._build_image_hint_text(image_paths)
139148
if image_hints:
140-
clean_q = f"{clean_q} {image_hints}".strip() if clean_q else image_hints
141-
142-
# 3) Apply similarity-based expansion
143-
if hasattr(self.index, 'similarity_expander') and self.index.similarity_expander.vocabulary:
144-
expanded_q = self.index.similarity_expander.expand_query(clean_q)
145-
log.info(f"Similarity-expanded query: {clean_q}{expanded_q}")
149+
final_q = f"{clean_q} {image_hints}".strip()
146150
else:
147-
expanded_q = clean_q
151+
final_q = clean_q
152+
153+
log.info(f"Retrieval query: {clean_q}" + (f" + metadata: {image_hints[:50]}..." if image_hints else ""))
148154

149-
log.info(f"Final retrieval query: {expanded_q}")
150-
151-
# 4) Vector search with automatic retry logic
155+
# 4) Vector search
152156
pool_k = max(50, top_k * 3)
153-
hits = self.index.search(expanded_q, k=pool_k, reranker=None)
157+
hits = self.index.search(final_q, k=pool_k, reranker=None)
154158

155159
# 5) Apply name-based exclusions if any
156160
if excluded_norm:
@@ -160,46 +164,39 @@ def _norm(s: str) -> str:
160164
if _norm(getattr(h["doc"], "name", "")) not in excluded_norm
161165
]
162166

163-
# 6) Check if results are sufficient, retry with alternatives if not
167+
# 6) Check if results are sufficient, retry with broader terms if not
164168
attempt = 0
165-
while len(hits) < min_results and attempt < max_retries:
169+
while len(hits) < self.min_results and attempt < self.max_retries:
166170
attempt += 1
167-
log.info(f"Insufficient results ({len(hits)} < {min_results}), attempting retry {attempt}/{max_retries}")
171+
log.info(f"Insufficient results ({len(hits)} < {self.min_results}), attempting retry {attempt}/{self.max_retries}")
168172

169-
# Generate alternative query using similarity expander
170-
if hasattr(self.index, 'similarity_expander') and self.index.similarity_expander.vocabulary:
171-
alternatives = self.index.similarity_expander.suggest_alternative_queries(
172-
clean_q,
173-
num_alternatives=1
174-
)
175-
if alternatives:
176-
alt_query = alternatives[0]
177-
log.info(f"Trying alternative query: {alt_query}")
178-
179-
# Add image hints to alternative
180-
if image_hints:
181-
alt_query = f"{alt_query} {image_hints}".strip()
182-
183-
# Expand alternative query
184-
expanded_alt = self.index.similarity_expander.expand_query(alt_query)
185-
186-
# Search with alternative
187-
alt_hits = self.index.search(expanded_alt, k=pool_k, reranker=None)
188-
189-
# Merge results (avoiding duplicates)
190-
existing_ids = {h["id"] for h in hits}
191-
for h in alt_hits:
192-
if h["id"] not in existing_ids:
193-
if not excluded_norm or _norm(getattr(h["doc"], "name", "")) not in excluded_norm:
194-
hits.append(h)
195-
existing_ids.add(h["id"])
196-
197-
log.info(f"After retry {attempt}: {len(hits)} total results")
173+
# Generate alternative by simplifying query (remove specific terms, keep general ones)
174+
# Strategy: use first 2-3 words only to broaden the search
175+
words = clean_q.split()
176+
if len(words) > 3:
177+
alt_task = " ".join(words[:3])
178+
log.info(f"Trying broader query: {alt_task}")
179+
180+
# Build alternative query with image hints
181+
if image_hints:
182+
alt_q = f"{alt_task} {image_hints}".strip()
198183
else:
199-
log.warning(f"Could not generate alternative query for retry {attempt}")
200-
break
184+
alt_q = alt_task
185+
186+
# Search with alternative
187+
alt_hits = self.index.search(alt_q, k=pool_k, reranker=None)
188+
189+
# Merge results (avoiding duplicates)
190+
existing_ids = {h["id"] for h in hits}
191+
for h in alt_hits:
192+
if h["id"] not in existing_ids:
193+
if not excluded_norm or _norm(getattr(h["doc"], "name", "")) not in excluded_norm:
194+
hits.append(h)
195+
existing_ids.add(h["id"])
196+
197+
log.info(f"After retry {attempt}: {len(hits)} total results")
201198
else:
202-
log.warning("Similarity expander not available for retry")
199+
log.warning(f"Query too short to generate alternative for retry {attempt}")
203200
break
204201

205202
# 7) Attach convenience fields expected downstream
@@ -218,6 +215,37 @@ def rerank_only(self, query: str, hits: List[dict], top_k: int = 10) -> List[dic
218215
# Recreate query with any existing format tokens already embedded in retrieval
219216
ranked = self._apply_reranker(strip_tags(query), hits, top_k=top_k)
220217
return ranked
218+
219+
def retrieve(
220+
self,
221+
query: str,
222+
image_paths: Optional[List[str]] = None,
223+
top_k: int = 10,
224+
exclusions: Optional[List[str]] = None,
225+
) -> List[dict]:
226+
"""
227+
Retrieve and automatically rerank results using BGE-M3 + CrossEncoder.
228+
229+
This is the main retrieval method that combines:
230+
1. Semantic search via BGE-M3 embeddings (no query expansion)
231+
2. Precision reranking via CrossEncoder
232+
3. Image metadata hints (format, modality, dimensions)
233+
234+
Returns top_k results after CrossEncoder reranking.
235+
"""
236+
# Get more candidates than needed for reranking
237+
pool_k = max(30, top_k * 3)
238+
hits = self.retrieve_no_rerank(
239+
query=query,
240+
image_paths=image_paths,
241+
top_k=pool_k,
242+
exclusions=exclusions,
243+
)
244+
245+
# Apply reranking to get final top_k
246+
if hits:
247+
return self.rerank_only(query, hits, top_k=top_k)
248+
return []
221249

222250
def get_doc(self, name: str) -> Optional[SoftwareDoc]:
223251
"""Lookup a SoftwareDoc by name (case-sensitive match)."""

0 commit comments

Comments
 (0)