1919
2020
2121class RAGImagingPipeline :
22- def __init__ (self , index_dir : Optional [str ] = None ):
22+ def __init__ (
23+ self ,
24+ index_dir : Optional [str ] = None ,
25+ min_results : int = 5 ,
26+ max_retries : int = 2 ,
27+ ):
28+ """Initialize the RAG imaging pipeline."""
2329 self .index_dir = Path (index_dir or os .getenv ("RAG_INDEX_DIR" , "artifacts/rag_index" ))
2430 self .index_dir .mkdir (parents = True , exist_ok = True )
31+
32+ self .min_results = min_results
33+ self .max_retries = max_retries
2534
2635 self .embedder = LocalBGEEmbedder ()
2736 self .reranker = CrossEncoderReranker ()
@@ -115,42 +124,37 @@ def retrieve_no_rerank(
115124 image_paths : Optional [List [str ]] = None ,
116125 top_k : int = 30 ,
117126 exclusions : Optional [List [str ]] = None ,
118- max_retries : int = 2 ,
119- min_results : int = 5 ,
120127 ) -> List [dict ]:
121128 """
122129 Return raw vector hits WITHOUT applying the CrossEncoder reranker.
123130
124131 Each item: {id, doc, score}. Optional `image_paths` are used to derive
125132 additional text hints (format / modality / anatomy / dims) that are
126133 appended to the query before embedding.
134+
135+ Relies on BGE-M3 semantic embeddings + CrossEncoder reranking.
127136 """
128137
129138 def _norm (s : str ) -> str :
130139 return re .sub (r"\s+" , " " , (s or "" ).strip ().lower ())
131140
132141 excluded_norm = {_norm (x ) for x in (exclusions or []) if x }
133142
134- # 1) Strip any tags from the query (your existing behavior)
143+ # 1) Strip any tags from the query
135144 clean_q = strip_tags (query )
136145
137146 # 2) Add image-derived hints (format, modality, anatomy, dims, ...)
138147 image_hints = self ._build_image_hint_text (image_paths )
139148 if image_hints :
140- clean_q = f"{ clean_q } { image_hints } " .strip () if clean_q else image_hints
141-
142- # 3) Apply similarity-based expansion
143- if hasattr (self .index , 'similarity_expander' ) and self .index .similarity_expander .vocabulary :
144- expanded_q = self .index .similarity_expander .expand_query (clean_q )
145- log .info (f"Similarity-expanded query: { clean_q } → { expanded_q } " )
149+ final_q = f"{ clean_q } { image_hints } " .strip ()
146150 else :
147- expanded_q = clean_q
151+ final_q = clean_q
152+
153+ log .info (f"Retrieval query: { clean_q } " + (f" + metadata: { image_hints [:50 ]} ..." if image_hints else "" ))
148154
149- log .info (f"Final retrieval query: { expanded_q } " )
150-
151- # 4) Vector search with automatic retry logic
155+ # 4) Vector search
152156 pool_k = max (50 , top_k * 3 )
153- hits = self .index .search (expanded_q , k = pool_k , reranker = None )
157+ hits = self .index .search (final_q , k = pool_k , reranker = None )
154158
155159 # 5) Apply name-based exclusions if any
156160 if excluded_norm :
@@ -160,46 +164,39 @@ def _norm(s: str) -> str:
160164 if _norm (getattr (h ["doc" ], "name" , "" )) not in excluded_norm
161165 ]
162166
163- # 6) Check if results are sufficient, retry with alternatives if not
167+ # 6) Check if results are sufficient, retry with broader terms if not
164168 attempt = 0
165- while len (hits ) < min_results and attempt < max_retries :
169+ while len (hits ) < self . min_results and attempt < self . max_retries :
166170 attempt += 1
167- log .info (f"Insufficient results ({ len (hits )} < { min_results } ), attempting retry { attempt } /{ max_retries } " )
171+ log .info (f"Insufficient results ({ len (hits )} < { self . min_results } ), attempting retry { attempt } /{ self . max_retries } " )
168172
169- # Generate alternative query using similarity expander
170- if hasattr (self .index , 'similarity_expander' ) and self .index .similarity_expander .vocabulary :
171- alternatives = self .index .similarity_expander .suggest_alternative_queries (
172- clean_q ,
173- num_alternatives = 1
174- )
175- if alternatives :
176- alt_query = alternatives [0 ]
177- log .info (f"Trying alternative query: { alt_query } " )
178-
179- # Add image hints to alternative
180- if image_hints :
181- alt_query = f"{ alt_query } { image_hints } " .strip ()
182-
183- # Expand alternative query
184- expanded_alt = self .index .similarity_expander .expand_query (alt_query )
185-
186- # Search with alternative
187- alt_hits = self .index .search (expanded_alt , k = pool_k , reranker = None )
188-
189- # Merge results (avoiding duplicates)
190- existing_ids = {h ["id" ] for h in hits }
191- for h in alt_hits :
192- if h ["id" ] not in existing_ids :
193- if not excluded_norm or _norm (getattr (h ["doc" ], "name" , "" )) not in excluded_norm :
194- hits .append (h )
195- existing_ids .add (h ["id" ])
196-
197- log .info (f"After retry { attempt } : { len (hits )} total results" )
173+ # Generate alternative by simplifying query (remove specific terms, keep general ones)
174+ # Strategy: use first 2-3 words only to broaden the search
175+ words = clean_q .split ()
176+ if len (words ) > 3 :
177+ alt_task = " " .join (words [:3 ])
178+ log .info (f"Trying broader query: { alt_task } " )
179+
180+ # Build alternative query with image hints
181+ if image_hints :
182+ alt_q = f"{ alt_task } { image_hints } " .strip ()
198183 else :
199- log .warning (f"Could not generate alternative query for retry { attempt } " )
200- break
184+ alt_q = alt_task
185+
186+ # Search with alternative
187+ alt_hits = self .index .search (alt_q , k = pool_k , reranker = None )
188+
189+ # Merge results (avoiding duplicates)
190+ existing_ids = {h ["id" ] for h in hits }
191+ for h in alt_hits :
192+ if h ["id" ] not in existing_ids :
193+ if not excluded_norm or _norm (getattr (h ["doc" ], "name" , "" )) not in excluded_norm :
194+ hits .append (h )
195+ existing_ids .add (h ["id" ])
196+
197+ log .info (f"After retry { attempt } : { len (hits )} total results" )
201198 else :
202- log .warning ("Similarity expander not available for retry" )
199+ log .warning (f"Query too short to generate alternative for retry { attempt } " )
203200 break
204201
205202 # 7) Attach convenience fields expected downstream
@@ -218,6 +215,37 @@ def rerank_only(self, query: str, hits: List[dict], top_k: int = 10) -> List[dic
218215 # Recreate query with any existing format tokens already embedded in retrieval
219216 ranked = self ._apply_reranker (strip_tags (query ), hits , top_k = top_k )
220217 return ranked
218+
219+ def retrieve (
220+ self ,
221+ query : str ,
222+ image_paths : Optional [List [str ]] = None ,
223+ top_k : int = 10 ,
224+ exclusions : Optional [List [str ]] = None ,
225+ ) -> List [dict ]:
226+ """
227+ Retrieve and automatically rerank results using BGE-M3 + CrossEncoder.
228+
229+ This is the main retrieval method that combines:
230+ 1. Semantic search via BGE-M3 embeddings (no query expansion)
231+ 2. Precision reranking via CrossEncoder
232+ 3. Image metadata hints (format, modality, dimensions)
233+
234+ Returns top_k results after CrossEncoder reranking.
235+ """
236+ # Get more candidates than needed for reranking
237+ pool_k = max (30 , top_k * 3 )
238+ hits = self .retrieve_no_rerank (
239+ query = query ,
240+ image_paths = image_paths ,
241+ top_k = pool_k ,
242+ exclusions = exclusions ,
243+ )
244+
245+ # Apply reranking to get final top_k
246+ if hits :
247+ return self .rerank_only (query , hits , top_k = top_k )
248+ return []
221249
222250 def get_doc (self , name : str ) -> Optional [SoftwareDoc ]:
223251 """Lookup a SoftwareDoc by name (case-sensitive match)."""
0 commit comments