Skip to content

Commit

Permalink
Fix azure openai error strict mode unsupported
Browse files Browse the repository at this point in the history
  • Loading branch information
VRSEN committed Oct 16, 2024
1 parent e0f9fbd commit 688d40e
Show file tree
Hide file tree
Showing 5 changed files with 34 additions and 29 deletions.
29 changes: 26 additions & 3 deletions agency_swarm/agents/agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -547,12 +547,35 @@ def _check_parameters(self, assistant_settings, debug=False):
print(f"Instructions mismatch: {self.instructions} != {assistant_settings['instructions']}")
return False

tools_diff = DeepDiff(self.get_oai_tools(), assistant_settings['tools'], ignore_order=True)
def sort_tool(tool):
return json.dumps(tool, sort_keys=True)

# Sort the tools in both local and assistant settings
local_tools = sorted(self.get_oai_tools(), key=sort_tool)
assistant_tools = sorted(assistant_settings['tools'], key=sort_tool)

# Remove 'strict' from all assistant tools if it's set to False
for tool in assistant_tools:
if isinstance(tool, dict) and 'function' in tool:
if 'strict' in tool['function'] and tool['function']['strict'] is False:
tool['function'].pop('strict', None)

# Ignore specific differences in file_search tool
for tool in assistant_tools:
if isinstance(tool, dict) and tool.get('type') == 'file_search':
if 'file_search' in tool:
tool['file_search'].pop('ranking_options', None)
for tool in local_tools:
if isinstance(tool, dict) and tool.get('type') == 'file_search':
if 'file_search' in tool:
tool['file_search'].pop('ranking_options', None)

tools_diff = DeepDiff(local_tools, assistant_tools, ignore_order=True)
if tools_diff != {}:
if debug:
print(f"Tools mismatch: {tools_diff}")
print("local tools: ", self.get_oai_tools())
print("assistant tools: ", assistant_settings['tools'])
print("local tools: ", local_tools)
print("assistant tools: ", assistant_tools)
return False

if self.temperature != assistant_settings['temperature']:
Expand Down
3 changes: 0 additions & 3 deletions agency_swarm/tools/BaseTool.py
Original file line number Diff line number Diff line change
Expand Up @@ -72,9 +72,6 @@ def openai_schema(cls):
if "$defs" in schema["parameters"]:
for def_ in schema["parameters"]["$defs"].values():
def_["additionalProperties"] = False
else:
if "strict" in schema:
del schema["strict"]

return schema

Expand Down
27 changes: 6 additions & 21 deletions agency_swarm/tools/oai/FileSearch.py
Original file line number Diff line number Diff line change
@@ -1,23 +1,8 @@
from typing import Optional
from pydantic import BaseModel, field_validator, Field
from typing import Dict, Union, Optional
from openai.types.beta.file_search_tool import FileSearchTool
from openai.types.beta.file_search_tool import FileSearch as OpenAIFileSearch

class FileSearchConfig(BaseModel):
max_num_results: int = Field(50, description="Optional override for the maximum number of results")
ranking_options: Optional[Dict[str, Union[str, float]]] = Field(
{'ranker': 'default_2024_08_21', 'score_threshold': 0.0},
description="The ranking options for the file search. If not specified, the file search tool will use the auto ranker and a score_threshold of 0."
)
class FileSearchConfig(OpenAIFileSearch):
pass

@field_validator('max_num_results')
def check_max_num_results(cls, v):
if not 1 <= v <= 50:
raise ValueError('file_search.max_num_results must be between 1 and 50 inclusive')
return v
class FileSearch(BaseModel):
type: str = "file_search"

file_search: Optional[FileSearchConfig] = None

class Config:
exclude_none = True
class FileSearch(FileSearchTool):
type: str = "file_search"
2 changes: 1 addition & 1 deletion requirements.txt
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
openai==1.41.0
openai==1.51.2
docstring_parser==0.16
pydantic==2.8.2
datamodel-code-generator==0.25.8
Expand Down
2 changes: 1 addition & 1 deletion tests/demos/demo_gradio.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,7 @@ def run(self, **kwargs):

agency = Agency([
ceo, [ceo, test_agent, test_agent2],
], shared_instructions="", async_tool_calls=False)
], shared_instructions="", settings_path="./test_settings.json")

# agency.demo_gradio()

Expand Down

0 comments on commit 688d40e

Please sign in to comment.