11import abc
2+ import importlib
23import inspect
34import re
4- from enum import Enum
55from typing import Annotated , Any , ClassVar , Dict , Sequence , Type , TypedDict , cast
66
77from langchain_core .language_models import BaseChatModel
4040from django_ai_assistant .langchain .tools import tool as tool_decorator
4141
4242
43- class ProvidersEnum (Enum ):
44- OPENAI = "openai"
45- ANTHROPIC = "anthropic"
43+ PROVIDER_LLM_LOOKUP = {
44+ "openai" : "ChatOpenAI" ,
45+ "anthropic" : "ChatAnthropic" ,
46+ }
4647
4748
4849class AIAssistant (abc .ABC ): # noqa: F821
@@ -127,7 +128,7 @@ def __init__(
127128 user = None ,
128129 request = None ,
129130 view = None ,
130- provider = ProvidersEnum . OPENAI . value ,
131+ provider = "openai" ,
131132 ** kwargs : Any ,
132133 ):
133134 """Initialize the AIAssistant instance.\n
@@ -141,7 +142,9 @@ def __init__(
141142 A request instance. Defaults to `None`. Stored in `self._request`.
142143 view (Any | None): The current Django view the assistant was initialized with.
143144 A view instance. Defaults to `None`. Stored in `self._view`.
144- provider (str): TODO: add description
145+ provider (str): The provider that will be used for building the LLM instance.
146+ Requires the corresponding `langchain_[provider]` package to be installed.
147+ Defaults to `openai`. Stored in `self._provider`.
145148 **kwargs: Extra keyword arguments passed to the constructor. Stored in `self._init_kwargs`.
146149 """
147150
@@ -280,6 +283,29 @@ def get_model_kwargs(self) -> dict[str, Any]:
280283 """
281284 return {}
282285
286+ def _import_llm_class (self ):
287+ valid_providers_list = PROVIDER_LLM_LOOKUP .keys ()
288+ if self ._provider not in valid_providers_list :
289+ raise AIAssistantMisconfiguredError (
290+ f"Invalid provider={ self ._provider } , please use one "
291+ f"of the supported providers: { valid_providers_list } "
292+ )
293+
294+ # Performs a deferred import of the LLM class that corresponds to
295+ # the self._provider value and returns it.
296+ try :
297+ langchain_module = importlib .import_module (f"langchain_{ self ._provider } " )
298+ except ImportError as err :
299+ raise ImportError (
300+ f"'langchain_{ self ._provider } ' is required to use this provider. "
301+ f"Install it with: pip install django-ai-assistant[{ self ._provider } ]"
302+ ) from err
303+
304+ return getattr (
305+ langchain_module ,
306+ PROVIDER_LLM_LOOKUP [self ._provider ],
307+ )
308+
283309 def get_llm (self ) -> BaseChatModel :
284310 """Get the LangChain LLM instance for the assistant.
285311 By default, this uses the OpenAI implementation.\n
@@ -293,31 +319,7 @@ def get_llm(self) -> BaseChatModel:
293319 temperature = self .get_temperature ()
294320 model_kwargs = self .get_model_kwargs ()
295321
296- llm_class = None
297- valid_providers_list = [provider .value for provider in ProvidersEnum ]
298- if self ._provider in valid_providers_list :
299- try :
300- if self ._provider == ProvidersEnum .OPENAI .value :
301- from langchain_openai import ChatOpenAI
302-
303- llm_class = ChatOpenAI
304- elif self ._provider == ProvidersEnum .ANTHROPIC .value :
305- from langchain_anthropic import ChatAnthropic
306-
307- llm_class = ChatAnthropic
308- else :
309- raise ImportError
310- except ImportError as err :
311- raise ImportError (
312- f"'langchain_{ self ._provider } ' is required to use this provider. "
313- f"Install it with: pip install django-ai-assistant[{ self ._provider } ]"
314- ) from err
315- else :
316- raise AIAssistantMisconfiguredError (
317- f"Invalid provider={ self ._provider } , please use one "
318- "of the supported providers: "
319- f"{ valid_providers_list } "
320- )
322+ llm_class = self ._import_llm_class ()
321323
322324 if temperature is not None :
323325 return llm_class (
@@ -343,7 +345,7 @@ def get_structured_output_llm(self) -> Runnable:
343345 llm = self .get_llm ()
344346
345347 method = "json_mode"
346- if self ._provider == ProvidersEnum . OPENAI . value :
348+ if self ._provider == "openai" :
347349 # When using ChatOpenAI, it's better to use json_schema method
348350 # because it enables strict mode.
349351 # https://platform.openai.com/docs/guides/structured-outputs
0 commit comments