From 22768b13341d0bcf766d6857d7240283b0faab26 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Fl=C3=A1vio=20Juvenal?= Date: Mon, 17 Jun 2024 10:19:11 -0300 Subject: [PATCH 1/6] Reestructure modules: separate API --- django_ai_assistant/__init__.py | 13 ++++- django_ai_assistant/{ai => api}/__init__.py | 0 django_ai_assistant/{ => api}/schemas.py | 2 +- django_ai_assistant/{ => api}/views.py | 35 ++++++++----- django_ai_assistant/helpers/assistants.py | 6 +-- django_ai_assistant/langchain/__init__.py | 0 .../chat_message_histories.py | 0 django_ai_assistant/{ => langchain}/tools.py | 0 .../commands/generate_openapi_schema.py | 2 +- django_ai_assistant/urls.py | 2 +- example/demo/views.py | 8 +-- example/movies/ai_assistants.py | 2 +- example/weather/ai_assistants.py | 2 +- frontend/openapi_schema.json | 20 +++---- frontend/src/client/services.gen.ts | 22 ++++---- frontend/src/client/types.gen.ts | 52 +++++++++---------- frontend/src/hooks/useAssistantList.ts | 4 +- frontend/src/hooks/useMessageList.ts | 12 ++--- frontend/src/hooks/useThreadList.ts | 12 ++--- frontend/tests/useAssistantList.test.ts | 8 +-- frontend/tests/useMessageList.test.ts | 28 +++++----- frontend/tests/useThreadList.test.ts | 32 ++++++------ tests/test_assistants.py | 2 +- 23 files changed, 141 insertions(+), 123 deletions(-) rename django_ai_assistant/{ai => api}/__init__.py (100%) rename django_ai_assistant/{ => api}/schemas.py (94%) rename django_ai_assistant/{ => api}/views.py (83%) create mode 100644 django_ai_assistant/langchain/__init__.py rename django_ai_assistant/{ai => langchain}/chat_message_histories.py (100%) rename django_ai_assistant/{ => langchain}/tools.py (100%) diff --git a/django_ai_assistant/__init__.py b/django_ai_assistant/__init__.py index e189e73..5f4ebb0 100644 --- a/django_ai_assistant/__init__.py +++ b/django_ai_assistant/__init__.py @@ -1,4 +1,15 @@ from importlib import metadata +from django_ai_assistant.langchain.tools import ( # noqa + BaseModel, + BaseTool, + Field, + StructuredTool, + Tool, + method_tool, + tool, +) -__version__ = metadata.version(__package__) + +version = __version__ = metadata.version(__package__) +package_name = __package__ diff --git a/django_ai_assistant/ai/__init__.py b/django_ai_assistant/api/__init__.py similarity index 100% rename from django_ai_assistant/ai/__init__.py rename to django_ai_assistant/api/__init__.py diff --git a/django_ai_assistant/schemas.py b/django_ai_assistant/api/schemas.py similarity index 94% rename from django_ai_assistant/schemas.py rename to django_ai_assistant/api/schemas.py index 950179b..72a920e 100644 --- a/django_ai_assistant/schemas.py +++ b/django_ai_assistant/api/schemas.py @@ -4,7 +4,7 @@ from ninja import Field, ModelSchema, Schema -from .models import Thread +from django_ai_assistant.models import Thread class AssistantSchema(Schema): diff --git a/django_ai_assistant/views.py b/django_ai_assistant/api/views.py similarity index 83% rename from django_ai_assistant/views.py rename to django_ai_assistant/api/views.py index e0e5035..d015167 100644 --- a/django_ai_assistant/views.py +++ b/django_ai_assistant/api/views.py @@ -4,12 +4,19 @@ from langchain_core.messages import message_to_dict from ninja import NinjaAPI +from ninja.operation import Operation -from django_ai_assistant import __package__, __version__ - -from .exceptions import AIUserNotAllowedError -from .helpers import assistants -from .helpers.assistants import ( +from django_ai_assistant import package_name, version +from django_ai_assistant.api.schemas import ( + AssistantSchema, + ThreadMessagesSchemaIn, + ThreadMessagesSchemaOut, + ThreadSchema, + ThreadSchemaIn, +) +from django_ai_assistant.exceptions import AIUserNotAllowedError +from django_ai_assistant.helpers import assistants +from django_ai_assistant.helpers.assistants import ( create_message, get_assistants_info, get_single_assistant_info, @@ -17,17 +24,17 @@ get_thread_messages, get_threads, ) -from .models import Message, Thread -from .schemas import ( - AssistantSchema, - ThreadMessagesSchemaIn, - ThreadMessagesSchemaOut, - ThreadSchema, - ThreadSchemaIn, -) +from django_ai_assistant.models import Message, Thread + + +class API(NinjaAPI): + # Force "operationId" to be like "django_ai_assistant_delete_thread" + def get_openapi_operation_id(self, operation: Operation) -> str: + name = operation.view_func.__name__ + return (package_name + "_" + name).replace(".", "_") -api = NinjaAPI(title=__package__, version=__version__, urls_namespace="django_ai_assistant") +api = API(title=package_name, version=version, urls_namespace="django_ai_assistant") @api.exception_handler(AIUserNotAllowedError) diff --git a/django_ai_assistant/helpers/assistants.py b/django_ai_assistant/helpers/assistants.py index 85118dd..6d4c150 100644 --- a/django_ai_assistant/helpers/assistants.py +++ b/django_ai_assistant/helpers/assistants.py @@ -37,12 +37,14 @@ from langchain_core.tools import BaseTool from langchain_openai import ChatOpenAI -from django_ai_assistant.ai.chat_message_histories import DjangoChatMessageHistory from django_ai_assistant.exceptions import ( AIAssistantMisconfiguredError, AIAssistantNotDefinedError, AIUserNotAllowedError, ) +from django_ai_assistant.langchain.chat_message_histories import DjangoChatMessageHistory +from django_ai_assistant.langchain.tools import Tool +from django_ai_assistant.langchain.tools import tool as tool_decorator from django_ai_assistant.models import Message, Thread from django_ai_assistant.permissions import ( can_create_message, @@ -51,8 +53,6 @@ can_delete_thread, can_run_assistant, ) -from django_ai_assistant.tools import Tool -from django_ai_assistant.tools import tool as tool_decorator class AIAssistant(abc.ABC): # noqa: F821 diff --git a/django_ai_assistant/langchain/__init__.py b/django_ai_assistant/langchain/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/django_ai_assistant/ai/chat_message_histories.py b/django_ai_assistant/langchain/chat_message_histories.py similarity index 100% rename from django_ai_assistant/ai/chat_message_histories.py rename to django_ai_assistant/langchain/chat_message_histories.py diff --git a/django_ai_assistant/tools.py b/django_ai_assistant/langchain/tools.py similarity index 100% rename from django_ai_assistant/tools.py rename to django_ai_assistant/langchain/tools.py diff --git a/django_ai_assistant/management/commands/generate_openapi_schema.py b/django_ai_assistant/management/commands/generate_openapi_schema.py index 1b8cd20..bcd0ec3 100644 --- a/django_ai_assistant/management/commands/generate_openapi_schema.py +++ b/django_ai_assistant/management/commands/generate_openapi_schema.py @@ -3,7 +3,7 @@ from django.core.management.base import BaseCommand, CommandError -from django_ai_assistant.views import api +from django_ai_assistant.api.views import api class Command(BaseCommand): diff --git a/django_ai_assistant/urls.py b/django_ai_assistant/urls.py index b34f219..ad27118 100644 --- a/django_ai_assistant/urls.py +++ b/django_ai_assistant/urls.py @@ -1,6 +1,6 @@ from django.urls import path -from .views import api +from .api.views import api urlpatterns = [ diff --git a/example/demo/views.py b/example/demo/views.py index cacc77f..499aebd 100644 --- a/example/demo/views.py +++ b/example/demo/views.py @@ -5,6 +5,10 @@ from pydantic import ValidationError from weather.ai_assistants import WeatherAIAssistant +from django_ai_assistant.api.schemas import ( + ThreadMessagesSchemaIn, + ThreadSchemaIn, +) from django_ai_assistant.helpers.assistants import ( create_message, create_thread, @@ -12,10 +16,6 @@ get_threads, ) from django_ai_assistant.models import Thread -from django_ai_assistant.schemas import ( - ThreadMessagesSchemaIn, - ThreadSchemaIn, -) def react_index(request, **kwargs): diff --git a/example/movies/ai_assistants.py b/example/movies/ai_assistants.py index c044b0d..476fb6c 100644 --- a/example/movies/ai_assistants.py +++ b/example/movies/ai_assistants.py @@ -10,8 +10,8 @@ from langchain_community.utilities import WikipediaAPIWrapper from langchain_core.tools import BaseTool +from django_ai_assistant import method_tool from django_ai_assistant.helpers.assistants import AIAssistant, register_assistant -from django_ai_assistant.tools import method_tool from .models import MovieBacklogItem diff --git a/example/weather/ai_assistants.py b/example/weather/ai_assistants.py index 383470e..bafb599 100644 --- a/example/weather/ai_assistants.py +++ b/example/weather/ai_assistants.py @@ -4,7 +4,7 @@ import requests from django_ai_assistant.helpers.assistants import AIAssistant, register_assistant -from django_ai_assistant.tools import BaseModel, Field, method_tool +from django_ai_assistant.langchain.tools import BaseModel, Field, method_tool BASE_URL = "https://api.weatherapi.com/v1/" diff --git a/frontend/openapi_schema.json b/frontend/openapi_schema.json index fb4f982..75f5898 100644 --- a/frontend/openapi_schema.json +++ b/frontend/openapi_schema.json @@ -8,7 +8,7 @@ "paths": { "/assistants/": { "get": { - "operationId": "django_ai_assistant_views_list_assistants", + "operationId": "django_ai_assistant_list_assistants", "summary": "List Assistants", "parameters": [], "responses": { @@ -31,7 +31,7 @@ }, "/assistants/{assistant_id}/": { "get": { - "operationId": "django_ai_assistant_views_get_assistant", + "operationId": "django_ai_assistant_get_assistant", "summary": "Get Assistant", "parameters": [ { @@ -60,7 +60,7 @@ }, "/threads/": { "get": { - "operationId": "django_ai_assistant_views_list_threads", + "operationId": "django_ai_assistant_list_threads", "summary": "List Threads", "parameters": [], "responses": { @@ -81,7 +81,7 @@ } }, "post": { - "operationId": "django_ai_assistant_views_create_thread", + "operationId": "django_ai_assistant_create_thread", "summary": "Create Thread", "parameters": [], "responses": { @@ -110,7 +110,7 @@ }, "/threads/{thread_id}/": { "get": { - "operationId": "django_ai_assistant_views_get_thread", + "operationId": "django_ai_assistant_get_thread", "summary": "Get Thread", "parameters": [ { @@ -137,7 +137,7 @@ } }, "patch": { - "operationId": "django_ai_assistant_views_update_thread", + "operationId": "django_ai_assistant_update_thread", "summary": "Update Thread", "parameters": [ { @@ -174,7 +174,7 @@ } }, "delete": { - "operationId": "django_ai_assistant_views_delete_thread", + "operationId": "django_ai_assistant_delete_thread", "summary": "Delete Thread", "parameters": [ { @@ -196,7 +196,7 @@ }, "/threads/{thread_id}/messages/": { "get": { - "operationId": "django_ai_assistant_views_list_thread_messages", + "operationId": "django_ai_assistant_list_thread_messages", "summary": "List Thread Messages", "parameters": [ { @@ -227,7 +227,7 @@ } }, "post": { - "operationId": "django_ai_assistant_views_create_thread_message", + "operationId": "django_ai_assistant_create_thread_message", "summary": "Create Thread Message", "parameters": [ { @@ -259,7 +259,7 @@ }, "/threads/{thread_id}/messages/{message_id}/": { "delete": { - "operationId": "django_ai_assistant_views_delete_thread_message", + "operationId": "django_ai_assistant_delete_thread_message", "summary": "Delete Thread Message", "parameters": [ { diff --git a/frontend/src/client/services.gen.ts b/frontend/src/client/services.gen.ts index 4efb537..f441a98 100644 --- a/frontend/src/client/services.gen.ts +++ b/frontend/src/client/services.gen.ts @@ -3,14 +3,14 @@ import type { CancelablePromise } from './core/CancelablePromise'; import { OpenAPI } from './core/OpenAPI'; import { request as __request } from './core/request'; -import type { DjangoAiAssistantViewsListAssistantsResponse, DjangoAiAssistantViewsGetAssistantData, DjangoAiAssistantViewsGetAssistantResponse, DjangoAiAssistantViewsListThreadsResponse, DjangoAiAssistantViewsCreateThreadData, DjangoAiAssistantViewsCreateThreadResponse, DjangoAiAssistantViewsGetThreadData, DjangoAiAssistantViewsGetThreadResponse, DjangoAiAssistantViewsUpdateThreadData, DjangoAiAssistantViewsUpdateThreadResponse, DjangoAiAssistantViewsDeleteThreadData, DjangoAiAssistantViewsDeleteThreadResponse, DjangoAiAssistantViewsListThreadMessagesData, DjangoAiAssistantViewsListThreadMessagesResponse, DjangoAiAssistantViewsCreateThreadMessageData, DjangoAiAssistantViewsCreateThreadMessageResponse, DjangoAiAssistantViewsDeleteThreadMessageData, DjangoAiAssistantViewsDeleteThreadMessageResponse } from './types.gen'; +import type { DjangoAiAssistantListAssistantsResponse, DjangoAiAssistantGetAssistantData, DjangoAiAssistantGetAssistantResponse, DjangoAiAssistantListThreadsResponse, DjangoAiAssistantCreateThreadData, DjangoAiAssistantCreateThreadResponse, DjangoAiAssistantGetThreadData, DjangoAiAssistantGetThreadResponse, DjangoAiAssistantUpdateThreadData, DjangoAiAssistantUpdateThreadResponse, DjangoAiAssistantDeleteThreadData, DjangoAiAssistantDeleteThreadResponse, DjangoAiAssistantListThreadMessagesData, DjangoAiAssistantListThreadMessagesResponse, DjangoAiAssistantCreateThreadMessageData, DjangoAiAssistantCreateThreadMessageResponse, DjangoAiAssistantDeleteThreadMessageData, DjangoAiAssistantDeleteThreadMessageResponse } from './types.gen'; /** * List Assistants * @returns AssistantSchema OK * @throws ApiError */ -export const djangoAiAssistantViewsListAssistants = (): CancelablePromise => { return __request(OpenAPI, { +export const djangoAiAssistantListAssistants = (): CancelablePromise => { return __request(OpenAPI, { method: 'GET', url: '/assistants/' }); }; @@ -22,7 +22,7 @@ export const djangoAiAssistantViewsListAssistants = (): CancelablePromise => { return __request(OpenAPI, { +export const djangoAiAssistantGetAssistant = (data: DjangoAiAssistantGetAssistantData): CancelablePromise => { return __request(OpenAPI, { method: 'GET', url: '/assistants/{assistant_id}/', path: { @@ -35,7 +35,7 @@ export const djangoAiAssistantViewsGetAssistant = (data: DjangoAiAssistantViewsG * @returns ThreadSchema OK * @throws ApiError */ -export const djangoAiAssistantViewsListThreads = (): CancelablePromise => { return __request(OpenAPI, { +export const djangoAiAssistantListThreads = (): CancelablePromise => { return __request(OpenAPI, { method: 'GET', url: '/threads/' }); }; @@ -47,7 +47,7 @@ export const djangoAiAssistantViewsListThreads = (): CancelablePromise => { return __request(OpenAPI, { +export const djangoAiAssistantCreateThread = (data: DjangoAiAssistantCreateThreadData): CancelablePromise => { return __request(OpenAPI, { method: 'POST', url: '/threads/', body: data.requestBody, @@ -61,7 +61,7 @@ export const djangoAiAssistantViewsCreateThread = (data: DjangoAiAssistantViewsC * @returns ThreadSchema OK * @throws ApiError */ -export const djangoAiAssistantViewsGetThread = (data: DjangoAiAssistantViewsGetThreadData): CancelablePromise => { return __request(OpenAPI, { +export const djangoAiAssistantGetThread = (data: DjangoAiAssistantGetThreadData): CancelablePromise => { return __request(OpenAPI, { method: 'GET', url: '/threads/{thread_id}/', path: { @@ -77,7 +77,7 @@ export const djangoAiAssistantViewsGetThread = (data: DjangoAiAssistantViewsGetT * @returns ThreadSchema OK * @throws ApiError */ -export const djangoAiAssistantViewsUpdateThread = (data: DjangoAiAssistantViewsUpdateThreadData): CancelablePromise => { return __request(OpenAPI, { +export const djangoAiAssistantUpdateThread = (data: DjangoAiAssistantUpdateThreadData): CancelablePromise => { return __request(OpenAPI, { method: 'PATCH', url: '/threads/{thread_id}/', path: { @@ -94,7 +94,7 @@ export const djangoAiAssistantViewsUpdateThread = (data: DjangoAiAssistantViewsU * @returns void No Content * @throws ApiError */ -export const djangoAiAssistantViewsDeleteThread = (data: DjangoAiAssistantViewsDeleteThreadData): CancelablePromise => { return __request(OpenAPI, { +export const djangoAiAssistantDeleteThread = (data: DjangoAiAssistantDeleteThreadData): CancelablePromise => { return __request(OpenAPI, { method: 'DELETE', url: '/threads/{thread_id}/', path: { @@ -109,7 +109,7 @@ export const djangoAiAssistantViewsDeleteThread = (data: DjangoAiAssistantViewsD * @returns ThreadMessagesSchemaOut OK * @throws ApiError */ -export const djangoAiAssistantViewsListThreadMessages = (data: DjangoAiAssistantViewsListThreadMessagesData): CancelablePromise => { return __request(OpenAPI, { +export const djangoAiAssistantListThreadMessages = (data: DjangoAiAssistantListThreadMessagesData): CancelablePromise => { return __request(OpenAPI, { method: 'GET', url: '/threads/{thread_id}/messages/', path: { @@ -125,7 +125,7 @@ export const djangoAiAssistantViewsListThreadMessages = (data: DjangoAiAssistant * @returns unknown Created * @throws ApiError */ -export const djangoAiAssistantViewsCreateThreadMessage = (data: DjangoAiAssistantViewsCreateThreadMessageData): CancelablePromise => { return __request(OpenAPI, { +export const djangoAiAssistantCreateThreadMessage = (data: DjangoAiAssistantCreateThreadMessageData): CancelablePromise => { return __request(OpenAPI, { method: 'POST', url: '/threads/{thread_id}/messages/', path: { @@ -143,7 +143,7 @@ export const djangoAiAssistantViewsCreateThreadMessage = (data: DjangoAiAssistan * @returns void No Content * @throws ApiError */ -export const djangoAiAssistantViewsDeleteThreadMessage = (data: DjangoAiAssistantViewsDeleteThreadMessageData): CancelablePromise => { return __request(OpenAPI, { +export const djangoAiAssistantDeleteThreadMessage = (data: DjangoAiAssistantDeleteThreadMessageData): CancelablePromise => { return __request(OpenAPI, { method: 'DELETE', url: '/threads/{thread_id}/messages/{message_id}/', path: { diff --git a/frontend/src/client/types.gen.ts b/frontend/src/client/types.gen.ts index 4766c4e..65846d9 100644 --- a/frontend/src/client/types.gen.ts +++ b/frontend/src/client/types.gen.ts @@ -29,60 +29,60 @@ export type ThreadMessagesSchemaIn = { content: string; }; -export type DjangoAiAssistantViewsListAssistantsResponse = Array; +export type DjangoAiAssistantListAssistantsResponse = Array; -export type DjangoAiAssistantViewsGetAssistantData = { +export type DjangoAiAssistantGetAssistantData = { assistantId: string; }; -export type DjangoAiAssistantViewsGetAssistantResponse = AssistantSchema; +export type DjangoAiAssistantGetAssistantResponse = AssistantSchema; -export type DjangoAiAssistantViewsListThreadsResponse = Array; +export type DjangoAiAssistantListThreadsResponse = Array; -export type DjangoAiAssistantViewsCreateThreadData = { +export type DjangoAiAssistantCreateThreadData = { requestBody: ThreadSchemaIn; }; -export type DjangoAiAssistantViewsCreateThreadResponse = ThreadSchema; +export type DjangoAiAssistantCreateThreadResponse = ThreadSchema; -export type DjangoAiAssistantViewsGetThreadData = { +export type DjangoAiAssistantGetThreadData = { threadId: string; }; -export type DjangoAiAssistantViewsGetThreadResponse = ThreadSchema; +export type DjangoAiAssistantGetThreadResponse = ThreadSchema; -export type DjangoAiAssistantViewsUpdateThreadData = { +export type DjangoAiAssistantUpdateThreadData = { requestBody: ThreadSchemaIn; threadId: string; }; -export type DjangoAiAssistantViewsUpdateThreadResponse = ThreadSchema; +export type DjangoAiAssistantUpdateThreadResponse = ThreadSchema; -export type DjangoAiAssistantViewsDeleteThreadData = { +export type DjangoAiAssistantDeleteThreadData = { threadId: string; }; -export type DjangoAiAssistantViewsDeleteThreadResponse = void; +export type DjangoAiAssistantDeleteThreadResponse = void; -export type DjangoAiAssistantViewsListThreadMessagesData = { +export type DjangoAiAssistantListThreadMessagesData = { threadId: string; }; -export type DjangoAiAssistantViewsListThreadMessagesResponse = Array; +export type DjangoAiAssistantListThreadMessagesResponse = Array; -export type DjangoAiAssistantViewsCreateThreadMessageData = { +export type DjangoAiAssistantCreateThreadMessageData = { requestBody: ThreadMessagesSchemaIn; threadId: string; }; -export type DjangoAiAssistantViewsCreateThreadMessageResponse = unknown; +export type DjangoAiAssistantCreateThreadMessageResponse = unknown; -export type DjangoAiAssistantViewsDeleteThreadMessageData = { +export type DjangoAiAssistantDeleteThreadMessageData = { messageId: string; threadId: string; }; -export type DjangoAiAssistantViewsDeleteThreadMessageResponse = void; +export type DjangoAiAssistantDeleteThreadMessageResponse = void; export type $OpenApiTs = { '/assistants/': { @@ -97,7 +97,7 @@ export type $OpenApiTs = { }; '/assistants/{assistant_id}/': { get: { - req: DjangoAiAssistantViewsGetAssistantData; + req: DjangoAiAssistantGetAssistantData; res: { /** * OK @@ -116,7 +116,7 @@ export type $OpenApiTs = { }; }; post: { - req: DjangoAiAssistantViewsCreateThreadData; + req: DjangoAiAssistantCreateThreadData; res: { /** * OK @@ -127,7 +127,7 @@ export type $OpenApiTs = { }; '/threads/{thread_id}/': { get: { - req: DjangoAiAssistantViewsGetThreadData; + req: DjangoAiAssistantGetThreadData; res: { /** * OK @@ -136,7 +136,7 @@ export type $OpenApiTs = { }; }; patch: { - req: DjangoAiAssistantViewsUpdateThreadData; + req: DjangoAiAssistantUpdateThreadData; res: { /** * OK @@ -145,7 +145,7 @@ export type $OpenApiTs = { }; }; delete: { - req: DjangoAiAssistantViewsDeleteThreadData; + req: DjangoAiAssistantDeleteThreadData; res: { /** * No Content @@ -156,7 +156,7 @@ export type $OpenApiTs = { }; '/threads/{thread_id}/messages/': { get: { - req: DjangoAiAssistantViewsListThreadMessagesData; + req: DjangoAiAssistantListThreadMessagesData; res: { /** * OK @@ -165,7 +165,7 @@ export type $OpenApiTs = { }; }; post: { - req: DjangoAiAssistantViewsCreateThreadMessageData; + req: DjangoAiAssistantCreateThreadMessageData; res: { /** * Created @@ -176,7 +176,7 @@ export type $OpenApiTs = { }; '/threads/{thread_id}/messages/{message_id}/': { delete: { - req: DjangoAiAssistantViewsDeleteThreadMessageData; + req: DjangoAiAssistantDeleteThreadMessageData; res: { /** * No Content diff --git a/frontend/src/hooks/useAssistantList.ts b/frontend/src/hooks/useAssistantList.ts index a3bc7d1..28a5f36 100644 --- a/frontend/src/hooks/useAssistantList.ts +++ b/frontend/src/hooks/useAssistantList.ts @@ -2,7 +2,7 @@ import { useCallback } from "react"; import { useState } from "react"; import { AssistantSchema, - djangoAiAssistantViewsListAssistants, + djangoAiAssistantListAssistants, } from "../client"; /** @@ -21,7 +21,7 @@ export function useAssistantList() { const fetchAssistants = useCallback(async (): Promise => { try { setLoadingFetchAssistants(true); - const fetchedAssistants = await djangoAiAssistantViewsListAssistants(); + const fetchedAssistants = await djangoAiAssistantListAssistants(); setAssistants(fetchedAssistants); return fetchedAssistants; } finally { diff --git a/frontend/src/hooks/useMessageList.ts b/frontend/src/hooks/useMessageList.ts index f52d943..d7bb281 100644 --- a/frontend/src/hooks/useMessageList.ts +++ b/frontend/src/hooks/useMessageList.ts @@ -1,9 +1,9 @@ import { useCallback } from "react"; import { useState } from "react"; import { - djangoAiAssistantViewsCreateThreadMessage, - djangoAiAssistantViewsDeleteThreadMessage, - djangoAiAssistantViewsListThreadMessages, + djangoAiAssistantCreateThreadMessage, + djangoAiAssistantDeleteThreadMessage, + djangoAiAssistantListThreadMessages, ThreadMessagesSchemaOut, } from "../client"; @@ -45,7 +45,7 @@ export function useMessageList({ threadId }: { threadId: string | null }) { try { setLoadingFetchMessages(true); - const fetchedMessages = await djangoAiAssistantViewsListThreadMessages({ + const fetchedMessages = await djangoAiAssistantListThreadMessages({ threadId: threadId, }); setMessages(fetchedMessages); @@ -77,7 +77,7 @@ export function useMessageList({ threadId }: { threadId: string | null }) { try { setLoadingCreateMessage(true); // successful response is 201, None - await djangoAiAssistantViewsCreateThreadMessage({ + await djangoAiAssistantCreateThreadMessage({ threadId, requestBody: { content: messageTextValue, @@ -109,7 +109,7 @@ export function useMessageList({ threadId }: { threadId: string | null }) { try { setLoadingDeleteMessage(true); - await djangoAiAssistantViewsDeleteThreadMessage({ + await djangoAiAssistantDeleteThreadMessage({ threadId, messageId, }); diff --git a/frontend/src/hooks/useThreadList.ts b/frontend/src/hooks/useThreadList.ts index 9045ac0..9ffc645 100644 --- a/frontend/src/hooks/useThreadList.ts +++ b/frontend/src/hooks/useThreadList.ts @@ -2,9 +2,9 @@ import { useCallback } from "react"; import { useState } from "react"; import { ThreadSchema, - djangoAiAssistantViewsCreateThread, - djangoAiAssistantViewsDeleteThread, - djangoAiAssistantViewsListThreads, + djangoAiAssistantCreateThread, + djangoAiAssistantDeleteThread, + djangoAiAssistantListThreads, } from "../client"; /** @@ -27,7 +27,7 @@ export function useThreadList() { const fetchThreads = useCallback(async (): Promise => { try { setLoadingFetchThreads(true); - const fetchedThreads = await djangoAiAssistantViewsListThreads(); + const fetchedThreads = await djangoAiAssistantListThreads(); setThreads(fetchedThreads); return fetchedThreads; } finally { @@ -44,7 +44,7 @@ export function useThreadList() { async ({ name }: { name?: string } = {}): Promise => { try { setLoadingCreateThread(true); - const thread = await djangoAiAssistantViewsCreateThread({ + const thread = await djangoAiAssistantCreateThread({ requestBody: { name: name }, }); await fetchThreads(); @@ -65,7 +65,7 @@ export function useThreadList() { async ({ threadId }: { threadId: string }): Promise => { try { setLoadingDeleteThread(true); - await djangoAiAssistantViewsDeleteThread({ threadId }); + await djangoAiAssistantDeleteThread({ threadId }); await fetchThreads(); } finally { setLoadingDeleteThread(false); diff --git a/frontend/tests/useAssistantList.test.ts b/frontend/tests/useAssistantList.test.ts index 087a18f..44f7b41 100644 --- a/frontend/tests/useAssistantList.test.ts +++ b/frontend/tests/useAssistantList.test.ts @@ -1,9 +1,9 @@ import { act, renderHook } from "@testing-library/react"; import { useAssistantList } from "../src/hooks"; -import { djangoAiAssistantViewsListAssistants } from "../src/client"; +import { djangoAiAssistantListAssistants } from "../src/client"; jest.mock("../src/client", () => ({ - djangoAiAssistantViewsListAssistants: jest + djangoAiAssistantListAssistants: jest .fn() .mockImplementation(() => Promise.resolve()), })); @@ -26,7 +26,7 @@ describe("useAssistantList", () => { { id: 1, name: "Assistant 1" }, { id: 2, name: "Assistant 2" }, ]; - (djangoAiAssistantViewsListAssistants as jest.Mock).mockResolvedValue( + (djangoAiAssistantListAssistants as jest.Mock).mockResolvedValue( mockAssistants ); @@ -44,7 +44,7 @@ describe("useAssistantList", () => { }); it("should set loading to false if fetch fails", async () => { - (djangoAiAssistantViewsListAssistants as jest.Mock).mockRejectedValue( + (djangoAiAssistantListAssistants as jest.Mock).mockRejectedValue( new Error("Failed to fetch") ); diff --git a/frontend/tests/useMessageList.test.ts b/frontend/tests/useMessageList.test.ts index a07f102..5570d17 100644 --- a/frontend/tests/useMessageList.test.ts +++ b/frontend/tests/useMessageList.test.ts @@ -1,20 +1,20 @@ import { act, renderHook } from "@testing-library/react"; import { useMessageList } from "../src/hooks"; import { - djangoAiAssistantViewsCreateThreadMessage, - djangoAiAssistantViewsDeleteThreadMessage, - djangoAiAssistantViewsListThreadMessages, + djangoAiAssistantCreateThreadMessage, + djangoAiAssistantDeleteThreadMessage, + djangoAiAssistantListThreadMessages, ThreadMessagesSchemaOut, } from "../src/client"; jest.mock("../src/client", () => ({ - djangoAiAssistantViewsCreateThreadMessage: jest + djangoAiAssistantCreateThreadMessage: jest .fn() .mockImplementation(() => Promise.resolve()), - djangoAiAssistantViewsListThreadMessages: jest + djangoAiAssistantListThreadMessages: jest .fn() .mockImplementation(() => Promise.resolve()), - djangoAiAssistantViewsDeleteThreadMessage: jest + djangoAiAssistantDeleteThreadMessage: jest .fn() .mockImplementation(() => Promise.resolve()), })); @@ -47,7 +47,7 @@ describe("useMessageList", () => { describe("fetchMessages", () => { it("should fetch messages and update state correctly", async () => { - (djangoAiAssistantViewsListThreadMessages as jest.Mock).mockResolvedValue( + (djangoAiAssistantListThreadMessages as jest.Mock).mockResolvedValue( mockMessages ); @@ -65,7 +65,7 @@ describe("useMessageList", () => { }); it("should set loading to false if fetch fails", async () => { - (djangoAiAssistantViewsListThreadMessages as jest.Mock).mockRejectedValue( + (djangoAiAssistantListThreadMessages as jest.Mock).mockRejectedValue( new Error("Failed to fetch") ); @@ -98,9 +98,9 @@ describe("useMessageList", () => { }, ]; ( - djangoAiAssistantViewsCreateThreadMessage as jest.Mock + djangoAiAssistantCreateThreadMessage as jest.Mock ).mockResolvedValue(null); - (djangoAiAssistantViewsListThreadMessages as jest.Mock).mockResolvedValue( + (djangoAiAssistantListThreadMessages as jest.Mock).mockResolvedValue( [...mockMessages, ...mockNewMessages] ); @@ -127,7 +127,7 @@ describe("useMessageList", () => { it("should set loading to false if create fails", async () => { ( - djangoAiAssistantViewsCreateThreadMessage as jest.Mock + djangoAiAssistantCreateThreadMessage as jest.Mock ).mockRejectedValue(new Error("Failed to create")); const { result } = renderHook(() => useMessageList({ threadId: "1" })); @@ -152,7 +152,7 @@ describe("useMessageList", () => { describe("deleteMessage", () => { it("should delete a message and update state correctly", async () => { const deletedMessageId = mockMessages[0].id; - (djangoAiAssistantViewsListThreadMessages as jest.Mock).mockResolvedValue( + (djangoAiAssistantListThreadMessages as jest.Mock).mockResolvedValue( mockMessages.filter((message) => message.id !== deletedMessageId) ); @@ -177,11 +177,11 @@ describe("useMessageList", () => { it("should set loading to false if delete fails", async () => { const deletedMessageId = mockMessages[0].id; - (djangoAiAssistantViewsListThreadMessages as jest.Mock).mockResolvedValue( + (djangoAiAssistantListThreadMessages as jest.Mock).mockResolvedValue( mockMessages.filter((message) => message.id !== deletedMessageId) ); ( - djangoAiAssistantViewsDeleteThreadMessage as jest.Mock + djangoAiAssistantDeleteThreadMessage as jest.Mock ).mockRejectedValue(new Error("Failed to delete")); const { result } = renderHook(() => useMessageList({ threadId: "1" })); diff --git a/frontend/tests/useThreadList.test.ts b/frontend/tests/useThreadList.test.ts index 17c1d60..6ed97f0 100644 --- a/frontend/tests/useThreadList.test.ts +++ b/frontend/tests/useThreadList.test.ts @@ -1,19 +1,19 @@ import { act, renderHook } from "@testing-library/react"; import { useThreadList } from "../src/hooks"; import { - djangoAiAssistantViewsCreateThread, - djangoAiAssistantViewsDeleteThread, - djangoAiAssistantViewsListThreads, + djangoAiAssistantCreateThread, + djangoAiAssistantDeleteThread, + djangoAiAssistantListThreads, } from "../src/client"; jest.mock("../src/client", () => ({ - djangoAiAssistantViewsCreateThread: jest + djangoAiAssistantCreateThread: jest .fn() .mockImplementation(() => Promise.resolve()), - djangoAiAssistantViewsListThreads: jest + djangoAiAssistantListThreads: jest .fn() .mockImplementation(() => Promise.resolve()), - djangoAiAssistantViewsDeleteThread: jest + djangoAiAssistantDeleteThread: jest .fn() .mockImplementation(() => Promise.resolve()), })); @@ -48,7 +48,7 @@ describe("useThreadList", () => { describe("fetchThreads", () => { it("should fetch threads and update state correctly", async () => { - (djangoAiAssistantViewsListThreads as jest.Mock).mockResolvedValue( + (djangoAiAssistantListThreads as jest.Mock).mockResolvedValue( mockThreads ); @@ -66,7 +66,7 @@ describe("useThreadList", () => { }); it("should set loading to false if fetch fails", async () => { - (djangoAiAssistantViewsListThreads as jest.Mock).mockRejectedValue( + (djangoAiAssistantListThreads as jest.Mock).mockRejectedValue( new Error("Failed to fetch") ); @@ -94,10 +94,10 @@ describe("useThreadList", () => { created_at: "2024-06-11T00:00:00Z", updated_at: "2024-06-11T00:00:00Z", }; - (djangoAiAssistantViewsCreateThread as jest.Mock).mockResolvedValue( + (djangoAiAssistantCreateThread as jest.Mock).mockResolvedValue( mockNewThread ); - (djangoAiAssistantViewsListThreads as jest.Mock).mockResolvedValue([ + (djangoAiAssistantListThreads as jest.Mock).mockResolvedValue([ mockNewThread, ...mockThreads, ]); @@ -125,10 +125,10 @@ describe("useThreadList", () => { created_at: "2024-06-11T00:00:00Z", updated_at: "2024-06-11T00:00:00Z", }; - (djangoAiAssistantViewsCreateThread as jest.Mock).mockResolvedValue( + (djangoAiAssistantCreateThread as jest.Mock).mockResolvedValue( mockNewThread ); - (djangoAiAssistantViewsListThreads as jest.Mock).mockResolvedValue([ + (djangoAiAssistantListThreads as jest.Mock).mockResolvedValue([ mockNewThread, ...mockThreads, ]); @@ -148,7 +148,7 @@ describe("useThreadList", () => { }); it("should set loading to false if create fails", async () => { - (djangoAiAssistantViewsCreateThread as jest.Mock).mockRejectedValue( + (djangoAiAssistantCreateThread as jest.Mock).mockRejectedValue( new Error("Failed to create") ); @@ -171,7 +171,7 @@ describe("useThreadList", () => { describe("deleteThread", () => { it("should delete a thread and update state correctly", async () => { const deletedThreadId = mockThreads[0].id; - (djangoAiAssistantViewsListThreads as jest.Mock).mockResolvedValue( + (djangoAiAssistantListThreads as jest.Mock).mockResolvedValue( mockThreads.filter((thread) => thread.id !== deletedThreadId) ); @@ -196,10 +196,10 @@ describe("useThreadList", () => { it("should set loading to false if delete fails", async () => { const deletedThreadId = mockThreads[0].id; - (djangoAiAssistantViewsListThreads as jest.Mock).mockResolvedValue( + (djangoAiAssistantListThreads as jest.Mock).mockResolvedValue( mockThreads.filter((thread) => thread.id !== deletedThreadId) ); - (djangoAiAssistantViewsDeleteThread as jest.Mock).mockRejectedValue( + (djangoAiAssistantDeleteThread as jest.Mock).mockRejectedValue( new Error("Failed to delete") ); diff --git a/tests/test_assistants.py b/tests/test_assistants.py index dd2bb7f..4ba543b 100644 --- a/tests/test_assistants.py +++ b/tests/test_assistants.py @@ -6,8 +6,8 @@ from langchain_core.retrievers import BaseRetriever from django_ai_assistant.helpers.assistants import AIAssistant +from django_ai_assistant.langchain.tools import BaseModel, Field, method_tool from django_ai_assistant.models import Thread -from django_ai_assistant.tools import BaseModel, Field, method_tool class TemperatureAssistant(AIAssistant): From 89930a9e077939968cfe8c6334539164042ef128 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Fl=C3=A1vio=20Juvenal?= Date: Mon, 17 Jun 2024 11:19:42 -0300 Subject: [PATCH 2/6] Separate assistants from use cases --- django_ai_assistant/__init__.py | 4 + django_ai_assistant/api/views.py | 34 ++--- django_ai_assistant/helpers/assistants.py | 168 +-------------------- django_ai_assistant/helpers/use_cases.py | 171 ++++++++++++++++++++++ example/demo/views.py | 2 +- example/movies/ai_assistants.py | 3 +- example/rag/ai_assistants.py | 2 +- example/weather/ai_assistants.py | 3 +- 8 files changed, 197 insertions(+), 190 deletions(-) create mode 100644 django_ai_assistant/helpers/use_cases.py diff --git a/django_ai_assistant/__init__.py b/django_ai_assistant/__init__.py index 5f4ebb0..6c0f3d2 100644 --- a/django_ai_assistant/__init__.py +++ b/django_ai_assistant/__init__.py @@ -1,5 +1,9 @@ from importlib import metadata +from django_ai_assistant.helpers.assistants import ( # noqa + AIAssistant, + register_assistant, +) from django_ai_assistant.langchain.tools import ( # noqa BaseModel, BaseTool, diff --git a/django_ai_assistant/api/views.py b/django_ai_assistant/api/views.py index d015167..10eb0d6 100644 --- a/django_ai_assistant/api/views.py +++ b/django_ai_assistant/api/views.py @@ -15,15 +15,7 @@ ThreadSchemaIn, ) from django_ai_assistant.exceptions import AIUserNotAllowedError -from django_ai_assistant.helpers import assistants -from django_ai_assistant.helpers.assistants import ( - create_message, - get_assistants_info, - get_single_assistant_info, - get_single_thread, - get_thread_messages, - get_threads, -) +from django_ai_assistant.helpers import use_cases from django_ai_assistant.models import Message, Thread @@ -48,28 +40,30 @@ def ai_user_not_allowed_handler(request, exc): @api.get("assistants/", response=List[AssistantSchema], url_name="assistants_list") def list_assistants(request): - return list(get_assistants_info(user=request.user, request=request)) + return list(use_cases.get_assistants_info(user=request.user, request=request)) @api.get("assistants/{assistant_id}/", response=AssistantSchema, url_name="assistant_detail") def get_assistant(request, assistant_id: str): - return get_single_assistant_info(assistant_id=assistant_id, user=request.user, request=request) + return use_cases.get_single_assistant_info( + assistant_id=assistant_id, user=request.user, request=request + ) @api.get("threads/", response=List[ThreadSchema], url_name="threads_list_create") def list_threads(request): - return list(get_threads(user=request.user, request=request)) + return list(use_cases.get_threads(user=request.user, request=request)) @api.post("threads/", response=ThreadSchema, url_name="threads_list_create") def create_thread(request, payload: ThreadSchemaIn): name = payload.name - return assistants.create_thread(name=name, user=request.user, request=request) + return use_cases.create_thread(name=name, user=request.user, request=request) @api.get("threads/{thread_id}/", response=ThreadSchema, url_name="thread_detail_update_delete") def get_thread(request, thread_id: str): - thread = get_single_thread(thread_id=thread_id, user=request.user, request=request) + thread = use_cases.get_single_thread(thread_id=thread_id, user=request.user, request=request) return thread @@ -77,13 +71,13 @@ def get_thread(request, thread_id: str): def update_thread(request, thread_id: str, payload: ThreadSchemaIn): thread = get_object_or_404(Thread, id=thread_id) name = payload.name - return assistants.update_thread(thread=thread, name=name, user=request.user, request=request) + return use_cases.update_thread(thread=thread, name=name, user=request.user, request=request) @api.delete("threads/{thread_id}/", response={204: None}, url_name="thread_detail_update_delete") def delete_thread(request, thread_id: str): thread = get_object_or_404(Thread, id=thread_id) - assistants.delete_thread(thread=thread, user=request.user, request=request) + use_cases.delete_thread(thread=thread, user=request.user, request=request) return 204, None @@ -93,7 +87,9 @@ def delete_thread(request, thread_id: str): url_name="messages_list_create", ) def list_thread_messages(request, thread_id: str): - messages = get_thread_messages(thread_id=thread_id, user=request.user, request=request) + messages = use_cases.get_thread_messages( + thread_id=thread_id, user=request.user, request=request + ) return [message_to_dict(m)["data"] for m in messages] @@ -106,7 +102,7 @@ def list_thread_messages(request, thread_id: str): def create_thread_message(request, thread_id: str, payload: ThreadMessagesSchemaIn): thread = Thread.objects.get(id=thread_id) - create_message( + use_cases.create_message( assistant_id=payload.assistant_id, thread=thread, user=request.user, @@ -121,7 +117,7 @@ def create_thread_message(request, thread_id: str, payload: ThreadMessagesSchema ) def delete_thread_message(request, thread_id: str, message_id: str): message = get_object_or_404(Message, id=message_id, thread_id=thread_id) - assistants.delete_message( + use_cases.delete_message( message=message, user=request.user, request=request, diff --git a/django_ai_assistant/helpers/assistants.py b/django_ai_assistant/helpers/assistants.py index 6d4c150..f10dd3c 100644 --- a/django_ai_assistant/helpers/assistants.py +++ b/django_ai_assistant/helpers/assistants.py @@ -3,8 +3,6 @@ import re from typing import Any, ClassVar, Sequence, cast -from django.http import HttpRequest - from langchain.agents import AgentExecutor from langchain.agents.format_scratchpad.tools import ( format_to_tool_messages, @@ -15,7 +13,6 @@ DEFAULT_DOCUMENT_SEPARATOR, ) from langchain_core.chat_history import InMemoryChatMessageHistory -from langchain_core.messages import BaseMessage, HumanMessage from langchain_core.output_parsers import StrOutputParser from langchain_core.prompts import ( ChatPromptTemplate, @@ -39,20 +36,9 @@ from django_ai_assistant.exceptions import ( AIAssistantMisconfiguredError, - AIAssistantNotDefinedError, - AIUserNotAllowedError, ) -from django_ai_assistant.langchain.chat_message_histories import DjangoChatMessageHistory from django_ai_assistant.langchain.tools import Tool from django_ai_assistant.langchain.tools import tool as tool_decorator -from django_ai_assistant.models import Message, Thread -from django_ai_assistant.permissions import ( - can_create_message, - can_create_thread, - can_delete_message, - can_delete_thread, - can_run_assistant, -) class AIAssistant(abc.ABC): # noqa: F821 @@ -156,6 +142,9 @@ def get_prompt_template(self): ) def get_message_history(self, thread_id: int | None): + # DjangoChatMessageHistory must be here because Django may not be loaded yet elsewhere: + from django_ai_assistant.langchain.chat_message_histories import DjangoChatMessageHistory + if thread_id is None: return InMemoryChatMessageHistory() return DjangoChatMessageHistory(thread_id) @@ -309,154 +298,3 @@ def as_tool(self, description) -> BaseTool: def register_assistant(cls: type[AIAssistant]): ASSISTANT_CLS_REGISTRY[cls.id] = cls return cls - - -def _get_assistant_cls( - assistant_id: str, - user: Any, - request: HttpRequest | None = None, -): - if assistant_id not in ASSISTANT_CLS_REGISTRY: - raise AIAssistantNotDefinedError(f"Assistant with id={assistant_id} not found") - assistant_cls = ASSISTANT_CLS_REGISTRY[assistant_id] - if not can_run_assistant( - assistant_cls=assistant_cls, - user=user, - request=request, - ): - raise AIUserNotAllowedError("User is not allowed to use this assistant") - return assistant_cls - - -def get_single_assistant_info( - assistant_id: str, - user: Any, - request: HttpRequest | None = None, -): - assistant_cls = _get_assistant_cls(assistant_id, user, request) - - return { - "id": assistant_id, - "name": assistant_cls.name, - } - - -def get_assistants_info( - user: Any, - request: HttpRequest | None = None, -): - return [ - _get_assistant_cls(assistant_id=assistant_id, user=user, request=request) - for assistant_id in ASSISTANT_CLS_REGISTRY.keys() - ] - - -def create_message( - assistant_id: str, - thread: Thread, - user: Any, - content: Any, - request: HttpRequest | None = None, -): - assistant_cls = _get_assistant_cls(assistant_id, user, request) - - if not can_create_message(thread=thread, user=user, request=request): - raise AIUserNotAllowedError("User is not allowed to create messages in this thread") - - # TODO: Check if we can separate the message creation from the chain invoke - assistant = assistant_cls(user=user, request=request) - assistant_message = assistant.invoke( - {"input": content}, - thread_id=thread.id, - ) - return assistant_message - - -def create_thread( - name: str, - user: Any, - request: HttpRequest | None = None, -): - if not can_create_thread(user=user, request=request): - raise AIUserNotAllowedError("User is not allowed to create threads") - - thread = Thread.objects.create(name=name, created_by=user) - return thread - - -def get_single_thread( - thread_id: str, - user: Any, - request: HttpRequest | None = None, -): - return Thread.objects.filter(created_by=user).get(id=thread_id) - - -def get_threads( - user: Any, - request: HttpRequest | None = None, -): - return list(Thread.objects.filter(created_by=user)) - - -def update_thread( - thread: Thread, - name: str, - user: Any, - request: HttpRequest | None = None, -): - if not can_delete_thread(thread=thread, user=user, request=request): - raise AIUserNotAllowedError("User is not allowed to update this thread") - - thread.name = name - thread.save() - return thread - - -def delete_thread( - thread: Thread, - user: Any, - request: HttpRequest | None = None, -): - if not can_delete_thread(thread=thread, user=user, request=request): - raise AIUserNotAllowedError("User is not allowed to delete this thread") - - return thread.delete() - - -def get_thread_messages( - thread_id: str, - user: Any, - request: HttpRequest | None = None, -) -> list[BaseMessage]: - # TODO: have more permissions for threads? View thread permission? - thread = Thread.objects.get(id=thread_id) - if user != thread.created_by: - raise AIUserNotAllowedError("User is not allowed to view messages in this thread") - - return DjangoChatMessageHistory(thread.id).get_messages() - - -def create_thread_message_as_user( - thread_id: str, - content: str, - user: Any, - request: HttpRequest | None = None, -): - # TODO: have more permissions for threads? View thread permission? - thread = Thread.objects.get(id=thread_id) - if user != thread.created_by: - raise AIUserNotAllowedError("User is not allowed to create messages in this thread") - - DjangoChatMessageHistory(thread.id).add_messages([HumanMessage(content=content)]) - - -def delete_message( - message: Message, - user: Any, - request: HttpRequest | None = None, -): - if not can_delete_message(message=message, user=user, request=request): - raise AIUserNotAllowedError("User is not allowed to delete this message") - - return DjangoChatMessageHistory(thread_id=message.thread_id).remove_messages([str(message.id)]) diff --git a/django_ai_assistant/helpers/use_cases.py b/django_ai_assistant/helpers/use_cases.py new file mode 100644 index 0000000..595a424 --- /dev/null +++ b/django_ai_assistant/helpers/use_cases.py @@ -0,0 +1,171 @@ +from typing import Any + +from django.http import HttpRequest + +from langchain_core.messages import BaseMessage, HumanMessage + +from django_ai_assistant.exceptions import ( + AIAssistantNotDefinedError, + AIUserNotAllowedError, +) +from django_ai_assistant.helpers.assistants import ASSISTANT_CLS_REGISTRY +from django_ai_assistant.langchain.chat_message_histories import DjangoChatMessageHistory +from django_ai_assistant.models import Message, Thread +from django_ai_assistant.permissions import ( + can_create_message, + can_create_thread, + can_delete_message, + can_delete_thread, + can_run_assistant, +) + + +def get_assistant_cls( + assistant_id: str, + user: Any, + request: HttpRequest | None = None, +): + if assistant_id not in ASSISTANT_CLS_REGISTRY: + raise AIAssistantNotDefinedError(f"Assistant with id={assistant_id} not found") + assistant_cls = ASSISTANT_CLS_REGISTRY[assistant_id] + if not can_run_assistant( + assistant_cls=assistant_cls, + user=user, + request=request, + ): + raise AIUserNotAllowedError("User is not allowed to use this assistant") + return assistant_cls + + +def get_single_assistant_info( + assistant_id: str, + user: Any, + request: HttpRequest | None = None, +): + assistant_cls = get_assistant_cls(assistant_id, user, request) + + return { + "id": assistant_id, + "name": assistant_cls.name, + } + + +def get_assistants_info( + user: Any, + request: HttpRequest | None = None, +): + return [ + get_assistant_cls(assistant_id=assistant_id, user=user, request=request) + for assistant_id in ASSISTANT_CLS_REGISTRY.keys() + ] + + +def create_message( + assistant_id: str, + thread: Thread, + user: Any, + content: Any, + request: HttpRequest | None = None, +): + assistant_cls = get_assistant_cls(assistant_id, user, request) + + if not can_create_message(thread=thread, user=user, request=request): + raise AIUserNotAllowedError("User is not allowed to create messages in this thread") + + # TODO: Check if we can separate the message creation from the chain invoke + assistant = assistant_cls(user=user, request=request) + assistant_message = assistant.invoke( + {"input": content}, + thread_id=thread.id, + ) + return assistant_message + + +def create_thread( + name: str, + user: Any, + request: HttpRequest | None = None, +): + if not can_create_thread(user=user, request=request): + raise AIUserNotAllowedError("User is not allowed to create threads") + + thread = Thread.objects.create(name=name, created_by=user) + return thread + + +def get_single_thread( + thread_id: str, + user: Any, + request: HttpRequest | None = None, +): + return Thread.objects.filter(created_by=user).get(id=thread_id) + + +def get_threads( + user: Any, + request: HttpRequest | None = None, +): + return list(Thread.objects.filter(created_by=user)) + + +def update_thread( + thread: Thread, + name: str, + user: Any, + request: HttpRequest | None = None, +): + if not can_delete_thread(thread=thread, user=user, request=request): + raise AIUserNotAllowedError("User is not allowed to update this thread") + + thread.name = name + thread.save() + return thread + + +def delete_thread( + thread: Thread, + user: Any, + request: HttpRequest | None = None, +): + if not can_delete_thread(thread=thread, user=user, request=request): + raise AIUserNotAllowedError("User is not allowed to delete this thread") + + return thread.delete() + + +def get_thread_messages( + thread_id: str, + user: Any, + request: HttpRequest | None = None, +) -> list[BaseMessage]: + # TODO: have more permissions for threads? View thread permission? + thread = Thread.objects.get(id=thread_id) + if user != thread.created_by: + raise AIUserNotAllowedError("User is not allowed to view messages in this thread") + + return DjangoChatMessageHistory(thread.id).get_messages() + + +def create_thread_message_as_user( + thread_id: str, + content: str, + user: Any, + request: HttpRequest | None = None, +): + # TODO: have more permissions for threads? View thread permission? + thread = Thread.objects.get(id=thread_id) + if user != thread.created_by: + raise AIUserNotAllowedError("User is not allowed to create messages in this thread") + + DjangoChatMessageHistory(thread.id).add_messages([HumanMessage(content=content)]) + + +def delete_message( + message: Message, + user: Any, + request: HttpRequest | None = None, +): + if not can_delete_message(message=message, user=user, request=request): + raise AIUserNotAllowedError("User is not allowed to delete this message") + + return DjangoChatMessageHistory(thread_id=message.thread_id).remove_messages([str(message.id)]) diff --git a/example/demo/views.py b/example/demo/views.py index 499aebd..2ec59bd 100644 --- a/example/demo/views.py +++ b/example/demo/views.py @@ -9,7 +9,7 @@ ThreadMessagesSchemaIn, ThreadSchemaIn, ) -from django_ai_assistant.helpers.assistants import ( +from django_ai_assistant.helpers.use_cases import ( create_message, create_thread, get_thread_messages, diff --git a/example/movies/ai_assistants.py b/example/movies/ai_assistants.py index 476fb6c..e3c3330 100644 --- a/example/movies/ai_assistants.py +++ b/example/movies/ai_assistants.py @@ -10,8 +10,7 @@ from langchain_community.utilities import WikipediaAPIWrapper from langchain_core.tools import BaseTool -from django_ai_assistant import method_tool -from django_ai_assistant.helpers.assistants import AIAssistant, register_assistant +from django_ai_assistant import AIAssistant, method_tool, register_assistant from .models import MovieBacklogItem diff --git a/example/rag/ai_assistants.py b/example/rag/ai_assistants.py index f4b3084..a72ee29 100644 --- a/example/rag/ai_assistants.py +++ b/example/rag/ai_assistants.py @@ -2,7 +2,7 @@ from langchain_core.retrievers import BaseRetriever from langchain_text_splitters import RecursiveCharacterTextSplitter -from django_ai_assistant.helpers.assistants import AIAssistant, register_assistant +from django_ai_assistant import AIAssistant, register_assistant from .models import DjangoDocPage diff --git a/example/weather/ai_assistants.py b/example/weather/ai_assistants.py index bafb599..847afbf 100644 --- a/example/weather/ai_assistants.py +++ b/example/weather/ai_assistants.py @@ -3,8 +3,7 @@ import requests -from django_ai_assistant.helpers.assistants import AIAssistant, register_assistant -from django_ai_assistant.langchain.tools import BaseModel, Field, method_tool +from django_ai_assistant import AIAssistant, BaseModel, Field, method_tool, register_assistant BASE_URL = "https://api.weatherapi.com/v1/" From ce8fe42136e3e77ca87c627963c9a1264f67540a Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Fl=C3=A1vio=20Juvenal?= Date: Mon, 17 Jun 2024 11:24:01 -0300 Subject: [PATCH 3/6] Fix tests --- tests/ai/test_chat_message_histories.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/ai/test_chat_message_histories.py b/tests/ai/test_chat_message_histories.py index 1127717..4e54cce 100644 --- a/tests/ai/test_chat_message_histories.py +++ b/tests/ai/test_chat_message_histories.py @@ -1,7 +1,7 @@ import pytest from langchain_core.messages import AIMessage, HumanMessage -from django_ai_assistant.ai.chat_message_histories import DjangoChatMessageHistory +from django_ai_assistant.langchain.chat_message_histories import DjangoChatMessageHistory from django_ai_assistant.models import Message, Thread From eaa7d7a78ddbe78386c863c532cfc2d91c96d703 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Fl=C3=A1vio=20Juvenal?= Date: Mon, 17 Jun 2024 13:59:21 -0300 Subject: [PATCH 4/6] Fix useAssistant --- frontend/src/hooks/useAssistant.ts | 4 ++-- frontend/tests/useAssistant.test.ts | 8 ++++---- 2 files changed, 6 insertions(+), 6 deletions(-) diff --git a/frontend/src/hooks/useAssistant.ts b/frontend/src/hooks/useAssistant.ts index d075800..adfae52 100644 --- a/frontend/src/hooks/useAssistant.ts +++ b/frontend/src/hooks/useAssistant.ts @@ -2,7 +2,7 @@ import { useCallback } from "react"; import { useState } from "react"; import { AssistantSchema, - djangoAiAssistantViewsGetAssistant, + djangoAiAssistantGetAssistant, } from "../client"; /** @@ -23,7 +23,7 @@ export function useAssistant({ assistantId }: { const fetchAssistant = useCallback(async (): Promise => { try { setLoadingFetchAssistant(true); - const fetchedAssistant = await djangoAiAssistantViewsGetAssistant({ assistantId }); + const fetchedAssistant = await djangoAiAssistantGetAssistant({ assistantId }); setAssistant(fetchedAssistant); return fetchedAssistant; } finally { diff --git a/frontend/tests/useAssistant.test.ts b/frontend/tests/useAssistant.test.ts index f6527fc..960b8d3 100644 --- a/frontend/tests/useAssistant.test.ts +++ b/frontend/tests/useAssistant.test.ts @@ -1,9 +1,9 @@ import { act, renderHook } from "@testing-library/react"; import { useAssistant } from "../src/hooks"; -import { djangoAiAssistantViewsGetAssistant } from "../src/client"; +import { djangoAiAssistantGetAssistant } from "../src/client"; jest.mock("../src/client", () => ({ - djangoAiAssistantViewsGetAssistant: jest + djangoAiAssistantGetAssistant: jest .fn() .mockImplementation(() => Promise.resolve()), })); @@ -26,7 +26,7 @@ describe("useAssistant", () => { { id: 'weather_assistant', name: "Assistant 1" }, { id: 'movies_assistant', name: "Assistant 2" }, ]; - (djangoAiAssistantViewsGetAssistant as jest.Mock).mockResolvedValue( + (djangoAiAssistantGetAssistant as jest.Mock).mockResolvedValue( mockAssistants ); @@ -44,7 +44,7 @@ describe("useAssistant", () => { }); it("should set loading to false if fetch fails", async () => { - (djangoAiAssistantViewsGetAssistant as jest.Mock).mockRejectedValue( + (djangoAiAssistantGetAssistant as jest.Mock).mockRejectedValue( new Error("Failed to fetch") ); From a4e6f8640bd1e3f4bd82efd7894c968391d012f9 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Fl=C3=A1vio=20Juvenal?= Date: Mon, 17 Jun 2024 14:07:18 -0300 Subject: [PATCH 5/6] Absolute imports everywhere --- django_ai_assistant/admin.py | 2 +- django_ai_assistant/urls.py | 2 +- example/demo/urls.py | 2 +- example/movies/admin.py | 2 +- example/movies/ai_assistants.py | 3 +-- example/rag/admin.py | 2 +- example/rag/ai_assistants.py | 3 +-- 7 files changed, 7 insertions(+), 9 deletions(-) diff --git a/django_ai_assistant/admin.py b/django_ai_assistant/admin.py index e8ec85d..e81bc0f 100644 --- a/django_ai_assistant/admin.py +++ b/django_ai_assistant/admin.py @@ -1,6 +1,6 @@ from django.contrib import admin -from .models import Message, Thread +from django_ai_assistant.models import Message, Thread @admin.register(Thread) diff --git a/django_ai_assistant/urls.py b/django_ai_assistant/urls.py index ad27118..2f41d47 100644 --- a/django_ai_assistant/urls.py +++ b/django_ai_assistant/urls.py @@ -1,6 +1,6 @@ from django.urls import path -from .api.views import api +from django_ai_assistant.api.views import api urlpatterns = [ diff --git a/example/demo/urls.py b/example/demo/urls.py index 40c4dd5..fed2660 100644 --- a/example/demo/urls.py +++ b/example/demo/urls.py @@ -1,6 +1,6 @@ from django.urls import include, path -from . import views +from demo import views urlpatterns = [ diff --git a/example/movies/admin.py b/example/movies/admin.py index 806b8a7..34d045c 100644 --- a/example/movies/admin.py +++ b/example/movies/admin.py @@ -1,7 +1,7 @@ from django.contrib import admin from django.utils.safestring import mark_safe -from .models import MovieBacklogItem +from movies.models import MovieBacklogItem @admin.register(MovieBacklogItem) diff --git a/example/movies/ai_assistants.py b/example/movies/ai_assistants.py index e3c3330..4a55263 100644 --- a/example/movies/ai_assistants.py +++ b/example/movies/ai_assistants.py @@ -11,8 +11,7 @@ from langchain_core.tools import BaseTool from django_ai_assistant import AIAssistant, method_tool, register_assistant - -from .models import MovieBacklogItem +from movies.models import MovieBacklogItem # Note this assistant is not registered, but we'll use it as a tool on the other. diff --git a/example/rag/admin.py b/example/rag/admin.py index 52e6049..9243592 100644 --- a/example/rag/admin.py +++ b/example/rag/admin.py @@ -1,7 +1,7 @@ from django.contrib import admin from django.utils.safestring import mark_safe -from .models import DjangoDocPage +from rag.models import DjangoDocPage @admin.register(DjangoDocPage) diff --git a/example/rag/ai_assistants.py b/example/rag/ai_assistants.py index a72ee29..932a4c5 100644 --- a/example/rag/ai_assistants.py +++ b/example/rag/ai_assistants.py @@ -3,8 +3,7 @@ from langchain_text_splitters import RecursiveCharacterTextSplitter from django_ai_assistant import AIAssistant, register_assistant - -from .models import DjangoDocPage +from rag.models import DjangoDocPage @register_assistant From 0f485da0eb4e379b699826dbb08db8802e0a4801 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Fl=C3=A1vio=20Juvenal?= Date: Mon, 17 Jun 2024 14:09:59 -0300 Subject: [PATCH 6/6] Rename tests/ai to tests/langchain --- tests/{ai => langchain}/__init__.py | 0 tests/{ai => langchain}/test_chat_message_histories.py | 0 2 files changed, 0 insertions(+), 0 deletions(-) rename tests/{ai => langchain}/__init__.py (100%) rename tests/{ai => langchain}/test_chat_message_histories.py (100%) diff --git a/tests/ai/__init__.py b/tests/langchain/__init__.py similarity index 100% rename from tests/ai/__init__.py rename to tests/langchain/__init__.py diff --git a/tests/ai/test_chat_message_histories.py b/tests/langchain/test_chat_message_histories.py similarity index 100% rename from tests/ai/test_chat_message_histories.py rename to tests/langchain/test_chat_message_histories.py