2727 RunnableBranch ,
2828)
2929from langchain_core .tools import BaseTool , StructuredTool
30- from langchain_openai import ChatOpenAI
3130from langgraph .graph import END , StateGraph , add_messages
3231from langgraph .prebuilt import ToolNode
3332from pydantic import BaseModel
4039from django_ai_assistant .langchain .tools import tool as tool_decorator
4140
4241
42+ class ProvidersEnum :
43+ OPENAI = "openai"
44+ ANTHROPIC = "anthropic"
45+
46+
4347class AIAssistant (abc .ABC ): # noqa: F821
4448 """Base class for AI Assistants. Subclasses must define at least the following attributes:
4549
@@ -116,7 +120,9 @@ class AIAssistant(abc.ABC): # noqa: F821
116120 )
117121 DEFAULT_DOCUMENT_SEPARATOR : ClassVar [str ] = "\n \n "
118122
119- def __init__ (self , * , user = None , request = None , view = None , ** kwargs : Any ):
123+ def __init__ (
124+ self , * , user = None , request = None , view = None , provider = ProvidersEnum .OPENAI , ** kwargs : Any
125+ ):
120126 """Initialize the AIAssistant instance.\n
121127 Optionally set the current user, request, and view for the assistant.\n
122128 Those can be used in any `@method_tool` to customize behavior.\n
@@ -128,12 +134,14 @@ def __init__(self, *, user=None, request=None, view=None, **kwargs: Any):
128134 A request instance. Defaults to `None`. Stored in `self._request`.
129135 view (Any | None): The current Django view the assistant was initialized with.
130136 A view instance. Defaults to `None`. Stored in `self._view`.
137+ provider (str): TODO: add description
131138 **kwargs: Extra keyword arguments passed to the constructor. Stored in `self._init_kwargs`.
132139 """
133140
134141 self ._user = user
135142 self ._request = request
136143 self ._view = view
144+ self ._provider = provider
137145 self ._init_kwargs = kwargs
138146
139147 self ._set_method_tools ()
@@ -278,14 +286,28 @@ def get_llm(self) -> BaseChatModel:
278286 temperature = self .get_temperature ()
279287 model_kwargs = self .get_model_kwargs ()
280288
289+ llm_class = None
290+ if self ._provider == ProvidersEnum .OPENAI :
291+ try :
292+ from langchain_openai import ChatOpenAI
293+ except ImportError as err :
294+ raise ImportError (
295+ "'langchain_openai' is required to use this provider. "
296+ "Install it with: pip install django-ai-assistant[openai]"
297+ ) from err
298+ llm_class = ChatOpenAI
299+ else :
300+ # TODO: raise exception due to incorrect provider
301+ raise
302+
281303 if temperature is not None :
282- return ChatOpenAI (
304+ return llm_class (
283305 model = model ,
284306 temperature = temperature ,
285307 model_kwargs = model_kwargs ,
286308 )
287309 else :
288- return ChatOpenAI (
310+ return llm_class (
289311 model = model ,
290312 model_kwargs = model_kwargs ,
291313 )
@@ -302,7 +324,7 @@ def get_structured_output_llm(self) -> Runnable:
302324 llm = self .get_llm ()
303325
304326 method = "json_mode"
305- if isinstance ( llm , ChatOpenAI ) :
327+ if self . _provider == ProvidersEnum . OPENAI :
306328 # When using ChatOpenAI, it's better to use json_schema method
307329 # because it enables strict mode.
308330 # https://platform.openai.com/docs/guides/structured-outputs
0 commit comments