From 83410a9c4a0464c6f79673ef98b5581c9af3eba6 Mon Sep 17 00:00:00 2001 From: Filipe Ximenes Date: Thu, 12 Sep 2024 15:09:03 -0300 Subject: [PATCH] enables structured output assistants --- django_ai_assistant/helpers/assistants.py | 77 ++++++- example/demo/views.py | 4 +- example/tour_guide/ai_assistants.py | 42 ++-- poetry.lock | 89 ++++++++- ...IAssistant_pydantic_structured_output.yaml | 189 ++++++++++++++++++ ...Assistant_typeddict_structured_output.yaml | 167 ++++++++++++++++ tests/test_helpers/test_assistants.py | 57 +++++- 7 files changed, 582 insertions(+), 43 deletions(-) create mode 100644 tests/test_helpers/cassettes/test_assistants/test_AIAssistant_pydantic_structured_output.yaml create mode 100644 tests/test_helpers/cassettes/test_assistants/test_AIAssistant_typeddict_structured_output.yaml diff --git a/django_ai_assistant/helpers/assistants.py b/django_ai_assistant/helpers/assistants.py index a4010d2..bdbed13 100644 --- a/django_ai_assistant/helpers/assistants.py +++ b/django_ai_assistant/helpers/assistants.py @@ -1,7 +1,8 @@ import abc import inspect +import json import re -from typing import Annotated, Any, ClassVar, Sequence, TypedDict, cast +from typing import Annotated, Any, ClassVar, Dict, Sequence, Type, TypedDict, cast from langchain.chains.combine_documents.base import ( DEFAULT_DOCUMENT_PROMPT, @@ -37,6 +38,7 @@ from langgraph.graph import END, StateGraph from langgraph.graph.message import add_messages from langgraph.prebuilt import ToolNode +from pydantic import BaseModel from django_ai_assistant.decorators import with_cast_id from django_ai_assistant.exceptions import ( @@ -79,6 +81,12 @@ class AIAssistant(abc.ABC): # noqa: F821 When True, the assistant will use a retriever to get documents to provide as context to the LLM. Additionally, the assistant class should implement the `get_retriever` method to return the retriever to use.""" + structured_output: Dict[str, Any] | Type[BaseModel] | Type | None = None + """Structured output to use for the assistant.\n + Defaults to `None`. + When not `None`, the assistant will return a structured output in the provided format. + See https://python.langchain.com/v0.2/docs/how_to/structured_output/ for the available formats. + """ _user: Any | None """The current user the assistant is helping. A model instance.\n Set by the constructor. @@ -269,6 +277,27 @@ def get_llm(self) -> BaseChatModel: model_kwargs=model_kwargs, ) + def get_structured_output_llm(self) -> Runnable: + """Get the LLM model to use for the structured output. + By default, this is the `get_llm` method. + + Returns: + BaseChatModel: The LLM model to use for the structured output. + """ + if not self.structured_output: + raise ValueError("structured_output is not defined") + + llm = self.get_llm() + + method = "json_mode" + if isinstance(llm, ChatOpenAI): + # When using ChatOpenAI, it's better to use json_schema method + # because it enables strict mode. + # https://platform.openai.com/docs/guides/structured-outputs + method = "json_schema" + + return llm.with_structured_output(self.structured_output, method=method) + def get_tools(self) -> Sequence[BaseTool]: """Get the list of method tools the assistant can use. By default, this is the `_method_tools` attribute, which are all `@method_tool`s.\n @@ -422,7 +451,36 @@ class AgentState(TypedDict): output: str def setup(state: AgentState): - return {"messages": [SystemMessage(content=self.get_instructions())]} + messages = [SystemMessage(content=self.get_instructions())] + + if self.structured_output: + schema = None + + # If Pydantic + if inspect.isclass(self.structured_output) and issubclass( + self.structured_output, BaseModel + ): + schema = json.dumps(self.structured_output.model_json_schema()) + + schema_information = "" + if schema: + schema_information = f"JSON will have the following schema:\n\n{schema}\n\n" + + # The assistant won't have access to the schema of the structured output before + # the last step of the chat. This message gives visibility about what fields the + # response should have so it can gather the necessary information by using tools. + messages.append( + SystemMessage( + content=( + "In the last step of this chat you will be asked to respond in JSON. " + + schema_information + + "Gather information using tools. " + "Don't generate JSON until you are explicitly told to. " + ) + ) + ) + + return {"messages": messages} def retriever(state: AgentState): if not self.has_rag: @@ -462,7 +520,20 @@ def tool_selector(state: AgentState): return "continue" def record_response(state: AgentState): - return {"output": state["messages"][-1].content} + if self.structured_output: + llm_with_structured_output = self.get_structured_output_llm() + response = llm_with_structured_output.invoke( + [ + *state["messages"], + SystemMessage( + content="Use the information gathered in the conversation to answer." + ), + ] + ) + else: + response = state["messages"][-1].content + + return {"output": response} workflow = StateGraph(AgentState) diff --git a/example/demo/views.py b/example/demo/views.py index 74c2a1c..d2b91b6 100644 --- a/example/demo/views.py +++ b/example/demo/views.py @@ -1,5 +1,3 @@ -import json - from django.contrib import messages from django.http import JsonResponse from django.shortcuts import get_object_or_404, redirect, render @@ -122,4 +120,4 @@ def get(self, request, *args, **kwargs): a = TourGuideAIAssistant() data = a.run(f"My coordinates are: ({coordinates})") - return JsonResponse(json.loads(data)) + return JsonResponse(data.model_dump()) diff --git a/example/tour_guide/ai_assistants.py b/example/tour_guide/ai_assistants.py index e44840f..c91dd85 100644 --- a/example/tour_guide/ai_assistants.py +++ b/example/tour_guide/ai_assistants.py @@ -2,49 +2,37 @@ from django.utils import timezone +from pydantic import BaseModel, Field + from django_ai_assistant import AIAssistant, method_tool from tour_guide.integrations import fetch_points_of_interest -def _tour_guide_example_json(): - return json.dumps( - { - "nearby_attractions": [ - { - "attraction_name": f"", - "attraction_description": f"", - "attraction_url": f"", - } - for i in range(1, 6) - ] - }, - indent=2, - ).translate( # Necessary due to ChatPromptTemplate - str.maketrans( - { - "{": "{{", - "}": "}}", - } - ) +class Attraction(BaseModel): + attraction_name: str = Field(description="The name of the attraction in english") + attraction_description: str = Field( + description="The description of the attraction, provide information in an entertaining way" + ) + attraction_url: str = Field( + description="The URL of the attraction, keep empty if you don't have this information" ) +class TourGuide(BaseModel): + nearby_attractions: list[Attraction] = Field(description="The list of nearby attractions") + + class TourGuideAIAssistant(AIAssistant): id = "tour_guide_assistant" # noqa: A003 name = "Tour Guide Assistant" instructions = ( "You are a tour guide assistant that offers information about nearby attractions. " "You will receive the user coordinates and should use available tools to find nearby attractions. " + "Only include in your response the items that are relevant to a tourist visiting the area. " "Only call the find_nearby_attractions tool once. " - "Your response should only contain valid JSON data. DON'T include '```json' in your response. " - "The JSON should be formatted according to the following structure: \n" - f"\n\n{_tour_guide_example_json()}\n\n\n" - "In the 'attraction_name' field provide the name of the attraction in english. " - "In the 'attraction_description' field generate an overview about the attraction with the most important information, " - "curiosities and interesting facts. " - "Only include a value for the 'attraction_url' field if you find a real value in the provided data otherwise keep it empty. " ) model = "gpt-4o-2024-08-06" + structured_output = TourGuide def get_instructions(self): # Warning: this will use the server's timezone diff --git a/poetry.lock b/poetry.lock index b02684b..b1481ca 100644 --- a/poetry.lock +++ b/poetry.lock @@ -1324,6 +1324,76 @@ MarkupSafe = ">=2.0" [package.extras] i18n = ["Babel (>=2.7)"] +[[package]] +name = "jiter" +version = "0.5.0" +description = "Fast iterable JSON parser." +optional = false +python-versions = ">=3.8" +files = [ + {file = "jiter-0.5.0-cp310-cp310-macosx_10_12_x86_64.whl", hash = "sha256:b599f4e89b3def9a94091e6ee52e1d7ad7bc33e238ebb9c4c63f211d74822c3f"}, + {file = "jiter-0.5.0-cp310-cp310-macosx_11_0_arm64.whl", hash = "sha256:2a063f71c4b06225543dddadbe09d203dc0c95ba352d8b85f1221173480a71d5"}, + {file = "jiter-0.5.0-cp310-cp310-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:acc0d5b8b3dd12e91dd184b87273f864b363dfabc90ef29a1092d269f18c7e28"}, + {file = "jiter-0.5.0-cp310-cp310-manylinux_2_17_armv7l.manylinux2014_armv7l.whl", hash = "sha256:c22541f0b672f4d741382a97c65609332a783501551445ab2df137ada01e019e"}, + {file = "jiter-0.5.0-cp310-cp310-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:63314832e302cc10d8dfbda0333a384bf4bcfce80d65fe99b0f3c0da8945a91a"}, + {file = "jiter-0.5.0-cp310-cp310-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:a25fbd8a5a58061e433d6fae6d5298777c0814a8bcefa1e5ecfff20c594bd749"}, + {file = "jiter-0.5.0-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:503b2c27d87dfff5ab717a8200fbbcf4714516c9d85558048b1fc14d2de7d8dc"}, + {file = "jiter-0.5.0-cp310-cp310-manylinux_2_5_i686.manylinux1_i686.whl", hash = "sha256:6d1f3d27cce923713933a844872d213d244e09b53ec99b7a7fdf73d543529d6d"}, + {file = "jiter-0.5.0-cp310-cp310-musllinux_1_1_aarch64.whl", hash = "sha256:c95980207b3998f2c3b3098f357994d3fd7661121f30669ca7cb945f09510a87"}, + {file = "jiter-0.5.0-cp310-cp310-musllinux_1_1_x86_64.whl", hash = "sha256:afa66939d834b0ce063f57d9895e8036ffc41c4bd90e4a99631e5f261d9b518e"}, + {file = "jiter-0.5.0-cp310-none-win32.whl", hash = "sha256:f16ca8f10e62f25fd81d5310e852df6649af17824146ca74647a018424ddeccf"}, + {file = "jiter-0.5.0-cp310-none-win_amd64.whl", hash = "sha256:b2950e4798e82dd9176935ef6a55cf6a448b5c71515a556da3f6b811a7844f1e"}, + {file = "jiter-0.5.0-cp311-cp311-macosx_10_12_x86_64.whl", hash = "sha256:d4c8e1ed0ef31ad29cae5ea16b9e41529eb50a7fba70600008e9f8de6376d553"}, + {file = "jiter-0.5.0-cp311-cp311-macosx_11_0_arm64.whl", hash = "sha256:c6f16e21276074a12d8421692515b3fd6d2ea9c94fd0734c39a12960a20e85f3"}, + {file = "jiter-0.5.0-cp311-cp311-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:5280e68e7740c8c128d3ae5ab63335ce6d1fb6603d3b809637b11713487af9e6"}, + {file = "jiter-0.5.0-cp311-cp311-manylinux_2_17_armv7l.manylinux2014_armv7l.whl", hash = "sha256:583c57fc30cc1fec360e66323aadd7fc3edeec01289bfafc35d3b9dcb29495e4"}, + {file = "jiter-0.5.0-cp311-cp311-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:26351cc14507bdf466b5f99aba3df3143a59da75799bf64a53a3ad3155ecded9"}, + {file = "jiter-0.5.0-cp311-cp311-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:4829df14d656b3fb87e50ae8b48253a8851c707da9f30d45aacab2aa2ba2d614"}, + {file = "jiter-0.5.0-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:a42a4bdcf7307b86cb863b2fb9bb55029b422d8f86276a50487982d99eed7c6e"}, + {file = "jiter-0.5.0-cp311-cp311-manylinux_2_5_i686.manylinux1_i686.whl", hash = "sha256:04d461ad0aebf696f8da13c99bc1b3e06f66ecf6cfd56254cc402f6385231c06"}, + {file = "jiter-0.5.0-cp311-cp311-musllinux_1_1_aarch64.whl", hash = "sha256:e6375923c5f19888c9226582a124b77b622f8fd0018b843c45eeb19d9701c403"}, + {file = "jiter-0.5.0-cp311-cp311-musllinux_1_1_x86_64.whl", hash = "sha256:2cec323a853c24fd0472517113768c92ae0be8f8c384ef4441d3632da8baa646"}, + {file = "jiter-0.5.0-cp311-none-win32.whl", hash = "sha256:aa1db0967130b5cab63dfe4d6ff547c88b2a394c3410db64744d491df7f069bb"}, + {file = "jiter-0.5.0-cp311-none-win_amd64.whl", hash = "sha256:aa9d2b85b2ed7dc7697597dcfaac66e63c1b3028652f751c81c65a9f220899ae"}, + {file = "jiter-0.5.0-cp312-cp312-macosx_10_12_x86_64.whl", hash = "sha256:9f664e7351604f91dcdd557603c57fc0d551bc65cc0a732fdacbf73ad335049a"}, + {file = "jiter-0.5.0-cp312-cp312-macosx_11_0_arm64.whl", hash = "sha256:044f2f1148b5248ad2c8c3afb43430dccf676c5a5834d2f5089a4e6c5bbd64df"}, + {file = "jiter-0.5.0-cp312-cp312-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:702e3520384c88b6e270c55c772d4bd6d7b150608dcc94dea87ceba1b6391248"}, + {file = "jiter-0.5.0-cp312-cp312-manylinux_2_17_armv7l.manylinux2014_armv7l.whl", hash = "sha256:528d742dcde73fad9d63e8242c036ab4a84389a56e04efd854062b660f559544"}, + {file = "jiter-0.5.0-cp312-cp312-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:8cf80e5fe6ab582c82f0c3331df27a7e1565e2dcf06265afd5173d809cdbf9ba"}, + {file = "jiter-0.5.0-cp312-cp312-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:44dfc9ddfb9b51a5626568ef4e55ada462b7328996294fe4d36de02fce42721f"}, + {file = "jiter-0.5.0-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:c451f7922992751a936b96c5f5b9bb9312243d9b754c34b33d0cb72c84669f4e"}, + {file = "jiter-0.5.0-cp312-cp312-manylinux_2_5_i686.manylinux1_i686.whl", hash = "sha256:308fce789a2f093dca1ff91ac391f11a9f99c35369117ad5a5c6c4903e1b3e3a"}, + {file = "jiter-0.5.0-cp312-cp312-musllinux_1_1_aarch64.whl", hash = "sha256:7f5ad4a7c6b0d90776fdefa294f662e8a86871e601309643de30bf94bb93a64e"}, + {file = "jiter-0.5.0-cp312-cp312-musllinux_1_1_x86_64.whl", hash = "sha256:ea189db75f8eca08807d02ae27929e890c7d47599ce3d0a6a5d41f2419ecf338"}, + {file = "jiter-0.5.0-cp312-none-win32.whl", hash = "sha256:e3bbe3910c724b877846186c25fe3c802e105a2c1fc2b57d6688b9f8772026e4"}, + {file = "jiter-0.5.0-cp312-none-win_amd64.whl", hash = "sha256:a586832f70c3f1481732919215f36d41c59ca080fa27a65cf23d9490e75b2ef5"}, + {file = "jiter-0.5.0-cp38-cp38-macosx_10_12_x86_64.whl", hash = "sha256:f04bc2fc50dc77be9d10f73fcc4e39346402ffe21726ff41028f36e179b587e6"}, + {file = "jiter-0.5.0-cp38-cp38-macosx_11_0_arm64.whl", hash = "sha256:6f433a4169ad22fcb550b11179bb2b4fd405de9b982601914ef448390b2954f3"}, + {file = "jiter-0.5.0-cp38-cp38-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:ad4a6398c85d3a20067e6c69890ca01f68659da94d74c800298581724e426c7e"}, + {file = "jiter-0.5.0-cp38-cp38-manylinux_2_17_armv7l.manylinux2014_armv7l.whl", hash = "sha256:6baa88334e7af3f4d7a5c66c3a63808e5efbc3698a1c57626541ddd22f8e4fbf"}, + {file = "jiter-0.5.0-cp38-cp38-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:1ece0a115c05efca597c6d938f88c9357c843f8c245dbbb53361a1c01afd7148"}, + {file = "jiter-0.5.0-cp38-cp38-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:335942557162ad372cc367ffaf93217117401bf930483b4b3ebdb1223dbddfa7"}, + {file = "jiter-0.5.0-cp38-cp38-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:649b0ee97a6e6da174bffcb3c8c051a5935d7d4f2f52ea1583b5b3e7822fbf14"}, + {file = "jiter-0.5.0-cp38-cp38-manylinux_2_5_i686.manylinux1_i686.whl", hash = "sha256:f4be354c5de82157886ca7f5925dbda369b77344b4b4adf2723079715f823989"}, + {file = "jiter-0.5.0-cp38-cp38-musllinux_1_1_aarch64.whl", hash = "sha256:5206144578831a6de278a38896864ded4ed96af66e1e63ec5dd7f4a1fce38a3a"}, + {file = "jiter-0.5.0-cp38-cp38-musllinux_1_1_x86_64.whl", hash = "sha256:8120c60f8121ac3d6f072b97ef0e71770cc72b3c23084c72c4189428b1b1d3b6"}, + {file = "jiter-0.5.0-cp38-none-win32.whl", hash = "sha256:6f1223f88b6d76b519cb033a4d3687ca157c272ec5d6015c322fc5b3074d8a5e"}, + {file = "jiter-0.5.0-cp38-none-win_amd64.whl", hash = "sha256:c59614b225d9f434ea8fc0d0bec51ef5fa8c83679afedc0433905994fb36d631"}, + {file = "jiter-0.5.0-cp39-cp39-macosx_10_12_x86_64.whl", hash = "sha256:0af3838cfb7e6afee3f00dc66fa24695199e20ba87df26e942820345b0afc566"}, + {file = "jiter-0.5.0-cp39-cp39-macosx_11_0_arm64.whl", hash = "sha256:550b11d669600dbc342364fd4adbe987f14d0bbedaf06feb1b983383dcc4b961"}, + {file = "jiter-0.5.0-cp39-cp39-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:489875bf1a0ffb3cb38a727b01e6673f0f2e395b2aad3c9387f94187cb214bbf"}, + {file = "jiter-0.5.0-cp39-cp39-manylinux_2_17_armv7l.manylinux2014_armv7l.whl", hash = "sha256:b250ca2594f5599ca82ba7e68785a669b352156260c5362ea1b4e04a0f3e2389"}, + {file = "jiter-0.5.0-cp39-cp39-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:8ea18e01f785c6667ca15407cd6dabbe029d77474d53595a189bdc813347218e"}, + {file = "jiter-0.5.0-cp39-cp39-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:462a52be85b53cd9bffd94e2d788a09984274fe6cebb893d6287e1c296d50653"}, + {file = "jiter-0.5.0-cp39-cp39-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:92cc68b48d50fa472c79c93965e19bd48f40f207cb557a8346daa020d6ba973b"}, + {file = "jiter-0.5.0-cp39-cp39-manylinux_2_5_i686.manylinux1_i686.whl", hash = "sha256:1c834133e59a8521bc87ebcad773608c6fa6ab5c7a022df24a45030826cf10bc"}, + {file = "jiter-0.5.0-cp39-cp39-musllinux_1_1_aarch64.whl", hash = "sha256:ab3a71ff31cf2d45cb216dc37af522d335211f3a972d2fe14ea99073de6cb104"}, + {file = "jiter-0.5.0-cp39-cp39-musllinux_1_1_x86_64.whl", hash = "sha256:cccd3af9c48ac500c95e1bcbc498020c87e1781ff0345dd371462d67b76643eb"}, + {file = "jiter-0.5.0-cp39-none-win32.whl", hash = "sha256:368084d8d5c4fc40ff7c3cc513c4f73e02c85f6009217922d0823a48ee7adf61"}, + {file = "jiter-0.5.0-cp39-none-win_amd64.whl", hash = "sha256:ce03f7b4129eb72f1687fa11300fbf677b02990618428934662406d2a76742a1"}, + {file = "jiter-0.5.0.tar.gz", hash = "sha256:1d916ba875bcab5c5f7d927df998c4cb694d27dceddf3392e58beaf10563368a"}, +] + [[package]] name = "joblib" version = "1.4.2" @@ -1463,18 +1533,18 @@ typing-extensions = ">=4.7" [[package]] name = "langchain-openai" -version = "0.1.14" +version = "0.1.23" description = "An integration package connecting OpenAI and LangChain" optional = false python-versions = "<4.0,>=3.8.1" files = [ - {file = "langchain_openai-0.1.14-py3-none-any.whl", hash = "sha256:fcd34cc5b5713798908a5828d364b4426e3b1afccbd564d344e5477acb86634a"}, - {file = "langchain_openai-0.1.14.tar.gz", hash = "sha256:1f13d6041e8bedddf6eb47ccad7416e05af38fa169324f7f1bdf4f385780f8d8"}, + {file = "langchain_openai-0.1.23-py3-none-any.whl", hash = "sha256:8e3d215803e157f26480c6108eb4333629832b1a0e746723060c24f93b8b78f4"}, + {file = "langchain_openai-0.1.23.tar.gz", hash = "sha256:ed7f16671ea0af177ac5f82d5645a746c5097c56f97b31798e5c07b5c84f0eed"}, ] [package.dependencies] -langchain-core = ">=0.2.2,<0.3" -openai = ">=1.32.0,<2.0.0" +langchain-core = ">=0.2.35,<0.3.0" +openai = ">=1.40.0,<2.0.0" tiktoken = ">=0.7,<1" [[package]] @@ -2251,23 +2321,24 @@ files = [ [[package]] name = "openai" -version = "1.35.10" +version = "1.44.1" description = "The official Python library for the openai API" optional = false python-versions = ">=3.7.1" files = [ - {file = "openai-1.35.10-py3-none-any.whl", hash = "sha256:962cb5c23224b5cbd16078308dabab97a08b0a5ad736a4fdb3dc2ffc44ac974f"}, - {file = "openai-1.35.10.tar.gz", hash = "sha256:85966949f4f960f3e4b239a659f9fd64d3a97ecc43c44dc0a044b5c7f11cccc6"}, + {file = "openai-1.44.1-py3-none-any.whl", hash = "sha256:07e2c2758d1c94151c740b14dab638ba0d04bcb41a2e397045c90e7661cdf741"}, + {file = "openai-1.44.1.tar.gz", hash = "sha256:e0ffdab601118329ea7529e684b606a72c6c9d4f05be9ee1116255fcf5593874"}, ] [package.dependencies] anyio = ">=3.5.0,<5" distro = ">=1.7.0,<2" httpx = ">=0.23.0,<1" +jiter = ">=0.4.0,<1" pydantic = ">=1.9.0,<3" sniffio = "*" tqdm = ">4" -typing-extensions = ">=4.7,<5" +typing-extensions = ">=4.11,<5" [package.extras] datalib = ["numpy (>=1)", "pandas (>=1.2.3)", "pandas-stubs (>=1.1.0.11)"] diff --git a/tests/test_helpers/cassettes/test_assistants/test_AIAssistant_pydantic_structured_output.yaml b/tests/test_helpers/cassettes/test_assistants/test_AIAssistant_pydantic_structured_output.yaml new file mode 100644 index 0000000..5b10c66 --- /dev/null +++ b/tests/test_helpers/cassettes/test_assistants/test_AIAssistant_pydantic_structured_output.yaml @@ -0,0 +1,189 @@ +interactions: +- request: + body: '{"messages": [{"content": "You are a helpful assistant that provides information + about people.", "role": "system"}, {"content": "In the last step of this chat + you will be asked to respond in JSON. Gather information using tools. Don''t + generate JSON until you are explicitly told to. ", "role": "system"}, {"content": + "Tell me about John who is 30 years old and not a student.", "role": "user"}], + "model": "gpt-4o-2024-08-06", "n": 1, "stream": false, "temperature": 1.0}' + headers: + accept: + - application/json + accept-encoding: + - gzip, deflate + authorization: + - DUMMY + connection: + - keep-alive + content-length: + - '470' + content-type: + - application/json + host: + - api.openai.com + user-agent: + - OpenAI/Python + method: POST + uri: https://api.openai.com/v1/chat/completions + response: + body: + string: !!binary | + H4sIAAAAAAAAA1RSwW7bMAy9+ysInZ0iTlyny2XoZQOG7rChQA/rECgybWmVRU2km2ZF/32Q4ybb + RRDeI58e9fhaACjXqi0oY7WYIfrFbeO+fzK3rglPvx/oW/Xn7u7rYWUflk19/KzK3EH7X2jkvevK + 0BA9iqNwok1CLZhVq82qqZq6ul5PxEAt+tzWR1nUtFgtV/ViebNYNnOjJWeQ1RZ+FAAAr9OZLYYW + X9QWluU7MiCz7lFtz0UAKpHPiNLMjkUHUeWFNBQEw+T6niAmenYtwpFGODixkNDjsw4CLnSUBp3H + Ab2nUeAL2VCCodG3U330qBmBrU4IAyWEFkU7z0AJpmdeBBL2OrUu9HCwzljgiMZ1zsCjynqPapLK + Cgk7TClXCn2Ee4sJJ3zQ4QgRKXo8WRSrBYIesAQmyKxuW5eNan+2wKOxoBmsYyBjxjhNUoInM98o + gQuCCVkYDtNUFn0E43Vy3XH2GzExhcnkgDpc/fuVCbuRdU4yjN7P+Ns5G099TLTnmT/jnQuO7S6h + Zgo5BxaKamLfCoCf0w6M/8WqYqIhyk7oCUMW3KxPcuqydBfyejOTQqL9Ba/Wy2I2qPjIgsOuc6HH + FJM7bUQXd1W93u9v6g8ro4q34i8AAAD//wMA5g4c6BoDAAA= + headers: + Connection: + - keep-alive + Content-Encoding: + - gzip + Content-Type: + - application/json + Date: Sun, 09 Jun 2024 23:39:08 GMT + Server: DUMMY + Transfer-Encoding: + - chunked + X-Content-Type-Options: + - nosniff + access-control-expose-headers: + - X-Request-ID + status: + code: 200 + message: OK +- request: + body: '{"messages": [{"content": "You are a helpful assistant that provides information + about people.", "role": "system"}, {"content": "In the last step of this chat + you will be asked to respond in JSON. JSON will have the following schema:\n\n{\"properties\": + {\"name\": {\"title\": \"Name\", \"type\": \"string\"}, \"age\": {\"title\": + \"Age\", \"type\": \"integer\"}, \"is_student\": {\"title\": \"Is Student\", + \"type\": \"boolean\"}}, \"required\": [\"name\", \"age\", \"is_student\"], + \"title\": \"OutputSchema\", \"type\": \"object\"}\n\nGather information using + tools. Don''t generate JSON until you are explicitly told to. ", "role": "system"}, + {"content": "Tell me about John who is 30 years old and not a student.", "role": + "user"}], "model": "gpt-4o-2024-08-06", "n": 1, "stream": false, "temperature": + 1.0}' + headers: + accept: + - application/json + accept-encoding: + - gzip, deflate + authorization: + - DUMMY + connection: + - keep-alive + content-length: + - '811' + content-type: + - application/json + host: + - api.openai.com + user-agent: + - OpenAI/Python + method: POST + uri: https://api.openai.com/v1/chat/completions + response: + body: + string: !!binary | + H4sIAAAAAAAAA1RSy27bMBC86ysWvPQiB5LiR+tLUaSXBL0ELdoAQWHQ5FpiTXFZ7iquEfjfC8qK + 3V54mOHMDnf4WgAoZ9UalOm0mD762ael+/rFfj/Yav94d9f2WD+F1cP9D//4lFiVWUHbX2jkTXVj + qI8exVE40yahFsyu9apZ1st509Qj0ZNFn2VtlNmcZk3VzGfV+1m1nIQdOYOs1vBcAAC8jmeOGCz+ + UWuoyjekR2bdolpfLgGoRD4jSjM7Fh1ElVfSUBAMY+pvBDHRi7MIRxrg4KQD6RAS/h6QBS24sKPU + 6/ymEu7feQ+tlg4TWBTtPIPe0iCgIWJiChB0jxYeqAtw6Agcw20FR9SJgbwFHSwEyvdZBotBbuAz + jbM7/YKgwxE4onE7Z2CM4Cgw0HXcGFMHASHYBzpM86VzPE4tgQfTgWbICBkzxCm8C4IJWbjMfp7M + iH/8dzMJdwPrXEwYvJ/w02XVntqYaMsTf8F3LjjuNgk1U8hrZaGoRvZUAPwcKx3+a0nFRH2UjdAe + QzasF6uzn7p+oiu7uJ1IIdH+ijd1VUwJFR9ZsN/sXGgxxeTODe/i5gPWC2PsUs9VcSr+AgAA//8D + AEEAD0DqAgAA + headers: + Connection: + - keep-alive + Content-Encoding: + - gzip + Content-Type: + - application/json + Date: Sun, 09 Jun 2024 23:39:08 GMT + Server: DUMMY + Transfer-Encoding: + - chunked + X-Content-Type-Options: + - nosniff + access-control-expose-headers: + - X-Request-ID + status: + code: 200 + message: OK +- request: + body: '{"messages": [{"content": "You are a helpful assistant that provides information + about people.", "role": "system"}, {"content": "In the last step of this chat + you will be asked to respond in JSON. JSON will have the following schema:\n\n{\"properties\": + {\"name\": {\"title\": \"Name\", \"type\": \"string\"}, \"age\": {\"title\": + \"Age\", \"type\": \"integer\"}, \"is_student\": {\"title\": \"Is Student\", + \"type\": \"boolean\"}}, \"required\": [\"name\", \"age\", \"is_student\"], + \"title\": \"OutputSchema\", \"type\": \"object\"}\n\nGather information using + tools. Don''t generate JSON until you are explicitly told to. ", "role": "system"}, + {"content": "Tell me about John who is 30 years old and not a student.", "role": + "user"}, {"content": "To provide you with the requested information, I''ll gather + details about a person named John who is 30 years old and not a student. Do + you have any specific questions or details you want to know about this John, + such as his occupation, interests, or location?", "role": "assistant"}, {"content": + "Use the information gathered in the conversation to answer.", "role": "system"}], + "model": "gpt-4o-2024-08-06", "n": 1, "response_format": {"type": "json_schema", + "json_schema": {"schema": {"properties": {"name": {"title": "Name", "type": + "string"}, "age": {"title": "Age", "type": "integer"}, "is_student": {"title": + "Is Student", "type": "boolean"}}, "required": ["name", "age", "is_student"], + "title": "OutputSchema", "type": "object", "additionalProperties": false}, "name": + "OutputSchema", "strict": true}}, "temperature": 1.0}' + headers: + accept: + - application/json + accept-encoding: + - gzip, deflate + authorization: + - DUMMY + connection: + - keep-alive + content-length: + - '1578' + content-type: + - application/json + host: + - api.openai.com + user-agent: + - OpenAI/Python + x-stainless-helper-method: + - beta.chat.completions.parse + method: POST + uri: https://api.openai.com/v1/chat/completions + response: + body: + string: !!binary | + H4sIAAAAAAAAA1SQwU+DMBjF7/wVzXcGA4zhxs3ExLgYPajxIIaU8gHdStvQEmcW/ndTxoZeeni/ + vpf3vZNHCPAKMgKspZZ1WgR3KX99KYfmvi3l80O6e/z4Xh9xfyjl+9Mb+M6hyj0ye3HdMNVpgZYr + ecasR2rRpUa3cRqlSRwnE+hUhcLZGm2DRAVxGCdBuAnCdDa2ijM0kJFPjxBCTtPrKsoKj5CR0L8o + HRpDG4Ts+okQ6JVwClBjuLFUWvAXyJS0KKfWpxwk7TCHLIedamUOfg60ccIq9HPgpjB2qFDaHLKa + CoPj36Qe68FQd4gchJj18VpNqEb3qjQzv+o1l9y0RY/UKOlqGKs0THT0CPmaJhj+XQW6V522hVUH + lC4w3kbnPFhGX2iUzNAqS8Wir8K1NzcE82MsdkXNZYO97vl5kVoXUbIqy02yjRl4o/cLAAD//wMA + b8XsohoCAAA= + headers: + Connection: + - keep-alive + Content-Encoding: + - gzip + Content-Type: + - application/json + Date: Sun, 09 Jun 2024 23:39:08 GMT + Server: DUMMY + Transfer-Encoding: + - chunked + X-Content-Type-Options: + - nosniff + access-control-expose-headers: + - X-Request-ID + status: + code: 200 + message: OK +version: 1 diff --git a/tests/test_helpers/cassettes/test_assistants/test_AIAssistant_typeddict_structured_output.yaml b/tests/test_helpers/cassettes/test_assistants/test_AIAssistant_typeddict_structured_output.yaml new file mode 100644 index 0000000..d830e13 --- /dev/null +++ b/tests/test_helpers/cassettes/test_assistants/test_AIAssistant_typeddict_structured_output.yaml @@ -0,0 +1,167 @@ +interactions: +- request: + body: '{"messages": [{"content": "You are a helpful assistant that provides information + about movies.", "role": "system"}, {"content": "In the last step of this chat + you will be asked to respond in JSON. Gather information using tools. Don''t + generate JSON until you are explicitly told to. ", "role": "system"}, {"content": + "Provide information about the movie Shrek. It was released in 2001 and is an + animation and comedy movie.", "role": "user"}], "model": "gpt-4o-2024-08-06", + "n": 1, "stream": false, "temperature": 1.0}' + headers: + accept: + - application/json + accept-encoding: + - gzip, deflate + authorization: + - DUMMY + connection: + - keep-alive + content-length: + - '517' + content-type: + - application/json + host: + - api.openai.com + user-agent: + - OpenAI/Python + method: POST + uri: https://api.openai.com/v1/chat/completions + response: + body: + string: !!binary | + H4sIAAAAAAAAAwAAAP//bFZRb9xGDn73ryD0dDV2F7brprbf3DhpmsYHw+u74HAuAu6IktidGaqc + kTdKkf9+4EjrdYJ7EXaH5Dfkx4/E/H0EUHFdXUHlOswu9H55/Yoffht+Pb/78s/txe7H3z//ZxPX + N018H+/avlpYhGz+JJf3USsnofeUWeJkdkqYyVBPfz57dfrq/OzipBiC1OQtrO3z8lyWZydn58uT + i+XJqzmwE3aUqiv47xEAwN/laynGmj5XV1BgykmglLCl6urZCaBS8XZSYUqcMsZcLQ5GJzFTLFk/ + VutOaftYASfACBg5WMbgJFA9QsM+QK9SD45q2Ixwo4Tho+g2wXXxZbGoGpQ8YaIaOMLZycnpCh46 + muI5waaYzBVOLy9PoEHWETJ6gp5dHpRgI7IFaSB3BAkDQbTPZoSP7D1jgHUmblfwjpQAlSBJINjS + CDVlZJ8ANzLkEh/kienqMT7GJRwf37CSyyX/4+MruI610g6uawxpTv7f7LYjvKeYZI65O9RcYlQi + fESNpAt4L12Ed6t9YmlRMN5T0yiN8DvmLxQ3pO2EtHZKFHuP44z1QDW88Z4l5wU8kOoI95ISiyET + rDN7HzBOqPfSksJ6ZfetXTeYZcbNqMqxNcjHCLCEW94S3I6kCTBBaSz848mE9MPk8KaumeB20L4b + zeVGovH3jc9rDGTF3jB+MZ875egoJXjLEvFb30LEB85dKzvz/SBaw1vUvwbE+uBq2d4Oid1MwDu0 + mn9VapPE5Z7FiURDvJMdef+yD0VlryX0GAvC/1PhvtcpK2+GQ7tf+N5NUkuT6/0QI8cWMgcqPe57 + lc9F/36EyxMIHIe89/6AsR2wLZ5vYus5dZPhtQwxa7nqX5Ht3nXG57D7aSzgBnMJvcURTi8WZURM + n8fHd14yrIcQsIA8xsNMZvI+TQORRUebDgQl54fET1QIa3UI/QjS6jQw9dz3XSfQcKwTdJwg7TD0 + IE+kOkSbqRfzV3aUsQK5wwwdPhFsiCJsMHLqprG3HOiJ/bctXsGDgFKLHKdrxHMealrMSQTcUgKE + mtDDjnN3EEe2wOQG+l5gVtTGdA0dqbl1HFZw7cU61RHscNzD28pyrqiCpzTRStpiNnbqSdwTKZPS + V3C9nwu7ZrqwpWzXbKPsgNB1ILmzKc8djYXC6VfhJYiSLTgnIUg0wuJk5ciZ0XtjSoa2y6upuffk + qDd1ftfZHRovTjmzQ1+yMUxSx+ghDc4YWcFvuXj2imz7sxEFzgm6IYgu5mXNEhdQRg160kY0oPE5 + rY8kQ6yzotu+2Mfok0BA3VJdmtOjlg3cqITCcVas2YDRv1RKwc7Gc8LMyl+sS3ZLjyr1yFOHQkl7 + Q87Wt0Ta7/SO245SXrZqyy62JZdkVpuFgrMrnFqWmvKSnkjh2mFNYYTrHepEwC+U8jz1VMPbSbzP + VJXaogSOxXwIqLEvs/m8jecOfaAW3feDl3rcRaohDD5z7wkS/TWQTwvg6PxQW/qzN5wtHqvnP5b9 + Q8da26GVtDe8FZ3qaTLp3jiThJC4jdyww5iNzGys9NIPHhXc4J8rLNVxTD1r6V3qOS6laQqXC3is + 7oaUTKC/iOS0vwYh2PY1nRkLRTKrlw8CpWZIaO+ROHg/n399fmF4aXuVTZrtz+cN24L4pIRJor0m + Upa+KtavRwB/lJfM8M3jpOpVQp8/ZdlSNMCL0wmuOjydDsbzy4vZmiWjPxh++vnyaM6wSmPKFD41 + HFvSXnl62DT9p9PzHzebi/PLM1cdfT36HwAAAP//AwBZgHU64QkAAA== + headers: + Connection: + - keep-alive + Content-Encoding: + - gzip + Content-Type: + - application/json + Date: Sun, 09 Jun 2024 23:39:08 GMT + Server: DUMMY + Transfer-Encoding: + - chunked + X-Content-Type-Options: + - nosniff + access-control-expose-headers: + - X-Request-ID + status: + code: 200 + message: OK +- request: + body: '{"messages": [{"content": "You are a helpful assistant that provides information + about movies.", "role": "system"}, {"content": "In the last step of this chat + you will be asked to respond in JSON. Gather information using tools. Don''t + generate JSON until you are explicitly told to. ", "role": "system"}, {"content": + "Provide information about the movie Shrek. It was released in 2001 and is an + animation and comedy movie.", "role": "user"}, {"content": "\"Shrek\" is an + animated comedy film produced by DreamWorks Animation and released in 2001. + The film is based on a 1990 fairy tale picture book of the same name by William + Steig. Here are some key details about the movie:\n\n- **Directed by**: Andrew + Adamson and Vicky Jenson\n- **Produced by**: Aron Warner, John H. Williams, + and Jeffrey Katzenberg\n- **Screenplay by**: Ted Elliott, Terry Rossio, Joe + Stillman, and Roger S. H. Schulman\n- **Starring**: \n - Mike Myers as Shrek + (voice)\n - Eddie Murphy as Donkey (voice)\n - Cameron Diaz as Princess Fiona + (voice)\n - John Lithgow as Lord Farquaad (voice)\n- **Music by**: Harry Gregson-Williams + and John Powell\n- **Production Company**: DreamWorks Animation\n- **Distributed + by**: DreamWorks Pictures\n- **Running time**: Approximately 90 minutes\n- **Language**: + English\n- **Country**: United States\n- **Release Date**: May 18, 2001\n\n**Plot + Summary**:\n\"Shrek\" tells the story of a reclusive and grumpy ogre named Shrek + who finds his swamp overrun by fairy tale creatures that have been banished + by the evil Lord Farquaad. To regain his solitude, Shrek makes a deal with Farquaad + to rescue Princess Fiona and bring her to him. Along the way, Shrek is accompanied + by a talkative donkey named Donkey. As Shrek and Fiona get to know each other, + they find they have more in common than they initially thought.\n\n**Reception**:\n\"Shrek\" + was a critical and commercial success. It was praised for its humor, animation, + voice performances, and soundtrack. The film also marked a departure from the + traditional fairy tale format by satirizing and parodying them. It became one + of the highest-grossing films of 2001 and won the first-ever Academy Award for + Best Animated Feature. It was also nominated for Best Adapted Screenplay.\n\n**Legacy**:\n\"Shrek\" + spawned multiple sequels, including \"Shrek 2,\" \"Shrek the Third,\" and \"Shrek + Forever After,\" and became a significant part of popular culture. It also inspired + a spin-off film, \"Puss in Boots,\" and a musical adaptation.", "role": "assistant"}, + {"content": "Use the information gathered in the conversation to answer.", "role": + "system"}], "model": "gpt-4o-2024-08-06", "n": 1, "response_format": {"type": + "json_schema", "json_schema": {"name": "OutputSchema", "description": "dict() + -> new empty dictionary\ndict(mapping) -> new dictionary initialized from a + mapping object''s\n (key, value) pairs\ndict(iterable) -> new dictionary + initialized as if via:\n d = {}\n for k, v in iterable:\n d[k] + = v\ndict(**kwargs) -> new dictionary initialized with the name=value pairs\n in + the keyword argument list. For example: dict(one=1, two=2)", "strict": true, + "schema": {"type": "object", "properties": {"title": {"type": "string"}, "year": + {"type": "integer"}, "genres": {"type": "array", "items": {"type": "string"}}}, + "required": ["title", "year", "genres"], "additionalProperties": false}}}, "temperature": + 1.0}' + headers: + accept: + - application/json + accept-encoding: + - gzip, deflate + authorization: + - DUMMY + connection: + - keep-alive + content-length: + - '3406' + content-type: + - application/json + host: + - api.openai.com + user-agent: + - OpenAI/Python + x-stainless-helper-method: + - beta.chat.completions.parse + method: POST + uri: https://api.openai.com/v1/chat/completions + response: + body: + string: !!binary | + H4sIAAAAAAAAA1SRzW7CMBCE73kKa89JlaRpoLmhHipVqpAKN4yQEzbBxbEte5FKEe9eOYSfXnyY + b2c0uz5FjIHcQsWg2QlqequSWSmX88K6A7bzxRf9LmVRYz2rZx/d+yfEwWHqb2zo6npqTG8VkjT6 + ghuHgjCkZpO8zMoin5YD6M0WVbB1lpLCJHmaF0k6TdJyNO6MbNBDxVYRY4ydhjdU1Fv8gYql8VXp + 0XvRIVS3IcbAGRUUEN5LT0ITxHfYGE2oh9YnDiRJIYeKw2LncM8h5nBE4ThUeZpmMYcOtUPPoVpx + mGnZi7DgMPdmetweOazPj/kO24MXYT19UGrUz7fCynTWmdqP/Ka3Uku/2zgU3uhQzpOxMNBzxNh6 + OMzh365gnektbcjsUYfASZFd8uD+FXeaTUdIhoR6cL28RmND8EdP2G9aqTt01snLnVq7mWKWTSb1 + c1pCdI7+AAAA//8DAEyr8E8wAgAA + headers: + Connection: + - keep-alive + Content-Encoding: + - gzip + Content-Type: + - application/json + Date: Sun, 09 Jun 2024 23:39:08 GMT + Server: DUMMY + Transfer-Encoding: + - chunked + X-Content-Type-Options: + - nosniff + access-control-expose-headers: + - X-Request-ID + status: + code: 200 + message: OK +version: 1 diff --git a/tests/test_helpers/test_assistants.py b/tests/test_helpers/test_assistants.py index 70cf7fd..f928c63 100644 --- a/tests/test_helpers/test_assistants.py +++ b/tests/test_helpers/test_assistants.py @@ -1,4 +1,4 @@ -from typing import List +from typing import List, TypedDict from unittest.mock import patch import pytest @@ -236,3 +236,58 @@ def tool_a(self, foo: str) -> str: "tool_b", "tool_a", ] + + +@pytest.mark.django_db(transaction=True) +@pytest.mark.vcr +def test_AIAssistant_pydantic_structured_output(): + from pydantic import BaseModel + + class OutputSchema(BaseModel): + name: str + age: int + is_student: bool + + class StructuredOutputAssistant(AIAssistant): + id = "structured_output_assistant" # noqa: A003 + name = "Structured Output Assistant" + instructions = "You are a helpful assistant that provides information about people." + model = "gpt-4o-2024-08-06" + structured_output = OutputSchema + + assistant = StructuredOutputAssistant() + + # Test invoking the assistant with structured output + result = assistant.run("Tell me about John who is 30 years old and not a student.") + assert isinstance(result, OutputSchema) + assert result.name == "John" + assert result.age == 30 + assert result.is_student is False + + +@pytest.mark.django_db(transaction=True) +@pytest.mark.vcr +def test_AIAssistant_typeddict_structured_output(): + class OutputSchema(TypedDict): + title: str + year: int + genres: List[str] + + class DictStructuredOutputAssistant(AIAssistant): + id = "dict_structured_output_assistant" # noqa: A003 + name = "Dict Structured Output Assistant" + instructions = "You are a helpful assistant that provides information about movies." + model = "gpt-4o-2024-08-06" + structured_output = OutputSchema + + assistant = DictStructuredOutputAssistant() + + # Test invoking the assistant with dict structured output + result = assistant.run( + "Provide information about the movie Shrek. " + "It was released in 2001 and is an animation and comedy movie." + ) + + assert result["title"] == "Shrek" + assert result["year"] == 2001 + assert result["genres"] == ["Animation", "Comedy"]