Skip to content

Commit

Permalink
Simplify graph for structured output
Browse files Browse the repository at this point in the history
  • Loading branch information
fjsj committed Sep 30, 2024
1 parent cdfe463 commit ece4ac2
Show file tree
Hide file tree
Showing 3 changed files with 130 additions and 805 deletions.
51 changes: 18 additions & 33 deletions django_ai_assistant/helpers/assistants.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,5 @@
import abc
import inspect
import json
import re
from typing import Annotated, Any, ClassVar, Dict, Sequence, Type, TypedDict, cast

Expand Down Expand Up @@ -453,25 +452,6 @@ class AgentState(TypedDict):

def setup(state: AgentState):
system_prompt = self.get_instructions()

if self.structured_output:
# If Pydantic
if inspect.isclass(self.structured_output) and issubclass(
self.structured_output, BaseModel
):
schema = json.dumps(self.structured_output.model_json_schema())
schema_information = (
f"Your JSON output must have the following schema:\n{schema}\n"
if schema
else ""
)
json_info = (
"In the last step of this chat you will be asked to respond in JSON. "
+ schema_information
+ "Don't generate JSON until you are explicitly told to. "
)
system_prompt += "\n" + json_info

return {"messages": [SystemMessage(content=system_prompt)]}

def history(state: AgentState):
Expand Down Expand Up @@ -514,18 +494,23 @@ def tool_selector(state: AgentState):
return "continue"

def record_response(state: AgentState):
# Structured output must happen in the end, to avoid disabling tool calling.
# Tool calling + structured output is not supported by OpenAI:
if self.structured_output:
# Structured output must happen in the end, to avoid disabling tool calling.
# Tool calling + structured output is not supported by OpenAI:
llm_with_structured_output = self.get_structured_output_llm()
response = llm_with_structured_output.invoke(
[
*state["messages"],
HumanMessage(
content="Use the information gathered in the conversation to answer with JSON."
),
]
messages = state["messages"]

# Change the original system prompt:
if isinstance(messages[0], SystemMessage):
messages[0].content += "\nUse the chat history to produce a JSON output."

# Add a final message asking for JSON generation / structured output:
json_request_message = HumanMessage(
content="Use the chat history to produce a JSON output."
)
messages.append(json_request_message)

llm_with_structured_output = self.get_structured_output_llm()
response = llm_with_structured_output.invoke(messages)
else:
response = state["messages"][-1].content

Expand Down Expand Up @@ -580,7 +565,7 @@ def invoke(self, *args: Any, thread_id: Any | None, **kwargs: Any) -> dict:
return graph.invoke(*args, config=config, **kwargs)

@with_cast_id
def run(self, message: str, thread_id: Any | None = None, **kwargs: Any) -> str:
def run(self, message: str, thread_id: Any | None = None, **kwargs: Any) -> Any:
"""Run the assistant with the given message and thread ID.\n
This is the higher-level method to run the assistant.\n
Expand All @@ -591,7 +576,7 @@ def run(self, message: str, thread_id: Any | None = None, **kwargs: Any) -> str:
**kwargs: Additional keyword arguments to pass to the graph.
Returns:
str: The assistant response to the user message.
Any: The assistant response to the user message.
"""
return self.invoke(
{
Expand All @@ -601,7 +586,7 @@ def run(self, message: str, thread_id: Any | None = None, **kwargs: Any) -> str:
**kwargs,
)["output"]

def _run_as_tool(self, message: str, **kwargs: Any) -> str:
def _run_as_tool(self, message: str, **kwargs: Any) -> Any:
return self.run(message, thread_id=None, **kwargs)

def as_tool(self, description: str) -> BaseTool:
Expand Down
Loading

0 comments on commit ece4ac2

Please sign in to comment.