77
88import asyncio
99import contextlib
10+ import json
1011import os
1112from collections .abc import AsyncGenerator
1213from dataclasses import dataclass
@@ -139,6 +140,8 @@ async def _aprocess_query_and_retrieve_docs(
139140 )
140141 # documents already contains all retrieved docs, no action needed
141142
143+ documents = await self ._expand_skill_documents (documents )
144+
142145 # Ensure Grok summary is present and first in order (for generation context)
143146 if grok_summary_doc is not None :
144147 if grok_summary_doc in documents :
@@ -150,6 +153,84 @@ async def _aprocess_query_and_retrieve_docs(
150153
151154 return processed_query , documents , grok_citations
152155
156+ async def _expand_skill_documents (self , documents : list [Document ]) -> list [Document ]:
157+ """
158+ Replace skill chunks with full skill documents when available.
159+
160+ If a full document row cannot be fetched for a skill, keep that skill's
161+ original chunks to degrade gracefully.
162+ """
163+ skill_chunks = [
164+ document
165+ for document in documents
166+ if document .metadata .get ("source" ) == DocumentSource .CAIRO_SKILLS
167+ and document .metadata .get ("skillId" )
168+ ]
169+ if not skill_chunks :
170+ return documents
171+
172+ skill_ids = list (dict .fromkeys (doc .metadata ["skillId" ] for doc in skill_chunks ))
173+ unique_ids = [f"skill-{ skill_id } -full" for skill_id in skill_ids ]
174+
175+ try :
176+ rows = await self .document_retriever .vector_db .afetch_by_unique_ids (unique_ids )
177+ except Exception as e :
178+ logger .warning (
179+ "_expand_skill_documents: failed to fetch full rows, keeping original chunks" ,
180+ error = str (e ),
181+ exc_info = True ,
182+ )
183+ return documents
184+
185+ full_documents_by_skill_id : dict [str , Document ] = {}
186+ for row in rows :
187+ metadata : Any = row .get ("metadata" , {})
188+ if isinstance (metadata , str ):
189+ try :
190+ metadata = json .loads (metadata )
191+ except Exception :
192+ logger .warning (
193+ "_expand_skill_documents: unable to decode metadata json, skipping row"
194+ )
195+ continue
196+
197+ if not isinstance (metadata , dict ):
198+ continue
199+
200+ skill_id = metadata .get ("skillId" )
201+ full_content = metadata .get ("fullContent" )
202+ if skill_id and full_content :
203+ full_documents_by_skill_id [skill_id ] = Document (
204+ page_content = full_content ,
205+ metadata = metadata ,
206+ )
207+
208+ result_documents = [
209+ document
210+ for document in documents
211+ if document .metadata .get ("source" ) != DocumentSource .CAIRO_SKILLS
212+ ]
213+
214+ found_skill_ids = set (full_documents_by_skill_id )
215+ for skill_id in skill_ids :
216+ if skill_id not in found_skill_ids :
217+ original_chunks = [
218+ document
219+ for document in skill_chunks
220+ if document .metadata .get ("skillId" ) == skill_id
221+ ]
222+ result_documents .extend (original_chunks )
223+ logger .warning (
224+ "_expand_skill_documents: no full document found, keeping chunks" ,
225+ skill_id = skill_id ,
226+ )
227+
228+ for skill_id in skill_ids :
229+ if skill_id in full_documents_by_skill_id :
230+ result_documents .append (full_documents_by_skill_id [skill_id ])
231+
232+ return result_documents
233+
153234 @traceable (name = "RagPipeline" , run_type = "chain" )
154235 async def aforward (
155236 self ,
0 commit comments