Skip to content

Commit 4819a7e

Browse files
committed
Add typing for AIAssistant.__init__ provider
1 parent f94983d commit 4819a7e

2 files changed

Lines changed: 46 additions & 4 deletions

File tree

django_ai_assistant/helpers/assistants.py

Lines changed: 20 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -53,7 +53,15 @@
5353
from django_ai_assistant.langchain.tools import tool as tool_decorator
5454

5555

56-
PROVIDER_LLM_LOOKUP = {
56+
ProviderName = Literal["openai", "anthropic", "google"]
57+
58+
59+
class ProviderConfig(TypedDict):
60+
langchain_module: str
61+
llm_class: str
62+
63+
64+
PROVIDER_LLM_LOOKUP: dict[ProviderName, ProviderConfig] = {
5765
"openai": {
5866
"langchain_module": "langchain_openai",
5967
"llm_class": "ChatOpenAI",
@@ -134,6 +142,8 @@ class AIAssistant(abc.ABC): # noqa: F821
134142
Can be used in any `@method_tool` to customize behavior."""
135143
_method_tools: Sequence[BaseTool]
136144
"""List of `@method_tool` tools the assistant can use. Automatically set by the constructor."""
145+
_provider: ProviderName
146+
"""The provider key used to resolve and import the chat model class."""
137147

138148
_registry: ClassVar[dict[str, type["AIAssistant"]]] = {}
139149
"""Registry of all AIAssistant subclasses by their id.\n
@@ -151,7 +161,7 @@ def __init__(
151161
user=None,
152162
request=None,
153163
view=None,
154-
provider="openai",
164+
provider: ProviderName = "openai",
155165
**kwargs: Any,
156166
):
157167
"""Initialize the AIAssistant instance.\n
@@ -165,7 +175,7 @@ def __init__(
165175
A request instance. Defaults to `None`. Stored in `self._request`.
166176
view (Any | None): The current Django view the assistant was initialized with.
167177
A view instance. Defaults to `None`. Stored in `self._view`.
168-
provider (str): The provider that will be used for building the LLM instance.
178+
provider (ProviderName): The provider used to build the LLM instance.
169179
Requires the corresponding langchain module to be installed (see `PROVIDER_LLM_LOOKUP`).
170180
Defaults to `openai`. Stored in `self._provider`.
171181
**kwargs: Extra keyword arguments passed to the constructor. Stored in `self._init_kwargs`.
@@ -326,7 +336,13 @@ def _import_llm_class(self):
326336
f"Install it with: pip install django-ai-assistant[{self._provider}]"
327337
) from err
328338

329-
return getattr(langchain_module, provider["llm_class"])
339+
llm_class_str = provider["llm_class"]
340+
try:
341+
return getattr(langchain_module, llm_class_str)
342+
except AttributeError as err:
343+
raise ImportError(
344+
f"'{llm_class_str}' is not a valid LLM class for provider '{self._provider}'."
345+
) from err
330346

331347
def get_llm(self) -> BaseChatModel:
332348
"""Get the LangChain LLM instance for the assistant.

tests/test_helpers/test_assistants.py

Lines changed: 26 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -508,6 +508,32 @@ class UninstalledAIAssistant(AIAssistant):
508508
assistant.get_llm()
509509

510510

511+
def test_AIAssistant_get_llm_invalid_llm_class_for_provider(monkeypatch):
512+
class InvalidClassAIAssistant(AIAssistant):
513+
id = "override_invalid_class_assistant" # noqa: A003
514+
name = "Override Invalid Class Assistant"
515+
instructions = "Instructions"
516+
model = "gpt-test"
517+
518+
assistant = InvalidClassAIAssistant(provider="openai")
519+
520+
from django_ai_assistant.helpers import assistants
521+
522+
monkeypatch.setattr(
523+
assistants,
524+
"PROVIDER_LLM_LOOKUP",
525+
{
526+
"openai": {
527+
"langchain_module": "math",
528+
"llm_class": "NotExistingClass",
529+
},
530+
},
531+
)
532+
533+
with pytest.raises(ImportError):
534+
assistant.get_llm()
535+
536+
511537
@pytest.mark.vcr
512538
def test_AIAssistant_pydantic_structured_output():
513539
from pydantic import BaseModel

0 commit comments

Comments
 (0)