Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
113 changes: 79 additions & 34 deletions cli/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -152,24 +152,40 @@ def select_shallow_thinking_agent(provider) -> str:
"ollama": [
("llama3.1 local", "llama3.1"),
("llama3.2 local", "llama3.2"),
]
],
"vllm": [],
}

choice = questionary.select(
"Select Your [Quick-Thinking LLM Engine]:",
choices=[
questionary.Choice(display, value=value)
for display, value in SHALLOW_AGENT_OPTIONS[provider.lower()]
],
instruction="\n- Use arrow keys to navigate\n- Press Enter to select",
style=questionary.Style(
[
("selected", "fg:magenta noinherit"),
("highlighted", "fg:magenta noinherit"),
("pointer", "fg:magenta noinherit"),
]
),
).ask()
if provider == "vllm":
choice = questionary.text(
"Please input the vllm model name for shallow thinking (default: llama3.1):",
default="llama3.1",
validate=lambda x: len(x.strip()) > 0 or "Please enter a valid model name.",
style=questionary.Style(
[
("text", "fg:green"),
("highlighted", "noinherit"),
]
),
).ask()

else:
# Use questionary to create an interactive selection menu for shallow thinking LLM engines
choice = questionary.select(
"Select Your [Quick-Thinking LLM Engine]:",
choices=[
questionary.Choice(display, value=value)
for display, value in SHALLOW_AGENT_OPTIONS[provider.lower()]
],
instruction="\n- Use arrow keys to navigate\n- Press Enter to select",
style=questionary.Style(
[
("selected", "fg:magenta noinherit"),
("highlighted", "fg:magenta noinherit"),
("pointer", "fg:magenta noinherit"),
]
),
).ask()
Comment on lines +159 to +188

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

This if/else block for selecting a model is nearly identical to the one in select_deep_thinking_agent (lines 243-272). This duplication makes the code harder to maintain. Consider extracting this logic into a helper function that can be called by both select_shallow_thinking_agent and select_deep_thinking_agent to reduce redundancy.


if choice is None:
console.print(
Expand Down Expand Up @@ -214,24 +230,40 @@ def select_deep_thinking_agent(provider) -> str:
"ollama": [
("llama3.1 local", "llama3.1"),
("qwen3", "qwen3"),
]
}

choice = questionary.select(
"Select Your [Deep-Thinking LLM Engine]:",
choices=[
questionary.Choice(display, value=value)
for display, value in DEEP_AGENT_OPTIONS[provider.lower()]
],
instruction="\n- Use arrow keys to navigate\n- Press Enter to select",
style=questionary.Style(
[
("selected", "fg:magenta noinherit"),
("highlighted", "fg:magenta noinherit"),
("pointer", "fg:magenta noinherit"),
]
),
).ask()
"vllm": [],
}

if provider == "vllm":
choice = questionary.text(
"Please input the vllm model name for deep thinking (default: llama3.1):",
default="llama3.1",
validate=lambda x: len(x.strip()) > 0 or "Please enter a valid model name.",
style=questionary.Style(
[
("text", "fg:green"),
("highlighted", "noinherit"),
]
),
).ask()

else:
# Use questionary to create an interactive selection menu for deep thinking LLM engines
choice = questionary.select(
"Select Your [Deep-Thinking LLM Engine]:",
choices=[
questionary.Choice(display, value=value)
for display, value in DEEP_AGENT_OPTIONS[provider.lower()]
],
instruction="\n- Use arrow keys to navigate\n- Press Enter to select",
style=questionary.Style(
[
("selected", "fg:magenta noinherit"),
("highlighted", "fg:magenta noinherit"),
("pointer", "fg:magenta noinherit"),
]
),
).ask()

if choice is None:
console.print("\n[red]No deep thinking llm engine selected. Exiting...[/red]")
Expand All @@ -247,7 +279,8 @@ def select_llm_provider() -> tuple[str, str]:
("Anthropic", "https://api.anthropic.com/"),
("Google", "https://generativelanguage.googleapis.com/v1"),
("Openrouter", "https://openrouter.ai/api/v1"),
("Ollama", "http://localhost:11434/v1"),
("Ollama", "http://localhost:11434/v1"),
("vllm", "http://localhost:8000/v1"),
]

choice = questionary.select(
Expand All @@ -271,6 +304,18 @@ def select_llm_provider() -> tuple[str, str]:
exit(1)

display_name, url = choice
if display_name == "vllm":
url = questionary.text(
"Please input the vllm api url (default: http://localhost:8000/v1):",
default="http://localhost:8000/v1",
validate=lambda x: len(x.strip()) > 0 or "Please enter a valid URL.",
style=questionary.Style(
[
("text", "fg:green"),
("highlighted", "noinherit"),
]
),
).ask()
print(f"You selected: {display_name}\tURL: {url}")

return display_name, url
2 changes: 2 additions & 0 deletions main.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,8 @@
"news_data": "alpha_vantage", # Options: openai, alpha_vantage, google, local
}

config["embeddings"] = "text-embedding-3-small"

# Initialize with custom config
ta = TradingAgentsGraph(debug=True, config=config)

Expand Down
2 changes: 2 additions & 0 deletions tradingagents/agents/utils/memory.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,8 @@ class FinancialSituationMemory:
def __init__(self, name, config):
if config["backend_url"] == "http://localhost:11434/v1":
self.embedding = "nomic-embed-text"
elif config["llm_provider"] == "vllm":
self.embedding = config["embeddings"]
else:
self.embedding = "text-embedding-3-small"
self.client = OpenAI(base_url=config["backend_url"])
Expand Down
1 change: 1 addition & 0 deletions tradingagents/default_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,4 +30,5 @@
# Example: "get_stock_data": "alpha_vantage", # Override category default
# Example: "get_news": "openai", # Override category default
},
"embeddings": "text-embedding-3-small", # Options: text-embedding-3-small, nomic-embed-text, vllm model name
}
17 changes: 16 additions & 1 deletion tradingagents/graph/trading_graph.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,8 @@
from datetime import date
from typing import Dict, Any, Tuple, List, Optional

import questionary

from langchain_openai import ChatOpenAI
from langchain_anthropic import ChatAnthropic
from langchain_google_genai import ChatGoogleGenerativeAI
Expand Down Expand Up @@ -72,7 +74,7 @@ def __init__(
)

# Initialize LLMs
if self.config["llm_provider"].lower() == "openai" or self.config["llm_provider"] == "ollama" or self.config["llm_provider"] == "openrouter":
if self.config["llm_provider"].lower() == "openai" or self.config["llm_provider"] == "ollama" or self.config["llm_provider"] == "openrouter" or self.config["llm_provider"] == "vllm":
self.deep_thinking_llm = ChatOpenAI(model=self.config["deep_think_llm"], base_url=self.config["backend_url"])
self.quick_thinking_llm = ChatOpenAI(model=self.config["quick_think_llm"], base_url=self.config["backend_url"])
elif self.config["llm_provider"].lower() == "anthropic":
Expand All @@ -85,6 +87,19 @@ def __init__(
raise ValueError(f"Unsupported LLM provider: {self.config['llm_provider']}")

# Initialize memories
if self.config["llm_provider"] == "vllm":
self.config["embeddings"] = questionary.text(
"Please input the vllm embedding model name (default: None):",
default="None",
validate=lambda x: len(x.strip()) > 0 or "Please enter a valid embedding model name.",
style=questionary.Style(
[
("text", "fg:green"),
("highlighted", "noinherit"),
]
),
).ask()

self.bull_memory = FinancialSituationMemory("bull_memory", self.config)
self.bear_memory = FinancialSituationMemory("bear_memory", self.config)
self.trader_memory = FinancialSituationMemory("trader_memory", self.config)
Expand Down