Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 6 additions & 1 deletion .env
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
# LLM Provider: "openai" (default), "anthropic", or "ollama"
# LLM Provider: "openai" (default), "anthropic", "nvidia", or "ollama"
LLM_PROVIDER=openai

# OpenAI Configuration
Expand All @@ -14,3 +14,8 @@ ANTHROPIC_MODEL=claude-sonnet-4-5
# Ollama Configuration (local models)
OLLAMA_MODEL=llama3
OLLAMA_BASE_URL=http://localhost:11434

# NVIDIA NIM Configuration
NVIDIA_API_KEY=
NVIDIA_MODEL=nvidia/llama-3.3-nemotron-super-49b-v1.5
NVIDIA_BASE_URL=https://integrate.api.nvidia.com/v1
3 changes: 2 additions & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -40,8 +40,9 @@ dependencies = [

[project.optional-dependencies]
anthropic = ["langchain-anthropic~=0.3.12"]
nvidia = ["langchain-nvidia-ai-endpoints~=0.3.9"]
ollama = ["langchain-ollama~=0.3.2"]
all = ["langchain-anthropic~=0.3.12", "langchain-ollama~=0.3.2"]
all = ["langchain-anthropic~=0.3.12", "langchain-nvidia-ai-endpoints~=0.3.9", "langchain-ollama~=0.3.2"]

[project.urls]
"Homepage" = "https://github.com/nasa-jpl/rosa"
Expand Down
5 changes: 3 additions & 2 deletions src/rosa/rosa.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@

if TYPE_CHECKING:
from langchain_anthropic import ChatAnthropic
from langchain_nvidia_ai_endpoints import ChatNVIDIA
from langchain_ollama import ChatOllama

from .prompts import RobotSystemPrompts, system_prompts
Expand All @@ -37,7 +38,7 @@

# Tested providers for static analysis; BaseChatModel accepted at runtime.
if TYPE_CHECKING:
ChatModel = Union[ChatOpenAI, AzureChatOpenAI, ChatAnthropic, ChatOllama]
ChatModel = Union[ChatOpenAI, AzureChatOpenAI, ChatAnthropic, ChatNVIDIA, ChatOllama]
else:
ChatModel = BaseChatModel

Expand All @@ -49,7 +50,7 @@ class ROSA:
Args:
ros_version (Literal[1, 2]): The version of ROS that the agent will interact with.
llm (ChatModel): The language model to use for generating responses. Tested providers:
ChatOpenAI, AzureChatOpenAI, ChatAnthropic, and ChatOllama. Other BaseChatModel
ChatOpenAI, AzureChatOpenAI, ChatAnthropic, ChatNVIDIA, and ChatOllama. Other BaseChatModel
subclasses that support tool calling may work but are not officially tested.
Note: token usage tracking is only supported for ChatOpenAI and AzureChatOpenAI.
tools (Optional[list]): A list of additional LangChain tool functions to use with the agent.
Expand Down
22 changes: 18 additions & 4 deletions src/turtle_agent/scripts/llm.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,16 +21,17 @@
def get_llm(streaming: bool = False):
"""A helper function to get the LLM instance.

Supports OpenAI (default), Anthropic and Ollama models.
Supports OpenAI (default), Anthropic, NVIDIA NIM and Ollama models.
Set the LLM_PROVIDER env variable to switch between providers:
- "openai" (default): uses OPENAI_API_KEY
- "anthropic": uses ANTHROPIC_API_KEY
- "nvidia": uses NVIDIA_API_KEY (NIM API)
- "ollama": uses local Ollama instance
"""
dotenv.load_dotenv(dotenv.find_dotenv())

provider = os.getenv("LLM_PROVIDER", "openai").lower().strip()
supported = ("openai", "anthropic", "ollama")
supported = ("openai", "anthropic", "nvidia", "ollama")
if provider not in supported:
raise ValueError(
f"Unknown LLM_PROVIDER: '{provider}'. Must be one of: {', '.join(supported)}"
Expand All @@ -55,6 +56,19 @@ def get_llm(streaming: bool = False):
model=os.getenv("ANTHROPIC_MODEL", "claude-sonnet-4-5"),
streaming=streaming,
)
elif provider == "nvidia":
try:
from langchain_nvidia_ai_endpoints import ChatNVIDIA
except ImportError as e:
raise ImportError(
"Install the project's NVIDIA extra with: pip install 'jpl-rosa[nvidia]'"
) from e
llm = ChatNVIDIA(
api_key=get_env_variable("NVIDIA_API_KEY"),
model=os.getenv("NVIDIA_MODEL", "nvidia/llama-3.3-nemotron-super-49b-v1.5"),
base_url=os.getenv("NVIDIA_BASE_URL", "https://integrate.api.nvidia.com/v1"),
Comment on lines +67 to +69
Copy link

Copilot AI Apr 5, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

get_env_variable() treats an empty string as “set”, and the .env template sets NVIDIA_API_KEY= (empty). With LLM_PROVIDER=nvidia, this will pass an empty API key into ChatNVIDIA and fail later with a less actionable auth error. Consider treating empty/whitespace values as unset (e.g., if not value or not value.strip(): ...) either here or inside get_env_variable() so misconfiguration is caught early with a clear message.

Copilot uses AI. Check for mistakes.
streaming=streaming,
)
elif provider == "ollama":
try:
from langchain_ollama import ChatOllama
Expand Down Expand Up @@ -91,7 +105,7 @@ def get_env_variable(var_name: str) -> str:
raise a ValueError, making it easier to debug configuration issues.
"""
value = os.getenv(var_name)
if value is None:
if not value or not value.strip():
msg = f"Environment variable {var_name} is not set."
raise ValueError(msg)
return value
return value.strip()