Skip to content

Commit

Permalink
Replaced llm_validator in pre made agents
Browse files Browse the repository at this point in the history
  • Loading branch information
VRSEN committed Aug 15, 2024
1 parent a29cf33 commit 6d93928
Show file tree
Hide file tree
Showing 7 changed files with 112 additions and 13 deletions.
1 change: 1 addition & 0 deletions agency_swarm/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,3 +5,4 @@
from .util import set_openai_client
from .util import get_openai_client
from .util.streaming import AgencyEventHandler
from .util.validators import llm_validator
2 changes: 1 addition & 1 deletion agency_swarm/agents/Devid/Devid.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
import re
from agency_swarm.agents import Agent
from agency_swarm.tools import FileSearch
from instructor import llm_validator
from agency_swarm import llm_validator


class Devid(Agent):
Expand Down
5 changes: 2 additions & 3 deletions agency_swarm/agents/Devid/tools/ChangeFile.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,12 +2,11 @@
from enum import Enum
from typing import Literal, Optional, List

from instructor import OpenAISchema
from pydantic import Field, model_validator, field_validator
from pydantic import Field, model_validator, field_validator, BaseModel

from agency_swarm import BaseTool

class LineChange(OpenAISchema):
class LineChange(BaseModel):
"""
Line changes to be made.
"""
Expand Down
2 changes: 1 addition & 1 deletion agency_swarm/agents/Devid/tools/FileWriter.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
import json

import os
from instructor import llm_validator, OpenAISchema
from agency_swarm import llm_validator

from agency_swarm import get_openai_client
from agency_swarm.tools import BaseTool
Expand Down
18 changes: 11 additions & 7 deletions agency_swarm/agents/Devid/tools/util/format_file_deps.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,4 @@
from instructor import OpenAISchema
from pydantic import Field
from pydantic import Field, BaseModel
from typing import List, Literal

from agency_swarm import get_openai_client
Expand All @@ -13,11 +12,11 @@ def format_file_deps(v):
with open(file, 'r') as f:
content = f.read()

class Dependency(OpenAISchema):
class Dependency(BaseModel):
type: Literal['class', 'function', 'import'] = Field(..., description="The type of the dependency.")
name: str = Field(..., description="The name of the dependency, matching the import or definition.")

class Dependencies(OpenAISchema):
class Dependencies(BaseModel):
dependencies: List[Dependency] = Field([], description="The dependencies extracted from the file.")

def append_dependencies(self):
Expand All @@ -29,7 +28,7 @@ def append_dependencies(self):
result += f"File path: {file}\n"
result += f"Functions: {functions}\nClasses: {classes}\nImports: {imports}\nVariables: {variables}\n\n"

resp = client.chat.completions.create(
completion = client.beta.chat.completions.parse(
messages=[
{
"role": "system",
Expand All @@ -42,9 +41,14 @@ def append_dependencies(self):
],
model="gpt-3.5-turbo",
temperature=0,
response_model=Dependencies
response_format=Dependencies
)

resp.append_dependencies()
if completion.choices[0].message.refusal:
raise ValueError(completion.choices[0].message.refusal)

model = completion.choices[0].message.parsed

model.append_dependencies()

return result
3 changes: 2 additions & 1 deletion agency_swarm/util/__init__.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
from .cli.create_agent_template import create_agent_template
from .cli.import_agent import import_agent
from .oai import set_openai_key, get_openai_client, set_openai_client
from .files import determine_file_type
from .files import determine_file_type
from .validators import llm_validator
94 changes: 94 additions & 0 deletions agency_swarm/util/validators.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,94 @@
from openai import OpenAI
from typing import Callable
from pydantic import Field
from agency_swarm.tools import BaseTool

class Validator(BaseTool):
"""
Validate if an attribute is correct and if not,
return a new value with an error message
"""

is_valid: bool = Field(
default=True,
description="Whether the attribute is valid based on the requirements",
)
reason: str = Field(
default=None,
description="The error message if the attribute is not valid, otherwise None",
)
fixed_value: str = Field(
default=None,
description="If the attribute is not valid, suggest a new value for the attribute",
)

def run(self):
pass

def llm_validator(
statement: str,
client: OpenAI=None,
allow_override: bool = False,
model: str = "gpt-3.5-turbo",
temperature: float = 0,
) -> Callable[[str], str]:
"""
Create a validator that uses the LLM to validate an attribute
## Usage
```python
from agency_swarm import llm_validator
from pydantic import Field, field_validator
class User(BaseTool):
name: str = Annotated[str, llm_validator("The name must be a full name all lowercase")
age: int = Field(description="The age of the person")
try:
user = User(name="Jason Liu", age=20)
except ValidationError as e:
print(e)
```
```
1 validation error for User
name
The name is valid but not all lowercase (type=value_error.llm_validator)
```
Note that there, the error message is written by the LLM, and the error type is `value_error.llm_validator`.
Parameters:
statement (str): The statement to validate
model (str): The LLM to use for validation (default: "gpt-3.5-turbo-0613")
temperature (float): The temperature to use for the LLM (default: 0)
openai_client (OpenAI): The OpenAI client to use (default: None)
"""
def llm(v: str) -> str:
resp = client.beta.chat.completions.parse(
response_format=Validator,
messages=[
{
"role": "system",
"content": "You are a world class validation model. Capable to determine if the following value is valid for the statement, if it is not, explain why and suggest a new value.",
},
{
"role": "user",
"content": f"Does `{v}` follow the rules: {statement}",
},
],
model=model,
temperature=temperature,
)

# If the response is not valid, return the reason, this could be used in
# the future to generate a better response, via reasking mechanism.
assert resp.is_valid, resp.reason

if allow_override and not resp.is_valid and resp.fixed_value is not None:
# If the value is not valid, but we allow override, return the fixed value
return resp.fixed_value
return v

return llm

0 comments on commit 6d93928

Please sign in to comment.