From 03e4a011a1cca60a14b8d228ca817ddc2c284e36 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Fl=C3=A1vio=20Juvenal?= Date: Wed, 19 Jun 2024 14:56:43 -0300 Subject: [PATCH] Reference docstrings at AIAssistant --- django_ai_assistant/helpers/assistants.py | 283 +++++++++++++++++++--- example/movies/ai_assistants.py | 10 - 2 files changed, 253 insertions(+), 40 deletions(-) diff --git a/django_ai_assistant/helpers/assistants.py b/django_ai_assistant/helpers/assistants.py index bc07cf7..c275357 100644 --- a/django_ai_assistant/helpers/assistants.py +++ b/django_ai_assistant/helpers/assistants.py @@ -12,7 +12,8 @@ DEFAULT_DOCUMENT_PROMPT, DEFAULT_DOCUMENT_SEPARATOR, ) -from langchain_core.chat_history import InMemoryChatMessageHistory +from langchain_core.chat_history import BaseChatMessageHistory, InMemoryChatMessageHistory +from langchain_core.language_models import BaseChatModel from langchain_core.output_parsers import StrOutputParser from langchain_core.prompts import ( ChatPromptTemplate, @@ -42,20 +43,64 @@ class AIAssistant(abc.ABC): # noqa: F821 + """Base class for AI Assistants. Subclasses must define at least the following attributes: + - id: str + - name: str + - instructions: str + - model: str + + Subclasses can override the public methods to customize the behavior of the assistant.\n + Tools can be added to the assistant by decorating methods with `@method_tool`.\n + Check the docs Tutorial for more info on how to build an AI Assistant. + """ + id: ClassVar[str] # noqa: A003 - name: str + """Class variable with the id of the assistant. Used to select the assistant to use.\n + Must be unique across the whole Django project and match the pattern '^[a-zA-Z0-9_-]+$'.""" + name: ClassVar[str] + """Class variable with the name of the assistant. + Should be a friendly name to optionally display to users.""" instructions: str + """Instructions for the AI assistant knowing what to do. This is the LLM system prompt.""" model: str - temperature: float + """LLM model name to use for the assistant.\n + Should be a valid model name from OpenAI, because the default `get_llm` method uses OpenAI.\n + `get_llm` can be overridden to use a different LLM implementation. + """ + temperature: float = 1.0 + """Temperature to use for the assistant LLM model.\nDefaults to `1.0`.""" has_rag: bool = False - + """Whether the assistant uses RAG (Retrieval-Augmented Generation) or not.\n + Defaults to `False`. + When True, the assistant will use a retriever to get documents to provide as context to the LLM. + For this to work, the `instructions` should contain a placeholder for the context, + which is `{context}` by default. + Additionally, the assistant class should implement the `get_retriever` method to return + the retriever to use.""" _user: Any | None + """The current user the assistant is helping. A model instance.\n + Set by the constructor. + When API views are used, this is set to the current request user.\n + Can be used in any `@method_tool` to customize behavior.""" _request: Any | None + """The current Django request the assistant was initialized with. A request instance.\n + Set by the constructor.\n + Can be used in any `@method_tool` to customize behavior.""" _view: Any | None + """The current Django view the assistant was initialized with. A view instance.\n + Set by the constructor.\n + Can be used in any `@method_tool` to customize behavior.""" _init_kwargs: dict[str, Any] + """Extra keyword arguments passed to the constructor.\n + Set by the constructor.\n + Can be used in any `@method_tool` to customize behavior.""" _method_tools: Sequence[BaseTool] + """List of `@method_tool` tools the assistant can use. Automatically set by the constructor.""" _registry: ClassVar[dict[str, type["AIAssistant"]]] = {} + """Registry of all AIAssistant subclasses by their id.\n + Automatically populated by when a subclass is declared.\n + Use `get_cls_registry` and `get_cls` to access the registry.""" def __init__(self, *, user=None, request=None, view=None, **kwargs): self._user = user @@ -63,13 +108,10 @@ def __init__(self, *, user=None, request=None, view=None, **kwargs): self._view = view self._init_kwargs = kwargs - self.temperature = 1.0 # default OpenAI temperature for Assistant - self._set_method_tools() def __init_subclass__(cls, **kwargs): - """ - Called when a class is subclassed from AIAssistant. + """Called when a class is subclassed from AIAssistant. This method is automatically invoked when a new subclass of AIAssistant is created. It allows AIAssistant to perform additional setup or configuration @@ -129,40 +171,88 @@ def _set_method_tools(self): @classmethod def get_cls_registry(cls) -> dict[str, type["AIAssistant"]]: - """Get the registry of AIAssistant classes.""" + """Get the registry of AIAssistant classes. + + Returns: + dict[str, type[AIAssistant]]: A dictionary mapping assistant ids to their classes. + """ return cls._registry @classmethod def get_cls(cls, assistant_id: str) -> type["AIAssistant"]: - """Get the AIAssistant class for the given assistant ID.""" + """Get the AIAssistant class for the given assistant ID. + + Args: + assistant_id (str): The ID of the assistant to get. + Returns: + type[AIAssistant]: The AIAssistant subclass for the given ID. + """ return cls.get_cls_registry()[assistant_id] @classmethod def clear_cls_registry(cls: type["AIAssistant"]) -> None: + """Clear the registry of AIAssistant classes.""" + cls._registry.clear() - def get_name(self): - return self.name + def get_instructions(self) -> str: + """Get the instructions for the assistant. By default, this is the `instructions` attribute.\n + Override the `instructions` attribute or this method to use different instructions. - def get_instructions(self): + Returns: + str: The instructions for the assistant, i.e., the LLM system prompt. + """ return self.instructions - def get_model(self): + def get_model(self) -> str: + """Get the LLM model name for the assistant. By default, this is the `model` attribute.\n + Used by the `get_llm` method to create the LLM instance.\n + Override the `model` attribute or this method to use a different LLM model. + + Returns: + str: The LLM model name for the assistant. + """ return self.model - def get_temperature(self): + def get_temperature(self) -> float: + """Get the temperature to use for the assistant LLM model. + By default, this is the `temperature` attribute, which is `1.0` by default.\n + Used by the `get_llm` method to create the LLM instance.\n + Override the `temperature` attribute or this method to use a different temperature. + + Returns: + float: The temperature to use for the assistant LLM model. + """ return self.temperature - def get_model_kwargs(self): + def get_model_kwargs(self) -> dict[str, Any]: + """Get additional keyword arguments to pass to the LLM model constructor.\n + Used by the `get_llm` method to create the LLM instance.\n + Override this method to pass additional keyword arguments to the LLM model constructor. + + Returns: + dict[str, Any]: Additional keyword arguments to pass to the LLM model constructor. + """ return {} - def get_prompt_template(self): + def get_prompt_template(self) -> ChatPromptTemplate: + """Get the `ChatPromptTemplate` for the Langchain chain to use.\n + The system prompt come from the `get_instructions` method.\n + The template includes placeholders for the instructions, chat `{history}`, user `{input}`, + and `{agent_scratchpad}`, all which are necessary for the chain to work properly.\n + The chat history is filled by the chain using the message history from `get_message_history`.\n + If the assistant uses RAG, the instructions should contain a placeholder + for the context, which is `{context}` by default, defined by the `get_context_placeholder` method. + + Returns: + ChatPromptTemplate: The chat prompt template for the Langchain chain. + """ instructions = self.get_instructions() - context_key = self.get_context_key() - if self.has_rag and f"{context_key}" not in instructions: + context_placeholder = self.get_context_placeholder() + if self.has_rag and f"{context_placeholder}" not in instructions: raise AIAssistantMisconfiguredError( f"{self.__class__.__name__} has_rag=True" - f"but does not have a {{{context_key}}} placeholder in instructions." + f"but does not have a {{{context_placeholder}}} placeholder in instructions." ) return ChatPromptTemplate.from_messages( @@ -174,7 +264,19 @@ def get_prompt_template(self): ] ) - def get_message_history(self, thread_id: int | None): + def get_message_history(self, thread_id: int | None) -> BaseChatMessageHistory: + """Get the chat message history instance for the given `thread_id`.\n + The Langchain chain uses the return of this method to get the thread messages + for the assistant, filling the `history` placeholder in the `get_prompt_template`.\n + + Args: + thread_id (int | None): The thread ID for the chat message history. + If `None`, an in-memory chat message history is used. + + Returns: + BaseChatMessageHistory: The chat message history instance for the given `thread_id`. + """ + # DjangoChatMessageHistory must be here because Django may not be loaded yet elsewhere: from django_ai_assistant.langchain.chat_message_histories import DjangoChatMessageHistory @@ -182,7 +284,15 @@ def get_message_history(self, thread_id: int | None): return InMemoryChatMessageHistory() return DjangoChatMessageHistory(thread_id) - def get_llm(self): + def get_llm(self) -> BaseChatModel: + """Get the Langchain LLM instance for the assistant. + By default, this uses the OpenAI implementation.\n + `get_model`, `get_temperature`, and `get_model_kwargs` are used to create the LLM instance.\n + Override this method to use a different LLM implementation. + + Returns: + BaseChatModel: The LLM instance for the assistant. + """ model = self.get_model() temperature = self.get_temperature() model_kwargs = self.get_model_kwargs() @@ -193,23 +303,71 @@ def get_llm(self): ) def get_tools(self) -> Sequence[BaseTool]: + """Get the list of method tools the assistant can use. + By default, this is the `_method_tools` attribute, which are all `@method_tool`s.\n + Override and call super to add additional tools, + such as [any langchain_community tools](https://python.langchain.com/v0.2/docs/integrations/tools/). + + Returns: + Sequence[BaseTool]: The list of tools the assistant can use. + """ return self._method_tools def get_document_separator(self) -> str: + """Get the RAG document separator to use in the prompt. Only used when `has_rag=True`.\n + Defaults to `"\\n\\n"`, which is the Langchain default.\n + Override this method to use a different separator. + + Returns: + str: a separator for documents in the prompt. + """ return DEFAULT_DOCUMENT_SEPARATOR def get_document_prompt(self) -> PromptTemplate: + """Get the PromptTemplate template to use when rendering RAG documents in the prompt. + Only used when `has_rag=True`.\n + Defaults to `PromptTemplate.from_template("{page_content}")`, which is the Langchain default.\n + Override this method to use a different template. + + Returns: + PromptTemplate: a prompt template for RAG documents. + """ return DEFAULT_DOCUMENT_PROMPT - def get_context_key(self) -> str: + def get_context_placeholder(self) -> str: + """Get the RAG context placeholder to use in the prompt when `has_rag=True`.\n + Defaults to `"context"`. Override this method to use a different placeholder. + + Returns: + str: the RAG context placeholder to use in the prompt. + """ return "context" def get_retriever(self) -> BaseRetriever: + """Get the RAG retriever to use for fetching documents.\n + Must be implemented by subclasses when `has_rag=True`.\n + + Returns: + BaseRetriever: the RAG retriever to use for fetching documents. + """ raise NotImplementedError( f"Override the get_retriever with your implementation at {self.__class__.__name__}" ) def get_contextualize_prompt(self) -> ChatPromptTemplate: + """Get the contextualize prompt template for the assistant.\n + This is used when `has_rag=True` and there are previous messages in the thread. + Since the latest user question might reference the chat history, + the LLM needs to generate a new standalone question, + and use that question to query the retriever for relevant documents.\n + By default, this is a prompt that asks the LLM to + reformulate the latest user question without the chat history.\n + Override this method to use a different contextualize prompt.\n + See `get_history_aware_retriever` for how this prompt is used.\n + + Returns: + ChatPromptTemplate: The contextualize prompt template for the assistant. + """ contextualize_q_system_prompt = ( "Given a chat history and the latest user question " "which might reference context in the chat history, " @@ -228,6 +386,20 @@ def get_contextualize_prompt(self) -> ChatPromptTemplate: ) def get_history_aware_retriever(self) -> Runnable[dict, RetrieverOutput]: + """Get the history-aware retriever Langchain chain for the assistant.\n + This is used when `has_rag=True` to fetch documents based on the chat history.\n + By default, this is a chain that checks if there is chat history, + and if so, it uses the chat history to generate a new standalone question + to query the retriever for relevant documents.\n + When there is no chat history, it just passes the input to the retriever.\n + Override this method to use a different history-aware retriever chain. + + Read more about the history-aware retriever in the + [Langchain docs](https://python.langchain.com/v0.2/docs/how_to/qa_chat_history_how_to/). + + Returns: + Runnable[dict, RetrieverOutput]: a history-aware retriever Langchain chain. + """ llm = self.get_llm() retriever = self.get_retriever() prompt = self.get_contextualize_prompt() @@ -244,6 +416,23 @@ def get_history_aware_retriever(self) -> Runnable[dict, RetrieverOutput]: ) def as_chain(self, thread_id: int | None) -> Runnable[dict, dict]: + """Create the Langchain chain for the assistant.\n + This chain is a agent that supports chat history, tool calling, and RAG (if `has_rag=True`).\n + `as_chain` uses many other methods to create the chain.\n + Prefer to override the other methods to customize the chain for the assistant. + Only override this method if you need to customize the chain at a lower level. + + The chain input is a dictionary with the key `"input"` containing the user message.\n + The chain output is a dictionary with the key `"output"` containing the assistant response, + along with the key `"history"` containing the previous chat history. + + Args: + thread_id (int | None): The thread ID for the chat message history. + If `None`, an in-memory chat message history is used. + + Returns: + Runnable[dict, dict]: The Langchain chain for the assistant. + """ # Based on: # - https://python.langchain.com/v0.2/docs/how_to/qa_chat_history_how_to/ # - https://python.langchain.com/v0.2/docs/how_to/migrate_agent/ @@ -273,10 +462,10 @@ def as_chain(self, thread_id: int | None) -> Runnable[dict, dict]: # based on create_stuff_documents_chain: document_separator = self.get_document_separator() document_prompt = self.get_document_prompt() - context_key = self.get_context_key() + context_placeholder = self.get_context_placeholder() chain = chain | RunnablePassthrough.assign( **{ - context_key: lambda x: document_separator.join( + context_placeholder: lambda x: document_separator.join( format_document(doc, document_prompt) for doc in x["docs"] ) } @@ -311,10 +500,37 @@ def as_chain(self, thread_id: int | None) -> Runnable[dict, dict]: return agent_with_chat_history def invoke(self, *args, thread_id: int | None, **kwargs): + """Invoke the assistant Langchain chain with the given arguments and keyword arguments.\n + This is the lower-level method to run the assistant.\n + The chain is created by the `as_chain` method.\n + + Args: + *args: Positional arguments to pass to the chain. + Make sure to include a `dict` like `{"input": "user message"}`. + thread_id (int | None): The thread ID for the chat message history. + If `None`, an in-memory chat message history is used. + **kwargs: Keyword arguments to pass to the chain. + + Returns: + dict: The output of the assistant chain, + structured like `{"output": "assistant response", "history": ...}`. + """ chain = self.as_chain(thread_id) return chain.invoke(*args, **kwargs) def run(self, message, thread_id: int | None, **kwargs): + """Run the assistant with the given message and thread ID.\n + This is the higher-level method to run the assistant.\n + + Args: + message (str): The user message to pass to the assistant. + thread_id (int | None): The thread ID for the chat message history. + If `None`, an in-memory chat message history is used. + **kwargs: Additional keyword arguments to pass to the chain. + + Returns: + str: The assistant response to the user message. + """ return self.invoke( { "input": message, @@ -323,14 +539,21 @@ def run(self, message, thread_id: int | None, **kwargs): **kwargs, )["output"] - def run_as_tool(self, message: str, **kwargs): - chain = self.as_chain(thread_id=None) - output = chain.invoke({"input": message}, **kwargs) - return output["output"] + def _run_as_tool(self, message: str, **kwargs): + return self.run(message, thread_id=None, **kwargs) def as_tool(self, description) -> BaseTool: + """Create a tool from the assistant.\n + This is useful to compose assistants.\n + + Args: + description (str): The description for the tool. + + Returns: + BaseTool: A tool that runs the assistant. The tool name is this assistant's id. + """ return Tool.from_function( - func=self.run_as_tool, + func=self._run_as_tool, name=self.id, description=description, ) diff --git a/example/movies/ai_assistants.py b/example/movies/ai_assistants.py index 8bd7370..4fb9037 100644 --- a/example/movies/ai_assistants.py +++ b/example/movies/ai_assistants.py @@ -1,4 +1,3 @@ -import functools from typing import Sequence from django.db.models import Max @@ -43,15 +42,6 @@ def get_tools(self) -> Sequence[BaseTool]: *super().get_tools(), ] - @functools.lru_cache(maxsize=1024) # noqa: B019 - def run_as_tool(self, message: str, **kwargs): - # We may already know the IMDB URL, so we can return it directly: - any_item = MovieBacklogItem.objects.filter(movie_name__icontains=message).first() - if any_item: - return any_item.imdb_url - - return super().run_as_tool(message, **kwargs) - class MovieRecommendationAIAssistant(AIAssistant): id = "movie_recommendation_assistant" # noqa: A003