From 03e4a011a1cca60a14b8d228ca817ddc2c284e36 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Fl=C3=A1vio=20Juvenal?= Date: Wed, 19 Jun 2024 14:56:43 -0300 Subject: [PATCH 1/4] Reference docstrings at AIAssistant --- django_ai_assistant/helpers/assistants.py | 283 +++++++++++++++++++--- example/movies/ai_assistants.py | 10 - 2 files changed, 253 insertions(+), 40 deletions(-) diff --git a/django_ai_assistant/helpers/assistants.py b/django_ai_assistant/helpers/assistants.py index bc07cf7..c275357 100644 --- a/django_ai_assistant/helpers/assistants.py +++ b/django_ai_assistant/helpers/assistants.py @@ -12,7 +12,8 @@ DEFAULT_DOCUMENT_PROMPT, DEFAULT_DOCUMENT_SEPARATOR, ) -from langchain_core.chat_history import InMemoryChatMessageHistory +from langchain_core.chat_history import BaseChatMessageHistory, InMemoryChatMessageHistory +from langchain_core.language_models import BaseChatModel from langchain_core.output_parsers import StrOutputParser from langchain_core.prompts import ( ChatPromptTemplate, @@ -42,20 +43,64 @@ class AIAssistant(abc.ABC): # noqa: F821 + """Base class for AI Assistants. Subclasses must define at least the following attributes: + - id: str + - name: str + - instructions: str + - model: str + + Subclasses can override the public methods to customize the behavior of the assistant.\n + Tools can be added to the assistant by decorating methods with `@method_tool`.\n + Check the docs Tutorial for more info on how to build an AI Assistant. + """ + id: ClassVar[str] # noqa: A003 - name: str + """Class variable with the id of the assistant. Used to select the assistant to use.\n + Must be unique across the whole Django project and match the pattern '^[a-zA-Z0-9_-]+$'.""" + name: ClassVar[str] + """Class variable with the name of the assistant. + Should be a friendly name to optionally display to users.""" instructions: str + """Instructions for the AI assistant knowing what to do. This is the LLM system prompt.""" model: str - temperature: float + """LLM model name to use for the assistant.\n + Should be a valid model name from OpenAI, because the default `get_llm` method uses OpenAI.\n + `get_llm` can be overridden to use a different LLM implementation. + """ + temperature: float = 1.0 + """Temperature to use for the assistant LLM model.\nDefaults to `1.0`.""" has_rag: bool = False - + """Whether the assistant uses RAG (Retrieval-Augmented Generation) or not.\n + Defaults to `False`. + When True, the assistant will use a retriever to get documents to provide as context to the LLM. + For this to work, the `instructions` should contain a placeholder for the context, + which is `{context}` by default. + Additionally, the assistant class should implement the `get_retriever` method to return + the retriever to use.""" _user: Any | None + """The current user the assistant is helping. A model instance.\n + Set by the constructor. + When API views are used, this is set to the current request user.\n + Can be used in any `@method_tool` to customize behavior.""" _request: Any | None + """The current Django request the assistant was initialized with. A request instance.\n + Set by the constructor.\n + Can be used in any `@method_tool` to customize behavior.""" _view: Any | None + """The current Django view the assistant was initialized with. A view instance.\n + Set by the constructor.\n + Can be used in any `@method_tool` to customize behavior.""" _init_kwargs: dict[str, Any] + """Extra keyword arguments passed to the constructor.\n + Set by the constructor.\n + Can be used in any `@method_tool` to customize behavior.""" _method_tools: Sequence[BaseTool] + """List of `@method_tool` tools the assistant can use. Automatically set by the constructor.""" _registry: ClassVar[dict[str, type["AIAssistant"]]] = {} + """Registry of all AIAssistant subclasses by their id.\n + Automatically populated by when a subclass is declared.\n + Use `get_cls_registry` and `get_cls` to access the registry.""" def __init__(self, *, user=None, request=None, view=None, **kwargs): self._user = user @@ -63,13 +108,10 @@ def __init__(self, *, user=None, request=None, view=None, **kwargs): self._view = view self._init_kwargs = kwargs - self.temperature = 1.0 # default OpenAI temperature for Assistant - self._set_method_tools() def __init_subclass__(cls, **kwargs): - """ - Called when a class is subclassed from AIAssistant. + """Called when a class is subclassed from AIAssistant. This method is automatically invoked when a new subclass of AIAssistant is created. It allows AIAssistant to perform additional setup or configuration @@ -129,40 +171,88 @@ def _set_method_tools(self): @classmethod def get_cls_registry(cls) -> dict[str, type["AIAssistant"]]: - """Get the registry of AIAssistant classes.""" + """Get the registry of AIAssistant classes. + + Returns: + dict[str, type[AIAssistant]]: A dictionary mapping assistant ids to their classes. + """ return cls._registry @classmethod def get_cls(cls, assistant_id: str) -> type["AIAssistant"]: - """Get the AIAssistant class for the given assistant ID.""" + """Get the AIAssistant class for the given assistant ID. + + Args: + assistant_id (str): The ID of the assistant to get. + Returns: + type[AIAssistant]: The AIAssistant subclass for the given ID. + """ return cls.get_cls_registry()[assistant_id] @classmethod def clear_cls_registry(cls: type["AIAssistant"]) -> None: + """Clear the registry of AIAssistant classes.""" + cls._registry.clear() - def get_name(self): - return self.name + def get_instructions(self) -> str: + """Get the instructions for the assistant. By default, this is the `instructions` attribute.\n + Override the `instructions` attribute or this method to use different instructions. - def get_instructions(self): + Returns: + str: The instructions for the assistant, i.e., the LLM system prompt. + """ return self.instructions - def get_model(self): + def get_model(self) -> str: + """Get the LLM model name for the assistant. By default, this is the `model` attribute.\n + Used by the `get_llm` method to create the LLM instance.\n + Override the `model` attribute or this method to use a different LLM model. + + Returns: + str: The LLM model name for the assistant. + """ return self.model - def get_temperature(self): + def get_temperature(self) -> float: + """Get the temperature to use for the assistant LLM model. + By default, this is the `temperature` attribute, which is `1.0` by default.\n + Used by the `get_llm` method to create the LLM instance.\n + Override the `temperature` attribute or this method to use a different temperature. + + Returns: + float: The temperature to use for the assistant LLM model. + """ return self.temperature - def get_model_kwargs(self): + def get_model_kwargs(self) -> dict[str, Any]: + """Get additional keyword arguments to pass to the LLM model constructor.\n + Used by the `get_llm` method to create the LLM instance.\n + Override this method to pass additional keyword arguments to the LLM model constructor. + + Returns: + dict[str, Any]: Additional keyword arguments to pass to the LLM model constructor. + """ return {} - def get_prompt_template(self): + def get_prompt_template(self) -> ChatPromptTemplate: + """Get the `ChatPromptTemplate` for the Langchain chain to use.\n + The system prompt come from the `get_instructions` method.\n + The template includes placeholders for the instructions, chat `{history}`, user `{input}`, + and `{agent_scratchpad}`, all which are necessary for the chain to work properly.\n + The chat history is filled by the chain using the message history from `get_message_history`.\n + If the assistant uses RAG, the instructions should contain a placeholder + for the context, which is `{context}` by default, defined by the `get_context_placeholder` method. + + Returns: + ChatPromptTemplate: The chat prompt template for the Langchain chain. + """ instructions = self.get_instructions() - context_key = self.get_context_key() - if self.has_rag and f"{context_key}" not in instructions: + context_placeholder = self.get_context_placeholder() + if self.has_rag and f"{context_placeholder}" not in instructions: raise AIAssistantMisconfiguredError( f"{self.__class__.__name__} has_rag=True" - f"but does not have a {{{context_key}}} placeholder in instructions." + f"but does not have a {{{context_placeholder}}} placeholder in instructions." ) return ChatPromptTemplate.from_messages( @@ -174,7 +264,19 @@ def get_prompt_template(self): ] ) - def get_message_history(self, thread_id: int | None): + def get_message_history(self, thread_id: int | None) -> BaseChatMessageHistory: + """Get the chat message history instance for the given `thread_id`.\n + The Langchain chain uses the return of this method to get the thread messages + for the assistant, filling the `history` placeholder in the `get_prompt_template`.\n + + Args: + thread_id (int | None): The thread ID for the chat message history. + If `None`, an in-memory chat message history is used. + + Returns: + BaseChatMessageHistory: The chat message history instance for the given `thread_id`. + """ + # DjangoChatMessageHistory must be here because Django may not be loaded yet elsewhere: from django_ai_assistant.langchain.chat_message_histories import DjangoChatMessageHistory @@ -182,7 +284,15 @@ def get_message_history(self, thread_id: int | None): return InMemoryChatMessageHistory() return DjangoChatMessageHistory(thread_id) - def get_llm(self): + def get_llm(self) -> BaseChatModel: + """Get the Langchain LLM instance for the assistant. + By default, this uses the OpenAI implementation.\n + `get_model`, `get_temperature`, and `get_model_kwargs` are used to create the LLM instance.\n + Override this method to use a different LLM implementation. + + Returns: + BaseChatModel: The LLM instance for the assistant. + """ model = self.get_model() temperature = self.get_temperature() model_kwargs = self.get_model_kwargs() @@ -193,23 +303,71 @@ def get_llm(self): ) def get_tools(self) -> Sequence[BaseTool]: + """Get the list of method tools the assistant can use. + By default, this is the `_method_tools` attribute, which are all `@method_tool`s.\n + Override and call super to add additional tools, + such as [any langchain_community tools](https://python.langchain.com/v0.2/docs/integrations/tools/). + + Returns: + Sequence[BaseTool]: The list of tools the assistant can use. + """ return self._method_tools def get_document_separator(self) -> str: + """Get the RAG document separator to use in the prompt. Only used when `has_rag=True`.\n + Defaults to `"\\n\\n"`, which is the Langchain default.\n + Override this method to use a different separator. + + Returns: + str: a separator for documents in the prompt. + """ return DEFAULT_DOCUMENT_SEPARATOR def get_document_prompt(self) -> PromptTemplate: + """Get the PromptTemplate template to use when rendering RAG documents in the prompt. + Only used when `has_rag=True`.\n + Defaults to `PromptTemplate.from_template("{page_content}")`, which is the Langchain default.\n + Override this method to use a different template. + + Returns: + PromptTemplate: a prompt template for RAG documents. + """ return DEFAULT_DOCUMENT_PROMPT - def get_context_key(self) -> str: + def get_context_placeholder(self) -> str: + """Get the RAG context placeholder to use in the prompt when `has_rag=True`.\n + Defaults to `"context"`. Override this method to use a different placeholder. + + Returns: + str: the RAG context placeholder to use in the prompt. + """ return "context" def get_retriever(self) -> BaseRetriever: + """Get the RAG retriever to use for fetching documents.\n + Must be implemented by subclasses when `has_rag=True`.\n + + Returns: + BaseRetriever: the RAG retriever to use for fetching documents. + """ raise NotImplementedError( f"Override the get_retriever with your implementation at {self.__class__.__name__}" ) def get_contextualize_prompt(self) -> ChatPromptTemplate: + """Get the contextualize prompt template for the assistant.\n + This is used when `has_rag=True` and there are previous messages in the thread. + Since the latest user question might reference the chat history, + the LLM needs to generate a new standalone question, + and use that question to query the retriever for relevant documents.\n + By default, this is a prompt that asks the LLM to + reformulate the latest user question without the chat history.\n + Override this method to use a different contextualize prompt.\n + See `get_history_aware_retriever` for how this prompt is used.\n + + Returns: + ChatPromptTemplate: The contextualize prompt template for the assistant. + """ contextualize_q_system_prompt = ( "Given a chat history and the latest user question " "which might reference context in the chat history, " @@ -228,6 +386,20 @@ def get_contextualize_prompt(self) -> ChatPromptTemplate: ) def get_history_aware_retriever(self) -> Runnable[dict, RetrieverOutput]: + """Get the history-aware retriever Langchain chain for the assistant.\n + This is used when `has_rag=True` to fetch documents based on the chat history.\n + By default, this is a chain that checks if there is chat history, + and if so, it uses the chat history to generate a new standalone question + to query the retriever for relevant documents.\n + When there is no chat history, it just passes the input to the retriever.\n + Override this method to use a different history-aware retriever chain. + + Read more about the history-aware retriever in the + [Langchain docs](https://python.langchain.com/v0.2/docs/how_to/qa_chat_history_how_to/). + + Returns: + Runnable[dict, RetrieverOutput]: a history-aware retriever Langchain chain. + """ llm = self.get_llm() retriever = self.get_retriever() prompt = self.get_contextualize_prompt() @@ -244,6 +416,23 @@ def get_history_aware_retriever(self) -> Runnable[dict, RetrieverOutput]: ) def as_chain(self, thread_id: int | None) -> Runnable[dict, dict]: + """Create the Langchain chain for the assistant.\n + This chain is a agent that supports chat history, tool calling, and RAG (if `has_rag=True`).\n + `as_chain` uses many other methods to create the chain.\n + Prefer to override the other methods to customize the chain for the assistant. + Only override this method if you need to customize the chain at a lower level. + + The chain input is a dictionary with the key `"input"` containing the user message.\n + The chain output is a dictionary with the key `"output"` containing the assistant response, + along with the key `"history"` containing the previous chat history. + + Args: + thread_id (int | None): The thread ID for the chat message history. + If `None`, an in-memory chat message history is used. + + Returns: + Runnable[dict, dict]: The Langchain chain for the assistant. + """ # Based on: # - https://python.langchain.com/v0.2/docs/how_to/qa_chat_history_how_to/ # - https://python.langchain.com/v0.2/docs/how_to/migrate_agent/ @@ -273,10 +462,10 @@ def as_chain(self, thread_id: int | None) -> Runnable[dict, dict]: # based on create_stuff_documents_chain: document_separator = self.get_document_separator() document_prompt = self.get_document_prompt() - context_key = self.get_context_key() + context_placeholder = self.get_context_placeholder() chain = chain | RunnablePassthrough.assign( **{ - context_key: lambda x: document_separator.join( + context_placeholder: lambda x: document_separator.join( format_document(doc, document_prompt) for doc in x["docs"] ) } @@ -311,10 +500,37 @@ def as_chain(self, thread_id: int | None) -> Runnable[dict, dict]: return agent_with_chat_history def invoke(self, *args, thread_id: int | None, **kwargs): + """Invoke the assistant Langchain chain with the given arguments and keyword arguments.\n + This is the lower-level method to run the assistant.\n + The chain is created by the `as_chain` method.\n + + Args: + *args: Positional arguments to pass to the chain. + Make sure to include a `dict` like `{"input": "user message"}`. + thread_id (int | None): The thread ID for the chat message history. + If `None`, an in-memory chat message history is used. + **kwargs: Keyword arguments to pass to the chain. + + Returns: + dict: The output of the assistant chain, + structured like `{"output": "assistant response", "history": ...}`. + """ chain = self.as_chain(thread_id) return chain.invoke(*args, **kwargs) def run(self, message, thread_id: int | None, **kwargs): + """Run the assistant with the given message and thread ID.\n + This is the higher-level method to run the assistant.\n + + Args: + message (str): The user message to pass to the assistant. + thread_id (int | None): The thread ID for the chat message history. + If `None`, an in-memory chat message history is used. + **kwargs: Additional keyword arguments to pass to the chain. + + Returns: + str: The assistant response to the user message. + """ return self.invoke( { "input": message, @@ -323,14 +539,21 @@ def run(self, message, thread_id: int | None, **kwargs): **kwargs, )["output"] - def run_as_tool(self, message: str, **kwargs): - chain = self.as_chain(thread_id=None) - output = chain.invoke({"input": message}, **kwargs) - return output["output"] + def _run_as_tool(self, message: str, **kwargs): + return self.run(message, thread_id=None, **kwargs) def as_tool(self, description) -> BaseTool: + """Create a tool from the assistant.\n + This is useful to compose assistants.\n + + Args: + description (str): The description for the tool. + + Returns: + BaseTool: A tool that runs the assistant. The tool name is this assistant's id. + """ return Tool.from_function( - func=self.run_as_tool, + func=self._run_as_tool, name=self.id, description=description, ) diff --git a/example/movies/ai_assistants.py b/example/movies/ai_assistants.py index 8bd7370..4fb9037 100644 --- a/example/movies/ai_assistants.py +++ b/example/movies/ai_assistants.py @@ -1,4 +1,3 @@ -import functools from typing import Sequence from django.db.models import Max @@ -43,15 +42,6 @@ def get_tools(self) -> Sequence[BaseTool]: *super().get_tools(), ] - @functools.lru_cache(maxsize=1024) # noqa: B019 - def run_as_tool(self, message: str, **kwargs): - # We may already know the IMDB URL, so we can return it directly: - any_item = MovieBacklogItem.objects.filter(movie_name__icontains=message).first() - if any_item: - return any_item.imdb_url - - return super().run_as_tool(message, **kwargs) - class MovieRecommendationAIAssistant(AIAssistant): id = "movie_recommendation_assistant" # noqa: A003 From 3e3b6211a5d5342b9e25b2bad25531691c279ff9 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Fl=C3=A1vio=20Juvenal?= Date: Wed, 19 Jun 2024 16:01:14 -0300 Subject: [PATCH 2/4] Apply suggestions from code review Co-authored-by: Amanda Savluchinske --- django_ai_assistant/helpers/assistants.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/django_ai_assistant/helpers/assistants.py b/django_ai_assistant/helpers/assistants.py index c275357..0a019cc 100644 --- a/django_ai_assistant/helpers/assistants.py +++ b/django_ai_assistant/helpers/assistants.py @@ -114,7 +114,7 @@ def __init_subclass__(cls, **kwargs): """Called when a class is subclassed from AIAssistant. This method is automatically invoked when a new subclass of AIAssistant - is created. It allows AIAssistant to perform additional setup or configuration + is created. It allows AIAssistant to perform additional setup or configuration. for the subclass, such as registering the subclass in a registry. Args: @@ -237,7 +237,7 @@ def get_model_kwargs(self) -> dict[str, Any]: def get_prompt_template(self) -> ChatPromptTemplate: """Get the `ChatPromptTemplate` for the Langchain chain to use.\n - The system prompt come from the `get_instructions` method.\n + The system prompt comes from the `get_instructions` method.\n The template includes placeholders for the instructions, chat `{history}`, user `{input}`, and `{agent_scratchpad}`, all which are necessary for the chain to work properly.\n The chat history is filled by the chain using the message history from `get_message_history`.\n @@ -417,7 +417,7 @@ def get_history_aware_retriever(self) -> Runnable[dict, RetrieverOutput]: def as_chain(self, thread_id: int | None) -> Runnable[dict, dict]: """Create the Langchain chain for the assistant.\n - This chain is a agent that supports chat history, tool calling, and RAG (if `has_rag=True`).\n + This chain is an agent that supports chat history, tool calling, and RAG (if `has_rag=True`).\n `as_chain` uses many other methods to create the chain.\n Prefer to override the other methods to customize the chain for the assistant. Only override this method if you need to customize the chain at a lower level. From bcda8f79f60c0a51c123b2c463904e5de1568442 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Fl=C3=A1vio=20Juvenal?= Date: Thu, 20 Jun 2024 09:43:14 -0300 Subject: [PATCH 3/4] Apply suggestions from code review Co-authored-by: Pamella Bezerra --- django_ai_assistant/helpers/assistants.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/django_ai_assistant/helpers/assistants.py b/django_ai_assistant/helpers/assistants.py index 0a019cc..68977df 100644 --- a/django_ai_assistant/helpers/assistants.py +++ b/django_ai_assistant/helpers/assistants.py @@ -114,7 +114,7 @@ def __init_subclass__(cls, **kwargs): """Called when a class is subclassed from AIAssistant. This method is automatically invoked when a new subclass of AIAssistant - is created. It allows AIAssistant to perform additional setup or configuration. + is created. It allows AIAssistant to perform additional setup or configuration for the subclass, such as registering the subclass in a registry. Args: From 2539c61dafc21018976bb88c55d0cc436541c63c Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Fl=C3=A1vio=20Juvenal?= Date: Thu, 20 Jun 2024 09:44:03 -0300 Subject: [PATCH 4/4] Fix typo --- django_ai_assistant/helpers/assistants.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/django_ai_assistant/helpers/assistants.py b/django_ai_assistant/helpers/assistants.py index 68977df..8a89cc0 100644 --- a/django_ai_assistant/helpers/assistants.py +++ b/django_ai_assistant/helpers/assistants.py @@ -378,9 +378,9 @@ def get_contextualize_prompt(self) -> ChatPromptTemplate: return ChatPromptTemplate.from_messages( [ ("system", contextualize_q_system_prompt), - # TODO: make history key confirgurable? + # TODO: make history key configurable? MessagesPlaceholder("history"), - # TODO: make input key confirgurable? + # TODO: make input key configurable? ("human", "{input}"), ] )