From 7390e8bb08bef3985f0dc6b9fe3c1270d6dec84b Mon Sep 17 00:00:00 2001 From: Arsenii Shatokhin Date: Thu, 7 Nov 2024 10:57:42 +0400 Subject: [PATCH] Added predicted outputs to file writer --- agency_swarm/agents/Devid/tools/FileWriter.py | 42 ++++++++++++++----- 1 file changed, 31 insertions(+), 11 deletions(-) diff --git a/agency_swarm/agents/Devid/tools/FileWriter.py b/agency_swarm/agents/Devid/tools/FileWriter.py index e714fd51..f070fd0d 100644 --- a/agency_swarm/agents/Devid/tools/FileWriter.py +++ b/agency_swarm/agents/Devid/tools/FileWriter.py @@ -78,11 +78,11 @@ def run(self): if self.mode == "modify": message += f"\nThe existing file content is as follows:" - + try: with open(self.file_path, 'r') as file: - prev_content = file.read() - message += f"\n\n```{prev_content}```" + file_content = file.read() + message += f"\n\n```{file_content}```" except Exception as e: return f'Error reading {self.file_path}: {e}' @@ -102,11 +102,22 @@ def run(self): n = 0 error_message = "" while n < 3: - resp = client.chat.completions.create( - messages=messages, - model="gpt-4o", - temperature=0, - ) + if self.mode == "modify": + resp = client.chat.completions.create( + messages=messages, + model="gpt-4o", + temperature=0, + prediction={ + "type": "content", + "content": file_content + } + ) + else: + resp = client.chat.completions.create( + messages=messages, + model="gpt-4o", + temperature=0, + ) content = resp.choices[0].message.content @@ -236,9 +247,18 @@ def validate_documentation(cls, v): if __name__ == "__main__": - tool = FileWriter( + # Test case for 'write' mode + tool_write = FileWriter( requirements="Write a program that takes a list of integers as input and returns the sum of all the integers in the list.", mode="write", - file_path="test.py", + file_path="test_write.py", + ) + print(tool_write.run()) + + # Test case for 'modify' mode + tool_modify = FileWriter( + requirements="Modify the program to also return the product of all the integers in the list.", + mode="modify", + file_path="test_write.py", ) - print(tool.run()) + print(tool_modify.run())