diff --git a/django_ai_assistant/helpers/assistants.py b/django_ai_assistant/helpers/assistants.py index fb5a4d3..93763bf 100644 --- a/django_ai_assistant/helpers/assistants.py +++ b/django_ai_assistant/helpers/assistants.py @@ -4,7 +4,6 @@ from typing import Any, ClassVar, Sequence, cast from django.http import HttpRequest -from django.views import View from langchain.agents import AgentExecutor from langchain.agents.format_scratchpad.tools import ( @@ -223,7 +222,7 @@ def get_history_aware_retriever(self) -> Runnable[dict, RetrieverOutput]: def as_chain(self, thread_id: int | None) -> Runnable[dict, dict]: # Based on: # - https://python.langchain.com/v0.2/docs/how_to/qa_chat_history_how_to/ - # - https://python.langchain.com/v0.2/docs/how_to/migrate_agent/#memory + # - https://python.langchain.com/v0.2/docs/how_to/migrate_agent/ # TODO: use langgraph instead? llm = self.get_llm() tools = self.get_tools() @@ -312,18 +311,43 @@ def register_assistant(cls: type[AIAssistant]): return cls +def _get_assistant_cls( + assistant_id: str, + user: Any, + request: HttpRequest | None = None, +): + if assistant_id not in ASSISTANT_CLS_REGISTRY: + raise AIAssistantNotDefinedError(f"Assistant with id={assistant_id} not found") + assistant_cls = ASSISTANT_CLS_REGISTRY[assistant_id] + if not can_run_assistant( + assistant_cls=assistant_cls, + user=user, + request=request, + ): + raise AIUserNotAllowedError("User is not allowed to use this assistant") + return assistant_cls + + +def get_single_assistant_info( + assistant_id: str, + user: Any, + request: HttpRequest | None = None, +): + assistant_cls = _get_assistant_cls(assistant_id, user, request) + + return { + "id": assistant_id, + "name": assistant_cls.name, + } + + def get_assistants_info( user: Any, request: HttpRequest | None = None, - view: View | None = None, ): return [ - { - "id": assistant_id, - "name": assistant_cls.name, - } - for assistant_id, assistant_cls in ASSISTANT_CLS_REGISTRY.items() - if can_run_assistant(assistant_cls=assistant_cls, user=user, request=request, view=view) + _get_assistant_cls(assistant_id=assistant_id, user=user, request=request) + for assistant_id in ASSISTANT_CLS_REGISTRY.keys() ] @@ -333,23 +357,14 @@ def create_message( user: Any, content: Any, request: HttpRequest | None = None, - view: View | None = None, ): - if not can_create_message(thread=thread, user=user, request=request, view=view): + assistant_cls = _get_assistant_cls(assistant_id, user, request) + + if not can_create_message(thread=thread, user=user, request=request): raise AIUserNotAllowedError("User is not allowed to create messages in this thread") - if assistant_id not in ASSISTANT_CLS_REGISTRY: - raise AIAssistantNotDefinedError(f"Assistant with id={assistant_id} not found") - assistant_cls = ASSISTANT_CLS_REGISTRY[assistant_id] - if not can_run_assistant( - assistant_cls=assistant_cls, - user=user, - request=request, - view=view, - ): - raise AIUserNotAllowedError("User is not allowed to use this assistant") # TODO: Check if we can separate the message creation from the chain invoke - assistant = assistant_cls(user=user, request=request, view=view) + assistant = assistant_cls(user=user, request=request) assistant_message = assistant.invoke( {"input": content}, thread_id=thread.id, @@ -361,19 +376,25 @@ def create_thread( name: str, user: Any, request: HttpRequest | None = None, - view: View | None = None, ): - if not can_create_thread(user=user, request=request, view=view): + if not can_create_thread(user=user, request=request): raise AIUserNotAllowedError("User is not allowed to create threads") thread = Thread.objects.create(name=name, created_by=user) return thread +def get_single_thread( + thread_id: str, + user: Any, + request: HttpRequest | None = None, +): + return Thread.objects.filter(created_by=user).get(id=thread_id) + + def get_threads( user: Any, request: HttpRequest | None = None, - view: View | None = None, ): return list(Thread.objects.filter(created_by=user)) @@ -382,9 +403,8 @@ def delete_thread( thread: Thread, user: Any, request: HttpRequest | None = None, - view: View | None = None, ): - if not can_delete_thread(thread=thread, user=user, request=request, view=view): + if not can_delete_thread(thread=thread, user=user, request=request): raise AIUserNotAllowedError("User is not allowed to delete this thread") return thread.delete() @@ -394,7 +414,6 @@ def get_thread_messages( thread_id: str, user: Any, request: HttpRequest | None = None, - view: View | None = None, ) -> list[BaseMessage]: # TODO: have more permissions for threads? View thread permission? thread = Thread.objects.get(id=thread_id) @@ -409,7 +428,6 @@ def create_thread_message_as_user( content: str, user: Any, request: HttpRequest | None = None, - view: View | None = None, ): # TODO: have more permissions for threads? View thread permission? thread = Thread.objects.get(id=thread_id) @@ -423,9 +441,8 @@ def delete_message( message: Message, user: Any, request: HttpRequest | None = None, - view: View | None = None, ): - if not can_delete_message(message=message, user=user, request=request, view=view): + if not can_delete_message(message=message, user=user, request=request): raise AIUserNotAllowedError("User is not allowed to delete this message") - return DjangoChatMessageHistory(thread_id=message.thread_id).remove_messages([message.id]) + return DjangoChatMessageHistory(thread_id=message.thread_id).remove_messages([str(message.id)]) diff --git a/django_ai_assistant/views.py b/django_ai_assistant/views.py index 0813725..e7d56ac 100644 --- a/django_ai_assistant/views.py +++ b/django_ai_assistant/views.py @@ -12,6 +12,8 @@ from .helpers.assistants import ( create_message, get_assistants_info, + get_single_assistant_info, + get_single_thread, get_thread_messages, get_threads, ) @@ -39,29 +41,35 @@ def ai_user_not_allowed_handler(request, exc): @api.get("assistants/", response=List[AssistantSchema], url_name="assistants_list") def list_assistants(request): - return list(get_assistants_info(user=request.user, request=request, view=None)) + return list(get_assistants_info(user=request.user, request=request)) + + +@api.get("assistants/{assistant_id}/", response=AssistantSchema, url_name="assistant_detail") +def get_assistant(request, assistant_id: str): + return get_single_assistant_info(assistant_id=assistant_id, user=request.user, request=request) @api.get("threads/", response=List[ThreadSchema], url_name="threads_list_create") def list_threads(request): - return list(get_threads(user=request.user, request=request, view=None)) + return list(get_threads(user=request.user, request=request)) + + +@api.get("threads/{thread_id}/", response=ThreadSchema, url_name="thread_detail") +def get_thread(request, thread_id: str): + thread = get_single_thread(thread_id=thread_id, user=request.user, request=request) + return thread @api.post("threads/", response=ThreadSchema, url_name="threads_list_create") def create_thread(request, payload: ThreadSchemaIn): name = payload.name - return assistants.create_thread(name=name, user=request.user, request=request, view=None) + return assistants.create_thread(name=name, user=request.user, request=request) @api.delete("threads/{thread_id}/", response={204: None}, url_name="threads_delete") def delete_thread(request, thread_id: str): thread = get_object_or_404(Thread, id=thread_id) - assistants.delete_thread( - thread=thread, - user=request.user, - request=request, - view=None, - ) + assistants.delete_thread(thread=thread, user=request.user, request=request) return 204, None @@ -71,9 +79,7 @@ def delete_thread(request, thread_id: str): url_name="messages_list_create", ) def list_thread_messages(request, thread_id: str): - messages = get_thread_messages( - thread_id=thread_id, user=request.user, request=request, view=None - ) + messages = get_thread_messages(thread_id=thread_id, user=request.user, request=request) return [message_to_dict(m)["data"] for m in messages] @@ -92,9 +98,7 @@ def create_thread_message(request, thread_id: str, payload: ThreadMessagesSchema user=request.user, content=payload.content, request=request, - view=None, ) - return 201, None @@ -107,6 +111,5 @@ def delete_thread_message(request, thread_id: str, message_id: str): message=message, user=request.user, request=request, - view=None, ) return 204, None diff --git a/example/assets/js/App.tsx b/example/assets/js/App.tsx index b9627e1..4c4281a 100644 --- a/example/assets/js/App.tsx +++ b/example/assets/js/App.tsx @@ -1,8 +1,10 @@ import "@mantine/core/styles.css"; -import { createTheme, MantineProvider } from "@mantine/core"; -import { Chat } from "@/Chat"; +import { Container, createTheme, MantineProvider } from "@mantine/core"; +import { Chat } from "@/components"; +import { createBrowserRouter, Link, RouterProvider } from "react-router-dom"; import { configAIAssistant } from "django-ai-assistant-client"; +import React from "react"; const theme = createTheme({}); @@ -10,10 +12,62 @@ const theme = createTheme({}); // which can be found at example/demo/urls.py) configAIAssistant({ baseURL: "ai-assistant" }); +const ExampleIndex = () => { + return ( + +

Examples

+ +
+ ); +}; + +const Redirect = ({ to }: { to: string }) => { + window.location.href = to; + return null; +}; + +const router = createBrowserRouter([ + { + path: "/", + element: , + }, + { + path: "/weather-chat", + element: , + }, + { + path: "/movies-chat", + element: , + }, + { + path: "/rag-chat", + element: , + }, + { + path: "/htmx", + element: , + }, +]); + const App = () => { return ( - + + + ); }; diff --git a/example/assets/js/Chat/index.ts b/example/assets/js/Chat/index.ts deleted file mode 100644 index 279ce39..0000000 --- a/example/assets/js/Chat/index.ts +++ /dev/null @@ -1 +0,0 @@ -export { Chat } from "./Chat"; diff --git a/example/assets/js/Chat/Chat.module.css b/example/assets/js/components/Chat/Chat.module.css similarity index 88% rename from example/assets/js/Chat/Chat.module.css rename to example/assets/js/components/Chat/Chat.module.css index e4f2917..4c9c171 100644 --- a/example/assets/js/Chat/Chat.module.css +++ b/example/assets/js/components/Chat/Chat.module.css @@ -13,3 +13,7 @@ .chat { height: 100%; } + +.mdMessage p { + margin: 0; +} diff --git a/example/assets/js/Chat/Chat.tsx b/example/assets/js/components/Chat/Chat.tsx similarity index 92% rename from example/assets/js/Chat/Chat.tsx rename to example/assets/js/components/Chat/Chat.tsx index 6f66241..3abbf62 100644 --- a/example/assets/js/Chat/Chat.tsx +++ b/example/assets/js/components/Chat/Chat.tsx @@ -14,7 +14,7 @@ import { Title, Tooltip, } from "@mantine/core"; -import { ThreadsNav } from "./ThreadsNav"; +import { ThreadsNav } from "../ThreadsNav/ThreadsNav"; import classes from "./Chat.module.css"; import { useCallback, useEffect, useRef, useState } from "react"; @@ -82,7 +82,8 @@ function ChatMessage({ maw="75%" shadow="none" radius="lg" - p="lg" + p="xs" + px="md" bg="var(--mantine-color-gray-0)" > @@ -128,12 +129,10 @@ function ChatMessageList({ ); } -export function Chat() { - const [assistantId, setAssistantId] = useState(""); +export function Chat({ assistantId }: { assistantId: string }) { const [activeThread, setActiveThread] = useState(null); const [inputValue, setInputValue] = useState(""); - const { fetchAssistants, assistants } = useAssistant(); const { fetchThreads, threads, createThread, deleteThread } = useThread(); const { fetchMessages, @@ -147,8 +146,8 @@ export function Chat() { const loadingMessages = loadingFetchMessages || loadingCreateMessage || loadingDeleteMessage; - const isThreadSelected = assistantId && activeThread; - const isChatActive = assistantId && activeThread && !loadingMessages; + const isThreadSelected = Boolean(activeThread); + const isChatActive = activeThread && !loadingMessages; const scrollViewport = useRef(null); const scrollToBottom = useCallback( @@ -166,15 +165,6 @@ export function Chat() { [scrollViewport] ); - // Load assistantId when component mounts: - useEffect(() => { - if (assistants) { - setAssistantId(assistants[0].id); - } else { - fetchAssistants(); - } - }, [assistants, fetchAssistants]); - // Load threads when component mounts: useEffect(() => { fetchThreads(); @@ -217,7 +207,7 @@ export function Chat() { - Chat + Chat: {assistantId} ", views.react_index), ] diff --git a/example/demo/views.py b/example/demo/views.py index 5a4d2e4..cacc77f 100644 --- a/example/demo/views.py +++ b/example/demo/views.py @@ -3,6 +3,7 @@ from django.views.generic.base import TemplateView from pydantic import ValidationError +from weather.ai_assistants import WeatherAIAssistant from django_ai_assistant.helpers.assistants import ( create_message, @@ -16,10 +17,8 @@ ThreadSchemaIn, ) -from .ai_assistants import WeatherAIAssistant - -def react_index(request): +def react_index(request, **kwargs): return render(request, "demo/react_index.html") @@ -30,7 +29,12 @@ def get_assistant_id(self, **kwargs): def get_context_data(self, **kwargs): context = super().get_context_data(**kwargs) - threads = list(get_threads(user=self.request.user, request=self.request, view=None)) + threads = list( + get_threads( + user=self.request.user, + request=self.request, + ) + ) context.update( { "assistant_id": self.get_assistant_id(**kwargs), @@ -51,7 +55,11 @@ def post(self, request, *args, **kwargs): messages.error(request, "Invalid thread data") return redirect("chat_home") - thread = create_thread(name=thread_data.name, user=request.user, request=request, view=None) + thread = create_thread( + name=thread_data.name, + user=request.user, + request=request, + ) return redirect("chat_thread", thread_id=thread.id) @@ -65,7 +73,6 @@ def get_context_data(self, **kwargs): thread_id=self.kwargs["thread_id"], user=self.request.user, request=self.request, - view=None, ) context.update( { @@ -96,6 +103,5 @@ def post(self, request, *args, **kwargs): user=request.user, content=message.content, request=request, - view=None, ) return redirect("chat_thread", thread_id=thread_id) diff --git a/example/example/settings.py b/example/example/settings.py index 91101b1..a36d38c 100644 --- a/example/example/settings.py +++ b/example/example/settings.py @@ -29,7 +29,10 @@ "django.contrib.staticfiles", "webpack_loader", "django_ai_assistant", - "demo", + "demo", # contains the views + "weather", + "movies", + "rag", ] MIDDLEWARE = [ @@ -166,3 +169,4 @@ # Example specific settings: WEATHER_API_KEY = os.getenv("WEATHER_API_KEY") # get for free at https://www.weatherapi.com/ +DJANGO_DOCS_BRANCH = "stable/5.0.x" diff --git a/example_movies/example_movies/__init__.py b/example/movies/__init__.py similarity index 100% rename from example_movies/example_movies/__init__.py rename to example/movies/__init__.py diff --git a/example_movies/movies/admin.py b/example/movies/admin.py similarity index 66% rename from example_movies/movies/admin.py rename to example/movies/admin.py index cb97e59..806b8a7 100644 --- a/example_movies/movies/admin.py +++ b/example/movies/admin.py @@ -1,4 +1,5 @@ from django.contrib import admin +from django.utils.safestring import mark_safe from .models import MovieBacklogItem @@ -7,7 +8,7 @@ class MovieBacklogItemAdmin(admin.ModelAdmin): list_display = ( "movie_name", - "imdb_url", + "imdb_url_link", "imdb_rating", "position", "user", @@ -18,3 +19,7 @@ class MovieBacklogItemAdmin(admin.ModelAdmin): list_filter = ("user", "created_at", "updated_at") list_select_related = ("user",) raw_id_fields = ("user",) + + @admin.display(ordering="imdb_url", description="IMDB URL") + def imdb_url_link(self, obj): + return mark_safe(f'{obj.imdb_url}') # noqa: S308 diff --git a/example_movies/movies/ai_assistants.py b/example/movies/ai_assistants.py similarity index 100% rename from example_movies/movies/ai_assistants.py rename to example/movies/ai_assistants.py diff --git a/example_movies/movies/apps.py b/example/movies/apps.py similarity index 100% rename from example_movies/movies/apps.py rename to example/movies/apps.py diff --git a/example_movies/movies/migrations/0001_initial.py b/example/movies/migrations/0001_initial.py similarity index 100% rename from example_movies/movies/migrations/0001_initial.py rename to example/movies/migrations/0001_initial.py diff --git a/example_movies/movies/migrations/0002_alter_moviebacklogitem_options_and_more.py b/example/movies/migrations/0002_alter_moviebacklogitem_options_and_more.py similarity index 100% rename from example_movies/movies/migrations/0002_alter_moviebacklogitem_options_and_more.py rename to example/movies/migrations/0002_alter_moviebacklogitem_options_and_more.py diff --git a/example_movies/movies/migrations/0003_moviebacklogitem_imdb_rating.py b/example/movies/migrations/0003_moviebacklogitem_imdb_rating.py similarity index 100% rename from example_movies/movies/migrations/0003_moviebacklogitem_imdb_rating.py rename to example/movies/migrations/0003_moviebacklogitem_imdb_rating.py diff --git a/example_movies/movies/__init__.py b/example/movies/migrations/__init__.py similarity index 100% rename from example_movies/movies/__init__.py rename to example/movies/migrations/__init__.py diff --git a/example_movies/movies/models.py b/example/movies/models.py similarity index 100% rename from example_movies/movies/models.py rename to example/movies/models.py diff --git a/example/package-lock.json b/example/package-lock.json index a7fee38..c350159 100644 --- a/example/package-lock.json +++ b/example/package-lock.json @@ -14,7 +14,8 @@ "cookie": "^0.6.0", "django-ai-assistant-client": "file:../frontend", "modern-normalize": "^2.0.0", - "react-markdown": "^9.0.1" + "react-markdown": "^9.0.1", + "react-router-dom": "^6.23.1" }, "devDependencies": { "@babel/core": "^7.24.5", @@ -2990,6 +2991,14 @@ "node": ">=14" } }, + "node_modules/@remix-run/router": { + "version": "1.16.1", + "resolved": "https://registry.npmjs.org/@remix-run/router/-/router-1.16.1.tgz", + "integrity": "sha512-es2g3dq6Nb07iFxGk5GuHN20RwBZOsuDQN7izWIisUcv9r+d2C5jQxqmgkdebXgReWfiyUabcki6Fg77mSNrig==", + "engines": { + "node": ">=14.0.0" + } + }, "node_modules/@tabler/icons": { "version": "3.5.0", "license": "MIT", @@ -8038,6 +8047,36 @@ } } }, + "node_modules/react-router": { + "version": "6.23.1", + "resolved": "https://registry.npmjs.org/react-router/-/react-router-6.23.1.tgz", + "integrity": "sha512-fzcOaRF69uvqbbM7OhvQyBTFDVrrGlsFdS3AL+1KfIBtGETibHzi3FkoTRyiDJnWNc2VxrfvR+657ROHjaNjqQ==", + "dependencies": { + "@remix-run/router": "1.16.1" + }, + "engines": { + "node": ">=14.0.0" + }, + "peerDependencies": { + "react": ">=16.8" + } + }, + "node_modules/react-router-dom": { + "version": "6.23.1", + "resolved": "https://registry.npmjs.org/react-router-dom/-/react-router-dom-6.23.1.tgz", + "integrity": "sha512-utP+K+aSTtEdbWpC+4gxhdlPFwuEfDKq8ZrPFU65bbRJY+l706qjR7yaidBpo3MSeA/fzwbXWbKBI6ftOnP3OQ==", + "dependencies": { + "@remix-run/router": "1.16.1", + "react-router": "6.23.1" + }, + "engines": { + "node": ">=14.0.0" + }, + "peerDependencies": { + "react": ">=16.8", + "react-dom": ">=16.8" + } + }, "node_modules/react-style-singleton": { "version": "2.2.1", "license": "MIT", diff --git a/example/package.json b/example/package.json index 33160b0..fd6c27d 100644 --- a/example/package.json +++ b/example/package.json @@ -45,6 +45,7 @@ "cookie": "^0.6.0", "django-ai-assistant-client": "file:../frontend", "modern-normalize": "^2.0.0", - "react-markdown": "^9.0.1" + "react-markdown": "^9.0.1", + "react-router-dom": "^6.23.1" } } diff --git a/example_movies/movies/migrations/__init__.py b/example/rag/__init__.py similarity index 100% rename from example_movies/movies/migrations/__init__.py rename to example/rag/__init__.py diff --git a/example_rag/demo/admin.py b/example/rag/admin.py similarity index 85% rename from example_rag/demo/admin.py rename to example/rag/admin.py index 652071e..52e6049 100644 --- a/example_rag/demo/admin.py +++ b/example/rag/admin.py @@ -9,7 +9,6 @@ class DjangoDocPageAdmin(admin.ModelAdmin): list_display = ("path", "django_docs_url") search_fields = ("path",) + @admin.display(ordering="path", description="Django Docs URL") def django_docs_url(self, obj): return mark_safe(f'{obj.django_docs_url}') # noqa: S308 - - django_docs_url.short_description = "Django Docs URL" diff --git a/example_rag/demo/ai_assistants.py b/example/rag/ai_assistants.py similarity index 100% rename from example_rag/demo/ai_assistants.py rename to example/rag/ai_assistants.py diff --git a/example_rag/demo/apps.py b/example/rag/apps.py similarity index 66% rename from example_rag/demo/apps.py rename to example/rag/apps.py index 6fbee78..7b9be3a 100644 --- a/example_rag/demo/apps.py +++ b/example/rag/apps.py @@ -1,6 +1,6 @@ from django.apps import AppConfig -class DemoConfig(AppConfig): +class RAGConfig(AppConfig): default_auto_field = "django.db.models.BigAutoField" - name = "demo" + name = "rag" diff --git a/example_rag/demo/__init__.py b/example/rag/management/__init__.py similarity index 100% rename from example_rag/demo/__init__.py rename to example/rag/management/__init__.py diff --git a/example_rag/demo/management/__init__.py b/example/rag/management/commands/__init__.py similarity index 100% rename from example_rag/demo/management/__init__.py rename to example/rag/management/commands/__init__.py diff --git a/example_rag/demo/management/commands/fetch_django_docs.py b/example/rag/management/commands/fetch_django_docs.py similarity index 97% rename from example_rag/demo/management/commands/fetch_django_docs.py rename to example/rag/management/commands/fetch_django_docs.py index 7ab12c1..2687578 100644 --- a/example_rag/demo/management/commands/fetch_django_docs.py +++ b/example/rag/management/commands/fetch_django_docs.py @@ -6,7 +6,7 @@ from git import Repo -from demo.models import DjangoDocPage +from rag.models import DjangoDocPage class Command(BaseCommand): diff --git a/example_rag/demo/migrations/0001_initial.py b/example/rag/migrations/0001_initial.py similarity index 100% rename from example_rag/demo/migrations/0001_initial.py rename to example/rag/migrations/0001_initial.py diff --git a/example_rag/demo/management/commands/__init__.py b/example/rag/migrations/__init__.py similarity index 100% rename from example_rag/demo/management/commands/__init__.py rename to example/rag/migrations/__init__.py diff --git a/example_rag/demo/models.py b/example/rag/models.py similarity index 100% rename from example_rag/demo/models.py rename to example/rag/models.py diff --git a/example_movies/movies/templates/base.html b/example/rag/templates/base.html similarity index 100% rename from example_movies/movies/templates/base.html rename to example/rag/templates/base.html diff --git a/example_rag/demo/templates/demo/chat_home.html b/example/rag/templates/demo/chat_home.html similarity index 100% rename from example_rag/demo/templates/demo/chat_home.html rename to example/rag/templates/demo/chat_home.html diff --git a/example_rag/demo/templates/demo/chat_thread.html b/example/rag/templates/demo/chat_thread.html similarity index 100% rename from example_rag/demo/templates/demo/chat_thread.html rename to example/rag/templates/demo/chat_thread.html diff --git a/example_rag/demo/templates/demo/htmx_index.html b/example/rag/templates/demo/htmx_index.html similarity index 100% rename from example_rag/demo/templates/demo/htmx_index.html rename to example/rag/templates/demo/htmx_index.html diff --git a/example_movies/movies/templates/movies/react_index.html b/example/rag/templates/demo/react_index.html similarity index 100% rename from example_movies/movies/templates/movies/react_index.html rename to example/rag/templates/demo/react_index.html diff --git a/example_rag/demo/migrations/__init__.py b/example/weather/__init__.py similarity index 100% rename from example_rag/demo/migrations/__init__.py rename to example/weather/__init__.py diff --git a/example/demo/ai_assistants.py b/example/weather/ai_assistants.py similarity index 100% rename from example/demo/ai_assistants.py rename to example/weather/ai_assistants.py diff --git a/example/weather/apps.py b/example/weather/apps.py new file mode 100644 index 0000000..9a0f07d --- /dev/null +++ b/example/weather/apps.py @@ -0,0 +1,6 @@ +from django.apps import AppConfig + + +class WeatherConfig(AppConfig): + default_auto_field = "django.db.models.BigAutoField" + name = "weather" diff --git a/example_rag/example_rag/__init__.py b/example/weather/migrations/__init__.py similarity index 100% rename from example_rag/example_rag/__init__.py rename to example/weather/migrations/__init__.py diff --git a/example_movies/assets/css/htmx_index.css b/example_movies/assets/css/htmx_index.css deleted file mode 100644 index f3b0e76..0000000 --- a/example_movies/assets/css/htmx_index.css +++ /dev/null @@ -1,12 +0,0 @@ -#threads-container, -#messages-container { - height: calc(100vh - 100px); -} - -.main-container { - max-width: 920px; -} - -[data-loading] { - display: none; -} diff --git a/example_movies/assets/js/App.tsx b/example_movies/assets/js/App.tsx deleted file mode 100644 index b9627e1..0000000 --- a/example_movies/assets/js/App.tsx +++ /dev/null @@ -1,21 +0,0 @@ -import "@mantine/core/styles.css"; - -import { createTheme, MantineProvider } from "@mantine/core"; -import { Chat } from "@/Chat"; -import { configAIAssistant } from "django-ai-assistant-client"; - -const theme = createTheme({}); - -// Relates to path("ai-assistant/", include("django_ai_assistant.urls")) -// which can be found at example/demo/urls.py) -configAIAssistant({ baseURL: "ai-assistant" }); - -const App = () => { - return ( - - - - ); -}; - -export default App; diff --git a/example_movies/assets/js/Chat/Chat.module.css b/example_movies/assets/js/Chat/Chat.module.css deleted file mode 100644 index e4f2917..0000000 --- a/example_movies/assets/js/Chat/Chat.module.css +++ /dev/null @@ -1,15 +0,0 @@ -.main { - font-size: var(--mantine-font-size-md); - height: 100vh; - width: calc(100% - rem(250px)); - margin-left: auto; -} - -.chatContainer { - max-width: calc(51.25rem * var(--mantine-scale)); - height: 100%; -} - -.chat { - height: 100%; -} diff --git a/example_movies/assets/js/Chat/Chat.tsx b/example_movies/assets/js/Chat/Chat.tsx deleted file mode 100644 index 21af812..0000000 --- a/example_movies/assets/js/Chat/Chat.tsx +++ /dev/null @@ -1,206 +0,0 @@ -import { - Container, - Text, - Stack, - Title, - Textarea, - Box, - Button, - LoadingOverlay, - ScrollArea, -} from "@mantine/core"; -import { ThreadsNav } from "./ThreadsNav"; - -import classes from "./Chat.module.css"; -import { useCallback, useEffect, useRef, useState } from "react"; -import { IconSend2 } from "@tabler/icons-react"; -import { getHotkeyHandler } from "@mantine/hooks"; -import Markdown from "react-markdown"; - -import { - ThreadMessagesSchemaOut, - ThreadSchema, - useAssistant, - useMessage, - useThread, -} from "django-ai-assistant-client"; - -function ChatMessage({ message }: { message: ThreadMessagesSchemaOut }) { - return ( - - {message.type === "ai" ? "AI" : "User"} - {message.content} - - ); -} - -function ChatMessageList({ - messages, -}: { - messages: ThreadMessagesSchemaOut[]; -}) { - if (messages.length === 0) { - return No messages.; - } - - // TODO: check why horizontal scroll appears - return ( -
- {messages.map((message, index) => ( - - ))} -
- ); -} - -export function Chat() { - const [assistantId, setAssistantId] = useState(""); - const [activeThread, setActiveThread] = useState(null); - const [inputValue, setInputValue] = useState(""); - - const { fetchAssistants, assistants } = useAssistant(); - const { fetchThreads, threads, createThread, deleteThread } = useThread(); - const { - fetchMessages, - messages, - loadingFetchMessages, - createMessage, - loadingCreateMessage, - } = useMessage(); - - const loadingMessages = loadingFetchMessages || loadingCreateMessage; - const isThreadSelected = assistantId && activeThread; - const isChatActive = assistantId && activeThread && !loadingMessages; - - const scrollViewport = useRef(null); - const scrollToBottom = useCallback( - () => - // setTimeout is used because scrollViewport.current?.scrollHeight update is not - // being triggered in time for the scrollTo method to work properly. - setTimeout( - () => - scrollViewport.current?.scrollTo({ - top: scrollViewport.current!.scrollHeight, - behavior: "smooth", - }), - 500 - ), - [scrollViewport] - ); - - // Load assistantId when component mounts: - useEffect(() => { - if (assistants) { - setAssistantId(assistants[0].id); - } else { - fetchAssistants(); - } - }, [assistants, fetchAssistants]); - - // Load threads when component mounts: - useEffect(() => { - fetchThreads(); - }, [fetchThreads]); - - // Load messages when threadId changes: - useEffect(() => { - if (!assistantId) return; - if (!activeThread) return; - - fetchMessages({ - threadId: activeThread.id, - }); - scrollToBottom(); - }, [assistantId, activeThread?.id, fetchMessages]); - - async function handleCreateMessage() { - if (!activeThread) return; - - await createMessage({ - threadId: activeThread.id, - assistantId, - messageTextValue: inputValue, - }); - - setInputValue(""); - scrollToBottom(); - } - - return ( - <> - -
- - - - Chat - - - - {isThreadSelected ? ( - - ) : ( - - Select or create a thread to start chatting. - - )} - -