Skip to content

Commit

Permalink
Incomplete
Browse files Browse the repository at this point in the history
  • Loading branch information
fjsj committed Sep 27, 2024
1 parent 39560bc commit cdfe463
Show file tree
Hide file tree
Showing 4 changed files with 60 additions and 45 deletions.
48 changes: 22 additions & 26 deletions django_ai_assistant/helpers/assistants.py
Original file line number Diff line number Diff line change
Expand Up @@ -75,6 +75,8 @@ class AIAssistant(abc.ABC): # noqa: F821
"""
temperature: float = 1.0
"""Temperature to use for the assistant LLM model.\nDefaults to `1.0`."""
tool_max_concurrency: int = 1
"""Maximum number of tools to run concurrently / in parallel.\nDefaults to `1` (no concurrency)."""
has_rag: bool = False
"""Whether the assistant uses RAG (Retrieval-Augmented Generation) or not.\n
Defaults to `False`.
Expand Down Expand Up @@ -430,7 +432,7 @@ def as_graph(self, thread_id: Any | None = None) -> Runnable[dict, dict]:
llm_with_tools = llm.bind_tools(tools) if tools else llm

def custom_add_messages(left: list[BaseMessage], right: list[BaseMessage]):
result = add_messages(left, right)
result = add_messages(left, right) # type: ignore

if message_history:
messages_to_store = [
Expand All @@ -447,40 +449,30 @@ class AgentState(TypedDict):
messages: Annotated[list[AnyMessage], custom_add_messages]
input: str # noqa: A003
context: str
output: str
output: Any

def setup(state: AgentState):
messages: list[AnyMessage] = [SystemMessage(content=self.get_instructions())]
system_prompt = self.get_instructions()

if self.structured_output:
schema = None

# If Pydantic
if inspect.isclass(self.structured_output) and issubclass(
self.structured_output, BaseModel
):
schema = json.dumps(self.structured_output.model_json_schema())

schema_information = (
f"JSON will have the following schema:\n\n{schema}\n\n" if schema else ""
)
tools_information = "Gather information using tools. " if tools else ""

# The assistant won't have access to the schema of the structured output before
# the last step of the chat. This message gives visibility about what fields the
# response should have so it can gather the necessary information by using tools.
messages.append(
HumanMessage(
content=(
"In the last step of this chat you will be asked to respond in JSON. "
+ schema_information
+ tools_information
+ "Don't generate JSON until you are explicitly told to. "
)
schema_information = (
f"Your JSON output must have the following schema:\n{schema}\n"
if schema
else ""
)
)
json_info = (
"In the last step of this chat you will be asked to respond in JSON. "
+ schema_information
+ "Don't generate JSON until you are explicitly told to. "
)
system_prompt += "\n" + json_info

return {"messages": messages}
return {"messages": [SystemMessage(content=system_prompt)]}

def history(state: AgentState):
history = message_history.messages if message_history else []
Expand Down Expand Up @@ -523,12 +515,14 @@ def tool_selector(state: AgentState):

def record_response(state: AgentState):
if self.structured_output:
# Structured output must happen in the end, to avoid disabling tool calling.
# Tool calling + structured output is not supported by OpenAI:
llm_with_structured_output = self.get_structured_output_llm()
response = llm_with_structured_output.invoke(
[
*state["messages"],
HumanMessage(
content="Use the information gathered in the conversation to answer."
content="Use the information gathered in the conversation to answer with JSON."
),
]
)
Expand Down Expand Up @@ -581,7 +575,9 @@ def invoke(self, *args: Any, thread_id: Any | None, **kwargs: Any) -> dict:
structured like `{"output": "assistant response", "history": ...}`.
"""
graph = self.as_graph(thread_id)
return graph.invoke(*args, **kwargs)
config = kwargs.pop("config", {})
config["max_concurrency"] = config.pop("max_concurrency", self.tool_max_concurrency)
return graph.invoke(*args, config=config, **kwargs)

@with_cast_id
def run(self, message: str, thread_id: Any | None = None, **kwargs: Any) -> str:
Expand Down
1 change: 0 additions & 1 deletion example/demo/views.py
Original file line number Diff line number Diff line change
Expand Up @@ -119,5 +119,4 @@ def get(self, request, *args, **kwargs):

a = TourGuideAIAssistant()
data = a.run(f"My coordinates are: ({coordinates})")

return JsonResponse(data.model_dump())
3 changes: 3 additions & 0 deletions example/example/settings.py
Original file line number Diff line number Diff line change
Expand Up @@ -180,4 +180,7 @@
# Example specific settings:

WEATHER_API_KEY = os.getenv("WEATHER_API_KEY") # get for free at https://www.weatherapi.com/
BRAVE_SEARCH_API_KEY = os.getenv(
"BRAVE_SEARCH_API_KEY"
) # get for free at https://brave.com/search/api/
DJANGO_DOCS_BRANCH = "stable/5.0.x"
53 changes: 35 additions & 18 deletions example/movies/ai_assistants.py
Original file line number Diff line number Diff line change
@@ -1,32 +1,42 @@
from typing import Sequence

from django.conf import settings
from django.db import transaction
from django.db.models import Max
from django.utils import timezone

import requests
from langchain_community.tools.tavily_search import TavilySearchResults
from langchain_community.tools import BraveSearch
from langchain_core.tools import BaseTool
from pydantic import BaseModel

from django_ai_assistant import AIAssistant, method_tool
from movies.models import MovieBacklogItem


class IMDbMovie(BaseModel):
imdb_url: str
imdb_rating: float
scrapped_imdb_page_markdown: str


# Note this assistant is not registered, but we'll use it as a tool on the other.
# This one shouldn't be used directly, as it does web searches and scraping.
class IMDbURLFinderTool(AIAssistant):
id = "imdb_url_finder" # noqa: A003
class IMDbScraper(AIAssistant):
id = "imdb_scraper" # noqa: A003
instructions = (
"You're a tool to find the IMDb URL of a given movie. "
"Use the Tavily Search API to find the IMDb URL. "
"You're a tool to find the IMDb URL of a given movie, "
"and scrape this URL to get the movie rating and other information.\n"
"Use the search tool to find the IMDb URL. "
"Make search queries like: \n"
"- IMDb page of The Matrix\n"
"- IMDb page of The Godfather\n"
"- IMDb page of The Shawshank Redemption\n"
"Then check results and provide only the IMDb URL to the user."
"Then check results, scape the IMDb URL, process the page, and produce a JSON output."
)
name = "IMDb URL Finder"
model = "gpt-4o-mini"
name = "IMDb Scraper"
model = "gpt-4o"
structured_output = IMDbMovie

def get_instructions(self):
# Warning: this will use the server's timezone
Expand All @@ -35,9 +45,16 @@ def get_instructions(self):
current_date_str = timezone.now().date().isoformat()
return f"{self.instructions} Today is: {current_date_str}."

@method_tool
def scrape_imdb_url(self, url: str) -> str:
"""Scrape the IMDb URL and return the content as markdown."""
return requests.get("https://r.jina.ai/" + url, timeout=20).text[:10000]

def get_tools(self) -> Sequence[BaseTool]:
return [
TavilySearchResults(),
BraveSearch.from_api_key(
api_key=settings.BRAVE_SEARCH_API_KEY, search_kwargs={"count": 5}
),
*super().get_tools(),
]

Expand All @@ -47,6 +64,11 @@ class MovieRecommendationAIAssistant(AIAssistant):
instructions = (
"You're a helpful movie recommendation assistant. "
"Help the user find movies to watch and manage their movie backlogs. "
"Use the provided tools for that.\n"
"Note the backlog is stored in a DB. "
"When managing the backlog, you must call the tools, to keep the sync with the DB. "
"The backlog has an order, and you should respect it. Call `reorder_backlog` when necessary.\n"
"Include the IMDb URL and rating of the movies when displaying the backlog.\n"
"Ask the user if they want to add your recommended movies to their backlog, "
"but only if the movie is not on the user's backlog yet."
)
Expand All @@ -70,18 +92,13 @@ def get_instructions(self):

def get_tools(self) -> Sequence[BaseTool]:
return [
TavilySearchResults(),
IMDbURLFinderTool().as_tool(description="Tool to find the IMDb URL of a given movie."),
BraveSearch.from_api_key(
api_key=settings.BRAVE_SEARCH_API_KEY, search_kwargs={"count": 5}
),
IMDbScraper().as_tool(description="Tool to get the IMDb data a given movie."),
*super().get_tools(),
]

@method_tool
def scrape_imdb_url(self, url: str) -> str:
"""Scrape the IMDb URL and return the content as markdown.
Use this to get more info about a movie / show, including its rating, plot, cast, etc."""

return requests.get("https://r.jina.ai/" + url, timeout=20).text[:10000]

@method_tool
def get_movies_backlog(self) -> str:
"""Get what movies are on user's backlog."""
Expand Down

0 comments on commit cdfe463

Please sign in to comment.