Skip to content

Commit

Permalink
Adds decorator to helper functions
Browse files Browse the repository at this point in the history
  • Loading branch information
amandasavluchinske committed Jun 21, 2024
1 parent 0e55aa1 commit a7fc6d9
Show file tree
Hide file tree
Showing 5 changed files with 17 additions and 6 deletions.
5 changes: 5 additions & 0 deletions django_ai_assistant/decorators.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@ def with_cast_id(func):
def wrapper(*args, **kwargs):
thread_id = kwargs.get("thread_id")
message_id = kwargs.get("message_id")
message_ids = kwargs.get("message_ids")

if thread_id:
thread_id = cast_id(thread_id, Thread)
Expand All @@ -18,6 +19,10 @@ def wrapper(*args, **kwargs):
message_id = cast_id(message_id, Message)
kwargs["message_id"] = message_id

if message_ids:
message_ids = [cast_id(message_id, Message) for message_id in message_ids]
kwargs["message_ids"] = message_ids

return func(*args, **kwargs)

return wrapper
5 changes: 5 additions & 0 deletions django_ai_assistant/helpers/assistants.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,7 @@
from langchain_core.tools import BaseTool
from langchain_openai import ChatOpenAI

from django_ai_assistant.decorators import with_cast_id
from django_ai_assistant.exceptions import (
AIAssistantMisconfiguredError,
)
Expand Down Expand Up @@ -279,6 +280,7 @@ def get_prompt_template(self) -> ChatPromptTemplate:
]
)

@with_cast_id
def get_message_history(self, thread_id: Any | None) -> BaseChatMessageHistory:
"""Get the chat message history instance for the given `thread_id`.\n
The Langchain chain uses the return of this method to get the thread messages
Expand Down Expand Up @@ -430,6 +432,7 @@ def get_history_aware_retriever(self) -> Runnable[dict, RetrieverOutput]:
prompt | llm | StrOutputParser() | retriever,
)

@with_cast_id
def as_chain(self, thread_id: Any | None) -> Runnable[dict, dict]:
"""Create the Langchain chain for the assistant.\n
This chain is an agent that supports chat history, tool calling, and RAG (if `has_rag=True`).\n
Expand Down Expand Up @@ -514,6 +517,7 @@ def as_chain(self, thread_id: Any | None) -> Runnable[dict, dict]:

return agent_with_chat_history

@with_cast_id
def invoke(self, *args: Any, thread_id: Any | None, **kwargs: Any) -> dict:
"""Invoke the assistant Langchain chain with the given arguments and keyword arguments.\n
This is the lower-level method to run the assistant.\n
Expand All @@ -533,6 +537,7 @@ def invoke(self, *args: Any, thread_id: Any | None, **kwargs: Any) -> dict:
chain = self.as_chain(thread_id)
return chain.invoke(*args, **kwargs)

@with_cast_id
def run(self, message: str, thread_id: Any | None, **kwargs: Any) -> str:
"""Run the assistant with the given message and thread ID.\n
This is the higher-level method to run the assistant.\n
Expand Down
8 changes: 6 additions & 2 deletions django_ai_assistant/langchain/chat_message_histories.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,13 +14,15 @@
from langchain_core.chat_history import BaseChatMessageHistory
from langchain_core.messages import BaseMessage, message_to_dict, messages_from_dict

from django_ai_assistant.decorators import with_cast_id
from django_ai_assistant.models import Message


logger = logging.getLogger(__name__)


class DjangoChatMessageHistory(BaseChatMessageHistory):
@with_cast_id
def __init__(
self,
thread_id: Any,
Expand Down Expand Up @@ -103,15 +105,17 @@ async def aadd_messages(self, messages: Sequence[BaseMessage]) -> None:

await Message.objects.abulk_update(created_messages, ["message"])

def remove_messages(self, message_ids: List[str]) -> None:
@with_cast_id
def remove_messages(self, message_ids: List[Any]) -> None:
"""Remove messages from the chat thread.
Args:
message_ids: A list of message IDs to remove.
"""
Message.objects.filter(id__in=message_ids).delete()

async def aremove_messages(self, message_ids: List[str]) -> None:
@with_cast_id
async def aremove_messages(self, message_ids: List[Any]) -> None:
"""Remove messages from the chat thread.
Args:
Expand Down
4 changes: 1 addition & 3 deletions example/example/settings.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,8 +28,7 @@
"django.contrib.messages",
"django.contrib.staticfiles",
"webpack_loader",
# "django_ai_assistant",
"example.apps.AIAssistantConfigOverride",
"django_ai_assistant",
"demo", # contains the views
"weather",
"movies",
Expand Down Expand Up @@ -167,7 +166,6 @@
AI_ASSISTANT_CAN_UPDATE_MESSAGE_FN = "django_ai_assistant.permissions.owns_thread"
AI_ASSISTANT_CAN_DELETE_MESSAGE_FN = "django_ai_assistant.permissions.owns_thread"
AI_ASSISTANT_CAN_RUN_ASSISTANT = "django_ai_assistant.permissions.allow_all"
AI_ASSISTANT_PRIMARY_KEY_FIELD = "uuid" # Options: 'auto', 'uuid', 'string'


# Example specific settings:
Expand Down
1 change: 0 additions & 1 deletion tests/settings.py
Original file line number Diff line number Diff line change
Expand Up @@ -116,4 +116,3 @@
AI_ASSISTANT_CAN_UPDATE_MESSAGE_FN = "django_ai_assistant.permissions.owns_thread"
AI_ASSISTANT_CAN_DELETE_MESSAGE_FN = "django_ai_assistant.permissions.owns_thread"
AI_ASSISTANT_CAN_RUN_ASSISTANT = "django_ai_assistant.permissions.allow_all"
AI_ASSISTANT_PRIMARY_KEY_FIELD = "uuid" # Options: 'auto', 'uuid', 'string'

0 comments on commit a7fc6d9

Please sign in to comment.