Skip to content

Commit 4d6f07c

Browse files
committed
Refactors how get_llm imports the LLM class and adds tests
1 parent 6890a09 commit 4d6f07c

2 files changed

Lines changed: 95 additions & 32 deletions

File tree

django_ai_assistant/helpers/assistants.py

Lines changed: 34 additions & 32 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
import abc
2+
import importlib
23
import inspect
34
import re
4-
from enum import Enum
55
from typing import Annotated, Any, ClassVar, Dict, Sequence, Type, TypedDict, cast
66

77
from langchain_core.language_models import BaseChatModel
@@ -40,9 +40,10 @@
4040
from django_ai_assistant.langchain.tools import tool as tool_decorator
4141

4242

43-
class ProvidersEnum(Enum):
44-
OPENAI = "openai"
45-
ANTHROPIC = "anthropic"
43+
PROVIDER_LLM_LOOKUP = {
44+
"openai": "ChatOpenAI",
45+
"anthropic": "ChatAnthropic",
46+
}
4647

4748

4849
class AIAssistant(abc.ABC): # noqa: F821
@@ -127,7 +128,7 @@ def __init__(
127128
user=None,
128129
request=None,
129130
view=None,
130-
provider=ProvidersEnum.OPENAI.value,
131+
provider="openai",
131132
**kwargs: Any,
132133
):
133134
"""Initialize the AIAssistant instance.\n
@@ -141,7 +142,9 @@ def __init__(
141142
A request instance. Defaults to `None`. Stored in `self._request`.
142143
view (Any | None): The current Django view the assistant was initialized with.
143144
A view instance. Defaults to `None`. Stored in `self._view`.
144-
provider (str): TODO: add description
145+
provider (str): The provider that will be used for building the LLM instance.
146+
Requires the corresponding `langchain_[provider]` package to be installed.
147+
Defaults to `openai`. Stored in `self._provider`.
145148
**kwargs: Extra keyword arguments passed to the constructor. Stored in `self._init_kwargs`.
146149
"""
147150

@@ -280,6 +283,29 @@ def get_model_kwargs(self) -> dict[str, Any]:
280283
"""
281284
return {}
282285

286+
def _import_llm_class(self):
287+
valid_providers_list = PROVIDER_LLM_LOOKUP.keys()
288+
if self._provider not in valid_providers_list:
289+
raise AIAssistantMisconfiguredError(
290+
f"Invalid provider={self._provider}, please use one "
291+
f"of the supported providers: {valid_providers_list}"
292+
)
293+
294+
# Performs a deferred import of the LLM class that corresponds to
295+
# the self._provider value and returns it.
296+
try:
297+
langchain_module = importlib.import_module(f"langchain_{self._provider}")
298+
except ImportError as err:
299+
raise ImportError(
300+
f"'langchain_{self._provider}' is required to use this provider. "
301+
f"Install it with: pip install django-ai-assistant[{self._provider}]"
302+
) from err
303+
304+
return getattr(
305+
langchain_module,
306+
PROVIDER_LLM_LOOKUP[self._provider],
307+
)
308+
283309
def get_llm(self) -> BaseChatModel:
284310
"""Get the LangChain LLM instance for the assistant.
285311
By default, this uses the OpenAI implementation.\n
@@ -293,31 +319,7 @@ def get_llm(self) -> BaseChatModel:
293319
temperature = self.get_temperature()
294320
model_kwargs = self.get_model_kwargs()
295321

296-
llm_class = None
297-
valid_providers_list = [provider.value for provider in ProvidersEnum]
298-
if self._provider in valid_providers_list:
299-
try:
300-
if self._provider == ProvidersEnum.OPENAI.value:
301-
from langchain_openai import ChatOpenAI
302-
303-
llm_class = ChatOpenAI
304-
elif self._provider == ProvidersEnum.ANTHROPIC.value:
305-
from langchain_anthropic import ChatAnthropic
306-
307-
llm_class = ChatAnthropic
308-
else:
309-
raise ImportError
310-
except ImportError as err:
311-
raise ImportError(
312-
f"'langchain_{self._provider}' is required to use this provider. "
313-
f"Install it with: pip install django-ai-assistant[{self._provider}]"
314-
) from err
315-
else:
316-
raise AIAssistantMisconfiguredError(
317-
f"Invalid provider={self._provider}, please use one "
318-
"of the supported providers: "
319-
f"{valid_providers_list}"
320-
)
322+
llm_class = self._import_llm_class()
321323

322324
if temperature is not None:
323325
return llm_class(
@@ -343,7 +345,7 @@ def get_structured_output_llm(self) -> Runnable:
343345
llm = self.get_llm()
344346

345347
method = "json_mode"
346-
if self._provider == ProvidersEnum.OPENAI.value:
348+
if self._provider == "openai":
347349
# When using ChatOpenAI, it's better to use json_schema method
348350
# because it enables strict mode.
349351
# https://platform.openai.com/docs/guides/structured-outputs

tests/test_helpers/test_assistants.py

Lines changed: 61 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,9 @@
1111
)
1212
from langchain_core.retrievers import BaseRetriever
1313

14+
from django_ai_assistant.exceptions import (
15+
AIAssistantMisconfiguredError,
16+
)
1417
from django_ai_assistant.helpers.assistants import AIAssistant
1518
from django_ai_assistant.langchain.tools import BaseModel, Field, method_tool
1619
from django_ai_assistant.models import Thread
@@ -384,6 +387,64 @@ def get_temperature(self) -> float | None:
384387
AIAssistant.clear_cls_registry()
385388

386389

390+
@patch("langchain_anthropic.ChatAnthropic")
391+
def test_AIAssistant_get_llm_anthropic_provider(mock_chat_anthropic):
392+
class AnthropicAIAssistant(AIAssistant):
393+
id = "override_anthropic_assistant" # noqa: A003
394+
name = "Override Anthropic Assistant"
395+
instructions = "Instructions"
396+
model = "gpt-test"
397+
398+
assistant = AnthropicAIAssistant(provider="anthropic")
399+
assistant.get_llm()
400+
401+
mock_chat_anthropic.assert_called_once_with(
402+
model="gpt-test",
403+
temperature=1.0,
404+
model_kwargs={},
405+
)
406+
407+
AIAssistant.clear_cls_registry()
408+
409+
410+
def test_AIAssistant_get_llm_invalid_provider():
411+
class InvalidAIAssistant(AIAssistant):
412+
id = "override_invalid_assistant" # noqa: A003
413+
name = "Override Invalid Assistant"
414+
instructions = "Instructions"
415+
model = "gpt-test"
416+
417+
assistant = InvalidAIAssistant(provider="invalid")
418+
with pytest.raises(AIAssistantMisconfiguredError):
419+
assistant.get_llm()
420+
421+
422+
def test_AIAssistant_get_llm_uninstalled_provider(monkeypatch):
423+
class UninstalledAIAssistant(AIAssistant):
424+
id = "override_uninstalled_assistant" # noqa: A003
425+
name = "Override Uninstalled Assistant"
426+
instructions = "Instructions"
427+
model = "gpt-test"
428+
429+
assistant = UninstalledAIAssistant(provider="uninstalled")
430+
431+
# Simulates a scenario where the user tries to use a valid provider
432+
# that isn't installed with lib (i.e.: user tries to access the
433+
# openai provider, but langchain_openai isn't installed)
434+
from django_ai_assistant.helpers import assistants
435+
436+
monkeypatch.setattr(
437+
assistants,
438+
"PROVIDER_LLM_LOOKUP",
439+
{
440+
"uninstalled": "UninstalledChat",
441+
},
442+
)
443+
444+
with pytest.raises(ImportError):
445+
assistant.get_llm()
446+
447+
387448
@pytest.mark.vcr
388449
def test_AIAssistant_pydantic_structured_output():
389450
from pydantic import BaseModel

0 commit comments

Comments
 (0)