Skip to content

Commit 023e432

Browse files
committed
switch of model during usage now working
1 parent 54db1ca commit 023e432

7 files changed

Lines changed: 234 additions & 48 deletions

File tree

CHANGELOG.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -59,6 +59,7 @@ All notable changes to this project will be documented in this file.
5959
- **Legacy Method**: Removed `recommend_and_link()` method from `api/pipeline.py` (~180 lines) - only used by outdated tests, replaced by agent-based approach.
6060
- **State Variables**: Removed 3 Gradio State objects: `last_task_state`, `last_suggestions_state`, `excluded_names`.
6161
- **Outdated Tests**: Removed `tests/full_test.py` which only tested the removed `recommend_and_link()` method.
62+
- CLI no more supports `ai_agent ui` command
6263

6364
### Fixed
6465
- **Conversation Context**: Agent now properly maintains conversation history, enabling natural understanding of follow-up requests like "show me alternatives".

src/ai_agent/agent/agent.py

Lines changed: 104 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,7 @@
88
from pydantic_ai.models.openai import OpenAIChatModel
99
from pydantic_ai.providers.openai import OpenAIProvider
1010

11-
from ai_agent.generator.prompts import AGENT_SYSTEM_PROMPT
11+
from ai_agent.generator.prompts import get_agent_system_prompt
1212
from ai_agent.generator.schema import ToolSelection
1313
from ai_agent.api.pipeline import RAGImagingPipeline
1414
from ai_agent.utils.utils import _best_runnable_link
@@ -54,7 +54,7 @@
5454

5555
agent = Agent(
5656
model=openai_model,
57-
system_prompt=AGENT_SYSTEM_PROMPT,
57+
system_prompt=get_agent_system_prompt(os.getenv("NUM_CHOICES", "3")),
5858
deps_type=AgentState,
5959
)
6060

@@ -65,7 +65,9 @@
6565
async def search_tools(ctx: RunContext[AgentState], query: str, excluded: List[str] | None = None, top_k: int = 12, original_formats: List[str] | None = None):
6666
# Merge explicit excluded param with state's excluded_tools
6767
all_excluded = list(set((excluded or []) + ctx.deps.excluded_tools))
68-
out = tool_search_tools(SearchToolsInput(query=query, excluded=all_excluded, top_k=top_k, original_formats=original_formats or []))
68+
# Use override from context if available
69+
effective_top_k = ctx.deps.override_top_k if ctx.deps.override_top_k is not None else top_k
70+
out = tool_search_tools(SearchToolsInput(query=query, excluded=all_excluded, top_k=effective_top_k, original_formats=original_formats or []))
6971
payload = [c.model_dump(mode="python") for c in out.candidates]
7072
ctx.deps.tool_calls.append({"tool": "search_tools", "query": query, "count": len(payload), "original_formats": original_formats or [], "excluded": all_excluded, "timestamp": datetime.now().isoformat()})
7173
return payload
@@ -161,6 +163,7 @@ def run_agent(
161163
image_meta: str | None = None,
162164
conversation_history: List[str] | None = None,
163165
model: str | None = None,
166+
base_url: str | None = None,
164167
top_k: int | None = None,
165168
num_choices: int | None = None,
166169
) -> AgentToolSelection:
@@ -172,7 +175,15 @@ def run_agent(
172175

173176
tool_logs: List[ToolRunLog] = []
174177

175-
deps = AgentState(excluded_tools=excluded or [])
178+
# Create AgentState with runtime overrides
179+
deps = AgentState(
180+
excluded_tools=excluded or [],
181+
override_model=model,
182+
override_base_url=base_url,
183+
override_top_k=top_k,
184+
override_num_choices=num_choices,
185+
)
186+
176187
# Provide hidden metadata context lines (non-user-visible) below a delimiter
177188
hidden_meta = ""
178189
if original_formats:
@@ -194,7 +205,95 @@ def run_agent(
194205
else:
195206
prompt = task + extra_context + hidden_meta
196207

197-
result = agent.run_sync(prompt, deps=deps, output_type=ToolSelection, usage_limits=UsageLimits(tool_calls_limit=10)).output
208+
# Determine which agent instance to use
209+
agent_instance = agent # Default to global agent
210+
effective_num_choices = num_choices if num_choices is not None else 3
211+
effective_model = model if model else agent_model_config.name
212+
effective_top_k = top_k if top_k is not None else 12
213+
214+
# When model is provided from UI, base_url comes with it (can be None for OpenAI)
215+
# When model is NOT provided, use config defaults
216+
if model:
217+
# Model selected from dropdown - base_url parameter is authoritative
218+
if base_url and "inference.rcp.epfl.ch" in base_url:
219+
# EPFL model selected
220+
runtime_api_key = os.getenv("EPFL_API_KEY")
221+
if not runtime_api_key:
222+
raise ValueError("EPFL_API_KEY not found. Cannot use EPFL models without VPN and API key.")
223+
effective_base_url = base_url
224+
log.info("✓ Using EPFL_API_KEY for EPFL inference server")
225+
else:
226+
# OpenAI or other model selected (base_url=None means OpenAI)
227+
runtime_api_key = os.getenv("OPENAI_API_KEY")
228+
if not runtime_api_key:
229+
raise ValueError("OPENAI_API_KEY not found. Cannot use OpenAI models.")
230+
effective_base_url = base_url # Will be None for OpenAI
231+
log.info("✓ Using OPENAI_API_KEY for OpenAI endpoint")
232+
else:
233+
# No model override - use config defaults
234+
effective_base_url = agent_model_config.base_url
235+
if effective_base_url and "inference.rcp.epfl.ch" in effective_base_url:
236+
runtime_api_key = os.getenv("EPFL_API_KEY")
237+
if not runtime_api_key:
238+
raise ValueError("EPFL_API_KEY not found")
239+
log.info("✓ Using EPFL_API_KEY from config")
240+
else:
241+
runtime_api_key = os.getenv("OPENAI_API_KEY")
242+
if not runtime_api_key:
243+
raise ValueError("OPENAI_API_KEY not found")
244+
log.info("✓ Using OPENAI_API_KEY from config")
245+
246+
# Log runtime configuration
247+
endpoint_display = effective_base_url if effective_base_url else "api.openai.com"
248+
log.info(
249+
f"🤖 Agent execution - Model: {effective_model}, endpoint: {endpoint_display}, "
250+
f"top_k: {effective_top_k}, num_choices: {effective_num_choices}, excluded: {len(excluded or [])}"
251+
)
252+
253+
# Create dynamic agent:
254+
needs_dynamic_agent = (
255+
(model and model != agent_model_config.name) or
256+
(base_url is not None and base_url != agent_model_config.base_url) or
257+
(runtime_api_key != api_key) # API key mismatch - need new agent!
258+
)
259+
260+
if needs_dynamic_agent:
261+
log.info(f"📦 Creating runtime agent with model={effective_model}, endpoint={effective_base_url or 'api.openai.com'}")
262+
263+
runtime_provider = OpenAIProvider(
264+
base_url=effective_base_url,
265+
api_key=runtime_api_key,
266+
)
267+
runtime_model = OpenAIChatModel(model_name=effective_model, provider=runtime_provider)
268+
agent_instance = Agent(
269+
model=runtime_model,
270+
system_prompt=get_agent_system_prompt(effective_num_choices),
271+
deps_type=AgentState,
272+
)
273+
# Register tools on the dynamic agent
274+
agent_instance.tool(search_tools, retries=2, prepare=cap_prepare)
275+
agent_instance.tool(rerank, retries=2, prepare=cap_prepare)
276+
agent_instance.tool(repo_info, retries=0, prepare=cap_prepare)
277+
agent_instance.tool(resolve_demo_link, retries=2, prepare=cap_prepare)
278+
elif num_choices is not None and num_choices != 3:
279+
# Model/base_url same but num_choices differs - create agent with updated prompt
280+
log.info(f"📦 Creating runtime agent with num_choices={effective_num_choices} (model: {effective_model})")
281+
agent_instance = Agent(
282+
model=openai_model,
283+
system_prompt=get_agent_system_prompt(effective_num_choices),
284+
deps_type=AgentState,
285+
)
286+
# Register tools on the dynamic agent
287+
agent_instance.tool(search_tools, retries=2, prepare=cap_prepare)
288+
agent_instance.tool(rerank, retries=2, prepare=cap_prepare)
289+
agent_instance.tool(repo_info, retries=0, prepare=cap_prepare)
290+
agent_instance.tool(resolve_demo_link, retries=2, prepare=cap_prepare)
291+
else:
292+
log.info(f"♻️ Using global agent (model: {effective_model}, num_choices: {effective_num_choices})")
293+
294+
log.debug(f"Prompt length: {len(prompt)} chars, has_image: {image_data_url is not None}")
295+
result = agent_instance.run_sync(prompt, deps=deps, output_type=ToolSelection, usage_limits=UsageLimits(tool_calls_limit=10)).output
296+
log.info(f"✅ Agent execution complete - choices returned: {len(result.choices)}")
198297

199298
# Convert tool call dicts into ToolRunLog entries
200299
for tc in deps.tool_calls:

src/ai_agent/agent/utils.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,12 @@ class AgentState(BaseModel):
1818
tool_counts: Dict[str, int] = Field(default_factory=dict)
1919
disabled_tools: Set[str] = Field(default_factory=set)
2020
excluded_tools: List[str] = Field(default_factory=list) # Tools to exclude from search
21+
22+
# Runtime overrides (session-only, not persisted)
23+
override_model: Optional[str] = None
24+
override_base_url: Optional[str] = None
25+
override_top_k: Optional[int] = None
26+
override_num_choices: Optional[int] = None
2127

2228
# Quota decorator + prepare hook -----------------------------------------------
2329

src/ai_agent/generator/prompts.py

Lines changed: 18 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -27,7 +27,7 @@
2727
- Also include a one-line context explaining why you need this info (≤ 15 words).
2828
2929
SCORING WHEN CLEAR (no question)
30-
- Rank up to NUM_CHOICES tools that truly match.
30+
- Rank up to {num_choices} tools that truly match.
3131
- Accuracy (0–100) = Task match (40) + Input compatibility (30) + Features (30).
3232
- Consider format friction (e.g., TIF→NIfTI conversion) in “compatibility” (±5 points).
3333
- Prefer tools matching the file extension/modality and 2D/3D nature.
@@ -37,19 +37,19 @@
3737
and include a structured reason and explanation.
3838
3939
OUTPUT (valid JSON):
40-
{
41-
"conversation": {
40+
{{
41+
"conversation": {{
4242
"status": "needs_clarification" | "complete",
4343
"question": "string, required if status=needs_clarification",
4444
"context": "string, explain why you need this information",
4545
"options": ["option1", "option2", ...] // optional; 3–5 max if present
46-
},
46+
}},
4747
"choices": [
48-
{"name": "tool-name", "rank": 1, "accuracy": 95.5, "why": "...", "demo_link": "optional"}
48+
{{"name": "tool-name", "rank": 1, "accuracy": 95.5, "why": "...", "demo_link": "optional"}}
4949
],
5050
"reason": "no_suitable_tool | no_modality_match | no_task_match | no_dimension_match",
5151
"explanation": "string (required if choices is empty)"
52-
}
52+
}}
5353
5454
CONSISTENCY RULES
5555
- If you return choices = [], you MUST set conversation.status = "complete" and include a reason + explanation.
@@ -77,7 +77,7 @@
7777
+ "\n1. If task ambiguous (operation OR target structure missing) -> immediately return clarification JSON (NO tool calls). Treat ultra-generic inputs like 'help', 'help me', 'suggest tools', 'what can you do', or empty/emoji-only as ambiguous. Do NOT guess a modality or claim PNG just from a preview."
7878
+ "\n2. Otherwise: call search_tools(query) ONCE early (pass original_formats param if present; do NOT manufacture or over-weight formats — they are a soft compatibility hint)."
7979
+ "\n3. If you have >=3 plausible candidates and high confidence, you MAY skip rerank; else call rerank(query,candidate_names)."
80-
+ "\n4. Mandatory repo verification before final output: After search_tools (and optional rerank), take the top K ≤ 3 candidates you plan to return and you MUST call repo_info(url) once for each. Use the repo URL from the candidate payload (field name repo_url; fallback keys: github, url, homepage). If a candidate has no repo URL, drop it rather than guessing. Only after repo_info confirms alignment with the requested task should you call resolve_demo_link(name). Do not return any candidate that wasnt verified by repo_info. Call `repo_info(url)` **only** with a GitHub repo URL or `owner/repo`. If a candidate lacks that, **drop it** (dont pass papers, docs, or homepages)."
80+
+ "\n4. Mandatory repo verification before final output: After search_tools (and optional rerank), take the top K ≤ {num_choices} candidates you plan to return and you MUST call repo_info(url) once for each. Use the repo URL from the candidate payload (field name repo_url; fallback keys: github, url, homepage). If a candidate has no repo URL, drop it rather than guessing. Only after repo_info confirms alignment with the requested task should you call resolve_demo_link(name). Do not return any candidate that wasn't verified by repo_info. Call `repo_info(url)` **only** with a GitHub repo URL or `owner/repo`. If a candidate lacks that, **drop it** (don't pass papers, docs, or homepages)."
8181
+ "\n5. The preview you receive may be PNG even if the original file is TIFF/DICOM/NIfTI, etc. Use provided original_formats hint (if any) for compatibility scoring only; do NOT assume a TIFF implies microscopy (could still be CT exported). Ask for modality if unclear."
8282
+ "\n6. FINAL RESPONSE: ONE JSON object only — no prose, no code fences. Include conversation + choices (rank, accuracy, why) OR clarification question."
8383
+ "\n7. Accuracy scoring: task(40)+compat(30)+features(30); incorporate original formats & 2D/3D nature from metadata; penalize format conversions (−5) if heavy."
@@ -89,4 +89,14 @@
8989
- repo_info(url="https://github.com/org/repo") # for each finalist
9090
- resolve_demo_link(tool_name="ToolName")
9191
"""
92-
)
92+
)
93+
94+
95+
def get_selector_system_prompt(num_choices: int = 3) -> str:
96+
"""Generate the system prompt with dynamic num_choices."""
97+
return SELECTOR_SYSTEM.format(num_choices=num_choices)
98+
99+
100+
def get_agent_system_prompt(num_choices: int = 3) -> str:
101+
"""Generate the full agent system prompt with dynamic num_choices."""
102+
return AGENT_SYSTEM_PROMPT.format(num_choices=num_choices)

src/ai_agent/generator/schema.py

Lines changed: 23 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -145,6 +145,25 @@ class NoToolReason(str, Enum):
145145
NO_DIMENSION_MATCH = "no_dimension_match"
146146
INVALID_FILES = "invalid_files"
147147

148+
class ConversationStatus(str, Enum):
149+
NEEDS_CLARIFICATION = "needs_clarification"
150+
COMPLETE = "complete"
151+
152+
class Conversation(BaseModel):
153+
status: ConversationStatus
154+
question: Optional[str] = None
155+
context: Optional[str] = None
156+
options: Optional[List[str]] = None
157+
158+
@model_validator(mode='after')
159+
def validate_fields(self) -> 'Conversation':
160+
if self.status == ConversationStatus.NEEDS_CLARIFICATION:
161+
if not self.question:
162+
raise ValueError("Question required when status is needs_clarification")
163+
if not self.context:
164+
raise ValueError("Context required when status is needs_clarification")
165+
return self
166+
148167
class ToolChoice(BaseModel):
149168
name: str
150169
rank: int
@@ -197,28 +216,12 @@ def _expl_empty_to_none(cls, v):
197216
return None if v is None or str(v).strip() == "" else v
198217

199218

200-
class ConversationStatus(str, Enum):
201-
NEEDS_CLARIFICATION = "needs_clarification"
202-
COMPLETE = "complete"
203-
204-
class Conversation(BaseModel):
205-
status: ConversationStatus
206-
question: Optional[str] = None
207-
context: Optional[str] = None
208-
options: Optional[List[str]] = None
209-
210-
@model_validator(mode='after')
211-
def validate_fields(self) -> 'Conversation':
212-
if self.status == ConversationStatus.NEEDS_CLARIFICATION:
213-
if not self.question:
214-
raise ValueError("Question required when status is needs_clarification")
215-
if not self.context:
216-
raise ValueError("Context required when status is needs_clarification")
217-
return self
218-
219-
220219
__all__ = [
221220
"CandidateDoc",
222221
"PlanAndCode",
223222
"ToolSelection",
223+
"Conversation",
224+
"ConversationStatus",
225+
"ToolChoice",
226+
"NoToolReason",
224227
]

src/ai_agent/ui/components.py

Lines changed: 27 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,30 @@
1313

1414
log = logging.getLogger("chat_components")
1515

16+
# Model configurations with their inference servers
17+
MODEL_CONFIGS = {
18+
# OpenAI models (default endpoint)
19+
"gpt-4o-mini": {"name": "gpt-4o-mini", "base_url": None, "provider": "OpenAI"},
20+
"gpt-4o": {"name": "gpt-4o", "base_url": None, "provider": "OpenAI"},
21+
"gpt-4-turbo": {"name": "gpt-4-turbo", "base_url": None, "provider": "OpenAI"},
22+
23+
# EPFL inference server models
24+
"openai/gpt-oss-120b [EPFL]": {
25+
"name": "openai/gpt-oss-120b",
26+
"base_url": "https://inference.rcp.epfl.ch/v1",
27+
"provider": "EPFL"
28+
},
29+
"mistralai/Mistral-Small-3.2-24B-Instruct-2506 [EPFL]": {
30+
"name": "mistralai/Mistral-Small-3.2-24B-Instruct-2506",
31+
"base_url": "https://inference.rcp.epfl.ch/v1",
32+
"provider": "EPFL"
33+
},
34+
}
35+
36+
def get_model_config(model_display_name: str) -> Dict[str, str]:
37+
"""Get model configuration from display name."""
38+
return MODEL_CONFIGS.get(model_display_name, {"name": model_display_name, "base_url": None, "provider": "Unknown"})
39+
1640

1741
def create_chat_interface(doc_index: Dict[str, SoftwareDoc]):
1842
"""
@@ -115,10 +139,10 @@ def create_chat_interface(doc_index: Dict[str, SoftwareDoc]):
115139
with gr.Accordion("⚙️ Settings", open=False):
116140
with gr.Row():
117141
model_dropdown = gr.Dropdown(
118-
choices=["gpt-4o-mini", "gpt-4o", "gpt-4-turbo"],
119-
value=os.getenv("OPENAI_MODEL", "gpt-4o-mini"),
142+
choices=list(MODEL_CONFIGS.keys()),
143+
value="gpt-4o-mini",
120144
label="Model",
121-
info="OpenAI model for agent reasoning",
145+
info="Select AI model and inference server",
122146
)
123147
top_k_slider = gr.Slider(
124148
minimum=5,

0 commit comments

Comments
 (0)