5353from django_ai_assistant .langchain .tools import tool as tool_decorator
5454
5555
56- PROVIDER_LLM_LOOKUP = {
56+ ProviderName = Literal ["openai" , "anthropic" , "google" ]
57+
58+
59+ class ProviderConfig (TypedDict ):
60+ langchain_module : str
61+ llm_class : str
62+
63+
64+ PROVIDER_LLM_LOOKUP : dict [ProviderName , ProviderConfig ] = {
5765 "openai" : {
5866 "langchain_module" : "langchain_openai" ,
5967 "llm_class" : "ChatOpenAI" ,
@@ -134,6 +142,8 @@ class AIAssistant(abc.ABC): # noqa: F821
134142 Can be used in any `@method_tool` to customize behavior."""
135143 _method_tools : Sequence [BaseTool ]
136144 """List of `@method_tool` tools the assistant can use. Automatically set by the constructor."""
145+ _provider : ProviderName
146+ """The provider key used to resolve and import the chat model class."""
137147
138148 _registry : ClassVar [dict [str , type ["AIAssistant" ]]] = {}
139149 """Registry of all AIAssistant subclasses by their id.\n
@@ -151,7 +161,7 @@ def __init__(
151161 user = None ,
152162 request = None ,
153163 view = None ,
154- provider = "openai" ,
164+ provider : ProviderName = "openai" ,
155165 ** kwargs : Any ,
156166 ):
157167 """Initialize the AIAssistant instance.\n
@@ -165,7 +175,7 @@ def __init__(
165175 A request instance. Defaults to `None`. Stored in `self._request`.
166176 view (Any | None): The current Django view the assistant was initialized with.
167177 A view instance. Defaults to `None`. Stored in `self._view`.
168- provider (str ): The provider that will be used for building the LLM instance.
178+ provider (ProviderName ): The provider used to build the LLM instance.
169179 Requires the corresponding langchain module to be installed (see `PROVIDER_LLM_LOOKUP`).
170180 Defaults to `openai`. Stored in `self._provider`.
171181 **kwargs: Extra keyword arguments passed to the constructor. Stored in `self._init_kwargs`.
@@ -326,7 +336,13 @@ def _import_llm_class(self):
326336 f"Install it with: pip install django-ai-assistant[{ self ._provider } ]"
327337 ) from err
328338
329- return getattr (langchain_module , provider ["llm_class" ])
339+ llm_class_str = provider ["llm_class" ]
340+ try :
341+ return getattr (langchain_module , llm_class_str )
342+ except AttributeError as err :
343+ raise ImportError (
344+ f"'{ llm_class_str } ' is not a valid LLM class for provider '{ self ._provider } '."
345+ ) from err
330346
331347 def get_llm (self ) -> BaseChatModel :
332348 """Get the LangChain LLM instance for the assistant.
0 commit comments