Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Drop @register_assistant decorator #98

Merged
merged 9 commits into from
Jun 19, 2024
Next Next commit
Experimental: drop @register_assistant decorator
  • Loading branch information
pamella committed Jun 19, 2024
commit 9c944eb62c0d3974150e685fb91621d44b1e150b
2 changes: 1 addition & 1 deletion django_ai_assistant/__init__.py
Original file line number Diff line number Diff line change
@@ -2,7 +2,7 @@

from django_ai_assistant.helpers.assistants import ( # noqa
AIAssistant,
register_assistant,
get_assistant_cls_registry,
)
from django_ai_assistant.langchain.tools import ( # noqa
BaseModel,
16 changes: 10 additions & 6 deletions django_ai_assistant/helpers/assistants.py
Original file line number Diff line number Diff line change
@@ -113,6 +113,14 @@ def _set_method_tools(self):

self._method_tools = tools

@classmethod
def _get_assistant_cls_registry(cls: type["AIAssistant"]) -> dict[str, type["AIAssistant"]]:
registry: dict[str, type["AIAssistant"]] = {}
for subclass in cls.__subclasses__():
registry[subclass.id] = subclass
registry.update(subclass._get_assistant_cls_registry())
return registry

def get_name(self):
return self.name

@@ -297,9 +305,5 @@ def as_tool(self, description) -> BaseTool:
)


ASSISTANT_CLS_REGISTRY: dict[str, type[AIAssistant]] = {}


def register_assistant(cls: type[AIAssistant]):
ASSISTANT_CLS_REGISTRY[cls.id] = cls
return cls
def get_assistant_cls_registry() -> dict[str, type[AIAssistant]]:
return AIAssistant._get_assistant_cls_registry()
8 changes: 4 additions & 4 deletions django_ai_assistant/helpers/use_cases.py
Original file line number Diff line number Diff line change
@@ -8,7 +8,7 @@
AIAssistantNotDefinedError,
AIUserNotAllowedError,
)
from django_ai_assistant.helpers.assistants import ASSISTANT_CLS_REGISTRY
from django_ai_assistant.helpers.assistants import get_assistant_cls_registry
from django_ai_assistant.langchain.chat_message_histories import DjangoChatMessageHistory
from django_ai_assistant.models import Message, Thread
from django_ai_assistant.permissions import (
@@ -25,9 +25,9 @@ def get_assistant_cls(
user: Any,
request: HttpRequest | None = None,
):
if assistant_id not in ASSISTANT_CLS_REGISTRY:
if assistant_id not in get_assistant_cls_registry():
raise AIAssistantNotDefinedError(f"Assistant with id={assistant_id} not found")
assistant_cls = ASSISTANT_CLS_REGISTRY[assistant_id]
assistant_cls = get_assistant_cls_registry()[assistant_id]
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

For this usage, we could have a get_assistant_cls helper but there are a few side effects (please see the PR description for more details).

if not can_run_assistant(
assistant_cls=assistant_cls,
user=user,
@@ -56,7 +56,7 @@ def get_assistants_info(
):
return [
get_assistant_cls(assistant_id=assistant_id, user=user, request=request)
for assistant_id in ASSISTANT_CLS_REGISTRY.keys()
for assistant_id in get_assistant_cls_registry().keys()
]


3 changes: 1 addition & 2 deletions example/movies/ai_assistants.py
Original file line number Diff line number Diff line change
@@ -10,7 +10,7 @@
from langchain_community.utilities import WikipediaAPIWrapper
from langchain_core.tools import BaseTool

from django_ai_assistant import AIAssistant, method_tool, register_assistant
from django_ai_assistant import AIAssistant, method_tool
from movies.models import MovieBacklogItem


@@ -53,7 +53,6 @@ def run_as_tool(self, message: str, **kwargs):
return super().run_as_tool(message, **kwargs)


@register_assistant
class MovieRecommendationAIAssistant(AIAssistant):
id = "movie_recommendation_assistant" # noqa: A003
instructions = (
3 changes: 1 addition & 2 deletions example/rag/ai_assistants.py
Original file line number Diff line number Diff line change
@@ -2,11 +2,10 @@
from langchain_core.retrievers import BaseRetriever
from langchain_text_splitters import RecursiveCharacterTextSplitter

from django_ai_assistant import AIAssistant, register_assistant
from django_ai_assistant import AIAssistant
from rag.models import DjangoDocPage


@register_assistant
class DjangoDocsAssistant(AIAssistant):
id = "django_docs_assistant" # noqa: A003
name = "Django Docs Assistant"
3 changes: 1 addition & 2 deletions example/weather/ai_assistants.py
Original file line number Diff line number Diff line change
@@ -3,14 +3,13 @@

import requests

from django_ai_assistant import AIAssistant, BaseModel, Field, method_tool, register_assistant
from django_ai_assistant import AIAssistant, BaseModel, Field, method_tool


BASE_URL = "https://api.weatherapi.com/v1/"
TIMEOUT = 10


@register_assistant
class WeatherAIAssistant(AIAssistant):
id = "weather_assistant" # noqa: A003
name = "Weather Assistant"
3 changes: 1 addition & 2 deletions tests/test_views.py
Original file line number Diff line number Diff line change
@@ -3,14 +3,13 @@
import pytest

from django_ai_assistant.exceptions import AIAssistantNotDefinedError
from django_ai_assistant.helpers.assistants import AIAssistant, register_assistant
from django_ai_assistant.helpers.assistants import AIAssistant
from django_ai_assistant.langchain.tools import BaseModel, Field, method_tool


# Set up


@register_assistant
class TemperatureAssistant(AIAssistant):
id = "temperature_assistant" # noqa: A003
name = "Temperature Assistant"
Loading
Oops, something went wrong.