Skip to content

Commit

Permalink
Move all examples to a single Django project
Browse files Browse the repository at this point in the history
  • Loading branch information
fjsj committed Jun 14, 2024
1 parent 8bace80 commit bf8a887
Show file tree
Hide file tree
Showing 96 changed files with 341 additions and 22,439 deletions.
81 changes: 49 additions & 32 deletions django_ai_assistant/helpers/assistants.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,6 @@
from typing import Any, ClassVar, Sequence, cast

from django.http import HttpRequest
from django.views import View

from langchain.agents import AgentExecutor
from langchain.agents.format_scratchpad.tools import (
Expand Down Expand Up @@ -223,7 +222,7 @@ def get_history_aware_retriever(self) -> Runnable[dict, RetrieverOutput]:
def as_chain(self, thread_id: int | None) -> Runnable[dict, dict]:
# Based on:
# - https://python.langchain.com/v0.2/docs/how_to/qa_chat_history_how_to/
# - https://python.langchain.com/v0.2/docs/how_to/migrate_agent/#memory
# - https://python.langchain.com/v0.2/docs/how_to/migrate_agent/
# TODO: use langgraph instead?
llm = self.get_llm()
tools = self.get_tools()
Expand Down Expand Up @@ -312,18 +311,43 @@ def register_assistant(cls: type[AIAssistant]):
return cls


def _get_assistant_cls(
assistant_id: str,
user: Any,
request: HttpRequest | None = None,
):
if assistant_id not in ASSISTANT_CLS_REGISTRY:
raise AIAssistantNotDefinedError(f"Assistant with id={assistant_id} not found")
assistant_cls = ASSISTANT_CLS_REGISTRY[assistant_id]
if not can_run_assistant(
assistant_cls=assistant_cls,
user=user,
request=request,
):
raise AIUserNotAllowedError("User is not allowed to use this assistant")
return assistant_cls


def get_single_assistant_info(
assistant_id: str,
user: Any,
request: HttpRequest | None = None,
):
assistant_cls = _get_assistant_cls(assistant_id, user, request)

return {
"id": assistant_id,
"name": assistant_cls.name,
}


def get_assistants_info(
user: Any,
request: HttpRequest | None = None,
view: View | None = None,
):
return [
{
"id": assistant_id,
"name": assistant_cls.name,
}
for assistant_id, assistant_cls in ASSISTANT_CLS_REGISTRY.items()
if can_run_assistant(assistant_cls=assistant_cls, user=user, request=request, view=view)
_get_assistant_cls(assistant_id=assistant_id, user=user, request=request)
for assistant_id in ASSISTANT_CLS_REGISTRY.keys()
]


Expand All @@ -333,23 +357,14 @@ def create_message(
user: Any,
content: Any,
request: HttpRequest | None = None,
view: View | None = None,
):
if not can_create_message(thread=thread, user=user, request=request, view=view):
assistant_cls = _get_assistant_cls(assistant_id, user, request)

if not can_create_message(thread=thread, user=user, request=request):
raise AIUserNotAllowedError("User is not allowed to create messages in this thread")
if assistant_id not in ASSISTANT_CLS_REGISTRY:
raise AIAssistantNotDefinedError(f"Assistant with id={assistant_id} not found")
assistant_cls = ASSISTANT_CLS_REGISTRY[assistant_id]
if not can_run_assistant(
assistant_cls=assistant_cls,
user=user,
request=request,
view=view,
):
raise AIUserNotAllowedError("User is not allowed to use this assistant")

# TODO: Check if we can separate the message creation from the chain invoke
assistant = assistant_cls(user=user, request=request, view=view)
assistant = assistant_cls(user=user, request=request)
assistant_message = assistant.invoke(
{"input": content},
thread_id=thread.id,
Expand All @@ -361,19 +376,25 @@ def create_thread(
name: str,
user: Any,
request: HttpRequest | None = None,
view: View | None = None,
):
if not can_create_thread(user=user, request=request, view=view):
if not can_create_thread(user=user, request=request):
raise AIUserNotAllowedError("User is not allowed to create threads")

thread = Thread.objects.create(name=name, created_by=user)
return thread


def get_single_thread(
thread_id: str,
user: Any,
request: HttpRequest | None = None,
):
return Thread.objects.filter(created_by=user).get(id=thread_id)


def get_threads(
user: Any,
request: HttpRequest | None = None,
view: View | None = None,
):
return list(Thread.objects.filter(created_by=user))

Expand All @@ -382,9 +403,8 @@ def delete_thread(
thread: Thread,
user: Any,
request: HttpRequest | None = None,
view: View | None = None,
):
if not can_delete_thread(thread=thread, user=user, request=request, view=view):
if not can_delete_thread(thread=thread, user=user, request=request):
raise AIUserNotAllowedError("User is not allowed to delete this thread")

return thread.delete()
Expand All @@ -394,7 +414,6 @@ def get_thread_messages(
thread_id: str,
user: Any,
request: HttpRequest | None = None,
view: View | None = None,
) -> list[BaseMessage]:
# TODO: have more permissions for threads? View thread permission?
thread = Thread.objects.get(id=thread_id)
Expand All @@ -409,7 +428,6 @@ def create_thread_message_as_user(
content: str,
user: Any,
request: HttpRequest | None = None,
view: View | None = None,
):
# TODO: have more permissions for threads? View thread permission?
thread = Thread.objects.get(id=thread_id)
Expand All @@ -423,9 +441,8 @@ def delete_message(
message: Message,
user: Any,
request: HttpRequest | None = None,
view: View | None = None,
):
if not can_delete_message(message=message, user=user, request=request, view=view):
if not can_delete_message(message=message, user=user, request=request):
raise AIUserNotAllowedError("User is not allowed to delete this message")

return DjangoChatMessageHistory(thread_id=message.thread_id).remove_messages([message.id])
return DjangoChatMessageHistory(thread_id=message.thread_id).remove_messages([str(message.id)])
33 changes: 18 additions & 15 deletions django_ai_assistant/views.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,8 @@
from .helpers.assistants import (
create_message,
get_assistants_info,
get_single_assistant_info,
get_single_thread,
get_thread_messages,
get_threads,
)
Expand Down Expand Up @@ -39,29 +41,35 @@ def ai_user_not_allowed_handler(request, exc):

@api.get("assistants/", response=List[AssistantSchema], url_name="assistants_list")
def list_assistants(request):
return list(get_assistants_info(user=request.user, request=request, view=None))
return list(get_assistants_info(user=request.user, request=request))


@api.get("assistants/{assistant_id}/", response=AssistantSchema, url_name="assistant_detail")
def get_assistant(request, assistant_id: str):
return get_single_assistant_info(assistant_id=assistant_id, user=request.user, request=request)


@api.get("threads/", response=List[ThreadSchema], url_name="threads_list_create")
def list_threads(request):
return list(get_threads(user=request.user, request=request, view=None))
return list(get_threads(user=request.user, request=request))


@api.get("threads/{thread_id}/", response=ThreadSchema, url_name="thread_detail")
def get_thread(request, thread_id: str):
thread = get_single_thread(thread_id=thread_id, user=request.user, request=request)
return thread


@api.post("threads/", response=ThreadSchema, url_name="threads_list_create")
def create_thread(request, payload: ThreadSchemaIn):
name = payload.name
return assistants.create_thread(name=name, user=request.user, request=request, view=None)
return assistants.create_thread(name=name, user=request.user, request=request)


@api.delete("threads/{thread_id}/", response={204: None}, url_name="threads_delete")
def delete_thread(request, thread_id: str):
thread = get_object_or_404(Thread, id=thread_id)
assistants.delete_thread(
thread=thread,
user=request.user,
request=request,
view=None,
)
assistants.delete_thread(thread=thread, user=request.user, request=request)
return 204, None


Expand All @@ -71,9 +79,7 @@ def delete_thread(request, thread_id: str):
url_name="messages_list_create",
)
def list_thread_messages(request, thread_id: str):
messages = get_thread_messages(
thread_id=thread_id, user=request.user, request=request, view=None
)
messages = get_thread_messages(thread_id=thread_id, user=request.user, request=request)
return [message_to_dict(m)["data"] for m in messages]


Expand All @@ -92,9 +98,7 @@ def create_thread_message(request, thread_id: str, payload: ThreadMessagesSchema
user=request.user,
content=payload.content,
request=request,
view=None,
)

return 201, None


Expand All @@ -107,6 +111,5 @@ def delete_thread_message(request, thread_id: str, message_id: str):
message=message,
user=request.user,
request=request,
view=None,
)
return 204, None
60 changes: 57 additions & 3 deletions example/assets/js/App.tsx
Original file line number Diff line number Diff line change
@@ -1,19 +1,73 @@
import "@mantine/core/styles.css";

import { createTheme, MantineProvider } from "@mantine/core";
import { Chat } from "@/Chat";
import { Container, createTheme, MantineProvider } from "@mantine/core";
import { Chat } from "@/components";
import { createBrowserRouter, Link, RouterProvider } from "react-router-dom";
import { configAIAssistant } from "django-ai-assistant-client";
import React from "react";

const theme = createTheme({});

// Relates to path("ai-assistant/", include("django_ai_assistant.urls"))
// which can be found at example/demo/urls.py)
configAIAssistant({ baseURL: "ai-assistant" });

const ExampleIndex = () => {
return (
<Container>
<h1>Examples</h1>
<ul>
<li>
<Link to="/weather-chat">Weather Chat</Link>
</li>
<li>
<Link to="/movies-chat">Movie Recommendation Chat</Link>
</li>
<li>
<Link to="/rag-chat">Django Docs RAG Chat</Link>
</li>
<li>
<Link to="/htmx">HTMX demo (no React)</Link>
</li>
</ul>
</Container>
);
};

const Redirect = ({ to }: { to: string }) => {
window.location.href = to;
return null;
};

const router = createBrowserRouter([
{
path: "/",
element: <ExampleIndex />,
},
{
path: "/weather-chat",
element: <Chat assistantId="weather_assistant" />,
},
{
path: "/movies-chat",
element: <Chat assistantId="movie_recommendation_assistant" />,
},
{
path: "/rag-chat",
element: <Chat assistantId="django_docs_assistant" />,
},
{
path: "/htmx",
element: <Redirect to="/htmx/" />,
},
]);

const App = () => {
return (
<MantineProvider theme={theme}>
<Chat />
<React.StrictMode>
<RouterProvider router={router} />
</React.StrictMode>
</MantineProvider>
);
};
Expand Down
1 change: 0 additions & 1 deletion example/assets/js/Chat/index.ts

This file was deleted.

Original file line number Diff line number Diff line change
Expand Up @@ -13,3 +13,7 @@
.chat {
height: 100%;
}

.mdMessage p {
margin: 0;
}
Loading

0 comments on commit bf8a887

Please sign in to comment.