Skip to content

Commit aaddfae

Browse files
committed
now supporting refinement
1 parent 2297767 commit aaddfae

3 files changed

Lines changed: 177 additions & 121 deletions

File tree

src/ai_agent/api/pipeline.py

Lines changed: 15 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -99,9 +99,10 @@ def _apply_reranker(self, query: str, hits: List[dict], top_k: int) -> List[dict
9999
break
100100
return out
101101

102-
def recommend(
103-
self, user_task: str, image_paths: Optional[List[str]], top_k: int = 5
104-
) -> Tuple[List[dict], Dict[str, float]]:
102+
def recommend(self, user_task: str, image_paths: Optional[List[str]], top_k: int = 5,
103+
persisted_exclusions: Optional[List[str]] = None
104+
) -> Tuple[List[dict], Dict[str, float]]:
105+
105106
"""
106107
Retrieve candidate tools for the given request. Control tags:
107108
[NO_RERANK] -> skip CrossEncoder reranker
@@ -114,7 +115,9 @@ def _norm(s: str) -> str:
114115

115116
# --- Control tags ---------------------------------------------------------
116117
skip_rerank = has_no_rerank(user_task)
117-
excluded_raw = parse_exclusions(user_task)
118+
excluded_raw = set(parse_exclusions(user_task))
119+
if persisted_exclusions:
120+
excluded_raw |= set(persisted_exclusions)
118121
excluded_norm = {_norm(x) for x in excluded_raw}
119122

120123
# Work with a clean task (no control tags) for retrieval
@@ -267,6 +270,7 @@ def recommend_and_link(
267270
image_paths: Optional[List[str]],
268271
user_task: str,
269272
conversation_history: Optional[List[str]] = None,
273+
persisted_exclusions: Optional[List[str]] = None,
270274
) -> Dict[str, Any]:
271275
# --- helpers ------------------------------------------------------------
272276

@@ -283,6 +287,8 @@ def _norm(s: str) -> str:
283287
# --- control tags ------------------------------------------------------
284288
force_clarification = has_refine(full_task)
285289
exclude_names = set(parse_exclusions(full_task))
290+
if persisted_exclusions:
291+
exclude_names |= set(persisted_exclusions)
286292
selector_task_clean = strip_tags(full_task)
287293
excluded_norm = {_norm(x) for x in exclude_names}
288294

@@ -309,7 +315,11 @@ def _norm(s: str) -> str:
309315
top_k = int(os.getenv("TOP_K", "8"))
310316
num_choices = int(os.getenv("NUM_CHOICES", "3"))
311317

312-
hits, _scores = self.recommend(full_task, image_paths, top_k=top_k)
318+
hits, _scores = self.recommend(
319+
full_task, image_paths, top_k=top_k,
320+
persisted_exclusions=list(exclude_names) if exclude_names else None
321+
)
322+
313323
if not hits:
314324
return {
315325
"conversation": {"status": "complete"},

0 commit comments

Comments
 (0)