-
Notifications
You must be signed in to change notification settings - Fork 175
Add support for NVIDIA NIM as an LLM provider #79
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Changes from 5 commits
35b5287
fdfb51c
3df4261
809af24
c775f0f
9c6831f
b56cacb
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -21,16 +21,17 @@ | |
| def get_llm(streaming: bool = False): | ||
| """A helper function to get the LLM instance. | ||
|
|
||
| Supports OpenAI (default), Anthropic and Ollama models. | ||
| Supports OpenAI (default), Anthropic, NVIDIA NIM and Ollama models. | ||
| Set the LLM_PROVIDER env variable to switch between providers: | ||
| - "openai" (default): uses OPENAI_API_KEY | ||
| - "anthropic": uses ANTHROPIC_API_KEY | ||
| - "nvidia": uses NVIDIA_API_KEY (NIM API) | ||
| - "ollama": uses local Ollama instance | ||
| """ | ||
| dotenv.load_dotenv(dotenv.find_dotenv()) | ||
|
|
||
| provider = os.getenv("LLM_PROVIDER", "openai").lower().strip() | ||
| supported = ("openai", "anthropic", "ollama") | ||
| supported = ("openai", "anthropic", "nvidia", "ollama") | ||
| if provider not in supported: | ||
| raise ValueError( | ||
| f"Unknown LLM_PROVIDER: '{provider}'. Must be one of: {', '.join(supported)}" | ||
|
|
@@ -55,6 +56,20 @@ def get_llm(streaming: bool = False): | |
| model=os.getenv("ANTHROPIC_MODEL", "claude-sonnet-4-5"), | ||
| streaming=streaming, | ||
| ) | ||
| elif provider == "nvidia": | ||
| try: | ||
| from langchain_nvidia_ai_endpoints import ChatNVIDIA | ||
| except ImportError: | ||
| raise ImportError( | ||
| "langchain-nvidia-ai-endpoints is required for NVIDIA NIM support. " | ||
| "Install the project's NVIDIA extra with: pip install '.[nvidia]'" | ||
| ) | ||
| llm = ChatNVIDIA( | ||
| api_key=get_env_variable("NVIDIA_API_KEY"), | ||
| model=os.getenv("NVIDIA_MODEL", "nvidia/nemotron-3-super-120b-a12b"), | ||
| base_url=os.getenv("NVIDIA_BASE_URL", "https://integrate.api.nvidia.com/v1"), | ||
|
Comment on lines
+67
to
+69
|
||
| streaming=streaming, | ||
| ) | ||
| elif provider == "ollama": | ||
| try: | ||
| from langchain_ollama import ChatOllama | ||
|
|
@@ -91,7 +106,7 @@ def get_env_variable(var_name: str) -> str: | |
| raise a ValueError, making it easier to debug configuration issues. | ||
| """ | ||
| value = os.getenv(var_name) | ||
| if value is None: | ||
| if not value or not value.strip(): | ||
| msg = f"Environment variable {var_name} is not set." | ||
| raise ValueError(msg) | ||
| return value | ||
|
||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
When re-raising on failed optional import, consider preserving or suppressing the original exception (
except ImportError as e: ... raise ImportError(...) from eorfrom None). This keeps tracebacks clearer for debugging dependency issues.