Skip to content

Commit c61242a

Browse files
authored
Merge pull request #464 from CadeYu/sync-validator-models
sync model validation with cli catalog
2 parents 58e9942 + bd6a5b7 commit c61242a

8 files changed

Lines changed: 193 additions & 125 deletions

File tree

cli/utils.py

Lines changed: 3 additions & 78 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44
from rich.console import Console
55

66
from cli.models import AnalystType
7+
from tradingagents.llm_clients.model_catalog import get_model_options
78

89
console = Console()
910

@@ -136,48 +137,11 @@ def select_research_depth() -> int:
136137
def select_shallow_thinking_agent(provider) -> str:
137138
"""Select shallow thinking llm engine using an interactive selection."""
138139

139-
# Define shallow thinking llm engine options with their corresponding model names
140-
# Ordering: medium → light → heavy (balanced first for quick tasks)
141-
# Within same tier, newer models first
142-
SHALLOW_AGENT_OPTIONS = {
143-
"openai": [
144-
("GPT-5 Mini - Balanced speed, cost, and capability", "gpt-5-mini"),
145-
("GPT-5 Nano - High-throughput, simple tasks", "gpt-5-nano"),
146-
("GPT-5.4 - Latest frontier, 1M context", "gpt-5.4"),
147-
("GPT-4.1 - Smartest non-reasoning model", "gpt-4.1"),
148-
],
149-
"anthropic": [
150-
("Claude Sonnet 4.6 - Best speed and intelligence balance", "claude-sonnet-4-6"),
151-
("Claude Haiku 4.5 - Fast, near-instant responses", "claude-haiku-4-5"),
152-
("Claude Sonnet 4.5 - Agents and coding", "claude-sonnet-4-5"),
153-
],
154-
"google": [
155-
("Gemini 3 Flash - Next-gen fast", "gemini-3-flash-preview"),
156-
("Gemini 2.5 Flash - Balanced, stable", "gemini-2.5-flash"),
157-
("Gemini 3.1 Flash Lite - Most cost-efficient", "gemini-3.1-flash-lite-preview"),
158-
("Gemini 2.5 Flash Lite - Fast, low-cost", "gemini-2.5-flash-lite"),
159-
],
160-
"xai": [
161-
("Grok 4.1 Fast (Non-Reasoning) - Speed optimized, 2M ctx", "grok-4-1-fast-non-reasoning"),
162-
("Grok 4 Fast (Non-Reasoning) - Speed optimized", "grok-4-fast-non-reasoning"),
163-
("Grok 4.1 Fast (Reasoning) - High-performance, 2M ctx", "grok-4-1-fast-reasoning"),
164-
],
165-
"openrouter": [
166-
("NVIDIA Nemotron 3 Nano 30B (free)", "nvidia/nemotron-3-nano-30b-a3b:free"),
167-
("Z.AI GLM 4.5 Air (free)", "z-ai/glm-4.5-air:free"),
168-
],
169-
"ollama": [
170-
("Qwen3:latest (8B, local)", "qwen3:latest"),
171-
("GPT-OSS:latest (20B, local)", "gpt-oss:latest"),
172-
("GLM-4.7-Flash:latest (30B, local)", "glm-4.7-flash:latest"),
173-
],
174-
}
175-
176140
choice = questionary.select(
177141
"Select Your [Quick-Thinking LLM Engine]:",
178142
choices=[
179143
questionary.Choice(display, value=value)
180-
for display, value in SHALLOW_AGENT_OPTIONS[provider.lower()]
144+
for display, value in get_model_options(provider, "quick")
181145
],
182146
instruction="\n- Use arrow keys to navigate\n- Press Enter to select",
183147
style=questionary.Style(
@@ -201,50 +165,11 @@ def select_shallow_thinking_agent(provider) -> str:
201165
def select_deep_thinking_agent(provider) -> str:
202166
"""Select deep thinking llm engine using an interactive selection."""
203167

204-
# Define deep thinking llm engine options with their corresponding model names
205-
# Ordering: heavy → medium → light (most capable first for deep tasks)
206-
# Within same tier, newer models first
207-
DEEP_AGENT_OPTIONS = {
208-
"openai": [
209-
("GPT-5.4 - Latest frontier, 1M context", "gpt-5.4"),
210-
("GPT-5.2 - Strong reasoning, cost-effective", "gpt-5.2"),
211-
("GPT-5 Mini - Balanced speed, cost, and capability", "gpt-5-mini"),
212-
("GPT-5.4 Pro - Most capable, expensive ($30/$180 per 1M tokens)", "gpt-5.4-pro"),
213-
],
214-
"anthropic": [
215-
("Claude Opus 4.6 - Most intelligent, agents and coding", "claude-opus-4-6"),
216-
("Claude Opus 4.5 - Premium, max intelligence", "claude-opus-4-5"),
217-
("Claude Sonnet 4.6 - Best speed and intelligence balance", "claude-sonnet-4-6"),
218-
("Claude Sonnet 4.5 - Agents and coding", "claude-sonnet-4-5"),
219-
],
220-
"google": [
221-
("Gemini 3.1 Pro - Reasoning-first, complex workflows", "gemini-3.1-pro-preview"),
222-
("Gemini 3 Flash - Next-gen fast", "gemini-3-flash-preview"),
223-
("Gemini 2.5 Pro - Stable pro model", "gemini-2.5-pro"),
224-
("Gemini 2.5 Flash - Balanced, stable", "gemini-2.5-flash"),
225-
],
226-
"xai": [
227-
("Grok 4 - Flagship model", "grok-4-0709"),
228-
("Grok 4.1 Fast (Reasoning) - High-performance, 2M ctx", "grok-4-1-fast-reasoning"),
229-
("Grok 4 Fast (Reasoning) - High-performance", "grok-4-fast-reasoning"),
230-
("Grok 4.1 Fast (Non-Reasoning) - Speed optimized, 2M ctx", "grok-4-1-fast-non-reasoning"),
231-
],
232-
"openrouter": [
233-
("Z.AI GLM 4.5 Air (free)", "z-ai/glm-4.5-air:free"),
234-
("NVIDIA Nemotron 3 Nano 30B (free)", "nvidia/nemotron-3-nano-30b-a3b:free"),
235-
],
236-
"ollama": [
237-
("GLM-4.7-Flash:latest (30B, local)", "glm-4.7-flash:latest"),
238-
("GPT-OSS:latest (20B, local)", "gpt-oss:latest"),
239-
("Qwen3:latest (8B, local)", "qwen3:latest"),
240-
],
241-
}
242-
243168
choice = questionary.select(
244169
"Select Your [Deep-Thinking LLM Engine]:",
245170
choices=[
246171
questionary.Choice(display, value=value)
247-
for display, value in DEEP_AGENT_OPTIONS[provider.lower()]
172+
for display, value in get_model_options(provider, "deep")
248173
],
249174
instruction="\n- Use arrow keys to navigate\n- Press Enter to select",
250175
style=questionary.Style(

tests/test_model_validation.py

Lines changed: 52 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,52 @@
1+
import unittest
2+
import warnings
3+
4+
from tradingagents.llm_clients.base_client import BaseLLMClient
5+
from tradingagents.llm_clients.model_catalog import get_known_models
6+
from tradingagents.llm_clients.validators import validate_model
7+
8+
9+
class DummyLLMClient(BaseLLMClient):
10+
def __init__(self, provider: str, model: str):
11+
self.provider = provider
12+
super().__init__(model)
13+
14+
def get_llm(self):
15+
self.warn_if_unknown_model()
16+
return object()
17+
18+
def validate_model(self) -> bool:
19+
return validate_model(self.provider, self.model)
20+
21+
22+
class ModelValidationTests(unittest.TestCase):
23+
def test_cli_catalog_models_are_all_validator_approved(self):
24+
for provider, models in get_known_models().items():
25+
if provider in ("ollama", "openrouter"):
26+
continue
27+
28+
for model in models:
29+
with self.subTest(provider=provider, model=model):
30+
self.assertTrue(validate_model(provider, model))
31+
32+
def test_unknown_model_emits_warning_for_strict_provider(self):
33+
client = DummyLLMClient("openai", "not-a-real-openai-model")
34+
35+
with warnings.catch_warnings(record=True) as caught:
36+
warnings.simplefilter("always")
37+
client.get_llm()
38+
39+
self.assertEqual(len(caught), 1)
40+
self.assertIn("not-a-real-openai-model", str(caught[0].message))
41+
self.assertIn("openai", str(caught[0].message))
42+
43+
def test_openrouter_and_ollama_accept_custom_models_without_warning(self):
44+
for provider in ("openrouter", "ollama"):
45+
client = DummyLLMClient(provider, "custom-model-name")
46+
47+
with self.subTest(provider=provider):
48+
with warnings.catch_warnings(record=True) as caught:
49+
warnings.simplefilter("always")
50+
client.get_llm()
51+
52+
self.assertEqual(caught, [])

tradingagents/llm_clients/anthropic_client.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -31,6 +31,7 @@ def __init__(self, model: str, base_url: Optional[str] = None, **kwargs):
3131

3232
def get_llm(self) -> Any:
3333
"""Return configured ChatAnthropic instance."""
34+
self.warn_if_unknown_model()
3435
llm_kwargs = {"model": self.model}
3536

3637
if self.base_url:

tradingagents/llm_clients/base_client.py

Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
from abc import ABC, abstractmethod
22
from typing import Any, Optional
3+
import warnings
34

45

56
def normalize_content(response):
@@ -29,6 +30,27 @@ def __init__(self, model: str, base_url: Optional[str] = None, **kwargs):
2930
self.base_url = base_url
3031
self.kwargs = kwargs
3132

33+
def get_provider_name(self) -> str:
34+
"""Return the provider name used in warning messages."""
35+
provider = getattr(self, "provider", None)
36+
if provider:
37+
return str(provider)
38+
return self.__class__.__name__.removesuffix("Client").lower()
39+
40+
def warn_if_unknown_model(self) -> None:
41+
"""Warn when the model is outside the known list for the provider."""
42+
if self.validate_model():
43+
return
44+
45+
warnings.warn(
46+
(
47+
f"Model '{self.model}' is not in the known model list for "
48+
f"provider '{self.get_provider_name()}'. Continuing anyway."
49+
),
50+
RuntimeWarning,
51+
stacklevel=2,
52+
)
53+
3254
@abstractmethod
3355
def get_llm(self) -> Any:
3456
"""Return the configured LLM instance."""

tradingagents/llm_clients/google_client.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,7 @@ def __init__(self, model: str, base_url: Optional[str] = None, **kwargs):
2525

2626
def get_llm(self) -> Any:
2727
"""Return configured ChatGoogleGenerativeAI instance."""
28+
self.warn_if_unknown_model()
2829
llm_kwargs = {"model": self.model}
2930

3031
if self.base_url:
Lines changed: 107 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,107 @@
1+
"""Shared model catalog for CLI selections and validation."""
2+
3+
from __future__ import annotations
4+
5+
from typing import Dict, List, Tuple
6+
7+
ModelOption = Tuple[str, str]
8+
ProviderModeOptions = Dict[str, Dict[str, List[ModelOption]]]
9+
10+
11+
MODEL_OPTIONS: ProviderModeOptions = {
12+
"openai": {
13+
"quick": [
14+
("GPT-5 Mini - Balanced speed, cost, and capability", "gpt-5-mini"),
15+
("GPT-5 Nano - High-throughput, simple tasks", "gpt-5-nano"),
16+
("GPT-5.4 - Latest frontier, 1M context", "gpt-5.4"),
17+
("GPT-4.1 - Smartest non-reasoning model", "gpt-4.1"),
18+
],
19+
"deep": [
20+
("GPT-5.4 - Latest frontier, 1M context", "gpt-5.4"),
21+
("GPT-5.2 - Strong reasoning, cost-effective", "gpt-5.2"),
22+
("GPT-5 Mini - Balanced speed, cost, and capability", "gpt-5-mini"),
23+
("GPT-5.4 Pro - Most capable, expensive ($30/$180 per 1M tokens)", "gpt-5.4-pro"),
24+
],
25+
},
26+
"anthropic": {
27+
"quick": [
28+
("Claude Sonnet 4.6 - Best speed and intelligence balance", "claude-sonnet-4-6"),
29+
("Claude Haiku 4.5 - Fast, near-instant responses", "claude-haiku-4-5"),
30+
("Claude Sonnet 4.5 - Agents and coding", "claude-sonnet-4-5"),
31+
],
32+
"deep": [
33+
("Claude Opus 4.6 - Most intelligent, agents and coding", "claude-opus-4-6"),
34+
("Claude Opus 4.5 - Premium, max intelligence", "claude-opus-4-5"),
35+
("Claude Sonnet 4.6 - Best speed and intelligence balance", "claude-sonnet-4-6"),
36+
("Claude Sonnet 4.5 - Agents and coding", "claude-sonnet-4-5"),
37+
],
38+
},
39+
"google": {
40+
"quick": [
41+
("Gemini 3 Flash - Next-gen fast", "gemini-3-flash-preview"),
42+
("Gemini 2.5 Flash - Balanced, stable", "gemini-2.5-flash"),
43+
("Gemini 3.1 Flash Lite - Most cost-efficient", "gemini-3.1-flash-lite-preview"),
44+
("Gemini 2.5 Flash Lite - Fast, low-cost", "gemini-2.5-flash-lite"),
45+
],
46+
"deep": [
47+
("Gemini 3.1 Pro - Reasoning-first, complex workflows", "gemini-3.1-pro-preview"),
48+
("Gemini 3 Flash - Next-gen fast", "gemini-3-flash-preview"),
49+
("Gemini 2.5 Pro - Stable pro model", "gemini-2.5-pro"),
50+
("Gemini 2.5 Flash - Balanced, stable", "gemini-2.5-flash"),
51+
],
52+
},
53+
"xai": {
54+
"quick": [
55+
("Grok 4.1 Fast (Non-Reasoning) - Speed optimized, 2M ctx", "grok-4-1-fast-non-reasoning"),
56+
("Grok 4 Fast (Non-Reasoning) - Speed optimized", "grok-4-fast-non-reasoning"),
57+
("Grok 4.1 Fast (Reasoning) - High-performance, 2M ctx", "grok-4-1-fast-reasoning"),
58+
],
59+
"deep": [
60+
("Grok 4 - Flagship model", "grok-4-0709"),
61+
("Grok 4.1 Fast (Reasoning) - High-performance, 2M ctx", "grok-4-1-fast-reasoning"),
62+
("Grok 4 Fast (Reasoning) - High-performance", "grok-4-fast-reasoning"),
63+
("Grok 4.1 Fast (Non-Reasoning) - Speed optimized, 2M ctx", "grok-4-1-fast-non-reasoning"),
64+
],
65+
},
66+
"openrouter": {
67+
"quick": [
68+
("NVIDIA Nemotron 3 Nano 30B (free)", "nvidia/nemotron-3-nano-30b-a3b:free"),
69+
("Z.AI GLM 4.5 Air (free)", "z-ai/glm-4.5-air:free"),
70+
],
71+
"deep": [
72+
("Z.AI GLM 4.5 Air (free)", "z-ai/glm-4.5-air:free"),
73+
("NVIDIA Nemotron 3 Nano 30B (free)", "nvidia/nemotron-3-nano-30b-a3b:free"),
74+
],
75+
},
76+
"ollama": {
77+
"quick": [
78+
("Qwen3:latest (8B, local)", "qwen3:latest"),
79+
("GPT-OSS:latest (20B, local)", "gpt-oss:latest"),
80+
("GLM-4.7-Flash:latest (30B, local)", "glm-4.7-flash:latest"),
81+
],
82+
"deep": [
83+
("GLM-4.7-Flash:latest (30B, local)", "glm-4.7-flash:latest"),
84+
("GPT-OSS:latest (20B, local)", "gpt-oss:latest"),
85+
("Qwen3:latest (8B, local)", "qwen3:latest"),
86+
],
87+
},
88+
}
89+
90+
91+
def get_model_options(provider: str, mode: str) -> List[ModelOption]:
92+
"""Return shared model options for a provider and selection mode."""
93+
return MODEL_OPTIONS[provider.lower()][mode]
94+
95+
96+
def get_known_models() -> Dict[str, List[str]]:
97+
"""Build known model names from the shared CLI catalog."""
98+
return {
99+
provider: sorted(
100+
{
101+
value
102+
for options in mode_options.values()
103+
for _, value in options
104+
}
105+
)
106+
for provider, mode_options in MODEL_OPTIONS.items()
107+
}

tradingagents/llm_clients/openai_client.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -53,6 +53,7 @@ def __init__(
5353

5454
def get_llm(self) -> Any:
5555
"""Return configured ChatOpenAI instance."""
56+
self.warn_if_unknown_model()
5657
llm_kwargs = {"model": self.model}
5758

5859
# Provider-specific base URL and auth

tradingagents/llm_clients/validators.py

Lines changed: 6 additions & 47 deletions
Original file line numberDiff line numberDiff line change
@@ -1,53 +1,12 @@
1-
"""Model name validators for each provider.
1+
"""Model name validators for each provider."""
2+
3+
from .model_catalog import get_known_models
24

3-
Only validates model names - does NOT enforce limits.
4-
Let LLM providers use their own defaults for unspecified params.
5-
"""
65

76
VALID_MODELS = {
8-
"openai": [
9-
# GPT-5 series
10-
"gpt-5.4-pro",
11-
"gpt-5.4",
12-
"gpt-5.2",
13-
"gpt-5.1",
14-
"gpt-5",
15-
"gpt-5-mini",
16-
"gpt-5-nano",
17-
# GPT-4.1 series
18-
"gpt-4.1",
19-
"gpt-4.1-mini",
20-
"gpt-4.1-nano",
21-
],
22-
"anthropic": [
23-
# Claude 4.6 series (latest)
24-
"claude-opus-4-6",
25-
"claude-sonnet-4-6",
26-
# Claude 4.5 series
27-
"claude-opus-4-5",
28-
"claude-sonnet-4-5",
29-
"claude-haiku-4-5",
30-
],
31-
"google": [
32-
# Gemini 3.1 series (preview)
33-
"gemini-3.1-pro-preview",
34-
"gemini-3.1-flash-lite-preview",
35-
# Gemini 3 series (preview)
36-
"gemini-3-flash-preview",
37-
# Gemini 2.5 series
38-
"gemini-2.5-pro",
39-
"gemini-2.5-flash",
40-
"gemini-2.5-flash-lite",
41-
],
42-
"xai": [
43-
# Grok 4.1 series
44-
"grok-4-1-fast-reasoning",
45-
"grok-4-1-fast-non-reasoning",
46-
# Grok 4 series
47-
"grok-4-0709",
48-
"grok-4-fast-reasoning",
49-
"grok-4-fast-non-reasoning",
50-
],
7+
provider: models
8+
for provider, models in get_known_models().items()
9+
if provider not in ("ollama", "openrouter")
5110
}
5211

5312

0 commit comments

Comments
 (0)