-
Notifications
You must be signed in to change notification settings - Fork 807
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Replaced llm_validator in pre made agents
- Loading branch information
Showing
7 changed files
with
112 additions
and
13 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,4 +1,5 @@ | ||
from .cli.create_agent_template import create_agent_template | ||
from .cli.import_agent import import_agent | ||
from .oai import set_openai_key, get_openai_client, set_openai_client | ||
from .files import determine_file_type | ||
from .files import determine_file_type | ||
from .validators import llm_validator |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,94 @@ | ||
from openai import OpenAI | ||
from typing import Callable | ||
from pydantic import Field | ||
from agency_swarm.tools import BaseTool | ||
|
||
class Validator(BaseTool): | ||
""" | ||
Validate if an attribute is correct and if not, | ||
return a new value with an error message | ||
""" | ||
|
||
is_valid: bool = Field( | ||
default=True, | ||
description="Whether the attribute is valid based on the requirements", | ||
) | ||
reason: str = Field( | ||
default=None, | ||
description="The error message if the attribute is not valid, otherwise None", | ||
) | ||
fixed_value: str = Field( | ||
default=None, | ||
description="If the attribute is not valid, suggest a new value for the attribute", | ||
) | ||
|
||
def run(self): | ||
pass | ||
|
||
def llm_validator( | ||
statement: str, | ||
client: OpenAI=None, | ||
allow_override: bool = False, | ||
model: str = "gpt-3.5-turbo", | ||
temperature: float = 0, | ||
) -> Callable[[str], str]: | ||
""" | ||
Create a validator that uses the LLM to validate an attribute | ||
## Usage | ||
```python | ||
from agency_swarm import llm_validator | ||
from pydantic import Field, field_validator | ||
class User(BaseTool): | ||
name: str = Annotated[str, llm_validator("The name must be a full name all lowercase") | ||
age: int = Field(description="The age of the person") | ||
try: | ||
user = User(name="Jason Liu", age=20) | ||
except ValidationError as e: | ||
print(e) | ||
``` | ||
``` | ||
1 validation error for User | ||
name | ||
The name is valid but not all lowercase (type=value_error.llm_validator) | ||
``` | ||
Note that there, the error message is written by the LLM, and the error type is `value_error.llm_validator`. | ||
Parameters: | ||
statement (str): The statement to validate | ||
model (str): The LLM to use for validation (default: "gpt-3.5-turbo-0613") | ||
temperature (float): The temperature to use for the LLM (default: 0) | ||
openai_client (OpenAI): The OpenAI client to use (default: None) | ||
""" | ||
def llm(v: str) -> str: | ||
resp = client.beta.chat.completions.parse( | ||
response_format=Validator, | ||
messages=[ | ||
{ | ||
"role": "system", | ||
"content": "You are a world class validation model. Capable to determine if the following value is valid for the statement, if it is not, explain why and suggest a new value.", | ||
}, | ||
{ | ||
"role": "user", | ||
"content": f"Does `{v}` follow the rules: {statement}", | ||
}, | ||
], | ||
model=model, | ||
temperature=temperature, | ||
) | ||
|
||
# If the response is not valid, return the reason, this could be used in | ||
# the future to generate a better response, via reasking mechanism. | ||
assert resp.is_valid, resp.reason | ||
|
||
if allow_override and not resp.is_valid and resp.fixed_value is not None: | ||
# If the value is not valid, but we allow override, return the fixed value | ||
return resp.fixed_value | ||
return v | ||
|
||
return llm |