Skip to content
Open
Show file tree
Hide file tree
Changes from 5 commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 6 additions & 1 deletion .env
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
# LLM Provider: "openai" (default), "anthropic", or "ollama"
# LLM Provider: "openai" (default), "anthropic", "nvidia", or "ollama"
LLM_PROVIDER=openai

# OpenAI Configuration
Expand All @@ -14,3 +14,8 @@ ANTHROPIC_MODEL=claude-sonnet-4-5
# Ollama Configuration (local models)
OLLAMA_MODEL=llama3
OLLAMA_BASE_URL=http://localhost:11434

# NVIDIA NIM Configuration
NVIDIA_API_KEY=
NVIDIA_MODEL=nvidia/nemotron-3-super-120b-a12b
NVIDIA_BASE_URL=https://integrate.api.nvidia.com/v1
3 changes: 2 additions & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -40,8 +40,9 @@ dependencies = [

[project.optional-dependencies]
anthropic = ["langchain-anthropic~=0.3.12"]
nvidia = ["langchain-nvidia-ai-endpoints~=0.3.9"]
ollama = ["langchain-ollama~=0.3.2"]
all = ["langchain-anthropic~=0.3.12", "langchain-ollama~=0.3.2"]
all = ["langchain-anthropic~=0.3.12", "langchain-nvidia-ai-endpoints~=0.3.9", "langchain-ollama~=0.3.2"]

[project.urls]
"Homepage" = "https://github.com/nasa-jpl/rosa"
Expand Down
5 changes: 3 additions & 2 deletions src/rosa/rosa.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@

if TYPE_CHECKING:
from langchain_anthropic import ChatAnthropic
from langchain_nvidia_ai_endpoints import ChatNVIDIA
from langchain_ollama import ChatOllama

from .prompts import RobotSystemPrompts, system_prompts
Expand All @@ -37,7 +38,7 @@

# Tested providers for static analysis; BaseChatModel accepted at runtime.
if TYPE_CHECKING:
ChatModel = Union[ChatOpenAI, AzureChatOpenAI, ChatAnthropic, ChatOllama]
ChatModel = Union[ChatOpenAI, AzureChatOpenAI, ChatAnthropic, ChatNVIDIA, ChatOllama]
else:
ChatModel = BaseChatModel

Expand All @@ -49,7 +50,7 @@ class ROSA:
Args:
ros_version (Literal[1, 2]): The version of ROS that the agent will interact with.
llm (ChatModel): The language model to use for generating responses. Tested providers:
ChatOpenAI, AzureChatOpenAI, ChatAnthropic, and ChatOllama. Other BaseChatModel
ChatOpenAI, AzureChatOpenAI, ChatAnthropic, ChatNVIDIA, and ChatOllama. Other BaseChatModel
subclasses that support tool calling may work but are not officially tested.
Note: token usage tracking is only supported for ChatOpenAI and AzureChatOpenAI.
tools (Optional[list]): A list of additional LangChain tool functions to use with the agent.
Expand Down
21 changes: 18 additions & 3 deletions src/turtle_agent/scripts/llm.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,16 +21,17 @@
def get_llm(streaming: bool = False):
"""A helper function to get the LLM instance.

Supports OpenAI (default), Anthropic and Ollama models.
Supports OpenAI (default), Anthropic, NVIDIA NIM and Ollama models.
Set the LLM_PROVIDER env variable to switch between providers:
- "openai" (default): uses OPENAI_API_KEY
- "anthropic": uses ANTHROPIC_API_KEY
- "nvidia": uses NVIDIA_API_KEY (NIM API)
- "ollama": uses local Ollama instance
"""
dotenv.load_dotenv(dotenv.find_dotenv())

provider = os.getenv("LLM_PROVIDER", "openai").lower().strip()
supported = ("openai", "anthropic", "ollama")
supported = ("openai", "anthropic", "nvidia", "ollama")
if provider not in supported:
raise ValueError(
f"Unknown LLM_PROVIDER: '{provider}'. Must be one of: {', '.join(supported)}"
Expand All @@ -55,6 +56,20 @@ def get_llm(streaming: bool = False):
model=os.getenv("ANTHROPIC_MODEL", "claude-sonnet-4-5"),
streaming=streaming,
)
elif provider == "nvidia":
try:
from langchain_nvidia_ai_endpoints import ChatNVIDIA
except ImportError:
raise ImportError(
"langchain-nvidia-ai-endpoints is required for NVIDIA NIM support. "
"Install the project's NVIDIA extra with: pip install '.[nvidia]'"
)

Copilot AI Apr 5, 2026

Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

When re-raising on failed optional import, consider preserving or suppressing the original exception (except ImportError as e: ... raise ImportError(...) from e or from None). This keeps tracebacks clearer for debugging dependency issues.

Copilot uses AI. Check for mistakes.
llm = ChatNVIDIA(
api_key=get_env_variable("NVIDIA_API_KEY"),
model=os.getenv("NVIDIA_MODEL", "nvidia/nemotron-3-super-120b-a12b"),
base_url=os.getenv("NVIDIA_BASE_URL", "https://integrate.api.nvidia.com/v1"),
Comment on lines +67 to +69

Copilot AI Apr 5, 2026

Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

get_env_variable() treats an empty string as “set”, and the .env template sets NVIDIA_API_KEY= (empty). With LLM_PROVIDER=nvidia, this will pass an empty API key into ChatNVIDIA and fail later with a less actionable auth error. Consider treating empty/whitespace values as unset (e.g., if not value or not value.strip(): ...) either here or inside get_env_variable() so misconfiguration is caught early with a clear message.

Copilot uses AI. Check for mistakes.
streaming=streaming,
)
elif provider == "ollama":
try:
from langchain_ollama import ChatOllama
Expand Down Expand Up @@ -91,7 +106,7 @@ def get_env_variable(var_name: str) -> str:
raise a ValueError, making it easier to debug configuration issues.
"""
value = os.getenv(var_name)
if value is None:
if not value or not value.strip():
msg = f"Environment variable {var_name} is not set."
raise ValueError(msg)
return value

Copilot AI Apr 5, 2026

Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

get_env_variable now rejects whitespace-only values (good), but it returns the unstripped value. If an env var has accidental leading/trailing whitespace, this will pass validation but can still cause auth/URL failures downstream. Consider returning value.strip() (and optionally adjusting the error text to say “not set or empty”).

Copilot uses AI. Check for mistakes.
Loading