Skip to content

Commit 2c4e334

Browse files
committed
Add provider selecting flow to AIAssistant class
1 parent c26a2f4 commit 2c4e334

1 file changed

Lines changed: 27 additions & 5 deletions

File tree

django_ai_assistant/helpers/assistants.py

Lines changed: 27 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -27,7 +27,6 @@
2727
RunnableBranch,
2828
)
2929
from langchain_core.tools import BaseTool, StructuredTool
30-
from langchain_openai import ChatOpenAI
3130
from langgraph.graph import END, StateGraph, add_messages
3231
from langgraph.prebuilt import ToolNode
3332
from pydantic import BaseModel
@@ -40,6 +39,11 @@
4039
from django_ai_assistant.langchain.tools import tool as tool_decorator
4140

4241

42+
class ProvidersEnum:
43+
OPENAI = "openai"
44+
ANTHROPIC = "anthropic"
45+
46+
4347
class AIAssistant(abc.ABC): # noqa: F821
4448
"""Base class for AI Assistants. Subclasses must define at least the following attributes:
4549
@@ -116,7 +120,9 @@ class AIAssistant(abc.ABC): # noqa: F821
116120
)
117121
DEFAULT_DOCUMENT_SEPARATOR: ClassVar[str] = "\n\n"
118122

119-
def __init__(self, *, user=None, request=None, view=None, **kwargs: Any):
123+
def __init__(
124+
self, *, user=None, request=None, view=None, provider=ProvidersEnum.OPENAI, **kwargs: Any
125+
):
120126
"""Initialize the AIAssistant instance.\n
121127
Optionally set the current user, request, and view for the assistant.\n
122128
Those can be used in any `@method_tool` to customize behavior.\n
@@ -128,12 +134,14 @@ def __init__(self, *, user=None, request=None, view=None, **kwargs: Any):
128134
A request instance. Defaults to `None`. Stored in `self._request`.
129135
view (Any | None): The current Django view the assistant was initialized with.
130136
A view instance. Defaults to `None`. Stored in `self._view`.
137+
provider (str): TODO: add description
131138
**kwargs: Extra keyword arguments passed to the constructor. Stored in `self._init_kwargs`.
132139
"""
133140

134141
self._user = user
135142
self._request = request
136143
self._view = view
144+
self._provider = provider
137145
self._init_kwargs = kwargs
138146

139147
self._set_method_tools()
@@ -278,14 +286,28 @@ def get_llm(self) -> BaseChatModel:
278286
temperature = self.get_temperature()
279287
model_kwargs = self.get_model_kwargs()
280288

289+
llm_class = None
290+
if self._provider == ProvidersEnum.OPENAI:
291+
try:
292+
from langchain_openai import ChatOpenAI
293+
except ImportError as err:
294+
raise ImportError(
295+
"'langchain_openai' is required to use this provider. "
296+
"Install it with: pip install django-ai-assistant[openai]"
297+
) from err
298+
llm_class = ChatOpenAI
299+
else:
300+
# TODO: raise exception due to incorrect provider
301+
raise
302+
281303
if temperature is not None:
282-
return ChatOpenAI(
304+
return llm_class(
283305
model=model,
284306
temperature=temperature,
285307
model_kwargs=model_kwargs,
286308
)
287309
else:
288-
return ChatOpenAI(
310+
return llm_class(
289311
model=model,
290312
model_kwargs=model_kwargs,
291313
)
@@ -302,7 +324,7 @@ def get_structured_output_llm(self) -> Runnable:
302324
llm = self.get_llm()
303325

304326
method = "json_mode"
305-
if isinstance(llm, ChatOpenAI):
327+
if self._provider == ProvidersEnum.OPENAI:
306328
# When using ChatOpenAI, it's better to use json_schema method
307329
# because it enables strict mode.
308330
# https://platform.openai.com/docs/guides/structured-outputs

0 commit comments

Comments
 (0)