Skip to content

Commit aef25a1

Browse files
committed
small refactor of the tests and files for the PR
1 parent f0fc57a commit aef25a1

43 files changed

Lines changed: 1771 additions & 1061 deletions

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

src/ai_agent/agent/agent.py

Lines changed: 80 additions & 31 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
from __future__ import annotations
22

3-
import os, logging
3+
import os
4+
import logging
45
from datetime import datetime
56
from typing import List
67

@@ -17,7 +18,10 @@
1718
from .tools.repo_info_tool import tool_repo_summary, RepoSummaryInput
1819
from ai_agent.agent.utils import coerce_github_url_or_none
1920
from .tools.search_tool import tool_search_tools, SearchToolsInput
20-
from .tools.search_alternative_tool import tool_search_alternative, SearchAlternativeInput
21+
from .tools.search_alternative_tool import (
22+
tool_search_alternative,
23+
SearchAlternativeInput,
24+
)
2125
from .utils import AgentState, limit_tool_calls, cap_prepare
2226
from ai_agent.utils.image_meta import summarize_image_metadata, detect_ext_token
2327

@@ -68,6 +72,7 @@
6872
# Tool adapters for the agent
6973
# ---------------------------------------------------------------------------
7074

75+
7176
@agent.tool(retries=2, prepare=cap_prepare)
7277
@limit_tool_calls("search_tools", cap=1)
7378
async def search_tools(
@@ -93,7 +98,9 @@ async def search_tools(
9398
original_formats = getattr(ctx.deps, "original_formats", []) or []
9499
image_paths = getattr(ctx.deps, "image_paths", []) or []
95100

96-
effective_top_k = ctx.deps.override_top_k if ctx.deps.override_top_k is not None else top_k
101+
effective_top_k = (
102+
ctx.deps.override_top_k if ctx.deps.override_top_k is not None else top_k
103+
)
97104

98105
inp = SearchToolsInput(
99106
query=query,
@@ -111,7 +118,7 @@ async def search_tools(
111118
"count": len(out.candidates),
112119
"original_formats": original_formats,
113120
"excluded": all_excluded,
114-
"timestamp": datetime.now().isoformat()
121+
"timestamp": datetime.now().isoformat(),
115122
}
116123
)
117124

@@ -154,7 +161,7 @@ async def search_alternative(
154161
"count": len(out.candidates),
155162
"original_formats": original_formats,
156163
"excluded": all_excluded,
157-
"timestamp": datetime.now().isoformat()
164+
"timestamp": datetime.now().isoformat(),
158165
}
159166
)
160167

@@ -163,24 +170,28 @@ async def search_alternative(
163170

164171
@agent.tool(retries=2, prepare=cap_prepare)
165172
@limit_tool_calls("repo_info", cap=12)
166-
async def repo_info(ctx: RunContext[AgentState], url: str, tool_name: str | None = None) -> dict:
173+
async def repo_info(
174+
ctx: RunContext[AgentState], url: str, tool_name: str | None = None
175+
) -> dict:
167176
"""
168177
Fetch a short summary of a GitHub repository.
169178
170179
Non-GitHub URLs are ignored; the tool returns a small dict noting
171180
that it was skipped. If a tool_name is provided and the URL is not
172181
a GitHub URL, the tool will attempt to look up the GitHub URL from
173182
the catalog.
174-
183+
175184
Args:
176185
url: Repository URL or GitHub owner/repo format
177186
tool_name: Optional tool name to look up in catalog if URL is not GitHub
178187
"""
179188
norm_url = coerce_github_url_or_none(url)
180-
189+
181190
# If URL is not a GitHub URL and tool_name is provided, try catalog lookup
182191
if not norm_url and tool_name:
183-
log.info(f"Non-GitHub URL provided, tool_name={tool_name}, attempting catalog lookup")
192+
log.info(
193+
f"Non-GitHub URL provided, tool_name={tool_name}, attempting catalog lookup"
194+
)
184195
# The tool_repo_summary will handle the catalog lookup
185196
norm_url = url # Pass through, tool_repo_summary will handle it
186197
elif not norm_url:
@@ -190,16 +201,24 @@ async def repo_info(ctx: RunContext[AgentState], url: str, tool_name: str | None
190201
"skipped": True,
191202
"reason": "NON_GITHUB_URL",
192203
"hint": "Pass a GitHub repo URL or 'owner/repo' to repo_info(url). Optionally provide tool_name for catalog lookup.",
193-
"timestamp": datetime.now().isoformat()
204+
"timestamp": datetime.now().isoformat(),
194205
}
195206
ctx.deps.tool_calls.append(payload)
196207
return {k: v for k, v in payload.items() if k != "tool"}
197208

198209
try:
199-
out = await tool_repo_summary(RepoSummaryInput(url=norm_url, tool_name=tool_name))
210+
out = await tool_repo_summary(
211+
RepoSummaryInput(url=norm_url, tool_name=tool_name)
212+
)
200213
except Exception as e:
201214
ctx.deps.tool_calls.append(
202-
{"tool": "repo_info", "url": norm_url, "tool_name": tool_name, "error": str(e), "timestamp": datetime.now().isoformat()}
215+
{
216+
"tool": "repo_info",
217+
"url": norm_url,
218+
"tool_name": tool_name,
219+
"error": str(e),
220+
"timestamp": datetime.now().isoformat(),
221+
}
203222
)
204223
raise
205224

@@ -209,7 +228,7 @@ async def repo_info(ctx: RunContext[AgentState], url: str, tool_name: str | None
209228
"url": norm_url,
210229
"tool_name": tool_name,
211230
"truncated": getattr(out, "truncated", False),
212-
"timestamp": datetime.now().isoformat()
231+
"timestamp": datetime.now().isoformat(),
213232
}
214233
)
215234
return out.model_dump(mode="python")
@@ -251,7 +270,11 @@ def run_agent(
251270
tool_logs: List[ToolRunLog] = []
252271

253272
# ---- 1) Derive image-based metadata and format hints --------------------
254-
meta_str = image_metadata if image_metadata is not None else (summarize_image_metadata(image_paths) or "")
273+
meta_str = (
274+
image_metadata
275+
if image_metadata is not None
276+
else (summarize_image_metadata(image_paths) or "")
277+
)
255278
fmt_str = detect_ext_token(image_paths) or ""
256279
original_formats = [t.lower() for t in fmt_str.split()] if fmt_str else []
257280

@@ -273,7 +296,12 @@ def run_agent(
273296
hidden_meta += "\n(Formats Hint: " + ",".join(original_formats) + ")"
274297
if meta_str:
275298
short_meta = " ".join(x.strip() for x in meta_str.splitlines() if x.strip())
276-
hidden_meta += "\n(Image Metadata: " + short_meta[:500] + ("…" if len(short_meta) > 500 else "") + ")"
299+
hidden_meta += (
300+
"\n(Image Metadata: "
301+
+ short_meta[:500]
302+
+ ("…" if len(short_meta) > 500 else "")
303+
+ ")"
304+
)
277305
if top_k is not None:
278306
hidden_meta += f"\n(Search top_k: {top_k})"
279307

@@ -303,7 +331,9 @@ def run_agent(
303331
key_env_name = api_key_env if api_key_env else "OPENAI_API_KEY"
304332
runtime_api_key = os.getenv(key_env_name)
305333
if not runtime_api_key:
306-
raise ValueError(f"{key_env_name} not found in environment. Cannot use this model.")
334+
raise ValueError(
335+
f"{key_env_name} not found in environment. Cannot use this model."
336+
)
307337
effective_base_url = base_url # Can be None for OpenAI
308338
log.info(f"✓ Using {key_env_name} for model {effective_model}")
309339
log.debug(f"{key_env_name} is set: {bool(runtime_api_key)}")
@@ -335,13 +365,17 @@ def run_agent(
335365
base_url=effective_base_url,
336366
api_key=runtime_api_key,
337367
)
338-
368+
339369
# Use OpenAIChatModel (chat/completions) for custom endpoints, OpenAIResponsesModel for default OpenAI
340370
if effective_base_url:
341371
log.info("Using OpenAIChatModel (chat/completions API) for custom endpoint")
342-
runtime_model = OpenAIChatModel(model_name=effective_model, provider=runtime_provider)
372+
runtime_model = OpenAIChatModel(
373+
model_name=effective_model, provider=runtime_provider
374+
)
343375
else:
344-
runtime_model = OpenAIResponsesModel(model_name=effective_model, provider=runtime_provider)
376+
runtime_model = OpenAIResponsesModel(
377+
model_name=effective_model, provider=runtime_provider
378+
)
345379

346380
agent_instance = Agent(
347381
model=runtime_model,
@@ -370,7 +404,9 @@ def run_agent(
370404
agent_instance.tool(repo_info, retries=2, prepare=cap_prepare)
371405

372406
else:
373-
log.info(f"♻️ Using global agent (model: {effective_model}, num_choices: {effective_num_choices})")
407+
log.info(
408+
f"♻️ Using global agent (model: {effective_model}, num_choices: {effective_num_choices})"
409+
)
374410

375411
log.debug(
376412
f"Prompt length: {len(prompt)} chars, has_image_paths: {bool(image_paths)}, has_image_bytes: {bool(image_bytes)}"
@@ -389,7 +425,9 @@ def run_agent(
389425
),
390426
]
391427
else:
392-
log.warning("⚠️ No image bytes provided - the model will not see the image preview")
428+
log.warning(
429+
"⚠️ No image bytes provided - the model will not see the image preview"
430+
)
393431
user_prompt = prompt
394432

395433
# ---- 6) Run the agent --------------------------------------------------
@@ -402,7 +440,9 @@ def run_agent(
402440
)
403441
result = run_result.output
404442

405-
log.info(f"✅ Agent execution complete - choices returned: {len(result.choices)}")
443+
log.info(
444+
f"✅ Agent execution complete - choices returned: {len(result.choices)}"
445+
)
406446

407447
# Log usage (helpful, but may not explicitly expose image-specific counters)
408448
if run_result.usage:
@@ -414,28 +454,35 @@ def run_agent(
414454

415455
# Warn if using non-OpenAI endpoint with images
416456
if image_bytes and effective_base_url:
417-
log.warning("⚠️ Using custom endpoint - confirm the selected model supports vision.")
457+
log.warning(
458+
"⚠️ Using custom endpoint - confirm the selected model supports vision."
459+
)
418460

419461
except Exception as e:
420462
# Handle global tool quota limit (UsageLimitExceeded) and other errors gracefully
421463
error_msg = str(e)
422464
log.warning(f"⚠️ Agent execution encountered an error: {error_msg}")
423465
run_result = None # Ensure run_result is defined for usage stats extraction
424-
466+
425467
# Check if this is a usage limit error (global tool quota)
426-
if "UsageLimitExceeded" in str(type(e).__name__) or "tool_calls_limit" in error_msg.lower():
427-
log.warning("Global tool call quota reached - continuing with partial results")
468+
if (
469+
"UsageLimitExceeded" in str(type(e).__name__)
470+
or "tool_calls_limit" in error_msg.lower()
471+
):
472+
log.warning(
473+
"Global tool call quota reached - continuing with partial results"
474+
)
428475

429476
result = ToolSelection(
430477
conversation=Conversation(
431478
status=ConversationStatus.COMPLETE,
432479
context="The agent reached the maximum number of tool calls allowed. Please try a more specific query or break down your request into smaller parts.",
433480
question=None,
434-
options=None
481+
options=None,
435482
),
436483
choices=[],
437484
explanation="Tool call limit reached during execution. Try refining your query.",
438-
reason=None
485+
reason=None,
439486
)
440487
else:
441488
raise
@@ -445,7 +492,9 @@ def run_agent(
445492
tool_name = tc.get("tool")
446493
timestamp = tc.get("timestamp")
447494
error = tc.get("error")
448-
inputs = {k: v for k, v in tc.items() if k not in ("tool", "timestamp", "error")}
495+
inputs = {
496+
k: v for k, v in tc.items() if k not in ("tool", "timestamp", "error")
497+
}
449498
tool_logs.append(
450499
ToolRunLog(
451500
tool=tool_name,
@@ -464,7 +513,7 @@ def run_agent(
464513
input_tokens=usage.input_tokens,
465514
output_tokens=usage.output_tokens,
466515
)
467-
516+
468517
# ---- 9) Wrap into high-level AgentToolSelection ------------------------
469518
return AgentToolSelection(
470519
conversation=result.conversation,
@@ -476,4 +525,4 @@ def run_agent(
476525
)
477526

478527

479-
__all__ = ["run_agent", "agent"]
528+
__all__ = ["run_agent", "agent"]

src/ai_agent/agent/models.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5,18 +5,22 @@
55

66
from ai_agent.generator.schema import ToolSelection, CandidateDoc
77

8+
89
class ToolRunLog(BaseModel):
910
tool: str
1011
inputs: Dict[str, Any] = Field(default_factory=dict)
1112
error: Optional[str] = None
1213
timestamp: Optional[str] = None
1314

15+
1416
class UsageStats(BaseModel):
1517
"""Token usage statistics from the agent."""
18+
1619
total_tokens: int = 0
1720
input_tokens: int = 0
1821
output_tokens: int = 0
1922

23+
2024
class AgentToolSelection(ToolSelection):
2125
tool_calls: List[ToolRunLog] = Field(default_factory=list)
2226
usage: Optional[UsageStats] = None
@@ -32,6 +36,7 @@ def to_legacy_dict(self) -> Dict[str, Any]:
3236
"usage": self.usage.model_dump(mode="python") if self.usage else None,
3337
}
3438

39+
3540
__all__ = [
3641
"AgentToolSelection",
3742
"ToolRunLog",

src/ai_agent/agent/tools/__init__.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -31,6 +31,6 @@ def ensure_tools_registered():
3131
from .search_alternative_tool import tool_search_alternative
3232
from .repo_info_tool import tool_repo_summary
3333
from .gradio_space_tool import tool_run_example
34-
34+
3535
# Import MCP tools
3636
ensure_mcp_tools_registered()

src/ai_agent/agent/tools/deepwiki_tool.py

Lines changed: 8 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -19,14 +19,15 @@
1919
DEEPWIKI_TIMEOUT = 60
2020

2121

22-
2322
class DeepWikiInput(BaseModel):
2423
"""Input for DeepWiki operations."""
24+
2525
url: str # GitHub repository URL or owner/repo format
2626

2727

2828
class DeepWikiContentsOutput(BaseModel):
2929
"""Output from read_wiki_contents."""
30+
3031
success: bool
3132
contents: Optional[str] = None
3233
error: Optional[str] = None
@@ -55,11 +56,11 @@ async def get_wiki_contents(input: DeepWikiInput) -> DeepWikiContentsOutput:
5556
if isinstance(result, str):
5657
# Direct string result
5758
text = result
58-
elif hasattr(result, 'content'):
59+
elif hasattr(result, "content"):
5960
# MCP ToolResult with content field
6061
text_parts = []
6162
for item in result.content:
62-
if hasattr(item, 'text'):
63+
if hasattr(item, "text"):
6364
text_parts.append(item.text)
6465
elif isinstance(item, str):
6566
text_parts.append(item)
@@ -76,7 +77,9 @@ async def get_wiki_contents(input: DeepWikiInput) -> DeepWikiContentsOutput:
7677
truncated=truncated,
7778
)
7879

79-
return DeepWikiContentsOutput(success=False, error="No content returned from DeepWiki")
80+
return DeepWikiContentsOutput(
81+
success=False, error="No content returned from DeepWiki"
82+
)
8083

8184
except asyncio.TimeoutError:
8285
log.warning(f"DeepWiki timed out after {DEEPWIKI_TIMEOUT}s for {repo}")
@@ -96,4 +99,4 @@ async def get_wiki_contents(input: DeepWikiInput) -> DeepWikiContentsOutput:
9699
"get_wiki_contents",
97100
"DeepWikiInput",
98101
"DeepWikiContentsOutput",
99-
]
102+
]

0 commit comments

Comments
 (0)