From 0e55aa101d193fbcf6b418938d6220eb425cca42 Mon Sep 17 00:00:00 2001 From: amandasavluchinske Date: Fri, 21 Jun 2024 18:02:30 +0100 Subject: [PATCH] Transforms cast_id into decorator --- django_ai_assistant/api/views.py | 17 +++++++++-------- django_ai_assistant/decorators.py | 23 +++++++++++++++++++++++ django_ai_assistant/helpers/formatters.py | 2 +- 3 files changed, 33 insertions(+), 9 deletions(-) create mode 100644 django_ai_assistant/decorators.py diff --git a/django_ai_assistant/api/views.py b/django_ai_assistant/api/views.py index 0fc06a5..a28c94b 100644 --- a/django_ai_assistant/api/views.py +++ b/django_ai_assistant/api/views.py @@ -17,8 +17,9 @@ ThreadSchemaIn, ) from django_ai_assistant.conf import app_settings +from django_ai_assistant.decorators import with_cast_id from django_ai_assistant.exceptions import AIAssistantNotDefinedError, AIUserNotAllowedError -from django_ai_assistant.helpers import formatters, use_cases +from django_ai_assistant.helpers import use_cases from django_ai_assistant.models import Message, Thread @@ -85,8 +86,8 @@ def create_thread(request, payload: ThreadSchemaIn): @api.get("threads/{thread_id}/", response=ThreadSchema, url_name="thread_detail_update_delete") +@with_cast_id def get_thread(request, thread_id: Any): - thread_id = formatters.format_id(id, Thread) try: thread = use_cases.get_single_thread( thread_id=thread_id, user=request.user, request=request @@ -97,16 +98,16 @@ def get_thread(request, thread_id: Any): @api.patch("threads/{thread_id}/", response=ThreadSchema, url_name="thread_detail_update_delete") +@with_cast_id def update_thread(request, thread_id: Any, payload: ThreadSchemaIn): - thread_id = formatters.format_id(thread_id, Thread) thread = get_object_or_404(Thread, id=thread_id) name = payload.name return use_cases.update_thread(thread=thread, name=name, user=request.user, request=request) @api.delete("threads/{thread_id}/", response={204: None}, url_name="thread_detail_update_delete") +@with_cast_id def delete_thread(request, thread_id: Any): - thread_id = formatters.format_id(thread_id, Thread) thread = get_object_or_404(Thread, id=thread_id) use_cases.delete_thread(thread=thread, user=request.user, request=request) return 204, None @@ -117,8 +118,9 @@ def delete_thread(request, thread_id: Any): response=List[ThreadMessagesSchemaOut], url_name="messages_list_create", ) +@with_cast_id def list_thread_messages(request, thread_id: Any): - thread = get_object_or_404(Thread, id=formatters.format_id(thread_id, Thread)) + thread = get_object_or_404(Thread, id=thread_id) messages = use_cases.get_thread_messages(thread=thread, user=request.user, request=request) return [message_to_dict(m)["data"] for m in messages] @@ -129,8 +131,8 @@ def list_thread_messages(request, thread_id: Any): response={201: None}, url_name="messages_list_create", ) +@with_cast_id def create_thread_message(request, thread_id: Any, payload: ThreadMessagesSchemaIn): - thread_id = formatters.format_id(thread_id, Thread) thread = Thread.objects.get(id=thread_id) use_cases.create_message( @@ -146,9 +148,8 @@ def create_thread_message(request, thread_id: Any, payload: ThreadMessagesSchema @api.delete( "threads/{thread_id}/messages/{message_id}/", response={204: None}, url_name="messages_delete" ) +@with_cast_id def delete_thread_message(request, thread_id: Any, message_id: Any): - thread_id = formatters.format_id(thread_id, Message) - message_id = formatters.format_id(message_id, Message) message = get_object_or_404(Message, id=message_id, thread_id=thread_id) use_cases.delete_message( message=message, diff --git a/django_ai_assistant/decorators.py b/django_ai_assistant/decorators.py new file mode 100644 index 0000000..9596320 --- /dev/null +++ b/django_ai_assistant/decorators.py @@ -0,0 +1,23 @@ +from functools import wraps + +from django_ai_assistant.helpers.formatters import cast_id +from django_ai_assistant.models import Message, Thread + + +def with_cast_id(func): + @wraps(func) + def wrapper(*args, **kwargs): + thread_id = kwargs.get("thread_id") + message_id = kwargs.get("message_id") + + if thread_id: + thread_id = cast_id(thread_id, Thread) + kwargs["thread_id"] = thread_id + + if message_id: + message_id = cast_id(message_id, Message) + kwargs["message_id"] = message_id + + return func(*args, **kwargs) + + return wrapper diff --git a/django_ai_assistant/helpers/formatters.py b/django_ai_assistant/helpers/formatters.py index 0614f4b..74cafee 100644 --- a/django_ai_assistant/helpers/formatters.py +++ b/django_ai_assistant/helpers/formatters.py @@ -1,7 +1,7 @@ import uuid -def format_id(item_id, model): +def cast_id(item_id, model): if isinstance(item_id, str) and "UUID" in model._meta.pk.get_internal_type(): return uuid.UUID(item_id) return item_id