Skip to content

Commit

Permalink
Merge pull request #87 from vintasoftware/fix/move-modules
Browse files Browse the repository at this point in the history
Restructure modules: separate use cases from API
  • Loading branch information
fjsj authored Jun 17, 2024
2 parents 3f2663b + dcb8dcb commit 9579846
Show file tree
Hide file tree
Showing 33 changed files with 340 additions and 317 deletions.
17 changes: 16 additions & 1 deletion django_ai_assistant/__init__.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,19 @@
from importlib import metadata

from django_ai_assistant.helpers.assistants import ( # noqa
AIAssistant,
register_assistant,
)
from django_ai_assistant.langchain.tools import ( # noqa
BaseModel,
BaseTool,
Field,
StructuredTool,
Tool,
method_tool,
tool,
)

__version__ = metadata.version(__package__)

version = __version__ = metadata.version(__package__)
package_name = __package__
2 changes: 1 addition & 1 deletion django_ai_assistant/admin.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
from django.contrib import admin

from .models import Message, Thread
from django_ai_assistant.models import Message, Thread


@admin.register(Thread)
Expand Down
File renamed without changes.
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@

from ninja import Field, ModelSchema, Schema

from .models import Thread
from django_ai_assistant.models import Thread


class AssistantSchema(Schema):
Expand Down
53 changes: 28 additions & 25 deletions django_ai_assistant/views.py → django_ai_assistant/api/views.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,30 +4,29 @@

from langchain_core.messages import message_to_dict
from ninja import NinjaAPI
from ninja.operation import Operation

from django_ai_assistant import __package__, __version__

from .exceptions import AIUserNotAllowedError
from .helpers import assistants
from .helpers.assistants import (
create_message,
get_assistants_info,
get_single_assistant_info,
get_single_thread,
get_thread_messages,
get_threads,
)
from .models import Message, Thread
from .schemas import (
from django_ai_assistant import package_name, version
from django_ai_assistant.api.schemas import (
AssistantSchema,
ThreadMessagesSchemaIn,
ThreadMessagesSchemaOut,
ThreadSchema,
ThreadSchemaIn,
)
from django_ai_assistant.exceptions import AIUserNotAllowedError
from django_ai_assistant.helpers import use_cases
from django_ai_assistant.models import Message, Thread


class API(NinjaAPI):
# Force "operationId" to be like "django_ai_assistant_delete_thread"
def get_openapi_operation_id(self, operation: Operation) -> str:
name = operation.view_func.__name__
return (package_name + "_" + name).replace(".", "_")


api = NinjaAPI(title=__package__, version=__version__, urls_namespace="django_ai_assistant")
api = API(title=package_name, version=version, urls_namespace="django_ai_assistant")


@api.exception_handler(AIUserNotAllowedError)
Expand All @@ -41,42 +40,44 @@ def ai_user_not_allowed_handler(request, exc):

@api.get("assistants/", response=List[AssistantSchema], url_name="assistants_list")
def list_assistants(request):
return list(get_assistants_info(user=request.user, request=request))
return list(use_cases.get_assistants_info(user=request.user, request=request))


@api.get("assistants/{assistant_id}/", response=AssistantSchema, url_name="assistant_detail")
def get_assistant(request, assistant_id: str):
return get_single_assistant_info(assistant_id=assistant_id, user=request.user, request=request)
return use_cases.get_single_assistant_info(
assistant_id=assistant_id, user=request.user, request=request
)


@api.get("threads/", response=List[ThreadSchema], url_name="threads_list_create")
def list_threads(request):
return list(get_threads(user=request.user, request=request))
return list(use_cases.get_threads(user=request.user, request=request))


@api.post("threads/", response=ThreadSchema, url_name="threads_list_create")
def create_thread(request, payload: ThreadSchemaIn):
name = payload.name
return assistants.create_thread(name=name, user=request.user, request=request)
return use_cases.create_thread(name=name, user=request.user, request=request)


@api.get("threads/{thread_id}/", response=ThreadSchema, url_name="thread_detail_update_delete")
def get_thread(request, thread_id: str):
thread = get_single_thread(thread_id=thread_id, user=request.user, request=request)
thread = use_cases.get_single_thread(thread_id=thread_id, user=request.user, request=request)
return thread


@api.patch("threads/{thread_id}/", response=ThreadSchema, url_name="thread_detail_update_delete")
def update_thread(request, thread_id: str, payload: ThreadSchemaIn):
thread = get_object_or_404(Thread, id=thread_id)
name = payload.name
return assistants.update_thread(thread=thread, name=name, user=request.user, request=request)
return use_cases.update_thread(thread=thread, name=name, user=request.user, request=request)


@api.delete("threads/{thread_id}/", response={204: None}, url_name="thread_detail_update_delete")
def delete_thread(request, thread_id: str):
thread = get_object_or_404(Thread, id=thread_id)
assistants.delete_thread(thread=thread, user=request.user, request=request)
use_cases.delete_thread(thread=thread, user=request.user, request=request)
return 204, None


Expand All @@ -86,7 +87,9 @@ def delete_thread(request, thread_id: str):
url_name="messages_list_create",
)
def list_thread_messages(request, thread_id: str):
messages = get_thread_messages(thread_id=thread_id, user=request.user, request=request)
messages = use_cases.get_thread_messages(
thread_id=thread_id, user=request.user, request=request
)
return [message_to_dict(m)["data"] for m in messages]


Expand All @@ -99,7 +102,7 @@ def list_thread_messages(request, thread_id: str):
def create_thread_message(request, thread_id: str, payload: ThreadMessagesSchemaIn):
thread = Thread.objects.get(id=thread_id)

create_message(
use_cases.create_message(
assistant_id=payload.assistant_id,
thread=thread,
user=request.user,
Expand All @@ -114,7 +117,7 @@ def create_thread_message(request, thread_id: str, payload: ThreadMessagesSchema
)
def delete_thread_message(request, thread_id: str, message_id: str):
message = get_object_or_404(Message, id=message_id, thread_id=thread_id)
assistants.delete_message(
use_cases.delete_message(
message=message,
user=request.user,
request=request,
Expand Down
172 changes: 5 additions & 167 deletions django_ai_assistant/helpers/assistants.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,8 +3,6 @@
import re
from typing import Any, ClassVar, Sequence, cast

from django.http import HttpRequest

from langchain.agents import AgentExecutor
from langchain.agents.format_scratchpad.tools import (
format_to_tool_messages,
Expand All @@ -15,7 +13,6 @@
DEFAULT_DOCUMENT_SEPARATOR,
)
from langchain_core.chat_history import InMemoryChatMessageHistory
from langchain_core.messages import BaseMessage, HumanMessage
from langchain_core.output_parsers import StrOutputParser
from langchain_core.prompts import (
ChatPromptTemplate,
Expand All @@ -37,22 +34,11 @@
from langchain_core.tools import BaseTool
from langchain_openai import ChatOpenAI

from django_ai_assistant.ai.chat_message_histories import DjangoChatMessageHistory
from django_ai_assistant.exceptions import (
AIAssistantMisconfiguredError,
AIAssistantNotDefinedError,
AIUserNotAllowedError,
)
from django_ai_assistant.models import Message, Thread
from django_ai_assistant.permissions import (
can_create_message,
can_create_thread,
can_delete_message,
can_delete_thread,
can_run_assistant,
)
from django_ai_assistant.tools import Tool
from django_ai_assistant.tools import tool as tool_decorator
from django_ai_assistant.langchain.tools import Tool
from django_ai_assistant.langchain.tools import tool as tool_decorator


class AIAssistant(abc.ABC): # noqa: F821
Expand Down Expand Up @@ -156,6 +142,9 @@ def get_prompt_template(self):
)

def get_message_history(self, thread_id: int | None):
# DjangoChatMessageHistory must be here because Django may not be loaded yet elsewhere:
from django_ai_assistant.langchain.chat_message_histories import DjangoChatMessageHistory

if thread_id is None:
return InMemoryChatMessageHistory()
return DjangoChatMessageHistory(thread_id)
Expand Down Expand Up @@ -309,154 +298,3 @@ def as_tool(self, description) -> BaseTool:
def register_assistant(cls: type[AIAssistant]):
ASSISTANT_CLS_REGISTRY[cls.id] = cls
return cls


def _get_assistant_cls(
assistant_id: str,
user: Any,
request: HttpRequest | None = None,
):
if assistant_id not in ASSISTANT_CLS_REGISTRY:
raise AIAssistantNotDefinedError(f"Assistant with id={assistant_id} not found")
assistant_cls = ASSISTANT_CLS_REGISTRY[assistant_id]
if not can_run_assistant(
assistant_cls=assistant_cls,
user=user,
request=request,
):
raise AIUserNotAllowedError("User is not allowed to use this assistant")
return assistant_cls


def get_single_assistant_info(
assistant_id: str,
user: Any,
request: HttpRequest | None = None,
):
assistant_cls = _get_assistant_cls(assistant_id, user, request)

return {
"id": assistant_id,
"name": assistant_cls.name,
}


def get_assistants_info(
user: Any,
request: HttpRequest | None = None,
):
return [
_get_assistant_cls(assistant_id=assistant_id, user=user, request=request)
for assistant_id in ASSISTANT_CLS_REGISTRY.keys()
]


def create_message(
assistant_id: str,
thread: Thread,
user: Any,
content: Any,
request: HttpRequest | None = None,
):
assistant_cls = _get_assistant_cls(assistant_id, user, request)

if not can_create_message(thread=thread, user=user, request=request):
raise AIUserNotAllowedError("User is not allowed to create messages in this thread")

# TODO: Check if we can separate the message creation from the chain invoke
assistant = assistant_cls(user=user, request=request)
assistant_message = assistant.invoke(
{"input": content},
thread_id=thread.id,
)
return assistant_message


def create_thread(
name: str,
user: Any,
request: HttpRequest | None = None,
):
if not can_create_thread(user=user, request=request):
raise AIUserNotAllowedError("User is not allowed to create threads")

thread = Thread.objects.create(name=name, created_by=user)
return thread


def get_single_thread(
thread_id: str,
user: Any,
request: HttpRequest | None = None,
):
return Thread.objects.filter(created_by=user).get(id=thread_id)


def get_threads(
user: Any,
request: HttpRequest | None = None,
):
return list(Thread.objects.filter(created_by=user))


def update_thread(
thread: Thread,
name: str,
user: Any,
request: HttpRequest | None = None,
):
if not can_delete_thread(thread=thread, user=user, request=request):
raise AIUserNotAllowedError("User is not allowed to update this thread")

thread.name = name
thread.save()
return thread


def delete_thread(
thread: Thread,
user: Any,
request: HttpRequest | None = None,
):
if not can_delete_thread(thread=thread, user=user, request=request):
raise AIUserNotAllowedError("User is not allowed to delete this thread")

return thread.delete()


def get_thread_messages(
thread_id: str,
user: Any,
request: HttpRequest | None = None,
) -> list[BaseMessage]:
# TODO: have more permissions for threads? View thread permission?
thread = Thread.objects.get(id=thread_id)
if user != thread.created_by:
raise AIUserNotAllowedError("User is not allowed to view messages in this thread")

return DjangoChatMessageHistory(thread.id).get_messages()


def create_thread_message_as_user(
thread_id: str,
content: str,
user: Any,
request: HttpRequest | None = None,
):
# TODO: have more permissions for threads? View thread permission?
thread = Thread.objects.get(id=thread_id)
if user != thread.created_by:
raise AIUserNotAllowedError("User is not allowed to create messages in this thread")

DjangoChatMessageHistory(thread.id).add_messages([HumanMessage(content=content)])


def delete_message(
message: Message,
user: Any,
request: HttpRequest | None = None,
):
if not can_delete_message(message=message, user=user, request=request):
raise AIUserNotAllowedError("User is not allowed to delete this message")

return DjangoChatMessageHistory(thread_id=message.thread_id).remove_messages([str(message.id)])
Loading

0 comments on commit 9579846

Please sign in to comment.