Skip to content

Commit 91a2509

Browse files
authored
Merge pull request #211 from vintasoftware/fix-bulk-create-messages
Fix save_django_messages on DBs where can_return_rows_from_bulk_insert=False
2 parents bd5a9b7 + 7ceaf07 commit 91a2509

2 files changed

Lines changed: 70 additions & 4 deletions

File tree

django_ai_assistant/helpers/django_messages.py

Lines changed: 18 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
from typing import TYPE_CHECKING
22

3-
from django.db import transaction
3+
from django.db import connections, transaction
44

55
from langchain_core.messages import (
66
BaseMessage,
@@ -35,9 +35,23 @@ def save_django_messages(messages: list[BaseMessage], thread: "Thread") -> list[
3535

3636
messages_to_create = [m for m in messages if m.id not in existing_message_ids]
3737

38-
created_messages = DjangoMessage.objects.bulk_create(
39-
[DjangoMessage(thread=thread, message={}) for _ in messages_to_create]
40-
)
38+
# Insert in bulk only if primary keys are then assigned by the DB.
39+
# Please check https://docs.djangoproject.com/en/4.0/ref/models/querysets/#django.db.models.query.QuerySet.bulk_create
40+
# for more context on why this is required
41+
can_bulk_insert = connections[
42+
DjangoMessage.objects.db
43+
].features.can_return_rows_from_bulk_insert
44+
if can_bulk_insert:
45+
created_messages = DjangoMessage.objects.bulk_create(
46+
[DjangoMessage(thread=thread, message={}) for _ in messages_to_create],
47+
)
48+
else:
49+
for message in (
50+
created_messages := [
51+
DjangoMessage(thread=thread, message={}) for _ in messages_to_create
52+
]
53+
):
54+
message.save()
4155

4256
# Update langchain message IDs with Django message IDs
4357
for idx, created_message in enumerate(created_messages):
Lines changed: 52 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,52 @@
1+
from unittest.mock import patch
2+
3+
from django.contrib.auth.models import User
4+
from django.db import connection, connections
5+
from django.db.backends.sqlite3.features import DatabaseFeatures
6+
7+
import pytest
8+
from langchain_core.messages import HumanMessage
9+
from model_bakery import baker
10+
11+
from django_ai_assistant.helpers.django_messages import save_django_messages
12+
from django_ai_assistant.models import Message, Thread
13+
14+
15+
@pytest.mark.django_db()
16+
def test_django_messages_with_can_return_rows_from_bulk_insert_true():
17+
class MockFeatures(DatabaseFeatures):
18+
can_return_rows_from_bulk_insert = True
19+
20+
mock_features = MockFeatures(connections[Message.objects.db])
21+
22+
thread = baker.make(Thread, created_by=baker.make(User))
23+
with patch.object(
24+
Message.objects,
25+
"bulk_create",
26+
wraps=Message.objects.bulk_create,
27+
) as mock_bulk_create:
28+
with patch.object(connection, "features", mock_features):
29+
save_django_messages([HumanMessage(content="Hello")], thread=thread)
30+
mock_bulk_create.assert_called_once()
31+
assert Message.objects.count() == 1
32+
assert Message.objects.first().message["data"]["content"] == "Hello"
33+
34+
35+
@pytest.mark.django_db()
36+
def test_django_messages_with_can_return_rows_from_bulk_insert_false():
37+
class MockFeatures(DatabaseFeatures):
38+
can_return_rows_from_bulk_insert = False
39+
40+
mock_features = MockFeatures(connections[Message.objects.db])
41+
42+
thread = baker.make(Thread, created_by=baker.make(User))
43+
with patch.object(
44+
Message.objects,
45+
"bulk_create",
46+
wraps=Message.objects.bulk_create,
47+
) as mock_bulk_create:
48+
with patch.object(connection, "features", mock_features):
49+
save_django_messages([HumanMessage(content="Hello")], thread=thread)
50+
mock_bulk_create.assert_not_called()
51+
assert Message.objects.count() == 1
52+
assert Message.objects.first().message["data"]["content"] == "Hello"

0 commit comments

Comments
 (0)