Skip to content

Commit bb792b0

Browse files
jay-thakurekzhu
andauthored
Fix: Azure AI Search Tool Client Lifetime Management (#6316)
## Why are these changes needed? This PR fixes a bug where the underlying azure `SearchClient` was being closed prematurely due to use of `async with client` : inside the tool's run method. this caused the users to encounter errors "HTTP transport has already been closed" ## Related issue number Closes #6308 " ## Checks - [ ] I've included any doc changes needed for <https://microsoft.github.io/autogen/>. See <https://github.com/microsoft/autogen/blob/main/CONTRIBUTING.md> to build and test documentation locally. - [ ] I've added tests (if relevant) corresponding to the changes introduced in this PR. - [X] I've made sure all auto checks have passed. --------- Co-authored-by: Eric Zhu <[email protected]>
1 parent 629fb86 commit bb792b0

File tree

1 file changed

+62
-55
lines changed
  • python/packages/autogen-ext/src/autogen_ext/tools/azure

1 file changed

+62
-55
lines changed

Diff for: python/packages/autogen-ext/src/autogen_ext/tools/azure/_ai_search.py

+62-55
Original file line numberDiff line numberDiff line change
@@ -241,6 +241,12 @@ def __init__(
241241
self._client: Optional[SearchClient] = None
242242
self._cache: Dict[str, Dict[str, Any]] = {}
243243

244+
async def close(self) -> None:
245+
"""Explicitly close the Azure SearchClient if needed (for cleanup in long-running apps/tests)."""
246+
if self._client is not None:
247+
await self._client.close()
248+
self._client = None
249+
244250
def _process_credential(
245251
self, credential: Union[AzureKeyCredential, TokenCredential, Dict[str, str]]
246252
) -> Union[AzureKeyCredential, TokenCredential]:
@@ -362,61 +368,62 @@ async def run(
362368
client = await self._get_client()
363369
results: List[SearchResult] = []
364370

365-
async with client:
366-
search_future = client.search(text_query, **search_options) # type: ignore
367-
368-
if cancellation_token is not None:
369-
import asyncio
370-
371-
# Using explicit type ignores to handle Azure SDK type complexity
372-
async def awaitable_wrapper(): # type: ignore # pyright: ignore[reportUnknownVariableType,reportUnknownLambdaType,reportUnknownMemberType]
373-
return await search_future # pyright: ignore[reportUnknownVariableType]
374-
375-
task = asyncio.create_task(awaitable_wrapper()) # type: ignore # pyright: ignore[reportUnknownVariableType]
376-
cancellation_token.link_future(task) # pyright: ignore[reportUnknownArgumentType]
377-
search_results = await task # pyright: ignore[reportUnknownVariableType]
378-
else:
379-
search_results = await search_future # pyright: ignore[reportUnknownVariableType]
380-
381-
async for doc in search_results: # type: ignore
382-
search_doc: Any = doc
383-
doc_dict: Dict[str, Any] = {}
384-
385-
try:
386-
if hasattr(search_doc, "items") and callable(search_doc.items):
387-
dict_like_doc = cast(Dict[str, Any], search_doc)
388-
for key, value in dict_like_doc.items():
389-
doc_dict[str(key)] = value
390-
else:
391-
for key in [
392-
k
393-
for k in dir(search_doc)
394-
if not k.startswith("_") and not callable(getattr(search_doc, k, None))
395-
]:
396-
doc_dict[key] = getattr(search_doc, key)
397-
except Exception as e:
398-
logger.warning(f"Error processing search document: {e}")
399-
continue
400-
401-
metadata: Dict[str, Any] = {}
402-
content: Dict[str, Any] = {}
403-
for key, value in doc_dict.items():
404-
key_str: str = str(key)
405-
if key_str.startswith("@") or key_str.startswith("_"):
406-
metadata[key_str] = value
407-
else:
408-
content[key_str] = value
409-
410-
score: float = 0.0
411-
if "@search.score" in doc_dict:
412-
score = float(doc_dict["@search.score"])
413-
414-
result = SearchResult(
415-
score=score,
416-
content=content,
417-
metadata=metadata,
418-
)
419-
results.append(result)
371+
# Use the persistent client directly. Do NOT close after each operation.
372+
# WARNING: The SearchClient must live as long as the tool/agent is in use.
373+
search_future = client.search(text_query, **search_options) # type: ignore
374+
375+
if cancellation_token is not None:
376+
import asyncio
377+
378+
# Using explicit type ignores to handle Azure SDK type complexity
379+
async def awaitable_wrapper(): # type: ignore # pyright: ignore[reportUnknownVariableType,reportUnknownLambdaType,reportUnknownMemberType]
380+
return await search_future # pyright: ignore[reportUnknownVariableType]
381+
382+
task = asyncio.create_task(awaitable_wrapper()) # type: ignore # pyright: ignore[reportUnknownVariableType]
383+
cancellation_token.link_future(task) # pyright: ignore[reportUnknownArgumentType]
384+
search_results = await task # pyright: ignore[reportUnknownVariableType]
385+
else:
386+
search_results = await search_future # pyright: ignore[reportUnknownVariableType]
387+
388+
async for doc in search_results: # type: ignore
389+
search_doc: Any = doc
390+
doc_dict: Dict[str, Any] = {}
391+
392+
try:
393+
if hasattr(search_doc, "items") and callable(search_doc.items):
394+
dict_like_doc = cast(Dict[str, Any], search_doc)
395+
for key, value in dict_like_doc.items():
396+
doc_dict[str(key)] = value
397+
else:
398+
for key in [
399+
k
400+
for k in dir(search_doc)
401+
if not k.startswith("_") and not callable(getattr(search_doc, k, None))
402+
]:
403+
doc_dict[key] = getattr(search_doc, key)
404+
except Exception as e:
405+
logger.warning(f"Error processing search document: {e}")
406+
continue
407+
408+
metadata: Dict[str, Any] = {}
409+
content: Dict[str, Any] = {}
410+
for key, value in doc_dict.items():
411+
key_str: str = str(key)
412+
if key_str.startswith("@") or key_str.startswith("_"):
413+
metadata[key_str] = value
414+
else:
415+
content[key_str] = value
416+
417+
score: float = 0.0
418+
if "@search.score" in doc_dict:
419+
score = float(doc_dict["@search.score"])
420+
421+
result = SearchResult(
422+
score=score,
423+
content=content,
424+
metadata=metadata,
425+
)
426+
results.append(result)
420427

421428
if self.search_config.enable_caching:
422429
cache_key = f"{text_query}_{self.search_config.top}"

0 commit comments

Comments
 (0)