Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Small refactor on DjangoChatMessageHistory:add_messages #76

Merged
merged 1 commit into from
Jun 14, 2024
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
30 changes: 14 additions & 16 deletions django_ai_assistant/ai/chat_message_histories.py
Original file line number Diff line number Diff line change
Expand Up @@ -71,16 +71,15 @@ def add_messages(self, messages: Sequence[BaseMessage]) -> None:
messages: A list of BaseMessage objects to store.
"""
with transaction.atomic():
message_objects = [
Message(thread_id=self._thread_id, message=message_to_dict(message))
for message in messages
]

created_messages = Message.objects.bulk_create(message_objects)
created_messages = Message.objects.bulk_create(
[Message(thread_id=self._thread_id, message=dict()) for message in messages]
)

# Update langchain message IDs with Django message IDs
for created_message in created_messages:
created_message.message["data"]["id"] = str(created_message.id)
for idx, created_message in enumerate(created_messages):
message_with_id = messages[idx]
message_with_id.id = str(created_message.id)
created_message.message = message_to_dict(message_with_id)

Message.objects.bulk_update(created_messages, ["message"])

Expand All @@ -92,16 +91,15 @@ async def aadd_messages(self, messages: Sequence[BaseMessage]) -> None:
"""
# NOTE: This method does not use transactions because it do not yet work in async mode.
# Source: https://docs.djangoproject.com/en/5.0/topics/async/#queries-the-orm
message_objects = [
Message(thread_id=self._thread_id, message=message_to_dict(message))
for message in messages
]

created_messages = await Message.objects.abulk_create(message_objects)
created_messages = await Message.objects.abulk_create(
[Message(thread_id=self._thread_id, message=dict()) for message in messages]
)

# Update langchain message IDs with Django message IDs
for created_message in created_messages:
created_message.message["data"]["id"] = str(created_message.id)
for idx, created_message in enumerate(created_messages):
message_with_id = messages[idx]
message_with_id.id = str(created_message.id)
created_message.message = message_to_dict(message_with_id)

await Message.objects.abulk_update(created_messages, ["message"])

Expand Down