From 6d939280d358e7f1edd794a3a2f684aa896f0d5f Mon Sep 17 00:00:00 2001 From: Arsenii Shatokhin Date: Thu, 15 Aug 2024 09:42:08 +0400 Subject: [PATCH] Replaced llm_validator in pre made agents --- agency_swarm/__init__.py | 1 + agency_swarm/agents/Devid/Devid.py | 2 +- agency_swarm/agents/Devid/tools/ChangeFile.py | 5 +- agency_swarm/agents/Devid/tools/FileWriter.py | 2 +- .../Devid/tools/util/format_file_deps.py | 18 ++-- agency_swarm/util/__init__.py | 3 +- agency_swarm/util/validators.py | 94 +++++++++++++++++++ 7 files changed, 112 insertions(+), 13 deletions(-) create mode 100644 agency_swarm/util/validators.py diff --git a/agency_swarm/__init__.py b/agency_swarm/__init__.py index 5416c1d5..266398e5 100644 --- a/agency_swarm/__init__.py +++ b/agency_swarm/__init__.py @@ -5,3 +5,4 @@ from .util import set_openai_client from .util import get_openai_client from .util.streaming import AgencyEventHandler +from .util.validators import llm_validator diff --git a/agency_swarm/agents/Devid/Devid.py b/agency_swarm/agents/Devid/Devid.py index ffd3f790..d57f00cc 100644 --- a/agency_swarm/agents/Devid/Devid.py +++ b/agency_swarm/agents/Devid/Devid.py @@ -2,7 +2,7 @@ import re from agency_swarm.agents import Agent from agency_swarm.tools import FileSearch -from instructor import llm_validator +from agency_swarm import llm_validator class Devid(Agent): diff --git a/agency_swarm/agents/Devid/tools/ChangeFile.py b/agency_swarm/agents/Devid/tools/ChangeFile.py index 08d2ac6c..672c2c4d 100644 --- a/agency_swarm/agents/Devid/tools/ChangeFile.py +++ b/agency_swarm/agents/Devid/tools/ChangeFile.py @@ -2,12 +2,11 @@ from enum import Enum from typing import Literal, Optional, List -from instructor import OpenAISchema -from pydantic import Field, model_validator, field_validator +from pydantic import Field, model_validator, field_validator, BaseModel from agency_swarm import BaseTool -class LineChange(OpenAISchema): +class LineChange(BaseModel): """ Line changes to be made. """ diff --git a/agency_swarm/agents/Devid/tools/FileWriter.py b/agency_swarm/agents/Devid/tools/FileWriter.py index d84a02d0..73d639e1 100644 --- a/agency_swarm/agents/Devid/tools/FileWriter.py +++ b/agency_swarm/agents/Devid/tools/FileWriter.py @@ -3,7 +3,7 @@ import json import os -from instructor import llm_validator, OpenAISchema +from agency_swarm import llm_validator from agency_swarm import get_openai_client from agency_swarm.tools import BaseTool diff --git a/agency_swarm/agents/Devid/tools/util/format_file_deps.py b/agency_swarm/agents/Devid/tools/util/format_file_deps.py index 211fb10e..b848d335 100644 --- a/agency_swarm/agents/Devid/tools/util/format_file_deps.py +++ b/agency_swarm/agents/Devid/tools/util/format_file_deps.py @@ -1,5 +1,4 @@ -from instructor import OpenAISchema -from pydantic import Field +from pydantic import Field, BaseModel from typing import List, Literal from agency_swarm import get_openai_client @@ -13,11 +12,11 @@ def format_file_deps(v): with open(file, 'r') as f: content = f.read() - class Dependency(OpenAISchema): + class Dependency(BaseModel): type: Literal['class', 'function', 'import'] = Field(..., description="The type of the dependency.") name: str = Field(..., description="The name of the dependency, matching the import or definition.") - class Dependencies(OpenAISchema): + class Dependencies(BaseModel): dependencies: List[Dependency] = Field([], description="The dependencies extracted from the file.") def append_dependencies(self): @@ -29,7 +28,7 @@ def append_dependencies(self): result += f"File path: {file}\n" result += f"Functions: {functions}\nClasses: {classes}\nImports: {imports}\nVariables: {variables}\n\n" - resp = client.chat.completions.create( + completion = client.beta.chat.completions.parse( messages=[ { "role": "system", @@ -42,9 +41,14 @@ def append_dependencies(self): ], model="gpt-3.5-turbo", temperature=0, - response_model=Dependencies + response_format=Dependencies ) - resp.append_dependencies() + if completion.choices[0].message.refusal: + raise ValueError(completion.choices[0].message.refusal) + + model = completion.choices[0].message.parsed + + model.append_dependencies() return result \ No newline at end of file diff --git a/agency_swarm/util/__init__.py b/agency_swarm/util/__init__.py index cae60587..b4c5feb0 100644 --- a/agency_swarm/util/__init__.py +++ b/agency_swarm/util/__init__.py @@ -1,4 +1,5 @@ from .cli.create_agent_template import create_agent_template from .cli.import_agent import import_agent from .oai import set_openai_key, get_openai_client, set_openai_client -from .files import determine_file_type \ No newline at end of file +from .files import determine_file_type +from .validators import llm_validator \ No newline at end of file diff --git a/agency_swarm/util/validators.py b/agency_swarm/util/validators.py new file mode 100644 index 00000000..0dd36abb --- /dev/null +++ b/agency_swarm/util/validators.py @@ -0,0 +1,94 @@ +from openai import OpenAI +from typing import Callable +from pydantic import Field +from agency_swarm.tools import BaseTool + +class Validator(BaseTool): + """ + Validate if an attribute is correct and if not, + return a new value with an error message + """ + + is_valid: bool = Field( + default=True, + description="Whether the attribute is valid based on the requirements", + ) + reason: str = Field( + default=None, + description="The error message if the attribute is not valid, otherwise None", + ) + fixed_value: str = Field( + default=None, + description="If the attribute is not valid, suggest a new value for the attribute", + ) + + def run(self): + pass + +def llm_validator( + statement: str, + client: OpenAI=None, + allow_override: bool = False, + model: str = "gpt-3.5-turbo", + temperature: float = 0, +) -> Callable[[str], str]: + """ + Create a validator that uses the LLM to validate an attribute + + ## Usage + + ```python + from agency_swarm import llm_validator + from pydantic import Field, field_validator + + class User(BaseTool): + name: str = Annotated[str, llm_validator("The name must be a full name all lowercase") + age: int = Field(description="The age of the person") + + try: + user = User(name="Jason Liu", age=20) + except ValidationError as e: + print(e) + ``` + + ``` + 1 validation error for User + name + The name is valid but not all lowercase (type=value_error.llm_validator) + ``` + + Note that there, the error message is written by the LLM, and the error type is `value_error.llm_validator`. + + Parameters: + statement (str): The statement to validate + model (str): The LLM to use for validation (default: "gpt-3.5-turbo-0613") + temperature (float): The temperature to use for the LLM (default: 0) + openai_client (OpenAI): The OpenAI client to use (default: None) + """ + def llm(v: str) -> str: + resp = client.beta.chat.completions.parse( + response_format=Validator, + messages=[ + { + "role": "system", + "content": "You are a world class validation model. Capable to determine if the following value is valid for the statement, if it is not, explain why and suggest a new value.", + }, + { + "role": "user", + "content": f"Does `{v}` follow the rules: {statement}", + }, + ], + model=model, + temperature=temperature, + ) + + # If the response is not valid, return the reason, this could be used in + # the future to generate a better response, via reasking mechanism. + assert resp.is_valid, resp.reason + + if allow_override and not resp.is_valid and resp.fixed_value is not None: + # If the value is not valid, but we allow override, return the fixed value + return resp.fixed_value + return v + + return llm \ No newline at end of file